我编写了一个
Python
类,旨在计算光线在复杂介质中的两点(init_point
和 term_point
)之间传播所需的时间(使用参数 medium_model = [A, B, C]
建模) .
要计算光传输时间,我们需要两件事:
launch_angle
)。我有一个以 x 和 y 坐标表示的射线 z 坐标表达式(我们将其视为一个“x”坐标 (
x = sqrt(x**2 + y**2)
)、launch_angle
、以及起始点和终止点,因此,我们可以通过简单地优化已知终端 z 坐标与不同 launch_angle
值得到的值之间的差异来确定后者。
from dataclasses import dataclass
import warnings
import numpy as np
# Suppress warnings (sometimes there is overflow in _calculate_z_coord, etc.)
warnings.filterwarnings("ignore", category=RuntimeWarning)
@dataclass
class RayTracer:
medium_model: np.ndarray
SPEED_OF_LIGHT = 299792458 # Speed of light in m/s
def _calculate_z_coord(self, x: np.ndarray, launch_angle: np.ndarray, x0: np.ndarray, z0: np.ndarray) -> np.ndarray:
"""Calculate the z-coordinate based on launch angle and other parameters."""
A, B, C = self.medium_model
exp_Cz0 = np.exp(C * z0)
cos_launch_angle = np.cos(launch_angle)
beta = (A - B * exp_Cz0) * cos_launch_angle
sqrt_A2_beta2 = np.sqrt(A**2 - beta**2)
K = C * sqrt_A2_beta2 / beta
term1 = A**2 - beta**2
term2 = A * B * exp_Cz0
sqrt_term = np.sqrt(term1 + 2 * term2 + B**2 * exp_Cz0**2)
# Precompute for efficiency
log_arg = term1 + term2 + sqrt_A2_beta2 * sqrt_term
t = (sqrt_A2_beta2 * C * x0 - beta * C * z0 + beta * np.log(log_arg)) / (sqrt_A2_beta2 * C)
exp_Kx = np.exp(K * x)
exp_Kt = np.exp(K * t)
log_term_num = 2 * term1 * np.exp(K * (t + x))
log_term_den = beta**2 * B**2 * exp_Kx**2 - 2 * A * B * exp_Kt + exp_Kt**2
log_term = log_term_num / log_term_den
return (1 / C) * np.log(log_term)
def _find_launch_angle(self, init_points: np.ndarray, term_points: np.ndarray, num_steps: int = 1000) -> np.ndarray:
"""Find the optimal launch angle."""
x0 = np.hypot(init_points[:, 0], init_points[:, 1])
x1 = np.hypot(term_points[:, 0], term_points[:, 1])
# Coarse search with a fine search
launch_angles = np.linspace(-np.pi, np.pi, num_steps)
# Precompute z_coords for all launch angles
z_coords = self._calculate_z_coord(x1[:, np.newaxis], launch_angles, x0[:, np.newaxis], init_points[:, 2, np.newaxis])
term_z = term_points[:, 2][:, np.newaxis]
objective_values = (z_coords - term_z)**2
# Find the best angles
best_indices = np.nanargmin(objective_values, axis=1)
best_angles = launch_angles[best_indices]
return best_angles
def transit_time(self, init_points: np.ndarray, term_points: np.ndarray) -> np.ndarray:
"""Calculate the transit time."""
A, B, C = self.medium_model
# Vectorized launch angle search
launch_angles = self._find_launch_angle(init_points, term_points)
exp_Cz_init = np.exp(C * init_points[:, 2])
beta = np.abs((A - B * exp_Cz_init) * np.cos(launch_angles))
sqrt_A2_beta2 = np.sqrt(A**2 - beta**2)
K = C * sqrt_A2_beta2 / beta
def time_expression(z: np.ndarray, beta: np.ndarray, K: np.ndarray) -> np.ndarray:
exp_Cz = np.exp(C * z)
t = np.sqrt((A + B * exp_Cz - beta) / (A + B * exp_Cz + beta))
alpha = np.sqrt((A - beta) / (A + beta))
log_expr = np.log(np.abs((t - alpha) / (t + alpha)))
return A * np.sqrt((C**2 + K**2) / (C**2 * K**2)) * log_expr
time_diff = time_expression(term_points[:, 2], beta, K) - time_expression(init_points[:, 2], beta, K)
return time_diff / self.SPEED_OF_LIGHT
# Example usage
if __name__ == "__main__":
ray_tracer = RayTracer(np.array([1.78, 0.454, 0.0132]))
# Generate random init_points and term_points
num_points = 100000
init_points = np.random.uniform(low=-50, high=50, size=(num_points, 3))
term_points = np.random.uniform(low=-50, high=50, size=(num_points, 3))
# I'm including a simple but rough measure of the execution time using `time` here, for your reference.
import time
start_time = time.time()
transit_times = ray_tracer.transit_time(init_points, term_points)
end_time = time.time()
# Print results
print("Transit times:", transit_times)
print(f"Elapsed time: {end_time - start_time:.6f} seconds")
代码已矢量化。初始化
RayTracer
的操作如下:
my_ray_tracer = RayTracer(np.array([A, B, C]))
transit_time()
函数是矢量化的,并采用初始点和终点的数组,例如:
init_points = np.random.uniform(low=-50, high=50, size=(num_points, 3))
term_points = np.random.uniform(low=-50, high=50, size=(num_points, 3))
transit_times = ray_tracer.transit_time(init_points, term_points)
在某些应用中我希望能够快速计算数千、数万甚至数十万条光线的传播时间。虽然此代码相当快,但对于我的某些应用程序来说还不够快。我已经使用了我熟悉的所有技巧(例如矢量化),并且显着缩短了平均执行时间,但仍然不足以满足我的需要。
我的要求如下:
transit_time()
的执行时间。NumPy
和 SciPy
。我无法使用任何其他外部库。如何进一步提高这段代码的执行时间?
您可以添加多线程,numpy 会删除 GIL,因此您确实可以从中获得一些加速,numpy 应用程序通常会在核心计算限制之前达到内存带宽限制,因此不要期望太多加速。 (至少不要指望完全并发)
start_time = time.time()
from multiprocessing.pool import ThreadPool
from multiprocessing import cpu_count
with ThreadPool(cpu_count()) as pool:
transit_times = pool.starmap(ray_tracer.transit_time,
zip(np.split(init_points, cpu_count()),
np.split(term_points,cpu_count())))
end_time = time.time()
在这台 4 核机器上,它从 34.382188 秒 到 9.272425 秒,但实际收益取决于内核数量和内存带宽。