#python #pytorch
#python #pytorch
Вопрос:
Я очень новичок в PyTorch, и я столкнулся с ошибкой «Тензор индекса должен иметь то же количество измерений, что и входной тензор» при запуске моей нейронной сети. Это происходит, когда я вызываю экземпляр torch.gather(). Может ли кто-нибудь помочь мне разобраться в torch.gather () и объяснить причину этой ошибки?
Вот код, в котором возникает ошибка:
def learn(batch, optim, net, target_net, gamma, global_step, target_update):
my_loss = []
optim.zero_grad()
state, action, next_state, reward, done, next_action = batch
qval = net(state.float())
loss_a = torch.gather(qval, 3, action.view(-1,1,1,1)).squeeze() #Error happens here!
loss_b = reward gamma * torch.max(target_net(next_state.float()).cuda(), dim=3).values * (1 - done.int())
loss_val = torch.sum(( torch.abs(loss_a-loss_b) ))
loss_val /= 128
my_loss.append(loss_val.item())
loss_val.backward()
optim.step()
if global_step % target_update == 0:
target_network.load_state_dict(q_network.state_dict())
На случай, если это полезно, вот пакетная функция, которая создает пакет, из которого исходит действие:
def sample_batch(memory,batch_size):
indices = np.random.randint(0,len(memory), (batch_size,))
state = torch.stack([memory[i][0] for i in indices])
action = torch.tensor([memory[i][1] for i in indices], dtype = torch.long)
next_state = torch.stack([memory[i][2] for i in indices])
reward = torch.tensor([memory[i][3] for i in indices], dtype = torch.float)
done = torch.tensor([memory[i][4] for i in indices], dtype = torch.float)
next_action = torch.tensor([memory[i][5] for i in indices], dtype = torch.long)
return state,action,next_state,reward,done,next_action
Когда я распечатываю разные формы ‘qvals’, ‘action’ и ‘action.view (-1,1,1,1)’, это результат:
qval torch.Size([10, 225])
act view torch.Size([10, 1, 1, 1])
action shape torch.Size([10])
Приветствуется любое объяснение того, что вызывает эту ошибку! Я хочу больше понять, что происходит в коде, а также как устранить проблему. Спасибо!
Ответ №1:
Torch.gather описан здесь. Если мы возьмем ваш код, эта строка
torch.gather(qval, 3, action.view(-1,1,1,1))
эквивалентно
act_view = action.view(10,1,1,1)
out = torch.zeros_like(act_view)
for i in range(10):
for j in range(1):
for k in range(1):
for p in range(1):
out[i,j,k,p] = qval[i,j,k, act_view[i,j,k,p]]
return out
что, очевидно, имеет очень мало смысла. В частности, qval
не является 4-D и, следовательно, не может быть проиндексирован подобным образом. Количество for
циклов определяется формой ваших входных тензоров, и все они должны иметь одинаковое количество измерений, чтобы это работало (кстати, об этом говорит ваша ошибка). Здесь qval
это 2D и act_view
это 4D.
Я не уверен, что вы хотели с этим сделать, но если вы можете объяснить свою цель и удалить весь бесполезный материал в вашем примере (в основном код, связанный с обучением и обратной обработкой), чтобы получить минимально воспроизводимый пример, я мог бы помочь вам найти правильный способ сделать это 🙂