#python #tensorflow #tensorflow2.0
#python #tensorflow #tensorflow2.0
Вопрос:
В TensorFlow 1.x для обновления тензора я бы использовал tf.scatter_update
, чтобы обновить только соответствующую часть тензора.
Как мы можем сделать то же самое в TF 2.0?
Ответ №1:
Вы можете использовать tf.tensor_scatter_nd_update()
:
import tensorflow as tf
import numpy as np
tensor = tf.convert_to_tensor(np.ones((2, 2)), dtype=tf.float32)
indices = tf.constant([[0, 0]])
updates = tf.constant([0.0])
tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
# array([[0., 1.],
# [1., 1.]], dtype=float32)
Комментарии:
1. На самом деле это не то, что мне нужно, это создает новый объект. Есть ли какой-либо способ использовать ссылку на тот же объект? Если я хочу обновить некоторые переменные памяти, такие как буфер воспроизведения, tensorflow возвращает ошибку: «Операции вне кода построения функции передается тензор «Graph» «.
2. Я нашел ответ на свой собственный вопрос: tensor.assign(tf.tensor_scatter_nd_update(…. работает.
3. @gpernelle В вашем вопросе вы на самом деле хотели обновить переменную, а не тензорный объект (это то, что tf.scatter_update сделал в 1.x). Таким образом, ответ Влада является правильным при обновлении тензоров. Ваш ответ верен для обновления переменных, хотя я не знаю, почему эта функциональность исчезла в 2.x.