TypeError:使用 JAX custom_vjp 迭代 scqubits 中的 SymPy 表达式时,“添加”对象不可迭代

问题描述 投票:0回答:1

我正在尝试修改 scqubits Python 包以使用 JAX 的 custom_vjp 进行可微分编程。这样做时,我遇到了以下错误:

类型错误:“添加”对象不可迭代

这是重现该问题的最小示例:

import sympy as sm
x, y, z = sm.symbols('x y z')

expr = x**2 + 2*y + z

try:
    for term in expr:
        print(term)
except TypeError as e:
    print(f"Error: {e}")

此代码输出: 错误:“添加”对象不可迭代

在 scqubits 代码库中,特别是在 scqubits/core/Circuit_routines.py 中,有一个迭代 SymPy 表达式的函数:

def _constants_in_subsys(self, H_sys: sm.Expr, constants_expr: sm.Expr) -> sm.Expr:
    """
    Returns an expression of constants that belong to the subsystem with the
    Hamiltonian H_sys

    Parameters
    ----------
    H_sys:
        Subsystem Hamiltonian

    Returns
    -------
        Expression of constants belonging to the subsystem
    """
    constant_expr = 0
    subsys_free_symbols = set(H_sys.free_symbols)
    constant_terms = constants_expr.copy()
    for term in constant_terms:
        if set(term.free_symbols) & subsys_free_symbols == set(term.free_symbols):
            constant_expr += term
    return constant_expr

当我使用 JAX 的 custom_vjp 运行此函数时,它会引发相同的 TypeError。

我的问题是:

直接迭代 SymPy 表达式(如 expr 中的 for term)是否不正确? 为什么这会引发 TypeError 并指出“Add”对象不可迭代? 如何修改代码以正确迭代 SymPy 表达式的术语,尤其是在使用 JAX 的上下文中? 附加信息:

SymPy 版本:(请在此处插入您的 SymPy 版本,例如 1.9) JAX 版本:(请在此处插入您的 JAX 版本,例如 0.2.25) scqubits 版本:(如果适用) 我尝试用 expr.args 替换迭代,如下所示:

for term in expr.args:
    print(term)

这可以正常工作,但我不确定这是否是最好或最可靠的方法,特别是在与 JAX 集成时。

任何见解或建议将不胜感激!

python sympy jax
1个回答
0
投票

似乎您在

scqubits
中发现了一个错误。我建议在该存储库中打开一个问题。显然,该功能未经测试。

在 sympy 中,您无法迭代表达式。但是,您可以迭代表达式的参数,就像上面发现的那样。

请注意,任何 sympy 表达式都会公开

args
属性(如示例中的加法、乘法、幂、函数或导数等)

import sympy as sm
x, y, z = sm.symbols('x y z')

expr = x**2 + 2*y + z

try:
    for term in expr.args:
        print(term)
except TypeError as e:
    print(f"Error: {e}")
# z
# x**2
# 2*y
© www.soinside.com 2019 - 2024. All rights reserved.