Ошибка утверждения Catalyst SupervisedRunner при первом запуске

#python #python-3.x #machine-learning #pytorch #pytorch-lightning

Вопрос:

Я пытаюсь использовать catalyst для обучения пользовательскую pytorch нейронную сеть, которую я создал, однако, когда я впервые запускаю код на Юпитере, он всегда дает мне AssertionError неясное объяснение. Когда я запускаю камеру во второй раз, она, кажется, работает нормально. Как это может произойти и как это исправить?

 from catalyst import dl  runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss") runner.train(  model= net,  criterion= criterion,  optimizer= optimizer,  loaders= {"training": train, "val": val},   num_epochs=5,  callbacks=[   dl.AccuracyCallback(input_key="logits", target_key="targets",   topk_args=(1, 3, 5)),  dl.ConfusionMatrixCallback(input_key="logits",   target_key="targets",   num_classes=6),  ],  logdir="./logs",  valid_loader="val",   valid_metric="loss",  minimize_valid_metric=True,  verbose=True,  load_best_on_end=True,  seed= 42  )   

Полное отслеживание ошибки:

 AssertionError Traceback (most recent call last) lt;ipython-input-15-4eae1ab61f41gt; in lt;modulegt;  2   3 runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss") ----gt; 4 runner.train(  5 model= net,  6 criterion= criterion,  ~/venv/lib/python3.9/site-packages/catalyst/runners/runner.py in train(self, loaders, model, engine, trial, criterion, optimizer, scheduler, callbacks, loggers, seed, hparams, num_epochs, logdir, valid_loader, valid_metric, minimize_valid_metric, verbose, timeit, check, overfit, load_best_on_end, fp16, amp, apex, ddp)  513 self._load_best_on_end = load_best_on_end  514 # run --gt; 515 self.run()  516   517 @torch.no_grad()  ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in run(self)  852 self.exception = ex  853 self._run_event("on_experiment_end") --gt; 854 self._run_event("on_exception")  855 return self  856   ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_event(self, event)  786 getattr(callback, event)(self)  787 if _has_str_intersections(event, ("_end", "_exception")): --gt; 788 getattr(self, event)(self)  789   790 @abstractmethod  ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in on_exception(self, runner)  778 def on_exception(self, runner: "IRunner"):  779 """Event handler.""" --gt; 780 raise self.exception  781   782 def _run_event(self, event: str) -gt; None:  ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in run(self)  848 """  849 try: --gt; 850 self._run_experiment()  851 except (Exception, KeyboardInterrupt) as ex:  852 self.exception = ex  ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_experiment(self)  838 self._run_event("on_experiment_start")  839 for self.stage_key in self.stages: --gt; 840 self.engine.spawn(self._run_stage)  841 self._run_event("on_experiment_end")  842   ~/venv/lib/python3.9/site-packages/catalyst/core/engine.py in spawn(self, fn, *args, **kwargs)  136 wrapped function (if needed).  137 """ --gt; 138 return fn(*args, **kwargs)  139   140 def setup_process(self, rank: int = -1, world_size: int = 1):  ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_stage(self, rank, world_size)  829 self._run_event("on_stage_start")  830 while self.stage_epoch_step lt; self.stage_epoch_len: --gt; 831 self._run_epoch()  832 if self.need_early_stop:  833 self.need_early_stop = False  ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_epoch(self)  822 self._run_event("on_epoch_start")  823 for self.loader_key, self.loader in self.loaders.items(): --gt; 824 self._run_loader()  825 self._run_event("on_epoch_end")  826   ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_loader(self)  809 # as it was noted in docs:  810 # https://pytorch.org/docs/stable/notes/amp_examples.html#typical-mixed-precision-training --gt; 811 self._run_event("on_loader_start")  812 with torch.set_grad_enabled(self.is_train_loader):  813 for self.loader_batch_step, self.batch in enumerate(self.loader):  ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_event(self, event)  782 def _run_event(self, event: str) -gt; None:  783 if _has_str_intersections(event, ("_start",)): --gt; 784 getattr(self, event)(self)  785 for callback in self.callbacks.values():  786 getattr(callback, event)(self)  ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in on_loader_start(self, runner)  707 self.is_valid_loader: bool = self.loader_key.startswith("valid")  708 self.is_infer_loader: bool = self.loader_key.startswith("infer") --gt; 709 assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader  710 self.loader_batch_size: int = _get_batch_size(self.loader)  711 self.loader_batch_len: int = len(self.loader)  AssertionError: