#python #pytorch #object-detection
Вопрос:
Я получаю эту ошибку при запуске
net=load_net(‘../input/wheat-effdet5-fold0-best-checkpoint/fold0-best-all-states.bin’)
в этой записной книжке https://www.kaggle.com/shonenkov/inference-efficientdet/notebook
RuntimeError: Error(s) in loading state_dict for EfficientDet:
size mismatch for class_net.predict.conv_pw.weight: copying a param with shape torch.Size([9, 288, 1, 1]) from checkpoint, the shape in current model is torch.Size([810, 288, 1, 1]).
size mismatch for class_net.predict.conv_pw.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([810]).
def load_net(checkpoint_path):
config = get_efficientdet_config('tf_efficientdet_d5')
net = EfficientDet(config, pretrained_backbone=False)
config.num_classes = 1
config.image_size=512
net.class_net = HeadNet(config,
num_outputs=config.num_classes,norm_kwargs=dict(eps=.001,
momentum=.01))
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
del checkpoint
gc.collect()
net = DetBenchEval(net, config)
net.eval()
return net.cuda()
net = load_net('../input/wheat-effdet5-fold0-best-
checkpoint/fold0-best-all-states.bin')