Как использовать веса классов с фокусными потерями в PyTorch для несбалансированного набора данных для многоклассовой классификации

#python #machine-learning #deep-learning #neural-network #pytorch

#python #машинное обучение #глубокое обучение #нейронная сеть #pytorch

Вопрос:

Я работаю над многоклассовой классификацией (4 класса) для языковой задачи и использую модель BERT для задачи классификации. Я следую этому блогу в качестве ссылки. Возвращается моя точно настроенная модель BERT nn.LogSoftmax(dim=1) .

Мои данные довольно несбалансированы, поэтому я использовал sklearn.utils.class_weight.compute_class_weight для вычисления весов классов и использовал веса внутри потери.

 class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy  = nn.NLLLoss(weight=weights) 

  

Мои результаты были не очень хорошими, поэтому я подумал об экспериментировании Focal Loss и получил код для фокусных потерь.

 class FocalLoss(nn.Module):
  def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.logits = logits
    self.reduce = reduce

  def forward(self, inputs, targets):
    BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
      return torch.mean(F_loss)
    else:
      return F_loss
  

Теперь у меня есть 3 вопроса. Первое и самое важное

  1. Должен ли я использовать вес класса с фокусными потерями?
  2. Если мне нужно реализовать веса внутри этого Focal Loss , могу ли я использовать weights параметры внутри nn.CrossEntropyLoss()
  3. Если эта реализация неверна, каким должен быть правильный код для этого, включая веса (если возможно)

Комментарии:

1. подождите, если ваши данные несбалансированы, почему вы выбрали здесь «сбалансированный»? Я довольно смущен compute_class_weight('balanced', np.unique(train_labels), train_labels)

2.@MonaJalal balanced означает присвоение веса класса в соответствии с количеством образцов, присутствующих в классе? Не так ли? Как указано в этой документации , если «сбалансированный», веса классов будут заданы n_samples / (n_classes * np.bincount(y)) .

Ответ №1:

Вы можете найти ответы на свои вопросы следующим образом:

  1. Потеря фокуса автоматически обрабатывает дисбаланс классов, следовательно, веса не требуются для потери фокуса. Альфа- и гамма-факторы регулируют дисбаланс классов в уравнении фокальных потерь.
  2. Нет необходимости в дополнительных весах, поскольку фокальные потери обрабатывают их с использованием альфа- и гамма-модулирующих коэффициентов
  3. Упомянутая вами реализация верна в соответствии с формулой фокусных потерь, но у меня возникли проблемы с тем, чтобы моя модель сходилась с этой версией, поэтому я использовал следующую реализацию из mmdetection framework
     pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target   pred_sigmoid * (1 - target)
    focal_weight = (alpha * target   (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss
  

Вы также можете поэкспериментировать с другой доступной версией с фокусными потерями

Комментарии:

1. предположим, что это softmax? его мультикласс

Ответ №2:

Я думаю, что OP уже получил бы свой ответ. Я пишу это для других людей, которые могут задуматься над этим.

Существует одна проблема в реализации OPs с фокусными потерями:

  1. F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

В этой строке одно и то же alpha значение умножается на каждую вероятность вывода класса, т.Е. ( pt ) . Кроме того, код не показывает, как мы получаем pt . Очень хорошую реализацию фокусных потерь можно найти здесь. Но эта реализация предназначена только для двоичной классификации, как она есть alpha , и 1-alpha для двух классов в self.alpha тензоре.

В случае многоклассовой классификации или классификации с несколькими метками self.alpha тензор должен содержать количество элементов, равное общему количеству меток. Значениями могут быть обратная частота меток меток или обратная нормализованная частота меток (просто будьте осторожны с метками, которые имеют 0 в качестве частоты).

Ответ №3:

Я думаю, что реализация в вашем вопросе неверна. Альфа — это вес класса.

В перекрестной энтропии вес класса равен alpha_t, как показано в следующем выражении:

введите описание изображения здесь

вы видите, что это alpha_t, а не alpha.

При фокусных потерях значение fomular равно
введите описание изображения здесь

и мы можем видеть из этой популярной реализации Pytorch, что альфа действует так же, как и вес класса.

Ссылки:

  1. https://amaarora.github.io/2020/06/29/FocalLoss.html#alpha-and-gamma
  2. https://github.com/clcarwin/focal_loss_pytorch