использование rnn_cell внутри tf.при получении ValueError: две структуры не имеют одинакового количества элементов

#python #tensorflow #recurrent-neural-network

#python #tensorflow #рекуррентная нейронная сеть

Вопрос:

Учитывая data = tf.placeholder(tf.float32, [2, None, 3]) (batch_size * time_step * feature_size), в идеале я хочу tf.unstack(data, axis = 1) получить несколько тензоров, каждый из которых имеет форму [2,3] , чтобы позже передать их в rnn с циклом for, например

 for rnn_input in rnn_inputs:
    state = rnn_cell(rnn_input, state)
  

Использование высокоуровневого API, такого как tf.nn.dynamic_rnn, исключено, поэтому я создаю обходной путь, подобный

 import tensorflow as tf


data = tf.placeholder(tf.float32, [2, None, 3])

step_number = tf.placeholder(tf.int32, None)

loop_counter_inital = tf.constant(0)

initi_state = tf.zeros([2,3], tf.float32)

def while_condition(loop_counter, rnn_states):
    return loop_counter < step_number

def while_body(loop_counter, rnn_states):
    loop_counter_current = loop_counter

    current_states = tf.gather_nd(data, tf.stack([tf.range(0, 2), tf.zeros([2], tf.int32) loop_counter_current], axis=1))     

    cell = tf.nn.rnn_cell.BasicRNNCell(3)

    rnn_states = cell(current_states, rnn_states)

    return [loop_counter_current, rnn_states]


_, _states = tf.while_loop(while_condition, while_body, 
                   loop_vars=[loop_counter_inital, initi_state], 
                   shape_invariants=[loop_counter_inital.shape, tf.TensorShape([2, 3])])

with tf.Session() as sess:    

    sess.run(tf.global_variables_initializer())

    print (sess.run(_states, feed_dict={data:[[[3,1,6],[4,1,2]],[[5,8,1],[0,5,2]]], step_number:2 }))
  

Идея состоит в том, чтобы перебирать каждую строку в каждом из 2D тензоров data , чтобы получить функции для каждого временного шага. Я получил ошибку

 First structure (2 elements): [<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, <tf.Tensor 'while/Identity_1:0' shape=(2, 3) dtype=float32>]

Second structure (3 elements): [<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, (<tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>, <tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>)]
  

Кажется, есть несколько связанных сообщений. На самом деле ничего не сработало. Кто-нибудь может помочь?

Ответ №1:

Вам нужно знать, что каждая BasicRNNCell будет реализована call() с подписью (output, next_state) = call(input, state) . Это означает, что ваш результат представляет собой список фигур ((?,unit),(?,unit)) . Итак, вам нужно сделать следующее.

 rnn_states = cell(current_states, rnn_states)[1]
  

Здесь вы также допустили ошибку. Вы забыли добавить 1 к loop_counter_current .

 return [loop_counter_current 1, rnn_states]
  

Добавить

Первая структура представляет начальное значение переданного вами параметра loop_vars , который содержит начальные значения loop_counter_inital и initi_state . Таким образом, его структура соответствует следующему.

 [
<tf.Tensor 'while/Identity:0' shape=() dtype=int32>  #---> loop_counter_inital
, <tf.Tensor 'while/Identity_1:0' shape=(2, 3) dtype=float32>  #---> initi_state
]
  

Вторая структура представляет параметр loop_vars после цикла. Его результаты соответствуют следующему, основанному на предыдущих ошибках.

 [
<tf.Tensor 'while/Identity:0' shape=() dtype=int32>  #---> loop_counter_inital
, (<tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>  #---> output
, <tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>)  #---> initi_state
]
  

Комментарии:

1. Хорошо, спасибо, в любом случае, к чему относится первая, вторая структура?

2. @user1935724 Я добавил это к ответу.