В чем проблема в моем создании Softmax с нуля в Pytorch

#python-3.x #pytorch #loss-function #softmax

#python-3.x #пыторч #функция потери #softmax

Вопрос:

Я прочитал этот пост и попытался самостоятельно создать softmax. Вот этот код

 import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
import numpy as np

#============================ get the dataset =========================

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

batch_size = 256
num_workers = 0  

train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)



#============================     train      =========================
num_inputs = 28 * 28
num_outputs = 10
epochs = 5
lr = 0.05

# Initi the Weight and bia
W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
W.requires_grad_(requires_grad = True)
b.requires_grad_(requires_grad=True)

# softmax function
def softmax(X):
    X = X.exp()
    den = X.sum(dim=1, keepdim=True)
    return X / den  

# loss
def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1))).sum()

# accuracy function
def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()
    

for epoch in range(epochs):

    train_loss_sum = 0.0
    train_acc_sum = 0.0
    n_train = 0

    for X, y in train_iter:
        # X.shape: [256, 1, 28, 28]
        # y.shape: [256]
        
        # flatten the X into [256, 28*28]
        X = X.flatten(start_dim=1)  
        y_pred = softmax(torch.mm(X, W)   b)
        
        loss = cross_entropy(y_pred, y)
       
        loss.backward()

        W.data = W.data - lr * W.grad
        b.data = b.data - lr* b.grad

        W.grad.zero_()
        b.grad.zero_()

        train_loss_sum  = loss.item() 

        train_acc_sum  = accuracy(y_pred, y)
        n_train  = y.shape[0]
    
    # evaluate the Test
   
    test_acc, n_test = 0.0, 0
    with torch.no_grad():

        for X_test, y_test in test_iter:
            X_test = X_test.flatten(start_dim=1) 
            y_test_pred = softmax(torch.mm(X_test, W)   b)
            test_acc  = accuracy(y_test_pred, y_test)
            n_test  = y_test.shape[0]

    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch   1, train_loss_sum/n_train , train_acc_sum / n_train, test_acc / n_test))

 

Сравните с оригинальным постом, здесь я перехожу

 def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))

 

в

 def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1))).sum()
 

Так backward как нужен скаляр.

Тем не менее, мои результаты таковы

 epoch 1, loss nan, train acc 0.000, test acc 0.000
epoch 2, loss nan, train acc 0.000, test acc 0.000
epoch 3, loss nan, train acc 0.000, test acc 0.000
epoch 4, loss nan, train acc 0.000, test acc 0.000
epoch 5, loss nan, train acc 0.000, test acc 0.000
 

Есть какие-нибудь идеи?

Спасибо.

Ответ №1:

Изменить:

 def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1))).sum()
 

Для:

 def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y]   1e-8).sum()
 

Выходные данные должны быть примерно такими:

 epoch 1, loss 9.2651, train acc 0.002, test acc 0.002
epoch 2, loss 7.8493, train acc 0.002, test acc 0.002
epoch 3, loss 6.6875, train acc 0.002, test acc 0.003
epoch 4, loss 6.0928, train acc 0.003, test acc 0.003
epoch 5, loss 5.1277, train acc 0.003, test acc 0.003
 

И имейте в виду, что проблема nan также может возникнуть X = X.exp() из softmax(X) -за того, что, когда X слишком большой, тогда exp() будет вывод inf , когда это произойдет, вы можете попытаться обрезать X перед использованием exp()