跳转至

Axum Web 开发

Axum Web开发

📋 章节信息

属性 详情
学习时间 8-10小时
难度等级 ⭐⭐⭐⭐ 中高级
前置知识 异步编程(async/await)、Trait系统、错误处理、序列化(serde)
核心主题 Axum框架、路由、Handler、中间件、数据库、认证、部署

1. Axum 框架简介与架构

1.1 什么是 Axum

Axum 是由 Tokio 团队开发的 Rust Web 框架,具有以下核心特点:

  • 基于 Tower 生态系统:复用 Tower 的 Service / Layer 抽象
  • 基于 Hyper:HTTP 实现层使用 Hyper
  • 基于 Tokio:异步运行时使用 Tokio
  • 零宏路由:不使用过程宏,而是使用类型系统和组合子
  • 提取器模式:通过类型推导自动提取请求数据

1.2 架构层次

Text Only
┌─────────────────────────┐
│      Axum(路由+Handler)  │
├─────────────────────────┤
│    Tower(中间件+Service) │
├─────────────────────────┤
│      Hyper(HTTP协议)     │
├─────────────────────────┤
│     Tokio(异步运行时)     │
└─────────────────────────┘

1.3 快速入门

TOML
# Cargo.toml
[package]
name = "axum-demo"
version = "0.1.0"
edition = "2024"

[dependencies]
axum = "0.8"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.5"
tower-http = { version = "0.6", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = "0.3"
Rust
use axum::{routing::get, Router};

#[tokio::main]
async fn main() {
    // 初始化日志
    tracing_subscriber::fmt::init();

    // 构建路由
    let app = Router::new()
        .route("/", get(root))
        .route("/hello/{name}", get(hello));

    // 启动服务器
    // ⚠️ 教学示例使用 unwrap();生产代码应使用 ? 或 expect() 并配合优雅的错误处理
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    tracing::info!("服务器启动在 http://localhost:3000");
    axum::serve(listener, app).await.unwrap();
}

async fn root() -> &'static str {
    "Hello, Axum!"
}

async fn hello(axum::extract::Path(name): axum::extract::Path<String>) -> String {
    format!("Hello, {}!", name)
}

2. 路由系统

2.1 基本路由

Rust
use axum::{
    routing::{get, post, put, delete, patch},
    Router,
};

fn app() -> Router {
    Router::new()
        // HTTP 方法路由
        .route("/users", get(list_users).post(create_user))
        .route("/users/{id}", get(get_user).put(update_user).delete(delete_user))

        // 单方法路由
        .route("/health", get(health_check))
        .route("/upload", post(upload_file))
}

async fn list_users() -> &'static str { "列出所有用户" }
async fn create_user() -> &'static str { "创建用户" }
async fn get_user() -> &'static str { "获取用户" }
async fn update_user() -> &'static str { "更新用户" }
async fn delete_user() -> &'static str { "删除用户" }
async fn health_check() -> &'static str { "OK" }
async fn upload_file() -> &'static str { "上传成功" }

2.2 嵌套路由

Rust
use axum::Router;
use axum::routing::{get, post};

fn app() -> Router {
    Router::new()
        .nest("/api", api_routes())
        .nest("/admin", admin_routes())
        .route("/", get(index))
}

fn api_routes() -> Router {
    Router::new()
        .nest("/v1", v1_routes())
        .nest("/v2", v2_routes())
}

fn v1_routes() -> Router {
    Router::new()
        .route("/users", get(list_users_v1))
        .route("/posts", get(list_posts_v1))
}

fn v2_routes() -> Router {
    Router::new()
        .route("/users", get(list_users_v2))
        .route("/posts", get(list_posts_v2))
}

fn admin_routes() -> Router {
    Router::new()
        .route("/dashboard", get(dashboard))
        .route("/settings", get(settings).post(update_settings))
}

// 最终路由结构:
// GET /           -> index
// GET /api/v1/users -> list_users_v1
// GET /api/v1/posts -> list_posts_v1
// GET /api/v2/users -> list_users_v2
// GET /api/v2/posts -> list_posts_v2
// GET /admin/dashboard -> dashboard
// GET /admin/settings  -> settings

async fn index() -> &'static str { "首页" }
async fn list_users_v1() -> &'static str { "v1 users" }
async fn list_posts_v1() -> &'static str { "v1 posts" }
async fn list_users_v2() -> &'static str { "v2 users" }
async fn list_posts_v2() -> &'static str { "v2 posts" }
async fn dashboard() -> &'static str { "管理面板" }
async fn settings() -> &'static str { "设置" }
async fn update_settings() -> &'static str { "更新设置" }

2.3 路由参数

Rust
use axum::{
    extract::Path,
    routing::get,
    Router,
};

fn app() -> Router {
    Router::new()
        // 单个参数
        .route("/users/{id}", get(get_user))
        // 多个参数
        .route("/users/{user_id}/posts/{post_id}", get(get_user_post))
        // 通配符(匹配剩余路径)
        .route("/files/*path", get(serve_file))
}

// 单个路径参数
async fn get_user(Path(id): Path<u64>) -> String {
    format!("用户ID: {}", id)
}

// 多个路径参数 - 使用元组
async fn get_user_post(Path((user_id, post_id)): Path<(u64, u64)>) -> String {
    format!("用户 {} 的帖子 {}", user_id, post_id)
}

// 通配符路径
async fn serve_file(Path(path): Path<String>) -> String {
    format!("请求文件: {}", path)
}

2.4 Fallback 与 404

Rust
use axum::{
    http::StatusCode,
    routing::get,
    Router,
};

fn app() -> Router {
    Router::new()
        .route("/", get(root))
        .route("/about", get(about))
        // 未匹配路由的兜底处理
        .fallback(not_found)
}

async fn root() -> &'static str { "首页" }
async fn about() -> &'static str { "关于" }

async fn not_found() -> (StatusCode, &'static str) {
    (StatusCode::NOT_FOUND, "页面不存在")
}

3. Handler 函数与提取器(Extractor)

3.1 Handler 基础

Rust
use axum::{
    http::StatusCode,
    response::{Html, Json, IntoResponse},
};
use serde_json::json;

// Handler 就是普通的 async 函数
// 返回值需要实现 IntoResponse

// 返回字符串
async fn plain_text() -> &'static str {
    "纯文本响应"
}

// 返回 HTML
async fn html_page() -> Html<&'static str> {
    Html("<h1>Hello, Axum!</h1>")
}

// 返回 JSON
async fn json_response() -> Json<serde_json::Value> {
    Json(json!({
        "message": "Hello",
        "status": "ok"
    }))
}

// 返回状态码 + 响应体
async fn with_status() -> (StatusCode, String) {
    (StatusCode::CREATED, "资源已创建".to_string())
}

// 返回 impl IntoResponse(灵活)
async fn flexible() -> impl IntoResponse {
    (StatusCode::OK, [("X-Custom", "value")], "自定义响应头")
}

3.2 Path 提取器

Rust
use axum::extract::Path;
use serde::Deserialize;

// 使用结构体解析路径参数
#[derive(Deserialize)]
struct UserParams {
    user_id: u64,
    post_id: u64,
}

async fn get_user_post(Path(params): Path<UserParams>) -> String {
    format!("用户 {} 的帖子 {}", params.user_id, params.post_id)
}
// 路由: /users/{user_id}/posts/{post_id}

3.3 Query 提取器

Rust
use axum::extract::Query;
use serde::Deserialize;

#[derive(Deserialize)]
struct Pagination {
    page: Option<u32>,
    per_page: Option<u32>,
    sort: Option<String>,
}

// GET /users?page=1&per_page=20&sort=name
async fn list_users(Query(pagination): Query<Pagination>) -> String {
    let page = pagination.page.unwrap_or(1);
    let per_page = pagination.per_page.unwrap_or(10);
    let sort = pagination.sort.unwrap_or_else(|| "id".to_string());
    format!("页码: {}, 每页: {}, 排序: {}", page, per_page, sort)
}

3.4 Json 提取器

Rust
use axum::Json;
use serde::{Deserialize, Serialize};

#[derive(Deserialize)]
struct CreateUser {
    username: String,
    email: String,
    password: String,
}

#[derive(Serialize)]
struct UserResponse {
    id: u64,
    username: String,
    email: String,
}

// POST /users
// Content-Type: application/json
// {"username": "alice", "email": "alice@example.com", "password": "secret"}
async fn create_user(Json(payload): Json<CreateUser>) -> (axum::http::StatusCode, Json<UserResponse>) {
    let user = UserResponse {
        id: 1,
        username: payload.username,
        email: payload.email,
    };
    (axum::http::StatusCode::CREATED, Json(user))
}

3.5 State 提取器(共享状态)

Rust
use axum::{
    extract::State,
    routing::{get, post},
    Router, Json,
};
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};

// 应用状态
#[derive(Clone)]
struct AppState {
    db: Arc<RwLock<Vec<User>>>,
    config: AppConfig,
}

#[derive(Clone)]
struct AppConfig {
    app_name: String,
    max_users: usize,
}

#[derive(Clone, Serialize, Deserialize)]
struct User {
    id: u64,
    name: String,
}

async fn list_users(State(state): State<AppState>) -> Json<Vec<User>> {
    let users = state.db.read().await;
    Json(users.clone())
}

async fn create_user(
    State(state): State<AppState>,
    Json(user): Json<User>,
) -> axum::http::StatusCode {
    let mut users = state.db.write().await;
    if users.len() >= state.config.max_users {
        return axum::http::StatusCode::TOO_MANY_REQUESTS;
    }
    users.push(user);
    axum::http::StatusCode::CREATED
}

async fn app_info(State(state): State<AppState>) -> String {
    format!("应用: {}", state.config.app_name)
}

#[tokio::main]
async fn main() {
    let state = AppState {
        db: Arc::new(RwLock::new(Vec::new())),
        config: AppConfig {
            app_name: "My API".to_string(),
            max_users: 100,
        },
    };

    let app = Router::new()
        .route("/users", get(list_users).post(create_user))
        .route("/info", get(app_info))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

3.6 Header 提取器

Rust
use axum::{
    http::HeaderMap,
    extract::Request,
};
use axum_extra::{
    TypedHeader,
    headers::{Authorization, authorization::Bearer, UserAgent},
};

// 获取所有请求头
async fn all_headers(headers: HeaderMap) -> String {
    let mut result = String::new();
    for (key, value) in &headers {
        result.push_str(&format!("{}: {}\n", key, value.to_str().unwrap_or("")));
    }
    result
}

// 类型化的请求头(需要 axum-extra)
async fn typed_headers(
    TypedHeader(user_agent): TypedHeader<UserAgent>,
) -> String {
    format!("User-Agent: {}", user_agent)
}

// 使用 Request 直接访问
async fn raw_request(req: Request) -> String {
    let method = req.method().clone();
    let uri = req.uri().clone();
    format!("{} {}", method, uri)
}

3.7 自定义提取器

Rust
use axum::{
    extract::FromRequestParts,
    http::{request::Parts, StatusCode, header},
    response::{IntoResponse, Response},
};

// 自定义提取器:从请求头中提取 API Key
// 注意:Axum 0.8+ 已移除 #[async_trait],使用原生 async fn in trait
struct ApiKey(String);

impl<S> FromRequestParts<S> for ApiKey
where
    S: Send + Sync,
{
    type Rejection = ApiKeyError;

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        let api_key = parts
            .headers
            .get("X-API-Key")
            .and_then(|v| v.to_str().ok())
            .ok_or(ApiKeyError::Missing)?;

        if api_key.len() < 32 {
            return Err(ApiKeyError::Invalid);
        }

        Ok(ApiKey(api_key.to_string()))
    }
}

enum ApiKeyError {
    Missing,
    Invalid,
}

impl IntoResponse for ApiKeyError {
    fn into_response(self) -> Response {
        let (status, msg) = match self {
            ApiKeyError::Missing => (StatusCode::UNAUTHORIZED, "缺少 API Key"),
            ApiKeyError::Invalid => (StatusCode::FORBIDDEN, "无效的 API Key"),
        };
        (status, msg).into_response()
    }
}

// 使用自定义提取器
async fn protected_endpoint(ApiKey(key): ApiKey) -> String {
    format!("已认证,API Key: {}...", &key[..8])
}

4. 中间件

4.1 使用 Tower Layer

Rust
use axum::{Router, routing::get};
use tower_http::{
    cors::{CorsLayer, Any},
    trace::TraceLayer,
    compression::CompressionLayer,
    timeout::TimeoutLayer,
};
use std::time::Duration;
use tower::ServiceBuilder;

fn app() -> Router {
    Router::new()
        .route("/", get(root))
        .route("/api/data", get(get_data))
        .layer(
            ServiceBuilder::new()
                // 按顺序执行(从下到上)
                .layer(TraceLayer::new_for_http())       // 4. 请求追踪
                .layer(TimeoutLayer::new(Duration::from_secs(30))) // 3. 超时
                .layer(CompressionLayer::new())           // 2. 响应压缩
                .layer(
                    CorsLayer::new()                      // 1. CORS
                        .allow_origin(Any)
                        .allow_methods(Any)
                        .allow_headers(Any),
                ),
        )
}

async fn root() -> &'static str { "Hello" }
async fn get_data() -> &'static str { "Data" }

4.2 请求日志中间件

Rust
use axum::{
    Router,
    routing::get,
    middleware::{self, Next},
    extract::Request,
    response::Response,
};
use std::time::Instant;

// 自定义中间件函数
async fn logging_middleware(req: Request, next: Next) -> Response {
    let method = req.method().clone();
    let uri = req.uri().clone();
    let start = Instant::now();

    tracing::info!("→ {} {}", method, uri);

    let response = next.run(req).await;

    let duration = start.elapsed();
    tracing::info!(
        "← {} {} [{}] {:?}",
        method,
        uri,
        response.status(),
        duration
    );

    response
}

fn app() -> Router {
    Router::new()
        .route("/", get(|| async { "Hello" }))
        .layer(middleware::from_fn(logging_middleware))
}

4.3 认证中间件

Rust
use axum::{
    Router,
    routing::get,
    middleware::{self, Next},
    extract::Request,
    response::{Response, IntoResponse},
    http::{StatusCode, header},
};

async fn auth_middleware(req: Request, next: Next) -> Result<Response, StatusCode> {
    let auth_header = req
        .headers()
        .get(header::AUTHORIZATION)
        .and_then(|v| v.to_str().ok());

    match auth_header {
        Some(token) if token.starts_with("Bearer ") => {
            let token = &token[7..];
            if validate_token(token).await {
                Ok(next.run(req).await)
            } else {
                Err(StatusCode::UNAUTHORIZED)
            }
        }
        _ => Err(StatusCode::UNAUTHORIZED),
    }
}

async fn validate_token(token: &str) -> bool {
    // 实际应用中验证 JWT 等
    !token.is_empty()
}

fn app() -> Router {
    let public_routes = Router::new()
        .route("/login", get(login))
        .route("/health", get(health));

    let protected_routes = Router::new()
        .route("/dashboard", get(dashboard))
        .route("/profile", get(profile))
        .layer(middleware::from_fn(auth_middleware));

    Router::new()
        .merge(public_routes)
        .merge(protected_routes)
}

async fn login() -> &'static str { "登录页" }
async fn health() -> &'static str { "OK" }
async fn dashboard() -> &'static str { "仪表板" }
async fn profile() -> &'static str { "个人资料" }

4.4 限流中间件

Rust
use axum::{
    Router,
    routing::get,
    middleware::{self, Next},
    extract::Request,
    response::{Response, IntoResponse},
    http::StatusCode,
};
use std::sync::Arc;
use tokio::sync::Mutex;
use std::collections::HashMap;
use std::time::Instant;

#[derive(Clone)]
struct RateLimiter {
    requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
    max_requests: usize,
    window_secs: u64,
}

impl RateLimiter {
    fn new(max_requests: usize, window_secs: u64) -> Self {
        RateLimiter {
            requests: Arc::new(Mutex::new(HashMap::new())),
            max_requests,
            window_secs,
        }
    }

    async fn check(&self, ip: &str) -> bool {
        let mut requests = self.requests.lock().await;
        let now = Instant::now();
        let window = std::time::Duration::from_secs(self.window_secs);

        let entries = requests.entry(ip.to_string()).or_insert_with(Vec::new);

        // 清理过期记录
        entries.retain(|&t| now.duration_since(t) < window);

        if entries.len() >= self.max_requests {
            false
        } else {
            entries.push(now);
            true
        }
    }
}

async fn rate_limit_middleware(
    axum::extract::State(limiter): axum::extract::State<RateLimiter>,
    req: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    let ip = req
        .headers()
        .get("X-Forwarded-For")
        .and_then(|v| v.to_str().ok())
        .unwrap_or("unknown")
        .to_string();

    if limiter.check(&ip).await {
        Ok(next.run(req).await)
    } else {
        Err(StatusCode::TOO_MANY_REQUESTS)
    }
}

5. 错误处理

5.1 统一错误类型

Rust
use axum::{
    http::StatusCode,
    response::{IntoResponse, Response},
    Json,
};
use serde_json::json;

// 统一应用错误类型
#[derive(Debug)]
pub enum AppError {
    NotFound(String),
    BadRequest(String),
    Unauthorized,
    Forbidden,
    InternalError(String),
    DatabaseError(String),
    ValidationError(Vec<String>),
}

impl IntoResponse for AppError {
    fn into_response(self) -> Response {
        let (status, error_message) = match self {
            AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
            AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
            AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "未授权".to_string()),
            AppError::Forbidden => (StatusCode::FORBIDDEN, "禁止访问".to_string()),
            AppError::InternalError(msg) => {
                tracing::error!("内部错误: {}", msg);
                (StatusCode::INTERNAL_SERVER_ERROR, "服务器内部错误".to_string())
            }
            AppError::DatabaseError(msg) => {
                tracing::error!("数据库错误: {}", msg);
                (StatusCode::INTERNAL_SERVER_ERROR, "数据库错误".to_string())
            }
            AppError::ValidationError(errors) => {
                (StatusCode::UNPROCESSABLE_ENTITY, errors.join(", "))
            }
        };

        let body = Json(json!({
            "error": {
                "code": status.as_u16(),
                "message": error_message,
            }
        }));

        (status, body).into_response()
    }
}

// 从标准库错误自动转换
impl From<std::io::Error> for AppError {
    fn from(err: std::io::Error) -> Self {
        AppError::InternalError(err.to_string())
    }
}

impl From<serde_json::Error> for AppError {
    fn from(err: serde_json::Error) -> Self {
        AppError::BadRequest(format!("JSON 解析错误: {}", err))
    }
}

5.2 在 Handler 中使用错误类型

Rust
use axum::{extract::Path, Json};
use serde::Serialize;

#[derive(Serialize)]
struct User {
    id: u64,
    name: String,
    email: String,
}

// 返回 Result<T, AppError> 的 Handler
async fn get_user(Path(id): Path<u64>) -> Result<Json<User>, AppError> {
    if id == 0 {
        return Err(AppError::BadRequest("ID 不能为 0".to_string()));
    }

    // 模拟数据库查询
    let user = find_user_by_id(id)
        .await
        .ok_or_else(|| AppError::NotFound(format!("用户 {} 不存在", id)))?;

    Ok(Json(user))
}

async fn find_user_by_id(id: u64) -> Option<User> {
    if id == 1 {
        Some(User {
            id: 1,
            name: "Alice".to_string(),
            email: "alice@example.com".to_string(),
        })
    } else {
        None
    }
}

// 使用 ? 操作符链式处理错误
async fn create_user(Json(payload): Json<serde_json::Value>) -> Result<Json<User>, AppError> {
    let name = payload["name"]
        .as_str()
        .ok_or_else(|| AppError::ValidationError(vec!["name 字段必须为字符串".to_string()]))?;

    let email = payload["email"]
        .as_str()
        .ok_or_else(|| AppError::ValidationError(vec!["email 字段必须为字符串".to_string()]))?;

    let user = User {
        id: 42,
        name: name.to_string(),
        email: email.to_string(),
    };

    Ok(Json(user))
}

5.3 使用 anyhow 简化错误处理

Rust
use axum::{
    http::StatusCode,
    response::{IntoResponse, Response},
    Json,
};
use serde_json::json;

// 使用 anyhow 作为内部错误类型
struct AnyhowError(anyhow::Error);

impl IntoResponse for AnyhowError {
    fn into_response(self) -> Response {
        tracing::error!("应用错误: {:?}", self.0);
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(json!({"error": self.0.to_string()})),
        )
            .into_response()
    }
}

// 允许 ? 操作符自动转换
impl<E> From<E> for AnyhowError
where
    E: Into<anyhow::Error>,
{
    fn from(err: E) -> Self {
        AnyhowError(err.into())
    }
}

// Handler 中使用
async fn handler() -> Result<String, AnyhowError> {
    let data = std::fs::read_to_string("config.toml")?; // 自动转换
    Ok(data)
}

6. 数据库集成(SQLx)

6.1 配置 SQLx

TOML
# Cargo.toml
[dependencies]
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "uuid", "chrono"] }
uuid = { version = "1", features = ["v4", "serde"] }
chrono = { version = "0.4", features = ["serde"] }

6.2 数据库连接池

Rust
use axum::{
    extract::State,
    routing::{get, post},
    Router, Json,
};
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use serde::{Deserialize, Serialize};

#[derive(Clone)]
struct AppState {
    db: PgPool,
}

#[tokio::main]
async fn main() {
    // 从环境变量读取数据库 URL
    let database_url = std::env::var("DATABASE_URL")
        .unwrap_or_else(|_| "postgres://user:pass@localhost/mydb".to_string());

    // 创建连接池
    let pool = PgPoolOptions::new()
        .max_connections(20)
        .min_connections(5)
        .acquire_timeout(std::time::Duration::from_secs(3))
        .connect(&database_url)
        .await
        .expect("无法连接数据库");

    // 运行迁移
    sqlx::migrate!("./migrations")
        .run(&pool)
        .await
        .expect("迁移失败");

    let state = AppState { db: pool };

    let app = Router::new()
        .route("/users", get(list_users).post(create_user))
        .route("/users/{id}", get(get_user).delete(delete_user))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

6.3 CRUD 操作

Rust
use axum::{extract::{Path, State}, http::StatusCode, Json};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
use chrono::{DateTime, Utc};

#[derive(Debug, Serialize, FromRow)]
struct User {
    id: Uuid,
    username: String,
    email: String,
    created_at: DateTime<Utc>,
}

#[derive(Debug, Deserialize)]
struct CreateUserRequest {
    username: String,
    email: String,
}

// CREATE
async fn create_user(
    State(state): State<AppState>,
    Json(payload): Json<CreateUserRequest>,
) -> Result<(StatusCode, Json<User>), AppError> {
    let user = sqlx::query_as::<_, User>(
        r#"
        INSERT INTO users (id, username, email, created_at)
        VALUES ($1, $2, $3, $4)
        RETURNING id, username, email, created_at
        "#,
    )
    .bind(Uuid::new_v4())
    .bind(&payload.username)
    .bind(&payload.email)
    .bind(Utc::now())
    .fetch_one(&state.db)
    .await
    .map_err(|e| AppError::DatabaseError(e.to_string()))?;

    Ok((StatusCode::CREATED, Json(user)))
}

// READ (List)
async fn list_users(
    State(state): State<AppState>,
) -> Result<Json<Vec<User>>, AppError> {
    let users = sqlx::query_as::<_, User>("SELECT id, username, email, created_at FROM users ORDER BY created_at DESC")
        .fetch_all(&state.db)
        .await
        .map_err(|e| AppError::DatabaseError(e.to_string()))?;

    Ok(Json(users))
}

// READ (Single)
async fn get_user(
    State(state): State<AppState>,
    Path(id): Path<Uuid>,
) -> Result<Json<User>, AppError> {
    let user = sqlx::query_as::<_, User>(
        "SELECT id, username, email, created_at FROM users WHERE id = $1",
    )
    .bind(id)
    .fetch_optional(&state.db)
    .await
    .map_err(|e| AppError::DatabaseError(e.to_string()))?
    .ok_or_else(|| AppError::NotFound(format!("用户 {} 不存在", id)))?;

    Ok(Json(user))
}

// DELETE
async fn delete_user(
    State(state): State<AppState>,
    Path(id): Path<Uuid>,
) -> Result<StatusCode, AppError> {
    let result = sqlx::query("DELETE FROM users WHERE id = $1")
        .bind(id)
        .execute(&state.db)
        .await
        .map_err(|e| AppError::DatabaseError(e.to_string()))?;

    if result.rows_affected() == 0 {
        Err(AppError::NotFound(format!("用户 {} 不存在", id)))
    } else {
        Ok(StatusCode::NO_CONTENT)
    }
}

6.4 数据库迁移

SQL
-- migrations/001_create_users.sql
CREATE TABLE IF NOT EXISTS users (
    id UUID PRIMARY KEY,
    username VARCHAR(255) NOT NULL UNIQUE,
    email VARCHAR(255) NOT NULL UNIQUE,
    password_hash VARCHAR(255) NOT NULL,
    created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
    updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);

CREATE INDEX idx_users_email ON users(email);
CREATE INDEX idx_users_username ON users(username);

7. 认证与授权

7.1 JWT 认证

TOML
# Cargo.toml
[dependencies]
jsonwebtoken = "9"
Rust
use jsonwebtoken::{encode, decode, Header, Validation, EncodingKey, DecodingKey};
use serde::{Deserialize, Serialize};
use chrono::{Utc, Duration};

#[derive(Debug, Serialize, Deserialize)]
struct Claims {
    sub: String,      // 用户ID
    exp: usize,       // 过期时间
    iat: usize,       // 签发时间
    role: String,     // 角色
}

struct JwtConfig {
    secret: String,
    expiration_hours: i64,
}

impl JwtConfig {
    fn new() -> Self {
        JwtConfig {
            secret: std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-key".to_string()),
            expiration_hours: 24,
        }
    }

    // 生成 Token
    fn create_token(&self, user_id: &str, role: &str) -> Result<String, jsonwebtoken::errors::Error> {
        let now = Utc::now();
        let claims = Claims {
            sub: user_id.to_string(),
            exp: (now + Duration::hours(self.expiration_hours)).timestamp() as usize,
            iat: now.timestamp() as usize,
            role: role.to_string(),
        };

        encode(
            &Header::default(),
            &claims,
            &EncodingKey::from_secret(self.secret.as_bytes()),
        )
    }

    // 验证 Token
    fn verify_token(&self, token: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
        let token_data = decode::<Claims>(
            token,
            &DecodingKey::from_secret(self.secret.as_bytes()),
            &Validation::default(),
        )?;

        Ok(token_data.claims)
    }
}

7.2 认证提取器

Rust
use axum::{
    extract::FromRequestParts,
    http::{request::Parts, StatusCode, header},
    response::{IntoResponse, Response},
    Json,
};
use serde_json::json;

// 认证用户信息
#[derive(Debug, Clone)]
pub struct AuthUser {
    pub user_id: String,
    pub role: String,
}

// Axum 0.8+ 使用原生 async fn in trait,无需 #[async_trait]
impl<S> FromRequestParts<S> for AuthUser
where
    S: Send + Sync,
{
    type Rejection = AuthError;

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        // 从 Authorization 头提取 Bearer Token
        let auth_header = parts
            .headers
            .get(header::AUTHORIZATION)
            .and_then(|v| v.to_str().ok())
            .ok_or(AuthError::MissingToken)?;

        if !auth_header.starts_with("Bearer ") {
            return Err(AuthError::InvalidToken);
        }

        let token = &auth_header[7..];
        let jwt_config = JwtConfig::new();
        let claims = jwt_config
            .verify_token(token)
            .map_err(|_| AuthError::InvalidToken)?;

        Ok(AuthUser {
            user_id: claims.sub,
            role: claims.role,
        })
    }
}

#[derive(Debug)]
pub enum AuthError {
    MissingToken,
    InvalidToken,
    InsufficientPermission,
}

impl IntoResponse for AuthError {
    fn into_response(self) -> Response {
        let (status, msg) = match self {
            AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "缺少认证 Token"),
            AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "无效的 Token"),
            AuthError::InsufficientPermission => (StatusCode::FORBIDDEN, "权限不足"),
        };
        (status, Json(json!({"error": msg}))).into_response()
    }
}

7.3 登录接口

Rust
use axum::{Json, http::StatusCode};
use serde::{Deserialize, Serialize};

#[derive(Deserialize)]
struct LoginRequest {
    username: String,
    password: String,
}

#[derive(Serialize)]
struct LoginResponse {
    token: String,
    token_type: String,
    expires_in: i64,
}

async fn login(Json(payload): Json<LoginRequest>) -> Result<Json<LoginResponse>, AppError> {
    // 验证用户凭证(实际应查数据库并验证密码哈希)
    let (user_id, role) = verify_credentials(&payload.username, &payload.password)
        .await
        .ok_or(AppError::Unauthorized)?;

    let jwt_config = JwtConfig::new();
    let token = jwt_config
        .create_token(&user_id, &role)
        .map_err(|e| AppError::InternalError(e.to_string()))?;

    Ok(Json(LoginResponse {
        token,
        token_type: "Bearer".to_string(),
        expires_in: 86400,
    }))
}

async fn verify_credentials(username: &str, password: &str) -> Option<(String, String)> {
    // 示例:实际应查询数据库
    if username == "admin" && password == "password" {
        Some(("user-1".to_string(), "admin".to_string()))
    } else {
        None
    }
}

// 受保护的路由使用 AuthUser 提取器
async fn protected(user: AuthUser) -> String {
    format!("欢迎,用户 {},角色: {}", user.user_id, user.role)
}

// 仅管理员访问
async fn admin_only(user: AuthUser) -> Result<String, AuthError> {
    if user.role != "admin" {
        return Err(AuthError::InsufficientPermission);
    }
    Ok("管理员面板".to_string())
}

8. WebSocket 支持

8.1 基本 WebSocket

Rust
use axum::{
    extract::ws::{Message, WebSocket, WebSocketUpgrade},
    response::Response,
    routing::get,
    Router,
};
use futures_util::{SinkExt, StreamExt};

fn app() -> Router {
    Router::new()
        .route("/ws", get(ws_handler))
}

// WebSocket 升级处理
async fn ws_handler(ws: WebSocketUpgrade) -> Response {
    ws.on_upgrade(handle_socket)
}

// WebSocket 连接处理
async fn handle_socket(mut socket: WebSocket) {
    // 发送欢迎消息
    if socket
        .send(Message::Text("欢迎连接 WebSocket!".to_string()))
        .await
        .is_err()
    {
        return; // 客户端已断开
    }

    // 消息循环
    while let Some(msg) = socket.recv().await {
        let msg = match msg {
            Ok(msg) => msg,
            Err(_) => return, // 连接错误
        };

        match msg {
            Message::Text(text) => {
                // Echo: 回显收到的消息
                let reply = format!("收到: {}", text);
                if socket.send(Message::Text(reply)).await.is_err() {
                    return;
                }
            }
            Message::Binary(data) => {
                // 处理二进制数据
                if socket.send(Message::Binary(data)).await.is_err() {
                    return;
                }
            }
            Message::Ping(data) => {
                if socket.send(Message::Pong(data)).await.is_err() {
                    return;
                }
            }
            Message::Close(_) => return,
            _ => {}
        }
    }
}

8.2 WebSocket 广播(聊天室)

Rust
use axum::{
    extract::{
        ws::{Message, WebSocket, WebSocketUpgrade},
        State,
    },
    response::Response,
    routing::get,
    Router,
};
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::broadcast;

#[derive(Clone)]
struct ChatState {
    tx: broadcast::Sender<String>,
}

fn app() -> Router {
    let (tx, _rx) = broadcast::channel(100);
    let state = ChatState { tx };

    Router::new()
        .route("/chat", get(chat_handler))
        .with_state(state)
}

async fn chat_handler(
    ws: WebSocketUpgrade,
    State(state): State<ChatState>,
) -> Response {
    ws.on_upgrade(move |socket| handle_chat(socket, state))
}

async fn handle_chat(socket: WebSocket, state: ChatState) {
    let (mut sender, mut receiver) = socket.split();

    // 订阅广播频道
    let mut rx = state.tx.subscribe();

    // 任务1:接收广播消息并发送给客户端
    let mut send_task = tokio::spawn(async move {
        while let Ok(msg) = rx.recv().await {
            if sender.send(Message::Text(msg)).await.is_err() {
                break;
            }
        }
    });

    // 任务2:接收客户端消息并广播
    let tx = state.tx.clone();
    let mut recv_task = tokio::spawn(async move {
        while let Some(Ok(Message::Text(text))) = receiver.next().await {
            let _ = tx.send(text);
        }
    });

    // 等待任一任务结束
    tokio::select! {
        _ = &mut send_task => recv_task.abort(),
        _ = &mut recv_task => send_task.abort(),
    }
}

9. 测试

9.1 集成测试

Rust
#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request, StatusCode},
    };
    use http_body_util::BodyExt;
    use tower::ServiceExt; // for `oneshot`

    fn test_app() -> Router {
        Router::new()
            .route("/", get(|| async { "Hello" }))
            .route("/users/{id}", get(|Path(id): Path<u64>| async move {
                format!("User {}", id)
            }))
            .route("/json", get(|| async {
                Json(serde_json::json!({"status": "ok"}))
            }))
    }

    #[tokio::test]
    async fn test_root() {
        let app = test_app();

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);

        let body = response.into_body().collect().await.unwrap().to_bytes();
        assert_eq!(&body[..], b"Hello");
    }

    #[tokio::test]
    async fn test_user_path() {
        let app = test_app();

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/users/42")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);
        let body = response.into_body().collect().await.unwrap().to_bytes();
        assert_eq!(&body[..], b"User 42");
    }

    #[tokio::test]
    async fn test_json_response() {
        let app = test_app();

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/json")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);

        let body = response.into_body().collect().await.unwrap().to_bytes();
        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
        assert_eq!(json["status"], "ok");
    }

    #[tokio::test]
    async fn test_not_found() {
        let app = test_app();

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/nonexistent")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::NOT_FOUND);
    }
}

9.2 测试 POST 请求

Rust
#[cfg(test)]
mod post_tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request, StatusCode, header},
    };
    use http_body_util::BodyExt;
    use tower::ServiceExt;

    #[tokio::test]
    async fn test_create_user() {
        let app = Router::new()
            .route("/users", post(|Json(body): Json<serde_json::Value>| async move {
                let name = body["name"].as_str().unwrap_or("unknown");
                (StatusCode::CREATED, Json(serde_json::json!({
                    "id": 1,
                    "name": name
                })))
            }));

        let response = app
            .oneshot(
                Request::builder()
                    .method("POST")
                    .uri("/users")
                    .header(header::CONTENT_TYPE, "application/json")
                    .body(Body::from(r#"{"name": "Alice"}"#))
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::CREATED);

        let body = response.into_body().collect().await.unwrap().to_bytes();
        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
        assert_eq!(json["name"], "Alice");
    }
}

9.3 使用 TestClient(辅助工具)

Rust
// 自定义测试辅助模块
#[cfg(test)]
mod test_helpers {
    use axum::{body::Body, Router};
    use http_body_util::BodyExt;
    use axum::http::{Request, Response, StatusCode, header};
    use tower::ServiceExt;

    pub struct TestClient {
        app: Router,
    }

    impl TestClient {
        pub fn new(app: Router) -> Self {
            TestClient { app }
        }

        pub async fn get(&self, uri: &str) -> TestResponse {
            let response = self
                .app
                .clone()
                .oneshot(
                    Request::builder()
                        .uri(uri)
                        .body(Body::empty())
                        .unwrap(),
                )
                .await
                .unwrap();
            TestResponse(response)
        }

        pub async fn post_json(&self, uri: &str, body: serde_json::Value) -> TestResponse {
            let response = self
                .app
                .clone()
                .oneshot(
                    Request::builder()
                        .method("POST")
                        .uri(uri)
                        .header(header::CONTENT_TYPE, "application/json")
                        .body(Body::from(serde_json::to_string(&body).unwrap()))
                        .unwrap(),
                )
                .await
                .unwrap();
            TestResponse(response)
        }
    }

    pub struct TestResponse(Response<Body>);

    impl TestResponse {
        pub fn status(&self) -> StatusCode {
            self.0.status()
        }

        pub async fn json(self) -> serde_json::Value {
            let body = self.0.into_body().collect().await.unwrap().to_bytes();
            serde_json::from_slice(&body).unwrap()
        }

        pub async fn text(self) -> String {
            let body = self.0.into_body().collect().await.unwrap().to_bytes();
            String::from_utf8(body.to_vec()).unwrap()
        }
    }
}

10. 部署

10.1 Docker 多阶段构建

Docker
# Dockerfile

# 阶段1:构建
FROM rust:1.85-bookworm AS builder

WORKDIR /app

# 先复制依赖文件,利用 Docker 层缓存
COPY Cargo.toml Cargo.lock ./
RUN mkdir src && echo "fn main() {}" > src/main.rs
RUN cargo build --release
RUN rm -rf src

# 复制源代码并构建
COPY src/ src/
COPY migrations/ migrations/
RUN touch src/main.rs  # 更新时间戳,强制重新编译
RUN cargo build --release

# 阶段2:运行(使用精简镜像)
FROM debian:bookworm-slim

RUN apt-get update && apt-get install -y \
    ca-certificates \
    && rm -rf /var/lib/apt/lists/*

WORKDIR /app

# 从构建阶段复制二进制文件
COPY --from=builder /app/target/release/axum-demo /app/axum-demo

# 非 root 用户运行
RUN useradd -r -s /bin/false appuser
USER appuser

EXPOSE 3000

ENV RUST_LOG=info

CMD ["./axum-demo"]

10.2 docker-compose

YAML
# docker-compose.yml
# 注意:Docker Compose V2 中 version 字段已废弃,无需指定

services:
  app:
    build: .
    ports:
      - "3000:3000"
    environment:
      - DATABASE_URL=postgres://user:password@db:5432/myapp
      - JWT_SECRET=your-secret-key
      - RUST_LOG=info
    depends_on:
      db:
        condition: service_healthy

  db:
    image: postgres:16-alpine
    environment:
      POSTGRES_USER: user
      POSTGRES_PASSWORD: password
      POSTGRES_DB: myapp
    volumes:
      - pgdata:/var/lib/postgresql/data
    healthcheck:
      test: ["CMD-SHELL", "pg_isready -U user -d myapp"]
      interval: 5s
      timeout: 5s
      retries: 5

volumes:
  pgdata:

10.3 静态编译(musl)

Docker
# 完全静态编译(无系统依赖)
FROM rust:1.85-bookworm AS builder

RUN apt-get update && apt-get install -y musl-tools
RUN rustup target add x86_64-unknown-linux-musl

WORKDIR /app
COPY . .

RUN cargo build --release --target x86_64-unknown-linux-musl

# 最小运行镜像
FROM scratch

COPY --from=builder /app/target/x86_64-unknown-linux-musl/release/axum-demo /axum-demo

# 如果需要 HTTPS,复制 CA 证书
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/

EXPOSE 3000
ENTRYPOINT ["/axum-demo"]

10.4 优雅关闭

Rust
use axum::{routing::get, Router};
use tokio::signal;

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt::init();

    let app = Router::new()
        .route("/", get(|| async { "Hello" }))
        .route("/health", get(|| async { "OK" }));

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    tracing::info!("服务器启动在 http://0.0.0.0:3000");

    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .unwrap();

    tracing::info!("服务器已优雅关闭");
}

async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("无法注册 Ctrl+C 处理器");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("无法注册 SIGTERM 处理器")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => tracing::info!("收到 Ctrl+C"),
        _ = terminate => tracing::info!("收到 SIGTERM"),
    }
}

11. 完整 REST API 项目结构

11.1 推荐目录结构

Text Only
my-api/
├── Cargo.toml
├── Dockerfile
├── docker-compose.yml
├── .env
├── migrations/
│   ├── 001_create_users.sql
│   └── 002_create_posts.sql
├── src/
│   ├── main.rs            # 入口:服务器启动
│   ├── config.rs           # 配置管理
│   ├── error.rs            # 统一错误类型
│   ├── db.rs               # 数据库连接
│   ├── routes/
│   │   ├── mod.rs          # 路由聚合
│   │   ├── users.rs        # 用户路由
│   │   └── posts.rs        # 帖子路由
│   ├── handlers/
│   │   ├── mod.rs
│   │   ├── users.rs        # 用户 Handler
│   │   └── posts.rs        # 帖子 Handler
│   ├── models/
│   │   ├── mod.rs
│   │   ├── user.rs         # 用户模型
│   │   └── post.rs         # 帖子模型
│   ├── middleware/
│   │   ├── mod.rs
│   │   ├── auth.rs         # 认证中间件
│   │   └── logging.rs      # 日志中间件
│   └── extractors/
│       ├── mod.rs
│       └── auth.rs          # 认证提取器
└── tests/
    └── api/
        ├── mod.rs
        ├── users_test.rs
        └── posts_test.rs

11.2 配置管理

Rust
// src/config.rs
use serde::Deserialize;

#[derive(Debug, Clone, Deserialize)]
pub struct Config {
    pub database_url: String,
    pub jwt_secret: String,
    pub server_host: String,
    pub server_port: u16,
    pub rust_log: String,
}

impl Config {
    pub fn from_env() -> Result<Self, envy::Error> {
        // 加载 .env 文件
        dotenvy::dotenv().ok();
        envy::from_env()
    }

    pub fn server_addr(&self) -> String {
        format!("{}:{}", self.server_host, self.server_port)
    }
}

11.3 主入口

Rust
// src/main.rs
mod config;
mod db;
mod error;
mod routes;
mod handlers;
mod models;
mod middleware;
mod extractors;

use config::Config;
use axum::Router;
use tower_http::trace::TraceLayer;

#[derive(Clone)]
pub struct AppState {
    pub db: sqlx::PgPool,
    pub config: Config,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // 加载配置
    let config = Config::from_env()?;

    // 初始化日志
    tracing_subscriber::fmt()
        .with_env_filter(&config.rust_log)
        .init();

    // 连接数据库
    let db = db::create_pool(&config.database_url).await?;
    sqlx::migrate!("./migrations").run(&db).await?;

    let state = AppState {
        db,
        config: config.clone(),
    };

    // 构建应用
    let app = Router::new()
        .merge(routes::all_routes())
        .layer(TraceLayer::new_for_http())
        .with_state(state);

    // 启动服务器
    let addr = config.server_addr();
    let listener = tokio::net::TcpListener::bind(&addr).await?;
    tracing::info!("服务器启动在 http://{}", addr);

    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await?;

    Ok(())
}

async fn shutdown_signal() {
    tokio::signal::ctrl_c().await.ok();
    tracing::info!("正在关闭服务器...");
}

11.4 路由模块

Rust
// src/routes/mod.rs
use axum::{
    routing::{get, post, put, delete},
    Router,
    middleware,
};
use crate::{AppState, handlers, middleware as mw};

pub fn all_routes() -> Router<AppState> {
    Router::new()
        .merge(public_routes())
        .merge(protected_routes())
}

fn public_routes() -> Router<AppState> {
    Router::new()
        .route("/health", get(|| async { "OK" }))
        .route("/api/login", post(handlers::users::login))
        .route("/api/register", post(handlers::users::register))
}

fn protected_routes() -> Router<AppState> {
    Router::new()
        .route("/api/users", get(handlers::users::list))
        .route("/api/users/{id}", get(handlers::users::get_by_id)
            .put(handlers::users::update)
            .delete(handlers::users::delete))
        .route("/api/posts", get(handlers::posts::list)
            .post(handlers::posts::create))
        .route("/api/posts/{id}", get(handlers::posts::get_by_id)
            .put(handlers::posts::update)
            .delete(handlers::posts::delete))
        .layer(middleware::from_fn(mw::auth::require_auth))
}

12. 面试题

Q1: 为什么选择 Axum?它与 Actix-web、Rocket 有什么区别?

A1:

Rust
// Axum 的核心优势:

// 1. 基于 Tower 生态 - 复用大量现成中间件
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use tower_http::compression::CompressionLayer;
// 这些中间件可以直接用于任何 Tower 兼容的框架

// 2. 零宏路由 - 不依赖过程宏
let app = Router::new()
    .route("/users", get(list_users).post(create_user));
// 对比 Rocket: #[get("/users")] 需要过程宏
// 对比 Actix: #[get("/users")] 也需要过程宏

// 3. 提取器(Extractor)模式 - 类型驱动的请求解析
async fn handler(
    Path(id): Path<u64>,       // 从路径提取
    Query(q): Query<Params>,   // 从查询参数提取
    Json(body): Json<Body>,    // 从请求体提取
) -> impl IntoResponse {
    // 参数顺序无关,类型系统自动匹配
    todo!()
}

// 4. Tokio 团队维护 - 与 Tokio 运行时深度集成
use axum::extract::{Path, Query, Json};
use axum::response::IntoResponse;

struct Params;
struct Body;

async fn list_users() {}
async fn create_user() {}
特性 Axum Actix-web Rocket
异步运行时 Tokio Tokio(多个单线程运行时) Tokio
宏依赖 需要 需要
中间件生态 Tower 自有 自有
维护团队 Tokio 团队 社区 个人
性能 极高 极高

Q2: 解释 Axum 的 Extractor 模式如何工作?

A2:

Rust
use axum::{
    extract::{FromRequest, FromRequestParts, Request},
    http::request::Parts,
    async_trait,
};

// Extractor 的核心是两个 Trait:

// 1. FromRequestParts - 从请求的"部件"(头、URL等)提取,不消耗 body
// Path, Query, Headers, State 都实现了这个
#[async_trait]
trait FromRequestPartsExample<S>: Sized {
    type Rejection; // 提取失败时的错误类型
    async fn from_request_parts(
        parts: &mut Parts,
        state: &S,
    ) -> Result<Self, Self::Rejection>;
}

// 2. FromRequest - 从完整请求提取,可以消耗 body
// Json, String, Bytes 实现了这个
#[async_trait]
trait FromRequestExample<S>: Sized {
    type Rejection;
    async fn from_request(
        req: Request,
        state: &S,
    ) -> Result<Self, Self::Rejection>;
}

// Handler 的参数列表就是一系列 Extractor
// Axum 按顺序调用它们,最后一个可以消耗 body
async fn handler(
    // FromRequestParts extractors (不消耗body,可以有多个)
    axum::extract::Path(id): axum::extract::Path<u64>,
    axum::extract::Query(q): axum::extract::Query<std::collections::HashMap<String, String>>,
    headers: axum::http::HeaderMap,
    // FromRequest extractor (消耗body,只能有一个,必须放最后)
    axum::Json(body): axum::Json<serde_json::Value>,
) -> String {
    format!("id={}, query={:?}, body={}", id, q, body)
}

Q3: Axum 中如何实现统一错误处理?

A3:

Rust
use axum::{
    http::StatusCode,
    response::{IntoResponse, Response},
    Json,
};
use serde_json::json;

// 核心:实现 IntoResponse trait
#[derive(Debug)]
enum AppError {
    NotFound(String),
    BadRequest(String),
    Internal(anyhow::Error),
}

impl IntoResponse for AppError {
    fn into_response(self) -> Response {
        let (status, message) = match self {
            AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
            AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
            AppError::Internal(err) => {
                // 不暴露内部错误给客户端
                tracing::error!("内部错误: {:?}", err);
                (StatusCode::INTERNAL_SERVER_ERROR, "服务器错误".into())
            }
        };

        (status, Json(json!({"error": message}))).into_response()
    }
}

// 实现 From<T> 使 ? 操作符自动转换
impl From<sqlx::Error> for AppError {
    fn from(err: sqlx::Error) -> Self {
        match err {
            sqlx::Error::RowNotFound => AppError::NotFound("记录不存在".into()),
            _ => AppError::Internal(err.into()),
        }
    }
}

// Handler 中使用 Result<T, AppError>
async fn handler() -> Result<Json<serde_json::Value>, AppError> {
    // ? 自动将 sqlx::Error 转换为 AppError
    // let user = sqlx::query_as("SELECT ...").fetch_one(&db).await?;
    Ok(Json(json!({"status": "ok"})))
}

Q4: 如何在 Axum 中管理共享状态?

A4:

Rust
use axum::{extract::State, Router, routing::get};
use std::sync::Arc;
use tokio::sync::RwLock;

// 方案1:State 提取器(推荐)
#[derive(Clone)]
struct AppState {
    // 不可变配置:直接存储
    db_url: String,
    // 可变状态:用 Arc + 锁
    counter: Arc<RwLock<u64>>,
    // 连接池:已内置 Arc(Clone 即共享)
    // db: sqlx::PgPool,
}

async fn handler(State(state): State<AppState>) -> String {
    let count = state.counter.read().await;
    format!("count: {}", *count)
}

fn app() -> Router {
    let state = AppState {
        db_url: "postgres://...".into(),
        counter: Arc::new(RwLock::new(0)),
    };
    Router::new()
        .route("/", get(handler))
        .with_state(state)
    // State 要求类型实现 Clone
    // 如果状态大,用 Arc 包装避免深拷贝
}

// 方案2:Extension(不推荐,已弱化)
// .layer(Extension(shared_data))

// 关键区别:
// State - 编译时类型检查,如果漏了 with_state 会编译报错
// Extension - 运行时检查,可能 panic

Q5: Axum 中间件的执行顺序是怎样的?

A5:

Rust
use axum::{Router, routing::get, middleware};

// Layer 的执行顺序是"洋葱模型"
// 添加顺序:从下到上包裹
// 请求处理:从外到内
// 响应处理:从内到外

let app = Router::new()
    .route("/", get(handler))
    .layer(layer_a)  // 最外层
    .layer(layer_b)  // 中间层
    .layer(layer_c); // 最内层(最接近 handler)

// 请求流:layer_a → layer_b → layer_c → handler
// 响应流:handler → layer_c → layer_b → layer_a

// 使用 ServiceBuilder 可以用"自然顺序"
use tower::ServiceBuilder;

let app = Router::new()
    .route("/", get(handler))
    .layer(
        ServiceBuilder::new()
            .layer(layer_a)  // 先执行
            .layer(layer_b)  // 后执行
            .layer(layer_c)  // 最后执行,最接近 handler
    );

// 请求流:layer_a → layer_b → layer_c → handler
// ServiceBuilder 的顺序更直观

async fn handler() -> &'static str { "Hello" }
// 注意:以上代码为示意,layer_a/b/c 需要是实际的 Layer 实例

Q6: Axum 如何实现 WebSocket?与 HTTP 有什么不同?

A6:

Rust
use axum::{
    extract::ws::{Message, WebSocket, WebSocketUpgrade},
    response::Response,
    routing::get,
    Router,
};
use futures_util::{SinkExt, StreamExt};

// WebSocket 是 HTTP 升级协议
// 1. 客户端发送 HTTP GET + Upgrade: websocket 头
// 2. 服务器返回 101 Switching Protocols
// 3. 连接升级为全双工 WebSocket

async fn ws_handler(ws: WebSocketUpgrade) -> Response {
    // WebSocketUpgrade 提取器处理协议升级
    // on_upgrade 在升级完成后调用回调
    ws.on_upgrade(handle_socket)
}

async fn handle_socket(mut socket: WebSocket) {
    // WebSocket 是全双工的
    // 可以同时发送和接收
    let (mut tx, mut rx) = socket.split();

    // 读写可以在不同的任务中
    tokio::spawn(async move {
        while let Some(Ok(msg)) = rx.next().await {
            match msg {
                Message::Text(t) => println!("收到: {}", t),
                Message::Close(_) => break,
                _ => {}
            }
        }
    });

    // 定时发送心跳
    loop {
        if tx.send(Message::Ping(vec![])).await.is_err() {
            break;
        }
        tokio::time::sleep(std::time::Duration::from_secs(30)).await;
    }
}

// 关键区别:
// HTTP: 请求-响应模式,无状态
// WebSocket: 全双工,有状态,长连接

✅ 学习检查清单

Axum 基础

  • 能搭建基本的 Axum 项目
  • 理解 Axum/Tower/Hyper/Tokio 的层次关系
  • 掌握路由定义(方法路由、嵌套路由、通配符)

Handler 与 Extractor

  • 熟练使用 Path / Query / Json / State 提取器
  • 理解 FromRequestFromRequestParts 的区别
  • 能实现自定义 Extractor

中间件

  • 理解 Tower Layer / Service 抽象
  • 能使用 tower-http 常用中间件
  • 能编写自定义中间件函数
  • 理解中间件的执行顺序(洋葱模型)

错误处理

  • 实现统一的 AppError 类型
  • 正确实现 IntoResponse trait
  • 熟练使用 ? 操作符配合 From 转换

数据库

  • 掌握 SQLx 连接池配置
  • 能实现完整的 CRUD 操作
  • 了解数据库迁移

认证

  • 实现 JWT 认证流程
  • 编写认证 Extractor
  • 区分公开路由和受保护路由

WebSocket

  • 理解 WebSocket 协议升级过程
  • 实现基本的 WebSocket 通信
  • 能实现广播/聊天室模式

测试与部署

  • 编写集成测试(tower::ServiceExt::oneshot
  • 编写 Dockerfile(多阶段构建)
  • 实现优雅关闭

📌 核心要点: Axum 的设计哲学是"组合优于继承"——通过 Tower 的 Service/Layer 抽象和类型驱动的 Extractor 模式,用编译期类型检查取代运行时错误,让 Web 开发既安全又高效。