我有一个火炬模型,它只包含一个 conv1d(旨在用于实现 STFT 傅里叶变换)。 该模型在 Torch 中运行良好,并在 Python 中使用 torch.jit.load 进行跟踪。
当我尝试通过打字稿 (https://www.npmjs.com/package/react-native-pytorch-core) 在 iOS 上使用带有 libtorch 的模型时,我没有得到预期的输出。
第一个输出通道是正确的(即在这种情况下,它等于前 2048 个样本与卷积的点积),但其余输出通道应与沿信号(及时)滑动的内核相对应,与第一个!)
在 Python / torch 中...
import torch
class Model(torch.nn.Module):
def __init__(self):
n_fft = 2048
hop_length = 512
self.conv = torch.nn.Conv1d(in_channels=1, out_channels=n_fft // 2 + 1,
kernel_size=n_fft, stride=hop_length, padding=0, dilation=1,
groups=1, bias=False)
def forward(self, x):
return self.conv(x)
model = Model();
torch.jit.script(model)._save_for_lite_interpreter('model.ptl')
In inference, react native typescript
import { torch } from 'react-native-pytorch-core';
let x = torch.linspace(0,1,96000).unsqueeze(0);
model.forward(x).then((e) => {
console.log(e.shape) // this the correct shape [1, 184, 1025] [batch, time, bin]
let data = e.data();
data[0,0,0] == data[0,0,1] // this is false, as expected
data[0,0,0] == data[0,1,0] // this is true, for any time step. NOT expected, if the convolution kernel sliding is correct
});