Loading page…
Rust walkthroughs
Loading page…
The async-trait crate provides a procedural macro that allows you to define async methods in traits. While Rust natively supports async functions, async methods in traits have historically been challenging due to how async functions desugar to return impl Future. The async-trait macro works around this by boxing the returned future, making trait methods work seamlessly. This is essential for building async abstractions, dependency injection patterns, and trait-based polymorphism in async Rust code.
Key concepts:
async fn methods in traitsPin<Box<dyn Future>>dyn Trait for dynamic dispatch# Cargo.toml
[dependencies]
async-trait = "0.1"
tokio = { version = "1", features = ["full"] }use async_trait::async_trait;
#[async_trait]
pub trait Database {
async fn get_user(&self, id: u64) -> Option<String>;
}
pub struct PostgresDb;
#[async_trait]
impl Database for PostgresDb {
async fn get_user(&self, id: u64) -> Option<String> {
// Simulate async database lookup
Some(format!("User {}", id))
}
}
#[tokio::main]
async fn main() {
let db = PostgresDb;
if let Some(user) = db.get_user(1).await {
println!("Found: {}", user);
}
}use async_trait::async_trait;
#[async_trait]
trait Fetcher {
async fn fetch(&self, url: &str) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
}
struct HttpFetcher;
#[async_trait]
impl Fetcher for HttpFetcher {
async fn fetch(&self, url: &str) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
// Simulate HTTP request
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Ok(format!("Response from {}", url))
}
}
#[tokio::main]
async fn main() {
let fetcher = HttpFetcher;
let result = fetcher.fetch("https://example.com").await;
println!("{:?}", result);
}use async_trait::async_trait;
#[async_trait]
pub trait Cache {
async fn get(&self, key: &str) -> Option<String>;
async fn set(&self, key: &str, value: &str) -> bool;
async fn delete(&self, key: &str) -> bool;
}
pub struct RedisCache {
// Simulated storage
store: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
}
impl RedisCache {
pub fn new() -> Self {
Self {
store: std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
}
}
}
#[async_trait]
impl Cache for RedisCache {
async fn get(&self, key: &str) -> Option<String> {
let store = self.store.read().await;
store.get(key).cloned()
}
async fn set(&self, key: &str, value: &str) -> bool {
let mut store = self.store.write().await;
store.insert(key.to_string(), value.to_string());
true
}
async fn delete(&self, key: &str) -> bool {
let mut store = self.store.write().await;
store.remove(key).is_some()
}
}
#[tokio::main]
async fn main() {
let cache = RedisCache::new();
cache.set("name", "Alice").await;
println!("Get: {:?}", cache.get("name").await);
cache.delete("name").await;
println!("After delete: {:?}", cache.get("name").await);
}use async_trait::async_trait;
#[async_trait]
pub trait Repository {
type Entity;
type Error: std::error::Error + Send + Sync + 'static;
async fn find_by_id(&self, id: u64) -> Result<Self::Entity, Self::Error>;
async fn save(&self, entity: &Self::Entity) -> Result<(), Self::Error>;
async fn delete(&self, id: u64) -> Result<(), Self::Error>;
}
#[derive(Debug, Clone)]
pub struct User {
pub id: u64,
pub name: String,
pub email: String,
}
#[derive(Debug)]
pub struct UserNotFoundError;
impl std::fmt::Display for UserNotFoundError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "User not found")
}
}
impl std::error::Error for UserNotFoundError {}
pub struct InMemoryUserRepo {
users: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<u64, User>>>,
}
impl InMemoryUserRepo {
pub fn new() -> Self {
Self {
users: std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
}
}
}
#[async_trait]
impl Repository for InMemoryUserRepo {
type Entity = User;
type Error = UserNotFoundError;
async fn find_by_id(&self, id: u64) -> Result<User, UserNotFoundError> {
let users = self.users.read().await;
users.get(&id).cloned().ok_or(UserNotFoundError)
}
async fn save(&self, entity: &User) -> Result<(), UserNotFoundError> {
let mut users = self.users.write().await;
users.insert(entity.id, entity.clone());
Ok(())
}
async fn delete(&self, id: u64) -> Result<(), UserNotFoundError> {
let mut users = self.users.write().await;
users.remove(&id).map(|_| ()).ok_or(UserNotFoundError)
}
}
#[tokio::main]
async fn main() {
let repo = InMemoryUserRepo::new();
let user = User {
id: 1,
name: "Alice".to_string(),
email: "alice@example.com".to_string(),
};
repo.save(&user).await.unwrap();
let found = repo.find_by_id(1).await.unwrap();
println!("Found: {:?}", found);
}use async_trait::async_trait;
#[async_trait]
pub trait Notifier: Send + Sync {
async fn notify(&self, message: &str) -> Result<(), String>;
}
pub struct EmailNotifier;
#[async_trait]
impl Notifier for EmailNotifier {
async fn notify(&self, message: &str) -> Result<(), String> {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
println!("Email sent: {}", message);
Ok(())
}
}
pub struct SmsNotifier;
#[async_trait]
impl Notifier for SmsNotifier {
async fn notify(&self, message: &str) -> Result<(), String> {
tokio::time::sleep(tokio::time::Duration::from_millis(30)).await;
println!("SMS sent: {}", message);
Ok(())
}
}
// Use dynamic dispatch
async fn send_notification(notifier: &Box<dyn Notifier>, message: &str) {
match notifier.notify(message).await {
Ok(()) => println!("Notification sent successfully"),
Err(e) => println!("Failed to send: {}", e),
}
}
#[tokio::main]
async fn main() {
let email: Box<dyn Notifier> = Box::new(EmailNotifier);
let sms: Box<dyn Notifier> = Box::new(SmsNotifier);
send_notification(&email, "Hello via email").await;
send_notification(&sms, "Hello via SMS").await;
}use async_trait::async_trait;
#[async_trait]
pub trait Processor<T> {
async fn process(&self, item: T) -> Result<String, String>;
}
pub struct UpperProcessor;
#[async_trait]
impl Processor<String> for UpperProcessor {
async fn process(&self, item: String) -> Result<String, String> {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
Ok(item.to_uppercase())
}
}
pub struct LengthProcessor;
#[async_trait]
impl Processor<String> for LengthProcessor {
async fn process(&self, item: String) -> Result<String, String> {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
Ok(format!("Length: {}", item.len()))
}
}
#[tokio::main]
async fn main() {
let upper = UpperProcessor;
let length = LengthProcessor;
println!("{:?}", upper.process("hello".to_string()).await);
println!("{:?}", length.process("hello world".to_string()).await);
}use async_trait::async_trait;
#[async_trait]
pub trait Logger: Send + Sync {
// Required method
async fn write(&self, message: &str) -> Result<(), std::io::Error>;
// Default implementation using the required method
async fn log_info(&self, message: &str) -> Result<(), std::io::Error> {
self.write(&format!("[INFO] {}", message)).await
}
async fn log_error(&self, message: &str) -> Result<(), std::io::Error> {
self.write(&format!("[ERROR] {}", message)).await
}
async fn log_debug(&self, message: &str) -> Result<(), std::io::Error> {
self.write(&format!("[DEBUG] {}", message)).await
}
}
pub struct ConsoleLogger;
#[async_trait]
impl Logger for ConsoleLogger {
async fn write(&self, message: &str) -> Result<(), std::io::Error> {
println!("{}", message);
Ok(())
}
// log_info, log_error, log_debug use default implementations
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let logger = ConsoleLogger;
logger.log_info("Application started").await?;
logger.log_error("Something went wrong").await?;
logger.log_debug("Debug information").await?;
Ok(())
}use async_trait::async_trait;
// When using async traits across thread boundaries, ensure Send + Sync
#[async_trait]
pub trait AsyncTask: Send + Sync {
async fn run(&self) -> Result<String, String>;
}
pub struct DownloadTask {
url: String,
}
impl DownloadTask {
pub fn new(url: &str) -> Self {
Self { url: url.to_string() }
}
}
#[async_trait]
impl AsyncTask for DownloadTask {
async fn run(&self) -> Result<String, String> {
// Simulate download
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Ok(format!("Downloaded from {}", self.url))
}
}
async fn execute_tasks(tasks: Vec<Box<dyn AsyncTask>>) {
let mut handles = vec![];
for task in tasks {
handles.push(tokio::spawn(async move {
task.run().await
}));
}
for handle in handles {
match handle.await {
Ok(Ok(result)) => println!("Success: {}", result),
Ok(Err(e)) => println!("Task error: {}", e),
Err(e) => println!("Join error: {}", e),
}
}
}
#[tokio::main]
async fn main() {
let tasks: Vec<Box<dyn AsyncTask>> = vec![
Box::new(DownloadTask::new("https://example.com/file1")),
Box::new(DownloadTask::new("https://example.com/file2")),
];
execute_tasks(tasks).await;
}use async_trait::async_trait;
#[async_trait]
pub trait Reader {
async fn read(&self) -> Result<Vec<u8>, std::io::Error>;
}
#[async_trait]
pub trait Writer {
async fn write(&self, data: &[u8]) -> Result<(), std::io::Error>;
}
#[async_trait]
pub trait Closer {
async fn close(&self) -> Result<(), std::io::Error>;
}
// Implement all three for a single struct
pub struct FileHandler {
path: String,
content: std::sync::Arc<tokio::sync::RwLock<Vec<u8>>>,
}
impl FileHandler {
pub fn new(path: &str) -> Self {
Self {
path: path.to_string(),
content: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
}
}
}
#[async_trait]
impl Reader for FileHandler {
async fn read(&self) -> Result<Vec<u8>, std::io::Error> {
let content = self.content.read().await;
Ok(content.clone())
}
}
#[async_trait]
impl Writer for FileHandler {
async fn write(&self, data: &[u8]) -> Result<(), std::io::Error> {
let mut content = self.content.write().await;
content.extend_from_slice(data);
Ok(())
}
}
#[async_trait]
impl Closer for FileHandler {
async fn close(&self) -> Result<(), std::io::Error> {
println!("Closing file: {}", self.path);
Ok(())
}
}
// Generic function accepting any type with all three traits
async fn process_file<T>(file: &T) -> Result<(), std::io::Error>
where
T: Reader + Writer + Closer,
{
file.write(b"Hello").await?;
let data = file.read().await?;
println!("Read {} bytes", data.len());
file.close().await?;
Ok(())
}
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
let file = FileHandler::new("/tmp/test.txt");
process_file(&file).await?;
Ok(())
}use async_trait::async_trait;
#[async_trait]
pub trait UserService {
async fn get_user(&self, id: u64) -> Result<String, String>;
}
// Real implementation
pub struct RealUserService;
#[async_trait]
impl UserService for RealUserService {
async fn get_user(&self, id: u64) -> Result<String, String> {
// Simulate database call
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Ok(format!("User {}", id))
}
}
// Mock implementation for testing
pub struct MockUserService {
pub response: String,
}
#[async_trait]
impl UserService for MockUserService {
async fn get_user(&self, _id: u64) -> Result<String, String> {
// No delay for tests
Ok(self.response.clone())
}
}
// Function using dependency injection
async fn greet_user(service: &dyn UserService) -> String {
match service.get_user(1).await {
Ok(user) => format!("Hello, {}!", user),
Err(e) => format!("Error: {}", e),
}
}
#[tokio::main]
async fn main() {
// Use real service
let real = RealUserService;
println!("{}", greet_user(&real).await);
// Use mock for testing
let mock = MockUserService {
response: "Test User".to_string(),
};
println!("{}", greet_user(&mock).await);
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_greet_user() {
let mock = MockUserService {
response: "Alice".to_string(),
};
let result = greet_user(&mock).await;
assert_eq!(result, "Hello, Alice!");
}
}use async_trait::async_trait;
#[async_trait]
pub trait HttpClient: Clone + Send + Sync + 'static {
async fn get(&self, url: &str) -> Result<String, String>;
}
#[derive(Clone)]
pub struct ReqwestClient {
timeout_ms: u64,
}
impl ReqwestClient {
pub fn builder() -> ClientBuilder {
ClientBuilder::default()
}
}
#[async_trait]
impl HttpClient for ReqwestClient {
async fn get(&self, url: &str) -> Result<String, String> {
tokio::time::sleep(tokio::time::Duration::from_millis(self.timeout_ms)).await;
Ok(format!("Response from {} (timeout: {}ms)", url, self.timeout_ms))
}
}
pub struct ClientBuilder {
timeout_ms: u64,
}
impl Default for ClientBuilder {
fn default() -> Self {
Self { timeout_ms: 5000 }
}
}
impl ClientBuilder {
pub fn timeout(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
pub fn build(self) -> ReqwestClient {
ReqwestClient { timeout_ms: self.timeout_ms }
}
}
#[tokio::main]
async fn main() {
let client = ReqwestClient::builder()
.timeout(3000)
.build();
let result = client.get("https://example.com").await;
println!("{:?}", result);
}use async_trait::async_trait;
#[async_trait]
pub trait Middleware: Send + Sync {
async fn handle(&self, request: &str) -> Result<String, String>;
}
pub struct LoggingMiddleware<M> {
inner: M,
}
impl<M> LoggingMiddleware<M> {
pub fn new(inner: M) -> Self {
Self { inner }
}
}
#[async_trait]
impl<M: Middleware + Send + Sync> Middleware for LoggingMiddleware<M> {
async fn handle(&self, request: &str) -> Result<String, String> {
println!("[LOG] Request: {}", request);
let result = self.inner.handle(request).await;
println!("[LOG] Result: {:?}", result);
result
}
}
pub struct AuthMiddleware<M> {
inner: M,
token: String,
}
impl<M> AuthMiddleware<M> {
pub fn new(inner: M, token: &str) -> Self {
Self { inner, token: token.to_string() }
}
}
#[async_trait]
impl<M: Middleware + Send + Sync> Middleware for AuthMiddleware<M> {
async fn handle(&self, request: &str) -> Result<String, String> {
if request.starts_with("auth:") {
let auth_request = request.strip_prefix("auth:").unwrap();
self.inner.handle(auth_request).await
} else {
Err("Unauthorized".to_string())
}
}
}
pub struct Handler;
#[async_trait]
impl Middleware for Handler {
async fn handle(&self, request: &str) -> Result<String, String> {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
Ok(format!("Processed: {}", request))
}
}
#[tokio::main]
async fn main() {
// Stack middlewares
let handler = Handler;
let with_auth = AuthMiddleware::new(handler, "secret-token");
let with_logging = LoggingMiddleware::new(with_auth);
// This will fail (not authorized)
println!("{:?}", with_logging.handle("hello").await);
// This will succeed
println!("{:?}", with_logging.handle("auth:hello").await);
}use async_trait::async_trait;
use std::sync::Arc;
#[async_trait]
pub trait Database: Send + Sync {
async fn query(&self, sql: &str) -> Result<Vec<Row>, DbError>;
async fn execute(&self, sql: &str) -> Result<u64, DbError>;
async fn begin_transaction(&self) -> Result<Box<dyn Transaction + '_>, DbError>;
}
#[async_trait]
pub trait Transaction: Send {
async fn commit(self: Box<Self>) -> Result<(), DbError>;
async fn rollback(self: Box<Self>) -> Result<(), DbError>;
}
#[derive(Debug, Clone)]
pub struct Row {
pub columns: Vec<(String, String)>,
}
#[derive(Debug)]
pub struct DbError(pub String);
impl std::fmt::Display for DbError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DbError: {}", self.0)
}
}
impl std::error::Error for DbError {}
// Mock implementation
pub struct MockDatabase {
data: Arc<tokio::sync::RwLock<Vec<Row>>>,
}
impl MockDatabase {
pub fn new() -> Self {
Self {
data: Arc::new(tokio::sync::RwLock::new(Vec::new())),
}
}
}
#[async_trait]
impl Database for MockDatabase {
async fn query(&self, _sql: &str) -> Result<Vec<Row>, DbError> {
let data = self.data.read().await;
Ok(data.clone())
}
async fn execute(&self, _sql: &str) -> Result<u64, DbError> {
Ok(1)
}
async fn begin_transaction(&self) -> Result<Box<dyn Transaction + '_>, DbError> {
Ok(Box::new(MockTransaction {
db: self,
committed: false,
}))
}
}
pub struct MockTransaction<'a> {
db: &'a MockDatabase,
committed: bool,
}
#[async_trait]
impl<'a> Transaction for MockTransaction<'a> {
async fn commit(self: Box<Self>) -> Result<(), DbError> {
println!("Transaction committed");
Ok(())
}
async fn rollback(self: Box<Self>) -> Result<(), DbError> {
println!("Transaction rolled back");
Ok(())
}
}
// Service using the abstraction
pub struct UserService {
db: Arc<dyn Database>,
}
impl UserService {
pub fn new(db: Arc<dyn Database>) -> Self {
Self { db }
}
pub async fn get_users(&self) -> Result<Vec<Row>, DbError> {
self.db.query("SELECT * FROM users").await
}
}
#[tokio::main]
async fn main() {
let db = Arc::new(MockDatabase::new());
let service = UserService::new(db);
let users = service.get_users().await;
println!("Users: {:?}", users);
}#[async_trait] attribute on both trait definition and implementationPin<Box<dyn Future>> under the hoodSend + Sync bounds for thread-safe trait objectsBox<dyn Trait> for dynamic dispatchasync-trait remains useful for compatibility