#python #jax
#python #jax
Вопрос:
Я не понимаю, как вычислять многомерные производные высшего порядка в jax.
Например, как вы вычисляете d ^ 2f / dx dy для
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
где x, y в R ^ n, n> = 1?
Я экспериментировал с jax.jvp
и jax.partial
, но у меня не было никакого успеха.
Ответ №1:
Поскольку x
and y
имеет векторное значение и f(x, y)
является скаляром, я полагаю, вы можете вычислить, что вам нужно, объединив функции jax.jacfwd
and jax.jacrev
с соответствующими аргументами:
import jax.numpy as jnp
from jax import jacfwd, jacrev
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
x = jnp.arange(4.0)
y = jnp.ones(4)
print(d2f_dxdy(x, y))
# DeviceArray([[0.96017027, 0. , 0. , 0. ],
# [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
# [0.558831 , 0.558831 , 1.5190012 , 0.558831 ],
# [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
# dtype=float32)