#pytorch
Вопрос:
Я создаю свой собственный набор данных:
class MyDataset(Dataset): def __init__(self, folders): self.folders = folders def __len__(self): return len(self.folders) def __getitem__(self, item): pos_file_list = glob(self.folders[item] "/*") positive_img = pos_file_list[1] positive_img = mpimg.imread(positive_img) positive_img = np.transpose(positive_img, (2,0,1)) # positive_img have the type: lt;class 'numpy.ndarray'gt;, shape: (3, 128, 128) return positive_img
И я использую его с:
batch_size = 128 train_ds = MyDataset(train_folder_list) oTrainDL = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) for i, imgs in enumerate(oTrainDL): break
Я получаю следующую гарантию:
UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:189.) return default_collate([torch.as_tensor(b) for b in batch])
Почему я получаю гарантийное сообщение ? Как я могу это исправить ?
Ответ №1:
изменение с return positive_img
на:
return torch.tensor(positive_img)