#python #tensorflow #tensorflow2.0
Вопрос:
Мне просто было интересно, есть ли лучший способ обновления тензора в tf2. Допустим, у меня есть tensor_a = tf.ones(4,5,5)
(batch_size, H, W), и я хотел бы заменить все значения второго образца нулями( index=1
). Вот как мне удается это делать без использования режима нетерпеливого выполнения:
tensor_a = tf.ones([4,5,5]) tensor_b = tf.zeros([1,5,5]) index=1 tensor_a = tf.concat([tensor_a[:index], tensor_b, tensor_a[index 1:]], axis=0)
Я знаю, что существует функция tf.tensor_scatter_nd_update (), но я не знаком с сетчатыми сетками, и, на мой взгляд, они выглядят немного уродливо для простой операции назначения среза. Также в некоторых случаях было бы удобно обновлять срезы со многими индексами (например, с образцами 0,1 и 2 до нулей) одновременно.
Ответ №1:
Операции тензорного потока иногда немного запутанны.
import tensorflow as tf tensor = tf.ones([4, 5, 5]) tensor = tf.tensor_scatter_nd_update( tensor, [[1]], tf.zeros_like(tf.gather(tensor, [1])) )
lt;tf.Tensor: shape=(4, 5, 5), dtype=float32, numpy= array([[[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]], [[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]], dtype=float32)gt;
Комментарии:
1. Аааа, это было просто с помощью tf.tensor_scatter_nd_update(). Я не был знаком с параметром indexes и отказался от попыток использовать concat. Спасибо!