#python #machine-learning #deep-learning #neural-network #pytorch
#python #машинное обучение #глубокое обучение #нейронная сеть #pytorch
Вопрос:
Я работаю над многоклассовой классификацией (4 класса) для языковой задачи и использую модель BERT для задачи классификации. Я следую этому блогу в качестве ссылки. Возвращается моя точно настроенная модель BERT nn.LogSoftmax(dim=1)
.
Мои данные довольно несбалансированы, поэтому я использовал sklearn.utils.class_weight.compute_class_weight
для вычисления весов классов и использовал веса внутри потери.
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy = nn.NLLLoss(weight=weights)
Мои результаты были не очень хорошими, поэтому я подумал об экспериментировании Focal Loss
и получил код для фокусных потерь.
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
BCE_loss = nn.CrossEntropyLoss()(inputs, targets)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
Теперь у меня есть 3 вопроса. Первое и самое важное
- Должен ли я использовать вес класса с фокусными потерями?
- Если мне нужно реализовать веса внутри этого
Focal Loss
, могу ли я использоватьweights
параметры внутриnn.CrossEntropyLoss()
- Если эта реализация неверна, каким должен быть правильный код для этого, включая веса (если возможно)
Комментарии:
1. подождите, если ваши данные несбалансированы, почему вы выбрали здесь «сбалансированный»? Я довольно смущен
compute_class_weight('balanced', np.unique(train_labels), train_labels)
2.@MonaJalal
balanced
означает присвоение веса класса в соответствии с количеством образцов, присутствующих в классе? Не так ли? Как указано в этой документации , если «сбалансированный», веса классов будут заданы n_samples / (n_classes * np.bincount(y)) .
Ответ №1:
Вы можете найти ответы на свои вопросы следующим образом:
- Потеря фокуса автоматически обрабатывает дисбаланс классов, следовательно, веса не требуются для потери фокуса. Альфа- и гамма-факторы регулируют дисбаланс классов в уравнении фокальных потерь.
- Нет необходимости в дополнительных весах, поскольку фокальные потери обрабатывают их с использованием альфа- и гамма-модулирующих коэффициентов
- Упомянутая вами реализация верна в соответствии с формулой фокусных потерь, но у меня возникли проблемы с тем, чтобы моя модель сходилась с этой версией, поэтому я использовал следующую реализацию из mmdetection framework
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target pred_sigmoid * (1 - target)
focal_weight = (alpha * target (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
Вы также можете поэкспериментировать с другой доступной версией с фокусными потерями
Комментарии:
1. предположим, что это softmax? его мультикласс
Ответ №2:
Я думаю, что OP уже получил бы свой ответ. Я пишу это для других людей, которые могут задуматься над этим.
Существует одна проблема в реализации OPs с фокусными потерями:
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
В этой строке одно и то же alpha
значение умножается на каждую вероятность вывода класса, т.Е. ( pt
) . Кроме того, код не показывает, как мы получаем pt
. Очень хорошую реализацию фокусных потерь можно найти здесь. Но эта реализация предназначена только для двоичной классификации, как она есть alpha
, и 1-alpha
для двух классов в self.alpha
тензоре.
В случае многоклассовой классификации или классификации с несколькими метками self.alpha
тензор должен содержать количество элементов, равное общему количеству меток. Значениями могут быть обратная частота меток меток или обратная нормализованная частота меток (просто будьте осторожны с метками, которые имеют 0 в качестве частоты).
Ответ №3:
Я думаю, что реализация в вашем вопросе неверна. Альфа — это вес класса.
В перекрестной энтропии вес класса равен alpha_t, как показано в следующем выражении:
вы видите, что это alpha_t, а не alpha.
При фокусных потерях значение fomular равно
и мы можем видеть из этой популярной реализации Pytorch, что альфа действует так же, как и вес класса.
Ссылки: