Создание пользовательской функции потерь в tensorflow с циклом

#tensorflow #model #loss-function

#tensorflow #Модель #функция потерь

Вопрос:

Я хотел бы создать пользовательскую функцию потерь для модели tensorflow, используя y_true и y_pred, но я получил следующую ошибку: ValueError: невозможно вывести num из shape (None, 1) это моя пользовательская метрика:

 def custom_metric(y_true,y_pred):

    y_true = float(y_true)
    y_pred = float(y_pred)
    y_true = tf.unstack(y_true)
    y_pred = tf.unstack(y_pred)

    sqr_pred_error = K.square(y_true - y_pred)
    sqr_y_true = K.square(y_true)
    r = []
    for i in y_true:
        if sqr_pred_error[i] < sqr_y_true[i] or sqr_pred_error[i] == sqr_y_true[i]:
            result = 1
            print("result: 1")
        else:
            result = 0
            print("result: 0")
        r.append(result)
    r = tf.stack(r)

    return  K.sum(r)/K.shape(r)
  

Ответ №1:

Вероятно, вам там не нужен цикл. Похоже, вам просто нужна куча 0 и 1.

  • 1 — Если sqr_pred_error <= sqr_y_true
  • 0 — еще

Затем вы можете сделать следующее.

 def custom_metric(y_true,y_pred):

    y_true = tf.cast(y_true, 'float32')
    y_pred = tf.cast(y_pred, 'float32')
    
    sqr_pred_error = K.square(y_true - y_pred)
    sqr_y_true = K.square(y_true)

    res = tf.where(sqr_pred_error<=sqr_y_true, tf.ones_like(y_true), tf.zeros_like(y_true))
    return  K.mean(res)
  

Комментарии:

1. Она работает, когда я использую ее как метрику, но в качестве потери я получаю следующее сообщение ошибка: ошибка значения: градиенты не указаны ни для одной переменной: [‘dense / kernel:0’, ‘dense / bias: 0’, ‘dense_1 / kernel:0’, ‘dense_1 / bias: 0’, ‘dense_2/ kernel:0’, ‘dense_2/ bias:0’, ‘dense_3/ kernel:0’, ‘dense_3 /bias:0’]. есть идеи?

2. @JasonGreffier, ну, это имеет смысл. Потому что посмотрите, как работает tf.where. Он собирает кучу единиц и нулей, которые не имеют ничего общего с самим прогнозом. Таким образом, градиента не будет. Я посмотрю, можно ли это спасти, чтобы компенсировать потерю fn

3. Все, что я пробовал, не удалось, если у вас есть успех, дайте мне знать

4. @JasonGreffier, да, у меня еще не было возможности. Постараюсь посмотреть сегодня

5. Я не видел вашего последнего комментария. Я попробую это с помощью if sqr_pred_error <= sqr_y_true тогда sqr_pred_error еще 0 завтра.