Многомерные производные высшего порядка в jax

#python #jax

#python #jax

Вопрос:

Я не понимаю, как вычислять многомерные производные высшего порядка в jax.

Например, как вы вычисляете d ^ 2f / dx dy для

 def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))
 

где x, y в R ^ n, n> = 1?

Я экспериментировал с jax.jvp и jax.partial , но у меня не было никакого успеха.

Ответ №1:

Поскольку x and y имеет векторное значение и f(x, y) является скаляром, я полагаю, вы можете вычислить, что вам нужно, объединив функции jax.jacfwd and jax.jacrev с соответствующими аргументами:

 import jax.numpy as jnp
from jax import jacfwd, jacrev

def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
  
x = jnp.arange(4.0)
y = jnp.ones(4)

print(d2f_dxdy(x, y))

# DeviceArray([[0.96017027, 0.        , 0.        , 0.        ],
#              [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
#              [0.558831  , 0.558831  , 1.5190012 , 0.558831  ],
#              [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
#             dtype=float32)