我正在使用 Sympy 进行一些符号计算,但我需要定义一些自定义对象。到目前为止我所做的方法是通过定义一个新类来利用 Sympy 的符号
class CustomSymbol(sym.Symbol):
def __init__(self, *args, **kwargs):
super().__init__(args, kwargs)
self.custom_attribute = True
这样我就可以重用很多代码。尽管如此,我还是必须重写一些函数并定义自定义的 Add、Mul、Pow 等...... 现在,问题如下:
from sympy import *
class CustomSymbol(Symbol):
def __pow__(self, other):
return CustomPow(self, other)
class CustomPow(Pow):
pass
a = CustomSymbol('a')
x = a**2 * a**3
type(x)
此代码返回
x
的类型是 Sympy 的 Pow,而不是 CustomPow。本质上是因为,在内部代码中,Sympy 本质上是计算 Pow(a, 2+3)
。
事实上,Sympy 使用这种方法,而不是更灵活的方法,例如计算 a**(2+3)
并让 a
决定如何管理幂运算,这让我认为也许有一种预期的方法来定义自定义符号。
所以,我的问题是:应该如何创建 CustomSymbol 类?对于我写的问题,是否有一种解决方法,不需要重写整个 Sympy 库?
要使解析器和简化器与您的自定义符号一起使用,您必须创建一个自定义表达式类,并让其他自定义类继承它。您必须重写核心运算符中的许多方法:Add、Mul、Pow。最值得注意的是
Add.flatten
和 Mul.flatten
类方法负责排序和收集术语。然后你就有了 _eval_subs
、_eval_refine
、_eval_expand_*
分别用于 subs、refine 和 expand 提示 的功能。
下面是
CustomSymbol
的代码,其行为(几乎)与 Symbol
完全相同。想象一下 ...
被替换为方法的源代码。我导入了源代码副本使用的方法和类。最后,我用自定义的覆盖了 Add
、Mul
、Pow
,以便在方法中使用它们进行评估。
from typing import Callable
from collections import defaultdict
from sympy import *
from sympy.core.parameters import global_parameters
from sympy.core.add import _addsort
from sympy.core.mul import NC_Marker, _mulsort, _unevaluated_Mul, _keep_coeff
from sympy.core.logic import fuzzy_bool, fuzzy_not
from sympy.utilities.misc import as_int
class CustomExpr(Expr):
def __add__(self, other):
return CustomAdd(self, other)
def __radd__(self, other):
return CustomAdd(other, self)
def __mul__(self, other):
return CustomMul(self, other)
def __rmul__(self, other):
return CustomMul(other, self)
def __pow__(self, other):
return CustomPow(self, other)
class CustomSymbol(CustomExpr, Symbol):
pass
class CustomAdd(CustomExpr, Add):
@classmethod
def flatten(cls, seq):
...
def __neg__(self):
...
class CustomMul(CustomExpr, Mul):
@classmethod
def flatten(cls, seq):
...
def _eval_power(self, e):
...
@staticmethod
def _expandsums(sums):
...
def _eval_expand_mul(self, **hints):
...
def matches(self, expr, repl_dict=None, old=False):
...
@staticmethod
def _matches_match_wilds(dictionary, wildcard_ind, nodes, targets):
...
@staticmethod
def _combine_inverse(lhs, rhs):
...
def _eval_subs(self, old, new):
...
class CustomPow(CustomExpr, Pow):
def _eval_refine(self, assumptions):
...
def _eval_subs(self, old, new):
...
def _eval_expand_power_exp(self, **hints):
...
def _eval_expand_power_base(self, **hints):
...
def _eval_expand_multinomial(self, **hints):
...
Add = CustomAdd
Mul = CustomMul
Pow = CustomPow
当某些方法使用
self.func
而不是硬编码 Add
、Pow
或 Mul
时,不需要重写。正如你所看到的 CustomAdd
只覆盖了 2 个方法。您可能想要覆盖其他方法或删除某些方法。例如,要让 subs
使用基本替换,您可以这样做。
class CustomPow(CustomExpr, Pow):
# This removes the method that it inherited from Pow
_eval_subs = property()
不幸的是,没有一种简单的方法来子类化核心运算符,因为它们是如此交织在一起。