为了通过应用@JIT来优化优化算法中使用的现有函数,我遇到了一些问题。运行以下功能时: 导入JAX 导入jax.numpy作为jnp ...

问题描述 投票:0回答:1

import jax import jax.numpy as jnp from jax import grad, jacobian from scipy.optimize import minimize from scipy.interpolate import BSpline jax.config.update("jax_enable_x64", True) @jax.jit def bspline(t_values, knots, coefficients, degree): """ Generate a B-spline curve from given knots and coefficients. Parameters: - t_values: Array of parameter values where the spline is evaluated. - knots: Knot vector as a 1D numpy array. - coefficients: Control points as a 1D numpy array of shape (n,). - degree: Degree of the B-spline (e.g., 3 for cubic B-splines). Returns: - A numpy array of shape (num_points,) representing the B-spline curve. """ def basis_function(i, k, t, knots): """Compute the basis function recursively.""" if k == 0: return jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0) else: denom1 = knots[i + k] - knots[i] denom2 = knots[i + k + 1] - knots[i + 1] term1 = (t - knots[i]) / denom1 * basis_function(i, k - 1, t, knots) if denom1 != 0 else 0 term2 = (knots[i + k + 1] - t) / denom2 * basis_function(i + 1, k - 1, t, knots) if denom2 != 0 else 0 return term1 + term2 # Compute the B-spline curve points curve_points = jnp.zeros(len(t_values)) for i in range(len(coefficients)): v = basis_function(i, degree, t_values, knots) curve_points = curve_points + v * coefficients[i] return curve_points

我有以下错误:

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
我尝试过的是:

咨询JAX关于此错误的官方文件(在此处可用
),我修改了
basis_function

def basis_function(i, k, t, knots): """Compute the basis function recursively.""" return jnp.where( k == 0, jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0), jnp.where( (knots[i + k] - knots[i]) != 0, (t - knots[i]) / (knots[i + k] - knots[i]) * basis_function(i, k - 1, t, knots), 0 ) + jnp.where( (knots[i + k + 1] - knots[i + 1]) != 0, (knots[i + k + 1] - t) / (knots[i + k + 1] - knots[i + 1]) * basis_function(i + 1, k - 1, t, knots), 0 ) )

避免直接布尔检查: RecursionError: maximum recursion depth exceeded in comparison

,但是,现在我遇到了一个

recursionError

@jit 这个递归问题似乎是由于以前没有出现的应用。

不幸的是,您不能在递归是基于追踪条件的JAX中使用递归方法。您要么必须使用静态条件的Python控制流来编写递归,要么必须使用非恢复方法重写。
在您的情况下,只要

degree

在呼叫站点上是静态的,那么第一个选择似乎是可行的。在这种情况下,您可以通过以这种方式重新定义第一个功能来解决问题:

from functools import partial @partial(jax.jit, static_argnames=['degree']) def bspline(t_values, knots, coefficients, degree): ...
python recursion jit jax
1个回答
0
投票
认为,尽管JAX跟踪将展开所有此类递归,因此这可能最终会产生一个长期的程序,从而导致漫长的编译时间。

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.