#python #numpy #tensorflow
#python #numpy #тензорный поток
Вопрос:
Здравствуйте, я изучаю GAN и глубокое обучение, и в целом, когда я работал с этим, я делал с массивами изображений NumPy, но для этого домашнего задания я получаю данные с помощью tfds следующим образом:
test_split, valid_split, train_split = tfds.Split.TRAIN.subsplit([10, 15, 75])
test_set_raw = tfds.load('cats_vs_dogs', split=test_split, as_supervised=True)
valid_set_raw = tfds.load('cats_vs_dogs', split=valid_split, as_supervised=True)
train_set_raw = tfds.load('cats_vs_dogs', split=train_split, as_supervised=True)
Проблема в том, что я хочу выполнить увеличение данных с помощью нескольких из этих примеров, но я не могу получить доступ к каждому изображению на них _OptionsDataset
только с помощью take(), но я хочу повторить это, чтобы увеличить данные для каждого изображения и добавить эти новостные изображения.
Я мог бы сделать это с помощью NumPy и двух массивов, но я понятия не имею, как это можно сделать с _OptionsDataset
.
Возможно ли это? Как я могу это сделать?, возможно ли преобразовать _OptionsDataset
в массив NumPy и снова преобразовать массив NumPy в _OptionsDataset?
Спасибо
Ответ №1:
tf.image
имеет кучу случайных преобразований, которые вы можете использовать, вам не нужен Numpy. Вот пример. Мне пришлось выбирать разделения немного по-другому, так как у меня другая версия. Вот документация для tf.image
.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import tensorflow_datasets as tfds
[train_set_raw] = tfds.load('cats_vs_dogs', split=['train[:100]'], as_supervised=True)
def augment(tensor):
tensor = tf.cast(x=tensor, dtype=tf.float32)
tensor = tf.image.rgb_to_grayscale(images=tensor)
tensor = tf.image.resize(images=tensor, size=(96, 96))
tensor = tf.divide(x=tensor, y=tf.constant(255.))
tensor = tf.image.random_flip_left_right(image=tensor)
tensor = tf.image.random_brightness(image=tensor, max_delta=2e-1)
tensor = tf.image.random_crop(value=tensor, size=(64, 64, 1))
return tensor
train_set_raw = train_set_raw.shuffle(128).map(lambda x, y: (augment(x), y)).batch(16)
import matplotlib.pyplot as plt
plt.imshow((next(iter(train_set_raw))[0][0][..., 0].numpy()*255).astype(int))
plt.show()
Комментарии:
1. Обратите внимание, что изображение здесь на самом деле в оттенках серого. Он желтоватый, потому что это фильтр matplotlib по умолчанию