Rust 中的残差神经网络与 tch-rs

问题描述 投票:0回答:1

我正在尝试使用 tch-rs (PyTorch) 在 Rust 中实现前馈残差神经网络。

到目前为止,这是我的代码:

fn res_block(vs: &nn::Path) -> impl ModuleT {
    let mut default = ConvConfigND::default();
    default.padding = 1;
    let conv1 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, default);
    let bn1 = batch_norm1d(vs, NUM_HIDDEN, BatchNormConfig::default());
    let conv2 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, default);
    let bn2 = batch_norm1d(vs, NUM_HIDDEN, BatchNormConfig::default());
    nn::func_t(|x,train| {
        let mut residual = Tensor::new();
        x.clone(&residual);
        let x = bn1.forward_t(&conv1.forward(x),train).relu();
        let x = bn2.forward_t(&conv2.forward(&x),train);
        //let x = x + residual;
        return x.relu();
    })

当我编译此代码时,出现此错误:

error[E0277]: `*mut torch_sys::C_tensor` cannot be shared between threads safely
  --> src/nn.rs:19:16
   |
19 |     // Your code here
   |                ^^^^^^^ `*mut torch_sys::C_tensor` cannot be shared between threads safely
   |
   = help: within `BatchNorm`, the trait `Sync` is not implemented for `*mut torch_sys::C_tensor`
   = note: required for `&BatchNorm` to implement `Send`

当我将forward_t行放入func_t中时,会发生此问题。 我该如何进行这项工作? 我也尝试使用顺序网络,但它们无法进一步传递残差变量。有办法让它发挥作用吗?或者我需要做别的事情吗? 谢谢!

machine-learning rust pytorch
1个回答
0
投票

根据您分享的内容,我猜测

nn::func_t
是异步或并行。

.await
中断保存的类型或在线程之间共享的类型必须具有
Sync + Send
特征以保证内存安全。

无法在线程之间安全地传递任何

*mut
类型。

根据您想要做什么以及

*mut C_tensor
的起源,有一个解决方案。

如果您同意 new 创建新数据,则可以

.clone()
它,但如果您想就地修改相同的数据,则需要以线程安全的方式包装它,例如使用
Arc<Mutex<T>>

© www.soinside.com 2019 - 2024. All rights reserved.