我正在尝试使用 tch-rs (PyTorch) 在 Rust 中实现前馈残差神经网络。
到目前为止,这是我的代码:
fn res_block(vs: &nn::Path) -> impl ModuleT {
let mut default = ConvConfigND::default();
default.padding = 1;
let conv1 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, default);
let bn1 = batch_norm1d(vs, NUM_HIDDEN, BatchNormConfig::default());
let conv2 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, default);
let bn2 = batch_norm1d(vs, NUM_HIDDEN, BatchNormConfig::default());
nn::func_t(|x,train| {
let mut residual = Tensor::new();
x.clone(&residual);
let x = bn1.forward_t(&conv1.forward(x),train).relu();
let x = bn2.forward_t(&conv2.forward(&x),train);
//let x = x + residual;
return x.relu();
})
当我编译此代码时,出现此错误:
error[E0277]: `*mut torch_sys::C_tensor` cannot be shared between threads safely
--> src/nn.rs:19:16
|
19 | // Your code here
| ^^^^^^^ `*mut torch_sys::C_tensor` cannot be shared between threads safely
|
= help: within `BatchNorm`, the trait `Sync` is not implemented for `*mut torch_sys::C_tensor`
= note: required for `&BatchNorm` to implement `Send`
当我将forward_t行放入func_t中时,会发生此问题。 我该如何进行这项工作? 我也尝试使用顺序网络,但它们无法进一步传递残差变量。有办法让它发挥作用吗?或者我需要做别的事情吗? 谢谢!
根据您分享的内容,我猜测
nn::func_t
是异步或并行。
跨
.await
中断保存的类型或在线程之间共享的类型必须具有 Sync + Send
特征以保证内存安全。
无法在线程之间安全地传递任何
*mut
类型。
根据您想要做什么以及
*mut C_tensor
的起源,有一个解决方案。
如果您同意 new 创建新数据,则可以
.clone()
它,但如果您想就地修改相同的数据,则需要以线程安全的方式包装它,例如使用 Arc<Mutex<T>>
。