Rust axum 中间件在应用于路由时总是返回错误

问题描述 投票:0回答:1

我可能很愚蠢,但我已经盯着这个问题好几天了,但无法弄清楚。 axum 中的中间件功能最近发生了很大变化,以至于 90% 的互联网都已经过时了。这些文档也不会帮助我,可能是因为我是 Rust 新手,所以我可能错过了一些东西。我正在使用 axum 0.7.5。无论如何,这是代码:

main.rs:

#[tokio::main]
async fn main() {
    // Initialize environment variables from .env file
    dotenv::dotenv().ok();

    if let Err(e) = run().await {
        eprintln!("Application error: {:?}", e); 
    }
}

lib.rs:

pub async fn run() -> Result<(), Box<dyn Error>> {
    // Establish the database connection pool
    let db_pool = Arc::new(db::connection::establish_connection().await?);

    // Create the application router and add the connection pool as an extension
    let app = create_routes(db_pool.clone());

    // Use the ? operator to propagate errors instead of unwrapping
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
    axum::serve(
        listener,
        app.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .await?;

    Ok(())
}

state.rs:

use std::sync::Arc;

#[derive(Clone)]
pub struct AppState {
    pub access_token_secret: Arc<String>,
    pub refresh_token_secret: Arc<String>,
}

路线/mod.rs:

pub mod analytics;
pub mod auth;
pub mod legal;
pub mod misc;
pub mod users;

use std::{env, sync::Arc};

use axum::middleware::from_fn;
use axum::{Extension, Router};
use axum_client_ip::SecureClientIpSource;
use sqlx::MySqlPool;

use crate::{
    middleware::{auth_middleware::auth_middleware, input_validation::input_validation},
    state::AppState,
};

pub fn create_routes(pool: Arc<MySqlPool>) -> Router {
    // Load the token secrets from environment variables
    let access_token_secret =
        env::var("ACCESS_TOKEN_SECRET").expect("ACCESS_TOKEN_SECRET must be set");
    let refresh_token_secret =
        env::var("REFRESH_TOKEN_SECRET").expect("REFRESH_TOKEN_SECRET must be set");

    let app_state = AppState {
        access_token_secret: Arc::new(access_token_secret),
        refresh_token_secret: Arc::new(refresh_token_secret),
    };

    Router::new()
        .nest("/v1/users", users::routes())
        .nest("/v1/auth", auth::routes(pool.clone()))
        .nest("/v1/legal", legal::routes(pool.clone()))
        .nest("/v1", misc::routes())
        .layer(from_fn(input_validation))
        .layer(from_fn(auth_middleware))
        .layer(Extension(app_state.clone()))
        .layer(SecureClientIpSource::ConnectInfo.into_extension())
}

middleware/input_validation.rs(这个有效):

use axum::{
    body::to_bytes, body::Body, http::Request, http::StatusCode, middleware::Next,
    response::Response,
};
use hyper::header::{HeaderValue, CONTENT_TYPE};
use serde_json::Value;
use std::collections::HashMap;
use url::form_urlencoded;

use crate::validation::validate_fields::validate_fields;

const MAX_BODY_SIZE: usize = 1024 * 1024; // 1 MB

pub async fn input_validation(req: Request<Body>, next: Next) -> Result<Response, StatusCode> {
    let mut fields = HashMap::new();

    // Extract query parameters
    if let Some(query) = req.uri().query() {
        for (key, value) in form_urlencoded::parse(query.as_bytes()) {
            fields.insert(key.into_owned(), value.into_owned());
        }
    }

    // Extract content type before moving the request
    let content_type = req.headers().get(CONTENT_TYPE).cloned();

    // Split the request into parts
    let (parts, body) = req.into_parts();

    // Extract body if it's JSON
    if content_type == Some(HeaderValue::from_static("application/json")) {
        // Attempt to read the whole body with a size limit
        let whole_body = match to_bytes(body, MAX_BODY_SIZE).await {
            Ok(bytes) => bytes,
            Err(_) => return Err(StatusCode::BAD_REQUEST),
        };

        // Check if the body exceeds the maximum size
        if whole_body.len() > MAX_BODY_SIZE {
            return Err(StatusCode::PAYLOAD_TOO_LARGE);
        }

        // Parse the JSON body
        let json_body: Value = match serde_json::from_slice(&whole_body) {
            Ok(json) => json,
            Err(_) => return Err(StatusCode::BAD_REQUEST),
        };

        // Extract fields from the JSON object
        if let Some(object) = json_body.as_object() {
            for (key, value) in object.iter() {
                fields.insert(key.clone(), value.to_string());
            }
        }

        // Validate all extracted fields and clean them
        match validate_fields(&fields) {
            Ok(cleaned_fields) => {
                // Optionally reconstruct JSON body with cleaned values
                let new_body = match serde_json::to_string(&cleaned_fields) {
                    Ok(json_str) => json_str,
                    Err(_) => return Err(StatusCode::BAD_REQUEST),
                };
                // Reconstruct the request with the cleaned body
                let new_request = Request::from_parts(parts, Body::from(new_body));

                return Ok(next.run(new_request).await);
            }
            Err(ref err) if err == "Invalid credentials." => Err(StatusCode::UNAUTHORIZED),
            Err(_) => Err(StatusCode::BAD_REQUEST),
        }
    } else {
        // Validate all extracted fields for non-JSON bodies
        match validate_fields(&fields) {
            Ok(_) => {
                // Reconstruct the request with the original body
                let original_request = Request::from_parts(parts, body);

                return Ok(next.run(original_request).await);
            }
            Err(ref err) if err == "Invalid credentials." => Err(StatusCode::UNAUTHORIZED),
            Err(_) => Err(StatusCode::BAD_REQUEST),
        }
    }
}

middleware/auth_validation.rs(这个不起作用):

use crate::{services::token_services::decode_token::decode_token, state::AppState};
use axum::{
    body::Body,
    http::{Request, StatusCode},
    middleware::Next,
    response::Response,
};

pub async fn auth_middleware(mut req: Request<Body>, next: Next) -> Result<Response, StatusCode> {
    // Get the state from request extensions
    let state = req
        .extensions()
        .get::<AppState>()
        .cloned()
        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;

    // Get the request path
    let path = req.uri().path();

    // Define public paths that bypass authentication
    let public_paths = [
        "/v1/auth/login",
        "/v1/auth/register",
        // Add more public routes if necessary
    ];

    // If the path is public, skip authentication
    if public_paths.contains(&path) {
        return Ok(next.run(req).await);
    }

    // Extract the "x-access-token" header
    let token = req
        .headers()
        .get("x-access-token")
        .and_then(|t| t.to_str().ok())
        .ok_or(StatusCode::UNAUTHORIZED)?;

    // Decode and validate the token using your existing decode_token function
    match decode_token(
        token,
        true, // Indicates that we're validating an access token
        &state.access_token_secret,
        &state.refresh_token_secret,
    ) {
        Ok(token_content) => {
            // Insert the TokenContent into request extensions for access in handlers
            req.extensions_mut().insert(token_content);
            Ok(next.run(req).await)
        }
        Err(_) => Err(StatusCode::UNAUTHORIZED),
    }
}

rust 分析器只是抱怨routes/mod.rs。我收到此编译错误:

error[E0277]: the trait bound `axum::middleware::FromFn<fn(hyper::Request<axum::body::Body>, Next) -> impl Future<Output = Result<Response<axum::body::Body>, StatusCode>> {auth_middleware}, (), Route, _>: tower_service::Service<hyper::Request<axum::body::Body>>` is not satisfied
   --> src/routes/mod.rs:40:16
    |
40  |         .layer(from_fn(auth_middleware))
    |          ----- ^^^^^^^^^^^^^^^^^^^^^^^^ the trait `tower_service::Service<hyper::Request<axum::body::Body>>` is not implemented for `axum::middleware::FromFn<fn(hyper::Request<axum::body::Body>, Next) -> impl Future<Output = Result<Response<axum::body::Body>, StatusCode>> {auth_middleware}, (), Route, _>`
    |          |
    |          required by a bound introduced by this call
    |
    = help: the following other types implement trait `tower_service::Service<Request>`:
              axum::middleware::FromFn<F, S, I, (T1, T2)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8, T9)>
            and 8 others
note: required by a bound in `Router::<S>::layer`
   --> /home/admin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/axum-0.7.7/src/routing/mod.rs:279:21
    |
276 |     pub fn layer<L>(self, layer: L) -> Router<S>
    |            ----- required by a bound in this associated function
...
279 |         L::Service: Service<Request> + Clone + Send + 'static,
    |                     ^^^^^^^^^^^^^^^^ required by this bound in `Router::<S>::layer`

For more information about this error, try `rustc --explain E0277`.

欢迎任何和所有建议。我花了好几天的时间来思考这个问题。我尝试过旧的谷歌,阅读文档和很多人工智能机器人:chatGPT 4、4o、o1-mini、o1-preview、claude、perplexity 等。没有任何帮助我。我刚刚从一个问题转移到另一个问题,我相信我当前的代码已经接近工作,但缺少一些东西。

rust rust-axum
1个回答
0
投票

您寻求建议,因此,尽管信息有限,我的猜测是:

  1. token_content
    !Send
  2. Rust 的近似静态分析认为它保持在随后的
    .await
  3. 传递给
    from_fn
    的未来因此是
    !Send
  4. FromFn<...>
    不再实现塔
    Service

您可以通过证明

token_content
Send
来伪造此解释。

fn is_send<T: Sync>(_t: T) {}
is_send(token_content.clone());
© www.soinside.com 2019 - 2024. All rights reserved.