Ошибка RuntimeError: ошибки при загрузке state_dict для EfficientDet

#python #pytorch #object-detection

Вопрос:

Я получаю эту ошибку при запуске

net=load_net(‘../input/wheat-effdet5-fold0-best-checkpoint/fold0-best-all-states.bin’)

в этой записной книжке https://www.kaggle.com/shonenkov/inference-efficientdet/notebook

 RuntimeError: Error(s) in loading state_dict for EfficientDet:
size mismatch for class_net.predict.conv_pw.weight: copying a param with shape torch.Size([9, 288, 1, 1]) from checkpoint, the shape in current model is torch.Size([810, 288, 1, 1]).
size mismatch for class_net.predict.conv_pw.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([810]).
 
     def load_net(checkpoint_path):
        config = get_efficientdet_config('tf_efficientdet_d5')
        net = EfficientDet(config, pretrained_backbone=False)

        config.num_classes = 1
        config.image_size=512
        net.class_net = HeadNet(config, 
        num_outputs=config.num_classes,norm_kwargs=dict(eps=.001, 
        momentum=.01))

        checkpoint = torch.load(checkpoint_path)
        net.load_state_dict(checkpoint['model_state_dict'])

        del checkpoint
        gc.collect()

        net = DetBenchEval(net, config)
        net.eval()
        return net.cuda()

    net = load_net('../input/wheat-effdet5-fold0-best- 
    checkpoint/fold0-best-all-states.bin')