环境:
macOS 索诺玛版本 14.0 (M1 mac) Rust 版本 1.65.0
我想做的事: 我想在多线程之间共享带有 [u8;128] 元素数组的 vec。 分享时我要执行的要求如下
下面是我写的代码,不过这段代码读起来可以,但是有一个问题就是写的没有体现出来。 如果我运行此代码,然后在执行它的计算机上运行以下命令一次
nc -v 本地主机 50051
[[0u8; 128],[1u8; 128],[2u8; 128]]将输出
。 到目前为止这是正确的,但第二次运行的数据输出与第一次运行相同。 我的意图是第二个元素将输出具有 3 个填充的数据,如下所示,因为我正在第一次运行中更新数据。
[[0u8; 128],[3u8; 128],[2u8; 128]]
我猜测我对 Arc 的使用是错误的,它实际上是传递 SharedData 的克隆,而不是对 SharedData 的引用,但我不知道如何识别这一点。 我该如何修复代码以使其按我的预期工作?
main.rs:
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use tokio_task_pool::Pool;
struct SharedData {
data: Arc<RwLock<Vec<[u8; 128]>>>
}
impl SharedData {
fn new(data: RwLock<Vec<[u8; 128]>>) -> Self {
Self {
data: Arc::new(data)
}
}
fn update(&self, index: usize, update_data: [u8; 128]) {
let read_guard_for_array = self.data.read().unwrap();
let write_lock = RwLock::new((*read_guard_for_array)[index]);
let mut write_guard_for_item = write_lock.write().unwrap();
*write_guard_for_item = update_data;
}
}
fn socket_to_async_tcplistener(s: socket2::Socket) -> std::io::Result<tokio::net::TcpListener> {
std::net::TcpListener::from(s).try_into()
}
async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
let read_guard = db_arc.data.read().unwrap();
println!("In process() read: {:?}", *read_guard);
db_arc.update(1, [3u8; 128]);
}
async fn serve(_: usize, tcplistener_arc: Arc<tokio::net::TcpListener>, db_arc: Arc<SharedData>) {
let task_pool_capacity = 10;
let task_pool = Pool::bounded(task_pool_capacity)
.with_spawn_timeout(Duration::from_secs(300))
.with_run_timeout(Duration::from_secs(300));
loop {
let (stream, _) = tcplistener_arc.as_ref().accept().await.unwrap();
let db_arc_clone = db_arc.clone();
task_pool.spawn(async move {
process(stream, db_arc_clone).await;
}).await.unwrap();
}
}
#[tokio::main]
async fn main() {
let addr: std::net::SocketAddr = "0.0.0.0:50051".parse().unwrap();
let soc2 = socket2::Socket::new(
match addr {
SocketAddr::V4(_) => socket2::Domain::IPV4,
SocketAddr::V6(_) => socket2::Domain::IPV6,
},
socket2::Type::STREAM,
Some(socket2::Protocol::TCP)
).unwrap();
soc2.set_reuse_address(true).unwrap();
soc2.set_reuse_port(true).unwrap();
soc2.set_nonblocking(true).unwrap();
soc2.bind(&addr.into()).unwrap();
soc2.listen(8192).unwrap();
let tcp_listener = Arc::new(socket_to_async_tcplistener(soc2).unwrap());
let mut vec = vec![
[0u8; 128],
[1u8; 128],
[2u8; 128],
];
let share_db = Arc::new(SharedData::new(RwLock::new(vec)));
let mut handlers = Vec::new();
for i in 0..num_cpus::get() - 1 {
let cloned_listener = Arc::clone(&tcp_listener);
let db_arc = share_db.clone();
let h = std::thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(serve(i, cloned_listener, db_arc));
});
handlers.push(h);
}
for h in handlers {
h.join().unwrap();
}
}
Cargo.toml:
[package]
name = "tokio-test"
version = "0.1.0"
edition = "2021"
[dependencies]
log = "0.4.20"
env_logger = "0.10.0"
tokio = { version = "1.34.0", features = ["full"] }
tokio-stream = { version = "0.1.14", features = ["net"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_yaml = "0.9.27"
serde_derive = "1.0.193"
mio = {version="0.8.9", features=["net", "os-poll", "os-ext"]}
num_cpus = "1.16.0"
socket2 = { version="0.5.5", features = ["all"]}
array-macro = "2.1.8"
tokio-task-pool = "0.1.5"
argparse = "0.2.2"
fn update(&self, index: usize, update_data: [u8; 128]) {
let read_guard_for_array = self.data.read().unwrap();
let write_lock = RwLock::new((*read_guard_for_array)[index]);
这会创建数据的副本并将其包装在无用的
RwLock
中(无用是因为该副本始终保存在单个线程中。
let mut write_guard_for_item = write_lock.write().unwrap();
*write_guard_for_item = update_data;
}
这会修改副本,然后在函数结束时立即将其丢弃。
相反,您需要锁定已有的
RwLock
:
fn update(&self, index: usize, update_data: [u8; 128]) {
let mut write_guard = self.data.write().unwrap();
write_guard[index] = update_data;
}
请注意,无法仅获得特定项目的写锁和整个数组的读锁:读锁和写锁必须与相同的数据相关。这意味着您还需要释放读锁才能更新:
async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
let read_guard = db_arc.data.read().unwrap();
println!("In process() read: {:?}", *read_guard);
drop (read_guard);
db_arc.update(1, [3u8; 128]);
}