#python #tensorflow
Вопрос:
Я пытаюсь поместить изображение с 3 каналами в tensorflow 2.0 и получаю ошибку типа данных
TensorFlow TypeError: Value passed to parameter input has DataType uint8 not in list of allowed values: float16, float32
Это моя модель:
def conv_block(number_of_filters, kernel_size, strides=(1, 1), padding='SAME', activation=tf.nn.relu):
return tf.keras.layers.Conv2D(
filters=number_of_filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
activation=activation)
class ImageSearchModel(object):
def __init__(self, learning_rate, image_size, number_of_classes):
tf.compat.v1.reset_default_graph()
model = tf.keras.models.Sequential()
#convolutional layers
model.add(conv_block(number_of_filters=64, kernel_size=(3,3)))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='SAME'))
model.add(conv_block(number_of_filters=128, kernel_size=(3,3)))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='SAME'))
model.add(conv_block(number_of_filters=256, kernel_size=(5,5)))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='SAME'))
model.add(conv_block(number_of_filters=512, kernel_size=(5,5)))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='SAME'))
#flattening
model.add(tf.keras.layers.Flatten())
#dense
model.add(dense_block(units=128))
model.add(tf.keras.layers.Dropout(rate=0.2))
model.add(dense_block(units=256))
model.add(tf.keras.layers.Dropout(rate=0.2))
model.add(dense_block(units=512))
model.add(tf.keras.layers.Dropout(rate=0.2))
model.add(dense_block(units=1024))
model.add(tf.keras.layers.Dropout(rate=0.2))
#output
model.add(dense_block(units=number_of_classes, activation=tf.nn.softmax))
# compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
self.model = model
def Train(self, X_train, y_train, X_test, y_test):
for i, s in enumerate(X_train):
X_train[i] = tf.image.convert_image_dtype(s, dtype=tf.float16)
for i, s in enumerate(X_test):
X_test[i] = tf.image.convert_image_dtype(s, dtype=tf.float16)
self.model.fit(X_train, y_train, epochs=1, batch_size=32, verbose=0)
loss, acc = self.model.evaluate(X_test, y_test, verbose=0)
print('Test Accuracy: %.3f' % acc)
Это мой абонент:
print(f"There are {len(classes)} classes: {classes}")
model = ImageSearchModel(0.001, (32, 32), len(classes))
model.Train(X_train, y_train, X_test, y_test)
Это моя обработка изображений:
def image_loader(image_path, image_size):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, image_size, cv2.INTER_CUBIC)
return image
def dataset_preprocessing(dataset_path, image_size):
images = []
labels = []
image_paths = []
for image_name in os.listdir(dataset_path):
label = image_name.split(".")[0]
label = label.split("_", 2)[1]
image_path = os.path.join(dataset_path, image_name)
images.append(image_loader(image_path, image_size))
image_paths.append(image_path)
labels.append(label)
classes = unique(labels)
assert len(images) == len(labels), f"{len(images)} != {len(labels)}"
return np.array(images), np.array(labels), np.array(classes)
Вот как выглядит мой вклад:
Как правильно подключить 3 канала к CNN?
Ответ №1:
Вы очень близки: вам нужно разделить целочисленные значения на 256 (2^8 от uint8
), чтобы получить ожидаемые значения с плавающей точкой. float16
должно быть достаточно точности, чтобы дать вам полезную модель обучения.
Комментарии:
1. Это скорее деление на
255
(максимальное значение uint8), если мы хотим правильно масштабировать значения пикселей между 0 и 1.2. Это зависит от вашего заданного диапазона. Поскольку мы пытаемся преобразовать непосредственно в другую двоичную форму, деление на 256 дает гораздо более чистый перевод значения.
3. В любом случае это не так уж важно, но тот факт, что данные не будут оптимально масштабироваться (уменьшены на 1/256), должен больше беспокоить процесс обучения. Обычно мы масштабируем значения пикселей на 255 .
4. Где мне это сделать? Извините, вы можете записать это в код? Я попытался использовать convert_image_dtype, но, похоже, это не работает. Я думаю, что я просто не очень хорошо знаком с синтаксисом, обновил свой код, чтобы отразить это.
5. @BillSoftwareEngineer Это cifar-10?