TensorFlow 2.0 возвращает неожиданный вывод на dtype = int32 с помощью GradientTape

#tensorflow

#tensorflow

Вопрос:

Следующий код должен вывести градиент y = x * x для x = 2, то есть значение 4. Однако код выводит значение None при использовании TensorFlow 2.0.0-alpha0. Когда определение x изменяется на использование tf.float32 вместо tf.int32 , как показано в следующем фрагменте, вывод изменяется на правильное значение 4. Существует ли какая-либо документация, разъясняющая требование, чтобы тип данных был числом с плавающей запятой для корректной работы GradientTape в этом сценарии?

 print(tf.__version__)

x = tf.constant(2, dtype=tf.int32)

with tf.GradientTape() as tape:
  tape.watch(x)
  y = x ** 2
  print(tape.gradient(y, x))
  

выводит:

 2.0.0-alpha0
None
  

Обратите внимание на изменение на tf.float32 в следующем фрагменте:

 print(tf.__version__)

x = tf.constant(2, dtype=tf.float32)

with tf.GradientTape() as tape:
  tape.watch(x)
  y = x ** 2
  print(tape.gradient(y, x))
  

выводит:

 2.0.0-alpha0
tf.Tensor(4.0, shape=(), dtype=float32)
  

Ответ №1:

Причина в том, что tf.gradient не распространяется градиенты через целочисленные тензоры. На это была ссылка в этом выпуске github:

https://github.com/tensorflow/tensorflow/issues/20524