#nlp #pytorch #bert-language-model #tpu #pytorch-dataloader
#nlp #pytorch #bert-language-model #tpu #pytorch-загрузчик данных
Вопрос:
Я пытаюсь настроить свою модель контроля качества на основе bert (PyTorch) с помощью Tpu v3-8, предоставленной Kaggle. В процессе проверки я использовал параллельный загрузчик для одновременного прогнозирования на 8 ядрах. Но после этого я не знаю, что мне делать, чтобы собрать все результаты обратно из каждого ядра (и в правильном порядке, соответствующем dataset), чтобы рассчитать общую оценку EM и F1. Кто-нибудь может помочь? Код:
def _run():
MAX_LEN = 192 # maximum text length in the batch (cannot have too high due to memory constraints)
BATCH_SIZE = 16 # batch size (cannot have too high due to memory constraints)
EPOCHS = 2 # number of epochs
train_sampler = torch.utils.data.distributed.DistributedSampler(
tokenized_datasets['train'],
num_replicas=xm.xrt_world_size(), # tell PyTorch how many devices (TPU cores) we are using for training
rank=xm.get_ordinal(), # tell PyTorch which device (core) we are on currently
shuffle=True
)
train_data_loader = torch.utils.data.DataLoader(
tokenized_datasets['train'],
batch_size=BATCH_SIZE,
sampler=train_sampler,
drop_last=True,
num_workers=0,
)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
tokenized_datasets['validation'],
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False
)
valid_data_loader = torch.utils.data.DataLoader(
tokenized_datasets['validation'],
batch_size=BATCH_SIZE,
sampler=valid_sampler,
drop_last=False,
num_workers=0
)
device = xm.xla_device() # device (single TPU core)
model = model.to(device) # put model onto the TPU core
xm.master_print('done loading model')
xm.master_print(xm.xrt_world_size(),'as size')
lr = 0.5e-5 * xm.xrt_world_size()
optimizer = AdamW(model.parameters(), lr=lr) # define our optimizer
for epoch in range(EPOCHS):
gc.collect()
# use ParallelLoader (provided by PyTorch XLA) for TPU-core-specific dataloading:
para_loader = pl.ParallelLoader(train_data_loader, [device])
xm.master_print('parallel loader created... training now')
gc.collect()
call training loop:
train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=None)
del para_loader
model.eval()
para_loader = pl.ParallelLoader(valid_data_loader, [device])
gc.collect()
model.eval()
# call evaluation loop
print("call evaluation loop")
start_logits, end_logits = eval_loop_fn(para_loader.per_device_loader(device), model, device)