#python-3.x #tensorflow #machine-learning #keras #deep-learning
#python-3.x #тензорный поток #машинное обучение #keras #глубокое обучение
Вопрос:
Чтение из TensorArray:
def __init__(self, size):
self.obs_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
self.obs2_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
self.act_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
self.rew_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
self.done_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
def get_sample(self, batch_size):
idxs = tf.random.uniform(shape=[batch_size], maxval=self.size, dtype=tf.int32)
tf.print(idxs)
return self.obs_buf.gather(indices=idxs), # HERE IS THE ISSUE
self.act_buf.gather(indices=idxs),
self.rew_buf.gather(indices=idxs),
self.obs2_buf.gather(indices=idxs),
self.done_buf.gather(indices=idxs)
Использование:
@tf.function
def train(self, rpm, batch_size, gradient_steps):
for gradient_step in tf.range(1, gradient_steps 1):
obs, act, rew, next_obs, done = rpm.get_sample(batch_size)
with tf.GradientTape() as tape:
...
Проблема:
Трассировка (последний последний вызов): File «. main.py «, строка 130, в файле rl_training.train() «C:UsersuserDocumentsProjectsrl-toolkitrl_training.py «, строка 129, в файле train self._rpm, self.batch_size, self.gradient_steps, logging_wandb=файл self.logging_wandb «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerdef_function.py» , строка 828, ввызовите result = self ._call(* аргументы, ** kwds) Файл «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerdef_function.py» , строка 871, в _call self._initialize(аргументы, kwds, add_initializers_to=инициализаторы) Файл «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerdef_function.py» , строка 726, в _initialize *аргументы, **kwds)) Файл «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerfunction.py» , строка 2969, в файле _get_concrete_function_internal_garbage_collected graph_function, _ = self._maybe_define_function(аргументы, kwargs) «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerfunction.py» , строка 3361, в файле _maybe_define_function graph_function = self._create_graph_function(аргументы, kwargs) «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerfunction.py» , строка 3206, в _create_graph_function capture_by_value=self._capture_by_value), файл»C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkfunc_graph.py», строка 990, в файле func_graph_from_py_func func_outputs = python_func(* func_args, **func_kwargs) «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerdef_function.py» , строка 634, в wrapped_fn out = weak_wrapped_fn().завернутый(* аргументы, ** kwds) Файл «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythoneagerfunction.py» , строка 3887, в bound_method_wrapper возвращает файл wrapped_fn(*args, **kwargs) «C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkfunc_graph.py» , строка 977, в обертке вызывает e.ag_error_metadata.to_exception(e) tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: в пользовательском коде:
C:UsersuserDocumentsProjectsrl-toolkitpolicysacsac.py:183 update *
obs, act, rew, next_obs, done = rpm.get_sample(batch_size)
C:UsersuserDocumentsProjectsrl-toolkitutilsreplay_buffer.py:39 __call__ *
return self.obs_buf.gather(indices=idxs), self.act_buf.gather(indices=idxs), self.rew_buf.gather(indices=idxs), self.obs2_buf.gather(indices=idxs), self.done_buf.gather(indices=idxs)
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonopstensor_array_ops.py:1190 gather **
return self._implementation.gather(indices, name=name)
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonopstensor_array_ops.py:861 gather
return array_ops.stack([self._maybe_zero(i) for i in indices])
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkops.py:505 __iter__
self._disallow_iteration()
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkops.py:498 _disallow_iteration
self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
C:UsersuserAppDataLocalProgramsPythonPython36libsite-packagestensorflowpythonframeworkops.py:476 _disallow_when_autograph_enabled
" indicate you are trying to use an unsupported feature.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
Почему я не могу использовать TensorArray в этом контексте? И какие у меня есть альтернативы?
Комментарии:
1. делает ли это github.com/tensorflow/tensorflow/issues/31952 помочь ?
2. Извините, но нет, потому что у меня проблемы с tf. TensorArray.gather() не tf.gather() ….. 0 решение в этом случае не работает.
Ответ №1:
Решаемая здесь. Необходимо использовать tf.Variable вместо tf.TensorArray.