How do I build web APIs with Axum?

Walkthrough

Axum is a modern, ergonomic web framework built on Hyper, Tower, and Tokio. It provides a clean API for building web applications with a focus on modularity and composability. Axum leverages Rust's type system to ensure correctness at compile time.

Key features:

  1. Extractors — automatically extract and validate data from requests
  2. Routing — declarative route definitions with method matching
  3. Middleware — composable layers via Tower services
  4. State sharing — thread-safe state access through Arc and extractors
  5. JSON handling — seamless serialization with serde integration

Axum's extractor system is its standout feature—parameters, bodies, headers, and state are all extracted type-safely.

Code Example

# Cargo.toml
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
use axum::{
    extract::{Path, Query, State},
    http::StatusCode,
    response::{IntoResponse, Json},
    routing::{get, post, put, delete},
    Router,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
 
// ===== Domain Types =====
 
#[derive(Debug, Serialize, Deserialize, Clone)]
struct User {
    id: u64,
    name: String,
    email: String,
    active: bool,
}
 
#[derive(Debug, Deserialize)]
struct CreateUser {
    name: String,
    email: String,
}
 
#[derive(Debug, Deserialize)]
struct UpdateUser {
    name: Option<String>,
    email: Option<String>,
    active: Option<bool>,
}
 
#[derive(Debug, Deserialize)]
struct Pagination {
    #[serde(default = "default_page")]
    page: u32,
    #[serde(default = "default_per_page")]
    per_page: u32,
}
 
fn default_page() -> u32 { 1 }
fn default_per_page() -> u32 { 10 }
 
// ===== Application State =====
 
type Db = Arc<RwLock<HashMap<u64, User>>>;
 
#[derive(Clone)]
struct AppState {
    db: Db,
}
 
// ===== Handlers =====
 
async fn root() -> &'static str {
    "Welcome to the User API!"
}
 
async fn list_users(
    State(state): State<AppState>,
    Query(pagination): Query<Pagination>,
) -> impl IntoResponse {
    let db = state.db.read().await;
    let users: Vec<&User> = db.values()
        .skip(((pagination.page - 1) * pagination.per_page) as usize)
        .take(pagination.per_page as usize)
        .collect();
    
    Json(users)
}
 
async fn get_user(
    State(state): State<AppState>,
    Path(id): Path<u64>,
) -> impl IntoResponse {
    let db = state.db.read().await;
    
    match db.get(&id) {
        Some(user) => (StatusCode::OK, Json(Some(user.clone()))),
        None => (StatusCode::NOT_FOUND, Json(None)),
    }
}
 
async fn create_user(
    State(state): State<AppState>,
    Json(payload): Json<CreateUser>,
) -> impl IntoResponse {
    let mut db = state.db.write().await;
    
    let id = db.keys().max().unwrap_or(&0) + 1;
    let user = User {
        id,
        name: payload.name,
        email: payload.email,
        active: true,
    };
    
    db.insert(id, user.clone());
    
    (StatusCode::CREATED, Json(user))
}
 
async fn update_user(
    State(state): State<AppState>,
    Path(id): Path<u64>,
    Json(payload): Json<UpdateUser>,
) -> impl IntoResponse {
    let mut db = state.db.write().await;
    
    match db.get_mut(&id) {
        Some(user) => {
            if let Some(name) = payload.name {
                user.name = name;
            }
            if let Some(email) = payload.email {
                user.email = email;
            }
            if let Some(active) = payload.active {
                user.active = active;
            }
            (StatusCode::OK, Json(Some(user.clone())))
        }
        None => (StatusCode::NOT_FOUND, Json(None)),
    }
}
 
async fn delete_user(
    State(state): State<AppState>,
    Path(id): Path<u64>,
) -> impl IntoResponse {
    let mut db = state.db.write().await;
    
    match db.remove(&id) {
        Some(user) => (StatusCode::OK, Json(Some(user))),
        None => (StatusCode::NOT_FOUND, Json(None)),
    }
}
 
// ===== Main =====
 
#[tokio::main]
async fn main() {
    // Initialize state with seed data
    let db = Arc::new(RwLock::new(HashMap::new()));
    db.write().await.insert(1, User {
        id: 1,
        name: "Alice".to_string(),
        email: "alice@example.com".to_string(),
        active: true,
    });
    
    let state = AppState { db };
 
    // Build router
    let app = Router::new()
        .route("/", get(root))
        .route("/users", get(list_users).post(create_user))
        .route("/users/{id}", get(get_user).put(update_user).delete(delete_user))
        .with_state(state);
 
    // Start server
    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap();
    println!("Server running on http://127.0.0.1:3000");
    
    axum::serve(listener, app).await.unwrap();
}

Error Handling

use axum::{
    response::{IntoResponse, Response},
    Json,
};
use serde_json::json;
use thiserror::Error;
 
#[derive(Error, Debug)]
enum ApiError {
    #[error("User not found")]
    NotFound,
    #[error("Invalid input: {0}")
    InvalidInput(String),
    #[error("Internal server error")
    Internal,
}
 
impl IntoResponse for ApiError {
    fn into_response(self) -> Response {
        let (status, message) = match self {
            ApiError::NotFound => (StatusCode::NOT_FOUND, self.to_string()),
            ApiError::InvalidInput(msg) => (StatusCode::BAD_REQUEST, msg),
            ApiError::Internal => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
        };
        
        let body = Json(json!({
            "error": message,
            "status": status.as_u16(),
        }));
        
        (status, body).into_response()
    }
}
 
// Handler returning Result
async fn get_user_or_error(
    State(state): State<AppState>,
    Path(id): Path<u64>,
) -> Result<Json<User>, ApiError> {
    let db = state.db.read().await;
    db.get(&id)
        .cloned()
        .map(Json)
        .ok_or(ApiError::NotFound)
}

Middleware with Tower

use axum::middleware::{self, Next};
use axum::response::Response;
use std::time::Instant;
 
async fn logging_middleware(
    request: axum::extract::Request,
    next: Next,
) -> Response {
    let method = request.method().clone();
    let path = request.uri().path().to_string();
    let start = Instant::now();
    
    println!("--> {} {}", method, path);
    
    let response = next.run(request).await;
    
    println!("<-- {} {} ( {:?} )", method, path, start.elapsed());
    
    response
}
 
// Apply to router
let app = Router::new()
    .route("/", get(root))
    .layer(middleware::from_fn(logging_middleware))
    .with_state(state);

Testing Axum Handlers

#[cfg(test)]
mod tests {
    use super::*;
    use axum::body::Body;
    use axum::http::{Request, Method};
    use tower::ServiceExt; // for `oneshot`
 
    #[tokio::test]
    async fn test_root() {
        let app = Router::new().route("/", get(root));
        
        let response = app
            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
            .await
            .unwrap();
        
        assert_eq!(response.status(), StatusCode::OK);
    }
}

Summary

  • Define handlers as async functions that return impl IntoResponseJson, String, and tuples of (StatusCode, Body) all work
  • Use extractors like Path<T>, Query<T>, Json<T>, and State<T> to declaratively parse request data
  • Router::new().route("/path", get(handler).post(handler)) defines RESTful endpoints
  • Path parameters use {name} syntax; extract with Path(id): Path<u64> or Path(params): Path<HashMap<String, String>>
  • Share application state via .with_state(state) and extract with State(state): State<AppState>
  • Implement IntoResponse for custom error types to return structured error JSON
  • Add middleware with .layer(middleware::from_fn(my_middleware)) — Tower provides many built-in layers
  • Test handlers using app.oneshot(request) from tower::ServiceExt
  • Axum integrates seamlessly with serde for JSON serialization and tokio for async runtime