#python #machine-learning #deep-learning #tensorflow2.0 #tpu
#python #машинное обучение #глубокое обучение #tensorflow2.0 #tpu
Вопрос:
def get_training_dataset():
lst = [flip, rotate, color]
dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
for data_augment in lst:
dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
tf.squeeze(dataset, axis=0)
print(dataset.shape)
dataset = dataset.map(flip, num_parallel_calls=AUTOTUNE)
dataset = dataset.repeat()
dataset = dataset.shuffle(2048)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTOTUNE)
return dataset
Формат набора данных — ParallelMapDataset, но когда я пытаюсь использовать этот код, который добавил одно измерение в набор данных, например, моя форма (512, 512, 3), но она возвращает форму (1, 512, 512, 3).
Функция переворачивания, поворота и цвета выглядит следующим образом:
def flip(image, label):
print("random flip")
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
return image, label
def color(image, label):
print("random color")
image = tf.image.random_saturation(image, 0.6, 1.6)
image = tf.image.random_brightness(image, 0.05)
image = tf.image.random_contrast(image, 0.7, 1.3)
return image, label
def rotate(image, label):
print("random rotate")
image = tf.image.rot90(image, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)),
return image, label
Комментарии:
1. это из-за
batch
2. @NicolasGervais это значит, что я должен его удалить? итак, как я могу получить пакет?
dataset.batch(BATCH_SIZE)
3. Пакет добавляет измерения с индексом 0