не удалось загрузить модель pytorch для оценки

#python #pytorch #inference

Вопрос:

У меня .pth сохранена модель, и я пытаюсь загрузить ее, чтобы сделать вывод, используя следующий код

 model = GatherModel()
model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))
 

и я получаю эту ошибку, показанную ниже. почему я получаю это.

 ---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-3bff0e426886> in <module>()
----> 1 model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1405         if len(error_msgs) > 0:
   1406             raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
-> 1407                                self.__class__.__name__, "nt".join(error_msgs)))
   1408         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1409 

RuntimeError: Error(s) in loading state_dict for GatherModel:
    Missing key(s) in state_dict: "lin0.weight", "lin0.bias", "set2set.lstm.weight_ih_l0", "set2set.lstm.weight_hh_l0", "set2set.lstm.bias_ih_l0", "set2set.lstm.bias_hh_l0", "message_layer.weight", "message_layer.bias", "conv.bias", "conv.edge_func.0.weight", "conv.edge_func.0.bias", "conv.edge_func.2.weight", "conv.edge_func.2.bias". 
    Unexpected key(s) in state_dict: "solute_pass.U_0.weight", "solute_pass.U_0.bias", "solute_pass.U_1.weight", "solute_pass.U_1.bias", "solute_pass.U_2.weight", "solute_pass.U_2.bias", "solute_pass.M_0.weight", "solute_pass.M_0.bias", "solute_pass.M_1.weight", "solute_pass.M_1.bias", "solute_pass.M_2.weight", "solute_pass.M_2.bias", "solvent_pass.U_0.weight", "solvent_pass.U_0.bias", "solvent_pass.U_1.weight", "solvent_pass.U_1.bias", "solvent_pass.U_2.weight", "solvent_pass.U_2.bias", "solvent_pass.M_0.weight", "solvent_pass.M_0.bias", "solvent_pass.M_1.weight", "solvent_pass.M_1.bias", "solvent_pass.M_2.weight", "solvent_pass.M_2.bias", "lstm_solute.weight_ih_l0", "lstm_solute.weight_hh_l0", "lstm_solute.bias_ih_l0", "lstm_solute.bias_hh_l0", "lstm_solvent.weight_ih_l0", "lstm_solvent.weight_hh_l0", "lstm_solvent.bias_ih_l0", "lstm_solvent.bias_hh_l0", "lstm_gather_solute.weight_ih_l0", "lstm_gather_solute.weight_hh_l0", "lstm_gather_solute.bias_ih_l0", "lstm_gather_solute.bias_hh_l0", "lstm_gather_solvent.weight_ih_l0", "lstm_gather_solvent.weight_hh_l0", "lstm_gather_solvent.bias_ih_l0", "lstm_gather_solvent.bias_hh_l0", "first_layer.weight", "first_layer.bias", "second_layer.weight", "second_layer.bias", "third_layer.weight", "third_layer.bias", "fourth_layer.weight", "fourth_layer.bias". 
 

Я пробовал использовать strict=False в state_dict, но я получаю эту ошибку

 _IncompatibleKeys(missing_keys=['lin0.weight', 'lin0.bias', 'set2set.lstm.weight_ih_l0', 'set2set.lstm.weight_hh_l0', 'set2set.lstm.bias_ih_l0', 'set2set.lstm.bias_hh_l0', 'message_layer.weight', 'message_layer.bias', 'conv.bias', 'conv.edge_func.0.weight', 'conv.edge_func.0.bias', 'conv.edge_func.2.weight', 'conv.edge_func.2.bias'], unexpected_keys=['solute_pass.U_0.weight', 'solute_pass.U_0.bias', 'solute_pass.U_1.weight', 'solute_pass.U_1.bias', 'solute_pass.U_2.weight', 'solute_pass.U_2.bias', 'solute_pass.M_0.weight', 'solute_pass.M_0.bias', 'solute_pass.M_1.weight', 'solute_pass.M_1.bias', 'solute_pass.M_2.weight', 'solute_pass.M_2.bias', 'solvent_pass.U_0.weight', 'solvent_pass.U_0.bias', 'solvent_pass.U_1.weight', 'solvent_pass.U_1.bias', 'solvent_pass.U_2.weight', 'solvent_pass.U_2.bias', 'solvent_pass.M_0.weight', 'solvent_pass.M_0.bias', 'solvent_pass.M_1.weight', 'solvent_pass.M_1.bias', 'solvent_pass.M_2.weight', 'solvent_pass.M_2.bias', 'lstm_solute.weight_ih_l0', 'lstm_solute.weight_hh_l0', 'lstm_solute.bias_ih_l0', 'lstm_solute.bias_hh_l0', 'lstm_solvent.weight_ih_l0', 'lstm_solvent.weight_hh_l0', 'lstm_solvent.bias_ih_l0', 'lstm_solvent.bias_hh_l0', 'lstm_gather_solute.weight_ih_l0', 'lstm_gather_solute.weight_hh_l0', 'lstm_gather_solute.bias_ih_l0', 'lstm_gather_solute.bias_hh_l0', 'lstm_gather_solvent.weight_ih_l0', 'lstm_gather_solvent.weight_hh_l0', 'lstm_gather_solvent.bias_ih_l0', 'lstm_gather_solvent.bias_hh_l0', 'first_layer.weight', 'first_layer.bias', 'second_layer.weight', 'second_layer.bias', 'third_layer.weight', 'third_layer.bias', 'fourth_layer.weight', 'fourth_layer.bias'])
 

Ответ №1:

Ошибка в основном говорит о том , что существуют веса, определенные используемой вами архитектурой, которых нет в state_dict , а также есть веса, которые не определены архитектурой, но присутствуют в state_dict . Вы уверены, что все, что определяется GatherModel() , является той же архитектурой, которая была создана state_dict в первую очередь? Потому что эта ошибка указывает на то, что ответ отрицательный.