Как отфильтровать элементы по определенному измерению

#numpy #torch

#numpy #факел

Вопрос:

Допустим, тензор a имеет форму (128, 20, 10) . Я хотел бы отфильтровать этот тензор в тензор b формы (128,19,10) на основе условия: в каждой (20, 10) матрице есть одна строка, которую я хотел бы удалить, где сумма столбцов равна нулю. Как мне сделать это с помощью нарезки?

Я должен быть в состоянии сделать что-то вроде:

 mask = a.abs().sum(dim=2) != 0
a = a[mask]
  

Но это дает мне неправильную форму.

Комментарии:

1. Это все, что мне нужно для изменения формы после? a = a[mask].view(128,19,10) ?

2. Форма равна 1d, но правильные ли значения?