#python #python-3.x #machine-learning #pytorch #pytorch-lightning
Вопрос:
Я пытаюсь использовать catalyst
для обучения пользовательскую pytorch
нейронную сеть, которую я создал, однако, когда я впервые запускаю код на Юпитере, он всегда дает мне AssertionError
неясное объяснение. Когда я запускаю камеру во второй раз, она, кажется, работает нормально. Как это может произойти и как это исправить?
from catalyst import dl runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss") runner.train( model= net, criterion= criterion, optimizer= optimizer, loaders= {"training": train, "val": val}, num_epochs=5, callbacks=[ dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)), dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=6), ], logdir="./logs", valid_loader="val", valid_metric="loss", minimize_valid_metric=True, verbose=True, load_best_on_end=True, seed= 42 )
Полное отслеживание ошибки:
AssertionError Traceback (most recent call last) lt;ipython-input-15-4eae1ab61f41gt; in lt;modulegt; 2 3 runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss") ----gt; 4 runner.train( 5 model= net, 6 criterion= criterion, ~/venv/lib/python3.9/site-packages/catalyst/runners/runner.py in train(self, loaders, model, engine, trial, criterion, optimizer, scheduler, callbacks, loggers, seed, hparams, num_epochs, logdir, valid_loader, valid_metric, minimize_valid_metric, verbose, timeit, check, overfit, load_best_on_end, fp16, amp, apex, ddp) 513 self._load_best_on_end = load_best_on_end 514 # run --gt; 515 self.run() 516 517 @torch.no_grad() ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in run(self) 852 self.exception = ex 853 self._run_event("on_experiment_end") --gt; 854 self._run_event("on_exception") 855 return self 856 ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_event(self, event) 786 getattr(callback, event)(self) 787 if _has_str_intersections(event, ("_end", "_exception")): --gt; 788 getattr(self, event)(self) 789 790 @abstractmethod ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in on_exception(self, runner) 778 def on_exception(self, runner: "IRunner"): 779 """Event handler.""" --gt; 780 raise self.exception 781 782 def _run_event(self, event: str) -gt; None: ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in run(self) 848 """ 849 try: --gt; 850 self._run_experiment() 851 except (Exception, KeyboardInterrupt) as ex: 852 self.exception = ex ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_experiment(self) 838 self._run_event("on_experiment_start") 839 for self.stage_key in self.stages: --gt; 840 self.engine.spawn(self._run_stage) 841 self._run_event("on_experiment_end") 842 ~/venv/lib/python3.9/site-packages/catalyst/core/engine.py in spawn(self, fn, *args, **kwargs) 136 wrapped function (if needed). 137 """ --gt; 138 return fn(*args, **kwargs) 139 140 def setup_process(self, rank: int = -1, world_size: int = 1): ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_stage(self, rank, world_size) 829 self._run_event("on_stage_start") 830 while self.stage_epoch_step lt; self.stage_epoch_len: --gt; 831 self._run_epoch() 832 if self.need_early_stop: 833 self.need_early_stop = False ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_epoch(self) 822 self._run_event("on_epoch_start") 823 for self.loader_key, self.loader in self.loaders.items(): --gt; 824 self._run_loader() 825 self._run_event("on_epoch_end") 826 ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_loader(self) 809 # as it was noted in docs: 810 # https://pytorch.org/docs/stable/notes/amp_examples.html#typical-mixed-precision-training --gt; 811 self._run_event("on_loader_start") 812 with torch.set_grad_enabled(self.is_train_loader): 813 for self.loader_batch_step, self.batch in enumerate(self.loader): ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_event(self, event) 782 def _run_event(self, event: str) -gt; None: 783 if _has_str_intersections(event, ("_start",)): --gt; 784 getattr(self, event)(self) 785 for callback in self.callbacks.values(): 786 getattr(callback, event)(self) ~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in on_loader_start(self, runner) 707 self.is_valid_loader: bool = self.loader_key.startswith("valid") 708 self.is_infer_loader: bool = self.loader_key.startswith("infer") --gt; 709 assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader 710 self.loader_batch_size: int = _get_batch_size(self.loader) 711 self.loader_batch_len: int = len(self.loader) AssertionError: