#python #numpy #scipy
Вопрос:
Для выполнения производной я разработал следующий код:
import matplotlib.pyplot as plt
import numpy as np
from math import *
xi = jnp.linspace(-3,3)
def f(x):
a = x**3 5
return a
g1i = jax.vmap(jax.grad(f))(xi)
g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)
g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)
plt.plot(xi,yi, label = "f")
plt.plot(xi,g1i, label = "f'")
plt.plot(xi,g2i, label = "f''")
plt.plot(xi,g3i, label = "f'''")
plt.legend()
Этот код работает, но теперь я заинтересован в применении следующего кода для вычисления первой производной цены вызова по отношению к базовому активу (т. Е. дельте), пытаясь выполнить следующее, но это не работает:
import scipy.stats as si
import sympy as sy
import sys
xi = jnp.linspace(1,1.5)
def analytical_call(s0):
T=1.
q=0.
r=0.
k=1.
sigma=0.4
Kt = k*exp((q-r)*T)
d = (log(Kt/s0) (sigma**2)/2*T)/sigma
result = (Kt * si.norm.cdf((d / sqrt(T)), 0.0, 1.0) - s0 * si.norm.cdf(((d - sigma * T) / sqrt(T)), 0.0, 1.0) ) * exp(-q * T) exp(-q * T) * (s0 - Kt)
return result
print(analytical_call(1))
g1i = jax.vmap(jax.grad(analytical_call))(xi)
g2i = jax.vmap(jax.grad(jax.grad(analytical_call)))(xi)
plt.plot(xi,yi, label = "f")
plt.plot(xi,g1i, label = "f'")
plt.legend()
У вас есть какие-нибудь намеки? Заранее спасибо!
Комментарии:
1.
delta
Внутри вашей функции нетanalytical_call
, поэтому неясно, какую переменную вы хотите различать. Ты имеешь в видуs0
вместо этого? Обратите также внимание, что вы не можете смешивать методы scipy.stats и sympy с jax.2. Да, я имею в виду дифференцировать вызов в отношении потока s0, определенного в коде как «xi» @joni
3. Scipy используется только для вычисления d1.. Как я могу решить эту проблему? Потому что мне нужна чувствительность вызова по отношению к потоку подложек, используя метод AAD
Ответ №1:
Как уже упоминалось в комментариях, вы не можете использовать методы вне библиотеки jax, такие как scipy.stats.norm.cdf
. Используйте jax.scipy.stats
вместо этого. Аналогично, замените exp
и sqrt
их эквивалентами jax jnp.exp
и jnp.sqrt
:
from jax import jit, grad, vmap
import jax.numpy as jnp
from jax.scipy.stats.norm import cdf
def analytical_call(s0):
T, q, r, k, sigma = 1.0, 0.0, 0.0, 1.0, 0.4
Kt = k*jnp.exp((q-r)*T)
d = (jnp.log(Kt/s0) (sigma**2)/2*T)/sigma
result = (Kt * cdf((d / jnp.sqrt(T)), 0.0, 1.0) - s0 * cdf(((d - sigma * T) / jnp.sqrt(T)), 0.0, 1.0) ) * jnp.exp(-q * T) jnp.exp(-q * T) * (s0 - Kt)
return result
g = vmap(grad(analytical_call))
h = vmap(grad(grad(analytical_call)))
xi = jnp.linspace(1,1.5)
Затем вы можете оценить g(xi)
и h(xi)
.
Комментарии:
1. Я отредактировал вопрос только для того, чтобы прикрепить код о том, что я пробовал
2. @John_maddon Конечно. Однако, чтобы все было ясно, я бы рекомендовал опубликовать отдельный вопрос вместо редактирования вашего предыдущего поста.