libtorch Conv1D 不能在信号长度维度上运行

问题描述 投票:0回答:0

我有一个火炬模型,它只包含一个 conv1d(旨在用于实现 STFT 傅里叶变换)。 该模型在 Torch 中运行良好,并在 Python 中使用 torch.jit.load 进行跟踪。

当我尝试通过打字稿 (https://www.npmjs.com/package/react-native-pytorch-core) 在 iOS 上使用带有 libtorch 的模型时,我没有得到预期的输出。

第一个输出通道是正确的(即在这种情况下,它等于前 2048 个样本与卷积的点积),但其余输出通道应与沿信号(及时)滑动的内核相对应,与第一个!)

在 Python / torch 中...

import torch


class Model(torch.nn.Module):
    def __init__(self):
        n_fft = 2048
        hop_length = 512
        self.conv = torch.nn.Conv1d(in_channels=1, out_channels=n_fft // 2 + 1,
            kernel_size=n_fft, stride=hop_length, padding=0, dilation=1,
            groups=1, bias=False)
    
    def forward(self, x):
        return self.conv(x)

model = Model();
torch.jit.script(model)._save_for_lite_interpreter('model.ptl')

In inference, react native typescript


import { torch } from 'react-native-pytorch-core';

let x = torch.linspace(0,1,96000).unsqueeze(0);

model.forward(x).then((e) => {
    console.log(e.shape) // this the correct shape [1, 184, 1025] [batch, time, bin]

    let data = e.data();

    data[0,0,0] == data[0,0,1] // this is false, as expected

    data[0,0,0] == data[0,1,0] // this is true, for any time step. NOT expected, if the convolution kernel sliding is correct

});
python typescript react-native torch libtorch
© www.soinside.com 2019 - 2024. All rights reserved.