#colors #pytorch #patch
Вопрос:
У меня есть это изображение 32×32 в качестве образца (разрядность 32 бита).:
и я хотел бы разделить его на четыре части. Это мой код:
import os from PIL import Image from torch.utils.data import Dataset import torch import torchvision from matplotlib import pyplot as plt from torchvision import transforms import numpy as np main_dir = './dataset' images = './images/' patch_size = 16 stride = 16 def DivideInPatches(dataset, size, stride): patches = [] for i in dataset: patches.append(i.unfold(1, size, stride).unfold(2, size, stride).reshape((-1, 1, size, stride))) return patches class CustomDataSet(Dataset): def __init__(self, main_dir, transform): self.main_dir = main_dir self.transform = transform self.all_imgs = os.listdir(main_dir) def __len__(self): return len(self.all_imgs) def __getitem__(self, idx): img_loc = os.path.join(self.main_dir, self.all_imgs[idx]) image = Image.open(img_loc).convert('L') tensor_image = self.transform(image) return tensor_image transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[-1],std=[2]) # Normalize data between (0,1) ]) dataset = CustomDataSet(main_dir, transform) fig = plt.figure() plt.imshow(np.transpose(dataset.__getitem__(0), (1, 2, 0)), cmap='gray') fig.savefig(images 'pyplot_image.png') plt.close('all') torchvision.utils.save_image(dataset.__getitem__(0), images 'image.png') print(dataset.__getitem__(0)) channel = dataset[0].shape[0] patches = DivideInPatches(dataset, patch_size, stride) patches = torch.stack(patches).reshape(-1, channel, patch_size, patch_size) for i in range(0, 4): torchvision.utils.save_image(patches[i], images 'patch_' str(i) '.png') print(patches[i]) fig = plt.figure() plt.imshow(np.transpose(patches[i], (1, 2, 0)), cmap='gray') fig.savefig(images 'pyplot_patch_' str(i) '.png') plt.close('all')
Это отпечаток оригинального изображения:
tensor([[[0.5000, 0.5000, 0.5000, ..., 0.7490, 0.7490, 0.7490], [0.5000, 0.5000, 0.5000, ..., 0.7490, 0.7490, 0.7490], [0.5000, 0.5000, 0.5000, ..., 0.7490, 0.7490, 0.7490], ..., [0.8824, 0.8824, 0.8824, ..., 1.0000, 1.0000, 1.0000], [0.8824, 0.8824, 0.8824, ..., 1.0000, 1.0000, 1.0000], [0.8824, 0.8824, 0.8824, ..., 1.0000, 1.0000, 1.0000]]])
Печать показывает, что исходные цвета сдвинуты на 0,5, на самом деле графики неверны. Более того, когда я рисую патчи с помощью pyplot, я получаю только черные изображения. Что я делаю не так?