Пользовательская функция потерь в TensorFlow 2: работа с отсутствием в размерности пакета

#python #tensorflow #keras

Вопрос:

Я обучаю модель, которая вводит и выводит изображения одинаковой формы (H, W, C) в цветовом пространстве RGB.

Моя функция потери-это MSE над этими изображениями, но в другом цветовом пространстве.
Преобразование цветового пространства определяется transform_space функцией, которая принимает и возвращает одно изображение.

Я наследую tf.keras.losses.Loss , чтобы осуществить эту потерю.
Метод call , однако, снимает изображения не по одному, а пачками форм (None, H, W, C) .
Проблема в том, что первым измерением этого пакета является None .

Я пытался повторить эти пакеты, но получил ошибку iterating over tf.Tensor is not allowed .
Итак, как я должен рассчитать свои потери?

Причины, по которым я не могу использовать новое цветовое пространство в качестве ввода и вывода для своей модели:

  • модель использует один из предварительно tf.keras.applications обученных, который работает с RGB
  • обратное преобразование невозможно выполнить, потому что часть информации теряется во время преобразования

Я использую tf.distribute.MirroredStrategy , если это имеет значение.

 # Takes an image of shape (H, W, C),
# converts it to a new color space
# and returns a new image with shape (H, W, C)
def transform_space(image):
  # ...color space transformation...
  return image_in_a_new_color_space

class MyCustomLoss(tf.keras.losses.Loss):

  def __init__(self):
    super().__init__()

    # The loss function is defined this way
    # due to the fact that I use "tf.distribute.MirroredStrategy"
    mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    self.loss_fn = lambda true, pred: tf.math.reduce_mean(mse(true, pred))

  def call(self, true_batch, pred_batch):

    # Since shape of true/pred_batch is (None, H, W, C)
    # and transform_space expects shape (H, W, C)
    # the following transformations are impossible:
    true_batch_transformed = transform_space(true_batch)
    pred_batch_transformed = transform_space(pred_batch)

    return self.loss_fn(true_batch_transformed, pred_batch_transformed)
 

Ответ №1:

Пакетирование в основном жестко закодировано в дизайне TF. Это лучший способ использовать ресурсы графического процессора для быстрого запуска моделей глубокого обучения. Циклическое выполнение настоятельно не рекомендуется в TF по той же причине — весь смысл использования TF заключается в векторизации: параллельном выполнении множества вычислений.

Можно нарушить эти проектные предположения. Но на самом деле правильный способ сделать это-реализовать преобразование векторизованным способом (например, сделать transform_space() пакетами).

К вашему сведению, TF изначально поддерживает преобразования YUV, YUQ и HSV в пакете tf.image, если вы использовали один из них. Или вы можете посмотреть на источник там и посмотреть, сможете ли вы адаптировать его к своим потребностям.

В любом случае, чтобы делать то, что вы хотите, но с потенциально серьезным снижением производительности, вы хотите использовать tf.map_fn.

 true_batch_transformed = tf.map_fn(transform_space, true_batch)
pred_batch_transformed = tf.map_fn(transform_space, pred_batch)