#tensorflow #model #loss-function
#tensorflow #Модель #функция потерь
Вопрос:
Я хотел бы создать пользовательскую функцию потерь для модели tensorflow, используя y_true и y_pred, но я получил следующую ошибку: ValueError: невозможно вывести num из shape (None, 1) это моя пользовательская метрика:
def custom_metric(y_true,y_pred):
y_true = float(y_true)
y_pred = float(y_pred)
y_true = tf.unstack(y_true)
y_pred = tf.unstack(y_pred)
sqr_pred_error = K.square(y_true - y_pred)
sqr_y_true = K.square(y_true)
r = []
for i in y_true:
if sqr_pred_error[i] < sqr_y_true[i] or sqr_pred_error[i] == sqr_y_true[i]:
result = 1
print("result: 1")
else:
result = 0
print("result: 0")
r.append(result)
r = tf.stack(r)
return K.sum(r)/K.shape(r)
Ответ №1:
Вероятно, вам там не нужен цикл. Похоже, вам просто нужна куча 0 и 1.
1
— Еслиsqr_pred_error <= sqr_y_true
0
— еще
Затем вы можете сделать следующее.
def custom_metric(y_true,y_pred):
y_true = tf.cast(y_true, 'float32')
y_pred = tf.cast(y_pred, 'float32')
sqr_pred_error = K.square(y_true - y_pred)
sqr_y_true = K.square(y_true)
res = tf.where(sqr_pred_error<=sqr_y_true, tf.ones_like(y_true), tf.zeros_like(y_true))
return K.mean(res)
Комментарии:
1. Она работает, когда я использую ее как метрику, но в качестве потери я получаю следующее сообщение ошибка: ошибка значения: градиенты не указаны ни для одной переменной: [‘dense / kernel:0’, ‘dense / bias: 0’, ‘dense_1 / kernel:0’, ‘dense_1 / bias: 0’, ‘dense_2/ kernel:0’, ‘dense_2/ bias:0’, ‘dense_3/ kernel:0’, ‘dense_3 /bias:0’]. есть идеи?
2. @JasonGreffier, ну, это имеет смысл. Потому что посмотрите, как работает tf.where. Он собирает кучу единиц и нулей, которые не имеют ничего общего с самим прогнозом. Таким образом, градиента не будет. Я посмотрю, можно ли это спасти, чтобы компенсировать потерю fn
3. Все, что я пробовал, не удалось, если у вас есть успех, дайте мне знать
4. @JasonGreffier, да, у меня еще не было возможности. Постараюсь посмотреть сегодня
5. Я не видел вашего последнего комментария. Я попробую это с помощью if sqr_pred_error <= sqr_y_true тогда sqr_pred_error еще 0 завтра.