Дифференцируемое аффинное преобразование на участках изображений в pytorch

#python #pytorch #affinetransform

Вопрос:

У меня есть тензор ограничивающих рамок объекта, например, с формой [10,4], которые соответствуют пакету изображений, например, с формой [2,3,64,64] и матрицами преобразования для каждого объекта с формой [10,6] и вектором, который определяет, какой индекс объекта принадлежит какому изображению. Я хотел бы применить аффинные преобразования к участкам изображений и заменить эти участки после применения преобразований. Сейчас я делаю это с помощью цикла for, но то, как я это делаю, невозможно отличить (я получаю ошибку операции на месте от pytorch). Я хотел знать, есть ли другой способ сделать это. например, через grid_sample ?

Вот мой текущий код:

 for obj_num in range(obj_vecs.shape[0]): #batch_size
    im_id = obj_to_img[obj_num]
    x1, y1, x2, y2 = boxes_pred[obj_num]
    im_patch = img[im_id, :, x1:x2, y1:y2]
    im_patch = im_patch[None, :, :, :]
    img[im_id, :, x1:x2, y1:y2] = self.VITAE.stn(im_patch, theta_mean[obj_num], inverse=False)[0]
 

Ответ №1:

Существует несколько способов выполнения дифференцируемых культур в PyTorch.

Давайте рассмотрим минимальный пример в 2D:

 >>> x1, y1, x2, y2 = torch.randint(0, 9, (4,))
(tensor(7), tensor(3), tensor(5), tensor(6))

>>> x = torch.randint(0, 100, (9,9), dtype=float, requires_grad=True)
tensor([[18., 34., 28., 41.,  1., 14., 77., 75., 23.],
        [62., 33., 64., 41., 16., 70., 47., 45., 19.],
        [20., 69.,  5., 51.,  1., 16., 20., 63., 52.],
        [51., 25.,  8., 30., 40., 67., 41., 27., 33.],
        [36.,  6., 95., 53., 69., 84., 51., 42., 71.],
        [46., 72., 88., 82., 71., 75., 86., 36., 15.],
        [66., 19., 58., 50., 91., 28.,  7., 83.,  4.],
        [94., 50., 34., 34., 92., 45., 48., 97., 76.],
        [80., 34., 19., 13., 77., 77., 51., 15., 13.]], dtype=torch.float64,
       requires_grad=True)
 

Учитывая x1 , x2 (соответственно y1 , y2 границы индекса участка для измерения высоты (соответственно. измерение ширины). Вы можете получить сетку координат, соответствующую вашему исправлению, используя комбинацию torch.arange и torch.meshgrid :

 >>> sorted_range = lambda a, b: torch.arange(a, b) if b >= a else torch.arange(b, a)
>>> xi, yi = sorted_range(x1, x2), sorted_range(y1, y2)
(tensor([3, 4, 5, 6]), tensor([5]))

>>> i, j = torch.meshgrid(xi, yi)
(tensor([[3],
         [4],
         [5],
         [6]]), 
 tensor([[5],
         [5],
         [5],
         [5]]))
 

С помощью этой настройки вы можете извлекать и заменять исправления x .

  1. Вы можете извлечь патч путем x прямой индексации:
     >>> patch = x[i, j].reshape(len(xi), len(yi))
    tensor([[67.],
            [84.],
            [75.],
            [28.]], dtype=torch.float64, grad_fn=<ViewBackward>)
     

    Вот маска для иллюстрации:

     tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64,
    grad_fn=<IndexPutBackward>)
     
  2. Вы можете заменить значения в x результате некоторого преобразования в патче, используя torch.Tensor.index_put :
     >>> values = 2*patch
     tensor([[134.],
             [168.],
             [150.],
             [ 56.]], dtype=torch.float64, grad_fn=<MulBackward0>)
    
    >>> x.index_put(indices=(i, j), values=values)
    tensor([[ 18.,  34.,  28.,  41.,   1.,  14.,  77.,  75.,  23.],
            [ 62.,  33.,  64.,  41.,  16.,  70.,  47.,  45.,  19.],
            [ 20.,  69.,   5.,  51.,   1.,  16.,  20.,  63.,  52.],
            [ 51.,  25.,   8.,  30.,  40., 134.,  41.,  27.,  33.],
            [ 36.,   6.,  95.,  53.,  69., 168.,  51.,  42.,  71.],
            [ 46.,  72.,  88.,  82.,  71., 150.,  86.,  36.,  15.],
            [ 66.,  19.,  58.,  50.,  91.,  56.,   7.,  83.,   4.],
            [ 94.,  50.,  34.,  34.,  92.,  45.,  48.,  97.,  76.],
            [ 80.,  34.,  19.,  13.,  77.,  77.,  51.,  15.,  13.]],
        dtype=torch.float64, grad_fn=<IndexPutBackward>)
     

Комментарии:

1. Спасибо! однако клонирование im_patch перед отправкой его в STN сработало!