Уменьшение размера ParallelMapDataset

#python #machine-learning #deep-learning #tensorflow2.0 #tpu

#python #машинное обучение #глубокое обучение #tensorflow2.0 #tpu

Вопрос:

 def get_training_dataset():
    lst = [flip, rotate, color]
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)  
    for data_augment in lst:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE) 
    tf.squeeze(dataset, axis=0)
    print(dataset.shape)
    dataset = dataset.map(flip, num_parallel_calls=AUTOTUNE)  
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset
 

Формат набора данных — ParallelMapDataset, но когда я пытаюсь использовать этот код, который добавил одно измерение в набор данных, например, моя форма (512, 512, 3), но она возвращает форму (1, 512, 512, 3).
Функция переворачивания, поворота и цвета выглядит следующим образом:

 def flip(image, label):
    print("random flip")
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    return image, label

def color(image, label):
    print("random color")
    image = tf.image.random_saturation(image, 0.6, 1.6)
    image = tf.image.random_brightness(image, 0.05)
    image = tf.image.random_contrast(image, 0.7, 1.3)
    return image, label

def rotate(image, label):
    print("random rotate")
    image = tf.image.rot90(image, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)),
    return image, label
 

Комментарии:

1. это из-за batch

2. @NicolasGervais это значит, что я должен его удалить? итак, как я могу получить пакет? dataset.batch(BATCH_SIZE)

3. Пакет добавляет измерения с индексом 0