#python #neural-network #pytorch #loss-function #backpropagation
Вопрос:
Я тренирую простую нейронную сеть с помощью Pytorch. Мои входные данные что-то вроде
[10.2, nan] [10.0, 5.0] [nan, 3.2]
Где первый индекс всегда вдвое превышает второй индекс. Я могу обучить нейронную сеть, подавая примеры один за другим, такие как (10.0, 0), (5.0, 1), где 0 и 1 соответствуют их индексам.
Однако я хочу обучать массивы и выводить целые массивы. Так что моя нейронная сеть, например, будет предсказывать [10.2, 5.1] для первой записи.
Моя текущая проблема заключается в том, что потеря является тензором([nan]). Таким образом, функция loss.backward() не работает со следующей ошибкой:
RuntimeError Traceback (most recent call last) lt;ipython-input-62-03be43016acagt; in lt;modulegt;() 43 if training: 44 scheduler.step() ---gt; 45 loss.backward() 46 optimizer.step() 47 lr_history.extend(scheduler.get_lr()) 1 frames /usr/local/lib/python3.7/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs) 253 create_graph=create_graph, 254 inputs=inputs) --gt; 255 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) 256 257 def register_hook(self, hook): /usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 147 Variable._execution_engine.run_backward( 148 tensors, grad_tensors_, retain_graph, create_graph, inputs, --gt; 149 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag 150 151 RuntimeError: Found dtype Double but expected Float
Любые советы о том, как изменить функцию loss.backward () (или как создать свою собственную), чтобы она игнорировала функцию nan?
Заранее спасибо!
Комментарии:
1. Почему бы вам не исправить свои ярлыки, если вы знаете, что между этими двумя классами существует такая связь? Важны ли НАНЫ?
2. Извините, я не понимаю. Что вы предлагаете? Мне нужно ввести весь массив в качестве входных данных (например, [10, nan]).
3. Да, но ваши ярлыки подразумеваются? Вы можете просто сделать значения nan их фактическими значениями, поскольку они имеют определенную взаимосвязь. Тогда вам не нужно вносить изменения в свой код.