Nx 张量的映射切片

问题描述 投票:0回答:0

如果我有一个需要特定形状的 Nx 张量的函数,并且我有一个更大的张量,其中包含该形状的切片,是否有一种有效的方法可以将某些函数映射到这些切片上?

我想到的具体功能是

Nx.Random.uniform_split/4
,它需要一个形状正好是
{2}
的键。在纯功能代码中,您可以想象创建一个键列表,然后调用
Enum.map/2
,比如

def random_values(seed) do
  key0 = Nx.Random.key(seed)
  keys = 0..5 |> Enum.map(&Nx.Random.fold_in(key0, &1))
  Enum.map(keys, &Nx.Random.uniform_split(&1, 0, 1))
end

但请注意,

keys
是恰好 6 个形状为
{2}
的张量的列表;创建一个形状为
{6, 2}
的张量会更有效。事实上
Nx.Random.fold_in/2
支持创建这个矩阵

defn random_values(key0)
  keys = Nx.Random.fold_in(key0, Nx.iota({6)))
  # Nx.shape(keys) ==> `{6, 2}`

现在我希望能够在这个矩阵中的每个单独的向量上调用

Nx.Random.uniform_split/4
,得到某种标量集合,然后
Nx.concatenate/2
一起成为一个结果向量。有没有有效的方法来做到这一点?

Nx.map/3
仅适用于单个元素,不适用于切片,因此这里不是一个选择。
Nx.Defn.Kernel.while/4
循环是可能的,但笨拙,并且似乎涉及预分配结果张量并将单个值放入其中,因此这不是最快的事情。我可以写一个普通的递归函数,但那更慢。

我没有 GPU。 EXLA 目前无法在我的系统上运行,所以我使用的是默认后端,我知道这也不是最快的。我确实有志于有一天在 EXLA 下运行它,我认为这意味着最大限度地减少数据进出

defn
函数的次数。我不认为涉及的矩阵很大,但它们也不是微不足道的。作为比较,一个测试用例创建一个 131x131 矩阵并在 7-8 秒内运行这样的代码;如果我不想控制按键,我可以调用
Nx.Random.uniform_split(..., shape: {131, 131})
并立即有效地返回响应。

elixir
© www.soinside.com 2019 - 2024. All rights reserved.