#python #pytorch #vectorization #torchvision
#python #pytorch #векторизация #torchvision
Вопрос:
Предположим, у меня есть пакет изображений в виде тензора, например:
images = torch.zeros(64, 3, 1024, 1024)
Теперь я хочу выбрать исправление из каждого из этих изображений. Все исправления имеют одинаковый размер, но имеют разные начальные позиции для каждого изображения в пакете.
size_x = 100
size_y = 100
start_x = torch.zeros(64)
start_y = torch.zeros(64)
Я могу добиться желаемого результата таким образом:
result = []
for i in range(arr.shape[0]):
result.append(arr[i, :, start_x[i]:start_x[i] size_x, start_y[i]:start_y[i] size_y])
result = torch.stack(result, dim=0)
Вопрос в том, можно ли сделать то же самое быстрее, без цикла? Возможно, существует какая-то форма расширенного индексирования или функция PyTorch, которая может это сделать?
Ответ №1:
Вы можете использовать torch.take
, чтобы избавиться от цикла for. Но сначала с помощью этой функции должен быть создан массив индексов
def convert_inds(img_a,img_b,patch_a,patch_b,start_x,start_y):
all_patches = np.zeros((len(start_x),3,patch_a,patch_b))
patch_src = np.zeros((patch_a,patch_b))
inds_src = np.arange(patch_b)
patch_src[:] = inds_src
for ind,info in enumerate(zip(start_x,start_y)):
x,y = info
if x patch_a 1 > img_a: return False
if y patch_b 1 > img_b: return False
start_ind = img_b * x y
end_ind = img_b * (x patch_a -1) y
col_src = np.linspace(start_ind,end_ind,patch_b)[:,None]
all_patches[ind,:] = patch_src col_src
return all_patches.astype(np.int)
Как вы можете видеть, эта функция по существу создает индексы для каждого патча, который вы хотите вырезать. С помощью этой функции проблема может быть легко решена путем
size_x = 100
size_y = 100
start_x = torch.zeros(64)
start_y = torch.zeros(64)
images = torch.zeros(64, 3, 1024, 1024)
selected_inds = convert_inds(1024,1024,100,100,start_x,start_y)
selected_inds = torch.tensor(selected_inds)
res = torch.take(images,selected_inds)
Обновить
Наблюдение OP верно, описанный выше подход не быстрее, чем наивный подход. Чтобы избежать создания индексов каждый раз, вот еще одно решение, основанное на unfold
Сначала создайте тензор всех возможных исправлений
# create all possible patches
all_patches = images.unfold(2,size_x,1).unfold(3,size_y,1)
Затем нарежьте нужные исправления из all_patches
img_ind = torch.arange(images.shape[0])
selected_patches = all_patches[img_ind,:,start_x,start_y,:,:]
Комментарии:
1. Спасибо за ответ! Однако моей целью было ускорить код (предположительно, путем устранения цикла), а не просто избавиться от цикла как такового. Поскольку ваше решение требует построения индексов каждый раз, когда нужно выбирать исправления с новым набором позиций исправлений, я не думаю, что это будет быстрее, чем наивное решение.
2. @DLunin Добро пожаловать 🙂 Ваше наблюдение верно, я обновил свой пост решением, используя
unfold
, идея состоит в том, чтобы сначала создать тензор всех возможных исправлений, а затем нарезать его непосредственно сstart_x
помощью andstart_y
.