我可能很愚蠢,但我已经盯着这个问题好几天了,但无法弄清楚。 axum 中的中间件功能最近发生了很大变化,以至于 90% 的互联网都已经过时了。这些文档也不会帮助我,可能是因为我是 Rust 新手,所以我可能错过了一些东西。我正在使用 axum 0.7.5。无论如何,这是代码:
main.rs:
#[tokio::main]
async fn main() {
// Initialize environment variables from .env file
dotenv::dotenv().ok();
if let Err(e) = run().await {
eprintln!("Application error: {:?}", e);
}
}
lib.rs:
pub async fn run() -> Result<(), Box<dyn Error>> {
// Establish the database connection pool
let db_pool = Arc::new(db::connection::establish_connection().await?);
// Create the application router and add the connection pool as an extension
let app = create_routes(db_pool.clone());
// Use the ? operator to propagate errors instead of unwrapping
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
Ok(())
}
state.rs:
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub access_token_secret: Arc<String>,
pub refresh_token_secret: Arc<String>,
}
路线/mod.rs:
pub mod analytics;
pub mod auth;
pub mod legal;
pub mod misc;
pub mod users;
use std::{env, sync::Arc};
use axum::middleware::from_fn;
use axum::{Extension, Router};
use axum_client_ip::SecureClientIpSource;
use sqlx::MySqlPool;
use crate::{
middleware::{auth_middleware::auth_middleware, input_validation::input_validation},
state::AppState,
};
pub fn create_routes(pool: Arc<MySqlPool>) -> Router {
// Load the token secrets from environment variables
let access_token_secret =
env::var("ACCESS_TOKEN_SECRET").expect("ACCESS_TOKEN_SECRET must be set");
let refresh_token_secret =
env::var("REFRESH_TOKEN_SECRET").expect("REFRESH_TOKEN_SECRET must be set");
let app_state = AppState {
access_token_secret: Arc::new(access_token_secret),
refresh_token_secret: Arc::new(refresh_token_secret),
};
Router::new()
.nest("/v1/users", users::routes())
.nest("/v1/auth", auth::routes(pool.clone()))
.nest("/v1/legal", legal::routes(pool.clone()))
.nest("/v1", misc::routes())
.layer(from_fn(input_validation))
.layer(from_fn(auth_middleware))
.layer(Extension(app_state.clone()))
.layer(SecureClientIpSource::ConnectInfo.into_extension())
}
middleware/input_validation.rs(这个有效):
use axum::{
body::to_bytes, body::Body, http::Request, http::StatusCode, middleware::Next,
response::Response,
};
use hyper::header::{HeaderValue, CONTENT_TYPE};
use serde_json::Value;
use std::collections::HashMap;
use url::form_urlencoded;
use crate::validation::validate_fields::validate_fields;
const MAX_BODY_SIZE: usize = 1024 * 1024; // 1 MB
pub async fn input_validation(req: Request<Body>, next: Next) -> Result<Response, StatusCode> {
let mut fields = HashMap::new();
// Extract query parameters
if let Some(query) = req.uri().query() {
for (key, value) in form_urlencoded::parse(query.as_bytes()) {
fields.insert(key.into_owned(), value.into_owned());
}
}
// Extract content type before moving the request
let content_type = req.headers().get(CONTENT_TYPE).cloned();
// Split the request into parts
let (parts, body) = req.into_parts();
// Extract body if it's JSON
if content_type == Some(HeaderValue::from_static("application/json")) {
// Attempt to read the whole body with a size limit
let whole_body = match to_bytes(body, MAX_BODY_SIZE).await {
Ok(bytes) => bytes,
Err(_) => return Err(StatusCode::BAD_REQUEST),
};
// Check if the body exceeds the maximum size
if whole_body.len() > MAX_BODY_SIZE {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
// Parse the JSON body
let json_body: Value = match serde_json::from_slice(&whole_body) {
Ok(json) => json,
Err(_) => return Err(StatusCode::BAD_REQUEST),
};
// Extract fields from the JSON object
if let Some(object) = json_body.as_object() {
for (key, value) in object.iter() {
fields.insert(key.clone(), value.to_string());
}
}
// Validate all extracted fields and clean them
match validate_fields(&fields) {
Ok(cleaned_fields) => {
// Optionally reconstruct JSON body with cleaned values
let new_body = match serde_json::to_string(&cleaned_fields) {
Ok(json_str) => json_str,
Err(_) => return Err(StatusCode::BAD_REQUEST),
};
// Reconstruct the request with the cleaned body
let new_request = Request::from_parts(parts, Body::from(new_body));
return Ok(next.run(new_request).await);
}
Err(ref err) if err == "Invalid credentials." => Err(StatusCode::UNAUTHORIZED),
Err(_) => Err(StatusCode::BAD_REQUEST),
}
} else {
// Validate all extracted fields for non-JSON bodies
match validate_fields(&fields) {
Ok(_) => {
// Reconstruct the request with the original body
let original_request = Request::from_parts(parts, body);
return Ok(next.run(original_request).await);
}
Err(ref err) if err == "Invalid credentials." => Err(StatusCode::UNAUTHORIZED),
Err(_) => Err(StatusCode::BAD_REQUEST),
}
}
}
middleware/auth_validation.rs(这个不起作用):
use crate::{services::token_services::decode_token::decode_token, state::AppState};
use axum::{
body::Body,
http::{Request, StatusCode},
middleware::Next,
response::Response,
};
pub async fn auth_middleware(mut req: Request<Body>, next: Next) -> Result<Response, StatusCode> {
// Get the state from request extensions
let state = req
.extensions()
.get::<AppState>()
.cloned()
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
// Get the request path
let path = req.uri().path();
// Define public paths that bypass authentication
let public_paths = [
"/v1/auth/login",
"/v1/auth/register",
// Add more public routes if necessary
];
// If the path is public, skip authentication
if public_paths.contains(&path) {
return Ok(next.run(req).await);
}
// Extract the "x-access-token" header
let token = req
.headers()
.get("x-access-token")
.and_then(|t| t.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
// Decode and validate the token using your existing decode_token function
match decode_token(
token,
true, // Indicates that we're validating an access token
&state.access_token_secret,
&state.refresh_token_secret,
) {
Ok(token_content) => {
// Insert the TokenContent into request extensions for access in handlers
req.extensions_mut().insert(token_content);
Ok(next.run(req).await)
}
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
}
rust 分析器只是抱怨routes/mod.rs。我收到此编译错误:
error[E0277]: the trait bound `axum::middleware::FromFn<fn(hyper::Request<axum::body::Body>, Next) -> impl Future<Output = Result<Response<axum::body::Body>, StatusCode>> {auth_middleware}, (), Route, _>: tower_service::Service<hyper::Request<axum::body::Body>>` is not satisfied
--> src/routes/mod.rs:40:16
|
40 | .layer(from_fn(auth_middleware))
| ----- ^^^^^^^^^^^^^^^^^^^^^^^^ the trait `tower_service::Service<hyper::Request<axum::body::Body>>` is not implemented for `axum::middleware::FromFn<fn(hyper::Request<axum::body::Body>, Next) -> impl Future<Output = Result<Response<axum::body::Body>, StatusCode>> {auth_middleware}, (), Route, _>`
| |
| required by a bound introduced by this call
|
= help: the following other types implement trait `tower_service::Service<Request>`:
axum::middleware::FromFn<F, S, I, (T1, T2)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8, T9)>
and 8 others
note: required by a bound in `Router::<S>::layer`
--> /home/admin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/axum-0.7.7/src/routing/mod.rs:279:21
|
276 | pub fn layer<L>(self, layer: L) -> Router<S>
| ----- required by a bound in this associated function
...
279 | L::Service: Service<Request> + Clone + Send + 'static,
| ^^^^^^^^^^^^^^^^ required by this bound in `Router::<S>::layer`
For more information about this error, try `rustc --explain E0277`.
欢迎任何和所有建议。我花了好几天的时间来思考这个问题。我尝试过旧的谷歌,阅读文档和很多人工智能机器人:chatGPT 4、4o、o1-mini、o1-preview、claude、perplexity 等。没有任何帮助我。我刚刚从一个问题转移到另一个问题,我相信我当前的代码已经接近工作,但缺少一些东西。
您寻求建议,因此,尽管信息有限,我的猜测是:
token_content
是 !Send
.await
点from_fn
的未来因此是 !Send
FromFn<...>
不再实现塔Service
您可以通过证明
token_content
是 Send
来伪造此解释。
fn is_send<T: Sync>(_t: T) {}
is_send(token_content.clone());