Axum Web 开发¶
📋 章节信息¶
| 属性 | 详情 |
|---|---|
| 学习时间 | 8-10小时 |
| 难度等级 | ⭐⭐⭐⭐ 中高级 |
| 前置知识 | 异步编程(async/await)、Trait系统、错误处理、序列化(serde) |
| 核心主题 | Axum框架、路由、Handler、中间件、数据库、认证、部署 |
1. Axum 框架简介与架构¶
1.1 什么是 Axum¶
Axum 是由 Tokio 团队开发的 Rust Web 框架,具有以下核心特点:
- 基于 Tower 生态系统:复用 Tower 的 Service / Layer 抽象
- 基于 Hyper:HTTP 实现层使用 Hyper
- 基于 Tokio:异步运行时使用 Tokio
- 零宏路由:不使用过程宏,而是使用类型系统和组合子
- 提取器模式:通过类型推导自动提取请求数据
1.2 架构层次¶
Text Only
┌─────────────────────────┐
│ Axum(路由+Handler) │
├─────────────────────────┤
│ Tower(中间件+Service) │
├─────────────────────────┤
│ Hyper(HTTP协议) │
├─────────────────────────┤
│ Tokio(异步运行时) │
└─────────────────────────┘
1.3 快速入门¶
TOML
# Cargo.toml
[package]
name = "axum-demo"
version = "0.1.0"
edition = "2024"
[dependencies]
axum = "0.8"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.5"
tower-http = { version = "0.6", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = "0.3"
Rust
use axum::{routing::get, Router};
#[tokio::main]
async fn main() {
// 初始化日志
tracing_subscriber::fmt::init();
// 构建路由
let app = Router::new()
.route("/", get(root))
.route("/hello/{name}", get(hello));
// 启动服务器
// ⚠️ 教学示例使用 unwrap();生产代码应使用 ? 或 expect() 并配合优雅的错误处理
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
tracing::info!("服务器启动在 http://localhost:3000");
axum::serve(listener, app).await.unwrap();
}
async fn root() -> &'static str {
"Hello, Axum!"
}
async fn hello(axum::extract::Path(name): axum::extract::Path<String>) -> String {
format!("Hello, {}!", name)
}
2. 路由系统¶
2.1 基本路由¶
Rust
use axum::{
routing::{get, post, put, delete, patch},
Router,
};
fn app() -> Router {
Router::new()
// HTTP 方法路由
.route("/users", get(list_users).post(create_user))
.route("/users/{id}", get(get_user).put(update_user).delete(delete_user))
// 单方法路由
.route("/health", get(health_check))
.route("/upload", post(upload_file))
}
async fn list_users() -> &'static str { "列出所有用户" }
async fn create_user() -> &'static str { "创建用户" }
async fn get_user() -> &'static str { "获取用户" }
async fn update_user() -> &'static str { "更新用户" }
async fn delete_user() -> &'static str { "删除用户" }
async fn health_check() -> &'static str { "OK" }
async fn upload_file() -> &'static str { "上传成功" }
2.2 嵌套路由¶
Rust
use axum::Router;
use axum::routing::{get, post};
fn app() -> Router {
Router::new()
.nest("/api", api_routes())
.nest("/admin", admin_routes())
.route("/", get(index))
}
fn api_routes() -> Router {
Router::new()
.nest("/v1", v1_routes())
.nest("/v2", v2_routes())
}
fn v1_routes() -> Router {
Router::new()
.route("/users", get(list_users_v1))
.route("/posts", get(list_posts_v1))
}
fn v2_routes() -> Router {
Router::new()
.route("/users", get(list_users_v2))
.route("/posts", get(list_posts_v2))
}
fn admin_routes() -> Router {
Router::new()
.route("/dashboard", get(dashboard))
.route("/settings", get(settings).post(update_settings))
}
// 最终路由结构:
// GET / -> index
// GET /api/v1/users -> list_users_v1
// GET /api/v1/posts -> list_posts_v1
// GET /api/v2/users -> list_users_v2
// GET /api/v2/posts -> list_posts_v2
// GET /admin/dashboard -> dashboard
// GET /admin/settings -> settings
async fn index() -> &'static str { "首页" }
async fn list_users_v1() -> &'static str { "v1 users" }
async fn list_posts_v1() -> &'static str { "v1 posts" }
async fn list_users_v2() -> &'static str { "v2 users" }
async fn list_posts_v2() -> &'static str { "v2 posts" }
async fn dashboard() -> &'static str { "管理面板" }
async fn settings() -> &'static str { "设置" }
async fn update_settings() -> &'static str { "更新设置" }
2.3 路由参数¶
Rust
use axum::{
extract::Path,
routing::get,
Router,
};
fn app() -> Router {
Router::new()
// 单个参数
.route("/users/{id}", get(get_user))
// 多个参数
.route("/users/{user_id}/posts/{post_id}", get(get_user_post))
// 通配符(匹配剩余路径)
.route("/files/*path", get(serve_file))
}
// 单个路径参数
async fn get_user(Path(id): Path<u64>) -> String {
format!("用户ID: {}", id)
}
// 多个路径参数 - 使用元组
async fn get_user_post(Path((user_id, post_id)): Path<(u64, u64)>) -> String {
format!("用户 {} 的帖子 {}", user_id, post_id)
}
// 通配符路径
async fn serve_file(Path(path): Path<String>) -> String {
format!("请求文件: {}", path)
}
2.4 Fallback 与 404¶
Rust
use axum::{
http::StatusCode,
routing::get,
Router,
};
fn app() -> Router {
Router::new()
.route("/", get(root))
.route("/about", get(about))
// 未匹配路由的兜底处理
.fallback(not_found)
}
async fn root() -> &'static str { "首页" }
async fn about() -> &'static str { "关于" }
async fn not_found() -> (StatusCode, &'static str) {
(StatusCode::NOT_FOUND, "页面不存在")
}
3. Handler 函数与提取器(Extractor)¶
3.1 Handler 基础¶
Rust
use axum::{
http::StatusCode,
response::{Html, Json, IntoResponse},
};
use serde_json::json;
// Handler 就是普通的 async 函数
// 返回值需要实现 IntoResponse
// 返回字符串
async fn plain_text() -> &'static str {
"纯文本响应"
}
// 返回 HTML
async fn html_page() -> Html<&'static str> {
Html("<h1>Hello, Axum!</h1>")
}
// 返回 JSON
async fn json_response() -> Json<serde_json::Value> {
Json(json!({
"message": "Hello",
"status": "ok"
}))
}
// 返回状态码 + 响应体
async fn with_status() -> (StatusCode, String) {
(StatusCode::CREATED, "资源已创建".to_string())
}
// 返回 impl IntoResponse(灵活)
async fn flexible() -> impl IntoResponse {
(StatusCode::OK, [("X-Custom", "value")], "自定义响应头")
}
3.2 Path 提取器¶
Rust
use axum::extract::Path;
use serde::Deserialize;
// 使用结构体解析路径参数
#[derive(Deserialize)]
struct UserParams {
user_id: u64,
post_id: u64,
}
async fn get_user_post(Path(params): Path<UserParams>) -> String {
format!("用户 {} 的帖子 {}", params.user_id, params.post_id)
}
// 路由: /users/{user_id}/posts/{post_id}
3.3 Query 提取器¶
Rust
use axum::extract::Query;
use serde::Deserialize;
#[derive(Deserialize)]
struct Pagination {
page: Option<u32>,
per_page: Option<u32>,
sort: Option<String>,
}
// GET /users?page=1&per_page=20&sort=name
async fn list_users(Query(pagination): Query<Pagination>) -> String {
let page = pagination.page.unwrap_or(1);
let per_page = pagination.per_page.unwrap_or(10);
let sort = pagination.sort.unwrap_or_else(|| "id".to_string());
format!("页码: {}, 每页: {}, 排序: {}", page, per_page, sort)
}
3.4 Json 提取器¶
Rust
use axum::Json;
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
struct CreateUser {
username: String,
email: String,
password: String,
}
#[derive(Serialize)]
struct UserResponse {
id: u64,
username: String,
email: String,
}
// POST /users
// Content-Type: application/json
// {"username": "alice", "email": "alice@example.com", "password": "secret"}
async fn create_user(Json(payload): Json<CreateUser>) -> (axum::http::StatusCode, Json<UserResponse>) {
let user = UserResponse {
id: 1,
username: payload.username,
email: payload.email,
};
(axum::http::StatusCode::CREATED, Json(user))
}
3.5 State 提取器(共享状态)¶
Rust
use axum::{
extract::State,
routing::{get, post},
Router, Json,
};
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
// 应用状态
#[derive(Clone)]
struct AppState {
db: Arc<RwLock<Vec<User>>>,
config: AppConfig,
}
#[derive(Clone)]
struct AppConfig {
app_name: String,
max_users: usize,
}
#[derive(Clone, Serialize, Deserialize)]
struct User {
id: u64,
name: String,
}
async fn list_users(State(state): State<AppState>) -> Json<Vec<User>> {
let users = state.db.read().await;
Json(users.clone())
}
async fn create_user(
State(state): State<AppState>,
Json(user): Json<User>,
) -> axum::http::StatusCode {
let mut users = state.db.write().await;
if users.len() >= state.config.max_users {
return axum::http::StatusCode::TOO_MANY_REQUESTS;
}
users.push(user);
axum::http::StatusCode::CREATED
}
async fn app_info(State(state): State<AppState>) -> String {
format!("应用: {}", state.config.app_name)
}
#[tokio::main]
async fn main() {
let state = AppState {
db: Arc::new(RwLock::new(Vec::new())),
config: AppConfig {
app_name: "My API".to_string(),
max_users: 100,
},
};
let app = Router::new()
.route("/users", get(list_users).post(create_user))
.route("/info", get(app_info))
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
}
3.6 Header 提取器¶
Rust
use axum::{
http::HeaderMap,
extract::Request,
};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer, UserAgent},
};
// 获取所有请求头
async fn all_headers(headers: HeaderMap) -> String {
let mut result = String::new();
for (key, value) in &headers {
result.push_str(&format!("{}: {}\n", key, value.to_str().unwrap_or("")));
}
result
}
// 类型化的请求头(需要 axum-extra)
async fn typed_headers(
TypedHeader(user_agent): TypedHeader<UserAgent>,
) -> String {
format!("User-Agent: {}", user_agent)
}
// 使用 Request 直接访问
async fn raw_request(req: Request) -> String {
let method = req.method().clone();
let uri = req.uri().clone();
format!("{} {}", method, uri)
}
3.7 自定义提取器¶
Rust
use axum::{
extract::FromRequestParts,
http::{request::Parts, StatusCode, header},
response::{IntoResponse, Response},
};
// 自定义提取器:从请求头中提取 API Key
// 注意:Axum 0.8+ 已移除 #[async_trait],使用原生 async fn in trait
struct ApiKey(String);
impl<S> FromRequestParts<S> for ApiKey
where
S: Send + Sync,
{
type Rejection = ApiKeyError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let api_key = parts
.headers
.get("X-API-Key")
.and_then(|v| v.to_str().ok())
.ok_or(ApiKeyError::Missing)?;
if api_key.len() < 32 {
return Err(ApiKeyError::Invalid);
}
Ok(ApiKey(api_key.to_string()))
}
}
enum ApiKeyError {
Missing,
Invalid,
}
impl IntoResponse for ApiKeyError {
fn into_response(self) -> Response {
let (status, msg) = match self {
ApiKeyError::Missing => (StatusCode::UNAUTHORIZED, "缺少 API Key"),
ApiKeyError::Invalid => (StatusCode::FORBIDDEN, "无效的 API Key"),
};
(status, msg).into_response()
}
}
// 使用自定义提取器
async fn protected_endpoint(ApiKey(key): ApiKey) -> String {
format!("已认证,API Key: {}...", &key[..8])
}
4. 中间件¶
4.1 使用 Tower Layer¶
Rust
use axum::{Router, routing::get};
use tower_http::{
cors::{CorsLayer, Any},
trace::TraceLayer,
compression::CompressionLayer,
timeout::TimeoutLayer,
};
use std::time::Duration;
use tower::ServiceBuilder;
fn app() -> Router {
Router::new()
.route("/", get(root))
.route("/api/data", get(get_data))
.layer(
ServiceBuilder::new()
// 按顺序执行(从下到上)
.layer(TraceLayer::new_for_http()) // 4. 请求追踪
.layer(TimeoutLayer::new(Duration::from_secs(30))) // 3. 超时
.layer(CompressionLayer::new()) // 2. 响应压缩
.layer(
CorsLayer::new() // 1. CORS
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
),
)
}
async fn root() -> &'static str { "Hello" }
async fn get_data() -> &'static str { "Data" }
4.2 请求日志中间件¶
Rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
extract::Request,
response::Response,
};
use std::time::Instant;
// 自定义中间件函数
async fn logging_middleware(req: Request, next: Next) -> Response {
let method = req.method().clone();
let uri = req.uri().clone();
let start = Instant::now();
tracing::info!("→ {} {}", method, uri);
let response = next.run(req).await;
let duration = start.elapsed();
tracing::info!(
"← {} {} [{}] {:?}",
method,
uri,
response.status(),
duration
);
response
}
fn app() -> Router {
Router::new()
.route("/", get(|| async { "Hello" }))
.layer(middleware::from_fn(logging_middleware))
}
4.3 认证中间件¶
Rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
extract::Request,
response::{Response, IntoResponse},
http::{StatusCode, header},
};
async fn auth_middleware(req: Request, next: Next) -> Result<Response, StatusCode> {
let auth_header = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
match auth_header {
Some(token) if token.starts_with("Bearer ") => {
let token = &token[7..];
if validate_token(token).await {
Ok(next.run(req).await)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
_ => Err(StatusCode::UNAUTHORIZED),
}
}
async fn validate_token(token: &str) -> bool {
// 实际应用中验证 JWT 等
!token.is_empty()
}
fn app() -> Router {
let public_routes = Router::new()
.route("/login", get(login))
.route("/health", get(health));
let protected_routes = Router::new()
.route("/dashboard", get(dashboard))
.route("/profile", get(profile))
.layer(middleware::from_fn(auth_middleware));
Router::new()
.merge(public_routes)
.merge(protected_routes)
}
async fn login() -> &'static str { "登录页" }
async fn health() -> &'static str { "OK" }
async fn dashboard() -> &'static str { "仪表板" }
async fn profile() -> &'static str { "个人资料" }
4.4 限流中间件¶
Rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
extract::Request,
response::{Response, IntoResponse},
http::StatusCode,
};
use std::sync::Arc;
use tokio::sync::Mutex;
use std::collections::HashMap;
use std::time::Instant;
#[derive(Clone)]
struct RateLimiter {
requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
max_requests: usize,
window_secs: u64,
}
impl RateLimiter {
fn new(max_requests: usize, window_secs: u64) -> Self {
RateLimiter {
requests: Arc::new(Mutex::new(HashMap::new())),
max_requests,
window_secs,
}
}
async fn check(&self, ip: &str) -> bool {
let mut requests = self.requests.lock().await;
let now = Instant::now();
let window = std::time::Duration::from_secs(self.window_secs);
let entries = requests.entry(ip.to_string()).or_insert_with(Vec::new);
// 清理过期记录
entries.retain(|&t| now.duration_since(t) < window);
if entries.len() >= self.max_requests {
false
} else {
entries.push(now);
true
}
}
}
async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let ip = req
.headers()
.get("X-Forwarded-For")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
if limiter.check(&ip).await {
Ok(next.run(req).await)
} else {
Err(StatusCode::TOO_MANY_REQUESTS)
}
}
5. 错误处理¶
5.1 统一错误类型¶
Rust
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
// 统一应用错误类型
#[derive(Debug)]
pub enum AppError {
NotFound(String),
BadRequest(String),
Unauthorized,
Forbidden,
InternalError(String),
DatabaseError(String),
ValidationError(Vec<String>),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "未授权".to_string()),
AppError::Forbidden => (StatusCode::FORBIDDEN, "禁止访问".to_string()),
AppError::InternalError(msg) => {
tracing::error!("内部错误: {}", msg);
(StatusCode::INTERNAL_SERVER_ERROR, "服务器内部错误".to_string())
}
AppError::DatabaseError(msg) => {
tracing::error!("数据库错误: {}", msg);
(StatusCode::INTERNAL_SERVER_ERROR, "数据库错误".to_string())
}
AppError::ValidationError(errors) => {
(StatusCode::UNPROCESSABLE_ENTITY, errors.join(", "))
}
};
let body = Json(json!({
"error": {
"code": status.as_u16(),
"message": error_message,
}
}));
(status, body).into_response()
}
}
// 从标准库错误自动转换
impl From<std::io::Error> for AppError {
fn from(err: std::io::Error) -> Self {
AppError::InternalError(err.to_string())
}
}
impl From<serde_json::Error> for AppError {
fn from(err: serde_json::Error) -> Self {
AppError::BadRequest(format!("JSON 解析错误: {}", err))
}
}
5.2 在 Handler 中使用错误类型¶
Rust
use axum::{extract::Path, Json};
use serde::Serialize;
#[derive(Serialize)]
struct User {
id: u64,
name: String,
email: String,
}
// 返回 Result<T, AppError> 的 Handler
async fn get_user(Path(id): Path<u64>) -> Result<Json<User>, AppError> {
if id == 0 {
return Err(AppError::BadRequest("ID 不能为 0".to_string()));
}
// 模拟数据库查询
let user = find_user_by_id(id)
.await
.ok_or_else(|| AppError::NotFound(format!("用户 {} 不存在", id)))?;
Ok(Json(user))
}
async fn find_user_by_id(id: u64) -> Option<User> {
if id == 1 {
Some(User {
id: 1,
name: "Alice".to_string(),
email: "alice@example.com".to_string(),
})
} else {
None
}
}
// 使用 ? 操作符链式处理错误
async fn create_user(Json(payload): Json<serde_json::Value>) -> Result<Json<User>, AppError> {
let name = payload["name"]
.as_str()
.ok_or_else(|| AppError::ValidationError(vec!["name 字段必须为字符串".to_string()]))?;
let email = payload["email"]
.as_str()
.ok_or_else(|| AppError::ValidationError(vec!["email 字段必须为字符串".to_string()]))?;
let user = User {
id: 42,
name: name.to_string(),
email: email.to_string(),
};
Ok(Json(user))
}
5.3 使用 anyhow 简化错误处理¶
Rust
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
// 使用 anyhow 作为内部错误类型
struct AnyhowError(anyhow::Error);
impl IntoResponse for AnyhowError {
fn into_response(self) -> Response {
tracing::error!("应用错误: {:?}", self.0);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": self.0.to_string()})),
)
.into_response()
}
}
// 允许 ? 操作符自动转换
impl<E> From<E> for AnyhowError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
AnyhowError(err.into())
}
}
// Handler 中使用
async fn handler() -> Result<String, AnyhowError> {
let data = std::fs::read_to_string("config.toml")?; // 自动转换
Ok(data)
}
6. 数据库集成(SQLx)¶
6.1 配置 SQLx¶
TOML
# Cargo.toml
[dependencies]
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "uuid", "chrono"] }
uuid = { version = "1", features = ["v4", "serde"] }
chrono = { version = "0.4", features = ["serde"] }
6.2 数据库连接池¶
Rust
use axum::{
extract::State,
routing::{get, post},
Router, Json,
};
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use serde::{Deserialize, Serialize};
#[derive(Clone)]
struct AppState {
db: PgPool,
}
#[tokio::main]
async fn main() {
// 从环境变量读取数据库 URL
let database_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://user:pass@localhost/mydb".to_string());
// 创建连接池
let pool = PgPoolOptions::new()
.max_connections(20)
.min_connections(5)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&database_url)
.await
.expect("无法连接数据库");
// 运行迁移
sqlx::migrate!("./migrations")
.run(&pool)
.await
.expect("迁移失败");
let state = AppState { db: pool };
let app = Router::new()
.route("/users", get(list_users).post(create_user))
.route("/users/{id}", get(get_user).delete(delete_user))
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
}
6.3 CRUD 操作¶
Rust
use axum::{extract::{Path, State}, http::StatusCode, Json};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
use chrono::{DateTime, Utc};
#[derive(Debug, Serialize, FromRow)]
struct User {
id: Uuid,
username: String,
email: String,
created_at: DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
struct CreateUserRequest {
username: String,
email: String,
}
// CREATE
async fn create_user(
State(state): State<AppState>,
Json(payload): Json<CreateUserRequest>,
) -> Result<(StatusCode, Json<User>), AppError> {
let user = sqlx::query_as::<_, User>(
r#"
INSERT INTO users (id, username, email, created_at)
VALUES ($1, $2, $3, $4)
RETURNING id, username, email, created_at
"#,
)
.bind(Uuid::new_v4())
.bind(&payload.username)
.bind(&payload.email)
.bind(Utc::now())
.fetch_one(&state.db)
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok((StatusCode::CREATED, Json(user)))
}
// READ (List)
async fn list_users(
State(state): State<AppState>,
) -> Result<Json<Vec<User>>, AppError> {
let users = sqlx::query_as::<_, User>("SELECT id, username, email, created_at FROM users ORDER BY created_at DESC")
.fetch_all(&state.db)
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(Json(users))
}
// READ (Single)
async fn get_user(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<Json<User>, AppError> {
let user = sqlx::query_as::<_, User>(
"SELECT id, username, email, created_at FROM users WHERE id = $1",
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound(format!("用户 {} 不存在", id)))?;
Ok(Json(user))
}
// DELETE
async fn delete_user(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<StatusCode, AppError> {
let result = sqlx::query("DELETE FROM users WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
if result.rows_affected() == 0 {
Err(AppError::NotFound(format!("用户 {} 不存在", id)))
} else {
Ok(StatusCode::NO_CONTENT)
}
}
6.4 数据库迁移¶
SQL
-- migrations/001_create_users.sql
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
email VARCHAR(255) NOT NULL UNIQUE,
password_hash VARCHAR(255) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_users_email ON users(email);
CREATE INDEX idx_users_username ON users(username);
7. 认证与授权¶
7.1 JWT 认证¶
Rust
use jsonwebtoken::{encode, decode, Header, Validation, EncodingKey, DecodingKey};
use serde::{Deserialize, Serialize};
use chrono::{Utc, Duration};
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String, // 用户ID
exp: usize, // 过期时间
iat: usize, // 签发时间
role: String, // 角色
}
struct JwtConfig {
secret: String,
expiration_hours: i64,
}
impl JwtConfig {
fn new() -> Self {
JwtConfig {
secret: std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-key".to_string()),
expiration_hours: 24,
}
}
// 生成 Token
fn create_token(&self, user_id: &str, role: &str) -> Result<String, jsonwebtoken::errors::Error> {
let now = Utc::now();
let claims = Claims {
sub: user_id.to_string(),
exp: (now + Duration::hours(self.expiration_hours)).timestamp() as usize,
iat: now.timestamp() as usize,
role: role.to_string(),
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(self.secret.as_bytes()),
)
}
// 验证 Token
fn verify_token(&self, token: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(self.secret.as_bytes()),
&Validation::default(),
)?;
Ok(token_data.claims)
}
}
7.2 认证提取器¶
Rust
use axum::{
extract::FromRequestParts,
http::{request::Parts, StatusCode, header},
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
// 认证用户信息
#[derive(Debug, Clone)]
pub struct AuthUser {
pub user_id: String,
pub role: String,
}
// Axum 0.8+ 使用原生 async fn in trait,无需 #[async_trait]
impl<S> FromRequestParts<S> for AuthUser
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// 从 Authorization 头提取 Bearer Token
let auth_header = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or(AuthError::MissingToken)?;
if !auth_header.starts_with("Bearer ") {
return Err(AuthError::InvalidToken);
}
let token = &auth_header[7..];
let jwt_config = JwtConfig::new();
let claims = jwt_config
.verify_token(token)
.map_err(|_| AuthError::InvalidToken)?;
Ok(AuthUser {
user_id: claims.sub,
role: claims.role,
})
}
}
#[derive(Debug)]
pub enum AuthError {
MissingToken,
InvalidToken,
InsufficientPermission,
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, msg) = match self {
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "缺少认证 Token"),
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "无效的 Token"),
AuthError::InsufficientPermission => (StatusCode::FORBIDDEN, "权限不足"),
};
(status, Json(json!({"error": msg}))).into_response()
}
}
7.3 登录接口¶
Rust
use axum::{Json, http::StatusCode};
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
struct LoginRequest {
username: String,
password: String,
}
#[derive(Serialize)]
struct LoginResponse {
token: String,
token_type: String,
expires_in: i64,
}
async fn login(Json(payload): Json<LoginRequest>) -> Result<Json<LoginResponse>, AppError> {
// 验证用户凭证(实际应查数据库并验证密码哈希)
let (user_id, role) = verify_credentials(&payload.username, &payload.password)
.await
.ok_or(AppError::Unauthorized)?;
let jwt_config = JwtConfig::new();
let token = jwt_config
.create_token(&user_id, &role)
.map_err(|e| AppError::InternalError(e.to_string()))?;
Ok(Json(LoginResponse {
token,
token_type: "Bearer".to_string(),
expires_in: 86400,
}))
}
async fn verify_credentials(username: &str, password: &str) -> Option<(String, String)> {
// 示例:实际应查询数据库
if username == "admin" && password == "password" {
Some(("user-1".to_string(), "admin".to_string()))
} else {
None
}
}
// 受保护的路由使用 AuthUser 提取器
async fn protected(user: AuthUser) -> String {
format!("欢迎,用户 {},角色: {}", user.user_id, user.role)
}
// 仅管理员访问
async fn admin_only(user: AuthUser) -> Result<String, AuthError> {
if user.role != "admin" {
return Err(AuthError::InsufficientPermission);
}
Ok("管理员面板".to_string())
}
8. WebSocket 支持¶
8.1 基本 WebSocket¶
Rust
use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
response::Response,
routing::get,
Router,
};
use futures_util::{SinkExt, StreamExt};
fn app() -> Router {
Router::new()
.route("/ws", get(ws_handler))
}
// WebSocket 升级处理
async fn ws_handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(handle_socket)
}
// WebSocket 连接处理
async fn handle_socket(mut socket: WebSocket) {
// 发送欢迎消息
if socket
.send(Message::Text("欢迎连接 WebSocket!".to_string()))
.await
.is_err()
{
return; // 客户端已断开
}
// 消息循环
while let Some(msg) = socket.recv().await {
let msg = match msg {
Ok(msg) => msg,
Err(_) => return, // 连接错误
};
match msg {
Message::Text(text) => {
// Echo: 回显收到的消息
let reply = format!("收到: {}", text);
if socket.send(Message::Text(reply)).await.is_err() {
return;
}
}
Message::Binary(data) => {
// 处理二进制数据
if socket.send(Message::Binary(data)).await.is_err() {
return;
}
}
Message::Ping(data) => {
if socket.send(Message::Pong(data)).await.is_err() {
return;
}
}
Message::Close(_) => return,
_ => {}
}
}
}
8.2 WebSocket 广播(聊天室)¶
Rust
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::Response,
routing::get,
Router,
};
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::broadcast;
#[derive(Clone)]
struct ChatState {
tx: broadcast::Sender<String>,
}
fn app() -> Router {
let (tx, _rx) = broadcast::channel(100);
let state = ChatState { tx };
Router::new()
.route("/chat", get(chat_handler))
.with_state(state)
}
async fn chat_handler(
ws: WebSocketUpgrade,
State(state): State<ChatState>,
) -> Response {
ws.on_upgrade(move |socket| handle_chat(socket, state))
}
async fn handle_chat(socket: WebSocket, state: ChatState) {
let (mut sender, mut receiver) = socket.split();
// 订阅广播频道
let mut rx = state.tx.subscribe();
// 任务1:接收广播消息并发送给客户端
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = rx.recv().await {
if sender.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
// 任务2:接收客户端消息并广播
let tx = state.tx.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
let _ = tx.send(text);
}
});
// 等待任一任务结束
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
}
}
9. 测试¶
9.1 集成测试¶
Rust
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
};
use http_body_util::BodyExt;
use tower::ServiceExt; // for `oneshot`
fn test_app() -> Router {
Router::new()
.route("/", get(|| async { "Hello" }))
.route("/users/{id}", get(|Path(id): Path<u64>| async move {
format!("User {}", id)
}))
.route("/json", get(|| async {
Json(serde_json::json!({"status": "ok"}))
}))
}
#[tokio::test]
async fn test_root() {
let app = test_app();
let response = app
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Hello");
}
#[tokio::test]
async fn test_user_path() {
let app = test_app();
let response = app
.oneshot(
Request::builder()
.uri("/users/42")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"User 42");
}
#[tokio::test]
async fn test_json_response() {
let app = test_app();
let response = app
.oneshot(
Request::builder()
.uri("/json")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["status"], "ok");
}
#[tokio::test]
async fn test_not_found() {
let app = test_app();
let response = app
.oneshot(
Request::builder()
.uri("/nonexistent")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
}
9.2 测试 POST 请求¶
Rust
#[cfg(test)]
mod post_tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode, header},
};
use http_body_util::BodyExt;
use tower::ServiceExt;
#[tokio::test]
async fn test_create_user() {
let app = Router::new()
.route("/users", post(|Json(body): Json<serde_json::Value>| async move {
let name = body["name"].as_str().unwrap_or("unknown");
(StatusCode::CREATED, Json(serde_json::json!({
"id": 1,
"name": name
})))
}));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/users")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"name": "Alice"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::CREATED);
let body = response.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["name"], "Alice");
}
}
9.3 使用 TestClient(辅助工具)¶
Rust
// 自定义测试辅助模块
#[cfg(test)]
mod test_helpers {
use axum::{body::Body, Router};
use http_body_util::BodyExt;
use axum::http::{Request, Response, StatusCode, header};
use tower::ServiceExt;
pub struct TestClient {
app: Router,
}
impl TestClient {
pub fn new(app: Router) -> Self {
TestClient { app }
}
pub async fn get(&self, uri: &str) -> TestResponse {
let response = self
.app
.clone()
.oneshot(
Request::builder()
.uri(uri)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
TestResponse(response)
}
pub async fn post_json(&self, uri: &str, body: serde_json::Value) -> TestResponse {
let response = self
.app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri(uri)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&body).unwrap()))
.unwrap(),
)
.await
.unwrap();
TestResponse(response)
}
}
pub struct TestResponse(Response<Body>);
impl TestResponse {
pub fn status(&self) -> StatusCode {
self.0.status()
}
pub async fn json(self) -> serde_json::Value {
let body = self.0.into_body().collect().await.unwrap().to_bytes();
serde_json::from_slice(&body).unwrap()
}
pub async fn text(self) -> String {
let body = self.0.into_body().collect().await.unwrap().to_bytes();
String::from_utf8(body.to_vec()).unwrap()
}
}
}
10. 部署¶
10.1 Docker 多阶段构建¶
Docker
# Dockerfile
# 阶段1:构建
FROM rust:1.85-bookworm AS builder
WORKDIR /app
# 先复制依赖文件,利用 Docker 层缓存
COPY Cargo.toml Cargo.lock ./
RUN mkdir src && echo "fn main() {}" > src/main.rs
RUN cargo build --release
RUN rm -rf src
# 复制源代码并构建
COPY src/ src/
COPY migrations/ migrations/
RUN touch src/main.rs # 更新时间戳,强制重新编译
RUN cargo build --release
# 阶段2:运行(使用精简镜像)
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -y \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# 从构建阶段复制二进制文件
COPY --from=builder /app/target/release/axum-demo /app/axum-demo
# 非 root 用户运行
RUN useradd -r -s /bin/false appuser
USER appuser
EXPOSE 3000
ENV RUST_LOG=info
CMD ["./axum-demo"]
10.2 docker-compose¶
YAML
# docker-compose.yml
# 注意:Docker Compose V2 中 version 字段已废弃,无需指定
services:
app:
build: .
ports:
- "3000:3000"
environment:
- DATABASE_URL=postgres://user:password@db:5432/myapp
- JWT_SECRET=your-secret-key
- RUST_LOG=info
depends_on:
db:
condition: service_healthy
db:
image: postgres:16-alpine
environment:
POSTGRES_USER: user
POSTGRES_PASSWORD: password
POSTGRES_DB: myapp
volumes:
- pgdata:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U user -d myapp"]
interval: 5s
timeout: 5s
retries: 5
volumes:
pgdata:
10.3 静态编译(musl)¶
Docker
# 完全静态编译(无系统依赖)
FROM rust:1.85-bookworm AS builder
RUN apt-get update && apt-get install -y musl-tools
RUN rustup target add x86_64-unknown-linux-musl
WORKDIR /app
COPY . .
RUN cargo build --release --target x86_64-unknown-linux-musl
# 最小运行镜像
FROM scratch
COPY --from=builder /app/target/x86_64-unknown-linux-musl/release/axum-demo /axum-demo
# 如果需要 HTTPS,复制 CA 证书
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
EXPOSE 3000
ENTRYPOINT ["/axum-demo"]
10.4 优雅关闭¶
Rust
use axum::{routing::get, Router};
use tokio::signal;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let app = Router::new()
.route("/", get(|| async { "Hello" }))
.route("/health", get(|| async { "OK" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
tracing::info!("服务器启动在 http://0.0.0.0:3000");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
tracing::info!("服务器已优雅关闭");
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("无法注册 Ctrl+C 处理器");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("无法注册 SIGTERM 处理器")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => tracing::info!("收到 Ctrl+C"),
_ = terminate => tracing::info!("收到 SIGTERM"),
}
}
11. 完整 REST API 项目结构¶
11.1 推荐目录结构¶
Text Only
my-api/
├── Cargo.toml
├── Dockerfile
├── docker-compose.yml
├── .env
├── migrations/
│ ├── 001_create_users.sql
│ └── 002_create_posts.sql
├── src/
│ ├── main.rs # 入口:服务器启动
│ ├── config.rs # 配置管理
│ ├── error.rs # 统一错误类型
│ ├── db.rs # 数据库连接
│ ├── routes/
│ │ ├── mod.rs # 路由聚合
│ │ ├── users.rs # 用户路由
│ │ └── posts.rs # 帖子路由
│ ├── handlers/
│ │ ├── mod.rs
│ │ ├── users.rs # 用户 Handler
│ │ └── posts.rs # 帖子 Handler
│ ├── models/
│ │ ├── mod.rs
│ │ ├── user.rs # 用户模型
│ │ └── post.rs # 帖子模型
│ ├── middleware/
│ │ ├── mod.rs
│ │ ├── auth.rs # 认证中间件
│ │ └── logging.rs # 日志中间件
│ └── extractors/
│ ├── mod.rs
│ └── auth.rs # 认证提取器
└── tests/
└── api/
├── mod.rs
├── users_test.rs
└── posts_test.rs
11.2 配置管理¶
Rust
// src/config.rs
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub database_url: String,
pub jwt_secret: String,
pub server_host: String,
pub server_port: u16,
pub rust_log: String,
}
impl Config {
pub fn from_env() -> Result<Self, envy::Error> {
// 加载 .env 文件
dotenvy::dotenv().ok();
envy::from_env()
}
pub fn server_addr(&self) -> String {
format!("{}:{}", self.server_host, self.server_port)
}
}
11.3 主入口¶
Rust
// src/main.rs
mod config;
mod db;
mod error;
mod routes;
mod handlers;
mod models;
mod middleware;
mod extractors;
use config::Config;
use axum::Router;
use tower_http::trace::TraceLayer;
#[derive(Clone)]
pub struct AppState {
pub db: sqlx::PgPool,
pub config: Config,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// 加载配置
let config = Config::from_env()?;
// 初始化日志
tracing_subscriber::fmt()
.with_env_filter(&config.rust_log)
.init();
// 连接数据库
let db = db::create_pool(&config.database_url).await?;
sqlx::migrate!("./migrations").run(&db).await?;
let state = AppState {
db,
config: config.clone(),
};
// 构建应用
let app = Router::new()
.merge(routes::all_routes())
.layer(TraceLayer::new_for_http())
.with_state(state);
// 启动服务器
let addr = config.server_addr();
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("服务器启动在 http://{}", addr);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
async fn shutdown_signal() {
tokio::signal::ctrl_c().await.ok();
tracing::info!("正在关闭服务器...");
}
11.4 路由模块¶
Rust
// src/routes/mod.rs
use axum::{
routing::{get, post, put, delete},
Router,
middleware,
};
use crate::{AppState, handlers, middleware as mw};
pub fn all_routes() -> Router<AppState> {
Router::new()
.merge(public_routes())
.merge(protected_routes())
}
fn public_routes() -> Router<AppState> {
Router::new()
.route("/health", get(|| async { "OK" }))
.route("/api/login", post(handlers::users::login))
.route("/api/register", post(handlers::users::register))
}
fn protected_routes() -> Router<AppState> {
Router::new()
.route("/api/users", get(handlers::users::list))
.route("/api/users/{id}", get(handlers::users::get_by_id)
.put(handlers::users::update)
.delete(handlers::users::delete))
.route("/api/posts", get(handlers::posts::list)
.post(handlers::posts::create))
.route("/api/posts/{id}", get(handlers::posts::get_by_id)
.put(handlers::posts::update)
.delete(handlers::posts::delete))
.layer(middleware::from_fn(mw::auth::require_auth))
}
12. 面试题¶
Q1: 为什么选择 Axum?它与 Actix-web、Rocket 有什么区别?¶
A1:
Rust
// Axum 的核心优势:
// 1. 基于 Tower 生态 - 复用大量现成中间件
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use tower_http::compression::CompressionLayer;
// 这些中间件可以直接用于任何 Tower 兼容的框架
// 2. 零宏路由 - 不依赖过程宏
let app = Router::new()
.route("/users", get(list_users).post(create_user));
// 对比 Rocket: #[get("/users")] 需要过程宏
// 对比 Actix: #[get("/users")] 也需要过程宏
// 3. 提取器(Extractor)模式 - 类型驱动的请求解析
async fn handler(
Path(id): Path<u64>, // 从路径提取
Query(q): Query<Params>, // 从查询参数提取
Json(body): Json<Body>, // 从请求体提取
) -> impl IntoResponse {
// 参数顺序无关,类型系统自动匹配
todo!()
}
// 4. Tokio 团队维护 - 与 Tokio 运行时深度集成
use axum::extract::{Path, Query, Json};
use axum::response::IntoResponse;
struct Params;
struct Body;
async fn list_users() {}
async fn create_user() {}
| 特性 | Axum | Actix-web | Rocket |
|---|---|---|---|
| 异步运行时 | Tokio | Tokio(多个单线程运行时) | Tokio |
| 宏依赖 | 无 | 需要 | 需要 |
| 中间件生态 | Tower | 自有 | 自有 |
| 维护团队 | Tokio 团队 | 社区 | 个人 |
| 性能 | 极高 | 极高 | 高 |
Q2: 解释 Axum 的 Extractor 模式如何工作?¶
A2:
Rust
use axum::{
extract::{FromRequest, FromRequestParts, Request},
http::request::Parts,
async_trait,
};
// Extractor 的核心是两个 Trait:
// 1. FromRequestParts - 从请求的"部件"(头、URL等)提取,不消耗 body
// Path, Query, Headers, State 都实现了这个
#[async_trait]
trait FromRequestPartsExample<S>: Sized {
type Rejection; // 提取失败时的错误类型
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Self, Self::Rejection>;
}
// 2. FromRequest - 从完整请求提取,可以消耗 body
// Json, String, Bytes 实现了这个
#[async_trait]
trait FromRequestExample<S>: Sized {
type Rejection;
async fn from_request(
req: Request,
state: &S,
) -> Result<Self, Self::Rejection>;
}
// Handler 的参数列表就是一系列 Extractor
// Axum 按顺序调用它们,最后一个可以消耗 body
async fn handler(
// FromRequestParts extractors (不消耗body,可以有多个)
axum::extract::Path(id): axum::extract::Path<u64>,
axum::extract::Query(q): axum::extract::Query<std::collections::HashMap<String, String>>,
headers: axum::http::HeaderMap,
// FromRequest extractor (消耗body,只能有一个,必须放最后)
axum::Json(body): axum::Json<serde_json::Value>,
) -> String {
format!("id={}, query={:?}, body={}", id, q, body)
}
Q3: Axum 中如何实现统一错误处理?¶
A3:
Rust
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
// 核心:实现 IntoResponse trait
#[derive(Debug)]
enum AppError {
NotFound(String),
BadRequest(String),
Internal(anyhow::Error),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, message) = match self {
AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
AppError::Internal(err) => {
// 不暴露内部错误给客户端
tracing::error!("内部错误: {:?}", err);
(StatusCode::INTERNAL_SERVER_ERROR, "服务器错误".into())
}
};
(status, Json(json!({"error": message}))).into_response()
}
}
// 实现 From<T> 使 ? 操作符自动转换
impl From<sqlx::Error> for AppError {
fn from(err: sqlx::Error) -> Self {
match err {
sqlx::Error::RowNotFound => AppError::NotFound("记录不存在".into()),
_ => AppError::Internal(err.into()),
}
}
}
// Handler 中使用 Result<T, AppError>
async fn handler() -> Result<Json<serde_json::Value>, AppError> {
// ? 自动将 sqlx::Error 转换为 AppError
// let user = sqlx::query_as("SELECT ...").fetch_one(&db).await?;
Ok(Json(json!({"status": "ok"})))
}
Q4: 如何在 Axum 中管理共享状态?¶
A4:
Rust
use axum::{extract::State, Router, routing::get};
use std::sync::Arc;
use tokio::sync::RwLock;
// 方案1:State 提取器(推荐)
#[derive(Clone)]
struct AppState {
// 不可变配置:直接存储
db_url: String,
// 可变状态:用 Arc + 锁
counter: Arc<RwLock<u64>>,
// 连接池:已内置 Arc(Clone 即共享)
// db: sqlx::PgPool,
}
async fn handler(State(state): State<AppState>) -> String {
let count = state.counter.read().await;
format!("count: {}", *count)
}
fn app() -> Router {
let state = AppState {
db_url: "postgres://...".into(),
counter: Arc::new(RwLock::new(0)),
};
Router::new()
.route("/", get(handler))
.with_state(state)
// State 要求类型实现 Clone
// 如果状态大,用 Arc 包装避免深拷贝
}
// 方案2:Extension(不推荐,已弱化)
// .layer(Extension(shared_data))
// 关键区别:
// State - 编译时类型检查,如果漏了 with_state 会编译报错
// Extension - 运行时检查,可能 panic
Q5: Axum 中间件的执行顺序是怎样的?¶
A5:
Rust
use axum::{Router, routing::get, middleware};
// Layer 的执行顺序是"洋葱模型"
// 添加顺序:从下到上包裹
// 请求处理:从外到内
// 响应处理:从内到外
let app = Router::new()
.route("/", get(handler))
.layer(layer_a) // 最外层
.layer(layer_b) // 中间层
.layer(layer_c); // 最内层(最接近 handler)
// 请求流:layer_a → layer_b → layer_c → handler
// 响应流:handler → layer_c → layer_b → layer_a
// 使用 ServiceBuilder 可以用"自然顺序"
use tower::ServiceBuilder;
let app = Router::new()
.route("/", get(handler))
.layer(
ServiceBuilder::new()
.layer(layer_a) // 先执行
.layer(layer_b) // 后执行
.layer(layer_c) // 最后执行,最接近 handler
);
// 请求流:layer_a → layer_b → layer_c → handler
// ServiceBuilder 的顺序更直观
async fn handler() -> &'static str { "Hello" }
// 注意:以上代码为示意,layer_a/b/c 需要是实际的 Layer 实例
Q6: Axum 如何实现 WebSocket?与 HTTP 有什么不同?¶
A6:
Rust
use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
response::Response,
routing::get,
Router,
};
use futures_util::{SinkExt, StreamExt};
// WebSocket 是 HTTP 升级协议
// 1. 客户端发送 HTTP GET + Upgrade: websocket 头
// 2. 服务器返回 101 Switching Protocols
// 3. 连接升级为全双工 WebSocket
async fn ws_handler(ws: WebSocketUpgrade) -> Response {
// WebSocketUpgrade 提取器处理协议升级
// on_upgrade 在升级完成后调用回调
ws.on_upgrade(handle_socket)
}
async fn handle_socket(mut socket: WebSocket) {
// WebSocket 是全双工的
// 可以同时发送和接收
let (mut tx, mut rx) = socket.split();
// 读写可以在不同的任务中
tokio::spawn(async move {
while let Some(Ok(msg)) = rx.next().await {
match msg {
Message::Text(t) => println!("收到: {}", t),
Message::Close(_) => break,
_ => {}
}
}
});
// 定时发送心跳
loop {
if tx.send(Message::Ping(vec![])).await.is_err() {
break;
}
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
}
}
// 关键区别:
// HTTP: 请求-响应模式,无状态
// WebSocket: 全双工,有状态,长连接
✅ 学习检查清单¶
Axum 基础¶
- 能搭建基本的 Axum 项目
- 理解 Axum/Tower/Hyper/Tokio 的层次关系
- 掌握路由定义(方法路由、嵌套路由、通配符)
Handler 与 Extractor¶
- 熟练使用 Path / Query / Json / State 提取器
- 理解
FromRequest和FromRequestParts的区别 - 能实现自定义 Extractor
中间件¶
- 理解 Tower Layer / Service 抽象
- 能使用
tower-http常用中间件 - 能编写自定义中间件函数
- 理解中间件的执行顺序(洋葱模型)
错误处理¶
- 实现统一的
AppError类型 - 正确实现
IntoResponsetrait - 熟练使用
?操作符配合From转换
数据库¶
- 掌握 SQLx 连接池配置
- 能实现完整的 CRUD 操作
- 了解数据库迁移
认证¶
- 实现 JWT 认证流程
- 编写认证 Extractor
- 区分公开路由和受保护路由
WebSocket¶
- 理解 WebSocket 协议升级过程
- 实现基本的 WebSocket 通信
- 能实现广播/聊天室模式
测试与部署¶
- 编写集成测试(
tower::ServiceExt::oneshot) - 编写 Dockerfile(多阶段构建)
- 实现优雅关闭
📌 核心要点: Axum 的设计哲学是"组合优于继承"——通过 Tower 的 Service/Layer 抽象和类型驱动的 Extractor 模式,用编译期类型检查取代运行时错误,让 Web 开发既安全又高效。