Ошибка выполнения: одна из переменных, необходимых для вычисления градиента, была изменена с помощью операции на месте. Не могу понять, почему?

#pytorch #huggingface-transformers

Вопрос:

Пример использования: я обучаю модель 1 (используя train1) с определенной функцией потерь, которая включает тензор A. Я накапливаю убытки, а затем хочу выполнить обновление. Затем я тренирую вторую модель 2 (train2), в которой я хочу рассчитать градиенты wrt A, используя потери, рассчитанные в train2. Таким образом, я добавляю потерю 1 к потере 2.

 #reproduce error
from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel
import torch
import torch.nn.functional as F
model1 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
model2 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints


optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
A=torch.rand(1, requires_grad=True)
optimizer3 = torch.optim.SGD([A], lr=0.1)

en_input=torch.tensor([[1,2], [3,4]])
en_masks=torch.tensor([[0,0], [0,0]])
de_output=torch.tensor([[3,1], [4,2]])
de_masks=torch.tensor([[0,0], [0,0]])
lm_labels=torch.tensor([[5,7], [6,8]])

torch.autograd.set_detect_anomaly(True)

def train1():
  acc=torch.zeros(1)
  for i in range(2):
    optimizer1.zero_grad()
    out = model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                        decoder_attention_mask=de_masks, labels=lm_labels.clone())
          

    prediction_scores = out[1]
    predictions = F.log_softmax(prediction_scores, dim=2)
    p=((predictions.sum() - de_output.sum())*A).sum()
    p=torch.unsqueeze(p, dim=0)
    acc = torch.cat((p,acc)) # accumulating the loss 

  loss=acc.sum()
  A.retain_grad()
  loss.backward(retain_graph=True) 
  optimizer1.step() 
  return loss


def train2(loss1):
  for i in range (2):
     optimizer3.zero_grad()
     output = model2(input_ids=en_input, attention_mask=en_masks, 
                               decoder_input_ids=de_output, 
                      decoder_attention_mask=de_masks, labels=lm_labels.clone())
        
     prediction_scores_ = output[1]
     predictions_= F.log_softmax(prediction_scores_, dim=2)
     loss2=((predictions_.sum() - de_output.sum())).sum() loss1 # want to calculate gradients wrt A
     A.retain_grad()
     loss2.backward(inputs=[A], retain_graph=True) 
     optimizer3.step() #update A based on calculated gradients

loss1=train1()
train2(loss1)

 

Если это правильный метод, я не понимаю, что не так в моем коде? Если это неправильно, я был бы признателен, если бы кто-нибудь указал мне правильное направление

Трассировка ошибок

 /usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py:147: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 845, in launch_instance
    app.start()
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
    self._run_once()
  File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
    handle._run()
  File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
    handler_func(fileobj, events)
  File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 451, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 434, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-66-c603f915c713>", line 78, in <module>
    loss1=train1()
  File "<ipython-input-66-c603f915c713>", line 25, in train1
    p=((predictions.sum() - de_output.sum())*A).sum()
 (Triggered internally at  /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-66-c603f915c713> in <module>()
     77 for i in range(2):
     78   loss1=train1()
---> 79   train2(loss1)

2 frames
<ipython-input-66-c603f915c713> in train2(loss1)
     69     print(A.grad)
     70     #loss2.grad(inputs=A,outputs=A, only_inputs=True)
---> 71     loss2.backward(inputs=[A],retain_graph=True) #calculates gradients # retain_graph=True #list(dec.parameters())
     72     print(A.grad)
     73     # torch.nn.utils.clip_grad_norm_(model1.parameters(), 1.0)

/usr/local/lib/python3.7/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    243                 create_graph=create_graph,
    244                 inputs=inputs)
--> 245         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    246 
    247     def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145     Variable._execution_engine.run_backward(
    146         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 147         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    148 
    149 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
 

Комментарии:

1. Пожалуйста, опубликуйте полную трассировку стека ошибок (отредактируйте свой вопрос, чтобы добавить его).

2. @cronoik Я добавил трассировку стека ошибок.

3. Не уверен, что я полностью понимаю это, но разве вам не нужно отсоединить loss1? Т. е. train2(loss1.detach())

4. Почему от меня потребуется отсоединить loss1 @cronoik? Не могли бы вы, пожалуйста, объяснить. Я думаю, что если я отсоединю его, то градиенты wrt A не будут рассчитываться, так как график вычислений будет нарушен из-за потерь 1.