Будут ли tf.gradients проходить через tf.cond?

#tensorflow

#tensorflow

Вопрос:

Я хотел бы создать пару рекуррентных нейронных сетей, скажем, NN1 и NN2, где NN2 воспроизводит свои выходные данные с предыдущего временного шага и не обновляет свои веса на текущем временном шаге всякий раз, когда NN1 выводит значение, отличное от предыдущего временного шага.

Для этого я планировал использовать tf.cond() вместе с tf.stop_gradients() . Однако во всех игрушечных примерах, которые я запускал, я не могу заставить tf.gradients() пройти через tf.cond() : tf.gradients() просто возвращает [None] .

Вот простой игрушечный пример:

 import tensorflow as tf

x = tf.constant(5)
y = tf.constant(3)

mult = tf.multiply(x, y)
cond = tf.cond(pred = tf.constant(True),
               true_fn = lambda: mult,
               false_fn = lambda: mult)

grad = tf.gradients(cond, x) # Returns [None]
  

Вот еще один простой игрушечный пример, где я определяю true_fn и false_fn в tf.cond() (по-прежнему без кубиков):

 import tensorflow as tf

x = tf.constant(5)
y = tf.constant(3)
z = tf.constant(8)

cond = tf.cond(pred = x < y,
               true_fn = lambda: tf.add(x, z),
               false_fn = lambda: tf.square(y))

tf.gradients(cond, z) # Returns [None]
  

Изначально я думал, что градиент должен проходить через оба true_fn и и false_fn , но очевидно, что градиент вообще не течет. Это ожидаемое поведение градиентов, вычисленных через tf.cond() ? Может ли быть способ обойти эту проблему?

Ответ №1:

Да, градиенты будут проходить через tf.cond() . Вам просто нужно использовать значения с плавающей запятой вместо целых чисел и (предпочтительно) использовать переменные вместо констант:

 
import tensorflow as tf

x = tf.Variable(5.0, dtype=tf.float32)
y = tf.Variable(6.0, dtype=tf.float32)
z = tf.Variable(8.0, dtype=tf.float32)

cond = tf.cond(pred = x < y,
               true_fn = lambda: tf.add(x, z),
               false_fn = lambda: tf.square(y))

op = tf.gradients(cond, z) 
# Returns [<tf.Tensor 'gradients_1/cond_1/Add/Switch_1_grad/cond_grad:0' shape=() dtype=float32>]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op)) # [1.0]