我想用python编写以下复杂的函数
scipy.integrate
:
其中 \phi 和 \Phi 分别是标准法线的 pdf 和 cdf,总和 i != 1,2 超过 *theta 中的 theta,而乘积 s != 1,i,2 超过 *theta 中的 theta *theta 不是 theta_i
为了让事情变得更简单,我通过定义更多函数将问题分解为几个更简单的部分:
这是我尝试过的:
import scipy.integrate as integrate
from scipy.integrate import nquad
from scipy.stats import norm
import math
import numpy as np
def f(v, u, t1, ti, t2, *theta):
prod = 1
for t in theta:
prod = prod * (1 - norm.cdf(u + t2 - t))
return norm.pdf(v) * norm.cdf(v + ti - t1) * prod
def g(v, u, t1, t2, *theta):
S = 0
for ti in theta:
S = S + integrate.quad(f, -np.inf, u + t2 - ti, args=(u, t1, ti, t2, *theta))[0]
return S * norm.pdf(u)
def P(t1, t2, *theta):
return integrate.quad(g, -np.inf, np.inf, args=(t1, t2, *theta))[0]
但它不起作用,因为我认为在 P 的定义中,积分是对 w.r.t 进行积分。第一个变量,即 v,但我们应该对 u 进行积分,但我不知道如何做到这一点。
有什么快速的方法吗?
为了检查,以下是正确的 P 会产生的结果:
P(0.2, 0.1, 0.3, 0.4) = 0.08856347190679764
P(0.2, 0.1, 0.4, 0.3, 0.5) = 0.06094233268837703
我做的第一件事是将
prod
移到 f
之外,因为它不是 v
的函数。我意识到您的 prod
包含 ti
术语,因此我添加了对 ts
(从 t
重命名)等于 ti
的检查。
您的函数也有太多变量。
g
是 u
和 theta 的函数,而 f
是 v
和 theta 的函数,因此它们不需要 u
和 v
。
这不是最有效的代码,但它可以工作。
from scipy.integrate import quad
from scipy.stats import norm
import numpy as np
def f(v, t1, ti):
return norm.pdf(v) * norm.cdf(v + ti - t1)
def g(u, t1, t2, *theta):
S = 0
for ti in theta:
prod = 1
for ts in theta:
if ts != ti:
prod = prod * (1 - norm.cdf(u + t2 - ts))
S = S + prod*quad(f, -np.inf, u + t2 - ti, args=(t1, ti))[0]
return S * norm.pdf(u)
def P(t1, t2, *theta):
return quad(g, -np.inf, np.inf, args=(t1, t2, *theta))[0]
print(P(0.2, 0.1, 0.3, 0.4)) # 0.08856347187084088
print(P(0.2, 0.1, 0.4, 0.3, 0.5)) # 0.060942332693157915