#python #nlp #pytorch #torchtext
Вопрос:
Я хотел обучить RNN задаче анализа настроений, для этой задачи я использовал набор данных IMDB, предоставленный torchtext, который содержит 50000 обзоров фильмов и является итератором python. Я использовал split=('train', 'test')
.
Сначала я построил вокабуляр, используя torchtext.vocab.Vocab
и маркируя каждое предложение, а затем выполнил нумерацию.
Чтобы дополнить последовательность до той же длины, которую я использовал torch.nn.utils.rnn.pad_sequence
, а также использовал collate_fn
вместе с batch_sampler
. Затем я загрузил данные с помощью torch.utils.data. DataLoader
Реализация сети RNN в порядке, но загрузчик данных исчерпан после одной эпохи, как вы можете видеть на изображении, прикрепленном ниже.
Правильно ли я использую подход для загрузки этого повторяющегося набора данных? и почему загрузчик данных исчерпан после одной эпохи и как мне преодолеть эту проблему.
Пожалуйста, обратитесь к общей записной книжке colab, если вы хотите увидеть мою реализацию.
пс. Я следил за официальным списком изменений torchtext с github
Вы можете найти мою реализацию здесь
Ответ №1:
Решение состоит в том, чтобы использовать torchtext.data.functional.to_map_style_dataset(iter_data)
(официальный документ) для преобразования набора данных в стиле итерации в набор данных в стиле карты.
Подобный этому:
from torchtext.data.functional import to_map_style_dataset
train_iter = IMDB(split='train')
train_dataset = to_map_style_dataset(train_iter) #Map-style dataset
а затем сделайте загрузчик данных.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
Почему это происходит?
Я использую соглашение об именах в приведенном выше примере, чтобы объяснить.
Переход train_iter
к Dataloader
набору данных в итеративном стиле означает, что он не __getitem__
реализован. В нем есть только __iter__
и __next__
дандерс — что делает его повторяемым.
Поэтому , если я передам итерируемый в Dataloader
, загрузчик данных остановится после возникновения StopIteration
исключения, которое будет выдано __next__
dunder набора данных в стиле итерируемого( train_iter
в данном случае), когда набор данных(итерируемый) будет исчерпан.
Поэтому мы использовали to_map_style_dataset
функцию для преобразования итеративного стиля в набор данных в стиле карты. Это достигается за счет реализации __getitem__
dunder и, таким образом Dataloader
, по умолчанию используются индексы для получения элементов из набора данных.
Другим возможным способом сделать то же самое также может быть
Если я буду использовать набор данных в итерационном стиле-мне нужно создавать Dataloader
объект в каждую эпоху. Таким образом, после каждой эпохи новый объект dataloader будет запускаться с самого начала в цикле for.
Для лучшего понимания различий и вариантов использования наборов данных в стиле итераций и в стиле карт в Pytorch обратитесь к этому https://yizhepku.github.io/2020/12/26/dataloader.html
Комментарии:
1. Дайте мне знать, если я ответил на ваш вопрос. Также, пожалуйста, предложите отредактировать, если вы считаете, что мое понимание неверно.