#python-3.x #tensorflow
#python-3.x #тензорный поток
Вопрос:
Я хочу пропустить некоторые данные, которые имеют определенные метки (например, if label
>= 7 или другие). Мой код здесь:
true = tf.constant(True)
less_op = tf.less(label, tf.constant(delimiter))
label = tf.cast(
tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
tf.cond(less_op, lambda: true, lambda: true)
и в 4-й строке у меня ошибка: ValueError: Shapes (1,) and () are not compatible
. Мое предположение, что это вызвано less_op (если я заменю его true
кодом, работает). Также я выяснил, что есть некоторая проблема с label
: code less_op = tf.less(tf.constant(1), tf.constant(delimiter))
работает отлично.
Ответ №1:
Tensorflow ожидает, что он будет иметь форму None или [], а не (1,) . Это странное поведение, которое должно быть исправлено в моем варианте, потому что tf.less возвращает тензор shape (1,), а не shape () .
Измените это:
tf.cond(less_op, lambda: true, lambda: true)
к этому:
tf.cond(tf.reshape(less_op,[]), lambda: true, lambda: true)
Комментарии:
1. tf.less возвращает тот же ранг, что и входные данные, поэтому
sess.run(tf.less(5, 4))
имеет shape ()