Loading page…
Rust walkthroughs
Loading page…
The async-trait crate provides a procedural macro that enables async functions in trait definitions. In standard Rust, traits cannot have async functions because the compiler doesn't support async in traits natively (this is being worked on). The async-trait macro works around this by boxing the futures, converting async functions to return Pin<Box<dyn Future>>. This allows you to define traits with async methods and implement them for types.
Key points:
#[async_trait] — attribute macro for trait definitionsPin<Box<dyn Future>>Send bound — use #[async_trait] for Send futures#[async_trait(?Send)] for non-Send futures# Cargo.toml
[dependencies]
async-trait = "0.1"use async_trait::async_trait;
#[async_trait]
pub trait Database {
async fn get_user(&self, id: u32) -> Option<String>;
async fn save_user(&self, id: u32, name: &str) -> bool;
}
struct MyDatabase;
#[async_trait]
impl Database for MyDatabase {
async fn get_user(&self, id: u32) -> Option<String> {
Some(format!("User {}", id))
}
async fn save_user(&self, id: u32, name: &str) -> bool {
println!("Saving user {}: {}", id, name);
true
}
}
#[tokio::main]
async fn main() {
let db = MyDatabase;
let user = db.get_user(1).await;
println!("Got user: {:?}", user);
}use async_trait::async_trait;
// Define trait with async methods
#[async_trait]
trait Fetcher {
async fn fetch(&self, url: &str) -> Result<String, String>;
async fn fetch_all(&self, urls: &[&str]) -> Vec<Result<String, String>>;
}
// Simple implementation
struct HttpFetcher;
#[async_trait]
impl Fetcher for HttpFetcher {
async fn fetch(&self, url: &str) -> Result<String, String> {
// Simulate HTTP request
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Ok(format!("Response from {}", url))
}
async fn fetch_all(&self, urls: &[&str]) -> Vec<Result<String, String>> {
let mut results = Vec::new();
for url in urls {
results.push(self.fetch(url).await);
}
results
}
}
// Implementation with state
struct CachingFetcher {
cache: std::sync::Arc<tokio::sync::Mutex<Vec<String>>>,
}
impl CachingFetcher {
fn new() -> Self {
Self {
cache: std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())),
}
}
}
#[async_trait]
impl Fetcher for CachingFetcher {
async fn fetch(&self, url: &str) -> Result<String, String> {
// Check cache
{
let cache = self.cache.lock().await;
if let Some(cached) = cache.iter().find(|s| s.contains(url)) {
return Ok(format!("Cached: {}", cached));
}
}
// Fetch and cache
let response = format!("Response from {}", url);
let mut cache = self.cache.lock().await;
cache.push(response.clone());
Ok(response)
}
async fn fetch_all(&self, urls: &[&str]) -> Vec<Result<String, String>> {
let mut results = Vec::new();
for url in urls {
results.push(self.fetch(url).await);
}
results
}
}
#[tokio::main]
async fn main() {
let fetcher = HttpFetcher;
let result = fetcher.fetch("https://example.com").await;
println!("Result: {:?}", result);
let results = fetcher.fetch_all(&["https://a.com", "https://b.com"]).await;
println!("Results: {:?}", results);
}use async_trait::async_trait;
#[async_trait]
trait Processor {
async fn process(&self, data: &str) -> String;
async fn process_owned(&self, data: String) -> String;
}
#[async_trait]
trait Factory {
type Output;
async fn create(&self) -> Self::Output;
}
// Trait with lifetime
#[async_trait]
trait Parser<'a> {
async fn parse(&self, input: &'a str) -> Vec<&'a str>;
}
struct StringParser;
#[async_trait]
impl<'a> Parser<'a> for StringParser {
async fn parse(&self, input: &'a str) -> Vec<&'a str> {
// Simulate async parsing
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
input.split_whitespace().collect()
}
}
struct UppercaseProcessor;
#[async_trait]
impl Processor for UppercaseProcessor {
async fn process(&self, data: &str) -> String {
data.to_uppercase()
}
async fn process_owned(&self, data: String) -> String {
data.to_uppercase()
}
}
#[tokio::main]
async fn main() {
let processor = UppercaseProcessor;
let result = processor.process("hello world").await;
println!("Processed: {}", result);
let parser = StringParser;
let words = parser.parse("hello world from rust").await;
println!("Words: {:?}", words);
}use async_trait::async_trait;
#[async_trait]
trait Repository {
async fn get(&self, id: u32) -> Option<String>;
async fn get_all(&self) -> Vec<String>;
// Default implementation
async fn get_or_default(&self, id: u32) -> String {
self.get(id).await.unwrap_or_else(|| "default".to_string())
}
// Default that calls other async methods
async fn exists(&self, id: u32) -> bool {
self.get(id).await.is_some()
}
}
struct MemoryRepository {
data: std::sync::Arc<tokio::sync::Mutex<std::collections::HashMap<u32, String>>>,
}
impl MemoryRepository {
fn new() -> Self {
Self {
data: std::sync::Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
}
}
}
#[async_trait]
impl Repository for MemoryRepository {
async fn get(&self, id: u32) -> Option<String> {
let data = self.data.lock().await;
data.get(&id).cloned()
}
async fn get_all(&self) -> Vec<String> {
let data = self.data.lock().await;
data.values().cloned().collect()
}
}
#[tokio::main]
async fn main() {
let repo = MemoryRepository::new();
// Use default method
let result = repo.get_or_default(1).await;
println!("Result: {}", result);
let exists = repo.exists(1).await;
println!("Exists: {}", exists);
}use async_trait::async_trait;
// Default: futures must be Send (required for multi-threaded runtimes)
#[async_trait]
pub trait AsyncService: Send + Sync {
async fn process(&self, data: String) -> String;
}
// Non-Send futures (single-threaded contexts)
#[async_trait(?Send)]
pub trait LocalService {
async fn process(&self, data: String) -> String;
}
// Send implementation
struct ThreadSafeService;
#[async_trait]
impl AsyncService for ThreadSafeService {
async fn process(&self, data: String) -> String {
// This can run on any thread
data.to_uppercase()
}
}
// Non-Send implementation
struct LocalService {
// Rc is not Send
data: std::rc::Rc<String>,
}
#[async_trait(?Send)]
impl LocalService for LocalService {
async fn process(&self, data: String) -> String {
format!("{} - {}", *self.data, data)
}
}
// When to use ?Send:
// - Single-threaded runtimes
// - Using non-Send types like Rc<T>
// - Interacting with non-thread-safe C libraries
// - WASM targets
#[tokio::main]
async fn main() {
let service = ThreadSafeService;
let result = service.process("hello".to_string()).await;
println!("Result: {}", result);
}use async_trait::async_trait;
#[async_trait]
trait Handler {
async fn handle(&self, request: String) -> String;
}
struct EchoHandler;
#[async_trait]
impl Handler for EchoHandler {
async fn handle(&self, request: String) -> String {
format!("Echo: {}", request)
}
}
struct UppercaseHandler;
#[async_trait]
impl Handler for UppercaseHandler {
async fn handle(&self, request: String) -> String {
request.to_uppercase()
}
}
struct ReverseHandler;
#[async_trait]
impl Handler for ReverseHandler {
async fn handle(&self, request: String) -> String {
request.chars().rev().collect()
}
}
// Use trait object
async fn process_request(handler: &dyn Handler, request: String) -> String {
handler.handle(request).await
}
// Multiple handlers
async fn process_with_all(handlers: &[Box<dyn Handler>], request: String) -> Vec<String> {
let mut results = Vec::new();
for handler in handlers {
results.push(handler.handle(request.clone()).await);
}
results
}
#[tokio::main]
async fn main() {
// Single trait object
let handler: &dyn Handler = &EchoHandler;
let result = handler.handle("hello".to_string()).await;
println!("Result: {}", result);
// Boxed trait object
let boxed: Box<dyn Handler> = Box::new(EchoHandler);
let result = boxed.handle("world".to_string()).await;
println!("Boxed result: {}", result);
// Multiple handlers
let handlers: Vec<Box<dyn Handler>> = vec![
Box::new(EchoHandler),
Box::new(UppercaseHandler),
Box::new(ReverseHandler),
];
let results = process_with_all(&handlers, "hello".to_string()).await;
println!("Results: {:?}", results);
}use async_trait::async_trait;
#[async_trait]
trait Store {
type Item;
type Error;
async fn save(&self, item: Self::Item) -> Result<(), Self::Error>;
async fn load(&self, id: &str) -> Result<Self::Item, Self::Error>;
async fn delete(&self, id: &str) -> Result<(), Self::Error>;
}
#[derive(Debug, Clone)]
struct User {
id: String,
name: String,
}
#[derive(Debug)]
struct StoreError(String);
struct UserStore {
users: std::sync::Arc<tokio::sync::Mutex<std::collections::HashMap<String, User>>>,
}
impl UserStore {
fn new() -> Self {
Self {
users: std::sync::Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
}
}
}
#[async_trait]
impl Store for UserStore {
type Item = User;
type Error = StoreError;
async fn save(&self, item: Self::Item) -> Result<(), Self::Error> {
let mut users = self.users.lock().await;
users.insert(item.id.clone(), item);
Ok(())
}
async fn load(&self, id: &str) -> Result<Self::Item, Self::Error> {
let users = self.users.lock().await;
users.get(id).cloned().ok_or_else(|| StoreError(format!("User {} not found", id)))
}
async fn delete(&self, id: &str) -> Result<(), Self::Error> {
let mut users = self.users.lock().await;
users.remove(id).map(|_| ()).ok_or_else(|| StoreError(format!("User {} not found", id)))
}
}
#[tokio::main]
async fn main() {
let store = UserStore::new();
let user = User {
id: "1".to_string(),
name: "Alice".to_string(),
};
store.save(user.clone()).await.unwrap();
println!("Saved user: {:?}", user);
let loaded = store.load("1").await.unwrap();
println!("Loaded user: {:?}", loaded);
store.delete("1").await.unwrap();
println!("Deleted user");
let result = store.load("1").await;
println!("Load after delete: {:?}", result);
}use async_trait::async_trait;
#[async_trait]
trait AsyncFrom<T>: Sized {
async fn from(value: T) -> Self;
}
#[async_trait]
trait AsyncInto<T> {
async fn into_async(self) -> T;
}
// Generic trait implementation
#[async_trait]
impl<T, U> AsyncInto<U> for T
where
U: AsyncFrom<T> + Send,
T: Send,
{
async fn into_async(self) -> U {
U::from(self).await
}
}
// Concrete implementations
struct Number(i32);
struct StringNumber(String);
#[async_trait]
impl AsyncFrom<Number> for StringNumber {
async fn from(value: Number) -> Self {
// Simulate async conversion
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
StringNumber(value.0.to_string())
}
}
// Generic async function
#[async_trait]
trait AsyncMapper<T, U> {
async fn map(&self, input: T) -> U;
}
struct Doubler;
#[async_trait]
impl AsyncMapper<i32, i32> for Doubler {
async fn map(&self, input: i32) -> i32 {
input * 2
}
}
struct Stringifier;
#[async_trait]
impl AsyncMapper<i32, String> for Stringifier {
async fn map(&self, input: i32) -> String {
input.to_string()
}
}
#[tokio::main]
async fn main() {
let number = Number(42);
let string_number: StringNumber = AsyncFrom::from(number).await;
println!("String number: {}", string_number.0);
let doubler = Doubler;
let doubled = doubler.map(21).await;
println!("Doubled: {}", doubled);
let stringifier = Stringifier;
let string = stringifier.map(42).await;
println!("String: {}", string);
}use async_trait::async_trait;
use std::error::Error;
use std::fmt;
#[derive(Debug)]
struct MyError(String);
impl fmt::Display for MyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl Error for MyError {}
#[async_trait]
pub trait AsyncOperation {
type Output;
type Error: Error + Send;
async fn execute(&self) -> Result<Self::Output, Self::Error>;
}
#[async_trait]
pub trait AsyncRetry {
async fn execute_with_retry(&self, max_retries: u32) -> Result<String, MyError>;
}
struct FlakyOperation {
attempt: std::sync::Arc<tokio::sync::Mutex<u32>>,
fail_until: u32,
}
impl FlakyOperation {
fn new(fail_until: u32) -> Self {
Self {
attempt: std::sync::Arc::new(tokio::sync::Mutex::new(0)),
fail_until,
}
}
}
#[async_trait]
impl AsyncRetry for FlakyOperation {
async fn execute_with_retry(&self, max_retries: u32) -> Result<String, MyError> {
let mut retries = 0;
loop {
let mut attempt = self.attempt.lock().await;
*attempt += 1;
let current = *attempt;
drop(attempt);
if current > self.fail_until {
return Ok(format!("Success on attempt {}", current));
}
retries += 1;
if retries > max_retries {
return Err(MyError(format!("Failed after {} retries", max_retries)));
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
}
}
#[tokio::main]
async fn main() {
let op = FlakyOperation::new(2);
match op.execute_with_retry(5).await {
Ok(result) => println!("Success: {}", result),
Err(e) => println!("Error: {}", e),
}
}use async_trait::async_trait;
use std::sync::Arc;
#[async_trait]
pub trait Repository<T> {
async fn find_by_id(&self, id: u64) -> Result<Option<T>, String>;
async fn find_all(&self) -> Result<Vec<T>, String>;
async fn save(&self, entity: T) -> Result<T, String>;
async fn delete(&self, id: u64) -> Result<bool, String>;
}
#[derive(Debug, Clone, PartialEq)]
pub struct User {
pub id: u64,
pub name: String,
pub email: String,
}
pub struct InMemoryRepository<T> {
data: Arc<tokio::sync::Mutex<std::collections::HashMap<u64, T>>>,
next_id: Arc<tokio::sync::Mutex<u64>>,
}
impl<T: Clone + Send + 'static> InMemoryRepository<T> {
pub fn new() -> Self {
Self {
data: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
next_id: Arc::new(tokio::sync::Mutex::new(1)),
}
}
}
#[async_trait]
impl Repository<User> for InMemoryRepository<User> {
async fn find_by_id(&self, id: u64) -> Result<Option<User>, String> {
let data = self.data.lock().await;
Ok(data.get(&id).cloned())
}
async fn find_all(&self) -> Result<Vec<User>, String> {
let data = self.data.lock().await;
Ok(data.values().cloned().collect())
}
async fn save(&self, mut entity: User) -> Result<User, String> {
let mut next_id = self.next_id.lock().await;
if entity.id == 0 {
entity.id = *next_id;
*next_id += 1;
}
drop(next_id);
let mut data = self.data.lock().await;
data.insert(entity.id, entity.clone());
Ok(entity)
}
async fn delete(&self, id: u64) -> Result<bool, String> {
let mut data = self.data.lock().await;
Ok(data.remove(&id).is_some())
}
}
#[tokio::main]
async fn main() {
let repo = InMemoryRepository::<User>::new();
// Create users
let user1 = repo.save(User {
id: 0,
name: "Alice".to_string(),
email: "alice@example.com".to_string(),
}).await.unwrap();
let user2 = repo.save(User {
id: 0,
name: "Bob".to_string(),
email: "bob@example.com".to_string(),
}).await.unwrap();
println!("Created users: {:?} {:?}", user1, user2);
// Find by ID
let found = repo.find_by_id(1).await.unwrap();
println!("Found: {:?}", found);
// Find all
let all = repo.find_all().await.unwrap();
println!("All users: {:?}", all);
// Delete
let deleted = repo.delete(1).await.unwrap();
println!("Deleted: {}", deleted);
let remaining = repo.find_all().await.unwrap();
println!("Remaining: {:?}", remaining);
}use async_trait::async_trait;
use std::collections::HashMap;
#[async_trait]
pub trait Plugin: Send + Sync {
fn name(&self) -> &str;
async fn initialize(&mut self) -> Result<(), String>;
async fn process(&self, input: &str) -> Result<String, String>;
async fn shutdown(&self) -> Result<(), String>;
}
struct EchoPlugin;
#[async_trait]
impl Plugin for EchoPlugin {
fn name(&self) -> &str {
"echo"
}
async fn initialize(&mut self) -> Result<(), String> {
println!("Initializing echo plugin");
Ok(())
}
async fn process(&self, input: &str) -> Result<String, String> {
Ok(input.to_string())
}
async fn shutdown(&self) -> Result<(), String> {
println!("Shutting down echo plugin");
Ok(())
}
}
struct UpperPlugin;
#[async_trait]
impl Plugin for UpperPlugin {
fn name(&self) -> &str {
"upper"
}
async fn initialize(&mut self) -> Result<(), String> {
println!("Initializing upper plugin");
Ok(())
}
async fn process(&self, input: &str) -> Result<String, String> {
Ok(input.to_uppercase())
}
async fn shutdown(&self) -> Result<(), String> {
println!("Shutting down upper plugin");
Ok(())
}
}
struct PluginManager {
plugins: HashMap<String, Box<dyn Plugin>>,
}
impl PluginManager {
fn new() -> Self {
Self {
plugins: HashMap::new(),
}
}
async fn register(&mut self, mut plugin: Box<dyn Plugin>) -> Result<(), String> {
let name = plugin.name().to_string();
plugin.initialize().await?;
self.plugins.insert(name, plugin);
Ok(())
}
async fn process(&self, plugin_name: &str, input: &str) -> Result<String, String> {
let plugin = self.plugins.get(plugin_name)
.ok_or_else(|| format!("Plugin {} not found", plugin_name))?;
plugin.process(input).await
}
async fn shutdown_all(&self) {
for plugin in self.plugins.values() {
let _ = plugin.shutdown().await;
}
}
}
#[tokio::main]
async fn main() {
let mut manager = PluginManager::new();
// Register plugins
manager.register(Box::new(EchoPlugin)).await.unwrap();
manager.register(Box::new(UpperPlugin)).await.unwrap();
// Use plugins
let echo_result = manager.process("echo", "hello").await.unwrap();
println!("Echo: {}", echo_result);
let upper_result = manager.process("upper", "hello").await.unwrap();
println!("Upper: {}", upper_result);
// Shutdown
manager.shutdown_all().await;
}use async_trait::async_trait;
#[derive(Debug, Clone)]
pub struct Request {
pub method: String,
pub path: String,
pub body: String,
pub headers: std::collections::HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct Response {
pub status: u16,
pub body: String,
pub headers: std::collections::HashMap<String, String>,
}
impl Response {
fn ok(body: impl Into<String>) -> Self {
Self {
status: 200,
body: body.into(),
headers: std::collections::HashMap::new(),
}
}
fn not_found() -> Self {
Self {
status: 404,
body: "Not Found".to_string(),
headers: std::collections::HashMap::new(),
}
}
fn json(body: impl Into<String>) -> Self {
let mut headers = std::collections::HashMap::new();
headers.insert("Content-Type".to_string(), "application/json".to_string());
Self {
status: 200,
body: body.into(),
headers,
}
}
}
#[async_trait]
pub trait Handler: Send + Sync {
async fn handle(&self, request: Request) -> Response;
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn handle(&self, request: Request, next: &dyn Handler) -> Response;
}
// Concrete handlers
struct HomeHandler;
#[async_trait]
impl Handler for HomeHandler {
async fn handle(&self, _request: Request) -> Response {
Response::ok("Welcome!")
}
}
struct ApiHandler;
#[async_trait]
impl Handler for ApiHandler {
async fn handle(&self, request: Request) -> Response {
Response::json(format!("{{\"method\": \"{}\", \"body\": \"{}\"}}",
request.method, request.body))
}
}
// Middleware
struct LoggingMiddleware;
#[async_trait]
impl Middleware for LoggingMiddleware {
async fn handle(&self, request: Request, next: &dyn Handler) -> Response {
println!("Request: {} {}", request.method, request.path);
let response = next.handle(request).await;
println!("Response: {}", response.status);
response
}
}
struct AuthMiddleware {
api_key: String,
}
#[async_trait]
impl Middleware for AuthMiddleware {
async fn handle(&self, request: Request, next: &dyn Handler) -> Response {
let auth = request.headers.get("Authorization");
match auth {
Some(key) if key == &self.api_key => next.handle(request).await,
_ => Response {
status: 401,
body: "Unauthorized".to_string(),
headers: std::collections::HashMap::new(),
},
}
}
}
// Router
struct Router {
routes: std::collections::HashMap<String, Box<dyn Handler>>,
}
impl Router {
fn new() -> Self {
Self {
routes: std::collections::HashMap::new(),
}
}
fn add_route(&mut self, path: &str, handler: Box<dyn Handler>) {
self.routes.insert(path.to_string(), handler);
}
}
#[async_trait]
impl Handler for Router {
async fn handle(&self, request: Request) -> Response {
match self.routes.get(&request.path) {
Some(handler) => handler.handle(request).await,
None => Response::not_found(),
}
}
}
#[tokio::main]
async fn main() {
let mut router = Router::new();
router.add_route("/", Box::new(HomeHandler));
router.add_route("/api", Box::new(ApiHandler));
let request = Request {
method: "GET".to_string(),
path: "/".to_string(),
body: String::new(),
headers: std::collections::HashMap::new(),
};
let response = router.handle(request).await;
println!("Response: {:?}", response);
let api_request = Request {
method: "POST".to_string(),
path: "/api".to_string(),
body: "{\"name\": \"test\"}".to_string(),
headers: std::collections::HashMap::new(),
};
let api_response = router.handle(api_request).await;
println!("API Response: {:?}", api_response);
}// The #[async_trait] macro transforms this:
use async_trait::async_trait;
#[async_trait]
trait Example {
async fn method(&self) -> String;
}
// Into something like this:
trait ExampleExpanded {
fn method<'life0, 'async_trait>(
&'life0 self,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = String> + Send + 'async_trait>
>
where
'life0: 'async_trait,
Self: 'async_trait;
}
// Key points:
// 1. Async fn becomes a regular fn returning Pin<Box<dyn Future>>
// 2. The future is heap-allocated (boxed)
// 3. Send bound is added (for multi-threaded runtimes)
// 4. Lifetimes are automatically handled
// This is why:
// - There's a small runtime cost (boxing)
// - Futures are Send by default
// - You can use ?Send for non-Send futures
// - It works with complex lifetimes// Rust is adding native support for async in traits
// Currently unstable, but will eventually replace async-trait
// Future syntax (stabilization in progress):
// trait AsyncTrait {
// async fn method(&self) -> String;
// }
//
// impl AsyncTrait for MyType {
// async fn method(&self) -> String {
// "hello".to_string()
// }
// }
// Until then, async-trait remains the recommended approach#[async_trait] to define async methods in traits#[async_trait] to both trait definition and implementationsSend by default (required for multi-threaded runtimes)#[async_trait(?Send)] for non-Send futures (single-threaded contexts)dyn Trait) work with async-traitSend + Sync bounds for thread-safe trait objects