import jax
import jax.numpy as jnp
from jax import grad, jacobian
from scipy.optimize import minimize
from scipy.interpolate import BSpline
jax.config.update("jax_enable_x64", True)
@jax.jit
def bspline(t_values, knots, coefficients, degree):
"""
Generate a B-spline curve from given knots and coefficients.
Parameters:
- t_values: Array of parameter values where the spline is evaluated.
- knots: Knot vector as a 1D numpy array.
- coefficients: Control points as a 1D numpy array of shape (n,).
- degree: Degree of the B-spline (e.g., 3 for cubic B-splines).
Returns:
- A numpy array of shape (num_points,) representing the B-spline curve.
"""
def basis_function(i, k, t, knots):
"""Compute the basis function recursively."""
if k == 0:
return jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0)
else:
denom1 = knots[i + k] - knots[i]
denom2 = knots[i + k + 1] - knots[i + 1]
term1 = (t - knots[i]) / denom1 * basis_function(i, k - 1, t, knots) if denom1 != 0 else 0
term2 = (knots[i + k + 1] - t) / denom2 * basis_function(i + 1, k - 1, t, knots) if denom2 != 0 else 0
return term1 + term2
# Compute the B-spline curve points
curve_points = jnp.zeros(len(t_values))
for i in range(len(coefficients)):
v = basis_function(i, degree, t_values, knots)
curve_points = curve_points + v * coefficients[i]
return curve_points
我有以下错误:
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
我尝试过的是:咨询JAX关于此错误的官方文件(在此处可用
),我修改了
basis_function
def basis_function(i, k, t, knots):
"""Compute the basis function recursively."""
return jnp.where(
k == 0,
jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0),
jnp.where(
(knots[i + k] - knots[i]) != 0,
(t - knots[i]) / (knots[i + k] - knots[i]) * basis_function(i, k - 1, t, knots),
0
) +
jnp.where(
(knots[i + k + 1] - knots[i + 1]) != 0,
(knots[i + k + 1] - t) / (knots[i + k + 1] - knots[i + 1]) * basis_function(i + 1, k - 1, t, knots),
0
)
)
避免直接布尔检查:
RecursionError: maximum recursion depth exceeded in comparison
:
@jit
这个递归问题似乎是由于以前没有出现的应用。
不幸的是,您不能在递归是基于追踪条件的JAX中使用递归方法。您要么必须使用静态条件的Python控制流来编写递归,要么必须使用非恢复方法重写。在您的情况下,只要
degree
在呼叫站点上是静态的,那么第一个选择似乎是可行的。在这种情况下,您可以通过以这种方式重新定义第一个功能来解决问题:
from functools import partial
@partial(jax.jit, static_argnames=['degree'])
def bspline(t_values, knots, coefficients, degree):
...