Как добавить параметр на основе примера в пользовательскую функцию потери keras?

#python #tensorflow #keras #loss-function

Вопрос:

Я хочу иметь пользовательскую функцию потерь в keras, которая имеет параметр, отличающийся для каждого обучающего примера.

 from keras import backend as K

def my_mse_loss_b(b):
     def mseb(y_true, y_pred):
         return K.mean(K.square(y_pred - y_true))   b
     return mseb
 

Я читал здесь, что y_true и y_pred всегда передаются в функцию потерь, поэтому вам нужно создать функцию-оболочку.

 model.compile(loss=my_mse_loss_b(df.iloc[:,2]), optimizer='adam', metrics=['accuracy'])
 

Проблема в том, что, когда я подстраиваюсь под модель, возникает ошибка, поскольку функция предполагает, что передаваемые параметры будут такими же длинными, как и пакет. Я, с другой стороны, хочу, чтобы в каждом примере был свой параметр.

 tensorflow.python.framework.errors_impl.InvalidArgumentError:  Incompatible shapes: [20] vs. [10000]
     [[node gradients/loss_2/dense_3_loss/mseb/weighted_loss/mul_grad/BroadcastGradientArgs (defined at C:Usersflis1Miniconda3envsAutomatelibsite-packagestensorflow_corepythonframeworkops.py:1751) ]] [Op:__inference_keras_scratch_graph_1129]
Function call stack:
keras_scratch_graph
 

Несовместимые формы, говорится в нем. 20-это размер пакета, а 10000-размер моего набора данных поезда и размер всех параметров.

Я могу подогнать модель, если добавляемый мной параметр соответствует размеру пакета, но, как я уже сказал, я хочу, чтобы параметр передавался на основе примера.

Ответ №1:

В вашем случае, поскольку ваш параметр b тесно связан с его обучающим примером, имело бы смысл сделать его частью основной истины. Вы можете переписать свою функцию потерь следующим образом:

 def mseb(y_true, y_pred):
    y_t, b = y_true[0], y_true[1]
    return K.mean(K.square(y_pred - y_t))   b
 

а затем обучите свою модель с помощью

 model.compile(loss=mseb)
b = df.iloc[:,2]
model.fit(X,(y,b))
 

Комментарии:

1. Спасибо. Я сам это понял, единственная проблема, которую я вижу, заключается в том, что вам нужно настроить все метрические функции.