Skip to content

Commit

Permalink
Fix RwLock across await boundary by switching to DashMap
Browse files Browse the repository at this point in the history
  • Loading branch information
cmackenzie1 committed Feb 1, 2025
1 parent 57fdb5e commit 7275dbc
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 16 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ license = "MIT"

[workspace.dependencies]
async-trait = "0.1"
dashmap = "6.1"
futures = "0.3"
regex = "1"
serde = "1"
thiserror = "2"
Expand All @@ -19,7 +21,7 @@ uuid = { version = "1", features = ["v4"] }

sqlx = { version = "0.8", features = [
"runtime-tokio",
"sqlite", # TODO: Remove sqlite feature
"sqlite", # TODO: Remove sqlite dependency
"chrono",
"uuid",
] }
1 change: 0 additions & 1 deletion torii-auth-email/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
//! Password is hashed using the `password_auth` crate using argon2.
mod migrations;


use async_trait::async_trait;
use migrations::AddPasswordHashColumn;
use password_auth::{generate_hash, verify_password};
Expand Down
1 change: 0 additions & 1 deletion torii-auth-oidc/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

use async_trait::async_trait;
use migrations::CreateOidcTables;
use sqlx::{Pool, Sqlite};
Expand Down
2 changes: 2 additions & 0 deletions torii-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ license.workspace = true
[dependencies]
async-trait.workspace = true
chrono = "0.4.39"
dashmap = "6.1"
downcast-rs = "2.0.1"
futures.workspace = true
sqlx.workspace = true
thiserror.workspace = true
tracing.workspace = true
Expand Down
28 changes: 15 additions & 13 deletions torii-core/src/plugin.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use async_trait::async_trait;
use dashmap::DashMap;
use downcast_rs::{impl_downcast, DowncastSync};
use sqlx::{Pool, Row, Sqlite};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::Arc;

use crate::error::Error;
use crate::migration::PluginMigration;
Expand Down Expand Up @@ -61,7 +61,7 @@ impl_downcast!(sync Plugin);

/// Manages a collection of plugins.
pub struct PluginManager {
pub plugins: RwLock<HashMap<TypeId, Arc<dyn Plugin>>>,
pub plugins: DashMap<TypeId, Arc<dyn Plugin>>,
}

impl Default for PluginManager {
Expand All @@ -73,22 +73,21 @@ impl Default for PluginManager {
impl PluginManager {
pub fn new() -> Self {
Self {
plugins: RwLock::new(HashMap::new()),
plugins: DashMap::new(),
}
}

/// Get a plugin by type.
pub fn get_plugin<T: Plugin + 'static>(&self) -> Option<Arc<T>> {
let plugins = self.plugins.read().unwrap();
let plugin = plugins.get(&TypeId::of::<T>())?;
plugin.clone().downcast_arc::<T>().ok()
let plugin = self.plugins.get(&TypeId::of::<T>())?;
plugin.value().clone().downcast_arc::<T>().ok()
}

/// Register a new plugin.
pub fn register<T: Plugin + 'static>(&self, plugin: T) {
let plugin = Arc::new(plugin);
let type_id = TypeId::of::<T>();
self.plugins.write().unwrap().insert(type_id, plugin);
self.plugins.insert(type_id, plugin);
tracing::info!(
"Registered plugin: {}",
self.get_plugin::<T>().unwrap().name()
Expand All @@ -98,9 +97,9 @@ impl PluginManager {
/// Setup all registered plugins. This should be called before any authentication
/// is attempted.
pub async fn setup(&self, pool: &Pool<Sqlite>) -> Result<(), Error> {
for plugin in self.plugins.read().unwrap().values() {
plugin.setup(pool).await?;
tracing::info!("Setup plugin: {}", plugin.name());
for plugin in self.plugins.iter() {
plugin.value().setup(pool).await?;
tracing::info!("Setup plugin: {}", plugin.value().name());
}
Ok(())
}
Expand Down Expand Up @@ -202,9 +201,12 @@ impl PluginManager {
self.init_migration_table(pool).await?;
self.init_user_table(pool).await?;

for plugin in self.plugins.read().unwrap().values() {
let applied = self.get_applied_migrations(pool, plugin.name()).await?;
for plugin in self.plugins.iter() {
let applied = self
.get_applied_migrations(pool, plugin.value().name())
.await?;
let pending = plugin
.value()
.migrations()
.into_iter()
.filter(|m| !applied.contains(&m.version()));
Expand Down

0 comments on commit 7275dbc

Please sign in to comment.