Как я могу заставить модель возобновить обучение с той эпохи, на которой она остановилась?

#python #tensorflow #machine-learning #keras

Вопрос:

Я тренирую двух автокодеров для Deepfake, и ему нужно пройти цикл в 150 000 эпох. Я остановил его на 10 000, но я хочу, чтобы он мог возобновить обучение с той эпохи, на которой он остановился. Есть ли способ сделать это?

 train_setA = video.loading_images(setA_path)/255.0
train_setB = video.loading_images(setB_path)/255.0


train_setA  = train_setB.mean( axis=(0,1,2) ) - train_setA.mean( axis=(0,1,2) )


batch_size = int(len(os.listdir(setA_path))/20)

print( "press 'q' to stop training and save model" )

for epoch in range(1000000):
    batch_size = 64
    warped_A, target_A = train_util.training_data( train_setA, batch_size )
    warped_B, target_B = train_util.training_data( train_setB, batch_size )

    loss_A = aeA.train_on_batch( warped_A, target_A )
    loss_B = aeB.train_on_batch( warped_B, target_B )
    print( loss_A, loss_B )
    print('Current epoch no... '   str(epoch))

    if epoch % 100 == 0:
        save_model_weights()
        print('Model weights saved')
        test_A = target_A[0:14]
        test_B = target_B[0:14]

    figure_A = np.stack([
        test_A,
        aeA.predict( test_A ),
        aeB.predict( test_A ),
        ], axis=1 )
    figure_B = np.stack([
        test_B,
        aeB.predict( test_B ),
        aeA.predict( test_B ),
        ], axis=1 )

    figure = np.concatenate( [ figure_A, figure_B ], axis=0 )
    figure = figure.reshape( (4,7)   figure.shape[1:] )
    figure = train_util.stack_images( figure )

    figure = np.clip( figure * 255, 0, 255 ).astype('uint8')

    cv2.imshow( "", figure )
    key = cv2.waitKey(1)
    if key == ord('q'):
        save_model_weights()
        exit()
 

Ответ №1:

Я расскажу вам более подробно, что я знаю об этой теме в «Керасе».

Если вы сохраняете веса после каждой эпохи (например, контрольная точка модели), вы можете загрузить сохраненные веса.

Например:

Сохранить:

 weight_save_callback = ModelCheckpoint('/path/to/weights.{epoch:02d}-{val_loss:.2f}.hdf5', monitor='val_loss', save_best_only=False) # or True(Best result)
model.fit(X_train,y_train,batch_size=batch_size,nb_epoch=nb_epoch,callbacks=[weight_save_callback]) 
 

Загрузка:

 model = Sequential()
model.add(...)
model.load('path/to/weights.hf5') 
 

Важно, чтобы модели были одинаковыми.

Поскольку в некоторых оптимизаторах некоторые из их внутренних значений (например, learning rate ) устанавливаются с использованием текущего значения «эпохи», или даже у вас могут быть (пользовательские) обратные вызовы , которые зависят от текущего epoch , initial_epoch позволяет указать начальное epoch значение для начала при обучении. Это в основном необходимо, когда вы обучили свою модель для некоторых эпох, и после сохранения вы хотите загрузить ее и возобновить обучение еще для нескольких эпох, не нарушая состояние объектов, которые зависят от эпохи (например, оптимизатор). Поэтому вы должны установить initial_epoch значение = меньше, чем общее количество эпох(i.e. мы обучали модель, например, в течение 20 эпох и epochs = 40, а затем все возобновится, как если бы вы изначально обучали модель в течение 20 epochs за одну тренировку. Однако обратите внимание, что при использовании встроенных оптимизаторов Keras вам не нужно использовать initial_epoch , так как они хранят и обновляют свое состояние внутренне (без учета значения текущей эпохи), и при сохранении модели также будет сохранено состояние оптимизатора.

Я надеюсь, что помог тебе

Комментарии:

1. Наконец-то кто-то, кто действительно читает вопрос

2. @Tehnorobot Мне нужно было бы использовать контрольную точку модели, где в файле, где были скомпилированы два автокодера, или в файле, где находится код для их обучения.

3. @pasho_6798 Путь, который вы указали при назначении ModelCheckpoint . Я отредактировал свой ответ, вы можете увидеть его в примере.

4. @Tehnorobot спасибо, это мне очень помогло, также у меня есть небольшая проблема, есть ли способ использовать ModelCheckpoint и установить начальную эпоху без необходимости использовать model.fit , но вместо этого сохранить model.train_on_batch . Кажется, я не могу придумать, как продолжать использовать train_on_batch , так как в нем нет callbacks параметра like fit . Кстати, большое вам спасибо за помощь

5. @pasho_6798 Я могу предложить посмотреть здесь: github.com/keras-team/keras/issues/485