Pytorch — пользовательский загрузчик данных работает вечно

#pytorch

#pytorch

Вопрос:

 class TripletImageLoader(torch.utils.data.Dataset):
    def __init__(self):
        self.data = [0]*10000000

    def __getitem__(self, index):
        pid = os.getpid() % WORKER_SIZE
        # My code here only uses pid, doesnt use index

        return torch.tensor(batch.data), torch.tensor(batch.label)

    def __len__(self):
        return len(self.data)
  

Мне нужно, чтобы мой загрузчик данных работал вечно. Прямо сейчас он всегда завершается после нажатия 10000000 или любого другого максимального целого размера. Как мне заставить это работать вечно, меня не волнует «индекс», я его не использую. Я просто использую рабочие возможности этого класса

Комментарии:

1. Что значит загрузчик данных работает вечно? загрузчик данных обычно используется для загрузки данных и создания мини-пакетов. Почему вы хотите запускать его вечно?

2. мне нужно, чтобы он несколько раз повторял одну и ту же партию

3. Да, мы, в общем, делаем то же самое. Просто запустите цикл по количеству обучающих эпох.

Ответ №1:

Поскольку вам нужно обучать один и тот же пакет в течение нескольких итераций, следующий скелет кода должен работать для вас.

 def train(args, data_loader):
    for idx, ex in enumerate(data_loader):
        # iterate over each mini-batches
        # add your code

def validate(args, data_loader):
     with torch.no_grad():
        for idx, ex in enumerate(data_loader):
            # iterate over each mini-batches
            # add your code

# args = dict() containing required parameters
for epoch in range(start_epoch, args.num_epochs):
    # train_loader = data loader for the training data
    train(args, train_loader)
  

Вы можете использовать загрузчик данных следующим образом.

 class ReaderDataset(Dataset):
    def __init__(self, examples):
        # examples = a list of examples
        # add your code

    def __len__(self):
        # return total dataset size

    def __getitem__(self, index):
        # write your code to return each batch item

train_dataset = ReaderDataset(train_examples)
train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=batchify,
        pin_memory=args.cuda,
        drop_last=args.parallel
    )
# batchify is a custom function to prepare the mini-batches