Почему мой GAN не создает больше хороших изображений после определенного момента?

#python-3.x #machine-learning #deep-learning #tf.keras #generative-adversarial-network

#python-3.x #машинное обучение #глубокое обучение #tf.keras #генеративный-состязательный-сетевой

Вопрос:

Вопрос

Я обучал gan генерировать человеческие лица. Примерно за 500 эпох он научился генерировать такие изображения:

введите описание изображения здесь

Что ж, теперь это изображение не так уж плохо. Мы видим лицо в центре изображения.

Затем я обучал его более 1000 эпох, и он ничему не научился. Он по-прежнему генерировал изображения того же типа, что и показано выше. Почему это было? Почему мой gan не научился создавать еще лучшие изображения?

Код для моделей

Вот код дискриминатора:

     def define_discriminator(in_shape=(64, 64, 3)):
        Model = Sequential([
                Conv2D(32, (3, 3), padding='same', input_shape=in_shape),
                BatchNormalization(),
                LeakyReLU(alpha=0.2),
                MaxPooling2D(pool_size=(2,2)),
                Dropout(0.2),

                Conv2D(64, (3,3), padding='same'),
                BatchNormalization(),
                LeakyReLU(alpha=0.2),
                MaxPooling2D(pool_size=(2,2)),
                Dropout(0.3),

                Conv2D(128, (3,3), padding='same'),
                BatchNormalization(),
                LeakyReLU(alpha=0.2),
                MaxPooling2D(pool_size=(2,2)),
                Dropout(0.3),

                Conv2D(256, (3,3), padding='same'),
                BatchNormalization(),
                LeakyReLU(alpha=0.2),
                MaxPooling2D(pool_size=(2,2)),
                Dropout(0.4),

                Flatten(),

                Dense(1, activation='sigmoid')
])
        opt = Adam(lr=0.00002)
        Model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])

        return Model
  

Вот код генератора и GAN:

 def define_generator(in_shape=100):
    Model = Sequential([
                Dense(256*8*8, input_dim=in_shape),
                BatchNormalization(),
                LeakyReLU(alpha=0.2),
                Reshape((8, 8, 256)),

                Conv2DTranspose(256, (3,3), strides=(2,2), padding='same'),
                BatchNormalization(),
                LeakyReLU(alpha=0.2),

                Conv2DTranspose(64, (3,3), strides=(2,2), padding='same'),
                BatchNormalization(),
                LeakyReLU(alpha=0.2),

                Conv2DTranspose(3, (4, 4), strides=(2,2), padding='same', activation='sigmoid')
    ])
    return Model

def define_gan(d_model, g_model):
    d_model.trainable = False
    model = Sequential([
                g_model,
                d_model
    ])
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model
  

Entire Reproducible Code

 from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization
from tensorflow.keras.layers import Dropout, Flatten, Dense, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, Activation, Reshape, LeakyReLU
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adam
from numpy import ones
from numpy import zeros
from numpy.random import rand
from numpy.random import randint
from numpy.random import randn
from numpy import vstack
from numpy import array
import os
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from matplotlib import pyplot


def load_data(filepath):
    image_array = []
    n = 0
    for fold in os.listdir(filepath):
      if fold != 'wiki.mat':
        if n > 1:
            break
        for img in os.listdir(os.path.join(filepath, fold)):
            image = load_img(filepath   fold    '/'  img, target_size=(64, 64))
            img_array = img_to_array(image)
            img_array = img_array.astype('float32')
            img_array = img_array / 255.0
            image_array.append(img_array)
        n  = 1
    return array(image_array)
def generate_latent_points(n_samples, latent_dim=100):
    latent_points = randn(n_samples * latent_dim)
    latent_points = latent_points.reshape(n_samples, latent_dim)
    return latent_points

def generate_real_samples(n_samples, dataset):
    ix = randint(0, dataset.shape[0], n_samples)
    x = dataset[ix]
    y = ones((n_samples, 1))
    return x, y

def generate_fake_samples(g_model, n_samples):
    latent_points = generate_latent_points(n_samples)
    x = g_model.predict(latent_points)
    y = zeros((n_samples, 1))
    return x, y

def save_plot(examples, epoch, n=10):
    # plot images
    for i in range(n * n):
        # define subplot
        pyplot.subplot(n, n, 1   i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
    pyplot.imshow(examples[i, :, :, 0])
    # save plot to file
    filename = 'generated_plot_ed.png' % (epoch 1)
    pyplot.savefig(filename)
    pyplot.close()

def summarize_performance(d_model, g_model, gan_model, dataset, epoch, n_samples=100):
    real_x, real_y = generate_real_samples(n_samples, dataset)
    _, d_real_acc = d_model.evaluate(real_x, real_y)
    fake_x, fake_y = generate_fake_samples(g_model, n_samples)
    _, d_fake_acc = d_model.evaluate(fake_x, fake_y)

    latent_points, y = generate_latent_points(n_samples), ones((n_samples, 1))
    gan_loss = gan_model.evaluate(latent_points, y)

    print('Epoch %d, acc_real=%.3d, acc_fake=%.3f, gan_loss=%.3f' % (epoch, d_real_acc, d_fake_acc, gan_loss))

save_plot(fake_x, epoch)
filename = 'Genarator_Model % d' % (epoch   1)
g_model.save(filename)

def train(d_model, g_model, gan_model, dataset, epochs=200):
    batch_size = 64
    half_batch = int(batch_size / 2)
    batch_per_epoch = int(dataset.shape[0] / batch_size)
    for epoch in range(epochs):
        for i in range(batch_per_epoch):
            real_x, real_y = generate_real_samples(half_batch, dataset)
            _, d_real_acc = d_model.train_on_batch(real_x, real_y)
            fake_x, fake_y = generate_fake_samples(g_model, half_batch)
            _, d_fake_acc = d_model.train_on_batch(fake_x, fake_y)

            latent_points, y = generate_latent_points(batch_size), ones((batch_size, 1))
            gan_loss = gan_model.train_on_batch(latent_points, y)

            print('Epoch %d, acc_real=%.3d, acc_fake=%.3f, gan_loss=%.3f' % (epoch, d_real_acc, d_fake_acc, gan_loss))
        if (epoch % 2) == 0:
            summarize_performance(d_model, g_model, gan_model, dataset, epoch)

dataset = load_data(filepath) # filepath is not defined since every person will have seperate filepath

discriminator_model = define_discriminator()
generator_model = define_generator()
gan_model = define_gan(discriminator_model, generator_model)

train(discriminator_model, generator_model, gan_model, dataset)
  

Набор данных

Если вы хотите, вот набор данных.