出于某些原因,我需要在PyTorch中实现Runge-Kutta4方法(因此,不,我将不使用scipy.odeint
)。我尝试过,在最简单的测试用例上得到了奇怪的结果,用x(0)= 1求解x'= x(解析解决方案:x = exp(t))。基本上,随着我减少时间步长,我无法获得数值误差。我可以使用更简单的Euler方法来做到这一点,但是不能使用Runge-Kutta 4方法来做到这一点,这使我怀疑这里存在浮点问题(也许我错过了一些从双精度到单精度的隐藏转换)?
import torch import numpy as np import matplotlib.pyplot as plt def Euler(f, IC, time_grid): y0 = torch.tensor([IC]) time_grid = time_grid.to(y0[0]) values = y0 for i in range(0, time_grid.shape[0] - 1): t_i = time_grid[i] t_next = time_grid[i+1] y_i = values[i] dt = t_next - t_i dy = f(t_i, y_i) * dt y_next = y_i + dy y_next = y_next.unsqueeze(0) values = torch.cat((values, y_next), dim=0) return values def RungeKutta4(f, IC, time_grid): y0 = torch.tensor([IC]) time_grid = time_grid.to(y0[0]) values = y0 for i in range(0, time_grid.shape[0] - 1): t_i = time_grid[i] t_next = time_grid[i+1] y_i = values[i] dt = t_next - t_i dtd2 = 0.5 * dt f1 = f(t_i, y_i) f2 = f(t_i + dtd2, y_i + dtd2 * f1) f3 = f(t_i + dtd2, y_i + dtd2 * f2) f4 = f(t_next, y_i + dt * f3) dy = 1/6 * dt * (f1 + 2 * (f2 + f3) +f4) y_next = y_i + dy y_next = y_next.unsqueeze(0) values = torch.cat((values, y_next), dim=0) return values # differential equation def f(T, X): return X # initial condition IC = 1. # integration interval def integration_interval(steps, ND=1): return torch.linspace(0, ND, steps) # analytical solution def analytical_solution(t_range): return np.exp(t_range) # test a numerical method def test_method(method, t_range, analytical_solution): numerical_solution = method(f, IC, t_range) L_inf_err = torch.dist(numerical_solution, analytical_solution, float('inf')) return L_inf_err if __name__ == '__main__': Euler_error = np.array([0.,0.,0.]) RungeKutta4_error = np.array([0.,0.,0.]) indices = np.arange(1, Euler_error.shape[0]+1) n_steps = np.power(10, indices) for i, n in np.ndenumerate(n_steps): t_range = integration_interval(steps=n) solution = analytical_solution(t_range) Euler_error[i] = test_method(Euler, t_range, solution).numpy() RungeKutta4_error[i] = test_method(RungeKutta4, t_range, solution).numpy() plots_path = "./plots" a = plt.figure() plt.xscale('log') plt.yscale('log') plt.plot(n_steps, Euler_error, label="Euler error", linestyle='-') plt.plot(n_steps, RungeKutta4_error, label="RungeKutta 4 error", linestyle='-.') plt.legend() plt.savefig(plots_path + "/errors.png")
结果:
如您所见,Euler方法收敛(缓慢,如一阶方法所预期)。但是,随着时间步长越来越小,Runge-Kutta4方法不
收敛。错误首先下降,然后再次上升。这是什么问题?出于某种原因,我需要在PyTorch中实现Runge-Kutta4方法(因此,不,我将不使用scipy.odeint)。我尝试过,但在最简单的测试用例上却得到了怪异的结果,用x(0)= 1(...
原因确实是浮点精度问题。 torch
默认为单精度,因此,一旦truncation error变得足够小,总误差就基本上由roundoff error确定,并通过增加步数<=>进一步减少截断误差不会导致总误差的减少。