#pytorch #pytorch-dataloader
#пыторч #pytorch-загрузчик данных
Вопрос:
Я пытаюсь выполнить цикл через объект загрузки данных. Тем не менее, я продолжаю получать torch.cat() проблема, когда я просматриваю только этот конкретный список.
listOfData содержит все torch_geometric.data.данные.Объекты данных.
Следующий код-это то, как я создал объект DataLoader.
np.random.seed(seed) np.random.shuffle(listOfData) train_loader = DataLoader(listOfData[:int((length0.80))], batch_size=batch_size_train) test_loader = DataLoader(listOfData[int((length0.80)):], batch_size=batch_size_test)
Затем, когда я попытаюсь выполнить цикл загрузки данных.
for i in train_loader: print(i)
Он распечатывает несколько графиков, а затем выдает эту ошибку:
Batch(x=[462, 300], edge_index=[2, 223], y=[13], batch=[462], ptr=[14]) Batch(x=[501, 300], edge_index=[2, 247], y=[13], batch=[501], ptr=[14]) Batch(x=[764, 300], edge_index=[2, 370], y=[13], batch=[764], ptr=[14]) Batch(x=[490, 300], edge_index=[2, 236], y=[13], batch=[490], ptr=[14]) Batch(x=[353, 300], edge_index=[2, 169], y=[13], batch=[353], ptr=[14]) Batch(x=[452, 300], edge_index=[2, 215], y=[13], batch=[452], ptr=[14]) Batch(x=[375, 300], edge_index=[2, 161], y=[13], batch=[375], ptr=[14]) Batch(x=[622, 300], edge_index=[2, 336], y=[13], batch=[622], ptr=[14]) Batch(x=[355, 300], edge_index=[2, 177], y=[13], batch=[355], ptr=[14]) Batch(x=[506, 300], edge_index=[2, 132], y=[13], batch=[506], ptr=[14]) Batch(x=[486, 300], edge_index=[2, 176], y=[13], batch=[486], ptr=[14]) Batch(x=[534, 300], edge_index=[2, 266], y=[13], batch=[534], ptr=[14]) Batch(x=[540, 300], edge_index=[2, 252], y=[13], batch=[540], ptr=[14]) Batch(x=[560, 300], edge_index=[2, 247], y=[13], batch=[560], ptr=[14]) Batch(x=[600, 300], edge_index=[2, 269], y=[13], batch=[600], ptr=[14]) Batch(x=[486, 300], edge_index=[2, 220], y=[13], batch=[486], ptr=[14]) Batch(x=[228, 300], edge_index=[2, 88], y=[13], batch=[228], ptr=[14]) Batch(x=[473, 300], edge_index=[2, 191], y=[13], batch=[473], ptr=[14]) Batch(x=[322, 300], edge_index=[2, 142], y=[13], batch=[322], ptr=[14]) RuntimeError Traceback (most recent call last) /tmp/ipykernel_4277/2753664710.py in ----gt; 1 for i in train_loader: 2 print(i) /opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in next(self) 519 if self._sampler_iter is None: 520 self._reset() --gt; 521 data = self._next_data() 522 self._num_yielded = 1 523 if self._dataset_kind == _DatasetKind.Iterable and /opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self) 559 def _next_data(self): 560 index = self._next_index() # may raise StopIteration --gt; 561 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 562 if self._pin_memory: 563 data = _utils.pin_memory.pin_memory(data) /opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index) 45 else: 46 data = self.dataset[possibly_batched_index] ---gt; 47 return self.collate_fn(data) /opt/conda/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py in call(self, batch) 37 38 def call(self, batch): ---gt; 39 return self.collate(batch) 40 41 /opt/conda/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py in collate(self, batch) 18 if isinstance(elem, Data) or isinstance(elem, HeteroData): 19 return Batch.from_data_list(batch, self.follow_batch, ---gt; 20 self.exclude_keys) 21 elif isinstance(elem, torch.Tensor): 22 return default_collate(batch) /opt/conda/lib/python3.7/site-packages/torch_geometric/data/batch.py in from_data_list(cls, data_list, follow_batch, exclude_keys) 67 add_batch=True, 68 follow_batch=follow_batch, ---gt; 69 exclude_keys=exclude_keys, 70 ) 71 /opt/conda/lib/python3.7/site-packages/torch_geometric/data/collate.py in collate(cls, data_list, increment, add_batch, follow_batch, exclude_keys) 75 # Collate attributes into a unified representation: 76 value, slices, incs = _collate(attr, values, data_list, stores, ---gt; 77 increment) 78 79 out_store[attr] = value /opt/conda/lib/python3.7/site-packages/torch_geometric/data/collate.py in _collate(key, values, data_list, stores, increment) 146 incs = None 147 --gt; 148 value = torch.cat(values, dim=cat_dim or 0) 149 return value, slices, incs 150 RuntimeError: torch.cat(): Tensors must have same number of dimensions: got 1 and 2
Я очень застрял, если бы кто-нибудь мог мне помочь, пожалуйста.
Environment PyTorch version: (torch.__version__): 1.9.1 cu111 OS (e.g., Linux): Jupyter Notebook Python version (e.g., 3.9): 3.7 How you installed PyTorch and PyG (conda, pip, source): pip
Как я импортировал:
from torch_geometric.loader import DataLoader