Tensorflow сообщает «Ошибка типа: список тензоров, когда ожидается одиночный тензор» при использовании tf.cond()

#python #tensorflow

#python #tensorflow

Вопрос:

Я использую Tensorflow для кодирования модели. Часть моего условного выражения, такого как:

 new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: tf.constant([1, src_shape[0]]))
  

и src_shape является результатом tf.shape() .

Он сообщает TypeError: List of Tensors when single Tensor expected . Я знаю, что это потому, что tf.constant([1, src_shape[0]]) это список тензоров, но я не знаю, как реализовать мой код законным способом.

Я попытался удалить tf.constant() подобное

 new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: [1, src_shape[0]])
  

но он сообщает ValueError: Incompatible return values of true_fn and false_fn: The two structures don't have the same nested structure.

Ответ №1:

Одним из способов было бы использовать tf.stack, который объединяет список тензоров ранга R в один тензор ранга (R 1).

 lambda: tf.stack([1, src_shape[0]], axis=0)
  

Другим решением было бы использовать tf.concat с использованием правильных команд tf.reshape.

Ответ №2:

Я пробовал, что tf.convert_to_tensor([1, src_shape[0]]) работает. Это альтернативное решение.