Jax 与 numpy 生成 Heston 路径

问题描述 投票:0回答:1

我有这个Python代码(来自QuantPy),它使用numpy在Heston模型下生成股票路径。 我正在尝试将其转换为使用 Jax。 由于某种原因,numpy 版本的运行时间约为 2 秒,而 Jax 版本的运行时间约为 45 秒。 如果有人能指出原因并提出任何改进以使 Jax 运行得更快,我将非常感激。

# Parameters
S0 = 100.0          # initial asset price
K = 64              # strike price
T = 1.0                # time in years
r = 0.05               # risk-free rate
N = 252                # number of time steps in simulation
M = 100000               # number of simulations

# Heston dependent parameters
kappa = 2              # rate of mean reversion of variance under risk-neutral dynamics
theta = 0.05        # long-term mean of variance under risk-neutral dynamics
v0 = 0.05           # initial variance under risk-neutral dynamics
rho = -0.5              # correlation between returns and variances under risk-neutral dynamics
sigma = 0.3            # volatility of volatility

def heston_model_sim(S0, v0, rho, kappa, theta, sigma, r, T, N, M):

    # initialise other parameters
    dt = T/N
    mu = np.array([0,0])
    cov = np.array([[1,rho],
                    [rho,1]])
    # arrays for storing prices and variances
    S = np.full(shape=(N+1,M), fill_value=S0)
    v = np.full(shape=(N+1,M), fill_value=v0)
    # sampling correlated brownian motions under risk-neutral measure
    Z = np.random.multivariate_normal(mu, cov, (N,M))
    for i in range(1,N+1):
        S[i] = S[i-1] * np.exp( (r - 0.5*v[i-1])*dt + np.sqrt(v[i-1] * dt) * Z[i-1,:,0] )
        v[i] = np.maximum(v[i-1] + kappa*(theta-v[i-1])*dt + sigma*np.sqrt(v[i-1]*dt)*Z[i-1,:,1],0)
    
    return S, v

def heston_model_sim_jax(S0, v0, rho, kappa, theta, sigma, r, T, N, M):

    # Initialize other parameters
    dt = T / N
    mu = jnp.array([0, 0])
    cov = jnp.array([[1, rho], [rho, 1]])
    
    # Arrays for storing prices and variances
    S = jnp.full((N+1, M), S0)
    v = jnp.full((N+1, M), v0)
    
    # Sampling correlated Brownian motions under risk-neutral measure
    key = random.PRNGKey(0)
    Z = random.multivariate_normal(key, mean=mu, cov=cov, shape=(N, M))
    
    for i in range(1, N + 1):
        S = S.at[i].set(S[i-1] * jnp.exp((r - 0.5 * v[i-1]) * dt + jnp.sqrt(v[i-1] * dt) * Z[i-1, :, 0]))
        v = v.at[i].set(jnp.maximum(v[i-1] + kappa * (theta - v[i-1]) * dt + sigma * jnp.sqrt(v[i-1] * dt) * Z[i-1, :, 1], 0))
    
    return S, v

我已经阅读了 Jax 文档和其他人在线关于 Jax 与 numpy 的特定功能的问题,但我找不到任何可以帮助我理解我的功能的内容。

我想知道这是否与赋值有关:也许 S[i] = ... 比 S = S.at[i].set(...) 更快。

python numpy montecarlo quantitative-finance jax
1个回答
0
投票

为了比较 JAX 和 NumPy 的性能,您应该记住常见问题解答中的一般讨论:JAX 比 NumPy 更快吗?。特别是:

总而言之:如果您正在对 CPU 上的各个数组操作进行微基准测试,您通常可以预期 NumPy 的性能优于 JAX,因为它的每次操作调度开销较低。如果您在 GPU 或 TPU 上运行代码,或者在 CPU 上对更复杂的 JIT 编译操作序列进行基准测试,则通常可以预期 JAX 的性能优于 NumPy。

您的代码似乎是 CPU 上的非 jit 编译的数组操作序列,这正是我们期望 NumPy 比 JAX 更快的机制。

您可以通过将函数包装在

jax.jit
中(使用标记为静态的适当参数)来改进 JAX 运行时,但您可能会发现由于在代码中使用
for
循环,编译时间非常慢。您可以通过切换到 XLA 友好的迭代来解决这个问题,例如
fori_loop
(请参阅 JAX 锐利位:控制流 进行一些讨论),但即便如此,我也不认为 CPU 上的 JAX 会比 NumPy 快得多操作:问题是您的操作正在执行许多具有严格顺序依赖性的小操作,因此 XLA 编译器无法进行使其他 JAX 程序快速运行的并行化。

© www.soinside.com 2019 - 2024. All rights reserved.