#python #machine-learning #pytorch
#python #машинное обучение #pytorch
Вопрос:
Я пытаюсь загрузить предварительно подготовленную модель с помощью torch.load.
Я получаю следующую ошибку:
ModuleNotFoundError: No module named 'utils'
Я проверил, что используемый мной путь правильный, открыв его из командной строки. Что может быть причиной этого?
Вот мой код:
import torch
import sys
PATH = './gan.pth'
model = torch.load(PATH)
model.eval()
Редактировать:
Весь стек ошибок:
Traceback (most recent call last):
File "load.py", line 6, in <module>
model = torch.load(PATH)
File "C:Usersuseranaconda3envspytorch-flasklibsite-packagestorchserialization.py", line 595, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "C:Usersuseranaconda3envspytorch-flasklibsite-packagestorchserialization.py", line 774, in _legacy_load
result = unpickler.load()
ModuleNotFoundError: No module named 'utils'
Комментарии:
1. Вы уверены, что можете запустить этот конкретный файл (без
model = torch.load(PATH); model.eval()
)? Какие-либо другие зависимости (импорт), которые вы не упомянули?2. @Ivan Да. Файл запускается без последних двух строк.
3. Можете ли вы предоставить нам весь стек ошибок?
4. @Ivan добавлено в нижней части поста
5. Есть ли у вас какие-либо идеи о том, как этот
.pth
файл был сохранен (какая функция использовалась)?
Ответ №1:
РЕДАКТИРОВАТЬ этот ответ не дает ответа на вопрос, но решает другую проблему в данном коде
в .pth
файле хранятся только параметры модели, а не сама модель. Когда вы хотите загрузить модель, вам понадобится .pt/-h
файл и код python вашего класса модели. Затем вы можете загрузить его следующим образом:
# your model
class YourModel(nn.Modules):
def __init__(self):
super(YourModel, self).__init__()
. . .
def forward(self, x):
. . .
# the pytorch save-file in which you stored your trained model
model_file = "<your path>"
model = Model()
model = model.load_state_dict(torch.load(model_file))
model.eval()
Комментарии:
1. Хотя это верно, это не дает ответа на вопрос. Проблема, похоже, в
model = torch.load(PATH)
. Мы не можем догадаться, что внутри этого.pth
…2. да, я только что заметил, что ошибка не возникает из-за этого
Ответ №2:
У меня была такая же точная ошибка, и мне было интересно, в чем проблема. Оказывается, проблема в том, что данные, сохраненные с torch.load()
помощью модуля utils
, необходимы.
Пример:
from utils import some_function
model = some_function()
torch.save(model)
При сохранении с помощью torch в данном примере он распознает, что модуль utils использовался для получения желаемых данных. Таким образом, при загрузке файла ‘.pth’ вам необходимо импортировать тот же модуль utils
.