tensorflow потребляет слишком много памяти при использовании tf.map_fn

#python-3.x #tensorflow

Вопрос:

У меня есть следующий код:

 @tf.function  def c(self, point_instances, instances, alpha):  def inner_comp(j):  print(j.shape)  print(tf.transpose(instances).shape)  print(alpha.shape)  print(tf.linalg.matvec(tf.transpose(instances), alpha).shape)  tmp = tf.tensordot(j, tf.linalg.matvec(tf.transpose(instances), alpha), 1)  print(tmp.shape)  return tmp   return tf.reduce_max(tf.abs(tf.map_fn(inner_comp, point_instances)))  

Я заметил, что при tf.map_fn завершении второй итерации он погибает из-за нехватки памяти. Таким образом, вывод консоли будет выглядеть следующим образом:

 (245245,) (245245, 460) (460,) (245245,) ()  (245245,) (245245, 1040) (1040,) (245245,) ()   Killed  

Хотя мои данные являются крупномасштабными, результат inner_comp всегда скалярный. Я пытался выделить 200 ГБ памяти, но все равно ее убивают. В чем может быть причина?