#python #tensor
Вопрос:
DaCe поставляется с синтаксисом, который позволяет пользователю определять карту, под картой находится набор задач, определенный пользователями для вычислений. В качестве вычисления мы можем выполнять скалярные операции, такие как сложение целых чисел.
Однако, если мы выполним тензорную(матричную) операцию внутри этого набора задач, например, A@B или A B, где A и B-две матрицы, DaCe выдаст нам ошибку.
Вот конкретный фрагмент кода, который вызывает ошибку. Вы можете попробовать запустить этот пример.
@dace.program
def fusion(A: dace.float32[10, 20], B: dace.float32[10, 20],
out: dace.float32[1]):
tmp = dace.define_local([10, 20], dtype=A.dtype)
tmp_2 = dace.define_local([10, 20], dtype=A.dtype)
for i, j in dace.map[0:10, 0:20]:
with dace.tasklet:
a << (A A)[i, j]
b >> tmp[i, j]
b = a * a
Посмотрите на строку № 16, где есть матричное сложение A A. Я считаю, что это и есть причина ошибки.
Поэтому я хотел бы знать, допускается ли в такой ситуации работа тензора? Если это разрешено, могу ли я знать правильный синтаксис для написания этого? Если нет, я хотел бы знать, почему это не поддерживается?
Ответ №1:
with dace.tasklet
Оператор представляет собой низкоуровневый SDFG API, в котором элементы графика записываются напрямую (т. е. мемлеты и тасклеты). Поскольку тасклеты принимают доступ только к отдельным массивам или потокам, код действительно завершается ошибкой в строке #16 (поскольку в узле доступа нет массива «A A»).
Поскольку тасклеты также преобразуются непосредственно в сгенерированный код, написание в нем тензорной операции, подобной @
оператору, также не сработает, поскольку все должно входить/выходить из тасклета через явные мемлеты.
Чтобы код в вопросе был правильным, все, что нужно сделать, — это не объявлять явный тасклет:
@dace.program
def option1(A: dace.float32[10, 20], B: dace.float32[10, 20]):
for i, j in dace.map[0:10, 0:20]:
B[i, j] = A[i, j] A[i, j]
Это также может быть сокращено до крупнозернистых тензорных операций:
@dace.program
def option2(A: dace.float32[10, 20], B: dace.float32[10, 20]):
B[:] = A A # or "tmp = A A", this will automatically create a new array for tmp