Loading page…
Rust walkthroughs
Loading page…
axum::Extension share application state across handlers without explicit parameters?Extension uses Axum's request extension mechanism, which stores arbitrary types in a HashMap<TypeId, Box<dyn Any>> within the HTTP request, allowing handlers to extract shared state via the FromRequest trait without explicit function parameters. The state is inserted at the router level using .layer(Extension(state)) and becomes available in any handler that includes Extension<YourType> in its parameters, enabling dependency injection without threading references through every function signature.
use axum::{
Extension,
Router,
routing::get,
};
use std::sync::Arc;
// Shared state type
#[derive(Clone)]
struct AppState {
db_pool: String, // In reality: sqlx::PgPool
api_key: String,
}
async fn handler(Extension(state): Extension<Arc<AppState>>) -> String {
format!("API Key: {}", state.api_key)
}
#[tokio::main]
async fn main() {
let state = Arc::new(AppState {
db_pool: "postgres://...".to_string(),
api_key: "secret-key".to_string(),
});
let app = Router::new()
.route("/api", get(handler))
.layer(Extension(state));
// The Extension layer makes Arc<AppState> available to handlers
// Run server: axum::Server::bind(...).serve(app.into_make_service())
}Extension wraps your state and makes it extractable via handler parameters.
use axum::{Extension, Router, routing::get};
use std::sync::Arc;
// Simplified internal structure (conceptual):
//
// HTTP Request contains:
// struct Request<B> {
// method: Method,
// uri: Uri,
// headers: HeaderMap,
// version: Version,
// body: B,
// extensions: Extensions, // <- Extension data stored here
// }
//
// Extensions is approximately:
// pub struct Extensions {
// map: HashMap<TypeId, Box<dyn Any>>,
// }
#[derive(Clone)]
struct Database {
url: String,
}
async fn db_handler(Extension(db): Extension<Database>) -> String {
// FromRequest impl for Extension<T>:
// 1. Gets &mut Extensions from request
// 2. Looks up TypeId of T
// 3. Downcasts Box<dyn Any> to T
// 4. Returns T (or error if missing)
db.url
}
#[tokio::main]
async fn main() {
let db = Database { url: "postgres://localhost".to_string() };
// .layer(Extension(db)) adds middleware that:
// 1. Clones db for each request (or uses Arc)
// 2. Inserts db into request.extensions
// 3. Passes modified request to inner layers/handler
let app = Router::new()
.route("/db", get(db_handler))
.layer(Extension(db));
}Extensions are stored by TypeId in a type-erased map, allowing multiple extension types.
use axum::{Extension, Router, routing::get};
use std::sync::Arc;
#[derive(Clone)]
struct Database {
url: String,
}
#[derive(Clone)]
struct Config {
max_connections: usize,
}
#[derive(Clone)]
struct AuthService {
secret: String,
}
// Handlers can request multiple extensions
async fn complex_handler(
Extension(db): Extension<Database>,
Extension(config): Extension<Config>,
Extension(auth): Extension<AuthService>,
) -> String {
format!(
"DB: {}, Max Conn: {}, Auth: {}",
db.url, config.max_connections, auth.secret
)
}
// Or extract just the ones needed
async fn simple_handler(
Extension(config): Extension<Config>,
) -> String {
format!("Max connections: {}", config.max_connections)
}
#[tokio::main]
async fn main() {
let db = Database { url: "postgres://...".to_string() };
let config = Config { max_connections: 10 };
let auth = AuthService { secret: "secret".to_string() };
// Add multiple extensions
let app = Router::new()
.route("/complex", get(complex_handler))
.route("/simple", get(simple_handler))
.layer(Extension(db))
.layer(Extension(config))
.layer(Extension(auth));
// Each handler extracts only what it needs
}Multiple Extension layers add different types, all accessible via Extension<T>.
use axum::{Extension, Router, routing::get};
use std::sync::Arc;
use tokio::sync::RwLock;
struct AppState {
counter: RwLock<u64>,
}
async fn increment(
Extension(state): Extension<Arc<AppState>>,
) -> String {
let mut counter = state.counter.write().await;
*counter += 1;
format!("Counter: {}", *counter)
}
async fn read_counter(
Extension(state): Extension<Arc<AppState>>,
) -> String {
let counter = state.counter.read().await;
format!("Counter: {}", *counter)
}
#[tokio::main]
async fn main() {
let state = Arc::new(AppState {
counter: RwLock::new(0),
});
// Arc ensures all requests share the same AppState instance
let app = Router::new()
.route("/increment", get(increment))
.route("/read", get(read_counter))
.layer(Extension(state));
}Wrap state in Arc to share a single instance across all requests.
use axum::{
Extension,
extract::{FromRequest, FromRequestParts},
Router,
routing::get,
};
use http::Request;
use std::sync::Arc;
// Extension implements FromRequestParts (and FromRequest)
// This means it can be extracted from HTTP requests
#[derive(Clone)]
struct User {
name: String,
}
// Custom extractor using Extension internally
struct AuthenticatedUser {
name: String,
}
// Implement FromRequestParts for custom extractors
impl<S> FromRequestParts<S> for AuthenticatedUser
where
S: Send + Sync,
Arc<User>: FromRequestParts<S>,
{
type Rejection = axum::response::Json<String>;
async fn from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
// Get the Extension<User> from request parts
let user = Extension::<Arc<User>>::from_request_parts(parts, state)
.await
.map_err(|_| axum::response::Json("Unauthorized".to_string()))?;
Ok(AuthenticatedUser { name: user.0.name.clone() })
}
}
async fn protected(user: AuthenticatedUser) -> String {
format!("Hello, {}", user.name)
}Extension implements FromRequestParts, enabling extraction in handler parameters.
use axum::{Extension, Router, routing::get, middleware};
use tower::ServiceBuilder;
#[derive(Clone)]
struct RequestId(String);
#[derive(Clone)]
struct Logger;
async fn handler(
Extension(request_id): Extension<RequestId>,
Extension(logger): Extension<Logger>,
) -> String {
request_id.0
}
#[tokio::main]
async fn main() {
let request_id = RequestId("req-123".to_string());
let logger = Logger;
// Layer order matters! Extensions are added in order
let app = Router::new()
.route("/handler", get(handler))
.layer(Extension(request_id)) // Added first
.layer(Extension(logger)); // Added second
// After these layers, request.extensions contains:
// - TypeId(RequestId) -> RequestId { ... }
// - TypeId(Logger) -> Logger
// IMPORTANT: Extensions added AFTER the handler won't be available
// Wrong order:
// .layer(Extension(logger))
// .route("/handler", get(handler)) // Handler runs AFTER extension layers
// This is correct - extensions are applied to incoming requests
// Using ServiceBuilder for clarity:
let app_with_middleware = Router::new()
.route("/handler", get(handler))
.layer(
ServiceBuilder::new()
.layer(Extension(request_id))
.layer(Extension(logger))
);
}Extensions are added to requests as they flow through the middleware stack.
use axum::{Extension, Router, routing::get, extract::State};
#[derive(Clone)]
struct AppState {
db: String,
}
// Approach 1: Extension
async fn with_extension(
Extension(state): Extension<AppState>,
) -> String {
state.db.clone()
}
// Approach 2: State
async fn with_state(State(state): State<AppState>) -> String {
state.db.clone()
}
#[tokio::main]
async fn main() {
let state = AppState { db: "postgres://...".to_string() };
// Extension approach
let app_extension = Router::new()
.route("/ext", get(with_extension))
.layer(Extension(state.clone()));
// State approach
let app_state = Router::new()
.route("/state", get(with_state))
.with_state(state);
// Key differences:
// - Extension: Works with any layer, multiple types possible
// - State: Router-specific, single state type per router
// Extension is more flexible but requires Clone
// State is more type-safe but limited to one state per router
}State is router-specific; Extension works with the broader Tower middleware ecosystem.
use axum::{Extension, Router, routing::get};
use std::sync::Arc;
// Extension requires Clone because:
// 1. Each request needs access to the state
// 2. The extension layer clones the value for each request
// 3. This is why Arc is often used for non-trivial types
#[derive(Clone)]
struct Config {
api_key: String,
}
// Non-Clone type won't work:
// struct DatabaseConnection {
// pool: sqlx::PgPool, // Pool is Clone, but conceptually
// }
// Extension<DatabaseConnection> requires Clone
async fn handler(Extension(config): Extension<Config>) -> String {
config.api_key.clone()
}
#[tokio::main]
async fn main() {
let config = Config { api_key: "secret".to_string() };
// This clones config for each request
// If config is expensive to clone, use Arc
let app = Router::new()
.route("/api", get(handler))
.layer(Extension(config));
// Better for non-trivial types:
let config_arc = Arc::new(Config { api_key: "secret".to_string() });
let app_better = Router::new()
.route("/api", get(|Extension(config): Extension<Arc<Config>>| {
config.api_key.clone()
}))
.layer(Extension(config_arc));
}Extension clones for each request; wrap expensive types in Arc.
use axum::{Extension, Router, routing::get};
use tower::ServiceBuilder;
use tower_http::request_id::SetRequestId;
#[derive(Clone)]
struct RequestContext {
request_id: String,
}
// Custom middleware that adds an extension
async fn context_middleware<B>(
request: axum::http::Request<B>,
next: tower::ServiceFn<axum::http::Request<B>>,
) -> Result<axum::http::Response<B>, tower::BoxError> {
// Extract request ID (could come from header, generate, etc.)
let request_id = uuid::Uuid::new_v4().to_string();
// Add context to request extensions
let context = RequestContext { request_id };
let mut request = request;
request.extensions_mut().insert(context);
// Call next service
next.call(request).await
}
async fn handler(Extension(ctx): Extension<RequestContext>) -> String {
ctx.request_id
}
#[tokio::main]
async fn main() {
let app = Router::new()
.route("/handler", get(handler))
.layer(axum::middleware::from_fn(context_middleware));
// Middleware added the extension, handler extracted it
}Middleware can add extensions directly via request.extensions_mut().insert().
use axum::{Extension, Router, routing::get};
// Some extensions are request-scoped (new per request)
// Others are app-scoped (shared across requests)
#[derive(Clone)]
struct AppConfig {
version: String, // App-scoped: same for all requests
}
struct RequestId {
id: String, // Request-scoped: unique per request
}
// RequestId shouldn't be Clone (unique per request)
// Instead, generate it in middleware:
async fn request_id_middleware(
request: axum::http::Request<axum::body::Body>,
next: tower::ServiceFn<axum::http::Request<axum::body::Body>>,
) -> Result<axum::http::Response<axum::body::Body>, tower::BoxError> {
// Create new RequestId for this request
let request_id = RequestId { id: uuid::Uuid::new_v4().to_string() };
let mut request = request;
request.extensions_mut().insert(request_id);
next.call(request).await
}
// Handlers can then extract RequestId
async fn handler(
Extension(config): Extension<AppConfig>, // Shared across requests
Extension(req_id): Extension<RequestId>, // Unique per request
) -> String {
format!("Version: {}, Request: {}", config.version, req_id.id)
}App-scoped state uses Extension(Arc<T>) with Clone; request-scoped uses middleware to insert.
use axum::{
Extension, Router, routing::get,
response::{IntoResponse, Response},
http::StatusCode,
};
use std::sync::Arc;
#[derive(Clone)]
struct Database;
async fn db_handler(Extension(db): Extension<Arc<Database>>) -> String {
"Connected".to_string()
}
// If Extension is missing, handler returns 500 Internal Server Error
// Custom error handling:
async fn safe_handler(
extension_result: Result<Extension<Arc<Database>>, axum::extract::rejection::ExtensionRejection>,
) -> Result<String, StatusCode> {
match extension_result {
Ok(Extension(db)) => Ok("Connected".to_string()),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
// Or use custom extractor:
struct RequiredDatabase(Arc<Database>);
impl<S> axum::extract::FromRequestParts<S> for RequiredDatabase
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let ext = Extension::<Arc<Database>>::from_request_parts(parts, state)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database not configured"))?;
Ok(RequiredDatabase(ext.0))
}
}Missing extensions return ExtensionRejection; custom extractors can provide better error messages.
use axum::{
Extension, Router, routing::{get, post},
Json,
};
use serde::Serialize;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone)]
struct AppConfig {
app_name: String,
version: String,
}
struct AppState {
config: AppConfig,
// In reality: database pools, caches, etc.
}
#[derive(Serialize)]
struct StatusResponse {
app_name: String,
version: String,
active_requests: u64,
}
async fn get_status(
Extension(state): Extension<Arc<AppState>>,
) -> Json<StatusResponse> {
Json(StatusResponse {
app_name: state.config.app_name.clone(),
version: state.config.version.clone(),
active_requests: 0, // Would track with atomic counter
})
}
async fn health_check(
Extension(config): Extension<AppConfig>,
) -> &'static str {
// Can also extract just the config if AppState has multiple fields
"OK"
}
#[tokio::main]
async fn main() {
let config = AppConfig {
app_name: "MyApp".to_string(),
version: "1.0.0".to_string(),
};
let state = Arc::new(AppState { config: config.clone() });
let app = Router::new()
.route("/status", get(get_status))
.route("/health", get(health_check))
.layer(Extension(state)) // Arc<AppState> available to all routes
.layer(Extension(config)); // AppConfig also available
// Both extensions are available to any handler
}Multiple extension types can be layered; handlers extract what they need.
Extension mechanism flow:
┌─────────────────────────────────────────────────────────────┐
│ Incoming Request │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Extension Layer (.layer(Extension(state))) │
│ 1. Clones state (or uses Arc clone) │
│ 2. Inserts into request.extensions by TypeId │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Handler Function │
│ fn handler(Extension(state): Extension<T>) -> ... │
│ │
│ FromRequestParts implementation: │
│ 1. Access request.extensions │
│ 2. Look up by TypeId::of::<T>() │
│ 3. Downcast to T │
│ 4. Return T (or rejection if missing) │
└─────────────────────────────────────────────────────────────┘
Extension vs State:
| Aspect | Extension | State | |--------|-----------|-------| | Origin | Tower middleware layer | Router method | | Quantity | Multiple types | Single type per router | | Flexibility | Works with any Tower layer | Router-specific | | Type safety | Dynamic (rejection on missing) | Static (must be provided) | | Middleware compatibility | Full | Limited |
Key characteristics:
TypeIdClone (use Arc for efficiency)Extension<T> not foundWhen to use Extension:
Arc)When to use State:
Key insight: Extension is Axum's implementation of dependency injection through the request extension mechanism. Rather than threading shared state through every function parameter, you add it to the router as middleware, and Axum's FromRequest implementation automatically extracts it from the request's extension map. The type system ensures that handlers can only request types that exist (via Clone), and the TypeId-keyed storage allows multiple independent extension types to coexist. This pattern decouples handler logic from state management: handlers declare what they need via Extension<T>, and the router configuration ensures those dependencies are satisfied.