#deep-learning #computer-vision #generative-adversarial-network #dcgan
#глубокое обучение #компьютерное зрение #генеративный-состязательный-сетевой #dcgan
Вопрос:
Я новичок в GANs, и у меня возникли проблемы с обучением DCGAN на mnist, когда я работал с GAN с линейными слоями, все было хорошо, и генератор генерирует довольно хорошие изображения. Но когда я начал работать со сверточным GAN, генератор генерировал только шум. Кто-нибудь знает, как исправить эту проблему?
Потеря дискриминатора и потеря генератора
Вот моя Нейронная сеть и мой цикл обучения,
Нейронная сеть :
Дискриминатор
class Discriminator(nn.Module):
def __init__(self, ch, F):
super(Discriminator, self).__init__()
def block(in_ch,out_ch,k,s,p,final=False,bn=True):
block = []
block.append(nn.Conv2d(in_ch,out_ch,kernel_size=k,stride=s,padding=p))
if not final and bn:
block.append(nn.BatchNorm2d(out_ch))
block.append(nn.LeakyReLU(0.2))
elif not bn and not final:
block.append(nn.LeakyReLU(0.2))
elif final and not bn:
block.append(nn.Sigmoid())
return block
self.D = nn.Sequential(
*block(ch,F,k=4,s=2,p=1,bn=False),
*block(F,F*2,k=4,s=2,p=1),
*block(F*2,F*4,k=4,s=2,p=1),
*block(F*4,F*8,k=4,s=2,p=1),
*block(F*8,1,k=4,s=2,p=0,final=True,bn=False)
)
def forward(self,x): return self.D(x)
Генератор
class Generator(nn.Module):
def __init__(self, ch_noise, ch_img, features_g):
super(Generator, self).__init__()
def block(in_ch,out_ch,k,s,p,final=False):
block = []
block.append(nn.ConvTranspose2d(in_ch,out_ch,kernel_size=k,stride=s,padding=p))
if not final:
block.append(nn.BatchNorm2d(out_ch))
block.append(nn.ReLU())
if final:
block.append(nn.Tanh())
return block
self.G = nn.Sequential(
*block(ch_noise, features_g*16, k=4,s=1,p=0),
*block(features_g*16, features_g*8, k=4, s=2, p=1),
*block(features_g*8, features_g*4, k=4, s=2, p=1),
*block(features_g*4, features_g*2, k=4, s=2, p=1),
*block(features_g*2, 1, k=4, s=2, p=1,final=True)
)
def forward(self,z): return self.G(z)
Цикл обучения:
loss_G, loss_D = [],[]
for i in range(epochs):
D.train()
G.train()
st = time.time()
for idx, (img, _ ) in enumerate(mnist):
img = img.to(device)
## Discriminator ##
D.zero_grad(set_to_none=True)
lable = torch.ones(bs,device=device)*0.9
pred = D(img).reshape(-1)
loss_d_real = criterion(pred,lable)
z = torch.randn(img.shape[0],ch_z, 1,1,device=device)
fake_img = G(z)
lable = torch.ones(bs,device=device)*0.1
pred = D(fake_img.detach()).reshape(-1)
loss_d_fake = criterion(pred,lable)
D_loss = loss_d_real loss_d_fake
D_loss.backward()
optim_d.step()
## Generator ##
G.zero_grad(True)
lable = torch.randn(bs,device=device)
pred = D(fake_img).reshape(-1)
G_loss = criterion(pred,lable)
G_loss.backward()
optim_g.step()
## printing on terminal
if idx % 100 == 0:
print(f'nBatches done : {idx}/{len(mnist)}')
print(f'Loss_D : {D_loss.item():.4f}tLoss_G : {G_loss.item():.4f}')
et = time.time()
print(f'nEpoch : {i 1}n{time_cal(st,et)}')
G.eval()
with torch.no_grad():
fake_image = G(fixed_noise)
save_image(fake_image[:25],fp=f'{path_to_img}/{i 1}_fake.png',nrow=5,normalize=True)
loss_G.append(G_loss.item())
loss_D.append(D_loss.item())