Определите тензор констант для tf.while_loop

#python #tensorflow #tensorflow1.15

#питон #тензорный поток #тензорный поток1.15

Вопрос:

Я хочу каким-то образом сохранить список констант, tf.while_loop которые могут поддерживать следующие функции

  1. Я могу читать и записывать (несколько раз) постоянное значение индекса
  2. Я могу запустить tf.cond его, проверив его значение по индексу против некоторой константы

TensorArray не будет работать здесь, так как он не поддерживает перезапись. Какие еще у меня есть варианты?

Комментарии:

1. Я предлагаю вам перейти на TensorFlow2, так как большинство людей, которые будут вас читать, вероятно, используют TF2

2. Ответ полезен?

Ответ №1:

Вы можете просто определить нормальный Tensor и обновить его tf.tensor_scatter_nd_update следующим образом:

 %tensorflow_version 1.x  import tensorflow as tf  data = tf.constant([1, 1, 1, 0, 1, 0, 1, 1, 0, 0], dtype=tf.float32) data_tensor = tf.zeros_like(data) tensor_size = data_tensor.shape[0]  init_state = (0, data_tensor) condition = lambda i, _: i lt; tensor_size  def custom_body(i, tensor):  special_index = 3 # index for which a value should be changed  new_value = 8  tensor = tf.where(tf.equal(i, special_index),   tf.tensor_scatter_nd_update(tensor, [[special_index]], [new_value]),  tf.tensor_scatter_nd_update(tensor, [[i]], [data[i]*2]))   return i   1, tensor   body = lambda i, tensor: (custom_body(i, tensor)) _, final_result = tf.while_loop(condition, body, init_state)  with tf.Session() as sess:  final_result_values = final_result.eval()  print(final_result_values)  
 [2. 2. 2. 8. 2. 0. 2. 2. 0. 0.]