TensorFlow: как реализовать функцию потерь для каждого класса для двоичной классификации

#machine-learning #neural-network #tensorflow

#машинное обучение #нейронная сеть #tensorflow

Вопрос:

У меня есть два класса: положительный (1) и отрицательный (0).

Набор данных очень несбалансированный, поэтому на данный момент мои мини-пакеты содержат в основном 0. Фактически, многие пакеты будут содержать только 0. Я хотел поэкспериментировать с отдельной стоимостью для положительных и отрицательных примеров; см. Код ниже.

Проблема с моим кодом заключается в том, что я получаю много nan , потому что список bound_index будет пустым. Какой элегантный способ решить эту проблему?

 def calc_loss_debug(logits, labels):
  logits = tf.reshape(logits, [-1])
  labels = tf.reshape(labels, [-1])
  index_bound = tf.where(tf.equal(labels, tf.constant(1, dtype=tf.float32)))
  index_unbound = tf.where(tf.equal(labels, tf.constant(0, dtype=tf.float32)))
  entropies = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
  entropies_bound = tf.gather(entropies, index_bound)
  entropies_unbound = tf.gather(entropies, index_unbound)
  loss_bound = tf.reduce_mean(entropies_bound)
  loss_unbound = tf.reduce_mean(entropies_unbound)
  

Комментарии:

1. Nan происходит потому, что я беру среднее значение пустого списка (entropies_bound будет пустым)

Ответ №1:

Поскольку у вас есть метки 0 и 1, вы можете легко избежать tf.where такой конструкции

 labels = ...
entropies = ...
labels_complement = tf.constant(1.0, dtype=tf.float32) - labels
entropy_ones = tf.reduce_sum(tf.mul(labels, entropies))
entropy_zeros = tf.reduce_sum(tf.mul(labels_complement, entropies))
  

Чтобы получить среднюю потерю, вам нужно разделить на количество 0 и 1 в пакете, которое можно легко вычислить как

 num_ones = tf.reduce_sum(labels)
num_zeros = tf.reduce_sum(labels_complement)
  

Конечно, вам все равно нужно избегать деления на 0, когда в пакете нет единиц. Я бы предложил использовать tf.cond(tf.equal(num_ones, 0), ...) .

Комментарии:

1. Спасибо. Разве 0 в tf.conf не должно быть tf.constant(0)?