Почему я не могу использовать TensorArray.gather() в @tf.function?

#python-3.x #tensorflow #machine-learning #keras #deep-learning

#python-3.x #тензорный поток #машинное обучение #keras #глубокое обучение

Вопрос:

Чтение из TensorArray:

 def __init__(self, size):
    self.obs_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.obs2_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.act_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.rew_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.done_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)

def get_sample(self, batch_size):
        idxs = tf.random.uniform(shape=[batch_size], maxval=self.size, dtype=tf.int32)
        tf.print(idxs)
        return self.obs_buf.gather(indices=idxs),          # HERE IS THE ISSUE
               self.act_buf.gather(indices=idxs),     
               self.rew_buf.gather(indices=idxs),     
               self.obs2_buf.gather(indices=idxs),    
               self.done_buf.gather(indices=idxs)
 

Использование:

 @tf.function
def train(self, rpm, batch_size, gradient_steps):
    for gradient_step in tf.range(1, gradient_steps   1):
        obs, act, rew, next_obs, done = rpm.get_sample(batch_size)

        with tf.GradientTape() as tape:
        ...
 

Проблема:

Трассировка (последний последний вызов): File «. main.py «, строка 130, в файле rl_training.train() «C:UsersuserDocumentsProjectsrl-toolkitrl_training.py «, строка 129, в файле train self._rpm, self.batch_size, self.gradient_steps, logging_wandb=файл self.logging_wandb «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerdef_function.py» , строка 828, ввызовите result = self ._call(* аргументы, ** kwds) Файл «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerdef_function.py» , строка 871, в _call self._initialize(аргументы, kwds, add_initializers_to=инициализаторы) Файл «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerdef_function.py» , строка 726, в _initialize *аргументы, **kwds)) Файл «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerfunction.py» , строка 2969, в файле _get_concrete_function_internal_garbage_collected graph_function, _ = self._maybe_define_function(аргументы, kwargs) «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerfunction.py» , строка 3361, в файле _maybe_define_function graph_function = self._create_graph_function(аргументы, kwargs) «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerfunction.py» , строка 3206, в _create_graph_function capture_by_value=self._capture_by_value), файл»C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkfunc_graph.py», строка 990, в файле func_graph_from_py_func func_outputs = python_func(* func_args, **func_kwargs) «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerdef_function.py» , строка 634, в wrapped_fn out = weak_wrapped_fn().завернутый(* аргументы, ** kwds) Файл «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerfunction.py» , строка 3887, в bound_method_wrapper возвращает файл wrapped_fn(*args, **kwargs) «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkfunc_graph.py» , строка 977, в обертке вызывает e.ag_error_metadata.to_exception(e) tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: в пользовательском коде:

 C:UsersuserDocumentsProjectsrl-toolkitpolicysacsac.py:183 update  *
    obs, act, rew, next_obs, done = rpm.get_sample(batch_size)
C:UsersuserDocumentsProjectsrl-toolkitutilsreplay_buffer.py:39 __call__  *
    return self.obs_buf.gather(indices=idxs),                    self.act_buf.gather(indices=idxs),                    self.rew_buf.gather(indices=idxs),                    self.obs2_buf.gather(indices=idxs),                   self.done_buf.gather(indices=idxs)
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonopstensor_array_ops.py:1190 gather  **
    return self._implementation.gather(indices, name=name)
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonopstensor_array_ops.py:861 gather
    return array_ops.stack([self._maybe_zero(i) for i in indices])
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkops.py:505 __iter__
    self._disallow_iteration()
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkops.py:498 _disallow_iteration
    self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkops.py:476 _disallow_when_autograph_enabled
    " indicate you are trying to use an unsupported feature.".format(task))

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
 

Почему я не могу использовать TensorArray в этом контексте? И какие у меня есть альтернативы?

Комментарии:

1. делает ли это github.com/tensorflow/tensorflow/issues/31952 помочь ?

2. Извините, но нет, потому что у меня проблемы с tf. TensorArray.gather() не tf.gather() ….. 0 решение в этом случае не работает.

Ответ №1:

Решаемая здесь. Необходимо использовать tf.Variable вместо tf.TensorArray.