#python #tensorflow
Вопрос:
Я ищу эффективный способ вызова функций из нескольких tf.Modules
по оси ввода тензора (TensorFlow 2). Код будет вызываться с высокой частотой, поэтому производительность вызывает большую озабоченность.
Рассмотрим следующий упрощенный пример, где CompositeModule.predict
находится функция интереса:
class PredictorModule(tf.Module):
def __init__(self, params):
self.params = params
@tf.function
def predict(self, input_value): # input value shape: [ S ]
# calculate something complicated based on self.params and input_value
return result # result shape [ S ]
class CompositeModule(tf.Module):
def __init__(self, params_list):
self.predictors = [ PredictorModule(params) for params in params_list ]
@tf.function
def predict(self, batched_input_value): # batched_input_value shape [ B, S ]
# THIS IS THE IMPORTANT FUNCTION
# What I basically want to do is:
result = []
for i, predictor in enumerate(self.predictors):
result.append( predictor.predict(batched_input_value[i,:]) ) # input and output shape [ S ]
return tf.stack(result) # return shape [ B, S ]
Я уже пробовал использовать
result = tf.map_fn(
lambda nested: self.predictors[nested[0]].predict(nested[1]),
[tf.range(self.num_modes), input_values]
)
однако я не могу индексировать self.predictors
с помощью тензора, и я не думаю, что смогу оценить его в массив numpy внутри a tf.function
. У меня такое чувство, что должен быть эффективный способ сделать это, но мне трудно найти хороший способ.
Обратите внимание, что одна фундаментальная проблема возникает из self.predictors
-за того, что вы являетесь списком. Есть ли более разумный способ определения self.predictors
, который позволяет нам индексировать его с помощью целочисленного тензора?