#python #tensorflow
#python #tensorflow
Вопрос:
Предполагается, что у меня есть список тензоров tensorflow, я хочу динамически добавлять дополнительный тензор в этот список при определенных условиях. например, если максимальное скалярное произведение между каждым тензором в списке и этим дополнительным тензором больше 0, то этот дополнительный тензор добавляется в список. Вот код:
lists = []
for i in xrange(10):
a = tf.get_variable(name=str(i), shape=[3], dtype=tf.float32)
lists.append(a)
итак, прямо сейчас у нас есть список из 10 тензоров, каждый тензор имеет форму [3].
for j in xrange(11, 30):
b = tf.get_variable(name=str(j), shape=[3, 1], dtype=tf.float32)
c = tf.stack(lists)
e = tf.cond(tf.reduce_max(tf.reshape(lists, shape=[-1]), axis=0)>0.00, lambda: tf.stack(lists.append(tf.reshape(b, [-1]))), lambda: c)
lists = tf.unstack(e)
Однако у этого кода есть несколько проблем, прежде всего,
TypeError: 'NoneType' object has no attribute '__getitem__'
Это потому, что tf.stack(lists.append(tf.reshape(b, [-1])))
, lists.append(tf.reshape(b, [-1]))
является ‘нетипичным’.
Вторая проблема заключается в том, что даже если эта часть работает, то lists = tf.unstack(e)
имеет ошибку, потому что ValueError: Cannot infer num from shape (?, 3)
из-за tf.unstack()
не может работать с непереводимыми измерениями.
Ребята, пожалуйста, не могли бы вы научить меня, как реализовать эту функцию? Спасибо
Ответ №1:
Итак, у вас здесь как минимум две разные проблемы.
Первая проблема: я не понимаю, что reshape
вы делаете. Я бы использовал tensordot
вместо этого. И я бы не преобразовывал тензор обратно в список, если это не нужно.
Например:
c = tf.stack(lists) # shape [10,3]
for j in range(11, 30):
b = tf.get_variable(name=str(j), shape=[1, 3], dtype=tf.float32)
d = tf.tensordot(b, c, axes=[1,1]) # shape [1,10]
c = tf.cond(tf.reduce_max(d) > 0.00, lambda: tf.concat([c, b], 0), lambda: c) # shape [?,3]
Вторая проблема: преобразуйте тензор с непереводимыми измерениями в список. Есть много вопросов и ответов по этим темам:
http://www.google.com/search ?q=tensorflow unstack может не работать с непереводимыми измерениями
Надеюсь, это поможет.