tensorflow: метод обратного вызова `on_train_batch_end` медленный по сравнению со временем выполнения пакета

#tensorflow #keras #conv-neural-network

#tensorflow #keras #conv-нейронная сеть

Вопрос:

Я пытаюсь создать модель CNN с использованием RandomSearch, но она очень медленная и выдает эту ошибку tensorflow:Callback method on_train_batch_end is slow compared to the batch time . Я запускаю свой код в Google colab с аппаратным ускорением, установленным на gpu. это мой код

 def model_builder(hp):
    model=Sequential([
        Conv2D(filters=hp.Int('conv_1_filter',min_value=32,max_value=128,step=32),
               kernel_size=hp.Int('conv_1_filter',min_value=2,max_value=3,step=1),
               activation='relu',
               padding='same',
               input_shape=(200,200,3)),
        MaxPooling2D(pool_size=(2,2),strides=(2,2)),
        
        Conv2D(filters=hp.Int('conv_2_filter',min_value=32,max_value=128,step=32),
               kernel_size=hp.Int('conv_2_filter',min_value=2,max_value=3,step=1),
               padding='same',
               activation='relu'),
        MaxPooling2D(pool_size=(2,2),strides=(2,2)),
        
        Flatten(),
        
        Dense(units=hp.Int('dense_1_units',min_value=32,max_value=512,step=128),
              activation='relu'),
        
        Dense(units=10,
              activation='softmax')
               
    ])
    
    model.compile(optimizer=Adam(hp.Choice('learning_rate',values=[1e-1,1e-3,3e-2])),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model
 

затем произвольный поиск и подгонка

 tuner=RandomSearch(model_builder,
                   objective='val_accuracy',
                   max_trials=2,
                   directory='projects',
                   project_name='Hypercars CNN'
                  )
tuner.search(X_train,Y_train,epochs=2,validation_split=0.2)
 

Ответ №1:

Это вызвано тем, что другие операции, которые выполняются в конце каждого пакета, потребляют больше времени, чем сам пакет. Возможно, у вас действительно небольшие пакеты, т.Е. Любая операция, которая медленнее по сравнению с вашими исходными пакетами.

Increasing the batch size следует решить эту проблему, или вы можете use_mutiprocessing = True model.fit() ввести и выбрать соответствующее количество рабочих для более эффективной генерации обучающих пакетов.

Два потока обсуждают эту проблему:

  1. Поток 1
  2. Поток 2

Комментарии:

1. use_mutiprocessing = True

Ответ №2:

use_multiprocessing = True может помочь удалить это предупреждение, но появляется другое предупреждение, связанное с использованием многопроцессорной обработки в tf2.