根据参数中的可选标志返回不同的类,无需工厂

问题描述 投票:0回答:1

我正在

Equinox
中实现一系列类,以便能够对类参数进行导数。大多数时候,用户将实例化类
A
并使用
fn
函数来生成一些数据,其细节并不重要。然而,在我们对梯度感兴趣的情况下,用 sigmoid 函数表示
param_c
是有益的,以确保它保持在 (0,1) 范围内。但是,我不希望用户注意到这样做时类的行为方式有所不同。因此,我实现了另一个类
A_sigmoid
,其中
param_c
作为
property
,并使用
A_abstract
来确保两个类都继承
fn
方法,该方法将在其逻辑中调用
param_c
。虽然我可以简单地让用户用
A_sigmoid
而不是
_param_c_sigmoid
实例化
param_c
对象,但我不想强迫用户必须做出这种区分。相反,我希望它们传递相同的
kwargs
字典,无论类如何,并在幕后进行转换。我还想这样做,以便在创建新的
A
时可以简单地传递一个可选标志来指示程序使用代码的 sigmoid 版本。为此,我实现了以下 MWE:

class A_abstract(eqx.Module):
    param_a: jax.Array
    param_b: jax.Array
    param_c: eqx.AbstractVar[jax.Array]
    
    def fn(self,*args,**kwargs):
        pass

class A_sigmoid(A_abstract):
    _param_c_sigmoid: jax.Array

    @property
    def param_c(self):
        return 1 / (1 + jnp.exp(-self._param_c_sigmoid))

class A(A_abstract):
    param_c: jax.Array

    def __new__(cls, **kwargs):
        sigmoid_flag = kwargs.pop('use_sigmoid_c',False)
        if sigmoid_flag == True:
            param_c = kwargs.pop('param_c')
            _param_c_sigmoid = jnp.log(param_c / (1 - param_c))
            kwargs['_param_c_sigmoid'] = _param_c_sigmoid
            instance = A_sigmoid.__new__(A_sigmoid)
            instance.__init__(**kwargs)
            print(type(instance))
            return instance
        else:
            return super(A,cls).__new__(cls)

classA = A(param_a = 1.,param_b = 2.,param_c = 0.5,use_sigmoid_c=True)
print(type(classA))

当在

instance
方法中调用
A_sigmoid
时,代码正确地表明
print
具有类型
__new__
。但是,当我打印
type(classA)
时,它的类型为
A
并且没有属性
param_c
,尽管它确实有
_param_c_sigmoid
的值。为什么会这样呢?我在使用
__new__
时是否遗漏了导致此错误的某些内容?虽然我知道原则上工厂是执行此操作的最佳方式,但还有其他类型的
B
C
等类型不需要 sigmoid 实现,我想与
A
的行为方式完全相同,使它们能够轻松交换。因此,我不希望使用某些自定义方法来实例化
A
,这与调用其他类上的默认构造函数不同。

我在具有以下软件包版本的 Jupyter 笔记本上运行此程序:

Python           : 3.12.4
IPython          : 8.30.0
ipykernel        : 6.29.5
jupyter_client   : 8.6.3
jupyter_core     : 5.7.2
python instantiation jax python-class
1个回答
0
投票

如果您使用的是普通课程,那么您所做的事情是完全合理的:

class A_abstract:
  pass

class A_sigmoid(A_abstract):
  pass

class A(A_abstract):
  def __new__(cls, flag, **kwds):
    if flag:
      instance = A_sigmoid.__new__(A_sigmoid)
    else:
      instance = super().__new__(cls)
    instance.__init__(**kwds)
    return instance

print(type(A(True))) # <class '__main__.A_sigmoid'>

但是,

eqx.Module
包含一堆元类逻辑,它覆盖了
__new__
的工作方式,这似乎与您正在制作的
__new__
覆盖相冲突。请注意,这里唯一的区别是
A_abstract
继承自
eqx.Module
,结果是
A
而不是
A_sigmoid
:

import equinox as eqx

class A_abstract(eqx.Module):
  pass

class A_sigmoid(A_abstract):
  pass

class A(A_abstract):
  def __new__(cls, flag, **kwds):
    if flag:
      instance = A_sigmoid.__new__(A_sigmoid)
    else:
      instance = super().__new__(cls)
    instance.__init__(**kwds)
    return instance

print(type(A(True))) # <class '__main__.A'>

我研究了几分钟,无法找到此变化的确切原因,但无法确定它。

如果您尝试通过类初始化和用

__new__
中的其他类替换类来执行一些复杂的操作,则必须对其进行修改以与 Equinox 已经执行的复杂替换一起使用。

© www.soinside.com 2019 - 2024. All rights reserved.