Как ограничить накладные расходы на производительность, вызванные Tensor.to () в Факеле?

#python #performance #neural-network #pytorch

#python #Производительность #нейронная сеть #pytorch

Вопрос:

Я работаю над сетью отслеживания объектов в реальном времени на основе Tracktor и платформы mmdetection. Я имитирую постоянный поток входных данных, запуская его с размером пакета 1. Я заметил очень большое замедление, вызванное постоянным использованием метода Torch .to(), который, насколько мне известно, я не контролирую.

Какие методы можно использовать для ограничения накладных расходов, вызванных передачей данных с центрального процессора на графический процессор?

Вот профилирование Pycharm с выделенными важными ссылками:

введите описание изображения здесь

И код, соответствующий последнему блоку перед .to():

 def grid_anchors(self, featmap_sizes, device='cuda'):
    """Generate grid anchors in multiple feature levels.

    Args:
        featmap_sizes (list[tuple]): List of feature map sizes in
            multiple feature levels.
        device (str): Device where the anchors will be put on.

    Return:
        list[torch.Tensor]: Anchors in multiple feature levels. 
            The sizes of each tensor should be [N, 4], where 
            N = width * height * num_base_anchors, width and height 
            are the sizes of the corresponding feature lavel, 
            num_base_anchors is the number of anchors for that level.
    """
    assert self.num_levels == len(featmap_sizes)
    multi_level_anchors = []
    for i in range(self.num_levels):
        anchors = self.single_level_grid_anchors(
            self.base_anchors[i].to(device),
            featmap_sizes[i],
            self.strides[i],
            device=device)
        multi_level_anchors.append(anchors)
    return multi_level_anchors
  

И основной цикл, который загружает изображения и запускает отслеживание:

 from torch.utils.data import DataLoader
from tqdm import tqdm

data_loader = DataLoader(seq, batch_size=1, shuffle=False)
for i, frame in enumerate(tqdm(data_loader)):
    if len(seq) * tracktor['frame_split'][0] <= i <= len(seq) * tracktor['frame_split'][1]:
        with torch.no_grad():
            tracker.step(frame)
        num_frames  = 1