Тензор индекса должен иметь то же количество измерений, что и ошибка входного тензора, возникшая при использовании torch.gather()

#python #pytorch

#python #pytorch

Вопрос:

Я очень новичок в PyTorch, и я столкнулся с ошибкой «Тензор индекса должен иметь то же количество измерений, что и входной тензор» при запуске моей нейронной сети. Это происходит, когда я вызываю экземпляр torch.gather(). Может ли кто-нибудь помочь мне разобраться в torch.gather () и объяснить причину этой ошибки?

Вот код, в котором возникает ошибка:

   def learn(batch, optim, net, target_net, gamma, global_step, target_update):
      my_loss = []
      optim.zero_grad()
  
      state, action, next_state, reward, done, next_action = batch
      qval = net(state.float())
  
      loss_a = torch.gather(qval, 3, action.view(-1,1,1,1)).squeeze() #Error happens here!

      loss_b = reward   gamma * torch.max(target_net(next_state.float()).cuda(), dim=3).values * (1 - done.int())
      loss_val = torch.sum(( torch.abs(loss_a-loss_b) ))
      loss_val /= 128
      my_loss.append(loss_val.item())
      loss_val.backward()
      optim.step()
      if global_step % target_update == 0:
          target_network.load_state_dict(q_network.state_dict())
  

На случай, если это полезно, вот пакетная функция, которая создает пакет, из которого исходит действие:

 def sample_batch(memory,batch_size):
    
    indices = np.random.randint(0,len(memory), (batch_size,))

    state = torch.stack([memory[i][0] for i in indices]) 
    action = torch.tensor([memory[i][1] for i in indices], dtype = torch.long)
    next_state = torch.stack([memory[i][2] for i in indices])
    reward = torch.tensor([memory[i][3] for i in indices], dtype = torch.float)
    done = torch.tensor([memory[i][4] for i in indices], dtype = torch.float)
    next_action = torch.tensor([memory[i][5] for i in indices], dtype = torch.long)

    return state,action,next_state,reward,done,next_action
  

Когда я распечатываю разные формы ‘qvals’, ‘action’ и ‘action.view (-1,1,1,1)’, это результат:

 qval torch.Size([10, 225])
act view torch.Size([10, 1, 1, 1])
action shape  torch.Size([10])
  

Приветствуется любое объяснение того, что вызывает эту ошибку! Я хочу больше понять, что происходит в коде, а также как устранить проблему. Спасибо!

Ответ №1:

Torch.gather описан здесь. Если мы возьмем ваш код, эта строка

 torch.gather(qval, 3, action.view(-1,1,1,1))
  

эквивалентно

 act_view = action.view(10,1,1,1)
out = torch.zeros_like(act_view)
for i in range(10):
    for j in range(1):
         for k in range(1):
              for p in range(1):
                   out[i,j,k,p] = qval[i,j,k, act_view[i,j,k,p]]
return out
  

что, очевидно, имеет очень мало смысла. В частности, qval не является 4-D и, следовательно, не может быть проиндексирован подобным образом. Количество for циклов определяется формой ваших входных тензоров, и все они должны иметь одинаковое количество измерений, чтобы это работало (кстати, об этом говорит ваша ошибка). Здесь qval это 2D и act_view это 4D.

Я не уверен, что вы хотели с этим сделать, но если вы можете объяснить свою цель и удалить весь бесполезный материал в вашем примере (в основном код, связанный с обучением и обратной обработкой), чтобы получить минимально воспроизводимый пример, я мог бы помочь вам найти правильный способ сделать это 🙂