Как я могу выразить эту пользовательскую функцию потерь в tensorflow?

#python #tensorflow #machine-learning #keras #pytorch

#python #tensorflow #машинное обучение #keras #pytorch

Вопрос:

У меня есть функция потерь, которая удовлетворяет моим потребностям, но только в PyTorch. Мне нужно внедрить ее в мой код TensorFlow, но, хотя большая часть его может быть тривиально «переведена», я застрял на определенной строке:

 y_hat[:, torch.arange(N), torch.arange(N)] = torch.finfo(y_hat.dtype).max  # to be "1" after sigmoid
 

Вы можете увидеть весь код ниже, и он действительно довольно прост, за исключением этой строки:

 def get_loss(y_hat, y):
 # No loss on diagonal
 B, N, _ = y_hat.shape
 y_hat[:, torch.arange(N), torch.arange(N)] = torch.finfo(y_hat.dtype).max  # to be "1" after sigmoid

 # calc loss
 loss = F.binary_cross_entropy_with_logits(y_hat, y)  # cross entropy

 y_hat = torch.sigmoid(y_hat)
 tp = (y_hat * y).sum(dim=(1, 2))
 fn = ((1. - y_hat) * y).sum(dim=(1, 2))
 fp = (y_hat * (1. - y)).sum(dim=(1, 2))
 loss = loss - ((2 * tp) / (2 * tp   fp   fn   1e-10)).sum()  # fscore

return loss
 

До сих пор я придумал следующее:

 def get_loss(y_hat, y):
 loss = tf.keras.losses.BinaryCrossentropy()(y_hat,y)  # cross entropy (but no logits)


 y_hat = tf.math.sigmoid(y_hat)

 tp = tf.math.reduce_sum(tf.multiply(y_hat, y),[1,2])
 fn = tf.math.reduce_sum((y - tf.multiply(y_hat, y)),[1,2])
 fp = tf.math.reduce_sum((y_hat -tf.multiply(y_hat,y)),[1,2])
 loss = loss - ((2 * tp) / tf.math.reduce_sum((2 * tp   fp   fn   1e-10)))  # fscore

return loss
 

итак, мои вопросы сводятся к:

  • Что torch.finfo() делает и как выразить это в TensorFlow?
  • y_hat.dtype Просто возвращает тип данных?

Ответ №1:

1. Что делает torch.finfo() и как выразить это в TensorFlow?

.finfo() предоставляет удобный способ получить машинные ограничения для типов с плавающей запятой. Эта функция доступна в Numpy, Torch, а также в Tensorflow experimental.

.finfo().max возвращает максимально возможное число, которое можно представить как этот dtype.

ПРИМЕЧАНИЕ: Существует также .iinfo() для целочисленных типов.

Вот несколько примеров finfo и iinfo в действии.

 print('FLOATS')
print('float16',torch.finfo(torch.float16).max)
print('float32',torch.finfo(torch.float32).max)
print('float64',torch.finfo(torch.float64).max)
print('')
print('INTEGERS')
print('int16',torch.iinfo(torch.int16).max)
print('int32',torch.iinfo(torch.int32).max)
print('int64',torch.iinfo(torch.int64).max)
 
 FLOATS
float16 65504.0
float32 3.4028234663852886e 38
float64 1.7976931348623157e 308

INTEGERS
int16 32767
int32 2147483647
int64 9223372036854775807
 

Если вы хотите реализовать это в tensorflow, вы можете использовать tf.experimental.numpy.finfo для решения этой проблемы.

 print(tf.experimental.numpy.finfo(tf.float32))
print('Max ->',tf.experimental.numpy.finfo(tf.float32).max)  #<---- THIS IS WHAT YOU WANT
 
 Machine parameters for float32
---------------------------------------------------------------
precision =   6   resolution = 1.0000000e-06
machep =    -23   eps =        1.1920929e-07
negep =     -24   epsneg =     5.9604645e-08
minexp =   -126   tiny =       1.1754944e-38
maxexp =    128   max =        3.4028235e 38
nexp =        8   min =        -max
---------------------------------------------------------------

Max -> 3.4028235e 38
 

2. Возвращает ли y_hat.dtype просто тип данных?

ДА.

В torch он вернет torch.float32 или что-то в этом роде. В Tensorflow он вернет tf.float32 или что-то в этом роде.

Комментарии:

1. Спасибо. итак, в основном строка устанавливает для всех диагональных элементов максимальное значение типа данных? И код просто возвращает это значение. Поскольку tensorflow, похоже, использует только tf.float32, могу ли я безопасно просто использовать tf.float32.max вместо finfo и dtype?

2. исправьте, если вы уверены, что dtype равен float32. тогда tf.float32.max будет тот же результат.

3. Установка для нее максимального значения даст значение 1 по диагонали после sigmoid . Я думаю, что это то, что они пытаются сделать в первой функции потерь.

4. Итак, вы знаете, почему это было не только y_hat[:, torch.arange(N), torch.arange(N)] = 1 в реализации PyTorch?