Tensorflow: ООМ при слишком большом размере пакета

#tensorflow #tensorflow2.0

#тензорный поток #tensorflow2.0

Вопрос:

Мой скрипт терпит неудачу из-за слишком высокого использования памяти. Когда я уменьшаю размер пакета, это работает.

 @tf.function(autograph=not DEBUG)
def step(prev_state, input_b):
    input_b = tf.reshape(input_b, shape=[1,input_b.shape[0]])
    state = FastALIFStateTuple(v=prev_state[0], z=prev_state[1], b=prev_state[2], r=prev_state[3])
    new_b = self.decay_b * state.b   (tf.ones(shape=[self.units],dtype=tf.float32) - self.decay_b) * state.z
    thr = self.thr   new_b * self.beta
    z = state.z
    i_in = tf.matmul(input_b, W_in)
    i_rec = tf.matmul(z, W_rec)
    i_t = i_in   i_rec
    I_reset = z * thr * self.dt
    new_v = self._decay * state.v   (1 - self._decay) * i_t - I_reset
    # Spike generation
    is_refractory = tf.greater(state.r, .1)
    zeros_like_spikes = tf.zeros_like(z)
    new_z = tf.where(is_refractory, zeros_like_spikes, self.compute_z(new_v, thr))
    new_r = tf.clip_by_value(state.r   self.n_refractory * new_z - 1,
                            0., float(self.n_refractory))
    return [new_v, new_z, new_b, new_r]

@tf.function(autograph=not DEBUG)
def evolve_single(inputs):
    accumulated_state = tf.scan(step, inputs, initializer=state0)
    Z = tf.squeeze(accumulated_state[1]) # -> [T,units]
    if self.model_settings['avg_spikes']:
        Z = tf.reshape(tf.reduce_mean(Z, axis=0), shape=(1,-1))
    out = tf.matmul(Z, W_out)   b_out
    return out # - [BS,Num_labels]

# # - Using a simple loop
# out_store = []
# for i in range(fingerprint_3d.shape[0]):
#     out_store.append(tf.squeeze(evolve_single(fingerprint_3d[i,:,:])))

# return tf.reshape(out_store, shape=[fingerprint_3d.shape[0],self.d_out])

final_out = tf.squeeze(tf.map_fn(evolve_single, fingerprint_3d)) # -> [BS,T,self.units]
return final_out
  

Этот фрагмент кода находится внутри tf.function, но я опустил его, так как не думаю, что это актуально.
Как видно, я запускаю код на fingerprint_3d тензоре, который имеет размерность [BatchSize,Time,InputDimension], например [50,100,20] . Когда я запускаю это с BatchSize < 10, все работает нормально, хотя tf.scan для этого уже используется много памяти.
Когда я теперь выполняю код для пакета размером 50, внезапно я получаю ООМ, хотя я выполняю его итеративно (здесь прокомментировано).
Как я должен выполнить этот код, чтобы размер пакета не имел значения?
Возможно, tensorflow распараллеливает мой цикл for, чтобы он выполнялся одновременно над несколькими пакетами?

Еще один не связанный с этим вопрос заключается в следующем: какую функцию вместо tf.scan я должен использовать, если я хочу накопить только одну переменную состояния, по сравнению со случаем tf.scan , когда она просто накапливает все переменные состояния? Или это возможно с tf.scan помощью?

Ответ №1:

Как упоминалось в обсуждениях здесь, tf.foldl, tf.foldr и tf.scan требуют отслеживания всех значений для всех итераций, что необходимо для вычислений, таких как градиенты. Я не знаю никаких способов решения этой проблемы; тем не менее, мне также было бы интересно, есть ли у кого-нибудь лучший ответ, чем у меня.

Комментарии:

1. У меня есть другая версия кода, которая не использует быстрое выполнение, и я могу легко работать с размером пакета 50, что странно. Он использует функцию keras.backend.rnn() . Я пытался использовать его, но это также не решает проблему (он использует tf.while).

Ответ №2:

Когда я использовал

 @tf.function
def get_loss_and_gradients():
    with tf.GradientTape(persistent=False) as tape:
        logits, spikes = rnn.call(fingerprint_input=graz_dict["train_input"], W_in=W_in, W_rec=W_rec, W_out=W_out, b_out=b_out)
        loss = loss_normal(tf.cast(graz_dict["train_groundtruth"],dtype=tf.int32), logits)
    gradients = tape.gradient(loss, [W_in,W_rec,W_out,b_out])
    return loss, logits, spikes, gradients
  

это работает.
Когда я удаляю @tf.function декоратор, память взрывается. Поэтому действительно важно, чтобы tensorflow мог создавать график для ваших вычислений.