Расхождение KL переходит в Байесовскую сверточную нейронную сеть

#pytorch #conv-neural-network #nan #bayesian #bayesian-networks

Вопрос:

Я пытаюсь реализовать байесовскую сверточную нейронную сеть, используя Pytorch на Python 3.7. Я в основном ориентируюсь на реализацию Шридхара. При запуске моего CNN с нормализованными данными и данными MNIST расхождение KL составляет NaN после пары итераций. Я уже реализовал линейные слои таким же образом, и они отлично работали.

Я нормализовал данные следующим образом:

 train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist', train=True, download=True, 
               transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
               ])), batch_size=BATCH_SIZE, shuffle=True, **LOADER_KWARGS)
eval_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist', train=False, download=True, 
              transform=transforms.Compose([
               transforms.ToTensor(),
               transforms.Normalize((0.1307,), (0.3081,))
              ])), batch_size=EVAL_BATCH_SIZE, shuffle=False, **LOADER_KWARGS)
 

Моя реализация Conv-слоя выглядит следующим образом:

 class BayesianConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, prior_sigma, kernel_size, stride=1, padding=0, dilation=1, groups=1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normal = torch.distributions.Normal(0,1)

        # conv-parameters
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding 
        self.dilation = dilation 
        self.groups = groups

        # Weight parameters
        self.weight_mu = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size).uniform_(0, 0.1))
        self.weight_rho = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size).uniform_(-3,0.1))
        self.weight_sigma = 0
        self.weight = 0

        # Bias parameters
        self.bias_mu = nn.Parameter(torch.Tensor(out_channels).uniform_(0, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(out_channels).uniform_(-3,0.1))
        self.bias_sigma = 0
        self.bias = 0

        # prior
        self.prior_sigma = prior_sigma

    def forward(self, input, sample=False, calculate_log_probs=False): 
        # compute sigma out of rho: sigma = log(1 e^rho)
        self.weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        self.bias_sigma = torch.log1p(torch.exp(self.bias_rho))

        # sampling process -> use local reparameterization trick
        activations_mu = F.conv2d(input.to(DEVICE), self.weight_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups)  
        activations_sigma = torch.sqrt(1e-16   F.conv2d((input**2).to(DEVICE), self.weight_sigma**2, self.bias_sigma**2, self.stride, self.padding, self.dilation, self.groups)) 
        activation_epsilon = Variable(self.weight_mu.data.new(activations_sigma.size()).normal_(mean=0, std=1))

        outputs = activations_mu   activations_sigma * activation_epsilon 


        if self.training or calculate_log_probs:
            self.kl_div = 0.5 * ((2 * torch.log(self.prior_sigma / self.weight_sigma) - 1   (self.weight_sigma / self.prior_sigma).pow(2)   ((0 - self.weight_mu) / self.prior_sigma).pow(2)).sum() 
                              (2 * torch.log(0.1 / self.bias_sigma) - 1   (self.bias_sigma / 0.1).pow(2)   ((0 - self.bias_mu) / 0.1).pow(2)).sum())
    
        return outputs
 

Реализация соответствующей Conv-сети выглядит следующим образом:

 class BayesianConvNetwork(nn.Module):
    # Set up network by definining layers
    def __init__(self):
        super().__init__()
        self.conv1 = layers.BayesianConv2d(1, 24, prior_sigma=0.1, kernel_size = (5,5), padding=2) 
        self.pool1 = nn.MaxPool2d(kernel_size=3,stride=2, padding=1) 
        self.conv2 = layers.BayesianConv2d(24, 48, prior_sigma=0.1, kernel_size = (5,5), padding=2) 
        self.pool2 = nn.MaxPool2d(kernel_size=3,stride=2, padding=1) 
        self.conv3 = layers.BayesianConv2d(48, 64, prior_sigma=0.1, kernel_size = (5,5), padding=2)   
        self.pool3 = nn.MaxPool2d(kernel_size=3,stride=2, padding=1)
        self.fcl1 = layers.BayesianLinearWithLocalReparamTrick(4*4*64, 256, prior_sigma=0.1)
        self.fcl2 = layers.BayesianLinearWithLocalReparamTrick(256, 10, prior_sigma=0.1)

    # define forward function by assigning corresponding activation functions to layers
    def forward(self, x, sample=False):
        x = F.relu(self.conv1(x, sample))          
        x = self.pool1(x)
        x = F.relu(self.conv2(x, sample))          
        x = self.pool2(x)
        x = F.relu(self.conv3(x, sample))          
        x = self.pool3(x)
        x = x.view(-1, 4*4*64)
        x = F.relu(self.fcl1(x, sample))         
        x = F.log_softmax(self.fcl2(x, sample), dim=1) 
        return x
    
    # summing up KL-divergences to obtain overall KL-divergence-value
    def total_kl_div(self):
        return (self.conv1.kl_div   self.conv2.kl_div   self.conv3.kl_div   self.fcl1.kl_div   self.fcl2.kl_div) 
    
    # sampling prediction: perform prediction for each of the "different networks" that result from the weight distributions
    def sample_elbo(self, input, target, batch_idx, nmbr_batches, samples=SAMPLES):
        outputs = torch.zeros(samples, target.shape[0], CLASSES).to(DEVICE)
        kl_divs = torch.zeros(samples).to(DEVICE)
        for i in range(samples):                       # sample through networks
            outputs[i] = self(input, sample=True)      # perform prediction
            kl_divs[i] = self.total_kl_div()           # calculate total kl_div of the network
        kl_div = kl_divs.mean()                        # compute mean kl_div from all samples
        negative_log_likelihood = F.nll_loss(outputs.mean(0), target, size_average=False)
        loss = kl_weighting * kl_div   negative_log_likelihood
        return loss
 

Кто-нибудь сталкивался с такой же проблемой или знает, как ее решить?

Заранее большое спасибо!

Ответ №1:

Я понял, что это, похоже, проблема с SGD-оптимизатором. Использование Адама в качестве оптимизатора решило проблему, хотя я не знаю причины этого. Если у кого-нибудь есть ответ на вопрос, почему это работает с Adam, но не с SGD, не стесняйтесь комментировать.