Определение функции потерь в pytorch

#python #pytorch #loss-function

#python #pytorch #функция потерь

Вопрос:

Я должен определить функцию потерь huber, которая заключается в следующем:введите описание изображения здесь

Это мой код

 def huber(a, b): 
   res = (((a-b)[abs(a-b) < 1]) ** 2 / 2).sum()
   res  = ((abs(a-b)[abs(a-b) >= 1]) - 0.5).sum()
   res = res / torch.numel(a)
   return res
  

»’

тем не менее, он не работает должным образом. У вас есть какие-либо идеи, что не так?

Комментарии:

1. Что вы имеете в виду it is not working properly ? Это математическая корректность или какая-то проблема pytorch ?

Ответ №1:

Функция потерь Huber уже существует в PyTorch под именем torch.nn.SmoothL1Loss .

Перейдите по этой ссылке https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html для большего!