Как объединить GridSearchCV от sklearn и Keras с параллельными заданиями без утечки памяти

#python #tensorflow #keras #scikit-learn #parallel-processing

Вопрос:

Я обучаю тысячи мелких сетей с использованием tensorflow-cpu (keras) и gridsearchcv, используя несколько заданий (n_jobs=-1). Прямо сейчас, когда tensorflow создает экземпляр одного графика для всего выполнения, и память расходуется. Я бы хотел, чтобы каждая созданная модель была удалена после прогнозирования результатов в наборе данных проверки, так как меня интересуют только показатели, полученные в процессе перекрестной проверки.

Это код для создания нейронной сети

 
def create_model(n_neurons:int=2,
                 input_size:int=16,
                 activation:str='tanh') -> keras.models.Sequential:
    """
    Create MLP dense network

    Arguments
    ---------
    n_neurons: int, default=2
        The number of neurons to be used in the hidden layer
    input_size: int, default=16
        The number of neurons in the input layers
    activation: str, default='tanh'


    Returns
    -------
    A `Sequential` model object
    """
    # I WOULD LIKE TO CALL tf.keras.backend.clear_session() here
    # As gc strategy
    model = keras.models.Sequential()
    # add the input layers
    model.add(keras.Input(shape=(input_size,)))
    model.add(keras.layers.LayerNormalization())
    # add hidden layers
    model.add(keras.layers.Dense(n_neurons,
                                 name="hidden_layer",
                                 activation=activation))
    # add output layer
    model.add(keras.layers.Dense(1, activation=activation))
    model.compile(loss='mse',
                  optimizer="rmsprop",
                  metrics=['accuracy',
                           keras.metrics.Precision(name='precision'),
                           keras.metrics.Recall(name='recall')])
    return model
 

Обратите внимание, что я мог бы использовать сеанс clear_session в начале области build_fn. И это сработало бы, если бы я использовал один процесс. Однако меня интересует параллельное выполнение GridSearchCV, и этот метод удаляет глобальные переменные из tensorflow, поэтому график уничтожается для всего процесса в разные временные рамки, нарушая результаты.

вот метод обучения

 def mlp_train(dataset: str,
              X: pd.DataFrame,
              y: pd.DataFrame,
              params:dict ={}) -> pd.DataFrame:
    """
     Train a MLP network using backpropagation
    """
    n_init = params.get('n_init', 1)
    max_neurons = params.get('max_neurons', 15)
    min_neurons = params.get('min_neurons', 2)
    n_epochs = params.get('n_epochs', 100)
    patience = params.get('patience', 10)
    validation_split = params.get('validation_split', .2)

    ## Setup for network
    # Guarantee randomness of input data
    X = shuffle(X)
    input_size = X.shape[1]
    y = y[X.index]
    cv = X.shape[0]

    early_stopping = EarlyStopping(monitor='val_acc', patience=100)
    callbacks = [early_stopping,]

    model = KerasClassifier(build_fn=create_model,
                            epochs=n_epochs,
                            validation_split=validation_split,
                            input_size=input_size,
                            batch_size=cv//2,
                            verbose=0)

    n_neurons = get_n_neurons(min_neurons, max_neurons, n_init)

    # CV
    param_grid = dict(
        n_neurons=n_neurons
    )
    # Multi threads
    with parallel_backend("multiprocessing", n_jobs=-1):
        grid = GridSearchCV(
            estimator=model,
            param_grid=param_grid,
            scoring='f1',
            refit=False,
            verbose=2,
            return_train_score=False,
            cv=cv)

        grid_search_obj = grid.fit(X.values.astype('float32'),
                                   y.values.astype('float32'),
                                   callbacks=callbacks)

        loo_results = pd.DataFrame(grid_search_obj.cv_results_)
  
    # Avoid memory leak from keras. how ever all runs for GridSearchCV are store in memory
    keras.backend.clear_session()
 

Cheers