Эффективный вызов функций из списка модулей tf.вдоль оси

#python #tensorflow

Вопрос:

Я ищу эффективный способ вызова функций из нескольких tf.Modules по оси ввода тензора (TensorFlow 2). Код будет вызываться с высокой частотой, поэтому производительность вызывает большую озабоченность.

Рассмотрим следующий упрощенный пример, где CompositeModule.predict находится функция интереса:

 class PredictorModule(tf.Module):
   def __init__(self, params):
      self.params = params

   @tf.function
   def predict(self, input_value): # input value shape: [ S ]
      # calculate something complicated based on self.params and input_value
      return result  # result shape [ S ]


class CompositeModule(tf.Module):
   def __init__(self, params_list):
      self.predictors = [ PredictorModule(params) for params in params_list ]

   @tf.function
   def predict(self, batched_input_value):  # batched_input_value shape [ B, S ]
     # THIS IS THE IMPORTANT FUNCTION
     # What I basically want to do is:
     result = []
     for i, predictor in enumerate(self.predictors):
        result.append( predictor.predict(batched_input_value[i,:]) )  # input and output shape [ S ]
     return tf.stack(result)  # return shape [ B, S ]       
 

Я уже пробовал использовать

 result = tf.map_fn(
            lambda nested: self.predictors[nested[0]].predict(nested[1]),
            [tf.range(self.num_modes), input_values]
)
 

однако я не могу индексировать self.predictors с помощью тензора, и я не думаю, что смогу оценить его в массив numpy внутри a tf.function . У меня такое чувство, что должен быть эффективный способ сделать это, но мне трудно найти хороший способ.

Обратите внимание, что одна фундаментальная проблема возникает из self.predictors -за того, что вы являетесь списком. Есть ли более разумный способ определения self.predictors , который позволяет нам индексировать его с помощью целочисленного тензора?