#python #tensorflow
#python #тензорный поток
Вопрос:
Следующее не работает из-за формы tf.where()
. Есть ли хороший способ исправить это?
Я хочу, чтобы значения tensor_y
where tensor_x
выполняли условие (например, == значение ). Важно, что тензоры имеют batch_dims = 1
.
tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
trues = tf.where(tensor_y ==1)
new_tensor = tf.gather(tensor_y, axis=-1, indices = trues,batch_dims=1)
То, что я делаю сейчас, работает *, но, я думаю, это не так эффективно:
new_tensor = tf.stack([tf.gather(tensor_y[i,:], tf.where(tensor_x[i,:] == 1)) for i in range(tensor_x.shape[0])])
* иногда (я не знаю, при каких условиях) Я получаю его ошибку:
Формы всех входных данных должны совпадать: values[0].shape = [3,1] != values[1].shape = [6,1] [Op:Pack] name: stack
Ответ №1:
Это то, что вам нужно?
tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
new_tensor = tensor_y[(tensor_x==1)]
Комментарии:
1. Это работает в среде tensorflow. Но за пределами этой среды я получаю эту ошибку:
Value passed to parameter 'begin' has DataType bool not in list of allowed values: int32, int64