How do I use async traits with async-trait in Rust?

Walkthrough

The async-trait crate provides a macro that allows you to define async functions in traits. In Rust, async functions in traits have historically been challenging because async functions return opaque impl Future types, which couldn't be expressed in trait definitions. The async_trait macro solves this by boxing the futures, making them concrete types. While Rust 1.75+ now supports native async fn in traits, async-trait remains useful for complex scenarios, trait objects (dyn Trait), and compatibility with older Rust versions.

Key concepts:

  1. Macro placement#[async_trait] goes on both trait and impl blocks
  2. Boxed futures — async methods return Pin<Box<dyn Future>>
  3. Trait objects — enables dyn Trait with async methods
  4. Send requirement — futures are Send by default, use ?Send for non-Send
  5. Native alternative — Rust 1.75+ supports async fn in traits natively (with limitations)

Code Example

# Cargo.toml
[dependencies]
async-trait = "0.1"
use async_trait::async_trait;
 
#[async_trait]
trait Database {
    async fn get_user(&self, id: u64) -> Option<String>;
}
 
struct MockDb;
 
#[async_trait]
impl Database for MockDb {
    async fn get_user(&self, id: u64) -> Option<String> {
        Some(format!("User {}", id))
    }
}
 
#[tokio::main]
async fn main() {
    let db = MockDb;
    if let Some(user) = db.get_user(42).await {
        println!("Found: {}", user);
    }
}

Basic Async Trait Definition

use async_trait::async_trait;
 
// Apply the macro to the trait
#[async_trait]
pub trait Reader {
    // Async method in trait
    async fn read(&self) -> String;
    
    // Can have both async and sync methods
    fn sync_info(&self) -> &str {
        "reader"
    }
}
 
// A struct that implements the trait
struct FileReader {
    path: String,
}
 
// Apply the macro to the impl block too
#[async_trait]
impl Reader for FileReader {
    async fn read(&self) -> String {
        // Simulate async file reading
        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
        format!("Contents of {}", self.path)
    }
}
 
struct MemoryReader {
    data: String,
}
 
#[async_trait]
impl Reader for MemoryReader {
    async fn read(&self) -> String {
        self.data.clone()
    }
}
 
#[tokio::main]
async fn main() {
    let file_reader = FileReader { path: "test.txt".to_string() };
    let mem_reader = MemoryReader { data: "memory data".to_string() };
    
    println!("File: {}", file_reader.read().await);
    println!("Memory: {}", mem_reader.read().await);
}

Multiple Async Methods

use async_trait::async_trait;
 
#[async_trait]
trait Repository {
    async fn find_by_id(&self, id: u64) -> Option<String>;
    async fn find_all(&self) -> Vec<String>;
    async fn save(&mut self, id: u64, value: &str) -> bool;
    async fn delete(&mut self, id: u64) -> bool;
}
 
struct InMemoryRepo {
    data: std::collections::HashMap<u64, String>,
}
 
#[async_trait]
impl Repository for InMemoryRepo {
    async fn find_by_id(&self, id: u64) -> Option<String> {
        // Simulate async database lookup
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        self.data.get(&id).cloned()
    }
    
    async fn find_all(&self) -> Vec<String> {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        self.data.values().cloned().collect()
    }
    
    async fn save(&mut self, id: u64, value: &str) -> bool {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        self.data.insert(id, value.to_string());
        true
    }
    
    async fn delete(&mut self, id: u64) -> bool {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        self.data.remove(&id).is_some()
    }
}
 
#[tokio::main]
async fn main() {
    let mut repo = InMemoryRepo {
        data: std::collections::HashMap::new(),
    };
    
    // Save some items
    repo.save(1, "Alice").await;
    repo.save(2, "Bob").await;
    
    // Find by ID
    if let Some(name) = repo.find_by_id(1).await {
        println!("Found: {}", name);
    }
    
    // Find all
    let all = repo.find_all().await;
    println!("All: {:?}", all);
    
    // Delete
    if repo.delete(1).await {
        println!("Deleted user 1");
    }
}

Using Trait Objects (dyn Trait)

use async_trait::async_trait;
 
#[async_trait]
trait Handler {
    async fn handle(&self, request: &str) -> String;
}
 
struct LogHandler;
 
#[async_trait]
impl Handler for LogHandler {
    async fn handle(&self, request: &str) -> String {
        println!("Logging: {}", request);
        format!("Logged: {}", request)
    }
}
 
struct TransformHandler {
    prefix: String,
}
 
#[async_trait]
impl Handler for TransformHandler {
    async fn handle(&self, request: &str) -> String {
        format!("{}: {}", self.prefix, request)
    }
}
 
// Function that accepts any handler as a trait object
async fn process_request(handler: &dyn Handler, request: &str) {
    let result = handler.handle(request).await;
    println!("Result: {}", result);
}
 
// Collection of handlers
struct HandlerChain {
    handlers: Vec<Box<dyn Handler + Send + Sync>>,
}
 
impl HandlerChain {
    fn new() -> Self {
        Self { handlers: Vec::new() }
    }
    
    fn add<H: Handler + Send + Sync + 'static>(&mut self, handler: H) {
        self.handlers.push(Box::new(handler));
    }
    
    async fn process_all(&self, request: &str) {
        for handler in &self.handlers {
            let result = handler.handle(request).await;
            println!("  -> {}", result);
        }
    }
}
 
#[tokio::main]
async fn main() {
    // Single handler via trait object
    let handler: &dyn Handler = &LogHandler;
    process_request(handler, "test request").await;
    
    // Multiple handlers in a collection
    let mut chain = HandlerChain::new();
    chain.add(LogHandler);
    chain.add(TransformHandler { prefix: "PREFIX".to_string() });
    
    println!("\nProcessing through chain:");
    chain.process_all("hello").await;
}

Send and Non-Send Futures

use async_trait::async_trait;
 
// By default, async_trait makes futures Send
// This is required for spawning tasks on tokio
#[async_trait]
trait AsyncService: Send + Sync {
    async fn process(&self, data: &str) -> String;
}
 
struct MyService;
 
#[async_trait]
impl AsyncService for MyService {
    async fn process(&self, data: &str) -> String {
        format!("Processed: {}", data)
    }
}
 
// For non-Send futures (e.g., using Rc or RefCell), use ?Send
#[async_trait(?Send)]
trait LocalService {
    async fn local_process(&self) -> String;
}
 
struct LocalHandler {
    // This wouldn't work with Send bound
    data: std::rc::Rc<String>,
}
 
#[async_trait(?Send)]
impl LocalService for LocalHandler {
    async fn local_process(&self) -> String {
        (*self.data).clone()
    }
}
 
#[tokio::main]
async fn main() {
    // Send service can be used across threads
    let service = MyService;
    
    // Can spawn on tokio runtime
    let handle = tokio::spawn(async move {
        service.process("hello").await
    });
    
    let result = handle.await.unwrap();
    println!("Result: {}", result);
    
    // Local (non-Send) service
    let local = LocalHandler {
        data: std::rc::Rc::new("local data".to_string()),
    };
    
    // Must stay on single thread
    let result = local.local_process().await;
    println!("Local result: {}", result);
}

Default Implementations

use async_trait::async_trait;
 
#[async_trait]
trait Cache {
    // Required method
    async fn get_raw(&self, key: &str) -> Option<Vec<u8>>;
    
    // Required method
    async fn set_raw(&self, key: &str, value: Vec<u8>);
    
    // Default implementation using the required methods
    async fn get_string(&self, key: &str) -> Option<String> {
        match self.get_raw(key).await {
            Some(bytes) => String::from_utf8(bytes).ok(),
            None => None,
        }
    }
    
    // Default implementation
    async fn set_string(&self, key: &str, value: &str) {
        self.set_raw(key, value.as_bytes().to_vec()).await;
    }
    
    // Default that calls multiple async methods
    async fn get_or_default(&self, key: &str, default: &str) -> String {
        match self.get_string(key).await {
            Some(value) => value,
            None => {
                self.set_string(key, default).await;
                default.to_string()
            }
        }
    }
}
 
struct MemoryCache {
    data: std::collections::HashMap<String, Vec<u8>>,
}
 
#[async_trait]
impl Cache for MemoryCache {
    async fn get_raw(&self, key: &str) -> Option<Vec<u8>> {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        self.data.get(key).cloned()
    }
    
    async fn set_raw(&self, key: &str, value: Vec<u8>) {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        // Need interior mutability for this example
        println!("Would set {} to {:?}", key, value);
    }
}
 
#[tokio::main]
async fn main() {
    let cache = MemoryCache {
        data: std::collections::HashMap::new(),
    };
    
    // Using default method
    let value = cache.get_or_default("missing-key", "default-value").await;
    println!("Got: {}", value);
}

Generic Async Traits

use async_trait::async_trait;
 
#[async_trait]
trait Store<T> {
    async fn get(&self, id: u64) -> Option<T>;
    async fn save(&self, item: T) -> u64;
}
 
#[derive(Debug, Clone)]
struct User {
    id: u64,
    name: String,
}
 
#[derive(Debug, Clone)]
struct Product {
    id: u64,
    name: String,
    price: f64,
}
 
struct UserStore {
    users: std::collections::HashMap<u64, User>,
    next_id: u64,
}
 
#[async_trait]
impl Store<User> for UserStore {
    async fn get(&self, id: u64) -> Option<User> {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        self.users.get(&id).cloned()
    }
    
    async fn save(&self, item: User) -> u64 {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        item.id
    }
}
 
struct ProductStore {
    products: std::collections::HashMap<u64, Product>,
}
 
#[async_trait]
impl Store<Product> for ProductStore {
    async fn get(&self, id: u64) -> Option<Product> {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        self.products.get(&id).cloned()
    }
    
    async fn save(&self, item: Product) -> u64 {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        item.id
    }
}
 
#[tokio::main]
async fn main() {
    let user_store = UserStore {
        users: std::collections::HashMap::new(),
        next_id: 1,
    };
    
    let product_store = ProductStore {
        products: std::collections::HashMap::new(),
    };
    
    // Generic function that works with any Store<T>
    async fn get_item<T, S: Store<T>>(store: &S, id: u64) -> Option<T> {
        store.get(id).await
    }
    
    // Won't find anything in empty stores, but demonstrates the pattern
    let user = get_item(&user_store, 1).await;
    println!("User: {:?}", user);
}

Returning Results from Async Traits

use async_trait::async_trait;
use std::io;
 
#[async_trait]
trait AsyncFile {
    async fn read_all(&self) -> io::Result<String>;
    async fn write_all(&mut self, content: &str) -> io::Result<()>;
}
 
struct MockFile {
    content: String,
    should_fail: bool,
}
 
#[async_trait]
impl AsyncFile for MockFile {
    async fn read_all(&self) -> io::Result<String> {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        
        if self.should_fail {
            Err(io::Error::new(io::ErrorKind::Other, "mock error"))
        } else {
            Ok(self.content.clone())
        }
    }
    
    async fn write_all(&mut self, content: &str) -> io::Result<()> {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        
        if self.should_fail {
            Err(io::Error::new(io::ErrorKind::Other, "mock error"))
        } else {
            self.content = content.to_string();
            Ok(())
        }
    }
}
 
async fn copy_file<S: AsyncFile, D: AsyncFile>(src: &S, dst: &mut D) -> io::Result<()> {
    let content = src.read_all().await?;
    dst.write_all(&content).await?;
    Ok(())
}
 
#[tokio::main]
async fn main() {
    let mut src = MockFile {
        content: "Hello, World!".to_string(),
        should_fail: false,
    };
    
    let mut dst = MockFile {
        content: String::new(),
        should_fail: false,
    };
    
    match copy_file(&src, &mut dst).await {
        Ok(()) => println!("Copy succeeded: {}", dst.content),
        Err(e) => println!("Copy failed: {}", e),
    }
    
    // Test error handling
    src.should_fail = true;
    match copy_file(&src, &mut dst).await {
        Ok(()) => println!("Copy succeeded"),
        Err(e) => println!("Copy failed as expected: {}", e),
    }
}

HTTP Client Trait Example

use async_trait::async_trait;
 
#[async_trait]
pub trait HttpClient {
    async fn get(&self, url: &str) -> Result<String, String>;
    async fn post(&self, url: &str, body: &str) -> Result<String, String>;
}
 
// Real implementation using reqwest (would need reqwest in Cargo.toml)
pub struct ReqwestClient;
 
#[async_trait]
impl HttpClient for ReqwestClient {
    async fn get(&self, url: &str) -> Result<String, String> {
        // In real code: reqwest::get(url).await?.text().await
        Ok(format!("GET response from {}", url))
    }
    
    async fn post(&self, url: &str, body: &str) -> Result<String, String> {
        Ok(format!("POST {} with body: {}", url, body))
    }
}
 
// Mock implementation for testing
pub struct MockHttpClient {
    responses: std::collections::HashMap<String, String>,
}
 
#[async_trait]
impl HttpClient for MockHttpClient {
    async fn get(&self, url: &str) -> Result<String, String> {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        self.responses
            .get(url)
            .cloned()
            .ok_or_else(|| format!("No mock for {}", url))
    }
    
    async fn post(&self, url: &str, body: &str) -> Result<String, String> {
        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
        Ok(format!("Mock POST to {} with {}", url, body))
    }
}
 
// Service that uses the client
pub struct ApiService<T: HttpClient> {
    client: T,
    base_url: String,
}
 
impl<T: HttpClient> ApiService<T> {
    pub fn new(client: T, base_url: &str) -> Self {
        Self {
            client,
            base_url: base_url.to_string(),
        }
    }
    
    pub async fn fetch_user(&self, id: u64) -> Result<String, String> {
        let url = format!("{}/users/{}", self.base_url, id);
        self.client.get(&url).await
    }
    
    pub async fn create_user(&self, name: &str) -> Result<String, String> {
        let url = format!("{}/users", self.base_url);
        self.client.post(&url, &format!("{{\"name\": \"{}\"}}", name)).await
    }
}
 
#[tokio::main]
async fn main() {
    // Using mock client for testing
    let mock = MockHttpClient {
        responses: std::collections::HashMap::from([
            ("https://api.example.com/users/1".to_string(), 
             "{\"id\": 1, \"name\": \"Alice\"}".to_string()),
        ]),
    };
    
    let service = ApiService::new(mock, "https://api.example.com");
    
    match service.fetch_user(1).await {
        Ok(user) => println!("User: {}", user),
        Err(e) => println!("Error: {}", e),
    }
    
    match service.create_user("Bob").await {
        Ok(result) => println!("Created: {}", result),
        Err(e) => println!("Error: {}", e),
    }
}

Database Repository Pattern

use async_trait::async_trait;
 
#[derive(Debug, Clone)]
pub struct User {
    pub id: u64,
    pub username: String,
    pub email: String,
}
 
#[async_trait]
pub trait UserRepository {
    async fn find_by_id(&self, id: u64) -> Option<User>;
    async fn find_by_username(&self, username: &str) -> Option<User>;
    async fn save(&self, user: &User) -> bool;
    async fn delete(&self, id: u64) -> bool;
}
 
// In-memory implementation for testing
pub struct InMemoryUserRepository {
    users: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<u64, User>>>,
}
 
impl InMemoryUserRepository {
    pub fn new() -> Self {
        Self {
            users: std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
        }
    }
}
 
#[async_trait]
impl UserRepository for InMemoryUserRepository {
    async fn find_by_id(&self, id: u64) -> Option<User> {
        let users = self.users.read().await;
        users.get(&id).cloned()
    }
    
    async fn find_by_username(&self, username: &str) -> Option<User> {
        let users = self.users.read().await;
        users.values().find(|u| u.username == username).cloned()
    }
    
    async fn save(&self, user: &User) -> bool {
        let mut users = self.users.write().await;
        users.insert(user.id, user.clone());
        true
    }
    
    async fn delete(&self, id: u64) -> bool {
        let mut users = self.users.write().await;
        users.remove(&id).is_some()
    }
}
 
// Service that uses the repository
pub struct UserService<R: UserRepository> {
    repo: R,
}
 
impl<R: UserRepository> UserService<R> {
    pub fn new(repo: R) -> Self {
        Self { repo }
    }
    
    pub async fn get_user(&self, id: u64) -> Option<User> {
        self.repo.find_by_id(id).await
    }
    
    pub async fn create_user(&self, id: u64, username: &str, email: &str) -> User {
        let user = User {
            id,
            username: username.to_string(),
            email: email.to_string(),
        };
        self.repo.save(&user).await;
        user
    }
    
    pub async fn find_by_username(&self, username: &str) -> Option<User> {
        self.repo.find_by_username(username).await
    }
}
 
#[tokio::main]
async fn main() {
    let repo = InMemoryUserRepository::new();
    let service = UserService::new(repo);
    
    // Create users
    let alice = service.create_user(1, "alice", "alice@example.com").await;
    let bob = service.create_user(2, "bob", "bob@example.com").await;
    
    println!("Created: {:?}", alice);
    println!("Created: {:?}", bob);
    
    // Find users
    if let Some(user) = service.get_user(1).await {
        println!("Found by ID: {:?}", user);
    }
    
    if let Some(user) = service.find_by_username("bob").await {
        println!("Found by username: {:?}", user);
    }
}

Comparing Native vs async-trait

// Rust 1.75+ supports native async fn in traits:
// 
// trait NativeAsync {
//     async fn do_something(&self) -> i32;
// }
//
// However, native async traits have limitations:
// 1. Cannot use `dyn Trait` with async methods directly
// 2. Complex lifetime handling
// 3. No support for returning impl Future from trait methods
//
// async-trait works around these by boxing futures
 
use async_trait::async_trait;
 
#[async_trait]
trait WithAsyncTrait {
    async fn process(&self) -> i32;
}
 
struct ImplWithAsyncTrait;
 
#[async_trait]
impl WithAsyncTrait for ImplWithAsyncTrait {
    async fn process(&self) -> i32 {
        42
    }
}
 
fn use_dyn_trait(obj: &dyn WithAsyncTrait) -> std::pin::Pin<Box<dyn std::future::Future<Output = i32> + Send + '_>> {
    // This works because async-trait boxes the future
    obj.process()
}
 
fn main() {
    // This example shows the difference conceptually
    println!("Native async traits are available in Rust 1.75+");
    println!("But async-trait is still useful for:");
    println!("  - Trait objects (dyn Trait)");
    println!("  - Complex lifetime scenarios");
    println!("  - Older Rust versions");
}

Middleware Pattern

use async_trait::async_trait;
 
#[async_trait]
pub trait Middleware: Send + Sync {
    async fn handle(&self, request: Request, next: Next<'_>) -> Response;
}
 
#[derive(Debug, Clone)]
pub struct Request {
    pub path: String,
    pub headers: Vec<(String, String)>,
}
 
#[derive(Debug, Clone)]
pub struct Response {
    pub status: u16,
    pub body: String,
}
 
pub struct Next<'a> {
    middlewares: &'a [Box<dyn Middleware>],
    current: usize,
}
 
impl<'a> Next<'a> {
    pub async fn run(mut self, request: Request) -> Response {
        if self.current < self.middlewares.len() {
            let middleware = &self.middlewares[self.current];
            self.current += 1;
            middleware.handle(request, self).await
        } else {
            Response {
                status: 200,
                body: "OK".to_string(),
            }
        }
    }
}
 
// Logging middleware
pub struct LoggingMiddleware;
 
#[async_trait]
impl Middleware for LoggingMiddleware {
    async fn handle(&self, request: Request, next: Next<'_>) -> Response {
        println!("Request: {}", request.path);
        let response = next.run(request).await;
        println!("Response: {}", response.status);
        response
    }
}
 
// Auth middleware
pub struct AuthMiddleware {
    pub api_key: String,
}
 
#[async_trait]
impl Middleware for AuthMiddleware {
    async fn handle(&self, request: Request, mut next: Next<'_>) -> Response {
        let has_auth = request.headers.iter().any(|(k, v)| {
            k == "Authorization" && v == &format!("Bearer {}", self.api_key)
        });
        
        if has_auth {
            next.run(request).await
        } else {
            Response {
                status: 401,
                body: "Unauthorized".to_string(),
            }
        }
    }
}
 
// Header adding middleware
pub struct HeaderMiddleware {
    pub name: String,
    pub value: String,
}
 
#[async_trait]
impl Middleware for HeaderMiddleware {
    async fn handle(&self, mut request: Request, next: Next<'_>) -> Response {
        request.headers.push((self.name.clone(), self.value.clone()));
        next.run(request).await
    }
}
 
#[tokio::main]
async fn main() {
    let middlewares: Vec<Box<dyn Middleware>> = vec![
        Box::new(LoggingMiddleware),
        Box::new(AuthMiddleware { api_key: "secret".to_string() }),
        Box::new(HeaderMiddleware { name: "X-Request-Id".to_string(), value: "123".to_string() }),
    ];
    
    let request = Request {
        path: "/api/users".to_string(),
        headers: vec![("Authorization".to_string(), "Bearer secret".to_string())],
    };
    
    let next = Next {
        middlewares: &middlewares,
        current: 0,
    };
    
    let response = next.run(request).await;
    println!("Final response: {:?}", response);
}

Behind the Scenes

// What async_trait actually does:
//
// Before (with the macro):
// #[async_trait]
// trait Example {
//     async fn foo(&self) -> i32;
// }
//
// After (what the macro generates):
// trait Example {
//     fn foo<'life0, 'async_trait>(
//         &'life0 self,
//     ) -> Pin<Box<dyn Future<Output = i32> + Send + 'async_trait>>
//     where
//         'life0: 'async_trait,
//         Self: 'async_trait;
// }
//
// The key transformations:
// 1. async fn -> returns Pin<Box<dyn Future>>
// 2. .await becomes the future being polled
// 3. Futures are boxed to have a concrete size
// 4. Send bound is added by default (removable with ?Send)
 
use async_trait::async_trait;
 
#[async_trait]
trait Example {
    async fn foo(&self) -> i32;
}
 
struct MyExample;
 
#[async_trait]
impl Example for MyExample {
    async fn foo(&self) -> i32 {
        42
    }
}
 
fn main() {
    // The macro handles all the complexity
    println!("async_trait boxes futures to make them concrete types");
    println!("This allows using async methods in traits");
}

Summary

  • #[async_trait] enables async methods in traits by boxing futures
  • Apply the macro to both the trait definition and impl blocks
  • By default, futures are Send — use #[async_trait(?Send)] for non-Send futures
  • Enables using dyn Trait with async methods, which native async traits don't fully support
  • Perfect for: database repositories, HTTP clients, middleware, service abstractions
  • Great for testing: easily swap real and mock implementations
  • Supports default implementations for async methods
  • Works with generic traits (trait Store<T>)
  • Rust 1.75+ has native async fn in traits, but async-trait remains useful for trait objects and complex scenarios
  • Trade-off: boxing has a small runtime cost but provides flexibility
  • Use when you need trait objects (dyn Trait) with async methods