Python — Различные регулярные/аналитические функции

#python #numpy #scipy

Вопрос:

Для выполнения производной я разработал следующий код:

 import matplotlib.pyplot as plt
import numpy as np
from math import *

xi = jnp.linspace(-3,3)

def f(x):
  a = x**3 5
  return a


g1i = jax.vmap(jax.grad(f))(xi)
g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)
g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)
plt.plot(xi,yi,  label = "f")
plt.plot(xi,g1i, label = "f'")
plt.plot(xi,g2i, label = "f''")
plt.plot(xi,g3i, label = "f'''")
plt.legend()
 

Этот код работает, но теперь я заинтересован в применении следующего кода для вычисления первой производной цены вызова по отношению к базовому активу (т. Е. дельте), пытаясь выполнить следующее, но это не работает:

 import scipy.stats as si
import sympy as sy
import sys
xi = jnp.linspace(1,1.5)
def analytical_call(s0):
    T=1.
    q=0.
    r=0.
    k=1.
    sigma=0.4
    Kt = k*exp((q-r)*T)
    d = (log(Kt/s0) (sigma**2)/2*T)/sigma
    result = (Kt * si.norm.cdf((d / sqrt(T)), 0.0, 1.0)  - s0 * si.norm.cdf(((d - sigma * T) / sqrt(T)), 0.0, 1.0)  ) * exp(-q * T)   exp(-q * T) * (s0 - Kt)
    return result
print(analytical_call(1))

g1i = jax.vmap(jax.grad(analytical_call))(xi)
g2i = jax.vmap(jax.grad(jax.grad(analytical_call)))(xi)
plt.plot(xi,yi,  label = "f")
plt.plot(xi,g1i, label = "f'")
plt.legend()
 

У вас есть какие-нибудь намеки? Заранее спасибо!

Комментарии:

1. delta Внутри вашей функции нет analytical_call , поэтому неясно, какую переменную вы хотите различать. Ты имеешь в виду s0 вместо этого? Обратите также внимание, что вы не можете смешивать методы scipy.stats и sympy с jax.

2. Да, я имею в виду дифференцировать вызов в отношении потока s0, определенного в коде как «xi» @joni

3. Scipy используется только для вычисления d1.. Как я могу решить эту проблему? Потому что мне нужна чувствительность вызова по отношению к потоку подложек, используя метод AAD

Ответ №1:

Как уже упоминалось в комментариях, вы не можете использовать методы вне библиотеки jax, такие как scipy.stats.norm.cdf . Используйте jax.scipy.stats вместо этого. Аналогично, замените exp и sqrt их эквивалентами jax jnp.exp и jnp.sqrt :

 from jax import jit, grad, vmap
import jax.numpy as jnp
from jax.scipy.stats.norm import cdf

def analytical_call(s0):
    T, q, r, k, sigma = 1.0, 0.0, 0.0, 1.0, 0.4
    Kt = k*jnp.exp((q-r)*T)
    d = (jnp.log(Kt/s0) (sigma**2)/2*T)/sigma
    result = (Kt * cdf((d / jnp.sqrt(T)), 0.0, 1.0)  - s0 * cdf(((d - sigma * T) / jnp.sqrt(T)), 0.0, 1.0)  ) * jnp.exp(-q * T)   jnp.exp(-q * T) * (s0 - Kt)
    return result

g = vmap(grad(analytical_call))
h = vmap(grad(grad(analytical_call)))
xi = jnp.linspace(1,1.5)
 

Затем вы можете оценить g(xi) и h(xi) .

Комментарии:

1. Я отредактировал вопрос только для того, чтобы прикрепить код о том, что я пробовал

2. @John_maddon Конечно. Однако, чтобы все было ясно, я бы рекомендовал опубликовать отдельный вопрос вместо редактирования вашего предыдущего поста.