TensorFlow 2.0: как обновить тензоры?

#python #tensorflow #tensorflow2.0

#python #tensorflow #tensorflow2.0

Вопрос:

В TensorFlow 1.x для обновления тензора я бы использовал tf.scatter_update , чтобы обновить только соответствующую часть тензора.

Как мы можем сделать то же самое в TF 2.0?

Ответ №1:

Вы можете использовать tf.tensor_scatter_nd_update() :

 import tensorflow as tf
import numpy as np 

tensor = tf.convert_to_tensor(np.ones((2, 2)), dtype=tf.float32)
indices = tf.constant([[0, 0]])
updates = tf.constant([0.0])

tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
# array([[0., 1.],
#        [1., 1.]], dtype=float32)
  

Комментарии:

1. На самом деле это не то, что мне нужно, это создает новый объект. Есть ли какой-либо способ использовать ссылку на тот же объект? Если я хочу обновить некоторые переменные памяти, такие как буфер воспроизведения, tensorflow возвращает ошибку: «Операции вне кода построения функции передается тензор «Graph» «.

2. Я нашел ответ на свой собственный вопрос: tensor.assign(tf.tensor_scatter_nd_update(…. работает.

3. @gpernelle В вашем вопросе вы на самом деле хотели обновить переменную, а не тензорный объект (это то, что tf.scatter_update сделал в 1.x). Таким образом, ответ Влада является правильным при обновлении тензоров. Ваш ответ верен для обновления переменных, хотя я не знаю, почему эта функциональность исчезла в 2.x.