Как использовать tf.gather в сочетании с tf.where

#python #tensorflow

#python #тензорный поток

Вопрос:

Следующее не работает из-за формы tf.where() . Есть ли хороший способ исправить это?

Я хочу, чтобы значения tensor_y where tensor_x выполняли условие (например, == значение ). Важно, что тензоры имеют batch_dims = 1 .

 tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
trues = tf.where(tensor_y ==1)


new_tensor = tf.gather(tensor_y, axis=-1, indices = trues,batch_dims=1)


  

То, что я делаю сейчас, работает *, но, я думаю, это не так эффективно:

 new_tensor = tf.stack([tf.gather(tensor_y[i,:], tf.where(tensor_x[i,:] == 1)) for i in range(tensor_x.shape[0])])

  

* иногда (я не знаю, при каких условиях) Я получаю его ошибку:

Формы всех входных данных должны совпадать: values[0].shape = [3,1] != values[1].shape = [6,1] [Op:Pack] name: stack

Ответ №1:

Это то, что вам нужно?

 tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)

new_tensor = tensor_y[(tensor_x==1)]
  

Комментарии:

1. Это работает в среде tensorflow. Но за пределами этой среды я получаю эту ошибку: Value passed to parameter 'begin' has DataType bool not in list of allowed values: int32, int64