Что заставляет модуль nn Pytorch возвращать 1 или 0 для argmax этого массива?

#python #arrays #pytorch

Вопрос:

У меня есть массив, содержащий четыре случайных числа, которые функция argmax должна возвращать 0, 1, 2 или 3, но когда argmax вызывается из модели nn.Module, это всегда 0 или 1. Я просто хотел бы знать, как и почему он всегда получает 1 или 0 из четырех чисел в массиве.

Ниже у меня есть модуль nn.и сравнение случайного массива len 3, вычисленного внутри и снаружи модели с использованием функции act ( Net.act ).

 from torch import nn
import torch

import random


class Network(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(4, 64),
            nn.Tanh(),
            nn.Linear(64, 2))

    def forward(self, x):
        return self.net(x)

    def act(self, obs):
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
        q_values = self(obs_t.unsqueeze(0))
        max_q_index = torch.argmax(q_values, dim=1)[0]
        action = max_q_index.detach().item()

        return action


Net = Network()
 

Это параллельное сравнение случайного массива len 3 внутри и снаружи функции act в nn.Модуле (Net).

 for _ in range(20):
    z = np.array([random.uniform(-1, 1) for _ in range(4)])

    obs_t = torch.as_tensor(z, dtype=torch.float32, device=device)
    q_values = (obs_t.unsqueeze(0))
    max_q_index = torch.argmax(q_values, dim=1)[0]
    action = max_q_index.detach().item()

    print(action, Net.act(z))
 

Выход

 3 0
0 0
2 0
0 0
2 0
0 0
1 0
2 0
3 0
0 1
0 0
1 1
3 1
3 0
1 0
0 1
3 0
3 0
0 0
1 0
 

Комментарии:

1. self.net выходной слой имеет 2 записи (определенные с nn.Linear(64, 2) помощью), поэтому argmax там может выводиться только 0 или 1…

Ответ №1:

self(obs_t.unsqueeze(0)) возвращает матрицу из 2 столбцов, поскольку последний слой вашей модели ( nn.Linear(64, 2) ) определен для вывода двух столбцов. max_q_index содержит индекс столбца с наибольшим значением в каждой строке выходных данных модели (индекс столбца, потому dim=1 что ). Поскольку существует только 2 столбца, max_q_index они могут иметь только значения 0 или 1.

Ответ №2:

если вы хотите, чтобы ваши выходные данные совпадали, измените q_values = self(obs_t.unsqueeze(0)) q_values = obs_t.unsqueeze(0) act метод на in
Причина self(obs_t.unsqueeze(0)) в том, что на самом деле вызовет метод forward , который имеет 2 выходных нейрона, следовательно, вывод будет либо 0 or 1 .
Вот вывод, если я перейду q_values = self(obs_t.unsqueeze(0)) к q_values = obs_t.unsqueeze(0) act методу in

 1 1
1 1
3 3
0 0
2 2
2 2
1 1
0 0
1 1
3 3
3 3
1 1
3 3
3 3
3 3
0 0
3 3
1 1
1 1
2 2
 

Чего ты хотел