как сохранить model.state_dict() в временном var для последующего использования?

#python #pytorch

Вопрос:

Я попытался временно сохранить диктат состояния моей модели в переменной и хотел позже восстановить его в своей модели, но содержимое этой переменной автоматически менялось по мере обновления модели.

Есть минимальный пример:

 import torch as t
import torch.nn as nn
from torch.optim import Adam


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, x):
        return self.fc(x)


net = Net()
loss_fc = nn.MSELoss()
optimizer = Adam(net.parameters())

weights = net.state_dict()
print(weights)

x = t.rand((5, 3))
y = t.rand((5, 2))
loss = loss_fc(net(x), y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(weights)
 

Я думал, что два выхода будут одинаковыми, но я получил (выходы могут измениться из-за случайной инициализации)

 OrderedDict([('fc.weight', tensor([[-0.5557,  0.0544, -0.2277],
        [-0.0793,  0.4334, -0.1548]])), ('fc.bias', tensor([-0.2204,  0.2846]))])
OrderedDict([('fc.weight', tensor([[-0.5547,  0.0554, -0.2267],
        [-0.0783,  0.4344, -0.1538]])), ('fc.bias', tensor([-0.2194,  0.2856]))])
 

Содержание weights изменилось, что так странно.

Я также пробовал .copy() и t.no_grad() как следует, но они не помогли.

 with t.no_grad():
    weights = net.state_dict().copy()
 

Да, я знаю , что могу сохранить диктант состояния с помощью t.save() , но я просто хочу выяснить, что произошло в предыдущем примере.

Я использую Python 3.8.5 и Pytorch 1.8.1

Спасибо за любую помощь.

Ответ №1:

Вот как OrderedDict это работает. Вот более простой пример:

 from collections import OrderedDict

# a mutable variable
l = [1,2,3]

# an OrderedDict with an entry pointing to that mutable variable
x = OrderedDict([("a", l)])

# if you change the list
l[1] = 20

# the change is reflected in the OrderedDict
print(x)
# >> OrderedDict([('a', [1, 20, 3])])
 

Если вы хотите избежать этого, вам придется сделать deepcopy , а не мелкое copy :

 from copy import deepcopy
x2 = deepcopy(x)

print(x2)
# >> OrderedDict([('a', [1, 20, 3])])

# now, if you change the list
l[2] = 30

# you do not change your copy
print(x2)
# >> OrderedDict([('a', [1, 20, 3])])

# but you keep changing the original dict
print(x)
# >> OrderedDict([('a', [1, 20, 30])])
 

Поскольку Tensor это также изменчиво, в вашем случае ожидается такое же поведение. Поэтому вы можете использовать:

 from copy import deepcopy

weights = deepcopy(net.state_dict())
 

Комментарии:

1. Большое спасибо! Это помогает. Я думал .copy() , что это сработает, но он просто скопировал запись, указывающую на этот тензор.

2. @Vvvvvv рад это знать! Пожалуйста, подумайте о том, чтобы проголосовать и/или отметить как ответ, если это было полезно