Извлечение объединенного вывода из модели трансформатора Huggingface потребляет слишком много памяти

#memory-management #pytorch #ram #bert-language-model #huggingface-transformers

Вопрос:

Я пытаюсь использовать свою тонко настроенную модель Дистилберта для извлечения объединенных выходных данных моего набора данных. Иначе говоря, я пытаюсь извлечь скрытое состояние последнего слоя токена » [CLS]».

Я написал эту функцию для выполнения этой задачи:

 def getPooledOutputs(model, encoded_dataset, batch_size = 128):
  pooled_outputs = []
  print("total number of iters ", len(encoded_dataset['input_ids'])//batch_size   1)
  
  with torch.no_grad():
    for i in range(len(encoded_dataset['input_ids'])//batch_size   1):
      print(i)
      up_to = i*batch_size   batch_size
      if len(encoded_dataset['input_ids']) < up_to:
        up_to = len(encoded_dataset['input_ids'])
      input_ids = th.LongTensor(encoded_dataset['input_ids'][i*batch_size:up_to])
      attention_mask = th.LongTensor(encoded_dataset['attention_mask'][i*batch_size:up_to])

      start = time.time()
      embeddings = model.distilbert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)[0][:,0] # Pooled output
      end = time.time()
      print("time for inference took:", end - start)
      pooled_outputs.extend(embeddings)

  return pooled_outputs
 

И я называю это так:

 train_set_embeddings = getPooledOutputs(model, chunked_encoded_dataset['train'])
 

Это, кажется, работает хорошо и дает мне желаемые результаты, за исключением того, что к моменту завершения выполнения кода программа израсходовала почти все 25 ГБ оперативной памяти!! Когда общий объем оперативной памяти, используемой для вывода функции (train_set_embeddings), составляет около 22 МБ.

Объем оперативной памяти, как правило, увеличивается к концу цикла, в котором он будет увеличиваться с пары ГБ до 18 ГБ.

Вот скриншот использования оперативной памяти, на котором вы можете увидеть всплески к концу запуска:

введите описание изображения здесь

Вот скриншот журнала выполнения на случай, если это может помочь: введите описание изображения здесь