Использование учебника Tensorflow по сегментации изображений для разного количества классов

#python #tensorflow #keras #deep-learning #semantic-segmentation

Вопрос:

Я пытаюсь выполнить семантическую сегментацию с помощью учебника по сегментации изображений по Tensorflow, в котором они используют 3 класса, и я пытаюсь использовать 19 классов, используя маски PNG из набора данных Berkeley DeepDrive. Я изменил каналы ВЫВОДА на 19, но пиксели по-прежнему классифицируются на 3 класса. Я изменил размер изображений до 128х128.

Есть идеи, почему?

Результат, который я получаю без подготовки.

Вот основная часть кода, который я использую —

 
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

#main UNet model

def unet_model(output_channels):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # downsampling
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same') #64x64 -> 128x128
  x = last(x)
  return tf.keras.Model(inputs=inputs, outputs=x)

model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])