#pytorch #mnist #generative-adversarial-network #log-likelihood
#пыторч #mnist #генеративная-состязательная-сеть #логарифмическая вероятность
Вопрос:
Я пытаюсь переосмыслить оригинальную статью GAN Яна Гудфеллоу и др. И мне нужно показать, что моя реализация достигает тех же или аналогичных результатов, что и авторы. Но я не уверен, как оценить этот показатель. Я взглянул на их реализацию, но получил несколько забавных результатов. В документе они сообщают о 225 — 2 в MNIST для этого показателя, в то время как результаты, которые я получаю, приведены ниже -400000000. Я подумал, что, возможно, модель плохая, но она генерирует действительно хорошие изображения цифр MNIST.
Может кто-нибудь сказать мне, что я делаю не так?
Ниже приведен код, который я использовал. Я скопировал часть кода из официальной реализации.
допустимое примечание: переменными являются изображения, взятые из набора данных MNIST.
def get_nll(x, parzen, batch_size=10):
"""
Credit: Yann N. Dauphin
"""
inds = range(x.shape[0])
n_batches = int(numpy.ceil(float(len(inds)) / batch_size))
print("N batches:", n_batches)
times = []
nlls = []
for i in range(n_batches):
begin = time.time()
nll = parzen(x[inds[i::n_batches]])
end = time.time()
times.append(end-begin)
nlls.extend(nll)
if i % 10 == 0:
print(i, numpy.mean(times), numpy.mean(nlls))
return numpy.array(nlls)
def log_mean_exp(a):
"""
Credit: Yann N. Dauphin
"""
max_ = a.max(1)
return max_ T.log(T.exp(a - max_.dimshuffle(0, 'x')).mean(1))
def cross_validate_sigma(samples, data, sigmas, batch_size):
lls = []
for sigma in sigmas:
print("Sigma:", sigma)
parzen = theano_parzen(samples, sigma)
tmp = get_nll(data, parzen, batch_size = batch_size)
lls.append(numpy.asarray(tmp).mean())
del parzen
gc.collect()
ind = numpy.argmax(lls)
print(max(lls))
return sigmas[ind]
noise = torch.randn((10000, 100), device=device)
gen_model.eval()
gan_out = gen_model(noise)
sigma_range = numpy.logspace(-1., 0., num=10)
sigma = cross_validate_sigma(gan_out.reshape(10000,-1), valid[0:10000], sigma_range, 100)