我有一个函数(实际上是一个类,但为了简单起见,我们假设它是一个函数),它使用 PyTorch 中存在的多个 NumPy 操作,例如
np.add
我也想要该函数的 PyTorch 版本。我试图避免重复我的代码,所以我想知道:
有没有办法让我在 NumPy 和 PyTorch 之间动态地来回切换函数的执行,而不需要重复实现?
举一个玩具示例,假设我的函数是:
def foo_numpy(x: np.ndarray, y: np.ndarray) -> np.ndarray:
return np.add(x, y)
我可以定义一个 PyTorch 等效项:
def foo_torch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.add(x, y)
我可以以某种方式定义一个函数吗:
def foo(x, y, mode: str = 'numpy'):
if mode == 'numpy':
return np.add(x, y)
elif mode == 'torch':
return torch.add(x, y)
else:
raise ValueError
不需要 if-else 语句?
编辑:像下面这样的东西怎么样?
def foo(x, y, mode: str = 'numpy'):
if mode == 'numpy':
lib = np
elif mode == 'torch':
lib = torch
else:
raise ValueError
return lib.add(x, y)
您可以使用布尔 (bool) 值来代替使用字符串来表示要使用的模式,即 False (0) 表示 NumPy,True (1) 表示 PyTorch。然后可以使用三元运算符来进一步缩小 if 语句。
def foo(x, y, mode: bool = 0):
lib = torch if mode else np
return lib.add(x, y)
如果你想在课堂上的两者之间来回切换,你可以做类似的事情
class Example:
def __init__(self):
self._mode = True
def switchMode(self):
self._mode = !self._mode
def foo(self, x, y):
lib = torch if self._mode else np
return lib.add(x, y)
import array_api_compat
def foo(x, y):
xp = array_api_compat.array_namespace(x, y)
return xp.add(x, y)
import numpy as np
import torch
print(foo(torch.tensor(1), torch.tensor(2)))
print(foo(np.array(1), np.array(2)))
$ python test.py
tensor(3)
3
ivy.Array
中的库),但您必须非常小心,因为它不是FOSS。import ivy
from ivy import Array, NativeArray
ivy.set_array_mode(False) # Without this, the function will return ivy.array
def foo(x: Array | NativeArray, y: Array | NativeArray) -> Array:
return ivy.add(x, y)
import numpy as np
import torch
print(foo(torch.tensor(1), torch.tensor(2)))
print(foo(np.array(1), np.array(2)))
$ python test2.py
tensor(3)
3