Как выполнять операции tensordot после перестановки

#pytorch #permute #tensordot

#pytorch #перестановка #tensordot

Вопрос:

У меня есть 2 тензора, A и B:

 A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]
 

тензор D получается из операций «tensordot -> permute». Как я могу реализовать новую операцию f(), чтобы выполнить операцию tensordot после f(), например:

 A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)
 

Ответ №1:

Рассматривали ли вы возможность использования torch.einsum , которая является очень гибкой?

 D = torch.einsum('ijab,abkl->ikjl', A, B)
 

Проблема в tensordot том, что он выводит все измерения A перед измерениями, B и то, что вы ищете (при перестановке), заключается в «чередовании» измерений из A и B .

Комментарии:

1. Да! Наконец-то я действительно использую «torch.einsum».