我有这个Python代码(来自QuantPy),它使用numpy在Heston模型下生成股票路径。 我正在尝试将其转换为使用 Jax。 由于某种原因,numpy 版本的运行时间约为 2 秒,而 Jax 版本的运行时间约为 45 秒。 如果有人能指出原因并提出任何改进以使 Jax 运行得更快,我将非常感激。
# Parameters
S0 = 100.0 # initial asset price
K = 64 # strike price
T = 1.0 # time in years
r = 0.05 # risk-free rate
N = 252 # number of time steps in simulation
M = 100000 # number of simulations
# Heston dependent parameters
kappa = 2 # rate of mean reversion of variance under risk-neutral dynamics
theta = 0.05 # long-term mean of variance under risk-neutral dynamics
v0 = 0.05 # initial variance under risk-neutral dynamics
rho = -0.5 # correlation between returns and variances under risk-neutral dynamics
sigma = 0.3 # volatility of volatility
def heston_model_sim(S0, v0, rho, kappa, theta, sigma, r, T, N, M):
# initialise other parameters
dt = T/N
mu = np.array([0,0])
cov = np.array([[1,rho],
[rho,1]])
# arrays for storing prices and variances
S = np.full(shape=(N+1,M), fill_value=S0)
v = np.full(shape=(N+1,M), fill_value=v0)
# sampling correlated brownian motions under risk-neutral measure
Z = np.random.multivariate_normal(mu, cov, (N,M))
for i in range(1,N+1):
S[i] = S[i-1] * np.exp( (r - 0.5*v[i-1])*dt + np.sqrt(v[i-1] * dt) * Z[i-1,:,0] )
v[i] = np.maximum(v[i-1] + kappa*(theta-v[i-1])*dt + sigma*np.sqrt(v[i-1]*dt)*Z[i-1,:,1],0)
return S, v
def heston_model_sim_jax(S0, v0, rho, kappa, theta, sigma, r, T, N, M):
# Initialize other parameters
dt = T / N
mu = jnp.array([0, 0])
cov = jnp.array([[1, rho], [rho, 1]])
# Arrays for storing prices and variances
S = jnp.full((N+1, M), S0)
v = jnp.full((N+1, M), v0)
# Sampling correlated Brownian motions under risk-neutral measure
key = random.PRNGKey(0)
Z = random.multivariate_normal(key, mean=mu, cov=cov, shape=(N, M))
for i in range(1, N + 1):
S = S.at[i].set(S[i-1] * jnp.exp((r - 0.5 * v[i-1]) * dt + jnp.sqrt(v[i-1] * dt) * Z[i-1, :, 0]))
v = v.at[i].set(jnp.maximum(v[i-1] + kappa * (theta - v[i-1]) * dt + sigma * jnp.sqrt(v[i-1] * dt) * Z[i-1, :, 1], 0))
return S, v
我已经阅读了 Jax 文档和其他人在线关于 Jax 与 numpy 的特定功能的问题,但我找不到任何可以帮助我理解我的功能的内容。
我想知道这是否与赋值有关:也许 S[i] = ... 比 S = S.at[i].set(...) 更快。
为了比较 JAX 和 NumPy 的性能,您应该记住常见问题解答中的一般讨论:JAX 比 NumPy 更快吗?。特别是:
总而言之:如果您正在对 CPU 上的各个数组操作进行微基准测试,您通常可以预期 NumPy 的性能优于 JAX,因为它的每次操作调度开销较低。如果您在 GPU 或 TPU 上运行代码,或者在 CPU 上对更复杂的 JIT 编译操作序列进行基准测试,则通常可以预期 JAX 的性能优于 NumPy。
您的代码似乎是 CPU 上的非 jit 编译的数组操作序列,这正是我们期望 NumPy 比 JAX 更快的机制。
您可以通过将函数包装在
jax.jit
中(使用标记为静态的适当参数)来改进 JAX 运行时,但您可能会发现由于在代码中使用 for
循环,编译时间非常慢。您可以通过切换到 XLA 友好的迭代来解决这个问题,例如 fori_loop
(请参阅 JAX 锐利位:控制流 进行一些讨论),但即便如此,我也不认为 CPU 上的 JAX 会比 NumPy 快得多操作:问题是您的操作正在执行许多具有严格顺序依赖性的小操作,因此 XLA 编译器无法进行使其他 JAX 程序快速运行的并行化。