Сообщение об ошибке в Python с дифференциацией

#python #numpy #error-handling #jax

Вопрос:

Я вычисляю эти производные, используя подход Монтекарло для общего варианта вызова. Меня интересует эта комбинированная производная (как в отношении S, так и в отношении Сигмы). Делая это с помощью алгоритмического дифференцирования, я получаю ошибку, которую можно увидеть в конце страницы. Каким может быть возможное решение? Просто чтобы объяснить кое-что относительно кода, я собираюсь приложить формулу, используемую для вычисления «X» в приведенном ниже коде:

введите описание изображения здесь

 from jax import jit, grad, vmap
import jax.numpy as jnp
from jax import random
Underlying_asset = jnp.linspace(1.1,1.4,100)
volatilities = jnp.linspace(0.5,0.6,100)
def second_derivative_mc(S,vol):
    N = 100
    j,T,q,r,k = 10000,1.,0,0,1.
    S0 = jnp.array([S]).T #(Nx1) vector underlying asset
    C = jnp.identity(N)*vol    #matrix of volatilities with 0 outside diagonal 
    e = jnp.array([jnp.full(j,1.)])#(1xj) vector of "1"
    Rand = np.random.RandomState()
    Rand.seed(10)
    U= Rand.normal(0,1,(N,j)) #Random number for Brownian Motion
    sigma2 = jnp.array([vol**2]).T #Vector of variance Nx1

    first = jnp.dot(sigma2,e) #First part equation
    second = jnp.dot(C,U)     #Second part equation

    X = -0.5*first jnp.sqrt(T)*second

    St = jnp.exp(X)*S0

    P = jnp.maximum(St-k,0)
    payoff = jnp.average(P, axis=-1)*jnp.exp(-q*T)
    return payoff 


greek = vmap(grad(grad(second_derivative_mc, argnums=1), argnums=0)(Underlying_asset,volatilities)
 

Это сообщение об ошибке:

 > UnfilteredStackTrace                      Traceback (most recent call
> last) <ipython-input-78-0cc1da97ae0c> in <module>()
>      25 
> ---> 26 greek = vmap(grad(grad(second_derivative_mc, argnums=1), argnums=0))(Underlying_asset,volatilities)
> 
> 18 frames UnfilteredStackTrace: TypeError: Gradient only defined for
> scalar-output functions. Output had shape: (100,).
 

Трассировка стека ниже исключает внутренние кадры JAX.
Предыдущее-это исходное исключение, которое произошло без изменений.


Вышеуказанное исключение было прямой причиной следующего исключения:

 > TypeError                                 Traceback (most recent call
> last) /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in
> _check_scalar(x)
>     894     if isinstance(aval, ShapedArray):
>     895       if aval.shape != ():
> --> 896         raise TypeError(msg(f"had shape: {aval.shape}"))
>     897     else:
>     898       raise TypeError(msg(f"had abstract value {aval}"))

> TypeError: Gradient only defined for scalar-output functions. Output had shape: (100,).
 

Ответ №1:

Как указано в сообщении об ошибке, градиенты могут быть вычислены только для функций, возвращающих скаляр. Ваша функция возвращает вектор:

 print(len(second_derivative_mc(1.1, 0.5)))
# 100
 

Для векторнозначных функций можно вычислить якобиан (который аналогичен многомерному градиенту). Это то, что ты имел в виду?

 from jax import jacobian
greek = vmap(jacobian(jacobian(second_derivative_mc, argnums=1), argnums=0))(Underlying_asset,volatilities)
 

Кроме того, это не то, о чем вы спрашивали, но функция, описанная выше, вероятно, не будет работать так, как вы намереваетесь, даже если вы решите проблему в вопросе. RandomState Объекты Numpy отслеживают состояние, и поэтому, как правило, неправильно работают с преобразованиями jax , такими как grad , jit vmap , и т.д., Для которых требуется код без побочных эффектов (см. Вычисления с отслеживанием состояния в JAX). Вы можете попробовать использовать jax.random вместо этого; дополнительную информацию см. в разделе JAX: Случайные числа.

Комментарии:

1. Большое спасибо за предложение. Для векторнозначных функций вы можете вычислить якобиан (который аналогичен многомерному градиенту). Это то, что ты имел в виду? —> Да, в том-то и дело, спасибо! Могу я задать вам еще один (последний) вопрос? В этой ситуации результатом является массив, содержащий все нули, можно ли с помощью некоторой операции вычислить производную, с ненулевым значением?

2. Если функция имеет ненулевой градиент, градиент будет ненулевым. Пример: grad(lambda x: x)(0.0) ненулевое значение, несмотря на то, что сама функция возвращает ноль.

3. да, я согласен с вами, но, по вашему мнению, можно ли переписать эту функцию таким образом, чтобы я мог вычислить эту комбинированную производную?

4. Я не совсем понимаю, о чем ты спрашиваешь… может быть, откроем еще один вопрос с более подробной информацией?

5. да, может быть, так будет лучше, я собираюсь это сделать