DCgAN генерирует только шумы

#deep-learning #computer-vision #generative-adversarial-network #dcgan

#глубокое обучение #компьютерное зрение #генеративный-состязательный-сетевой #dcgan

Вопрос:

Я новичок в GANs, и у меня возникли проблемы с обучением DCGAN на mnist, когда я работал с GAN с линейными слоями, все было хорошо, и генератор генерирует довольно хорошие изображения. Но когда я начал работать со сверточным GAN, генератор генерировал только шум. Кто-нибудь знает, как исправить эту проблему?

Сгенерированные изображения

Потеря дискриминатора и потеря генератора

Вот моя Нейронная сеть и мой цикл обучения,

Нейронная сеть :

Дискриминатор

 class Discriminator(nn.Module):
    def __init__(self, ch, F):
        super(Discriminator, self).__init__()

        def block(in_ch,out_ch,k,s,p,final=False,bn=True):
            block = []
            block.append(nn.Conv2d(in_ch,out_ch,kernel_size=k,stride=s,padding=p))
            if not final and bn:
                block.append(nn.BatchNorm2d(out_ch))
                block.append(nn.LeakyReLU(0.2))
            elif not bn and not final:
                block.append(nn.LeakyReLU(0.2))
            elif final and not bn:
                block.append(nn.Sigmoid())
            return block
    
        self.D = nn.Sequential(
            *block(ch,F,k=4,s=2,p=1,bn=False),
            *block(F,F*2,k=4,s=2,p=1),
            *block(F*2,F*4,k=4,s=2,p=1),
            *block(F*4,F*8,k=4,s=2,p=1),
            *block(F*8,1,k=4,s=2,p=0,final=True,bn=False)
        )

    def forward(self,x): return self.D(x)
 

Генератор

 class Generator(nn.Module):
    def __init__(self, ch_noise, ch_img, features_g):
        super(Generator, self).__init__()

        def block(in_ch,out_ch,k,s,p,final=False):
            block = []
            block.append(nn.ConvTranspose2d(in_ch,out_ch,kernel_size=k,stride=s,padding=p))
            if not final:
                block.append(nn.BatchNorm2d(out_ch))
                block.append(nn.ReLU())
            if final:
                block.append(nn.Tanh())
            return block

        self.G = nn.Sequential(
            *block(ch_noise, features_g*16, k=4,s=1,p=0),
            *block(features_g*16, features_g*8, k=4, s=2, p=1),
            *block(features_g*8, features_g*4, k=4, s=2, p=1),
            *block(features_g*4, features_g*2, k=4, s=2, p=1),
            *block(features_g*2, 1, k=4, s=2, p=1,final=True)
        )
    def forward(self,z): return self.G(z)
 

Цикл обучения:

 loss_G, loss_D = [],[]
for i in range(epochs):
    D.train()
    G.train()
    st = time.time()
    for idx, (img, _ ) in enumerate(mnist):
        img = img.to(device)
        ## Discriminator ##
        D.zero_grad(set_to_none=True)
        lable = torch.ones(bs,device=device)*0.9
        pred = D(img).reshape(-1)
        loss_d_real = criterion(pred,lable)

        z = torch.randn(img.shape[0],ch_z, 1,1,device=device)
        fake_img = G(z)
        lable = torch.ones(bs,device=device)*0.1
        pred = D(fake_img.detach()).reshape(-1)
        loss_d_fake = criterion(pred,lable)
        D_loss = loss_d_real   loss_d_fake
        D_loss.backward()
        optim_d.step()
        ## Generator ##
        G.zero_grad(True)
        lable = torch.randn(bs,device=device)
        pred = D(fake_img).reshape(-1)
        G_loss = criterion(pred,lable)
        G_loss.backward()
        optim_g.step()
        ## printing on terminal
        if idx % 100 == 0:
            print(f'nBatches done : {idx}/{len(mnist)}')
            print(f'Loss_D : {D_loss.item():.4f}tLoss_G : {G_loss.item():.4f}')
    et = time.time()
    print(f'nEpoch : {i 1}n{time_cal(st,et)}')
    G.eval()
    with torch.no_grad():
        fake_image = G(fixed_noise)
        save_image(fake_image[:25],fp=f'{path_to_img}/{i 1}_fake.png',nrow=5,normalize=True)

    loss_G.append(G_loss.item())
    loss_D.append(D_loss.item())
 

Вот ссылка на мой блокнот Colab, в котором я работал.