无法在 Rust 中建立 websocket 连接

问题描述 投票:0回答:1
#[tokio::main]
async fn main() {
    println!("Starting Tokio runtime...");
    if tokio::runtime::Handle::try_current().is_ok() {
        println!("Tokio runtime is active.");
    } else {
        println!("No active Tokio runtime.");
    }
    let local = tokio::task::LocalSet::new(); 
    println!("LocalSet created.");
    dotenv().ok();
    let services = Arc::new(services::new_services());
    let connected_clients: ConnectedClients = Arc::new(Mutex::new(HashMap::new()));

    let make_svc = make_service_fn(|_conn| {
        let services = Arc::clone(&services);
        let clients = Arc::clone(&connected_clients);
        async move {
            Ok::<_, hyper::Error>(service_fn(move |req: Request<Body>| {
                handle_request(req, services.clone(), clients.clone())
            }))
        }
    });

    let port: u16 = env::var("PORT")
        .unwrap_or_else(|_| "8080".to_string())
        .parse()
        .expect("PORT must be a valid number");

    let addr = format!("127.0.0.1:{}", port).parse().expect("Invalid address");

    local.run_until(async {
        let server = Server::bind(&addr).serve(make_svc);

        println!("Server listening on {}", addr);
        if let Err(e) = server.await {
            eprintln!("Server error: {}", e);
        }
    }).await;
}


async fn handle_request(
    mut req: Request<Body>, 
    _services: Arc<Services>, 
    clients: ConnectedClients
) -> Result<Response<Body>, hyper::Error> {
    if req.uri().path() == "/ws" {
        println!("/ws path entered");

        // Check if it's a WebSocket request
        if !is_websocket_request(&req) {
            return Ok(Response::builder()
                .status(StatusCode::BAD_REQUEST)
                .body(Body::from("Not a WebSocket request"))
                .unwrap());
        }
        println!("ITS A WS CONNECTION!");
        // Extract token from query parameters
        let token = extract_token_from_query(req.uri());
        let (mut user_id, 
            mut org_id, 
            mut stores, 
            mut first_name, 
            mut last_name, 
            mut is_admin, 
            mut route,
            mut device_id
        ) = (
            String::new(),
            String::new(),
            Vec::new(),
            String::new(),
            String::new(),
            false,
            String::new(),
            None,
        );

        match token {
            Some(ref t) => {
                // Token validation logic
                println!("Entered into match token");
                let secret_key = std::env::var("JWT_SECRET").unwrap_or_else(|_| "default_secret".to_string());
                let decoding_key = DecodingKey::from_secret(secret_key.as_ref());
                let validation = Validation::default();

                
                match decode::<Claims>(t, &decoding_key, &validation) {
                    Ok(token_data) => {        
                        let claims = token_data.claims;
                      
                        // Validate organization claims
                        if let Some(ref organization) = claims.organization {
                            if organization.id.is_none() || organization.stores.is_none() {
                                println!("ERROR1");
                                return Ok(Response::builder()
                                    .status(StatusCode::UNAUTHORIZED)
                                    .body(Body::from("Invalid token: missing organization data"))
                                    .unwrap());
                            }
                        } else {
                            println!("ERROR2");
                            return Ok(Response::builder()
                                .status(StatusCode::UNAUTHORIZED)
                                .body(Body::from("Invalid token: missing organization"))
                                .unwrap());
                        }

                        // Validate user claims
                        if let Some(ref user) = claims.user {
                            if user.id.is_none() || user.firstName.is_none() || user.lastName.is_none() {
                                println!("ERROR3");
                                return Ok(Response::builder()
                                    .status(StatusCode::UNAUTHORIZED)
                                    .body(Body::from("Invalid token: missing user data"))
                                    .unwrap());
                            }
                        } else {
                            println!("ERROR4");
                            return Ok(Response::builder()
                                .status(StatusCode::UNAUTHORIZED)
                                .body(Body::from("Invalid token: missing user"))
                                .unwrap());
                        }

                        is_admin = claims.organization.as_ref()
                            .and_then(|o| o.privileges.as_ref())
                            .map_or(false, |p| p.admin.unwrap_or(false));

                        // Move these variables to the broader scope
                        org_id = claims.organization.as_ref()
                            .and_then(|o| o.id.clone())
                            .unwrap_or_else(|| "default_org_id".to_string());

                        stores = claims.organization.as_ref()
                            .and_then(|o| o.stores.clone())
                            .unwrap_or_default();

                        route = claims.organization.as_ref()
                            .and_then(|o| o.route.clone())
                            .unwrap_or_default();

                        user_id = claims.user.as_ref()
                            .and_then(|u| u.id.clone())
                            .unwrap_or_else(|| "default_user_id".to_string());

                        first_name = claims.user.as_ref()
                            .and_then(|u| u.firstName.clone())
                            .unwrap_or_else(|| "default_first_name".to_string());

                        last_name = claims.user.as_ref()
                            .and_then(|u| u.lastName.clone())
                            .unwrap_or_else(|| "default_last_name".to_string());

                        device_id = claims.device_id;

                        println!("Token validated successfully");

                        // Now you can access the variables like `user_id`, `org_id`, etc.

                        let user_interface = UserAlias::new(
                            user_id, 
                            org_id, 
                            stores, 
                            first_name, 
                            last_name, 
                            is_admin, 
                            route,
                        );
                        // let user_connection = Connection::new(
                        //     user_id, 
                        //     user_id, 
                        //     // conn, 
                        //     device_id, 
                        //     None, 
                        //     None,
                        // );
                        
                    }
                    Err(e) => {
                        println!("FAILED 1 {:}", e);
                        return Ok(Response::builder()
                            .status(StatusCode::UNAUTHORIZED)
                            .body(Body::from("Invalid token"))
                            .unwrap());
                    }
                }
            }
            None => {
                println!("FAILED");
                return Ok(Response::builder()
                    .status(StatusCode::BAD_REQUEST)
                    .body(Body::from("Missing token in query parameters"))
                    .unwrap());
            }
        }

        // Get the WebSocket key from headers
        let ws_key = req.headers().get(SEC_WEBSOCKET_KEY).and_then(|v| v.to_str().ok()).unwrap();

        // Generate accept key
        let accept_key = derive_accept_key(ws_key.as_bytes());

        println!("WebSocket Key: {:?}", ws_key);
        println!("Accept Key: {}", accept_key);

        let mut response = Response::builder()
            .status(StatusCode::SWITCHING_PROTOCOLS)
            .header(CONNECTION, "Upgrade")
            .header(UPGRADE, "websocket")
            .header(SEC_WEBSOCKET_ACCEPT, accept_key);

        // Add protocol if present
        if let Some(protocol) = req.headers().get(SEC_WEBSOCKET_PROTOCOL) {
            response = response.header(SEC_WEBSOCKET_PROTOCOL, protocol);
        }

        let response = response.body(Body::empty()).unwrap();
        println!("response : {:#?}", response);

        
    // Perform WebSocket upgrade and wrapping outside `tokio::spawn`
   
    let device_id_clone = device_id.clone(); 
    // let user_id_clone = user_id.clone();
    // let user_id_clone1 = user_id.clone();
    // let local = LocalSet::new(); 

    // let local: tokio::task::JoinHandle<()> = 
    let local = LocalSet::new();    

    println!("local {:#?}", local);
    if tokio::runtime::Handle::try_current().is_err() {
        println!("No active Tokio runtime");
    }
    println!("Incoming request: {:?}", req);
   
    
    
    local.spawn_local(async  move{
        println!("Entered into local spawn");
        match on(req).await {
            Ok(upgraded) => {
                println!("Connection upgraded");
    
                // Create WebSocket stream directly
                let ws_stream: WebSocketStream<Upgraded> = WebSocketStream::from_raw_socket(
                    upgraded,
                    Role::Server,
                    Some(WebSocketConfig::default())
                ).await;

               

                let shared_ws_stream = Arc::new(RwLock::new(ws_stream));
                let shared_ws_stream_clone = Arc::clone(&shared_ws_stream);
               
                // Wrap the WebSocket stream in an Arc<RwLock> for safe multi-threaded access
                // let ws_stream_clone = ws_stream.clone();
               
                let user_connection: Connection = Connection::new(
                    None, 
                    None, 
                    shared_ws_stream,
                    device_id.clone(), 
                    None, 
                    None,
                );
    
                println!("WebSocket stream created");
    
                // Ensure you handle the WebSocket stream correctly here
            
            if let Err(e) = handle_websocket_stream(shared_ws_stream_clone, clients).await {
                eprintln!("WebSocket handler error: {}", e);
            }
            }
            Err(e) => eprintln!("Failed to upgrade connection: {}", e),
        }
    });
   
    Ok(response)
} else {
    Ok(Response::new(Body::from("404 not found!")))
}
}

尝试在rust中建立websocket连接,但连接1-2毫秒后突然断开 以下是终端日志:

Starting Tokio runtime...
Tokio runtime is active.
LocalSet created.
Server listening on 127.0.0.1:8080
/ws path entered
WebSocket request validation:
  Upgrade header present: true
  Connection upgrade: true
  Has WebSocket key: true
  Has correct version: true
ITS A WS CONNECTION!
Entered into match token
Token validated successfully
CALLED
WebSocket Key: "rTNBeHrwfawHZ9uhlw3M1Q=="
Accept Key: pJhyS7Aq49YTSEt5997ylXf8Z7g=
response : Response {
    status: 101,
    version: HTTP/1.1,
    headers: {
        "connection": "Upgrade",
        "upgrade": "websocket",
        "sec-websocket-accept": "pJhyS7Aq49YTSEt5997ylXf8Z7g=",
    },
    body: Body(
        Empty,
    ),
}
local LocalSet
Incoming request: Request { method: GET, uri: /ws?token=JWT_TOKEN, version: HTTP/1.1, headers: {"sec-websocket-version": "13", "sec-websocket-key": "rTNBeHrwfawHZ9uhlw3M1Q==", "connection": "Upgrade", "upgrade": "websocket", "sec-websocket-extensions": "permessage-deflate; client_max_window_bits", "host": "localhost:8080"}, body: Body(Empty) }

代码没有进入 local.spawn_local 块,因此根据我的说法,它无法建立连接。如果有人知道如何解决这个问题,请分享。 谢谢

rust websocket rust-tokio
1个回答
0
投票

A

LocalSet
不会做任何事情,直到你真正以某种方式等待它。 如果
local.spawn_local
本身没有被驱动,则
LocalSet
不会开始驱动所提供的未来。 这就是代码无法运行的原因——您将 future 发送到
LocalSet
,然后将其删除。 这与编写
let foo = async move { ... };
,从不等待
foo
,然后想知道为什么异步块中的代码不运行没有太大区别。

例如,您可以在

local.await;
之前添加
Ok(response)
。 请注意,这将阻塞外部函数,直到
LocalSet
的 future 完成。从您的(又长又复杂的)示例代码中尚不清楚这是否实际上是问题的正确解决方案,但是它会导致代码运行,从而解决您所看到的主要问题。

也不清楚您是否需要

LocalSet
或者是否可以使用其他机制(例如
tokio::spawn
或其他未来的组合器)。

© www.soinside.com 2019 - 2024. All rights reserved.