Skip to content

Commit

Permalink
Refactor get_user functions to return Result<Option<User>>
Browse files Browse the repository at this point in the history
This way plugins can react to a user existing or not without returning
an error via `?`.
  • Loading branch information
cmackenzie1 committed Feb 3, 2025
1 parent 16ebad6 commit b13ea4c
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 140 deletions.
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ check: fmt lint test
update:
@cargo update

# Generate documentation
docs:
@cargo doc --all-features --no-deps --open

# Help command to list all available commands
help:
@echo "Available commands:"
Expand All @@ -61,4 +65,4 @@ help:
@echo "${BLUE}make release${RESET} - Build for release"
@echo "${BLUE}make check${RESET} - Run all checks (format, lint, test)"
@echo "${BLUE}make update${RESET} - Update dependencies"

@echo "${BLUE}make docs${RESET} - Generate documentation"
13 changes: 9 additions & 4 deletions torii-auth-email/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ impl EmailPasswordPlugin {

// TODO: Wrap this in a transaction

if storage.get_user_by_email(email).await.is_ok() {
if let Some(_user) = storage
.get_user_by_email(email)
.await
.map_err(|_| Error::InternalServerError)?
{
tracing::debug!(email = %email, "User already exists");
return Err(Error::UserAlreadyExists);
}
Expand Down Expand Up @@ -83,7 +87,8 @@ impl EmailPasswordPlugin {
let user = storage
.get_user_by_email(email)
.await
.map_err(|_| Error::UserNotFound)?; // TODO: Should we return a different error that doesn't leak information about the user?
.map_err(|_| Error::UserNotFound)?
.ok_or_else(|| Error::UserNotFound)?;

let hash = storage
.get_password_hash(&user.id)
Expand Down Expand Up @@ -233,7 +238,7 @@ mod tests {
.create_user(&*user_storage, "not-an-email", "password")
.await;

assert!(result.is_err());
assert!(matches!(result, Err(Error::InvalidEmailFormat)));

Ok(())
}
Expand All @@ -248,7 +253,7 @@ mod tests {
.create_user(&*user_storage, "test@example.com", "123")
.await;

assert!(result.is_err());
assert!(matches!(result, Err(Error::WeakPassword)));

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion torii-auth-oidc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ serde.workspace = true
sqlx.workspace = true
tracing.workspace = true
uuid.workspace = true
torii-storage-sqlite = { path = "../torii-storage-sqlite", version = "0.1.0" }

[dev-dependencies]
tokio.workspace = true
axum = { version = "0.8", features = ["macros"] }
axum-extra = { version = "0.10", features = ["cookie"] }
tracing-subscriber.workspace = true
torii-storage-sqlite = { path = "../torii-storage-sqlite", version = "0.1.0" }

[[example]]
name = "google"
Expand Down
14 changes: 9 additions & 5 deletions torii-auth-oidc/examples/google/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ struct QueryParams {

#[derive(Clone)]
struct AppState {
pool: Pool<Sqlite>,
user_storage: Arc<SqliteStorage>,
session_storage: Arc<SqliteStorage>,
plugin_manager: Arc<PluginManager<SqliteStorage, SqliteStorage>>,
}

Expand All @@ -30,7 +31,8 @@ async fn login_handler(State(state): State<AppState>, jar: CookieJar) -> (Cookie
let plugin = state.plugin_manager.get_plugin::<OIDCPlugin>().unwrap();
let auth_flow = plugin
.begin_auth(
&state.pool,
&*state.user_storage,
&*state.session_storage,
"http://localhost:4000/auth/google/callback".to_string(),
)
.await
Expand Down Expand Up @@ -61,7 +63,8 @@ async fn callback_handler(
let plugin = state.plugin_manager.get_plugin::<OIDCPlugin>().unwrap();
let user = plugin
.callback(
&state.pool,
&*state.user_storage,
&*state.session_storage,
&AuthFlowCallback {
csrf_state: params.state,
nonce_key: nonce_key.to_string(),
Expand All @@ -87,7 +90,7 @@ async fn main() {
user_storage.migrate().await.unwrap();
session_storage.migrate().await.unwrap();

let mut plugin_manager = PluginManager::new(user_storage, session_storage);
let mut plugin_manager = PluginManager::new(user_storage.clone(), session_storage.clone());
plugin_manager.register(OIDCPlugin::new(
"google".to_string(),
std::env::var("GOOGLE_CLIENT_ID").expect("GOOGLE_CLIENT_ID must be set"),
Expand All @@ -101,7 +104,8 @@ async fn main() {
.route("/auth/google/login", get(login_handler))
.route("/auth/google/callback", get(callback_handler))
.with_state(AppState {
pool,
user_storage: user_storage.clone(),
session_storage: session_storage.clone(),
plugin_manager: Arc::new(plugin_manager),
});

Expand Down
157 changes: 54 additions & 103 deletions torii-auth-oidc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use openidconnect::{
AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, RedirectUrl, Scope,
TokenResponse,
};
use sqlx::{Pool, Row, Sqlite};
use torii_core::{Error, Plugin, SessionStorage, User, UserStorage};
use torii_core::{Error, NewUser, Plugin, SessionStorage, User, UserId, UserStorage};
use torii_storage_sqlite::{OIDCAccount, OIDCStorage};
use uuid::Uuid;

/// The core OIDC plugin struct, responsible for handling OIDC authentication flow.
Expand Down Expand Up @@ -69,7 +69,9 @@ impl OIDCPlugin {
redirect_uri,
}
}
}

impl OIDCPlugin {
/// Begin the authentication process by generating a new CSRF state and redirecting the user to the provider's authorization URL.
///
/// This method is the first step in the OIDC authorization code flow. It will:
Expand All @@ -92,9 +94,10 @@ impl OIDCPlugin {
/// * The provider metadata discovery fails
/// * The nonce cannot be stored in the database
/// * The HTTP client cannot be created
pub async fn begin_auth(
pub async fn begin_auth<U: OIDCStorage, S: SessionStorage>(
&self,
pool: &Pool<Sqlite>,
user_storage: &U,
_session_storage: &S,
redirect_uri: String,
) -> Result<AuthFlowBegin, Error> {
let http_client = openidconnect::reqwest::ClientBuilder::new()
Expand Down Expand Up @@ -133,11 +136,9 @@ impl OIDCPlugin {

// Store nonce in database
let nonce_key = Uuid::new_v4().to_string();
sqlx::query("INSERT INTO nonces (id, value, expires_at) VALUES (?, ?, ?)")
.bind(nonce_key.clone())
.bind(nonce.secret().to_string())
.bind(Utc::now() + Duration::hours(1))
.execute(pool)
let expires_at = Utc::now() + Duration::hours(1);
user_storage
.save_nonce(&nonce_key, &nonce.secret().to_string(), &expires_at)
.await
.map_err(|_| Error::InternalServerError)?;

Expand Down Expand Up @@ -165,13 +166,12 @@ impl OIDCPlugin {
///
/// # Returns
/// Returns a [`User`] struct containing the user's information.
pub async fn callback(
pub async fn callback<U: OIDCStorage, S: SessionStorage>(
&self,
pool: &Pool<Sqlite>,
user_storage: &U,
_session_storage: &S,
auth_flow: &AuthFlowCallback,
) -> Result<User, Error> {
let mut tx = pool.begin().await.map_err(|_| Error::InternalServerError)?;

// Create http client for async requests
// TODO: move to builder
let http_client = openidconnect::reqwest::ClientBuilder::new()
Expand Down Expand Up @@ -224,26 +224,14 @@ impl OIDCPlugin {
nonce_key = ?auth_flow.nonce_key,
"Attempting to get nonce from database"
);
let nonce = sqlx::query("SELECT value FROM nonces WHERE id = ? AND expires_at > ? LIMIT 1")
.bind(auth_flow.nonce_key.to_string())
.bind(Utc::now())
.fetch_optional(&mut *tx)
let nonce = user_storage
.get_nonce(&auth_flow.nonce_key)
.await
.map_err(|e| {
tracing::error!(
nonce_key = auth_flow.nonce_key.clone(),
error = ?e,
"Unable to get nonce from database"
);
Error::InternalServerError
})?;
.map_err(|_| Error::InternalServerError)?;

let nonce: String = match nonce {
Some(nonce) => nonce.get("value"),
None => {
tracing::error!("Nonce not found in database");
return Err(Error::InvalidCredentials);
}
let nonce = match nonce {
Some(nonce) => nonce.to_string(),
None => return Err(Error::InvalidCredentials),
};

// Verify id token
Expand All @@ -270,84 +258,47 @@ impl OIDCPlugin {
);

// Check if user exists in database by email
let user = sqlx::query(
r#"SELECT id, name, email, email_verified_at, created_at, updated_at
FROM users
WHERE id = (
SELECT user_id
FROM oidc_accounts
WHERE provider = ? AND subject = ?
LIMIT 1
)"#,
)
.bind(&self.provider)
.bind(&subject)
.fetch_optional(&mut *tx)
.await
.map_err(|_| Error::InternalServerError)?;
let oidc_account = user_storage
.get_oidc_account_by_provider_and_subject(&self.provider, &subject)
.await
.map_err(|_| Error::InternalServerError)?;

let user = match user {
Some(user) => {
tracing::info!(user.email = ?email, "User found in database");
user
}
None => {
tracing::info!("User not found in database, creating user");
// User does not exist, create user
sqlx::query("INSERT INTO users (id, email, name) VALUES (?, ?, ?)")
.bind(Uuid::new_v4().to_string())
.bind(email.as_str())
.bind(name.as_str())
.execute(&mut *tx)
.await
.map_err(|e| {
tracing::error!(
error = ?e,
"Unable to create user in database"
);
Error::InternalServerError
})?;

sqlx::query(
r#"SELECT id, name, email, email_verified_at, created_at, updated_at FROM users WHERE email = ?"#,
)
.bind(email.as_str())
.fetch_one(&mut *tx)
if let Some(oidc_account) = oidc_account {
tracing::info!(
user_id = ?oidc_account.user_id,
"User already exists in database"
);

let user = user_storage
.get_user(&oidc_account.user_id)
.await
.map_err(|e| {
tracing::error!(
error = ?e,
"Unable to get user from database"
);
Error::InternalServerError
})?
}
};
.map_err(|_| Error::InternalServerError)?
.ok_or_else(|| Error::UserNotFound)?;

let user = User {
id: user.get("id"),
name: user.get("name"),
email: user.get("email"),
email_verified_at: user.get("email_verified_at"),
created_at: user.get("created_at"),
updated_at: user.get("updated_at"),
};
tracing::info!(user = ?user, "User created in database");
// The user has already been created, so we can return them immediately
return Ok(user);
}

// Create user if they don't exist
let user = user_storage
.create_user(&NewUser {
id: UserId::new_random(),
email: email.to_string(),
})
.await
.map_err(|_| Error::InternalServerError)?;

// Create link between user and provider
sqlx::query("INSERT INTO oidc_accounts (user_id, provider, subject) VALUES (?, ?, ?)")
.bind(&user.id)
.bind(&self.provider)
.bind(&subject)
.execute(&mut *tx)
user_storage
.create_oidc_account(&OIDCAccount {
user_id: user.id.to_string(),
provider: self.provider.clone(),
subject: subject.clone(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.await
.map_err(|e| {
tracing::error!(
error = ?e,
"Unable to create link between user and provider"
);
Error::InternalServerError
})?;
.map_err(|_| Error::InternalServerError)?;

tracing::info!(
user_id = ?user.id,
Expand Down
2 changes: 1 addition & 1 deletion torii-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Core functionality for the torii project. All plugins are built on top of this library and are responsible for handling the specific details of each authentication method.

Plugins may use the core functionality to handle common tasks such as database migrations, user management, and session management, but are otherwise free to implement the logic in any way they want.
Plugins may use the core functionality for user management, and session management, but are otherwise free to implement the logic or storage in any way they want.

## Users

Expand Down
4 changes: 1 addition & 3 deletions torii-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
//!
//! This module contains the core functionality for the torii project.
//!
//! It includes the core user and session structs, as well as the plugin system including migrations.
//! It includes the core user and session structs, as well as the plugin system.
//!
//! The core module is designed to be used as a dependency for plugins and is not intended to be used directly by application code.
//!
//! See [`User`] for the core user struct, [`Session`] for the core session struct, and [`Plugin`] for the plugin system.
//!
//! If your plugin requires migrations to the database, you can use the [`Migration`] struct to define the migrations.
//!
pub mod error;
pub mod plugin;
pub mod session;
Expand Down
Loading

0 comments on commit b13ea4c

Please sign in to comment.