#python #tensorflow
#python #tensorflow
Вопрос:
Я использую этот модуль факторизации матрицы WALS в TensorFlow. После установки оценки я пытаюсь сохранить модель, используя метод export_savedmodel(), но я не могу предоставить правильный аргумент serving_input_fn . Код здесь:
from tensorflow.contrib.factorization.python.ops import wals as wals_lib
# dense input array that shows the user X item interactions
# here set as dummy array
dense_array = np.ones((10,10))
num_rows, num_cols = dense_array.shape
emebedding_dim = 5 # manually setting hidden factor
factorizer = wals_lib.WALSMatrixFactorization(num_rows, num_cols, embedding_dim, max_sweeps=10)
# this generate_input_fn() is not shown here but it's a copy of
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/factorization/python/ops/wals_test.py#L82
input_fn, _, _ = generate_input_fn(np_matrix=dense_array, batch_size=32, mode=model_fn.ModeKeys.TRAIN)
factorizer.fit(input_fn, steps=10)
# MY PROBLEM IS HERE
# How to define the correct serving input function?
factorizer.export_savedmodel('path/to/save/model', serving_input_fn=???)
Сложность здесь в том, что модуль WALS, я полагаю, использует более старую парадигму TensorFlow, где serving_input_fn
аргумент ожидает вызываемую функцию, которая возвращает an InputFnOps
. Однако более обновленные оценки, такие как эта, ожидают функцию, которая возвращает tf.estimator.export.ServingInputReceiver
или tf.estimator.export.TensorServingInputReceiver
. Я признаю, что я еще не полностью владею функциями ввода TensorFlow, но буду признателен за любую помощь в моем конкретном случае использования сохранения моей оценки WALS. Спасибо!
Ответ №1:
Вы правы в отношении устаревших частей Tensorflow 1.x. Для WALSMatrixFactorization вам serving_input_fn
необходимо вернуть InputFnOps
объект. Поэтому правильная функция ввода будет:
def serving_input_receiver_fn():
# some example input
receiver_tensors = {'my_input': tf.placeholder(dtype=tf.string, shape=[None, 1], name='foo')}
# some example feature that completely ignores the input
features = {
WALSMatrixFactorization.INPUT_ROWS: tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[3, 4]),
WALSMatrixFactorization.INPUT_COLS: tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[3., 4.], dense_shape=[3, 4]),
WALSMatrixFactorization.PROJECT_ROW: tf.constant(True),
}
return tf.contrib.learn.utils.InputFnOps(features, None, receiver_tensors)