Как избежать ошибок ООМ при повторном обучении и прогнозировании в TensorFlow?

#python #tensorflow #out-of-memory

#python #tensorflow #нехватка памяти

Вопрос:

У меня есть некоторый код в TensorFlow, который берет базовую модель, настраивает (обучает) ее некоторым данным, а затем использует модель для predict() использования некоторых других данных. Все это инкапсулировано в main() метод модуля и отлично работает.

Однако, когда я запускаю этот код в цикле над разными базовыми моделями, я получаю ООМ, например, после 7 базовых моделей. Ожидается ли это? Я ожидал бы, что Python очищается после каждого main() вызова. Разве TensorFlow этого не делает? Как я могу заставить это сделать?

Редактировать: вот MWE, показывающий не сбои ООМ, а увеличение потребления памяти:

 import gc
import os

import numpy as np
import psutil
import tensorflow as tf

tf.get_logger().setLevel("ERROR")  # Suppress "tf.function retracing" warnings
process = psutil.Process(os.getpid())
for i in range(100):
    (model := tf.keras.applications.mobilenet.MobileNet()).compile(loss="mse")
    history = model.fit(
        x=(x := tf.zeros((1, *model.input.shape[1:]))),
        y=(y := tf.zeros((1, *model.output.shape[1:]))),
        verbose=0,
    )
    prediction = model.predict(x)
    _ = gc.collect()
    # tf.keras.backend.clear_session()
    print(f"rss {i}: {process.memory_info().rss >> 20} MB")
  

На моем компьютере (CPU) он печатает

 rss 0: 374 MB
rss 1: 438 MB
rss 2: 478 MB
rss 3: 517 MB
rss 4: 554 MB
rss 5: 588 MB
rss 6: 634 MB
rss 7: 669 MB
rss 8: 686 MB
rss 9: 726 MB
...
rss 30: 1386 MB
rss 31: 1413 MB
rss 32: 1445 MB
rss 33: 1476 MB
rss 34: 1506 MB
rss 35: 1536 MB
rss 36: 1568 MB
rss 37: 1597 MB
rss 38: 1630 MB
rss 39: 1662 MB
...
  

С tf.keras.backend.clear_session() раскомментированным это лучше, но еще не идеально:

 rss 0: 374 MB
rss 1: 420 MB
rss 2: 418 MB
rss 3: 450 MB
rss 4: 447 MB
rss 5: 469 MB
rss 6: 469 MB
rss 7: 475 MB
rss 8: 487 MB
rss 9: 494 MB
...
rss 40: 519 MB
rss 41: 516 MB
rss 42: 517 MB
rss 43: 520 MB
rss 44: 519 MB
rss 45: 519 MB
rss 46: 521 MB
rss 47: 517 MB
rss 48: 521 MB
rss 49: 521 MB
...
rss 90: 531 MB
rss 91: 531 MB
rss 92: 531 MB
rss 93: 531 MB
rss 94: 532 MB
rss 95: 532 MB
rss 96: 533 MB
rss 97: 534 MB
rss 98: 533 MB
rss 99: 533 MB
  

Переключение порядка gc.collect() и tf.keras.backend.clear_session() также не помогло.

Комментарии:

1. Пожалуйста, опубликуйте минимальный рабочий пример или хотя бы наброски кода.

2. Потенциально связанный, я просто пробую это: github.com/tensorflow/tensorflow/issues/14181

3. Я подал github.com/tensorflow/tensorflow/issues/43702

4. Столкнувшись с аналогичной проблемой, было найдено лучшее решение, чем: tf.keras.backend.clear_session();gc.collect()