XLA rng-bit-generator занимает слишком много памяти

#tensorflow #pytorch #tpu #xla

#tensorflow #pytorch #tpu #xla

Вопрос:

XLA выделяет 4G памяти для этого тензора. Размер которого, похоже, зависит от размера пакета. Для меня это не имеет смысла, похоже, он не является частью графика модели, который будет храниться в HBM. Я использую TPUv3.

Я не использую никаких случайных операций, кроме инициализации модели. Более того, я объявил bfloat16 для всех весов, но это тензор u32.

   Largest program allocations in hbm:

  1. Size: 4.00G
     Shape: u32[128,8,1024,1024]{3,2,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: %rng-bit-generator = (u32[2,128]{1,0:T(2,128)}, u32[128,8,1024,1024]{3,2,1,0:T(8,128)}) rng-bit-generator(u32[2,128]{1,0:T(2,128)} %fusion.2446), algorithm=rng_default
     Allocation type: HLO temp
     ==========================
  

Что может быть причиной вышеуказанного выделения? Я использую pixelsnail из: https://github.com/kamenbliznashki/pixel_models

Вопросы:

  • Почему этот тензор имеет тип u32 , когда все мои определения веса / модели (включая глобальный флаг окружающей среды) используют BF16?
  • Почему используется rng-bit-generator?

Комментарии:

1. Выглядит разумно: 4 ГБ = 128 * 8 * 1024 * 1024 * 4

2. @Andrey это правильно, без сомнения. Однако я не могу найти, соответствует ли он какому-либо компоненту моей модели. Более того, rng-bit-generator еще более неясен относительно того, зачем он нужен. Также он использует u32, а не указанный мной bfloat16.