在 Numpy 和 PyTorch 之间切换函数/类实现:?

问题描述 投票:0回答:2

我有一个函数(实际上是一个类,但为了简单起见,我们假设它是一个函数),它使用 PyTorch 中存在的多个 NumPy 操作,例如

np.add
我也想要该函数的 PyTorch 版本。我试图避免重复我的代码,所以我想知道:

有没有办法让我在 NumPy 和 PyTorch 之间动态地来回切换函数的执行,而不需要重复实现?

举一个玩具示例,假设我的函数是:

def foo_numpy(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  return np.add(x, y)

我可以定义一个 PyTorch 等效项:

def foo_torch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  return torch.add(x, y)

我可以以某种方式定义一个函数吗:

def foo(x, y, mode: str = 'numpy'):
  if mode == 'numpy':
    return np.add(x, y)
  elif mode == 'torch':
    return torch.add(x, y)
  else:
    raise ValueError

不需要 if-else 语句?

编辑:像下面这样的东西怎么样?

def foo(x, y, mode: str = 'numpy'):
  if mode == 'numpy':
    lib = np
  elif mode == 'torch':
    lib = torch
  else:
    raise ValueError
  return lib.add(x, y)
python numpy pytorch
2个回答
0
投票

您可以使用布尔 (bool) 值来代替使用字符串来表示要使用的模式,即 False (0) 表示 NumPy,True (1) 表示 PyTorch。然后可以使用三元运算符来进一步缩小 if 语句。

def foo(x, y, mode: bool = 0):
    lib = torch if mode else np
    return lib.add(x, y) 

如果你想在课堂上的两者之间来回切换,你可以做类似的事情

class Example:

    def __init__(self):
        self._mode = True
    
    def switchMode(self):
        self._mode = !self._mode

    def foo(self, x, y):
        lib = torch if self._mode else np
        return lib.add(x, y) 

0
投票
  1. 您正在寻找的正是array-api-compat
import array_api_compat

def foo(x, y):
    xp = array_api_compat.array_namespace(x, y)
    return xp.add(x, y)
import numpy as np
import torch

print(foo(torch.tensor(1), torch.tensor(2)))
print(foo(np.array(1), np.array(2)))
$ python test.py
tensor(3)
3
  1. 您还可以使用ivy(一个主要用于将各种数组包装在自己的
    ivy.Array
    中的库),但您必须非常小心,因为它不是FOSS
import ivy
from ivy import Array, NativeArray

ivy.set_array_mode(False) # Without this, the function will return ivy.array

def foo(x: Array | NativeArray, y: Array | NativeArray) -> Array:
    return ivy.add(x, y)
import numpy as np
import torch

print(foo(torch.tensor(1), torch.tensor(2)))
print(foo(np.array(1), np.array(2)))
$ python test2.py
tensor(3)
3
© www.soinside.com 2019 - 2024. All rights reserved.