Формы (1,) и () несовместимы с оператором cond

#python-3.x #tensorflow

#python-3.x #тензорный поток

Вопрос:

Я хочу пропустить некоторые данные, которые имеют определенные метки (например, if label >= 7 или другие). Мой код здесь:

 true = tf.constant(True)
less_op = tf.less(label, tf.constant(delimiter))
label = tf.cast(
    tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
tf.cond(less_op, lambda: true, lambda: true)
  

и в 4-й строке у меня ошибка: ValueError: Shapes (1,) and () are not compatible . Мое предположение, что это вызвано less_op (если я заменю его true кодом, работает). Также я выяснил, что есть некоторая проблема с label : code less_op = tf.less(tf.constant(1), tf.constant(delimiter)) работает отлично.

Ответ №1:

Tensorflow ожидает, что он будет иметь форму None или [], а не (1,) . Это странное поведение, которое должно быть исправлено в моем варианте, потому что tf.less возвращает тензор shape (1,), а не shape () .

Измените это:

 tf.cond(less_op, lambda: true, lambda: true)
  

к этому:

 tf.cond(tf.reshape(less_op,[]), lambda: true, lambda: true)
  

Комментарии:

1. tf.less возвращает тот же ранг, что и входные данные, поэтому sess.run(tf.less(5, 4)) имеет shape ()