Есть ли какой-нибудь способ сохранить модели pytorch с их классами

#pytorch

#pytorch

Вопрос:

Когда я пытаюсь загрузить свою сохраненную модель, мне нужно импортировать ее класс. Например:

 from module import Net
torch.load('saved_model.pth')
 

Есть ли какой-нибудь способ избежать этого импорта?. Например, сохранение модели с помощью class или что-то еще?

Комментарии:

1. Где вы используете Net ?

2. Я загружаю модель в другой модуль, а не там, где была описана сеть

Ответ №1:

Если вы хотите просто загрузить модель в известный объект nn.Module, такой как net вы можете использовать torch.load_state_dict('saved_model.pth') . Если вы хотите сохранить всю модель, чтобы кто-то другой мог ее использовать, вам придется использовать для маринования:

 import pickle
 
net = Net()
with open('saved_model.pth', 'w') as filehandler: 
    pickle.dump(net, filehandler)
 

для загрузки:

 with open('saved_model.pth', 'w') as filehandler:
    net = pickle.load(filehandler)
 

Однако настоятельно рекомендуется не использовать pickle, так как это может сохранить пользовательские данные на вашем компьютере / среде и привести к тому, что они не будут работать на чужом компьютере. Если вам действительно необходимо использовать pickle, возможно, стоит посмотреть, можете ли вы отделить класс от нейронной сети, сохранить класс в файле pickle и параметры в файле torch.

Надеюсь, это поможет, а не просто то, что вы знаете.