#pytorch #pytorch-lightning
Вопрос:
Я знаю, что мы можем использовать профилировщик факела с тензорной доской, используя что-то вроде этого:
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/resnet18'),
record_shapes=True,
with_stack=True
) as prof:
for step, batch_data in enumerate(train_loader):
if step >= (1 1 3) * 2:
break
train(batch_data)
prof.step() # Need to call this at the end of each step to notify profiler of steps' boundary.
Он отлично работает с pytorch, но проблема в том, что я должен использовать pytorch lightning, и если я включу это в свой шаг обучения, он просто не создаст файл журнала и не создаст запись для профилировщика. Все, что я получаю lightning_logs
, — это то, что не является результатом профилировщика. Я не смог найти ничего в документах о lightning_profiler и tensorboard, так что у кого-нибудь есть какие-либо идеи?
Вот как выглядит моя функция обучения:
def training_step(self, train_batch, batch_idx):
with torch.profiler.profile(
activities=[ProfilerActivity.CPU],
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=2,
repeat=1),
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
) as profiler:
x, y = train_batch
x = x.float()
logits = self.forward(x)
loss = self.loss_fn(logits, y)
profiler.step()
return loss
Ответ №1:
Вам вообще не нужно использовать raw torch.profiler
. В документах Lightning есть целая страница, посвященная профилированию ..
.. и это так же просто, как передать флаг тренера под названием profiler
» как
# other profilers are "simple", "advanced" etc
trainer = pl.Trainer(profiler="pytorch")
Кроме того, установите TensorBoardLogger
в качестве предпочтительного регистратора, как вы обычно делаете
trainer = pl.Trainer(profiler="pytorch", logger=TensorBoardLogger(..))
Комментарии:
1. Но я хочу использовать профилировщик факела, чтобы отслеживать каждую выполняемую операцию. Я пробовал использовать
trainer = pl.Trainer(profiler="pytorch", logger=TensorBoardLogger(..))
, но профилировщик не отображается на тензорной доске2. @Madara я предполагаю, что вы не установили
torch_tb_profiler
, как упоминалось здесь3. Сработало ли это @Madara ?
4. Да, это сделал @ayandas 🙂