Сортировка тензора Pytorch по трассировке

#sorting #pytorch #vectorization #tensor

#сортировка #pytorch #векторизация #тензор

Вопрос:

У меня есть тензор pytorch в форме (100,64,22,3,3), и я хотел бы отсортировать вдоль axis = 0 по трассировке (3,3) компонентов. Приведенный ниже код работает, но он очень медленный из-за циклов for . Есть ли способ векторизовать операцию, чтобы ускорить ее?

 x=torch.rand(100,64,22,3,3)
x_sorted=torch.zeros((x.shape[0],x.shape[1],x.shape[2],x.shape[3],x.shape[4]))
            for i in range(x.shape[0]):
              #compute tensorized trace
              trace=new=torch.diagonal(x[i], dim1=-2, dim2=-1).sum(-1) 
              #Sort the trace
              trace_values,trace_ind=torch.sort(trace,dim=0,descending=True)
              for j in range(x_sorted.shape[1]):
                for k in range(x_sorted.shape[2]):
                  x_sorted[i,j,k]=x[i,trace_ind[j,k],k]
  
 

Ответ №1:

Попробуйте:

 tensor = torch.tensor(np.random.rand(100,64, 3, 3))

orders = torch.argsort(torch.einsum('ijkk->ijk', tensor).sum(-1), axis=0)
orders.shape

tensor[orders, torch.arange(s.shape[1])[None, :]]
 

Комментарии:

1. В последней строке должно быть s.shape[1] быть orders.shape[1] . Это выдает ошибку несоответствия формы.