Как изменить loss.backward() в Pytorch для учета np.nan?

#python #neural-network #pytorch #loss-function #backpropagation

Вопрос:

Я тренирую простую нейронную сеть с помощью Pytorch. Мои входные данные что-то вроде

 [10.2, nan] [10.0, 5.0] [nan, 3.2]  

Где первый индекс всегда вдвое превышает второй индекс. Я могу обучить нейронную сеть, подавая примеры один за другим, такие как (10.0, 0), (5.0, 1), где 0 и 1 соответствуют их индексам.

Однако я хочу обучать массивы и выводить целые массивы. Так что моя нейронная сеть, например, будет предсказывать [10.2, 5.1] для первой записи.

Моя текущая проблема заключается в том, что потеря является тензором([nan]). Таким образом, функция loss.backward() не работает со следующей ошибкой:

 RuntimeError Traceback (most recent call last) lt;ipython-input-62-03be43016acagt; in lt;modulegt;()  43 if training:  44 scheduler.step() ---gt; 45 loss.backward()  46 optimizer.step()  47 lr_history.extend(scheduler.get_lr())  1 frames /usr/local/lib/python3.7/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)  253 create_graph=create_graph,  254 inputs=inputs) --gt; 255 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)  256   257 def register_hook(self, hook):  /usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)  147 Variable._execution_engine.run_backward(  148 tensors, grad_tensors_, retain_graph, create_graph, inputs, --gt; 149 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag  150   151   RuntimeError: Found dtype Double but expected Float  

Любые советы о том, как изменить функцию loss.backward () (или как создать свою собственную), чтобы она игнорировала функцию nan?

Заранее спасибо!

Комментарии:

1. Почему бы вам не исправить свои ярлыки, если вы знаете, что между этими двумя классами существует такая связь? Важны ли НАНЫ?

2. Извините, я не понимаю. Что вы предлагаете? Мне нужно ввести весь массив в качестве входных данных (например, [10, nan]).

3. Да, но ваши ярлыки подразумеваются? Вы можете просто сделать значения nan их фактическими значениями, поскольку они имеют определенную взаимосвязь. Тогда вам не нужно вносить изменения в свой код.