From 7275dbca0583d8d00de5a062a74e93662150a071 Mon Sep 17 00:00:00 2001 From: Cole MacKenzie Date: Fri, 31 Jan 2025 21:52:45 -0800 Subject: [PATCH] Fix RwLock across await boundary by switching to DashMap --- Cargo.toml | 4 +++- torii-auth-email/src/lib.rs | 1 - torii-auth-oidc/src/lib.rs | 1 - torii-core/Cargo.toml | 2 ++ torii-core/src/plugin.rs | 28 +++++++++++++++------------- 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fa17869..2754d5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,8 @@ license = "MIT" [workspace.dependencies] async-trait = "0.1" +dashmap = "6.1" +futures = "0.3" regex = "1" serde = "1" thiserror = "2" @@ -19,7 +21,7 @@ uuid = { version = "1", features = ["v4"] } sqlx = { version = "0.8", features = [ "runtime-tokio", - "sqlite", # TODO: Remove sqlite feature + "sqlite", # TODO: Remove sqlite dependency "chrono", "uuid", ] } diff --git a/torii-auth-email/src/lib.rs b/torii-auth-email/src/lib.rs index 97f56da..2c90848 100644 --- a/torii-auth-email/src/lib.rs +++ b/torii-auth-email/src/lib.rs @@ -5,7 +5,6 @@ //! Password is hashed using the `password_auth` crate using argon2. mod migrations; - use async_trait::async_trait; use migrations::AddPasswordHashColumn; use password_auth::{generate_hash, verify_password}; diff --git a/torii-auth-oidc/src/lib.rs b/torii-auth-oidc/src/lib.rs index 2f46cb1..eb5780c 100644 --- a/torii-auth-oidc/src/lib.rs +++ b/torii-auth-oidc/src/lib.rs @@ -1,4 +1,3 @@ - use async_trait::async_trait; use migrations::CreateOidcTables; use sqlx::{Pool, Sqlite}; diff --git a/torii-core/Cargo.toml b/torii-core/Cargo.toml index 58b4aef..d6786f9 100644 --- a/torii-core/Cargo.toml +++ b/torii-core/Cargo.toml @@ -8,7 +8,9 @@ license.workspace = true [dependencies] async-trait.workspace = true chrono = "0.4.39" +dashmap = "6.1" downcast-rs = "2.0.1" +futures.workspace = true sqlx.workspace = true thiserror.workspace = true tracing.workspace = true diff --git a/torii-core/src/plugin.rs b/torii-core/src/plugin.rs index dad3364..373dd22 100644 --- a/torii-core/src/plugin.rs +++ b/torii-core/src/plugin.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; +use dashmap::DashMap; use downcast_rs::{impl_downcast, DowncastSync}; use sqlx::{Pool, Row, Sqlite}; use std::any::{Any, TypeId}; -use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use crate::error::Error; use crate::migration::PluginMigration; @@ -61,7 +61,7 @@ impl_downcast!(sync Plugin); /// Manages a collection of plugins. pub struct PluginManager { - pub plugins: RwLock>>, + pub plugins: DashMap>, } impl Default for PluginManager { @@ -73,22 +73,21 @@ impl Default for PluginManager { impl PluginManager { pub fn new() -> Self { Self { - plugins: RwLock::new(HashMap::new()), + plugins: DashMap::new(), } } /// Get a plugin by type. pub fn get_plugin(&self) -> Option> { - let plugins = self.plugins.read().unwrap(); - let plugin = plugins.get(&TypeId::of::())?; - plugin.clone().downcast_arc::().ok() + let plugin = self.plugins.get(&TypeId::of::())?; + plugin.value().clone().downcast_arc::().ok() } /// Register a new plugin. pub fn register(&self, plugin: T) { let plugin = Arc::new(plugin); let type_id = TypeId::of::(); - self.plugins.write().unwrap().insert(type_id, plugin); + self.plugins.insert(type_id, plugin); tracing::info!( "Registered plugin: {}", self.get_plugin::().unwrap().name() @@ -98,9 +97,9 @@ impl PluginManager { /// Setup all registered plugins. This should be called before any authentication /// is attempted. pub async fn setup(&self, pool: &Pool) -> Result<(), Error> { - for plugin in self.plugins.read().unwrap().values() { - plugin.setup(pool).await?; - tracing::info!("Setup plugin: {}", plugin.name()); + for plugin in self.plugins.iter() { + plugin.value().setup(pool).await?; + tracing::info!("Setup plugin: {}", plugin.value().name()); } Ok(()) } @@ -202,9 +201,12 @@ impl PluginManager { self.init_migration_table(pool).await?; self.init_user_table(pool).await?; - for plugin in self.plugins.read().unwrap().values() { - let applied = self.get_applied_migrations(pool, plugin.name()).await?; + for plugin in self.plugins.iter() { + let applied = self + .get_applied_migrations(pool, plugin.value().name()) + .await?; let pending = plugin + .value() .migrations() .into_iter() .filter(|m| !applied.contains(&m.version()));