#pytorch
#pytorch
Вопрос:
Когда я пытаюсь загрузить свою сохраненную модель, мне нужно импортировать ее класс. Например:
from module import Net
torch.load('saved_model.pth')
Есть ли какой-нибудь способ избежать этого импорта?. Например, сохранение модели с помощью class или что-то еще?
Комментарии:
1. Где вы используете
Net
?2. Я загружаю модель в другой модуль, а не там, где была описана сеть
Ответ №1:
Если вы хотите просто загрузить модель в известный объект nn.Module, такой как net
вы можете использовать torch.load_state_dict('saved_model.pth')
. Если вы хотите сохранить всю модель, чтобы кто-то другой мог ее использовать, вам придется использовать для маринования:
import pickle
net = Net()
with open('saved_model.pth', 'w') as filehandler:
pickle.dump(net, filehandler)
для загрузки:
with open('saved_model.pth', 'w') as filehandler:
net = pickle.load(filehandler)
Однако настоятельно рекомендуется не использовать pickle, так как это может сохранить пользовательские данные на вашем компьютере / среде и привести к тому, что они не будут работать на чужом компьютере. Если вам действительно необходимо использовать pickle, возможно, стоит посмотреть, можете ли вы отделить класс от нейронной сети, сохранить класс в файле pickle и параметры в файле torch.
Надеюсь, это поможет, а не просто то, что вы знаете.