Пользовательская функция активации Keras с дополнительным параметром / аргументом

#python #tensorflow #machine-learning #keras

#python #тензорный поток #машинное обучение #keras #глубокое обучение

Вопрос:

Как я могу определить функцию активации в keras, которая принимает дополнительные аргументы. Моя первоначальная пользовательская функция активации — это функция, которая генерирует точки на основе полинома N степеней. Входные данные — это коэффициенты для полинома. Это выглядит так:

 def poly_transfer(x):
    a = np.arange(0, 1.05, 0.05)
    b = []
    for i in range(x.shape[1]):
        b.append(a**i)
    b = np.asarray(b)
    b = b.astype(np.float32)
    c = matmul(x,b)
    return c
 

Теперь я хочу установить длину вывода извне функции. Примерно так:

 def poly_transfer(x, lenght):
    a = np.arange(0, lenght   0.05, 0.05)
    b = []
    for i in range(x.shape[1]):
        b.append(a**i)
    b = np.asarray(b)
    b = b.astype(np.float32)
    c = matmul(x,b)
    return c
 

Как я могу реализовать эту функциональность и как я могу ее использовать?
На данный момент:

 speed_out = Lambda(poly_transfer)(speed_concat_layer)
 

Как я и предполагал:

 speed_out = Lambda(poly_transfer(lenght=lenght))(speed_concat_layer)
 

Ответ №1:

вы можете просто сделать это таким образом…

 X = np.random.uniform(0,1, (100,10))
y = np.random.uniform(0,1, (100,))

def poly_transfer(x, lenght):

    a = np.arange(0, lenght   0.05, 0.05)

    b = []
    for i in range(x.shape[1]):
        b.append(a**i)

    b = tf.constant(np.asarray(b), dtype=tf.float32)
    c = tf.matmul(x, b)

    return c

inp = Input((10,))
poly = Lambda(lambda x: poly_transfer(x, lenght=1))(inp)
out = Dense(1)(poly)

model = Model(inp, out)
model.compile('adam', 'mse')
model.fit(X, y, epochs=3)
 

Ответ №2:

Вы можете использовать для выполнения функции functools.partial :

 from functools import partial

poly_transfer_set_length = partial(poly_transfer, lenght=lenght)
speed_out = Lambda(poly_transfer_set_length)(speed_concat_layer)
 

или используйте lambda функцию:

 speed_out = Lambda(lambda x: poly_transfer(x, lenght=lenght))(speed_concat_layer)