Итерируемый набор данных исчерпывается после одной эпохи

#python #nlp #pytorch #torchtext

Вопрос:

Я хотел обучить RNN задаче анализа настроений, для этой задачи я использовал набор данных IMDB, предоставленный torchtext, который содержит 50000 обзоров фильмов и является итератором python. Я использовал split=('train', 'test') .

Сначала я построил вокабуляр, используя torchtext.vocab.Vocab и маркируя каждое предложение, а затем выполнил нумерацию.

Чтобы дополнить последовательность до той же длины, которую я использовал torch.nn.utils.rnn.pad_sequence , а также использовал collate_fn вместе с batch_sampler . Затем я загрузил данные с помощью torch.utils.data. DataLoader

Реализация сети RNN в порядке, но загрузчик данных исчерпан после одной эпохи, как вы можете видеть на изображении, прикрепленном ниже.

Правильно ли я использую подход для загрузки этого повторяющегося набора данных? и почему загрузчик данных исчерпан после одной эпохи и как мне преодолеть эту проблему.

Пожалуйста, обратитесь к общей записной книжке colab, если вы хотите увидеть мою реализацию.

пс. Я следил за официальным списком изменений torchtext с github

Вы можете найти мою реализацию здесь

Загрузчик данных исчерпан после одной эпохи

Ответ №1:

Решение состоит в том, чтобы использовать torchtext.data.functional.to_map_style_dataset(iter_data) (официальный документ) для преобразования набора данных в стиле итерации в набор данных в стиле карты.

Подобный этому:

 from torchtext.data.functional import to_map_style_dataset
train_iter = IMDB(split='train')
train_dataset = to_map_style_dataset(train_iter)  #Map-style dataset
 

а затем сделайте загрузчик данных.

 from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
 

Почему это происходит?

Я использую соглашение об именах в приведенном выше примере, чтобы объяснить.

Переход train_iter к Dataloader набору данных в итеративном стиле означает, что он не __getitem__ реализован. В нем есть только __iter__ и __next__ дандерс — что делает его повторяемым.

Поэтому , если я передам итерируемый в Dataloader , загрузчик данных остановится после возникновения StopIteration исключения, которое будет выдано __next__ dunder набора данных в стиле итерируемого( train_iter в данном случае), когда набор данных(итерируемый) будет исчерпан.

Поэтому мы использовали to_map_style_dataset функцию для преобразования итеративного стиля в набор данных в стиле карты. Это достигается за счет реализации __getitem__ dunder и, таким образом Dataloader , по умолчанию используются индексы для получения элементов из набора данных.

Другим возможным способом сделать то же самое также может быть

Если я буду использовать набор данных в итерационном стиле-мне нужно создавать Dataloader объект в каждую эпоху. Таким образом, после каждой эпохи новый объект dataloader будет запускаться с самого начала в цикле for.

Для лучшего понимания различий и вариантов использования наборов данных в стиле итераций и в стиле карт в Pytorch обратитесь к этому https://yizhepku.github.io/2020/12/26/dataloader.html

Комментарии:

1. Дайте мне знать, если я ответил на ваш вопрос. Также, пожалуйста, предложите отредактировать, если вы считаете, что мое понимание неверно.