我正在
Equinox
中实现一系列类,以便能够对类参数进行导数。大多数时候,用户将实例化类A
并使用fn
函数来生成一些数据,其细节并不重要。然而,在我们对梯度感兴趣的情况下,用 sigmoid 函数表示 param_c
是有益的,以确保它保持在 (0,1) 范围内。但是,我不希望用户注意到这样做时类的行为方式有所不同。因此,我实现了另一个类 A_sigmoid
,其中 param_c
作为 property
,并使用 A_abstract
来确保两个类都继承 fn
方法,该方法将在其逻辑中调用 param_c
。虽然我可以简单地让用户用 A_sigmoid
而不是 _param_c_sigmoid
实例化 param_c
对象,但我不想强迫用户必须做出这种区分。相反,我希望它们传递相同的 kwargs
字典,无论类如何,并在幕后进行转换。我还想这样做,以便在创建新的 A
时可以简单地传递一个可选标志来指示程序使用代码的 sigmoid 版本。为此,我实现了以下 MWE:
class A_abstract(eqx.Module):
param_a: jax.Array
param_b: jax.Array
param_c: eqx.AbstractVar[jax.Array]
def fn(self,*args,**kwargs):
pass
class A_sigmoid(A_abstract):
_param_c_sigmoid: jax.Array
@property
def param_c(self):
return 1 / (1 + jnp.exp(-self._param_c_sigmoid))
class A(A_abstract):
param_c: jax.Array
def __new__(cls, **kwargs):
sigmoid_flag = kwargs.pop('use_sigmoid_c',False)
if sigmoid_flag == True:
param_c = kwargs.pop('param_c')
_param_c_sigmoid = jnp.log(param_c / (1 - param_c))
kwargs['_param_c_sigmoid'] = _param_c_sigmoid
instance = A_sigmoid.__new__(A_sigmoid)
instance.__init__(**kwargs)
print(type(instance))
return instance
else:
return super(A,cls).__new__(cls)
classA = A(param_a = 1.,param_b = 2.,param_c = 0.5,use_sigmoid_c=True)
print(type(classA))
当在
instance
方法中调用 A_sigmoid
时,代码正确地表明 print
具有类型 __new__
。但是,当我打印 type(classA)
时,它的类型为 A
并且没有属性 param_c
,尽管它确实有 _param_c_sigmoid
的值。为什么会这样呢?我在使用 __new__
时是否遗漏了导致此错误的某些内容?虽然我知道原则上工厂是执行此操作的最佳方式,但还有其他类型的 B
、C
等类型不需要 sigmoid 实现,我想与 A
的行为方式完全相同,使它们能够轻松交换。因此,我不希望使用某些自定义方法来实例化 A
,这与调用其他类上的默认构造函数不同。
我在具有以下软件包版本的 Jupyter 笔记本上运行此程序:
Python : 3.12.4
IPython : 8.30.0
ipykernel : 6.29.5
jupyter_client : 8.6.3
jupyter_core : 5.7.2
如果您使用的是普通课程,那么您所做的事情是完全合理的:
class A_abstract:
pass
class A_sigmoid(A_abstract):
pass
class A(A_abstract):
def __new__(cls, flag, **kwds):
if flag:
instance = A_sigmoid.__new__(A_sigmoid)
else:
instance = super().__new__(cls)
instance.__init__(**kwds)
return instance
print(type(A(True))) # <class '__main__.A_sigmoid'>
但是,
eqx.Module
包含一堆元类逻辑,它覆盖了__new__
的工作方式,这似乎与您正在制作的__new__
覆盖相冲突。请注意,这里唯一的区别是 A_abstract
继承自 eqx.Module
,结果是 A
而不是 A_sigmoid
:
import equinox as eqx
class A_abstract(eqx.Module):
pass
class A_sigmoid(A_abstract):
pass
class A(A_abstract):
def __new__(cls, flag, **kwds):
if flag:
instance = A_sigmoid.__new__(A_sigmoid)
else:
instance = super().__new__(cls)
instance.__init__(**kwds)
return instance
print(type(A(True))) # <class '__main__.A'>
我研究了几分钟,无法找到此变化的确切原因,但无法确定它。
如果您尝试通过类初始化和用
__new__
中的其他类替换类来执行一些复杂的操作,则必须对其进行修改以与 Equinox 已经执行的复杂替换一起使用。