From f2952a5019f236fd8a3b3808abe56c6d0a828158 Mon Sep 17 00:00:00 2001 From: Miguel Piedrafita Date: Wed, 13 Dec 2023 03:30:14 +0100 Subject: [PATCH] switch model and relationship away from async_trait --- ensemble/src/connection.rs | 116 +- ensemble/src/lib.rs | 361 +++-- ensemble/src/migrations/migrator.rs | 462 +++--- ensemble/src/migrations/mod.rs | 64 +- ensemble/src/migrations/schema/column.rs | 406 ++--- ensemble/src/migrations/schema/command.rs | 168 +- ensemble/src/migrations/schema/mod.rs | 492 +++--- ensemble/src/query.rs | 1426 ++++++++--------- ensemble/src/relationships/belongs_to.rs | 169 +- ensemble/src/relationships/belongs_to_many.rs | 221 ++- ensemble/src/relationships/has_many.rs | 285 ++-- ensemble/src/relationships/has_one.rs | 165 +- ensemble/src/relationships/mod.rs | 182 +-- ensemble/src/types/datetime.rs | 391 ++--- ensemble/src/types/hashed.rs | 108 +- ensemble/src/types/json.rs | 132 +- ensemble/src/types/uuid.rs | 86 +- ensemble/src/value/de.rs | 1186 +++++++------- ensemble/src/value/mod.rs | 6 +- ensemble/src/value/ser.rs | 962 +++++------ ensemble_derive/src/column/field.rs | 114 +- ensemble_derive/src/column/mod.rs | 352 ++-- ensemble_derive/src/lib.rs | 74 +- ensemble_derive/src/model/default/mod.rs | 74 +- ensemble_derive/src/model/field.rs | 480 +++--- ensemble_derive/src/model/mod.rs | 509 +++--- ensemble_derive/src/model/serde.rs | 526 +++--- examples/user/src/main.rs | 24 +- rustfmt.toml | 8 + test_suite/tests/derive.rs | 2 +- 30 files changed, 4783 insertions(+), 4768 deletions(-) create mode 100644 rustfmt.toml diff --git a/ensemble/src/connection.rs b/ensemble/src/connection.rs index f1fea0e..c9e86a4 100644 --- a/ensemble/src/connection.rs +++ b/ensemble/src/connection.rs @@ -13,12 +13,12 @@ static DB_POOL: OnceLock = OnceLock::new(); #[derive(Debug, thiserror::Error)] pub enum SetupError { - #[error("The provided database URL is invalid.")] - UrlError(#[from] rbatis::Error), + #[error("The provided database URL is invalid.")] + UrlError(#[from] rbatis::Error), - #[cfg(any(feature = "mysql", feature = "postgres"))] - #[error("The database pool has already been initialized.")] - AlreadyInitialized, + #[cfg(any(feature = "mysql", feature = "postgres"))] + #[error("The database pool has already been initialized.")] + AlreadyInitialized, } /// Sets up the database pool. @@ -28,44 +28,44 @@ pub enum SetupError { /// Returns an error if the database pool has already been initialized, or if the provided database URL is invalid. #[cfg(any(feature = "mysql", feature = "postgres"))] pub fn setup(database_url: &str) -> Result<(), SetupError> { - let rb = RBatis::new(); - - #[cfg(feature = "mysql")] - tracing::info!( - database_url = database_url, - "Setting up MySQL database pool..." - ); - #[cfg(feature = "postgres")] - tracing::info!( - database_url = database_url, - "Setting up PostgreSQL database pool..." - ); - - #[cfg(feature = "mysql")] - rb.init_option::( - MysqlDriver {}, - MySqlConnectOptions::from_str(database_url)?, - )?; - #[cfg(feature = "postgres")] - rb.init_option::( - PgDriver {}, - PgConnectOptions::from_str(database_url)?, - )?; - - DB_POOL - .set(rb) - .map_err(|_| SetupError::AlreadyInitialized)?; - - Ok(()) + let rb = RBatis::new(); + + #[cfg(feature = "mysql")] + tracing::info!( + database_url = database_url, + "Setting up MySQL database pool..." + ); + #[cfg(feature = "postgres")] + tracing::info!( + database_url = database_url, + "Setting up PostgreSQL database pool..." + ); + + #[cfg(feature = "mysql")] + rb.init_option::( + MysqlDriver {}, + MySqlConnectOptions::from_str(database_url)?, + )?; + #[cfg(feature = "postgres")] + rb.init_option::( + PgDriver {}, + PgConnectOptions::from_str(database_url)?, + )?; + + DB_POOL + .set(rb) + .map_err(|_| SetupError::AlreadyInitialized)?; + + Ok(()) } #[derive(Debug, thiserror::Error)] pub enum ConnectError { - #[error("The database pool has not been initialized.")] - NotInitialized, + #[error("The database pool has not been initialized.")] + NotInitialized, - #[error("An error occurred while connecting to the database.")] - Connection(#[from] rbatis::Error), + #[error("An error occurred while connecting to the database.")] + Connection(#[from] rbatis::Error), } /// Returns a connection to the database. Used internally by `ensemble` models. @@ -74,37 +74,37 @@ pub enum ConnectError { /// /// Returns an error if the database pool has not been initialized, or if an error occurs while connecting to the database. pub async fn get() -> Result { - match DB_POOL.get() { - None => Err(ConnectError::NotInitialized), - Some(rb) => Ok(rb.get_pool()?.get().await?), - } + match DB_POOL.get() { + None => Err(ConnectError::NotInitialized), + Some(rb) => Ok(rb.get_pool()?.get().await?), + } } #[cfg(any(feature = "mysql", feature = "postgres"))] pub enum Database { - MySQL, - PostgreSQL, + MySQL, + PostgreSQL, } #[cfg(any(feature = "mysql", feature = "postgres"))] impl Database { - pub const fn is_mysql(&self) -> bool { - matches!(self, Self::MySQL) - } + pub const fn is_mysql(&self) -> bool { + matches!(self, Self::MySQL) + } - pub const fn is_postgres(&self) -> bool { - matches!(self, Self::PostgreSQL) - } + pub const fn is_postgres(&self) -> bool { + matches!(self, Self::PostgreSQL) + } } #[cfg(any(feature = "mysql", feature = "postgres"))] pub const fn which_db() -> Database { - #[cfg(all(feature = "mysql", feature = "postgres"))] - panic!("Both the `mysql` and `postgres` features are enabled. Please enable only one of them."); - - if cfg!(feature = "mysql") { - Database::MySQL - } else { - Database::PostgreSQL - } + #[cfg(all(feature = "mysql", feature = "postgres"))] + panic!("Both the `mysql` and `postgres` features are enabled. Please enable only one of them."); + + if cfg!(feature = "mysql") { + Database::MySQL + } else { + Database::PostgreSQL + } } diff --git a/ensemble/src/lib.rs b/ensemble/src/lib.rs index c6e8065..1e47364 100644 --- a/ensemble/src/lib.rs +++ b/ensemble/src/lib.rs @@ -19,8 +19,9 @@ pub use serde_json; use query::{Builder, EagerLoad}; use serde::{de::DeserializeOwned, Serialize}; use std::{ - collections::HashMap, - fmt::{Debug, Display}, + collections::HashMap, + fmt::{Debug, Display}, + future::Future, }; mod connection; @@ -35,195 +36,201 @@ pub use ensemble_derive::Model; #[derive(Debug, thiserror::Error)] pub enum Error { - #[error(transparent)] - Connection(#[from] ConnectError), + #[error(transparent)] + Connection(#[from] ConnectError), - #[cfg(feature = "validator")] - #[error(transparent)] - Validation(#[from] validator::ValidationErrors), + #[cfg(feature = "validator")] + #[error(transparent)] + Validation(#[from] validator::ValidationErrors), - #[error("{0}")] - Database(String), + #[error("{0}")] + Database(String), - #[error("The {0} field is required.")] - Required(&'static str), + #[error("The {0} field is required.")] + Required(&'static str), - #[error("Failed to serialize model.")] - Serialization(#[from] rbs::value::ext::Error), + #[error("Failed to serialize model.")] + Serialization(#[from] rbs::value::ext::Error), - #[error("The model could not be found.")] - NotFound, + #[error("The model could not be found.")] + NotFound, - #[error("The unique constraint was violated.")] - UniqueViolation, + #[error("The unique constraint was violated.")] + UniqueViolation, - #[error("The query is invalid.")] - InvalidQuery, + #[error("The query is invalid.")] + InvalidQuery, } -#[async_trait] pub trait Model: DeserializeOwned + Serialize + Sized + Send + Sync + Debug + Default { - /// The type of the primary key for the model. - type PrimaryKey: Display - + DeserializeOwned - + Serialize - + PartialEq - + Default - + Clone - + Send - + Sync; - - /// The name of the model. - const NAME: &'static str; - - /// The name of the table for the model - const TABLE_NAME: &'static str; - - /// The name of the primary key field for the model. - const PRIMARY_KEY: &'static str; - - /// Returns the value of the model's primary key. - fn primary_key(&self) -> &Self::PrimaryKey; - - /// Get all of the models from the database. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - async fn all() -> Result, Error> { - Self::query().get().await - } - - /// Find a model by its primary key. - /// - /// # Errors - /// - /// Returns an error if the model cannot be found, or if a connection to the database cannot be established. - async fn find(key: Self::PrimaryKey) -> Result; - - /// Insert a new model into the database. - /// - /// # Errors - /// - /// Returns an error if the model cannot be inserted, or if a connection to the database cannot be established. - async fn create(self) -> Result; - - /// Update the model in the database. - /// - /// # Errors - /// - /// Returns an error if the model cannot be updated, or if a connection to the database cannot be established. - async fn save(&mut self) -> Result<(), Error>; - - /// Delete the model from the database. - /// - /// # Errors - /// - /// Returns an error if the model cannot be deleted, or if a connection to the database cannot be established. - async fn delete(mut self) -> Result<(), Error> { - let rows_affected = Self::query() - .r#where( - Self::PRIMARY_KEY, - "=", - value::for_db(self.primary_key()).unwrap(), - ) - .delete() - .await?; - - if rows_affected != 1 { - return Err(Error::UniqueViolation); - } - - Ok(()) - } - - /// Reload a fresh model instance from the database. - /// - /// # Errors - /// Returns an error if the model cannot be retrieved, or if a connection to the database cannot be established. - async fn fresh(&self) -> Result; - - /// Begin querying the model. - #[must_use] - fn query() -> Builder { - Builder::new(Self::TABLE_NAME.to_string()) - } - - /// Begin querying a model with eager loading. - fn with>(eager_load: T) -> Builder { - Self::query().with(eager_load) - } - - async fn load + Send>(&mut self, relation: T) -> Result<(), Error> { - for relation in relation.into().list() { - let rows = self.eager_load(&relation, &[&self]).get_rows().await?; - - self.fill_relation(&relation, &rows)?; - } - - Ok(()) - } - - /// Convert the model to a JSON value. - /// - /// # Panics - /// - /// Panics if the model cannot be converted to JSON. Since Ensemble manually implement Serialize, this should never happen. - #[cfg(feature = "json")] - fn json(&self) -> serde_json::Value { - serde_json::to_value(self).unwrap() - } - - /// Eager load a relationship for a set of models. - /// This method is used internally by Ensemble, and should not be called directly. - #[doc(hidden)] - fn eager_load(&self, relation: &str, related: &[&Self]) -> Builder; - - /// Fill a relationship for a set of models. - /// This method is used internally by Ensemble, and should not be called directly. - #[doc(hidden)] - fn fill_relation( - &mut self, - relation: &str, - related: &[HashMap], - ) -> Result<(), Error>; + /// The type of the primary key for the model. + type PrimaryKey: Display + + DeserializeOwned + + Serialize + + PartialEq + + Default + + Clone + + Send + + Sync; + + /// The name of the model. + const NAME: &'static str; + + /// The name of the table for the model + const TABLE_NAME: &'static str; + + /// The name of the primary key field for the model. + const PRIMARY_KEY: &'static str; + + /// Returns the value of the model's primary key. + fn primary_key(&self) -> &Self::PrimaryKey; + + /// Get all of the models from the database. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + #[must_use] + fn all() -> impl Future, Error>> + Send { + async { Self::query().get().await } + } + + /// Find a model by its primary key. + /// + /// # Errors + /// + /// Returns an error if the model cannot be found, or if a connection to the database cannot be established. + fn find(key: Self::PrimaryKey) -> impl Future> + Send; + + /// Insert a new model into the database. + /// + /// # Errors + /// + /// Returns an error if the model cannot be inserted, or if a connection to the database cannot be established. + fn create(self) -> impl Future> + Send; + + /// Update the model in the database. + /// + /// # Errors + /// + /// Returns an error if the model cannot be updated, or if a connection to the database cannot be established. + fn save(&mut self) -> impl Future> + Send; + + /// Delete the model from the database. + /// + /// # Errors + /// + /// Returns an error if the model cannot be deleted, or if a connection to the database cannot be established. + #[allow(unused_mut)] + fn delete(mut self) -> impl Future> + Send { + async move { + let rows_affected = Self::query() + .r#where( + Self::PRIMARY_KEY, + "=", + value::for_db(self.primary_key()).unwrap(), + ) + .delete() + .await?; + + if rows_affected != 1 { + return Err(Error::UniqueViolation); + } + + Ok(()) + } + } + + /// Reload a fresh model instance from the database. + /// + /// # Errors + /// Returns an error if the model cannot be retrieved, or if a connection to the database cannot be established. + fn fresh(&self) -> impl Future> + Send; + + /// Begin querying the model. + #[must_use] + fn query() -> Builder { + Builder::new(Self::TABLE_NAME.to_string()) + } + + /// Begin querying a model with eager loading. + fn with>(eager_load: T) -> Builder { + Self::query().with(eager_load) + } + + fn load + Send>( + &mut self, + relation: T, + ) -> impl Future> + Send { + async move { + for relation in relation.into().list() { + let rows = self.eager_load(&relation, &[&self]).get_rows().await?; + + self.fill_relation(&relation, &rows)?; + } + + Ok(()) + } + } + + /// Convert the model to a JSON value. + /// + /// # Panics + /// + /// Panics if the model cannot be converted to JSON. Since Ensemble manually implement Serialize, this should never happen. + #[cfg(feature = "json")] + fn json(&self) -> serde_json::Value { + serde_json::to_value(self).unwrap() + } + + /// Eager load a relationship for a set of models. + /// This method is used internally by Ensemble, and should not be called directly. + #[doc(hidden)] + fn eager_load(&self, relation: &str, related: &[&Self]) -> Builder; + + /// Fill a relationship for a set of models. + /// This method is used internally by Ensemble, and should not be called directly. + #[doc(hidden)] + fn fill_relation( + &mut self, + relation: &str, + related: &[HashMap], + ) -> Result<(), Error>; } -#[async_trait] pub trait Collection { - /// Eager load a relationship for a collection of models. - /// - /// # Errors - /// - /// Returns an error if any of the models fail to load, or if a connection to the database cannot be established. - async fn load(&mut self, relation: T) -> Result<(), Error> - where - T: Into + Send + Sync + Clone; - - /// Convert the collection to a JSON value. - /// - /// # Panics - /// - /// Panics if the collection cannot be converted to JSON. Since models manually implement Serialize, this should never happen. - #[cfg(feature = "json")] - fn json(&self) -> serde_json::Value; + /// Eager load a relationship for a collection of models. + /// + /// # Errors + /// + /// Returns an error if any of the models fail to load, or if a connection to the database cannot be established. + fn load(&mut self, relation: T) -> impl Future> + Send + where + T: Into + Send + Sync + Clone; + + /// Convert the collection to a JSON value. + /// + /// # Panics + /// + /// Panics if the collection cannot be converted to JSON. Since models manually implement Serialize, this should never happen. + #[cfg(feature = "json")] + fn json(&self) -> serde_json::Value; } -#[async_trait] impl Collection for &mut Vec { - async fn load(&mut self, relation: U) -> Result<(), Error> - where - U: Into + Send + Sync + Clone, - { - for model in self.iter_mut() { - model.load(relation.clone()).await?; - } - - Ok(()) - } - - #[cfg(feature = "json")] - fn json(&self) -> serde_json::Value { - serde_json::to_value(self).unwrap() - } + async fn load(&mut self, relation: U) -> Result<(), Error> + where + U: Into + Send + Sync + Clone, + { + for model in self.iter_mut() { + model.load(relation.clone()).await?; + } + + Ok(()) + } + + #[cfg(feature = "json")] + fn json(&self) -> serde_json::Value { + serde_json::to_value(self).unwrap() + } } diff --git a/ensemble/src/migrations/migrator.rs b/ensemble/src/migrations/migrator.rs index 4d1ca3c..a73dcd3 100644 --- a/ensemble/src/migrations/migrator.rs +++ b/ensemble/src/migrations/migrator.rs @@ -5,260 +5,260 @@ use tokio::sync::Mutex; use super::{Error, Migration}; use crate::{ - connection::{self, Connection}, - value, + connection::{self, Connection}, + value, }; pub static MIGRATE_CONN: Mutex> = Mutex::const_new(None); /// The migration runner. pub struct Migrator { - batch: u64, - connection: Connection, - state: Vec, - migrations: Vec<(String, Box)>, + batch: u64, + connection: Connection, + state: Vec, + migrations: Vec<(String, Box)>, } impl Migrator { - /// Creates a new [`Migrator`]. - /// - /// # Errors - /// - /// Returns an error if a connection to the database cannot be established, or if the migrations cannot be retrieved. - pub async fn new() -> Result { - let mut conn = connection::get().await?; - let state = Self::get_state(&mut conn).await?; - let batch = state - .iter() - .map(|m| m.batch) - .max() - .unwrap_or_default() - .saturating_add(1); - - tracing::debug!( - batch = batch, - state = ?state, - "Loaded migration state from database." - ); - - Ok(Self { - state, - batch, - connection: conn, - migrations: Vec::new(), - }) - } - - /// Registers a migration. - /// - /// # Panics - /// - /// Panics if a migration with the same name has already been registered. - pub fn register(&mut self, name: String, migration: Box) { - tracing::trace!("Registered migration [{name}]"); - - assert!( - !self.migrations.iter().any(|(n, _)| n == &name), - "A migration with the name [{name}] has already been registered." - ); - - self.migrations.push((name, migration)); - } - - /// Returns a list of migrations that have been run. - #[must_use] - pub fn status(&self) -> Vec { - self.state.clone() - } - - /// Returns a list of migrations that have not been run. - #[must_use] - pub fn pending(&self) -> HashMap<&str, &dyn Migration> { - self.migrations - .iter() - .filter(|(name, _)| !self.state.iter().any(|m| &m.migration == name)) - .map(|(name, migration)| (name.as_str(), migration.as_ref())) - .collect() - } - - /// Runs the migrations. - /// - /// # Errors - /// - /// Returns an error if the migrations fail, or if a connection to the database cannot be established. - pub async fn run(mut self) -> Result<(), Error> { - for (name, migration) in &self.migrations { - if self.state.iter().any(|m| &m.migration == name) { - tracing::trace!("Skipping migration [{name}], since it's already been run."); - continue; - } - - tracing::trace!("Running migration [{name}]."); - - self.connection - .exec("begin", vec![]) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - MIGRATE_CONN - .try_lock() - .map_err(|_| Error::Lock)? - .replace(self.connection); - - let migration_result = migration.up().await; - - self.connection = MIGRATE_CONN - .try_lock() - .map_err(|_| Error::Lock)? - .take() - .ok_or(Error::Lock)?; - - if let Err(e) = migration_result { - self.connection - .exec("rollback", vec![]) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - tracing::debug!("Rolled back changes for migration [{name}]."); - - return Err(e); - } - - self.connection - .exec( - "insert into migrations (migration, batch) values (?, ?)", - vec![value::for_db(name)?, value::for_db(self.batch)?], - ) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - self.connection - .exec("commit", vec![]) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - self.state.push(StoredMigration { - id: 0, - batch: self.batch, - migration: name.to_string(), - }); - - tracing::info!("Successfully ran migration [{name}]."); - } - - Ok(()) - } - - /// Rolls back the last `n` batches of migrations. - /// - /// # Errors - /// - /// Returns an error if the migrations fail, or if a connection to the database cannot be established. - pub async fn rollback(mut self, batches: u64) -> Result<(), Error> { - let migrations = self - .state - .into_iter() - .filter(|m| m.batch >= self.batch.saturating_sub(batches)) - .rev(); - - for record in migrations { - let (name, migration) = self - .migrations - .iter() - .find(|(name, _)| name == &record.migration) - .ok_or_else(|| Error::NotFound(record.migration.clone()))?; - - self.connection - .exec("begin", vec![]) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - MIGRATE_CONN - .try_lock() - .map_err(|_| Error::Lock)? - .replace(self.connection); - - let migration_result = migration.down().await; - - self.connection = MIGRATE_CONN - .try_lock() - .map_err(|_| Error::Lock)? - .take() - .ok_or(Error::Lock)?; - - if let Err(e) = migration_result { - self.connection - .exec("rollback", vec![]) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - tracing::debug!("Re-applied changes for migration [{name}]."); - - return Err(e); - } - - self.connection - .exec( - "delete from migrations where id = ?", - vec![Value::U64(record.id)], - ) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - self.connection - .exec("commit", vec![]) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - tracing::info!("Successfully rolled back migration [{name}]."); - } - - Ok(()) - } - - async fn get_state(conn: &mut Connection) -> Result, Error> { - let sql = migrations_table_query(); - - tracing::debug!(sql = sql, "Running CREATE TABLE IF NOT EXISTS SQL query"); - - conn.exec(sql, vec![]) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - Ok(conn - .get_values("select * from migrations", vec![]) - .await - .map_err(|e| Error::Database(e.to_string()))? - .into_iter() - .map(from_value) - .collect::, _>>()?) - } + /// Creates a new [`Migrator`]. + /// + /// # Errors + /// + /// Returns an error if a connection to the database cannot be established, or if the migrations cannot be retrieved. + pub async fn new() -> Result { + let mut conn = connection::get().await?; + let state = Self::get_state(&mut conn).await?; + let batch = state + .iter() + .map(|m| m.batch) + .max() + .unwrap_or_default() + .saturating_add(1); + + tracing::debug!( + batch = batch, + state = ?state, + "Loaded migration state from database." + ); + + Ok(Self { + state, + batch, + connection: conn, + migrations: Vec::new(), + }) + } + + /// Registers a migration. + /// + /// # Panics + /// + /// Panics if a migration with the same name has already been registered. + pub fn register(&mut self, name: String, migration: Box) { + tracing::trace!("Registered migration [{name}]"); + + assert!( + !self.migrations.iter().any(|(n, _)| n == &name), + "A migration with the name [{name}] has already been registered." + ); + + self.migrations.push((name, migration)); + } + + /// Returns a list of migrations that have been run. + #[must_use] + pub fn status(&self) -> Vec { + self.state.clone() + } + + /// Returns a list of migrations that have not been run. + #[must_use] + pub fn pending(&self) -> HashMap<&str, &dyn Migration> { + self.migrations + .iter() + .filter(|(name, _)| !self.state.iter().any(|m| &m.migration == name)) + .map(|(name, migration)| (name.as_str(), migration.as_ref())) + .collect() + } + + /// Runs the migrations. + /// + /// # Errors + /// + /// Returns an error if the migrations fail, or if a connection to the database cannot be established. + pub async fn run(mut self) -> Result<(), Error> { + for (name, migration) in &self.migrations { + if self.state.iter().any(|m| &m.migration == name) { + tracing::trace!("Skipping migration [{name}], since it's already been run."); + continue; + } + + tracing::trace!("Running migration [{name}]."); + + self.connection + .exec("begin", vec![]) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + MIGRATE_CONN + .try_lock() + .map_err(|_| Error::Lock)? + .replace(self.connection); + + let migration_result = migration.up().await; + + self.connection = MIGRATE_CONN + .try_lock() + .map_err(|_| Error::Lock)? + .take() + .ok_or(Error::Lock)?; + + if let Err(e) = migration_result { + self.connection + .exec("rollback", vec![]) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + tracing::debug!("Rolled back changes for migration [{name}]."); + + return Err(e); + } + + self.connection + .exec( + "insert into migrations (migration, batch) values (?, ?)", + vec![value::for_db(name)?, value::for_db(self.batch)?], + ) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + self.connection + .exec("commit", vec![]) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + self.state.push(StoredMigration { + id: 0, + batch: self.batch, + migration: name.to_string(), + }); + + tracing::info!("Successfully ran migration [{name}]."); + } + + Ok(()) + } + + /// Rolls back the last `n` batches of migrations. + /// + /// # Errors + /// + /// Returns an error if the migrations fail, or if a connection to the database cannot be established. + pub async fn rollback(mut self, batches: u64) -> Result<(), Error> { + let migrations = self + .state + .into_iter() + .filter(|m| m.batch >= self.batch.saturating_sub(batches)) + .rev(); + + for record in migrations { + let (name, migration) = self + .migrations + .iter() + .find(|(name, _)| name == &record.migration) + .ok_or_else(|| Error::NotFound(record.migration.clone()))?; + + self.connection + .exec("begin", vec![]) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + MIGRATE_CONN + .try_lock() + .map_err(|_| Error::Lock)? + .replace(self.connection); + + let migration_result = migration.down().await; + + self.connection = MIGRATE_CONN + .try_lock() + .map_err(|_| Error::Lock)? + .take() + .ok_or(Error::Lock)?; + + if let Err(e) = migration_result { + self.connection + .exec("rollback", vec![]) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + tracing::debug!("Re-applied changes for migration [{name}]."); + + return Err(e); + } + + self.connection + .exec( + "delete from migrations where id = ?", + vec![Value::U64(record.id)], + ) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + self.connection + .exec("commit", vec![]) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + tracing::info!("Successfully rolled back migration [{name}]."); + } + + Ok(()) + } + + async fn get_state(conn: &mut Connection) -> Result, Error> { + let sql = migrations_table_query(); + + tracing::debug!(sql = sql, "Running CREATE TABLE IF NOT EXISTS SQL query"); + + conn.exec(sql, vec![]) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + Ok(conn + .get_values("select * from migrations", vec![]) + .await + .map_err(|e| Error::Database(e.to_string()))? + .into_iter() + .map(from_value) + .collect::, _>>()?) + } } #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] pub struct StoredMigration { - pub id: u64, - pub batch: u64, - pub migration: String, + pub id: u64, + pub batch: u64, + pub migration: String, } const fn migrations_table_query() -> &'static str { - use crate::connection::Database; + use crate::connection::Database; - match connection::which_db() { - Database::MySQL => { - "create table if not exists migrations ( + match connection::which_db() { + Database::MySQL => { + "create table if not exists migrations ( id int unsigned not null auto_increment primary key, migration varchar(255) not null unique, batch int not null )" - } - Database::PostgreSQL => { - "create table if not exists migrations ( + }, + Database::PostgreSQL => { + "create table if not exists migrations ( id serial primary key, migration varchar(255) not null unique, batch int not null )" - } - } + }, + } } diff --git a/ensemble/src/migrations/mod.rs b/ensemble/src/migrations/mod.rs index c504f55..cdd7215 100644 --- a/ensemble/src/migrations/mod.rs +++ b/ensemble/src/migrations/mod.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; -use std::fmt::Debug; use crate::connection::ConnectError; +use std::fmt::Debug; #[cfg(any(feature = "mysql", feature = "postgres"))] pub use {migrator::Migrator, schema::Schema}; @@ -16,29 +16,29 @@ pub mod schema; /// Errors that can occur while running migrations. #[derive(Debug, thiserror::Error)] pub enum Error { - /// An error occurred while connecting to the database. - #[error("Failed to connect to database.")] - Connection(#[from] ConnectError), + /// An error occurred while connecting to the database. + #[error("Failed to connect to database.")] + Connection(#[from] ConnectError), - /// An error occurred while running a migration. - #[error("{0}")] - Database(String), + /// An error occurred while running a migration. + #[error("{0}")] + Database(String), - /// The migration could not be found. - #[error("Could not locate the {0} migration.")] - NotFound(String), + /// The migration could not be found. + #[error("Could not locate the {0} migration.")] + NotFound(String), - /// There was an internal error with the migrations system. - #[error("Failed to receive column in schema.")] - SendColumn, + /// There was an internal error with the migrations system. + #[error("Failed to receive column in schema.")] + SendColumn, - /// One of the migrations locked the connection. - #[error("Failed to obtain connection")] - Lock, + /// One of the migrations locked the connection. + #[error("Failed to obtain connection")] + Lock, - /// The migration data could not be decoded. - #[error("Failed to deserialize migration data.")] - Decode(#[from] rbs::Error), + /// The migration data could not be decoded. + #[error("Failed to deserialize migration data.")] + Decode(#[from] rbs::Error), } /// Accepts a list of structs that implement the [`Migration`] trait, and runs them. @@ -57,20 +57,20 @@ macro_rules! migrate { }; } -#[async_trait] /// A trait for defining migrations. +#[async_trait] pub trait Migration: Sync + Send { - /// Runs the migration. - /// - /// # Errors - /// - /// Returns an error if the migration fails, or if a connection to the database cannot be established. - async fn up(&self) -> Result<(), Error>; + /// Runs the migration. + /// + /// # Errors + /// + /// Returns an error if the migration fails, or if a connection to the database cannot be established. + async fn up(&self) -> Result<(), Error>; - /// Reverts the migration. - /// - /// # Errors - /// - /// Returns an error if the migration fails, or if a connection to the database cannot be established. - async fn down(&self) -> Result<(), Error>; + /// Reverts the migration. + /// + /// # Errors + /// + /// Returns an error if the migration fails, or if a connection to the database cannot be established. + async fn down(&self) -> Result<(), Error>; } diff --git a/ensemble/src/migrations/schema/column.rs b/ensemble/src/migrations/schema/column.rs index ee60626..f1c4efd 100644 --- a/ensemble/src/migrations/schema/column.rs +++ b/ensemble/src/migrations/schema/column.rs @@ -5,228 +5,228 @@ use std::{fmt::Display, sync::mpsc}; use super::Schemable; use crate::{ - connection::{self, Database}, - value, + connection::{self, Database}, + value, }; #[derive(Debug, Clone, PartialEq, Eq)] pub enum Type { - Json, - Uuid, - Text, - Boolean, - Timestamp, - BigInteger, - String(u32), - Enum(String, Vec), + Json, + Uuid, + Text, + Boolean, + Timestamp, + BigInteger, + String(u32), + Enum(String, Vec), } impl Display for Type { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Json => f.write_str("json"), - Self::Uuid => f.write_str("uuid"), - Self::Text => f.write_str("text"), - Self::Boolean => f.write_str("boolean"), - Self::BigInteger => f.write_str("bigint"), - Self::Timestamp => f.write_str("timestamp"), - Self::String(size) => { - let value = format!("varchar({size})"); - f.write_str(&value) - } - Self::Enum(name, values) => { - let value = match connection::which_db() { - Database::MySQL => format!( - "enum({})", - values - .iter() - .map(|v| format!("'{}'", v.replace('\'', "\\'"))) - .join(", ") - ), - Database::PostgreSQL => format!( - "varchar(255) check({name} in ({}))", - values - .iter() - .map(|v| format!("'{}'", v.replace('\'', "\\'"))) - .join(", ") - ), - }; - f.write_str(&value) - } - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Json => f.write_str("json"), + Self::Uuid => f.write_str("uuid"), + Self::Text => f.write_str("text"), + Self::Boolean => f.write_str("boolean"), + Self::BigInteger => f.write_str("bigint"), + Self::Timestamp => f.write_str("timestamp"), + Self::String(size) => { + let value = format!("varchar({size})"); + f.write_str(&value) + }, + Self::Enum(name, values) => { + let value = match connection::which_db() { + Database::MySQL => format!( + "enum({})", + values + .iter() + .map(|v| format!("'{}'", v.replace('\'', "\\'"))) + .join(", ") + ), + Database::PostgreSQL => format!( + "varchar(255) check({name} in ({}))", + values + .iter() + .map(|v| format!("'{}'", v.replace('\'', "\\'"))) + .join(", ") + ), + }; + f.write_str(&value) + }, + } + } } /// A column in a table. #[derive(Debug, Clone, Column)] #[allow(clippy::struct_excessive_bools, dead_code)] pub struct Column { - /// The name of the column. - #[builder(init)] - name: String, - /// The type of the column. - #[builder(init)] - r#type: Type, - #[cfg(feature = "mysql")] - /// Place the column "after" another column - after: Option, - /// Set INTEGER columns as auto-increment (primary key) - #[builder(rename = "increments", type = Type::BigInteger, needs = [primary, unique])] - auto_increment: bool, - /// Automatically generate UUIDs for the column - #[builder(type = Type::Uuid)] - uuid: bool, - /// Add a comment to the column - comment: Option, - /// Specify a "default" value for the column - #[builder(skip)] - default: Option, - /// Add an index - index: Option, - /// Allow NULL values to be inserted into the column - nullable: bool, - /// Add a primary index - primary: bool, - /// Add a unique index - unique: bool, - /// Specify a "collation" for the column - collation: Option, - /// Set the INTEGER column as UNSIGNED - #[cfg(feature = "mysql")] - #[builder(type = Type::BigInteger)] - unsigned: bool, - /// Set the TIMESTAMP column to use CURRENT_TIMESTAMP as default value - #[builder(type = Type::Timestamp)] - use_current: bool, - /// Set the TIMESTAMP column to use CURRENT_TIMESTAMP when updating - #[cfg(feature = "mysql")] - #[builder(type = Type::Timestamp)] - use_current_on_update: bool, - - /// The channel to send the column to when it is dropped. - #[builder(init)] - tx: Option>, + /// The name of the column. + #[builder(init)] + name: String, + /// The type of the column. + #[builder(init)] + r#type: Type, + #[cfg(feature = "mysql")] + /// Place the column "after" another column + after: Option, + /// Set INTEGER columns as auto-increment (primary key) + #[builder(rename = "increments", type = Type::BigInteger, needs = [primary, unique])] + auto_increment: bool, + /// Automatically generate UUIDs for the column + #[builder(type = Type::Uuid)] + uuid: bool, + /// Add a comment to the column + comment: Option, + /// Specify a "default" value for the column + #[builder(skip)] + default: Option, + /// Add an index + index: Option, + /// Allow NULL values to be inserted into the column + nullable: bool, + /// Add a primary index + primary: bool, + /// Add a unique index + unique: bool, + /// Specify a "collation" for the column + collation: Option, + /// Set the INTEGER column as UNSIGNED + #[cfg(feature = "mysql")] + #[builder(type = Type::BigInteger)] + unsigned: bool, + /// Set the TIMESTAMP column to use CURRENT_TIMESTAMP as default value + #[builder(type = Type::Timestamp)] + use_current: bool, + /// Set the TIMESTAMP column to use CURRENT_TIMESTAMP when updating + #[cfg(feature = "mysql")] + #[builder(type = Type::Timestamp)] + use_current_on_update: bool, + + /// The channel to send the column to when it is dropped. + #[builder(init)] + tx: Option>, } impl Column { - /// Specify a "default" value for the column - pub fn default(mut self, default: T) -> Self { - let value = if self.r#type == Type::Json { - Value::String(serde_json::to_string(&default).unwrap()) - } else { - value::for_db(default).unwrap() - }; - - if let Type::Enum(_, values) = &self.r#type { - assert!( - values.contains(&value.as_str().unwrap_or_default().to_string()), - "default value must be one of the enum values" - ); - } - - self.default = Some(value); - - self - } - - pub(crate) fn to_sql(&self) -> String { - let db_type = if connection::which_db().is_postgres() - && self.r#type == Type::BigInteger - && self.auto_increment - { - "bigserial".to_string() - } else { - self.r#type.to_string() - }; - - let mut sql = format!("{} {db_type}", self.name); - - #[cfg(feature = "mysql")] - if self.unsigned { - sql.push_str(" unsigned"); - } - - if self.nullable { - sql.push_str(" NULL"); - } else { - sql.push_str(" NOT NULL"); - } - - #[cfg(feature = "mysql")] - if let Some(after) = &self.after { - sql.push_str(&format!(" AFTER {after}")); - } - - if let Some(comment) = &self.comment { - sql.push_str(&format!(" COMMENT {comment}")); - } - - if let Some(collation) = &self.collation { - sql.push_str(&format!(" COLLATE {collation}")); - } - - if let Some(default) = &self.default { - if self.r#type == Type::Json { - sql.push_str(&format!(" DEFAULT '{}'", default.as_str().unwrap())); - } else { - sql.push_str(&format!(" DEFAULT {default}")); - } - } - - if self.uuid { - assert!( - self.default.is_none(), - "cannot set a default valud and automatically generate UUIDs at the same time" - ); - - #[cfg(feature = "mysql")] - sql.push_str(" DEFAULT (UUID())"); - - #[cfg(feature = "postgres")] - sql.push_str(" DEFAULT (gen_random_uuid())"); - } - - if self.auto_increment { - #[cfg(feature = "mysql")] - sql.push_str(" AUTO_INCREMENT"); - } - - if let Some(index) = &self.index { - sql.push_str(&format!(" INDEX {index}")); - } - - if self.primary { - sql.push_str(" PRIMARY KEY"); - } - - if self.unique { - sql.push_str(" UNIQUE"); - } - - if self.use_current { - #[cfg(feature = "mysql")] - sql.push_str(" DEFAULT CURRENT_TIMESTAMP"); - - #[cfg(feature = "postgres")] - sql.push_str(" DEFAULT now()"); - } - - #[cfg(feature = "mysql")] - if self.use_current_on_update { - sql.push_str(" ON UPDATE CURRENT_TIMESTAMP"); - } - - sql - } + /// Specify a "default" value for the column + pub fn default(mut self, default: T) -> Self { + let value = if self.r#type == Type::Json { + Value::String(serde_json::to_string(&default).unwrap()) + } else { + value::for_db(default).unwrap() + }; + + if let Type::Enum(_, values) = &self.r#type { + assert!( + values.contains(&value.as_str().unwrap_or_default().to_string()), + "default value must be one of the enum values" + ); + } + + self.default = Some(value); + + self + } + + pub(crate) fn to_sql(&self) -> String { + let db_type = if connection::which_db().is_postgres() + && self.r#type == Type::BigInteger + && self.auto_increment + { + "bigserial".to_string() + } else { + self.r#type.to_string() + }; + + let mut sql = format!("{} {db_type}", self.name); + + #[cfg(feature = "mysql")] + if self.unsigned { + sql.push_str(" unsigned"); + } + + if self.nullable { + sql.push_str(" NULL"); + } else { + sql.push_str(" NOT NULL"); + } + + #[cfg(feature = "mysql")] + if let Some(after) = &self.after { + sql.push_str(&format!(" AFTER {after}")); + } + + if let Some(comment) = &self.comment { + sql.push_str(&format!(" COMMENT {comment}")); + } + + if let Some(collation) = &self.collation { + sql.push_str(&format!(" COLLATE {collation}")); + } + + if let Some(default) = &self.default { + if self.r#type == Type::Json { + sql.push_str(&format!(" DEFAULT '{}'", default.as_str().unwrap())); + } else { + sql.push_str(&format!(" DEFAULT {default}")); + } + } + + if self.uuid { + assert!( + self.default.is_none(), + "cannot set a default valud and automatically generate UUIDs at the same time" + ); + + #[cfg(feature = "mysql")] + sql.push_str(" DEFAULT (UUID())"); + + #[cfg(feature = "postgres")] + sql.push_str(" DEFAULT (gen_random_uuid())"); + } + + if self.auto_increment { + #[cfg(feature = "mysql")] + sql.push_str(" AUTO_INCREMENT"); + } + + if let Some(index) = &self.index { + sql.push_str(&format!(" INDEX {index}")); + } + + if self.primary { + sql.push_str(" PRIMARY KEY"); + } + + if self.unique { + sql.push_str(" UNIQUE"); + } + + if self.use_current { + #[cfg(feature = "mysql")] + sql.push_str(" DEFAULT CURRENT_TIMESTAMP"); + + #[cfg(feature = "postgres")] + sql.push_str(" DEFAULT now()"); + } + + #[cfg(feature = "mysql")] + if self.use_current_on_update { + sql.push_str(" ON UPDATE CURRENT_TIMESTAMP"); + } + + sql + } } // Incredibly cursed impl that basically recreates PHP's `__destruct` magic method. // If you're mad about this, go use sqlx or something idk. impl Drop for Column { - fn drop(&mut self) { - if let Some(tx) = self.tx.take() { - tx.send(Schemable::Column(self.clone())).unwrap(); - drop(tx); - } - } + fn drop(&mut self) { + if let Some(tx) = self.tx.take() { + tx.send(Schemable::Column(self.clone())).unwrap(); + drop(tx); + } + } } diff --git a/ensemble/src/migrations/schema/command.rs b/ensemble/src/migrations/schema/command.rs index 567e020..58c4e7d 100644 --- a/ensemble/src/migrations/schema/command.rs +++ b/ensemble/src/migrations/schema/command.rs @@ -8,50 +8,50 @@ use super::Schemable; #[derive(Debug)] pub struct Command { - pub(crate) inline_sql: String, - pub(crate) post_sql: Option, + pub(crate) inline_sql: String, + pub(crate) post_sql: Option, } /// A foreign key constraint. #[derive(Debug, Clone, Column)] #[allow(dead_code)] pub struct ForeignIndex { - #[builder(init)] - column: String, - #[builder(init)] - origin_table: String, - /// The name of the foreign index. - name: Option, - /// The name of the column in the foreign table. - #[builder(rename = "references")] - foreign_column: Option, - /// The name of the foreign table. - #[builder(rename = "on")] - table: String, - /// The action to take when the foreign row is deleted. - #[builder(into)] - on_delete: Option, - /// The action to take when the foreign row is updated. - #[builder(into)] - on_update: Option, - - #[builder(init)] - tx: Option>, + #[builder(init)] + column: String, + #[builder(init)] + origin_table: String, + /// The name of the foreign index. + name: Option, + /// The name of the column in the foreign table. + #[builder(rename = "references")] + foreign_column: Option, + /// The name of the foreign table. + #[builder(rename = "on")] + table: String, + /// The action to take when the foreign row is deleted. + #[builder(into)] + on_delete: Option, + /// The action to take when the foreign row is updated. + #[builder(into)] + on_update: Option, + + #[builder(init)] + tx: Option>, } impl ForeignIndex { - fn to_sql(&self) -> (String, Option) { - let foreign_column = &self - .foreign_column - .as_ref() - .expect("failed to build index: foreign column must be specified"); - - let index_name = self.name.as_ref().map_or_else( - || format!("{}_{}_foreign", self.origin_table, self.column), - ToString::to_string, - ); - - let mut sql = match connection::which_db() { + fn to_sql(&self) -> (String, Option) { + let foreign_column = &self + .foreign_column + .as_ref() + .expect("failed to build index: foreign column must be specified"); + + let index_name = self.name.as_ref().map_or_else( + || format!("{}_{}_foreign", self.origin_table, self.column), + ToString::to_string, + ); + + let mut sql = match connection::which_db() { Database::MySQL => format!( "KEY {index_name} ({}), CONSTRAINT {index_name} FOREIGN KEY ({}) REFERENCES {}({foreign_column})", self.column, self.column, self.table, ), @@ -61,69 +61,69 @@ impl ForeignIndex { ) }; - if let Some(on_delete) = &self.on_delete { - sql.push_str(&format!(" ON DELETE {on_delete}")); - } - - if let Some(on_update) = &self.on_update { - sql.push_str(&format!(" ON UPDATE {on_update}")); - } - - match connection::which_db() { - Database::MySQL => (sql, None), - Database::PostgreSQL => ( - sql, - Some(format!( - "CREATE INDEX {index_name} ON {}({});", - self.origin_table, self.column - )), - ), - } - } + if let Some(on_delete) = &self.on_delete { + sql.push_str(&format!(" ON DELETE {on_delete}")); + } + + if let Some(on_update) = &self.on_update { + sql.push_str(&format!(" ON UPDATE {on_update}")); + } + + match connection::which_db() { + Database::MySQL => (sql, None), + Database::PostgreSQL => ( + sql, + Some(format!( + "CREATE INDEX {index_name} ON {}({});", + self.origin_table, self.column + )), + ), + } + } } // Incredibly cursed impl that basically recreates PHP's `__destruct` magic method. // If you're mad about this, go use sqlx or something idk. impl Drop for ForeignIndex { - fn drop(&mut self) { - if let Some(tx) = self.tx.take() { - let (inline_sql, post_sql) = self.to_sql(); - - tx.send(Schemable::Command(Command { - inline_sql, - post_sql, - })) - .unwrap(); - drop(tx); - } - } + fn drop(&mut self) { + if let Some(tx) = self.tx.take() { + let (inline_sql, post_sql) = self.to_sql(); + + tx.send(Schemable::Command(Command { + inline_sql, + post_sql, + })) + .unwrap(); + drop(tx); + } + } } #[derive(Debug, Clone, Copy)] pub enum OnAction { - Restrict, - Cascade, - SetNull, + Restrict, + Cascade, + SetNull, } impl Display for OnAction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Cascade => write!(f, "CASCADE"), - Self::SetNull => write!(f, "SET NULL"), - Self::Restrict => write!(f, "RESTRICT"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Cascade => write!(f, "CASCADE"), + Self::SetNull => write!(f, "SET NULL"), + Self::Restrict => write!(f, "RESTRICT"), + } + } } #[allow(clippy::fallible_impl_from)] impl From<&str> for OnAction { - fn from(s: &str) -> Self { - match s.to_uppercase().as_str() { - "CASCADE" => Self::Cascade, - "SET NULL" => Self::SetNull, - "RESTRICT" => Self::Restrict, - _ => panic!("invalid action"), - } - } + fn from(s: &str) -> Self { + match s.to_uppercase().as_str() { + "CASCADE" => Self::Cascade, + "SET NULL" => Self::SetNull, + "RESTRICT" => Self::Restrict, + _ => panic!("invalid action"), + } + } } diff --git a/ensemble/src/migrations/schema/mod.rs b/ensemble/src/migrations/schema/mod.rs index 461bca7..4cf0bc1 100644 --- a/ensemble/src/migrations/schema/mod.rs +++ b/ensemble/src/migrations/schema/mod.rs @@ -4,8 +4,8 @@ use rbs::Value; use std::{any::type_name, sync::mpsc}; use self::{ - column::{Column, Type}, - command::{Command, ForeignIndex}, + column::{Column, Type}, + command::{Command, ForeignIndex}, }; use super::{migrator::MIGRATE_CONN, Error}; use crate::{connection, Model}; @@ -16,258 +16,258 @@ mod command; pub struct Schema {} pub enum Schemable { - Column(Column), - Command(Command), + Column(Column), + Command(Command), } impl Schema { - /// Creates a new table. - /// - /// # Errors - /// - /// Returns an error if the table cannot be created, or if a connection to the database cannot be established. - #[allow(clippy::unused_async)] - pub async fn create(table_name: &str, callback: F) -> Result<(), Error> - where - F: FnOnce(&mut Table) + Send, - { - let (columns, commands) = Self::get_schema(table_name.to_string(), callback)?; - let mut conn_lock = MIGRATE_CONN.try_lock().map_err(|_| Error::Lock)?; - let mut conn = conn_lock.take().ok_or(Error::Lock)?; - - let sql = format!( - "CREATE TABLE {} ({}) {}; {}", - table_name, - columns - .iter() - .map(Column::to_sql) - .chain(commands.iter().map(|cmd| cmd.inline_sql.clone())) - .join(", "), - if connection::which_db().is_mysql() { - "ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci" - } else { - "" - }, - commands - .iter() - .filter_map(|cmd| cmd.post_sql.clone()) - .join("\n") - ); - - tracing::debug!(sql = sql.as_str(), "Running CREATE TABLE SQL query"); - let query_result = conn.exec(&sql, vec![]).await; - - conn_lock.replace(conn); - drop(conn_lock); - - match query_result { - Ok(_) => Ok(()), - Err(e) => Err(Error::Database(e.to_string())), - } - } - - /// Drops a table. - /// - /// # Errors - /// - /// Returns an error if the table cannot be dropped, or if a connection to the database cannot be established. - pub async fn drop(table_name: &str) -> Result<(), Error> { - let mut conn_lock = MIGRATE_CONN.try_lock().map_err(|_| Error::Lock)?; - let mut conn = conn_lock.take().ok_or(Error::Lock)?; - - let (sql, bindings) = ("DROP TABLE ?", vec![Value::String(table_name.to_string())]); - - tracing::debug!(sql = sql, bindings = ?bindings, "Running DROP TABLE SQL query"); - let query_result = conn.exec(sql, bindings).await; - - conn_lock.replace(conn); - drop(conn_lock); - - match query_result { - Ok(_) => Ok(()), - Err(e) => Err(Error::Database(e.to_string())), - } - } - - fn get_schema(table_name: String, callback: F) -> Result<(Vec, Vec), Error> - where - F: FnOnce(&mut Table), - { - let (tx, rx) = mpsc::channel(); - let mut table = Table { - name: table_name, - sender: Some(tx), - }; - - let ret = std::thread::spawn(move || { - let mut schema = vec![]; - - while let Ok(part) = rx.recv() { - schema.push(part); - } - - schema - }); - - callback(&mut table); - drop(table.sender.take()); - - let schema = ret.join().map_err(|_| Error::SendColumn)?; - - Ok(schema - .into_iter() - .map(|part| match part { - Schemable::Column(col) => Either::Left(col), - Schemable::Command(cmd) => Either::Right(cmd), - }) - .partition_map(|part| part)) - } + /// Creates a new table. + /// + /// # Errors + /// + /// Returns an error if the table cannot be created, or if a connection to the database cannot be established. + #[allow(clippy::unused_async)] + pub async fn create(table_name: &str, callback: F) -> Result<(), Error> + where + F: FnOnce(&mut Table) + Send, + { + let (columns, commands) = Self::get_schema(table_name.to_string(), callback)?; + let mut conn_lock = MIGRATE_CONN.try_lock().map_err(|_| Error::Lock)?; + let mut conn = conn_lock.take().ok_or(Error::Lock)?; + + let sql = format!( + "CREATE TABLE {} ({}) {}; {}", + table_name, + columns + .iter() + .map(Column::to_sql) + .chain(commands.iter().map(|cmd| cmd.inline_sql.clone())) + .join(", "), + if connection::which_db().is_mysql() { + "ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci" + } else { + "" + }, + commands + .iter() + .filter_map(|cmd| cmd.post_sql.clone()) + .join("\n") + ); + + tracing::debug!(sql = sql.as_str(), "Running CREATE TABLE SQL query"); + let query_result = conn.exec(&sql, vec![]).await; + + conn_lock.replace(conn); + drop(conn_lock); + + match query_result { + Ok(_) => Ok(()), + Err(e) => Err(Error::Database(e.to_string())), + } + } + + /// Drops a table. + /// + /// # Errors + /// + /// Returns an error if the table cannot be dropped, or if a connection to the database cannot be established. + pub async fn drop(table_name: &str) -> Result<(), Error> { + let mut conn_lock = MIGRATE_CONN.try_lock().map_err(|_| Error::Lock)?; + let mut conn = conn_lock.take().ok_or(Error::Lock)?; + + let (sql, bindings) = ("DROP TABLE ?", vec![Value::String(table_name.to_string())]); + + tracing::debug!(sql = sql, bindings = ?bindings, "Running DROP TABLE SQL query"); + let query_result = conn.exec(sql, bindings).await; + + conn_lock.replace(conn); + drop(conn_lock); + + match query_result { + Ok(_) => Ok(()), + Err(e) => Err(Error::Database(e.to_string())), + } + } + + fn get_schema(table_name: String, callback: F) -> Result<(Vec, Vec), Error> + where + F: FnOnce(&mut Table), + { + let (tx, rx) = mpsc::channel(); + let mut table = Table { + name: table_name, + sender: Some(tx), + }; + + let ret = std::thread::spawn(move || { + let mut schema = vec![]; + + while let Ok(part) = rx.recv() { + schema.push(part); + } + + schema + }); + + callback(&mut table); + drop(table.sender.take()); + + let schema = ret.join().map_err(|_| Error::SendColumn)?; + + Ok(schema + .into_iter() + .map(|part| match part { + Schemable::Column(col) => Either::Left(col), + Schemable::Command(cmd) => Either::Right(cmd), + }) + .partition_map(|part| part)) + } } #[derive(Debug)] pub struct Table { - name: String, - sender: Option>, + name: String, + sender: Option>, } impl Table { - /// Creates a primary key incrementing integer column called `id`. - pub fn id(&mut self) -> Column { - let column = Column::new("id".to_string(), Type::BigInteger, self.sender.clone()) - .primary(true) - .increments(true); - - #[cfg(feature = "mysql")] - { - column.unsigned(true) - } - - #[cfg(not(feature = "mysql"))] - { - column - } - } - - /// Create a primary key UUID column called `id`. - pub fn uuid(&mut self) -> Column { - Column::new("id".to_string(), Type::Uuid, self.sender.clone()) - .uuid(true) - .primary(true) - } - - /// Create a new big integer (8-byte) column on the table. - pub fn integer(&mut self, name: &str) -> Column { - Column::new(name.to_string(), Type::BigInteger, self.sender.clone()) - } - - /// Create a new json column on the table. - pub fn json(&mut self, name: &str) -> Column { - Column::new(name.to_string(), Type::Json, self.sender.clone()) - } - - /// Create a new string column on the table. - pub fn string(&mut self, name: &str) -> Column { - Column::new(name.to_string(), Type::String(255), self.sender.clone()) - } - - /// Create a new boolean column on the table. - pub fn boolean(&mut self, name: &str) -> Column { - Column::new(name.to_string(), Type::Boolean, self.sender.clone()) - } - - /// Create a new text column on the table. - pub fn text(&mut self, name: &str) -> Column { - Column::new(name.to_string(), Type::Text, self.sender.clone()) - } - - /// Create a new timestamp column on the table. - pub fn timestamp(&mut self, name: &str) -> Column { - Column::new(name.to_string(), Type::Timestamp, self.sender.clone()) - } - - /// Specify a foreign key for the table. - pub fn foreign(&mut self, column: &str) -> ForeignIndex { - ForeignIndex::new(column.to_string(), self.name.clone(), self.sender.clone()) - } - - #[cfg(feature = "mysql")] - /// Create a new enum column on the table. - pub fn r#enum(&mut self, name: &str, values: &[&str]) -> Column { - Column::new( - name.to_string(), - Type::Enum( - name.to_string(), - values.iter().map(ToString::to_string).collect(), - ), - self.sender.clone(), - ) - } - - /// Create a foreign ID column for the given model. - pub fn foreign_id_for(&mut self) -> ForeignIndex { - let column = format!("{}_{}", M::NAME, M::PRIMARY_KEY).to_snake_case(); - - if ["u64", "u32", "u16", "u8", "usize"].contains(&type_name::()) { - #[allow(unused_variables)] - let column = Column::new(column.clone(), Type::BigInteger, self.sender.clone()); - - #[cfg(feature = "mysql")] - { - column.unsigned(true); - }; - } else { - Column::new(column.clone(), Type::String(255), self.sender.clone()); - } - - let index = ForeignIndex::new(column, self.name.clone(), self.sender.clone()); - index.on(M::TABLE_NAME).references(M::PRIMARY_KEY) - } - - /// Create a foreign ID column for the given model. - pub fn foreign_id(&mut self, name: &str) -> ForeignIndex { - #[allow(unused_variables)] - let column = Column::new(name.to_string(), Type::BigInteger, self.sender.clone()); - - #[cfg(feature = "mysql")] - { - column.unsigned(true); - }; - - let index = ForeignIndex::new(name.to_string(), self.name.clone(), self.sender.clone()); - - // if the column name is of the form `resource_id`, we extract and set the table name and foreign column name - if let Some((resource, column)) = name.split_once('_') { - index.on(&resource.to_plural()).references(column) - } else { - index - } - } - - /// Create a foreign UUID column for the given model. - pub fn foreign_uuid(&mut self, name: &str) -> ForeignIndex { - Column::new(name.to_string(), Type::Uuid, self.sender.clone()).uuid(true); - let index = ForeignIndex::new(name.to_string(), self.name.clone(), self.sender.clone()); - - // if the column name is of the form `resource_id`, we extract and set the table name and foreign column name - if let Some((resource, column)) = name.split_once('_') { - index.on(&resource.to_plural()).references(column) - } else { - index - } - } - - /// Add nullable creation and update timestamps to the table. - pub fn timestamps(&mut self) { - self.timestamp("created_at") - .nullable(true) - .use_current(true); - - #[allow(unused_variables)] - let updated_at = self.timestamp("updated_at").nullable(true); - - #[cfg(feature = "mysql")] - { - updated_at.use_current_on_update(true); - } - } + /// Creates a primary key incrementing integer column called `id`. + pub fn id(&mut self) -> Column { + let column = Column::new("id".to_string(), Type::BigInteger, self.sender.clone()) + .primary(true) + .increments(true); + + #[cfg(feature = "mysql")] + { + column.unsigned(true) + } + + #[cfg(not(feature = "mysql"))] + { + column + } + } + + /// Create a primary key UUID column called `id`. + pub fn uuid(&mut self) -> Column { + Column::new("id".to_string(), Type::Uuid, self.sender.clone()) + .uuid(true) + .primary(true) + } + + /// Create a new big integer (8-byte) column on the table. + pub fn integer(&mut self, name: &str) -> Column { + Column::new(name.to_string(), Type::BigInteger, self.sender.clone()) + } + + /// Create a new json column on the table. + pub fn json(&mut self, name: &str) -> Column { + Column::new(name.to_string(), Type::Json, self.sender.clone()) + } + + /// Create a new string column on the table. + pub fn string(&mut self, name: &str) -> Column { + Column::new(name.to_string(), Type::String(255), self.sender.clone()) + } + + /// Create a new boolean column on the table. + pub fn boolean(&mut self, name: &str) -> Column { + Column::new(name.to_string(), Type::Boolean, self.sender.clone()) + } + + /// Create a new text column on the table. + pub fn text(&mut self, name: &str) -> Column { + Column::new(name.to_string(), Type::Text, self.sender.clone()) + } + + /// Create a new timestamp column on the table. + pub fn timestamp(&mut self, name: &str) -> Column { + Column::new(name.to_string(), Type::Timestamp, self.sender.clone()) + } + + /// Specify a foreign key for the table. + pub fn foreign(&mut self, column: &str) -> ForeignIndex { + ForeignIndex::new(column.to_string(), self.name.clone(), self.sender.clone()) + } + + #[cfg(feature = "mysql")] + /// Create a new enum column on the table. + pub fn r#enum(&mut self, name: &str, values: &[&str]) -> Column { + Column::new( + name.to_string(), + Type::Enum( + name.to_string(), + values.iter().map(ToString::to_string).collect(), + ), + self.sender.clone(), + ) + } + + /// Create a foreign ID column for the given model. + pub fn foreign_id_for(&mut self) -> ForeignIndex { + let column = format!("{}_{}", M::NAME, M::PRIMARY_KEY).to_snake_case(); + + if ["u64", "u32", "u16", "u8", "usize"].contains(&type_name::()) { + #[allow(unused_variables)] + let column = Column::new(column.clone(), Type::BigInteger, self.sender.clone()); + + #[cfg(feature = "mysql")] + { + column.unsigned(true); + }; + } else { + Column::new(column.clone(), Type::String(255), self.sender.clone()); + } + + let index = ForeignIndex::new(column, self.name.clone(), self.sender.clone()); + index.on(M::TABLE_NAME).references(M::PRIMARY_KEY) + } + + /// Create a foreign ID column for the given model. + pub fn foreign_id(&mut self, name: &str) -> ForeignIndex { + #[allow(unused_variables)] + let column = Column::new(name.to_string(), Type::BigInteger, self.sender.clone()); + + #[cfg(feature = "mysql")] + { + column.unsigned(true); + }; + + let index = ForeignIndex::new(name.to_string(), self.name.clone(), self.sender.clone()); + + // if the column name is of the form `resource_id`, we extract and set the table name and foreign column name + if let Some((resource, column)) = name.split_once('_') { + index.on(&resource.to_plural()).references(column) + } else { + index + } + } + + /// Create a foreign UUID column for the given model. + pub fn foreign_uuid(&mut self, name: &str) -> ForeignIndex { + Column::new(name.to_string(), Type::Uuid, self.sender.clone()).uuid(true); + let index = ForeignIndex::new(name.to_string(), self.name.clone(), self.sender.clone()); + + // if the column name is of the form `resource_id`, we extract and set the table name and foreign column name + if let Some((resource, column)) = name.split_once('_') { + index.on(&resource.to_plural()).references(column) + } else { + index + } + } + + /// Add nullable creation and update timestamps to the table. + pub fn timestamps(&mut self) { + self.timestamp("created_at") + .nullable(true) + .use_current(true); + + #[allow(unused_variables)] + let updated_at = self.timestamp("updated_at").nullable(true); + + #[cfg(feature = "mysql")] + { + updated_at.use_current_on_update(true); + } + } } diff --git a/ensemble/src/query.rs b/ensemble/src/query.rs index 3b9db29..4dbe667 100644 --- a/ensemble/src/query.rs +++ b/ensemble/src/query.rs @@ -2,8 +2,8 @@ use itertools::Itertools; use rbs::Value; use serde::Serialize; use std::{ - collections::{HashMap, HashSet}, - fmt::Display, + collections::{HashMap, HashSet}, + fmt::Display, }; use crate::{connection, value, Error, Model}; @@ -11,826 +11,826 @@ use crate::{connection, value, Error, Model}; /// The Query Builder. #[derive(Debug)] pub struct Builder { - table: String, - join: Vec, - order: Vec, - limit: Option, - offset: Option, - r#where: Vec, - eager_load: HashSet, + table: String, + join: Vec, + order: Vec, + limit: Option, + offset: Option, + r#where: Vec, + eager_load: HashSet, } impl Builder { - pub(crate) fn new(table: String) -> Self { - Self { - table, - limit: None, - offset: None, - join: vec![], - order: vec![], - r#where: vec![], - eager_load: HashSet::new(), - } - } - - /// Execute a raw SQL query and return the results. - /// - /// # Safety - /// - /// This method is unsafe because it allows for arbitrary SQL to be executed, which can lead to SQL injection. - /// It is recommended to build queries using the methods provided by the query builder instead. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub async unsafe fn raw_sql(sql: &str, bindings: Vec) -> Result, Error> { - let mut conn = connection::get().await?; - - conn.get_values(sql, bindings) - .await - .map_err(|e| Error::Database(e.to_string())) - } - - /// Set the table which the query is targeting. - #[must_use] - pub fn from(mut self, table: &str) -> Self { - self.table = table.to_string(); - self - } - - /// Apply the given callback to the builder if the provided condition is true. - #[must_use] - pub fn when(mut self, condition: bool, r#fn: impl FnOnce(Self) -> Self) -> Self { - if condition { - self = r#fn(self); - } - - self - } - - /// Apply the given callback to the builder if the provided [`Option`] is `Some`. - #[must_use] - pub fn when_some(mut self, value: Option, r#fn: impl FnOnce(Self, T) -> Self) -> Self { - if let Some(value) = value { - self = r#fn(self, value); - } - - self - } - - /// Add a basic where clause to the query. - /// - /// # Panics - /// - /// Panics if the provided value cannot be serialized. - #[must_use] - pub fn r#where(mut self, column: &str, operator: Op, value: T) -> Self - where - Op: Into, - T: serde::Serialize, - { - self.r#where.push(WhereClause::Simple(Where { - boolean: Boolean::And, - operator: operator.into(), - column: column.to_string(), - value: Some(value::for_db(value).unwrap()), - })); - - self - } - - /// Set the "limit" value of the query. - #[must_use] - pub const fn limit(mut self, take: usize) -> Self { - self.limit = Some(take); - self - } - - /// Set the "offset" value of the query. - #[must_use] - pub const fn offset(mut self, skip: usize) -> Self { - self.offset = Some(skip); - self - } - - /// Set the relationships that should be eager loaded. - #[must_use] - pub fn with>(mut self, relations: T) -> Self { - self.eager_load.extend(relations.into().list()); - - self - } - - /// Add an "or where" clause to the query. - /// - /// # Panics - /// - /// Panics if this is the first where clause. - #[must_use] - pub fn or_where(mut self, column: &str, op: Op, value: T) -> Self - where - T: Into, - Op: Into, - { - assert!( - !self.r#where.is_empty(), - "Cannot use or_where without a where clause." - ); - - self.r#where.push(WhereClause::Simple(Where { - operator: op.into(), - boolean: Boolean::Or, - value: Some(value.into()), - column: column.to_string(), - })); - - self - } - - /// Add a "where not null" clause to the query. - #[must_use] - pub fn where_not_null(mut self, column: &str) -> Self { - self.r#where.push(WhereClause::Simple(Where { - value: None, - boolean: Boolean::And, - column: column.to_string(), - operator: Operator::NotNull, - })); - - self - } - - // Add a "where in" clause to the query. - #[must_use] - pub fn where_in(mut self, column: &str, values: Vec) -> Self - where - T: Into, - { - self.r#where.push(WhereClause::Simple(Where { - boolean: Boolean::And, - operator: Operator::In, - column: column.to_string(), - value: Some(Value::Array(values.into_iter().map(Into::into).collect())), - })); - - self - } - - /// Add a "where is null" clause to the query. - #[must_use] - pub fn where_null(mut self, column: &str) -> Self { - self.r#where.push(WhereClause::Simple(Where { - value: None, - boolean: Boolean::And, - column: column.to_string(), - operator: Operator::IsNull, - })); - - self - } - - /// Add an inner join to the query. - #[must_use] - pub fn join>( - mut self, - column: &str, - first: &str, - op: Op, - second: &str, - ) -> Self { - self.join.push(Join { - operator: op.into(), - first: first.to_string(), - column: column.to_string(), - r#type: JoinType::Inner, - second: second.to_string(), - }); - - self - } - - /// Add an "order by" clause to the query. - #[must_use] - pub fn order_by>(mut self, column: &str, direction: Dir) -> Self { - self.order.push(Order { - column: column.to_string(), - direction: direction.into(), - }); - - self - } - - /// Logically group a set of where clauses. - #[must_use] - pub fn where_group(mut self, r#fn: impl FnOnce(Self) -> Self) -> Self { - let builder = r#fn(Self::new(self.table.clone())); - - self.r#where - .push(WhereClause::Group(builder.r#where, Boolean::And)); - - self - } - - /// Get the SQL representation of the query. - #[must_use] - pub fn to_sql(&self, r#type: Type) -> String { - let mut sql = match r#type { - Type::Update => String::new(), // handled in update() - Type::Delete => format!("DELETE FROM {}", self.table), - Type::Select => format!("SELECT * FROM {}", self.table), - Type::Count => format!("SELECT COUNT(*) FROM {}", self.table), - }; - - if !self.join.is_empty() { - for join in &self.join { - sql.push_str(&format!( - " {} {} ON {} {} {}", - join.r#type, join.column, join.first, join.operator, join.second - )); - } - } - - if !self.r#where.is_empty() { - sql.push_str(" WHERE "); - - for (i, where_clause) in self.r#where.iter().enumerate() { - sql.push_str(&where_clause.to_sql(i != self.r#where.len() - 1)); - } - } - - if !self.order.is_empty() { - sql.push_str(" ORDER BY "); - - sql.push_str( - &self - .order - .iter() - .map(|order| format!("{} {}", order.column, order.direction)) - .join(", "), - ); - } - - if let Some(take) = self.limit { - sql.push_str(&format!(" LIMIT {take}")); - } - - if let Some(skip) = self.offset { - sql.push_str(&format!(" OFFSET {skip}")); - } - - sql - } - - /// Get the current query value bindings. - #[must_use] - pub fn get_bindings(&self) -> Vec { - self.r#where - .iter() - .flat_map(WhereClause::get_bindings) - .collect() - } - - /// Retrieve the number of records that match the query constraints. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub async fn count(self) -> Result { - let mut conn = connection::get().await?; - - let values = conn - .get_values(&self.to_sql(Type::Count), self.get_bindings()) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - values.first().and_then(Value::as_u64).ok_or_else(|| { - Error::Serialization(rbs::value::ext::Error::Syntax( - "Failed to parse count value".to_string(), - )) - }) - } - - /// Execute the query and return the first result. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub async fn first(mut self) -> Result, Error> { - self.limit = Some(1); - let values = self.get::().await?; - - Ok(values.into_iter().next()) - } - - /// Execute the query and return the results. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub async fn get(self) -> Result, Error> { - let mut models = self - ._get() - .await? - .into_iter() - .map(value::from::) - .collect::, rbs::Error>>()?; - - if models.is_empty() || self.eager_load.is_empty() { - return Ok(models); - } - - let model = M::default(); - for relation in self.eager_load { - tracing::trace!( - "Eager loading {relation} relation for {} models", - models.len() - ); - - let rows = model - .eager_load(&relation, models.iter().collect::>().as_slice()) - .get_rows() - .await?; - - for model in &mut models { - model.fill_relation(&relation, &rows)?; - } - } - - Ok(models) - } - - /// Execute the query and return the results as a vector of rows. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub(crate) async fn get_rows(&self) -> Result>, Error> { - let values = self - ._get() - .await? - .into_iter() - .map(|v| { - let Value::Map(map) = v else { unreachable!() }; - - map.into_iter() - .map(|(k, v)| (k.into_string().unwrap_or_else(|| unreachable!()), v)) - .collect() - }) - .collect(); - - Ok(values) - } - - /// Insert a new record into the database. Returns the ID of the inserted record, if applicable. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub async fn insert serde::Deserialize<'de>, T: Into + Send>( - &self, - columns: T, - ) -> Result { - if self.limit.is_some() - || !self.join.is_empty() - || !self.order.is_empty() - || !self.r#where.is_empty() - { - return Err(Error::InvalidQuery); - } - - let mut conn = connection::get().await?; - let values: Vec<(String, Value)> = columns.into().0; - - let (sql, bindings) = ( - format!( - "INSERT INTO {} ({}) VALUES ({})", - self.table, - values.iter().map(|(column, _)| column).join(", "), - values.iter().map(|_| "?").join(", ") - ), - values.into_iter().map(|(_, value)| value).collect(), - ); - - tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing INSERT SQL query"); - - let result = conn - .exec(&sql, bindings) - .await - .map_err(|e| Error::Database(e.to_string()))?; - - Ok(rbs::from_value(result.last_insert_id)?) - } - - /// Update records in the database. Returns the number of affected rows. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub async fn update + Send>(self, values: T) -> Result { - let mut conn = connection::get().await?; - let values: Vec<(String, Value)> = values.into().0; - - let (sql, bindings) = ( - format!( - "UPDATE {} SET {} {}", - self.table, - values - .iter() - .map(|(column, _)| format!("{column} = ?")) - .join(", "), - self.to_sql(Type::Update) - ), - values - .iter() - .map(|(_, value)| value.clone()) - .chain(self.get_bindings()) - .collect(), - ); - - tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing UPDATE SQL query"); - - conn.exec(&sql, bindings) - .await - .map_err(|e| Error::Database(e.to_string())) - .map(|r| r.rows_affected) - } - - /// Delete records from the database. Returns the number of affected rows. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub async fn delete(self) -> Result { - let mut conn = connection::get().await?; - let (sql, bindings) = (self.to_sql(Type::Delete), self.get_bindings()); - - tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing DELETE SQL query"); - - conn.exec(&sql, bindings) - .await - .map_err(|e| Error::Database(e.to_string())) - .map(|r| r.rows_affected) - } - - /// Run a truncate statement on the table. Returns the number of affected rows. - /// - /// # Errors - /// - /// Returns an error if the query fails, or if a connection to the database cannot be established. - pub async fn truncate(self) -> Result { - let mut conn = connection::get().await?; - let sql = format!("TRUNCATE TABLE {}", self.table); - - tracing::debug!(sql = sql.as_str(), "Executing TRUNCATE SQL query"); - - conn.exec(&sql, vec![]) - .await - .map_err(|e| Error::Database(e.to_string())) - .map(|r| r.rows_affected) - } + pub(crate) fn new(table: String) -> Self { + Self { + table, + limit: None, + offset: None, + join: vec![], + order: vec![], + r#where: vec![], + eager_load: HashSet::new(), + } + } + + /// Execute a raw SQL query and return the results. + /// + /// # Safety + /// + /// This method is unsafe because it allows for arbitrary SQL to be executed, which can lead to SQL injection. + /// It is recommended to build queries using the methods provided by the query builder instead. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async unsafe fn raw_sql(sql: &str, bindings: Vec) -> Result, Error> { + let mut conn = connection::get().await?; + + conn.get_values(sql, bindings) + .await + .map_err(|e| Error::Database(e.to_string())) + } + + /// Set the table which the query is targeting. + #[must_use] + pub fn from(mut self, table: &str) -> Self { + self.table = table.to_string(); + self + } + + /// Apply the given callback to the builder if the provided condition is true. + #[must_use] + pub fn when(mut self, condition: bool, r#fn: impl FnOnce(Self) -> Self) -> Self { + if condition { + self = r#fn(self); + } + + self + } + + /// Apply the given callback to the builder if the provided [`Option`] is `Some`. + #[must_use] + pub fn when_some(mut self, value: Option, r#fn: impl FnOnce(Self, T) -> Self) -> Self { + if let Some(value) = value { + self = r#fn(self, value); + } + + self + } + + /// Add a basic where clause to the query. + /// + /// # Panics + /// + /// Panics if the provided value cannot be serialized. + #[must_use] + pub fn r#where(mut self, column: &str, operator: Op, value: T) -> Self + where + Op: Into, + T: serde::Serialize, + { + self.r#where.push(WhereClause::Simple(Where { + boolean: Boolean::And, + operator: operator.into(), + column: column.to_string(), + value: Some(value::for_db(value).unwrap()), + })); + + self + } + + /// Set the "limit" value of the query. + #[must_use] + pub const fn limit(mut self, take: usize) -> Self { + self.limit = Some(take); + self + } + + /// Set the "offset" value of the query. + #[must_use] + pub const fn offset(mut self, skip: usize) -> Self { + self.offset = Some(skip); + self + } + + /// Set the relationships that should be eager loaded. + #[must_use] + pub fn with>(mut self, relations: T) -> Self { + self.eager_load.extend(relations.into().list()); + + self + } + + /// Add an "or where" clause to the query. + /// + /// # Panics + /// + /// Panics if this is the first where clause. + #[must_use] + pub fn or_where(mut self, column: &str, op: Op, value: T) -> Self + where + T: Into, + Op: Into, + { + assert!( + !self.r#where.is_empty(), + "Cannot use or_where without a where clause." + ); + + self.r#where.push(WhereClause::Simple(Where { + operator: op.into(), + boolean: Boolean::Or, + value: Some(value.into()), + column: column.to_string(), + })); + + self + } + + /// Add a "where not null" clause to the query. + #[must_use] + pub fn where_not_null(mut self, column: &str) -> Self { + self.r#where.push(WhereClause::Simple(Where { + value: None, + boolean: Boolean::And, + column: column.to_string(), + operator: Operator::NotNull, + })); + + self + } + + // Add a "where in" clause to the query. + #[must_use] + pub fn where_in(mut self, column: &str, values: Vec) -> Self + where + T: Into, + { + self.r#where.push(WhereClause::Simple(Where { + boolean: Boolean::And, + operator: Operator::In, + column: column.to_string(), + value: Some(Value::Array(values.into_iter().map(Into::into).collect())), + })); + + self + } + + /// Add a "where is null" clause to the query. + #[must_use] + pub fn where_null(mut self, column: &str) -> Self { + self.r#where.push(WhereClause::Simple(Where { + value: None, + boolean: Boolean::And, + column: column.to_string(), + operator: Operator::IsNull, + })); + + self + } + + /// Add an inner join to the query. + #[must_use] + pub fn join>( + mut self, + column: &str, + first: &str, + op: Op, + second: &str, + ) -> Self { + self.join.push(Join { + operator: op.into(), + first: first.to_string(), + column: column.to_string(), + r#type: JoinType::Inner, + second: second.to_string(), + }); + + self + } + + /// Add an "order by" clause to the query. + #[must_use] + pub fn order_by>(mut self, column: &str, direction: Dir) -> Self { + self.order.push(Order { + column: column.to_string(), + direction: direction.into(), + }); + + self + } + + /// Logically group a set of where clauses. + #[must_use] + pub fn where_group(mut self, r#fn: impl FnOnce(Self) -> Self) -> Self { + let builder = r#fn(Self::new(self.table.clone())); + + self.r#where + .push(WhereClause::Group(builder.r#where, Boolean::And)); + + self + } + + /// Get the SQL representation of the query. + #[must_use] + pub fn to_sql(&self, r#type: Type) -> String { + let mut sql = match r#type { + Type::Update => String::new(), // handled in update() + Type::Delete => format!("DELETE FROM {}", self.table), + Type::Select => format!("SELECT * FROM {}", self.table), + Type::Count => format!("SELECT COUNT(*) FROM {}", self.table), + }; + + if !self.join.is_empty() { + for join in &self.join { + sql.push_str(&format!( + " {} {} ON {} {} {}", + join.r#type, join.column, join.first, join.operator, join.second + )); + } + } + + if !self.r#where.is_empty() { + sql.push_str(" WHERE "); + + for (i, where_clause) in self.r#where.iter().enumerate() { + sql.push_str(&where_clause.to_sql(i != self.r#where.len() - 1)); + } + } + + if !self.order.is_empty() { + sql.push_str(" ORDER BY "); + + sql.push_str( + &self + .order + .iter() + .map(|order| format!("{} {}", order.column, order.direction)) + .join(", "), + ); + } + + if let Some(take) = self.limit { + sql.push_str(&format!(" LIMIT {take}")); + } + + if let Some(skip) = self.offset { + sql.push_str(&format!(" OFFSET {skip}")); + } + + sql + } + + /// Get the current query value bindings. + #[must_use] + pub fn get_bindings(&self) -> Vec { + self.r#where + .iter() + .flat_map(WhereClause::get_bindings) + .collect() + } + + /// Retrieve the number of records that match the query constraints. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async fn count(self) -> Result { + let mut conn = connection::get().await?; + + let values = conn + .get_values(&self.to_sql(Type::Count), self.get_bindings()) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + values.first().and_then(Value::as_u64).ok_or_else(|| { + Error::Serialization(rbs::value::ext::Error::Syntax( + "Failed to parse count value".to_string(), + )) + }) + } + + /// Execute the query and return the first result. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async fn first(mut self) -> Result, Error> { + self.limit = Some(1); + let values = self.get::().await?; + + Ok(values.into_iter().next()) + } + + /// Execute the query and return the results. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async fn get(self) -> Result, Error> { + let mut models = self + ._get() + .await? + .into_iter() + .map(value::from::) + .collect::, rbs::Error>>()?; + + if models.is_empty() || self.eager_load.is_empty() { + return Ok(models); + } + + let model = M::default(); + for relation in self.eager_load { + tracing::trace!( + "Eager loading {relation} relation for {} models", + models.len() + ); + + let rows = model + .eager_load(&relation, models.iter().collect::>().as_slice()) + .get_rows() + .await?; + + for model in &mut models { + model.fill_relation(&relation, &rows)?; + } + } + + Ok(models) + } + + /// Execute the query and return the results as a vector of rows. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub(crate) async fn get_rows(&self) -> Result>, Error> { + let values = self + ._get() + .await? + .into_iter() + .map(|v| { + let Value::Map(map) = v else { unreachable!() }; + + map.into_iter() + .map(|(k, v)| (k.into_string().unwrap_or_else(|| unreachable!()), v)) + .collect() + }) + .collect(); + + Ok(values) + } + + /// Insert a new record into the database. Returns the ID of the inserted record, if applicable. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async fn insert serde::Deserialize<'de>, T: Into + Send>( + &self, + columns: T, + ) -> Result { + if self.limit.is_some() + || !self.join.is_empty() + || !self.order.is_empty() + || !self.r#where.is_empty() + { + return Err(Error::InvalidQuery); + } + + let mut conn = connection::get().await?; + let values: Vec<(String, Value)> = columns.into().0; + + let (sql, bindings) = ( + format!( + "INSERT INTO {} ({}) VALUES ({})", + self.table, + values.iter().map(|(column, _)| column).join(", "), + values.iter().map(|_| "?").join(", ") + ), + values.into_iter().map(|(_, value)| value).collect(), + ); + + tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing INSERT SQL query"); + + let result = conn + .exec(&sql, bindings) + .await + .map_err(|e| Error::Database(e.to_string()))?; + + Ok(rbs::from_value(result.last_insert_id)?) + } + + /// Update records in the database. Returns the number of affected rows. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async fn update + Send>(self, values: T) -> Result { + let mut conn = connection::get().await?; + let values: Vec<(String, Value)> = values.into().0; + + let (sql, bindings) = ( + format!( + "UPDATE {} SET {} {}", + self.table, + values + .iter() + .map(|(column, _)| format!("{column} = ?")) + .join(", "), + self.to_sql(Type::Update) + ), + values + .iter() + .map(|(_, value)| value.clone()) + .chain(self.get_bindings()) + .collect(), + ); + + tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing UPDATE SQL query"); + + conn.exec(&sql, bindings) + .await + .map_err(|e| Error::Database(e.to_string())) + .map(|r| r.rows_affected) + } + + /// Delete records from the database. Returns the number of affected rows. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async fn delete(self) -> Result { + let mut conn = connection::get().await?; + let (sql, bindings) = (self.to_sql(Type::Delete), self.get_bindings()); + + tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing DELETE SQL query"); + + conn.exec(&sql, bindings) + .await + .map_err(|e| Error::Database(e.to_string())) + .map(|r| r.rows_affected) + } + + /// Run a truncate statement on the table. Returns the number of affected rows. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async fn truncate(self) -> Result { + let mut conn = connection::get().await?; + let sql = format!("TRUNCATE TABLE {}", self.table); + + tracing::debug!(sql = sql.as_str(), "Executing TRUNCATE SQL query"); + + conn.exec(&sql, vec![]) + .await + .map_err(|e| Error::Database(e.to_string())) + .map(|r| r.rows_affected) + } } impl Builder { - async fn _get(&self) -> Result, Error> { - let mut conn = connection::get().await?; - let (sql, bindings) = (self.to_sql(Type::Select), self.get_bindings()); + async fn _get(&self) -> Result, Error> { + let mut conn = connection::get().await?; + let (sql, bindings) = (self.to_sql(Type::Select), self.get_bindings()); - tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing SELECT SQL query"); + tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing SELECT SQL query"); - let values = conn - .get_values(&sql, bindings) - .await - .map_err(|s| Error::Database(s.to_string()))?; + let values = conn + .get_values(&sql, bindings) + .await + .map_err(|s| Error::Database(s.to_string()))?; - Ok(values) - } + Ok(values) + } } pub enum EagerLoad { - Single(String), - Multiple(Vec), + Single(String), + Multiple(Vec), } impl EagerLoad { - #[must_use] - pub fn list(self) -> Vec { - match self { - Self::Single(value) => vec![value], - Self::Multiple(value) => value, - } - } + #[must_use] + pub fn list(self) -> Vec { + match self { + Self::Single(value) => vec![value], + Self::Multiple(value) => value, + } + } } impl From<&str> for EagerLoad { - fn from(value: &str) -> Self { - Self::Single(value.to_string()) - } + fn from(value: &str) -> Self { + Self::Single(value.to_string()) + } } impl From> for EagerLoad { - fn from(value: Vec<&str>) -> Self { - Self::Multiple(value.iter().map(ToString::to_string).collect()) - } + fn from(value: Vec<&str>) -> Self { + Self::Multiple(value.iter().map(ToString::to_string).collect()) + } } pub struct Columns(Vec<(String, Value)>); #[allow(clippy::fallible_impl_from)] impl From for Columns { - fn from(value: Value) -> Self { - match value { - Value::Map(map) => Self( - map.into_iter() - .map(|(column, value)| (column.into_string().unwrap(), value)) - .collect(), - ), - _ => panic!("The provided value is not a map."), - } - } + fn from(value: Value) -> Self { + match value { + Value::Map(map) => Self( + map.into_iter() + .map(|(column, value)| (column.into_string().unwrap(), value)) + .collect(), + ), + _ => panic!("The provided value is not a map."), + } + } } impl From> for Columns { - fn from(values: Vec<(&str, T)>) -> Self { - Self( - values - .iter() - .map(|(column, value)| ((*column).to_string(), value::for_db(value).unwrap())) - .collect(), - ) - } + fn from(values: Vec<(&str, T)>) -> Self { + Self( + values + .iter() + .map(|(column, value)| ((*column).to_string(), value::for_db(value).unwrap())) + .collect(), + ) + } } impl From<&[(&str, T)]> for Columns { - fn from(values: &[(&str, T)]) -> Self { - Self( - values - .iter() - .map(|(column, value)| ((*column).to_string(), value::for_db(value).unwrap())) - .collect(), - ) - } + fn from(values: &[(&str, T)]) -> Self { + Self( + values + .iter() + .map(|(column, value)| ((*column).to_string(), value::for_db(value).unwrap())) + .collect(), + ) + } } /// Available sort directions. #[derive(Debug)] pub enum Direction { - Ascending, - Descending, + Ascending, + Descending, } impl Display for Direction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Ascending => write!(f, "ASC"), - Self::Descending => write!(f, "DESC"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ascending => write!(f, "ASC"), + Self::Descending => write!(f, "DESC"), + } + } } impl From for Direction { - fn from(value: String) -> Self { - value.as_str().into() - } + fn from(value: String) -> Self { + value.as_str().into() + } } #[allow(clippy::fallible_impl_from)] impl From<&str> for Direction { - fn from(value: &str) -> Self { - match value.to_uppercase().as_str() { - "ASC" | "ASCENDING" => Self::Ascending, - "DESC" | "DESCENDING" => Self::Descending, - - _ => panic!("Invalid direction {value}"), - } - } + fn from(value: &str) -> Self { + match value.to_uppercase().as_str() { + "ASC" | "ASCENDING" => Self::Ascending, + "DESC" | "DESCENDING" => Self::Descending, + + _ => panic!("Invalid direction {value}"), + } + } } /// An order clause. #[derive(Debug)] struct Order { - column: String, - direction: Direction, + column: String, + direction: Direction, } /// Available join types. #[derive(Debug)] enum JoinType { - /// The `INNER JOIN` type. - Inner, + /// The `INNER JOIN` type. + Inner, } impl Display for JoinType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Inner => write!(f, "INNER JOIN"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Inner => write!(f, "INNER JOIN"), + } + } } #[derive(Debug, Clone, Copy)] pub enum Type { - Count, - Select, - Update, - Delete, + Count, + Select, + Update, + Delete, } /// A join clause. #[derive(Debug)] struct Join { - column: String, - first: String, - second: String, - r#type: JoinType, - operator: Operator, + column: String, + first: String, + second: String, + r#type: JoinType, + operator: Operator, } #[derive(Debug)] enum WhereClause { - Simple(Where), - Group(Vec, Boolean), + Simple(Where), + Group(Vec, Boolean), } impl WhereClause { - fn to_sql(&self, add_boolean: bool) -> String { - match self { - Self::Simple(where_clause) => where_clause.to_sql(add_boolean), - Self::Group(where_clauses, boolean) => { - let mut sql = String::new(); - - for (i, where_clause) in where_clauses.iter().enumerate() { - sql.push_str(&format!("({})", where_clause.to_sql(false))); - - if i != where_clauses.len() - 1 { - sql.push_str(" AND "); - } - } - - if add_boolean { - format!("{boolean} {sql}") - } else { - sql - } - } - } - } - - fn get_bindings(&self) -> Vec { - match self { - Self::Simple(where_clause) => where_clause - .value - .clone() - .into_iter() - .flat_map(|v| match v { - Value::Array(array) => array, - _ => vec![v], - }) - .collect(), - Self::Group(where_clauses, _) => { - where_clauses.iter().flat_map(Self::get_bindings).collect() - } - } - } + fn to_sql(&self, add_boolean: bool) -> String { + match self { + Self::Simple(where_clause) => where_clause.to_sql(add_boolean), + Self::Group(where_clauses, boolean) => { + let mut sql = String::new(); + + for (i, where_clause) in where_clauses.iter().enumerate() { + sql.push_str(&format!("({})", where_clause.to_sql(false))); + + if i != where_clauses.len() - 1 { + sql.push_str(" AND "); + } + } + + if add_boolean { + format!("{boolean} {sql}") + } else { + sql + } + }, + } + } + + fn get_bindings(&self) -> Vec { + match self { + Self::Simple(where_clause) => where_clause + .value + .clone() + .into_iter() + .flat_map(|v| match v { + Value::Array(array) => array, + _ => vec![v], + }) + .collect(), + Self::Group(where_clauses, _) => { + where_clauses.iter().flat_map(Self::get_bindings).collect() + }, + } + } } /// A where clause. #[derive(Debug)] struct Where { - column: String, - boolean: Boolean, - operator: Operator, - value: Option, + column: String, + boolean: Boolean, + operator: Operator, + value: Option, } impl Where { - fn to_sql(&self, add_boolean: bool) -> String { - let sql = format!( - "{} {} {}", - self.column, - self.operator, - self.value.as_ref().map_or_else(String::new, |value| { - value.as_array().map_or_else( - || "?".to_string(), - |value| format!("({})", value.iter().map(|_| "?").join(", ")), - ) - }) - ); - - if add_boolean { - format!("{sql} {} ", self.boolean) - } else { - sql - } - } + fn to_sql(&self, add_boolean: bool) -> String { + let sql = format!( + "{} {} {}", + self.column, + self.operator, + self.value.as_ref().map_or_else(String::new, |value| { + value.as_array().map_or_else( + || "?".to_string(), + |value| format!("({})", value.iter().map(|_| "?").join(", ")), + ) + }) + ); + + if add_boolean { + format!("{sql} {} ", self.boolean) + } else { + sql + } + } } /// Available operators for where clauses. #[derive(Debug)] pub enum Operator { - /// The `IN` operator. - In, - /// The `LIKE` operator. - Like, - /// The `NOT IN` operator. - NotIn, - /// The `=` operator. - Equals, - /// The `IS NULL` operator. - IsNull, - /// The `IS NOT NULL` operator. - NotNull, - /// The `BETWEEN` operator. - Between, - /// The `NOT LIKE` operator. - NotLike, - /// The `<` operator. - LessThan, - /// The `<>` operator. - NotEquals, - /// The `NOT BETWEEN` operator. - NotBetween, - /// The `>` operator. - GreaterThan, - /// The `<=` operator. - LessOrEqual, - /// The `>=` operator. - GreaterOrEqual, + /// The `IN` operator. + In, + /// The `LIKE` operator. + Like, + /// The `NOT IN` operator. + NotIn, + /// The `=` operator. + Equals, + /// The `IS NULL` operator. + IsNull, + /// The `IS NOT NULL` operator. + NotNull, + /// The `BETWEEN` operator. + Between, + /// The `NOT LIKE` operator. + NotLike, + /// The `<` operator. + LessThan, + /// The `<>` operator. + NotEquals, + /// The `NOT BETWEEN` operator. + NotBetween, + /// The `>` operator. + GreaterThan, + /// The `<=` operator. + LessOrEqual, + /// The `>=` operator. + GreaterOrEqual, } impl Display for Operator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::In => "IN", - Self::Equals => "=", - Self::Like => "LIKE", - Self::LessThan => "<", - Self::NotIn => "NOT IN", - Self::NotEquals => "<>", - Self::GreaterThan => ">", - Self::LessOrEqual => "<=", - Self::IsNull => "IS NULL", - Self::Between => "BETWEEN", - Self::NotLike => "NOT LIKE", - Self::GreaterOrEqual => ">=", - Self::NotNull => "IS NOT NULL", - Self::NotBetween => "NOT BETWEEN", - } - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::In => "IN", + Self::Equals => "=", + Self::Like => "LIKE", + Self::LessThan => "<", + Self::NotIn => "NOT IN", + Self::NotEquals => "<>", + Self::GreaterThan => ">", + Self::LessOrEqual => "<=", + Self::IsNull => "IS NULL", + Self::Between => "BETWEEN", + Self::NotLike => "NOT LIKE", + Self::GreaterOrEqual => ">=", + Self::NotNull => "IS NOT NULL", + Self::NotBetween => "NOT BETWEEN", + } + ) + } } impl From for Operator { - fn from(value: String) -> Self { - value.as_str().into() - } + fn from(value: String) -> Self { + value.as_str().into() + } } impl From for Operator { - fn from(value: char) -> Self { - value.to_string().into() - } + fn from(value: char) -> Self { + value.to_string().into() + } } #[allow(clippy::fallible_impl_from)] impl From<&str> for Operator { - fn from(value: &str) -> Self { - match value.to_uppercase().as_str() { - "IN" => Self::In, - "=" => Self::Equals, - "LIKE" => Self::Like, - "<" => Self::LessThan, - "NOT IN" => Self::NotIn, - "!=" => Self::NotEquals, - ">" => Self::GreaterThan, - "<=" => Self::LessOrEqual, - "BETWEEN" => Self::Between, - "NOT LIKE" => Self::NotLike, - ">=" => Self::GreaterOrEqual, - "NOT BETWEEN" => Self::NotBetween, - - _ => panic!("Invalid operator {value}"), - } - } + fn from(value: &str) -> Self { + match value.to_uppercase().as_str() { + "IN" => Self::In, + "=" => Self::Equals, + "LIKE" => Self::Like, + "<" => Self::LessThan, + "NOT IN" => Self::NotIn, + "!=" => Self::NotEquals, + ">" => Self::GreaterThan, + "<=" => Self::LessOrEqual, + "BETWEEN" => Self::Between, + "NOT LIKE" => Self::NotLike, + ">=" => Self::GreaterOrEqual, + "NOT BETWEEN" => Self::NotBetween, + + _ => panic!("Invalid operator {value}"), + } + } } #[derive(Debug)] enum Boolean { - And, - Or, + And, + Or, } impl Display for Boolean { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Or => write!(f, "OR"), - Self::And => write!(f, "AND"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Or => write!(f, "OR"), + Self::And => write!(f, "AND"), + } + } } impl AsRef for Builder { - fn as_ref(&self) -> &Self { - self - } + fn as_ref(&self) -> &Self { + self + } } diff --git a/ensemble/src/relationships/belongs_to.rs b/ensemble/src/relationships/belongs_to.rs index 40448dc..299bf8a 100644 --- a/ensemble/src/relationships/belongs_to.rs +++ b/ensemble/src/relationships/belongs_to.rs @@ -35,109 +35,108 @@ use crate::{query::Builder, value::serializing_for_db, Error, Model}; /// ``` #[derive(Clone, Default)] pub struct BelongsTo { - local_key: String, - relation: Status, - _local: std::marker::PhantomData, - /// The value of the local model's related key. - pub value: Related::PrimaryKey, + local_key: String, + relation: Status, + _local: std::marker::PhantomData, + /// The value of the local model's related key. + pub value: Related::PrimaryKey, } -#[async_trait::async_trait] impl Relationship for BelongsTo { - type Value = Related; - type Key = Related::PrimaryKey; - type RelatedKey = Option; - - fn build(value: Self::Key, local_key: Self::RelatedKey) -> Self { - let local_key = local_key.unwrap_or_else(|| Related::PRIMARY_KEY.to_snake_case()); - - Self { - value, - local_key, - relation: Status::initial(), - _local: std::marker::PhantomData, - } - } - - fn query(&self) -> Builder { - Related::query() - .r#where( - &format!("{}.{}", Related::TABLE_NAME, self.local_key), - "=", - self.value.clone(), - ) - .limit(1) - } - - /// Get the related model. - async fn get(&mut self) -> Result<&mut Self::Value, Error> { - if self.relation.is_none() { - let relation = self.query().first().await?.ok_or(Error::NotFound)?; - - self.relation = Status::Fetched(Some(relation)); - } - - Ok(self.relation.as_mut().unwrap()) - } - - fn is_loaded(&self) -> bool { - self.relation.is_loaded() - } - - fn eager_query(&self, related: Vec) -> Builder { - Related::query() - .r#where( - &format!("{}.{}", Related::TABLE_NAME, self.local_key), - "in", - related, - ) - .limit(1) - } - - fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error> { - let related = find_related(related, &self.local_key, &self.value, true)?; - - self.relation = Status::Fetched(related.into_iter().next()); - - Ok(()) - } + type Value = Related; + type Key = Related::PrimaryKey; + type RelatedKey = Option; + + fn build(value: Self::Key, local_key: Self::RelatedKey) -> Self { + let local_key = local_key.unwrap_or_else(|| Related::PRIMARY_KEY.to_snake_case()); + + Self { + value, + local_key, + relation: Status::initial(), + _local: std::marker::PhantomData, + } + } + + fn query(&self) -> Builder { + Related::query() + .r#where( + &format!("{}.{}", Related::TABLE_NAME, self.local_key), + "=", + self.value.clone(), + ) + .limit(1) + } + + /// Get the related model. + async fn get(&mut self) -> Result<&mut Self::Value, Error> { + if self.relation.is_none() { + let relation = self.query().first().await?.ok_or(Error::NotFound)?; + + self.relation = Status::Fetched(Some(relation)); + } + + Ok(self.relation.as_mut().unwrap()) + } + + fn is_loaded(&self) -> bool { + self.relation.is_loaded() + } + + fn eager_query(&self, related: Vec) -> Builder { + Related::query() + .r#where( + &format!("{}.{}", Related::TABLE_NAME, self.local_key), + "in", + related, + ) + .limit(1) + } + + fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error> { + let related = find_related(related, &self.local_key, &self.value, true)?; + + self.relation = Status::Fetched(related.into_iter().next()); + + Ok(()) + } } impl Debug for BelongsTo { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.relation.fmt(f) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.relation.fmt(f) + } } impl Serialize for BelongsTo { - fn serialize(&self, serializer: S) -> Result { - if serializing_for_db::() { - if self.value == Default::default() { - return serializer.serialize_none(); - } + fn serialize(&self, serializer: S) -> Result { + if serializing_for_db::() { + if self.value == Default::default() { + return serializer.serialize_none(); + } - return self.value.serialize(serializer); - } + return self.value.serialize(serializer); + } - self.relation.serialize(serializer) - } + self.relation.serialize(serializer) + } } impl PartialEq for BelongsTo { - fn eq(&self, other: &Related) -> bool { - &self.value == other.primary_key() - } + fn eq(&self, other: &Related) -> bool { + &self.value == other.primary_key() + } } #[cfg(feature = "schema")] impl schemars::JsonSchema - for BelongsTo + for BelongsTo { - fn schema_name() -> String { - >::schema_name() - } + fn schema_name() -> String { + >::schema_name() + } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - gen.subschema_for::>() - } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + gen.subschema_for::>() + } } diff --git a/ensemble/src/relationships/belongs_to_many.rs b/ensemble/src/relationships/belongs_to_many.rs index efd1600..023d716 100644 --- a/ensemble/src/relationships/belongs_to_many.rs +++ b/ensemble/src/relationships/belongs_to_many.rs @@ -34,132 +34,131 @@ use crate::{query::Builder, value::serializing_for_db, Error, Model}; /// ``` #[derive(Clone, Default)] pub struct BelongsToMany { - local_key: String, - foreign_key: String, - pivot_table: String, - relation: Status>, - _local: std::marker::PhantomData, - /// The value of the local model's primary key. - pub value: Related::PrimaryKey, + local_key: String, + foreign_key: String, + pivot_table: String, + relation: Status>, + _local: std::marker::PhantomData, + /// The value of the local model's primary key. + pub value: Related::PrimaryKey, } -#[async_trait::async_trait] impl Relationship for BelongsToMany { - type Value = Vec; - type Key = Related::PrimaryKey; - type RelatedKey = (Option, Option, Option); - - fn build(value: Self::Key, (pivot_table, foreign_key, local_key): Self::RelatedKey) -> Self { - let pivot_table = pivot_table.unwrap_or_else(|| { - let mut names = [Local::NAME.to_string(), Related::NAME.to_string()]; - names.sort(); - names.join("_").to_snake_case() - }); - - let foreign_key = foreign_key.unwrap_or_else(|| { - format!("{}_{}", Related::NAME.to_snake_case(), Related::PRIMARY_KEY).to_snake_case() - }); - - let local_key = local_key.unwrap_or_else(|| { - format!("{}_{}", Local::NAME.to_snake_case(), Local::PRIMARY_KEY).to_snake_case() - }); - - Self { - value, - local_key, - foreign_key, - pivot_table, - relation: Status::initial(), - _local: std::marker::PhantomData, - } - } - - fn query(&self) -> Builder { - Related::query() - .from(Related::TABLE_NAME) - .join( - &self.pivot_table, - &format!("{}.{}", Related::TABLE_NAME, Related::PRIMARY_KEY), - "=", - &format!("{}.{}", self.pivot_table, self.foreign_key), - ) - .r#where( - &format!("{}.{}", self.pivot_table, self.local_key), - "=", - self.value.clone(), - ) - } - - async fn get(&mut self) -> Result<&mut Self::Value, Error> { - if self.relation.is_none() { - let relation = self.query().get().await?; - - self.relation = Status::Fetched(Some(relation)); - } - - Ok(self.relation.as_mut().unwrap()) - } - - fn is_loaded(&self) -> bool { - self.relation.is_loaded() - } - - fn eager_query(&self, related: Vec) -> Builder { - Related::query() - .from(Related::TABLE_NAME) - .join( - &self.pivot_table, - &format!("{}.{}", Related::TABLE_NAME, Related::PRIMARY_KEY), - "=", - &format!("{}.{}", self.pivot_table, self.foreign_key), - ) - .r#where( - &format!("{}.{}", self.pivot_table, self.local_key), - "in", - related, - ) - } - - fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error> { - let related = find_related(related, &self.foreign_key, &self.value, false)?; - - if !related.is_empty() { - self.relation = Status::Fetched(Some(related)); - } - - Ok(()) - } + type Value = Vec; + type Key = Related::PrimaryKey; + type RelatedKey = (Option, Option, Option); + + fn build(value: Self::Key, (pivot_table, foreign_key, local_key): Self::RelatedKey) -> Self { + let pivot_table = pivot_table.unwrap_or_else(|| { + let mut names = [Local::NAME.to_string(), Related::NAME.to_string()]; + names.sort(); + names.join("_").to_snake_case() + }); + + let foreign_key = foreign_key.unwrap_or_else(|| { + format!("{}_{}", Related::NAME.to_snake_case(), Related::PRIMARY_KEY).to_snake_case() + }); + + let local_key = local_key.unwrap_or_else(|| { + format!("{}_{}", Local::NAME.to_snake_case(), Local::PRIMARY_KEY).to_snake_case() + }); + + Self { + value, + local_key, + foreign_key, + pivot_table, + relation: Status::initial(), + _local: std::marker::PhantomData, + } + } + + fn query(&self) -> Builder { + Related::query() + .from(Related::TABLE_NAME) + .join( + &self.pivot_table, + &format!("{}.{}", Related::TABLE_NAME, Related::PRIMARY_KEY), + "=", + &format!("{}.{}", self.pivot_table, self.foreign_key), + ) + .r#where( + &format!("{}.{}", self.pivot_table, self.local_key), + "=", + self.value.clone(), + ) + } + + async fn get(&mut self) -> Result<&mut Self::Value, Error> { + if self.relation.is_none() { + let relation = self.query().get().await?; + + self.relation = Status::Fetched(Some(relation)); + } + + Ok(self.relation.as_mut().unwrap()) + } + + fn is_loaded(&self) -> bool { + self.relation.is_loaded() + } + + fn eager_query(&self, related: Vec) -> Builder { + Related::query() + .from(Related::TABLE_NAME) + .join( + &self.pivot_table, + &format!("{}.{}", Related::TABLE_NAME, Related::PRIMARY_KEY), + "=", + &format!("{}.{}", self.pivot_table, self.foreign_key), + ) + .r#where( + &format!("{}.{}", self.pivot_table, self.local_key), + "in", + related, + ) + } + + fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error> { + let related = find_related(related, &self.foreign_key, &self.value, false)?; + + if !related.is_empty() { + self.relation = Status::Fetched(Some(related)); + } + + Ok(()) + } } impl Debug for BelongsToMany { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.relation.fmt(f) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.relation.fmt(f) + } } impl Serialize for BelongsToMany { - fn serialize(&self, serializer: S) -> Result { - if serializing_for_db::() { - if self.value == Default::default() { - return serializer.serialize_none(); - } + fn serialize(&self, serializer: S) -> Result { + if serializing_for_db::() { + if self.value == Default::default() { + return serializer.serialize_none(); + } - return self.value.serialize(serializer); - } + return self.value.serialize(serializer); + } - self.relation.serialize(serializer) - } + self.relation.serialize(serializer) + } } #[cfg(feature = "schema")] impl schemars::JsonSchema - for BelongsToMany + for BelongsToMany { - fn schema_name() -> String { - >>::schema_name() - } + fn schema_name() -> String { + >>::schema_name() + } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - gen.subschema_for::>>() - } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + gen.subschema_for::>>() + } } diff --git a/ensemble/src/relationships/has_many.rs b/ensemble/src/relationships/has_many.rs index ddbc076..066a68e 100644 --- a/ensemble/src/relationships/has_many.rs +++ b/ensemble/src/relationships/has_many.rs @@ -5,9 +5,9 @@ use std::{collections::HashMap, fmt::Debug}; use super::{find_related, Relationship, Status}; use crate::{ - query::Builder, - value::{self, serializing_for_db}, - Error, Model, + query::Builder, + value::{self, serializing_for_db}, + Error, Model, }; /// ## A One to Many relationship. @@ -41,164 +41,163 @@ use crate::{ /// ``` #[derive(Clone, Default)] pub struct HasMany { - foreign_key: String, - relation: Status>, - /// The value of the local model's primary key. - pub value: Local::PrimaryKey, + foreign_key: String, + relation: Status>, + /// The value of the local model's primary key. + pub value: Local::PrimaryKey, } -#[async_trait::async_trait] impl Relationship for HasMany { - type Value = Vec; - type Key = Local::PrimaryKey; - type RelatedKey = Option; - - fn build(value: Self::Key, foreign_key: Self::RelatedKey) -> Self { - let foreign_key = foreign_key.unwrap_or_else(|| { - format!("{}_{}", Local::NAME.to_snake_case(), Local::PRIMARY_KEY).to_snake_case() - }); - - Self { - value, - foreign_key, - relation: Status::initial(), - } - } - - fn query(&self) -> Builder { - Related::query() - .r#where( - &format!("{}.{}", Related::TABLE_NAME, self.foreign_key), - "=", - self.value.clone(), - ) - .where_not_null(&format!("{}.{}", Related::TABLE_NAME, self.foreign_key)) - } - - /// Get the related models. - /// - /// # Errors - /// - /// Returns an error if the model cannot be retrieved, or if a connection to the database cannot be established. - async fn get(&mut self) -> Result<&mut Self::Value, Error> { - if self.relation.is_none() { - let relation = self.query().get().await?; - - self.relation = Status::Fetched(Some(relation)); - } - - Ok(self.relation.as_mut().unwrap()) - } - - fn is_loaded(&self) -> bool { - self.relation.is_loaded() - } - - fn eager_query(&self, related: Vec) -> Builder { - Related::query() - .r#where( - &format!("{}.{}", Related::TABLE_NAME, self.foreign_key), - "in", - related, - ) - .where_not_null(&format!("{}.{}", Related::TABLE_NAME, self.foreign_key)) - } - - fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error> { - let related = find_related(related, &self.foreign_key, &self.value, false)?; - - if !related.is_empty() { - self.relation = Status::Fetched(Some(related)); - } - - Ok(()) - } + type Value = Vec; + type Key = Local::PrimaryKey; + type RelatedKey = Option; + + fn build(value: Self::Key, foreign_key: Self::RelatedKey) -> Self { + let foreign_key = foreign_key.unwrap_or_else(|| { + format!("{}_{}", Local::NAME.to_snake_case(), Local::PRIMARY_KEY).to_snake_case() + }); + + Self { + value, + foreign_key, + relation: Status::initial(), + } + } + + fn query(&self) -> Builder { + Related::query() + .r#where( + &format!("{}.{}", Related::TABLE_NAME, self.foreign_key), + "=", + self.value.clone(), + ) + .where_not_null(&format!("{}.{}", Related::TABLE_NAME, self.foreign_key)) + } + + /// Get the related models. + /// + /// # Errors + /// + /// Returns an error if the model cannot be retrieved, or if a connection to the database cannot be established. + async fn get(&mut self) -> Result<&mut Self::Value, Error> { + if self.relation.is_none() { + let relation = self.query().get().await?; + + self.relation = Status::Fetched(Some(relation)); + } + + Ok(self.relation.as_mut().unwrap()) + } + + fn is_loaded(&self) -> bool { + self.relation.is_loaded() + } + + fn eager_query(&self, related: Vec) -> Builder { + Related::query() + .r#where( + &format!("{}.{}", Related::TABLE_NAME, self.foreign_key), + "in", + related, + ) + .where_not_null(&format!("{}.{}", Related::TABLE_NAME, self.foreign_key)) + } + + fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error> { + let related = find_related(related, &self.foreign_key, &self.value, false)?; + + if !related.is_empty() { + self.relation = Status::Fetched(Some(related)); + } + + Ok(()) + } } impl HasMany { - /// Create a new `Related` model. - /// - /// ## Errors - /// - /// Returns an error if the model cannot be inserted, or if a connection to the database cannot be established. - /// - /// ## Example - /// - /// ```rust - /// # use ensemble::{Model, relationships::HasMany}; - /// # #[derive(Debug, Model, Clone)] - /// # struct Comment { - /// # id: u64, - /// # content: String, - /// # } - /// # #[derive(Debug, Model, Clone)] - /// # struct Post { - /// # id: u64, - /// # comments: HasMany - /// # } - /// # async fn call() -> Result<(), ensemble::Error> { - /// let mut post = Post::find(1).await?; - /// - /// let comment = post.comments.create(Comment { - /// id: 1, - /// content: "Hello, world!".to_string(), - /// }).await?; - /// # Ok(()) - /// # } - pub async fn create(&mut self, related: Related) -> Result - where - Related: Clone, - { - let Value::Map(mut value) = rbs::to_value(related)? else { - return Err(Error::Serialization(rbs::Error::Syntax( - "Expected a map".to_string(), - ))); - }; - - value.insert( - Value::String(self.foreign_key.clone()), - value::for_db(&self.value)?, - ); - - let result = Related::create(rbs::from_value(Value::Map(value))?).await?; - - if let Status::Fetched(Some(relation)) = &mut self.relation { - relation.push(result.clone()); - } - - Ok(result) - } + /// Create a new `Related` model. + /// + /// ## Errors + /// + /// Returns an error if the model cannot be inserted, or if a connection to the database cannot be established. + /// + /// ## Example + /// + /// ```rust + /// # use ensemble::{Model, relationships::HasMany}; + /// # #[derive(Debug, Model, Clone)] + /// # struct Comment { + /// # id: u64, + /// # content: String, + /// # } + /// # #[derive(Debug, Model, Clone)] + /// # struct Post { + /// # id: u64, + /// # comments: HasMany + /// # } + /// # async fn call() -> Result<(), ensemble::Error> { + /// let mut post = Post::find(1).await?; + /// + /// let comment = post.comments.create(Comment { + /// id: 1, + /// content: "Hello, world!".to_string(), + /// }).await?; + /// # Ok(()) + /// # } + pub async fn create(&mut self, related: Related) -> Result + where + Related: Clone, + { + let Value::Map(mut value) = rbs::to_value(related)? else { + return Err(Error::Serialization(rbs::Error::Syntax( + "Expected a map".to_string(), + ))); + }; + + value.insert( + Value::String(self.foreign_key.clone()), + value::for_db(&self.value)?, + ); + + let result = Related::create(rbs::from_value(Value::Map(value))?).await?; + + if let Status::Fetched(Some(relation)) = &mut self.relation { + relation.push(result.clone()); + } + + Ok(result) + } } impl Debug for HasMany { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.relation.fmt(f) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.relation.fmt(f) + } } impl Serialize for HasMany { - fn serialize(&self, serializer: S) -> Result { - if serializing_for_db::() { - if self.value == Default::default() { - return serializer.serialize_none(); - } + fn serialize(&self, serializer: S) -> Result { + if serializing_for_db::() { + if self.value == Default::default() { + return serializer.serialize_none(); + } - return self.value.serialize(serializer); - } + return self.value.serialize(serializer); + } - self.relation.serialize(serializer) - } + self.relation.serialize(serializer) + } } #[cfg(feature = "schema")] impl schemars::JsonSchema - for HasMany + for HasMany { - fn schema_name() -> String { - >>::schema_name() - } + fn schema_name() -> String { + >>::schema_name() + } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - gen.subschema_for::>>() - } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + gen.subschema_for::>>() + } } diff --git a/ensemble/src/relationships/has_one.rs b/ensemble/src/relationships/has_one.rs index 25dc35b..d970662 100644 --- a/ensemble/src/relationships/has_one.rs +++ b/ensemble/src/relationships/has_one.rs @@ -35,104 +35,103 @@ use crate::{query::Builder, value::serializing_for_db, Error, Model}; /// ``` #[derive(Clone, Default)] pub struct HasOne { - foreign_key: String, - relation: Status, - /// The value of the local model's primary key. - pub value: Local::PrimaryKey, + foreign_key: String, + relation: Status, + /// The value of the local model's primary key. + pub value: Local::PrimaryKey, } -#[async_trait::async_trait] impl Relationship for HasOne { - type Value = Related; - type Key = Local::PrimaryKey; - type RelatedKey = Option; - - fn build(value: Self::Key, foreign_key: Self::RelatedKey) -> Self { - let foreign_key = foreign_key.unwrap_or_else(|| { - format!("{}_{}", Local::NAME.to_snake_case(), Local::PRIMARY_KEY).to_snake_case() - }); - - Self { - value, - foreign_key, - relation: Status::initial(), - } - } - - fn query(&self) -> Builder { - Related::query() - .r#where( - &format!("{}.{}", Related::TABLE_NAME, self.foreign_key), - "=", - self.value.clone(), - ) - .where_not_null(&format!("{}.{}", Related::TABLE_NAME, self.foreign_key)) - .limit(1) - } - - async fn get(&mut self) -> Result<&mut Self::Value, Error> { - if self.relation.is_none() { - let relation = self.query().first().await?.ok_or(Error::NotFound)?; - - self.relation = Status::Fetched(Some(relation)); - } - - Ok(self.relation.as_mut().unwrap()) - } - - fn is_loaded(&self) -> bool { - self.relation.is_loaded() - } - - fn eager_query(&self, related: Vec) -> Builder { - Related::query() - .r#where( - &format!("{}.{}", Related::TABLE_NAME, self.foreign_key), - "in", - related, - ) - .where_not_null(&format!("{}.{}", Related::TABLE_NAME, self.foreign_key)) - .limit(1) - } - - fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error> { - let related = find_related(related, &self.foreign_key, &self.value, true)?; - - self.relation = Status::Fetched(related.into_iter().next()); - - Ok(()) - } + type Value = Related; + type Key = Local::PrimaryKey; + type RelatedKey = Option; + + fn build(value: Self::Key, foreign_key: Self::RelatedKey) -> Self { + let foreign_key = foreign_key.unwrap_or_else(|| { + format!("{}_{}", Local::NAME.to_snake_case(), Local::PRIMARY_KEY).to_snake_case() + }); + + Self { + value, + foreign_key, + relation: Status::initial(), + } + } + + fn query(&self) -> Builder { + Related::query() + .r#where( + &format!("{}.{}", Related::TABLE_NAME, self.foreign_key), + "=", + self.value.clone(), + ) + .where_not_null(&format!("{}.{}", Related::TABLE_NAME, self.foreign_key)) + .limit(1) + } + + async fn get(&mut self) -> Result<&mut Self::Value, Error> { + if self.relation.is_none() { + let relation = self.query().first().await?.ok_or(Error::NotFound)?; + + self.relation = Status::Fetched(Some(relation)); + } + + Ok(self.relation.as_mut().unwrap()) + } + + fn is_loaded(&self) -> bool { + self.relation.is_loaded() + } + + fn eager_query(&self, related: Vec) -> Builder { + Related::query() + .r#where( + &format!("{}.{}", Related::TABLE_NAME, self.foreign_key), + "in", + related, + ) + .where_not_null(&format!("{}.{}", Related::TABLE_NAME, self.foreign_key)) + .limit(1) + } + + fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error> { + let related = find_related(related, &self.foreign_key, &self.value, true)?; + + self.relation = Status::Fetched(related.into_iter().next()); + + Ok(()) + } } impl Debug for HasOne { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.relation.fmt(f) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.relation.fmt(f) + } } impl Serialize for HasOne { - fn serialize(&self, serializer: S) -> Result { - if serializing_for_db::() { - if self.value == Default::default() { - return serializer.serialize_none(); - } + fn serialize(&self, serializer: S) -> Result { + if serializing_for_db::() { + if self.value == Default::default() { + return serializer.serialize_none(); + } - return self.value.serialize(serializer); - } + return self.value.serialize(serializer); + } - self.relation.serialize(serializer) - } + self.relation.serialize(serializer) + } } #[cfg(feature = "schema")] impl schemars::JsonSchema - for HasOne + for HasOne { - fn schema_name() -> String { - >::schema_name() - } + fn schema_name() -> String { + >::schema_name() + } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - gen.subschema_for::>() - } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + gen.subschema_for::>() + } } diff --git a/ensemble/src/relationships/mod.rs b/ensemble/src/relationships/mod.rs index a9bcef2..4aed929 100644 --- a/ensemble/src/relationships/mod.rs +++ b/ensemble/src/relationships/mod.rs @@ -6,8 +6,11 @@ mod belongs_to_many; mod has_many; mod has_one; -use std::ops::Deref; -use std::{collections::HashMap, ops::DerefMut}; +use std::{ + collections::HashMap, + future::Future, + ops::{Deref, DerefMut}, +}; use crate::{query::Builder, value, Error, Model}; @@ -18,117 +21,116 @@ pub use has_one::HasOne; use rbs::Value; /// A relationship between two models. -#[async_trait::async_trait] pub trait Relationship { - /// The provided input for the relationship. - type RelatedKey; - - /// The type of the primary key for the model. - type Key; - - /// The return type of the relationship. - type Value; - - /// Get the related model. - /// - /// # Errors - /// - /// Returns an error if the model cannot be retrieved, or if a connection to the database cannot be established. - async fn get(&mut self) -> Result<&mut Self::Value, Error>; - - /// Whether the relationship has been loaded. - fn is_loaded(&self) -> bool; - - /// Get the query builder for the relationship. - /// - /// # Errors - /// - /// Returns an error if serialization fails when building the query. - fn query(&self) -> Builder; - - #[doc(hidden)] - /// Get the query builder for eager loading the relationship. Not intended to be used directly. - fn eager_query(&self, related: Vec) -> Builder; - - #[doc(hidden)] - /// Match the eagerly loaded results to their parents. Not intended to be used directly. - fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error>; - - #[doc(hidden)] - /// Create an instance of the relationship. Not intended to be used directly. - fn build(value: Self::Key, related_key: Self::RelatedKey) -> Self; + /// The provided input for the relationship. + type RelatedKey; + + /// The type of the primary key for the model. + type Key; + + /// The return type of the relationship. + type Value; + + /// Get the related model. + /// + /// # Errors + /// + /// Returns an error if the model cannot be retrieved, or if a connection to the database cannot be established. + fn get(&mut self) -> impl Future> + Send; + + /// Whether the relationship has been loaded. + fn is_loaded(&self) -> bool; + + /// Get the query builder for the relationship. + /// + /// # Errors + /// + /// Returns an error if serialization fails when building the query. + fn query(&self) -> Builder; + + #[doc(hidden)] + /// Get the query builder for eager loading the relationship. Not intended to be used directly. + fn eager_query(&self, related: Vec) -> Builder; + + #[doc(hidden)] + /// Match the eagerly loaded results to their parents. Not intended to be used directly. + fn r#match(&mut self, related: &[HashMap]) -> Result<(), Error>; + + #[doc(hidden)] + /// Create an instance of the relationship. Not intended to be used directly. + fn build(value: Self::Key, related_key: Self::RelatedKey) -> Self; } #[derive(Debug, Clone, PartialEq, Eq)] enum Status { - Initial(Option), - Fetched(Option), + Initial(Option), + Fetched(Option), } impl Status { - const fn initial() -> Self { - Self::Initial(None) - } - - const fn is_loaded(&self) -> bool { - match self { - Self::Initial(_) => false, - Self::Fetched(_) => true, - } - } + const fn initial() -> Self { + Self::Initial(None) + } + + const fn is_loaded(&self) -> bool { + match self { + Self::Initial(_) => false, + Self::Fetched(_) => true, + } + } } impl Default for Status { - fn default() -> Self { - Self::initial() - } + fn default() -> Self { + Self::initial() + } } impl Deref for Status { - type Target = Option; + type Target = Option; - fn deref(&self) -> &Self::Target { - match self { - Self::Initial(value) | Self::Fetched(value) => value, - } - } + fn deref(&self) -> &Self::Target { + match self { + Self::Initial(value) | Self::Fetched(value) => value, + } + } } impl DerefMut for Status { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - Self::Initial(value) | Self::Fetched(value) => value, - } - } + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::Initial(value) | Self::Fetched(value) => value, + } + } } impl serde::Serialize for Status { - fn serialize(&self, serializer: S) -> Result { - match self { - Self::Initial(_) | Self::Fetched(None) => serializer.serialize_none(), - Self::Fetched(Some(ref value)) => value.serialize(serializer), - } - } + fn serialize(&self, serializer: S) -> Result { + match self { + Self::Initial(_) | Self::Fetched(None) => serializer.serialize_none(), + Self::Fetched(Some(ref value)) => value.serialize(serializer), + } + } } fn find_related( - related: &[HashMap], - foreign_key: &str, - value: T, - wants_one: bool, + related: &[HashMap], + foreign_key: &str, + value: T, + wants_one: bool, ) -> Result, Error> { - let value = value::for_db(value)?; - - let related = related - .iter() - .filter(|model| { - model - .get(foreign_key) - .is_some_and(|v| v.to_string() == value.to_string()) - }) - .take(if wants_one { 1 } else { usize::MAX }) - .map(|model| value::from::(value::for_db(model).unwrap())) - .collect::, _>>()?; - - Ok(related) + let value = value::for_db(value)?; + + let related = related + .iter() + .filter(|model| { + model + .get(foreign_key) + .is_some_and(|v| v.to_string() == value.to_string()) + }) + .take(if wants_one { 1 } else { usize::MAX }) + .map(|model| value::from::(value::for_db(model).unwrap())) + .collect::, _>>()?; + + Ok(related) } diff --git a/ensemble/src/types/datetime.rs b/ensemble/src/types/datetime.rs index d02ca26..ca3f46c 100644 --- a/ensemble/src/types/datetime.rs +++ b/ensemble/src/types/datetime.rs @@ -1,10 +1,11 @@ use rbs::Value; -use serde::de::Error; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use std::fmt::{Debug, Display, Formatter}; -use std::ops::{Add, Deref, DerefMut, Sub}; -use std::str::FromStr; -use std::time::{Duration, SystemTime}; +use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; +use std::{ + fmt::{Debug, Display, Formatter}, + ops::{Add, Deref, DerefMut, Sub}, + str::FromStr, + time::{Duration, SystemTime}, +}; /// A date and time value, used for storing timestamps in the database. #[derive(Clone, Eq, PartialEq, Hash)] @@ -12,242 +13,242 @@ use std::time::{Duration, SystemTime}; pub struct DateTime(pub fastdate::DateTime); impl Display for DateTime { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "DateTime({})", self.0) - } + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DateTime({})", self.0) + } } impl Serialize for DateTime { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_newtype_struct("DateTime", &self.0) - } + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_newtype_struct("DateTime", &self.0) + } } impl Debug for DateTime { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "DateTime({})", self.0) - } + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DateTime({})", self.0) + } } impl<'de> Deserialize<'de> for DateTime { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let v = Value::deserialize(deserializer)?; - - match v { - Value::I32(u) => Ok(Self(fastdate::DateTime::from_timestamp_millis(i64::from( - u, - )))), - Value::U32(u) => Ok(Self(fastdate::DateTime::from_timestamp_millis(i64::from( - u, - )))), - Value::I64(u) => Ok(Self(fastdate::DateTime::from_timestamp_millis(u))), - Value::U64(u) => Ok(Self(fastdate::DateTime::from_timestamp_millis( - i64::try_from(u).map_err(|e| D::Error::custom(e.to_string()))?, - ))), - Value::String(s) => Ok({ - Self( - fastdate::DateTime::from_str(&s) - .map_err(|e| D::Error::custom(e.to_string()))?, - ) - }), - _ => Err(D::Error::custom( - &format!("unsupported type DateTime({v})",), - )), - } - } + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let v = Value::deserialize(deserializer)?; + + match v { + Value::I32(u) => Ok(Self(fastdate::DateTime::from_timestamp_millis(i64::from( + u, + )))), + Value::U32(u) => Ok(Self(fastdate::DateTime::from_timestamp_millis(i64::from( + u, + )))), + Value::I64(u) => Ok(Self(fastdate::DateTime::from_timestamp_millis(u))), + Value::U64(u) => Ok(Self(fastdate::DateTime::from_timestamp_millis( + i64::try_from(u).map_err(|e| D::Error::custom(e.to_string()))?, + ))), + Value::String(s) => Ok({ + Self( + fastdate::DateTime::from_str(&s) + .map_err(|e| D::Error::custom(e.to_string()))?, + ) + }), + _ => Err(D::Error::custom( + &format!("unsupported type DateTime({v})",), + )), + } + } } impl Deref for DateTime { - type Target = fastdate::DateTime; + type Target = fastdate::DateTime; - fn deref(&self) -> &Self::Target { - &self.0 - } + fn deref(&self) -> &Self::Target { + &self.0 + } } impl DerefMut for DateTime { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } impl DateTime { - #[must_use] - pub fn now() -> Self { - Self(fastdate::DateTime::now()) - } - - #[must_use] - pub fn utc() -> Self { - Self(fastdate::DateTime::utc()) - } - - #[must_use] - pub fn from_timestamp(sec: i64) -> Self { - Self(fastdate::DateTime::from_timestamp(sec)) - } - - #[must_use] - pub fn from_timestamp_millis(ms: i64) -> Self { - Self(fastdate::DateTime::from_timestamp_millis(ms)) - } - - #[must_use] - pub fn from_timestamp_nano(nano: i128) -> Self { - Self(fastdate::DateTime::from_timestamp_nano(nano)) - } - - #[must_use] - pub fn from_system_time(s: SystemTime, offset: i32) -> Self { - Self(fastdate::DateTime::from_system_time(s, offset)) - } + #[must_use] + pub fn now() -> Self { + Self(fastdate::DateTime::now()) + } + + #[must_use] + pub fn utc() -> Self { + Self(fastdate::DateTime::utc()) + } + + #[must_use] + pub fn from_timestamp(sec: i64) -> Self { + Self(fastdate::DateTime::from_timestamp(sec)) + } + + #[must_use] + pub fn from_timestamp_millis(ms: i64) -> Self { + Self(fastdate::DateTime::from_timestamp_millis(ms)) + } + + #[must_use] + pub fn from_timestamp_nano(nano: i128) -> Self { + Self(fastdate::DateTime::from_timestamp_nano(nano)) + } + + #[must_use] + pub fn from_system_time(s: SystemTime, offset: i32) -> Self { + Self(fastdate::DateTime::from_system_time(s, offset)) + } } impl Sub for DateTime { - type Output = Duration; + type Output = Duration; - fn sub(self, rhs: Self) -> Self::Output { - self.0 - rhs.0 - } + fn sub(self, rhs: Self) -> Self::Output { + self.0 - rhs.0 + } } impl Add for DateTime { - type Output = Self; + type Output = Self; - fn add(self, rhs: Duration) -> Self::Output { - Self(self.0.add(rhs)) - } + fn add(self, rhs: Duration) -> Self::Output { + Self(self.0.add(rhs)) + } } impl Sub for DateTime { - type Output = Self; + type Output = Self; - fn sub(self, rhs: Duration) -> Self::Output { - Self(self.0.sub(rhs)) - } + fn sub(self, rhs: Duration) -> Self::Output { + Self(self.0.sub(rhs)) + } } impl FromStr for DateTime { - type Err = rbs::Error; + type Err = rbs::Error; - fn from_str(s: &str) -> Result { - Ok(Self( - fastdate::DateTime::from_str(s).map_err(|e| rbs::Error::Syntax(e.to_string()))?, - )) - } + fn from_str(s: &str) -> Result { + Ok(Self( + fastdate::DateTime::from_str(s).map_err(|e| rbs::Error::Syntax(e.to_string()))?, + )) + } } impl Default for DateTime { - fn default() -> Self { - Self(fastdate::DateTime::from_timestamp(0)) - } + fn default() -> Self { + Self(fastdate::DateTime::from_timestamp(0)) + } } impl From for Value { - fn from(arg: DateTime) -> Self { - Self::Ext("DateTime", Box::new(Self::String(arg.0.to_string()))) - } + fn from(arg: DateTime) -> Self { + Self::Ext("DateTime", Box::new(Self::String(arg.0.to_string()))) + } } #[cfg(feature = "schema")] impl schemars::JsonSchema for DateTime { - fn is_referenceable() -> bool { - false - } - - fn schema_name() -> String { - String::from("date-time") - } - - fn json_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - schemars::schema::SchemaObject { - instance_type: Some(schemars::schema::InstanceType::String.into()), - format: Some("date-time".to_owned()), - ..Default::default() - } - .into() - } + fn is_referenceable() -> bool { + false + } + + fn schema_name() -> String { + String::from("date-time") + } + + fn json_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + schemars::schema::SchemaObject { + instance_type: Some(schemars::schema::InstanceType::String.into()), + format: Some("date-time".to_owned()), + ..Default::default() + } + .into() + } } #[cfg(test)] mod test { - use super::*; - use std::str::FromStr; - - #[test] - fn test_ser_de() { - let dt = DateTime::now(); - let v = serde_json::to_value(&dt).unwrap(); - let new_dt: DateTime = serde_json::from_value(v).unwrap(); - assert_eq!(new_dt, dt); - } - - #[test] - fn test_de() { - let dt = DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap(); - - let v = serde_json::to_value(&dt).unwrap(); - let new_dt: DateTime = serde_json::from_value(v).unwrap(); - assert_eq!(new_dt, dt); - } - - #[test] - fn test_de2() { - let dt = vec![DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap()]; - let v = serde_json::to_value(&dt).unwrap(); - - let new_dt: Vec = serde_json::from_value(v).unwrap(); - assert_eq!(new_dt, dt); - } - - #[test] - fn test_de3() { - let dt = vec![DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap()]; - let v = rbs::to_value!(&dt); - let new_dt: Vec = rbs::from_value(v).unwrap(); - assert_eq!(new_dt, dt); - } - - #[test] - fn test_de4() { - let dt = DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap(); - let v = rbs::to_value!(&dt.unix_timestamp_millis()); - let new_dt: DateTime = rbs::from_value(v).unwrap(); - assert_eq!( - new_dt, - DateTime::from_str("2023-10-20T16:15:00.923Z").unwrap() - ); - } - - #[test] - fn test_de5() { - let dt = DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap(); - let v = serde_json::to_value(dt.unix_timestamp_millis()).unwrap(); - let new_dt: DateTime = serde_json::from_value(v).unwrap(); - assert_eq!( - new_dt, - DateTime::from_str("2023-10-20T16:15:00.923Z").unwrap() - ); - } - - #[test] - fn test_default() { - let dt = DateTime::default(); - - assert_eq!(dt.to_string(), "DateTime(1970-01-01T00:00:00Z)"); - } - - #[test] - fn test_format() { - let dt = DateTime::default(); - let s = dt.format("YYYY-MM-DD/hh/mm/ss"); - - assert_eq!(s, "1970-1-1/0/0/0"); - } + use super::*; + use std::str::FromStr; + + #[test] + fn test_ser_de() { + let dt = DateTime::now(); + let v = serde_json::to_value(&dt).unwrap(); + let new_dt: DateTime = serde_json::from_value(v).unwrap(); + assert_eq!(new_dt, dt); + } + + #[test] + fn test_de() { + let dt = DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap(); + + let v = serde_json::to_value(&dt).unwrap(); + let new_dt: DateTime = serde_json::from_value(v).unwrap(); + assert_eq!(new_dt, dt); + } + + #[test] + fn test_de2() { + let dt = vec![DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap()]; + let v = serde_json::to_value(&dt).unwrap(); + + let new_dt: Vec = serde_json::from_value(v).unwrap(); + assert_eq!(new_dt, dt); + } + + #[test] + fn test_de3() { + let dt = vec![DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap()]; + let v = rbs::to_value!(&dt); + let new_dt: Vec = rbs::from_value(v).unwrap(); + assert_eq!(new_dt, dt); + } + + #[test] + fn test_de4() { + let dt = DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap(); + let v = rbs::to_value!(&dt.unix_timestamp_millis()); + let new_dt: DateTime = rbs::from_value(v).unwrap(); + assert_eq!( + new_dt, + DateTime::from_str("2023-10-20T16:15:00.923Z").unwrap() + ); + } + + #[test] + fn test_de5() { + let dt = DateTime::from_str("2023-10-21T00:15:00.9233333+08:00").unwrap(); + let v = serde_json::to_value(dt.unix_timestamp_millis()).unwrap(); + let new_dt: DateTime = serde_json::from_value(v).unwrap(); + assert_eq!( + new_dt, + DateTime::from_str("2023-10-20T16:15:00.923Z").unwrap() + ); + } + + #[test] + fn test_default() { + let dt = DateTime::default(); + + assert_eq!(dt.to_string(), "DateTime(1970-01-01T00:00:00Z)"); + } + + #[test] + fn test_format() { + let dt = DateTime::default(); + let s = dt.format("YYYY-MM-DD/hh/mm/ss"); + + assert_eq!(s, "1970-1-1/0/0/0"); + } } diff --git a/ensemble/src/types/hashed.rs b/ensemble/src/types/hashed.rs index 2126dce..e69d508 100644 --- a/ensemble/src/types/hashed.rs +++ b/ensemble/src/types/hashed.rs @@ -5,94 +5,94 @@ use std::{fmt::Debug, ops::Deref}; /// A wrapper around a value that has been hashed with SHA-256. #[derive(Clone, Eq, Default)] pub struct Hashed { - hash: String, - _marker: std::marker::PhantomData, + hash: String, + _marker: std::marker::PhantomData, } impl Hashed { - /// Create a new `Hashed` value from the given value. - /// - /// # Example - /// - /// ``` - /// # use ensemble::types::Hashed; - /// let hashed = Hashed::new("hello world"); - /// # assert_eq!(hashed, "hello world") - /// ``` - pub fn new(value: T) -> Self { - Self { - hash: digest(value), - _marker: std::marker::PhantomData, - } - } + /// Create a new `Hashed` value from the given value. + /// + /// # Example + /// + /// ``` + /// # use ensemble::types::Hashed; + /// let hashed = Hashed::new("hello world"); + /// # assert_eq!(hashed, "hello world") + /// ``` + pub fn new(value: T) -> Self { + Self { + hash: digest(value), + _marker: std::marker::PhantomData, + } + } } impl Deref for Hashed { - type Target = String; + type Target = String; - fn deref(&self) -> &Self::Target { - &self.hash - } + fn deref(&self) -> &Self::Target { + &self.hash + } } impl From for Hashed { - fn from(value: T) -> Self { - Self::new(value) - } + fn from(value: T) -> Self { + Self::new(value) + } } impl From> for String { - fn from(val: Hashed) -> Self { - val.hash - } + fn from(val: Hashed) -> Self { + val.hash + } } impl Debug for Hashed { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.hash.fmt(f) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.hash.fmt(f) + } } impl PartialEq for Hashed { - fn eq(&self, other: &Self) -> bool { - self.hash == other.hash - } + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash + } } impl PartialEq for Hashed { - fn eq(&self, other: &String) -> bool { - self.hash == digest(other) - } + fn eq(&self, other: &String) -> bool { + self.hash == digest(other) + } } impl PartialEq<&str> for Hashed { - fn eq(&self, other: &&str) -> bool { - self.hash == digest(*other) - } + fn eq(&self, other: &&str) -> bool { + self.hash == digest(*other) + } } impl Serialize for Hashed { - fn serialize(&self, serializer: S) -> Result { - self.hash.serialize(serializer) - } + fn serialize(&self, serializer: S) -> Result { + self.hash.serialize(serializer) + } } impl<'de, T: Sha256Digest> Deserialize<'de> for Hashed { - fn deserialize>(deserializer: D) -> Result { - Ok(Self { - _marker: std::marker::PhantomData, - hash: String::deserialize(deserializer)?, - }) - } + fn deserialize>(deserializer: D) -> Result { + Ok(Self { + _marker: std::marker::PhantomData, + hash: String::deserialize(deserializer)?, + }) + } } #[cfg(feature = "schema")] impl schemars::JsonSchema for Hashed { - fn schema_name() -> String { - String::schema_name() - } + fn schema_name() -> String { + String::schema_name() + } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - gen.subschema_for::() - } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + gen.subschema_for::() + } } diff --git a/ensemble/src/types/json.rs b/ensemble/src/types/json.rs index 2d93961..714d83d 100644 --- a/ensemble/src/types/json.rs +++ b/ensemble/src/types/json.rs @@ -2,8 +2,8 @@ use schemars::JsonSchema; use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize}; use serde_json::Value; use std::{ - ops::{Deref, DerefMut}, - str::FromStr, + ops::{Deref, DerefMut}, + str::FromStr, }; #[derive(Clone, Eq, PartialEq, Hash, Debug)] @@ -12,101 +12,101 @@ pub struct Json(pub T); #[allow(clippy::module_name_repetitions)] pub trait ToJson { - type Target: Serialize + DeserializeOwned; + type Target: Serialize + DeserializeOwned; - fn to_json(self) -> Json; + fn to_json(self) -> Json; } impl FromStr for Json { - type Err = serde_json::Error; + type Err = serde_json::Error; - fn from_str(s: &str) -> Result { - Ok(Self(serde_json::from_str(s)?)) - } + fn from_str(s: &str) -> Result { + Ok(Self(serde_json::from_str(s)?)) + } } impl From for Json { - fn from(value: Value) -> Self { - Self(value) - } + fn from(value: Value) -> Self { + Self(value) + } } impl ToJson for T { - type Target = T; + type Target = T; - fn to_json(self) -> Json { - Json(self) - } + fn to_json(self) -> Json { + Json(self) + } } impl Serialize for Json { - fn serialize(&self, serializer: S) -> Result { - use serde::ser::Error; - if std::any::type_name::() == std::any::type_name::() { - serializer.serialize_newtype_struct( - "Json", - &serde_json::to_string(&self.0).map_err(|e| Error::custom(e.to_string()))?, - ) - } else { - self.0.serialize(serializer) - } - } + fn serialize(&self, serializer: S) -> Result { + use serde::ser::Error; + if std::any::type_name::() == std::any::type_name::() { + serializer.serialize_newtype_struct( + "Json", + &serde_json::to_string(&self.0).map_err(|e| Error::custom(e.to_string()))?, + ) + } else { + self.0.serialize(serializer) + } + } } impl<'de, T: Serialize + DeserializeOwned> Deserialize<'de> for Json { - fn deserialize>(deserializer: D) -> Result { - use serde::de::Error; - if std::any::type_name::() == std::any::type_name::() { - let mut v = rbs::Value::deserialize(deserializer)?; - if let rbs::Value::Ext(_ty, buf) = v { - v = *buf; - } - - let js; - if let rbs::Value::Binary(buf) = v { - js = String::from_utf8(buf).map_err(|e| Error::custom(e.to_string()))?; - } else if let rbs::Value::String(buf) = v { - js = buf; - } else { - js = v.to_string(); - } - - Ok(Self( - serde_json::from_str(&js).map_err(|e| Error::custom(e.to_string()))?, - )) - } else { - Ok(Self(T::deserialize(deserializer)?)) - } - } + fn deserialize>(deserializer: D) -> Result { + use serde::de::Error; + if std::any::type_name::() == std::any::type_name::() { + let mut v = rbs::Value::deserialize(deserializer)?; + if let rbs::Value::Ext(_ty, buf) = v { + v = *buf; + } + + let js; + if let rbs::Value::Binary(buf) = v { + js = String::from_utf8(buf).map_err(|e| Error::custom(e.to_string()))?; + } else if let rbs::Value::String(buf) = v { + js = buf; + } else { + js = v.to_string(); + } + + Ok(Self( + serde_json::from_str(&js).map_err(|e| Error::custom(e.to_string()))?, + )) + } else { + Ok(Self(T::deserialize(deserializer)?)) + } + } } impl Deref for Json { - type Target = T; + type Target = T; - fn deref(&self) -> &Self::Target { - &self.0 - } + fn deref(&self) -> &Self::Target { + &self.0 + } } impl DerefMut for Json { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } impl Default for Json { - fn default() -> Self { - Self(T::default()) - } + fn default() -> Self { + Self(T::default()) + } } #[cfg(feature = "schema")] impl schemars::JsonSchema for Json { - fn schema_name() -> String { - T::schema_name() - } + fn schema_name() -> String { + T::schema_name() + } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - T::json_schema(gen) - } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + T::json_schema(gen) + } } diff --git a/ensemble/src/types/uuid.rs b/ensemble/src/types/uuid.rs index 6891da8..fa67a7e 100644 --- a/ensemble/src/types/uuid.rs +++ b/ensemble/src/types/uuid.rs @@ -1,7 +1,7 @@ use std::{ - fmt::{Debug, Display, Formatter}, - ops::{Deref, DerefMut}, - str::FromStr, + fmt::{Debug, Display, Formatter}, + ops::{Deref, DerefMut}, + str::FromStr, }; use rbs::Value; @@ -12,73 +12,73 @@ use serde::Deserializer; pub struct Uuid(uuid::Uuid); impl<'de> serde::Deserialize<'de> for Uuid { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - Ok(Self(uuid::Uuid::deserialize(deserializer)?)) - } + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(Self(uuid::Uuid::deserialize(deserializer)?)) + } } impl Display for Uuid { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } } impl Debug for Uuid { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "Uuid({})", self.0) - } + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Uuid({})", self.0) + } } impl From for Value { - fn from(uuid: Uuid) -> Self { - Self::Ext("Uuid", Box::new(Self::String(uuid.0.to_string()))) - } + fn from(uuid: Uuid) -> Self { + Self::Ext("Uuid", Box::new(Self::String(uuid.0.to_string()))) + } } impl FromStr for Uuid { - type Err = uuid::Error; + type Err = uuid::Error; - fn from_str(s: &str) -> Result { - Ok(Self(s.parse()?)) - } + fn from_str(s: &str) -> Result { + Ok(Self(s.parse()?)) + } } impl Uuid { - #[must_use] - pub fn new() -> Self { - Self(uuid::Uuid::new_v4()) - } + #[must_use] + pub fn new() -> Self { + Self(uuid::Uuid::new_v4()) + } - #[must_use] - pub const fn nil() -> Self { - Self(uuid::Uuid::nil()) - } + #[must_use] + pub const fn nil() -> Self { + Self(uuid::Uuid::nil()) + } } impl Deref for Uuid { - type Target = uuid::Uuid; + type Target = uuid::Uuid; - fn deref(&self) -> &Self::Target { - &self.0 - } + fn deref(&self) -> &Self::Target { + &self.0 + } } impl DerefMut for Uuid { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } #[cfg(feature = "schema")] impl schemars::JsonSchema for Uuid { - fn schema_name() -> String { - uuid::Uuid::schema_name() - } + fn schema_name() -> String { + uuid::Uuid::schema_name() + } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - uuid::Uuid::json_schema(gen) - } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + uuid::Uuid::json_schema(gen) + } } diff --git a/ensemble/src/value/de.rs b/ensemble/src/value/de.rs index 845a0ad..6634eaf 100644 --- a/ensemble/src/value/de.rs +++ b/ensemble/src/value/de.rs @@ -1,667 +1,667 @@ use std::{ - fmt::{self, Debug}, - vec::IntoIter, + fmt::{self, Debug}, + vec::IntoIter, }; use rbs::{value::map::ValueMap, Value}; use serde::{ - de::{self, IntoDeserializer, Unexpected, Visitor}, - forward_to_deserialize_any, Deserialize, Deserializer, + de::{self, IntoDeserializer, Unexpected, Visitor}, + forward_to_deserialize_any, Deserialize, Deserializer, }; #[inline] pub fn deserialize_value<'de, T: Deserialize<'de>>(val: rbs::Value) -> Result { - Deserialize::deserialize(ValueDeserializer(val)) + Deserialize::deserialize(ValueDeserializer(val)) } #[repr(transparent)] struct ValueDeserializer(rbs::Value); trait ValueBase<'de>: Deserializer<'de, Error = rbs::Error> { - type Item: ValueBase<'de>; - type MapDeserializer: Deserializer<'de>; - type Iter: ExactSizeIterator; - type MapIter: Iterator; + type Item: ValueBase<'de>; + type MapDeserializer: Deserializer<'de>; + type Iter: ExactSizeIterator; + type MapIter: Iterator; - fn is_null(&self) -> bool; - fn unexpected(&self) -> Unexpected<'_>; + fn is_null(&self) -> bool; + fn unexpected(&self) -> Unexpected<'_>; - fn into_iter(self) -> Result; - fn into_map_iter(self) -> Result; + fn into_iter(self) -> Result; + fn into_map_iter(self) -> Result; } impl<'de> ValueBase<'de> for Value { - type Item = ValueDeserializer; - type Iter = IntoIter; - type MapIter = IntoIter<(Self::Item, Self::Item)>; - type MapDeserializer = MapDeserializer; - - #[inline] - fn is_null(&self) -> bool { - matches!(self, Self::Null) - } - - #[inline] - fn into_iter(self) -> Result { - match self { - Self::Array(v) => Ok(v - .into_iter() - .map(ValueDeserializer) - .collect::>() - .into_iter()), - other => Err(other.into()), - } - } - - #[inline] - fn into_map_iter(self) -> Result { - match self { - Self::Map(v) => Ok(v - .0 - .into_iter() - .map(|(k, v)| (ValueDeserializer(k), ValueDeserializer(v))) - .collect::>() - .into_iter()), - other => Err(other.into()), - } - } - - #[cold] - fn unexpected(&self) -> Unexpected<'_> { - match *self { - Self::Null => Unexpected::Unit, - Self::Map(..) => Unexpected::Map, - Self::F64(v) => Unexpected::Float(v), - Self::Bool(v) => Unexpected::Bool(v), - Self::I64(v) => Unexpected::Signed(v), - Self::U64(v) => Unexpected::Unsigned(v), - Self::Ext(..) | Self::Array(..) => Unexpected::Seq, - Self::F32(v) => Unexpected::Float(f64::from(v)), - Self::I32(v) => Unexpected::Signed(i64::from(v)), - Self::Binary(ref v) => Unexpected::Bytes(v), - Self::U32(v) => Unexpected::Unsigned(u64::from(v)), - Self::String(ref v) => Unexpected::Bytes(v.as_bytes()), - } - } + type Item = ValueDeserializer; + type Iter = IntoIter; + type MapIter = IntoIter<(Self::Item, Self::Item)>; + type MapDeserializer = MapDeserializer; + + #[inline] + fn is_null(&self) -> bool { + matches!(self, Self::Null) + } + + #[inline] + fn into_iter(self) -> Result { + match self { + Self::Array(v) => Ok(v + .into_iter() + .map(ValueDeserializer) + .collect::>() + .into_iter()), + other => Err(other.into()), + } + } + + #[inline] + fn into_map_iter(self) -> Result { + match self { + Self::Map(v) => Ok(v + .0 + .into_iter() + .map(|(k, v)| (ValueDeserializer(k), ValueDeserializer(v))) + .collect::>() + .into_iter()), + other => Err(other.into()), + } + } + + #[cold] + fn unexpected(&self) -> Unexpected<'_> { + match *self { + Self::Null => Unexpected::Unit, + Self::Map(..) => Unexpected::Map, + Self::F64(v) => Unexpected::Float(v), + Self::Bool(v) => Unexpected::Bool(v), + Self::I64(v) => Unexpected::Signed(v), + Self::U64(v) => Unexpected::Unsigned(v), + Self::Ext(..) | Self::Array(..) => Unexpected::Seq, + Self::F32(v) => Unexpected::Float(f64::from(v)), + Self::I32(v) => Unexpected::Signed(i64::from(v)), + Self::Binary(ref v) => Unexpected::Bytes(v), + Self::U32(v) => Unexpected::Unsigned(u64::from(v)), + Self::String(ref v) => Unexpected::Bytes(v.as_bytes()), + } + } } impl<'de> ValueBase<'de> for ValueDeserializer { - type Item = Self; - type Iter = IntoIter; - type MapIter = IntoIter<(Self::Item, Self::Item)>; - type MapDeserializer = MapDeserializer; - - #[inline] - fn is_null(&self) -> bool { - self.0.is_null() - } - - #[inline] - fn into_iter(self) -> Result { - match self.0 { - Value::Array(v) => Ok(v - .into_iter() - .map(ValueDeserializer) - .collect::>() - .into_iter()), - other => Err(other.into()), - } - } - - #[inline] - fn into_map_iter(self) -> Result { - match self.0 { - Value::Map(v) => Ok(v - .0 - .into_iter() - .map(|(k, v)| (Self(k), Self(v))) - .collect::>() - .into_iter()), - other => Err(other.into()), - } - } - - #[cold] - fn unexpected(&self) -> Unexpected<'_> { - match self.0 { - Value::Null => Unexpected::Unit, - Value::Map(..) => Unexpected::Map, - Value::F64(v) => Unexpected::Float(v), - Value::I64(v) => Unexpected::Signed(v), - Value::Bool(v) => Unexpected::Bool(v), - Value::U64(v) => Unexpected::Unsigned(v), - Value::Ext(..) | Value::Array(..) => Unexpected::Seq, - Value::F32(v) => Unexpected::Float(f64::from(v)), - Value::I32(v) => Unexpected::Signed(i64::from(v)), - Value::Binary(ref v) => Unexpected::Bytes(v), - Value::U32(v) => Unexpected::Unsigned(u64::from(v)), - Value::String(ref v) => Unexpected::Bytes(v.as_bytes()), - } - } + type Item = Self; + type Iter = IntoIter; + type MapIter = IntoIter<(Self::Item, Self::Item)>; + type MapDeserializer = MapDeserializer; + + #[inline] + fn is_null(&self) -> bool { + self.0.is_null() + } + + #[inline] + fn into_iter(self) -> Result { + match self.0 { + Value::Array(v) => Ok(v + .into_iter() + .map(ValueDeserializer) + .collect::>() + .into_iter()), + other => Err(other.into()), + } + } + + #[inline] + fn into_map_iter(self) -> Result { + match self.0 { + Value::Map(v) => Ok(v + .0 + .into_iter() + .map(|(k, v)| (Self(k), Self(v))) + .collect::>() + .into_iter()), + other => Err(other.into()), + } + } + + #[cold] + fn unexpected(&self) -> Unexpected<'_> { + match self.0 { + Value::Null => Unexpected::Unit, + Value::Map(..) => Unexpected::Map, + Value::F64(v) => Unexpected::Float(v), + Value::I64(v) => Unexpected::Signed(v), + Value::Bool(v) => Unexpected::Bool(v), + Value::U64(v) => Unexpected::Unsigned(v), + Value::Ext(..) | Value::Array(..) => Unexpected::Seq, + Value::F32(v) => Unexpected::Float(f64::from(v)), + Value::I32(v) => Unexpected::Signed(i64::from(v)), + Value::Binary(ref v) => Unexpected::Bytes(v), + Value::U32(v) => Unexpected::Unsigned(u64::from(v)), + Value::String(ref v) => Unexpected::Bytes(v.as_bytes()), + } + } } impl From for ValueDeserializer { - #[inline] - fn from(value: Value) -> Self { - Self(value) - } + #[inline] + fn from(value: Value) -> Self { + Self(value) + } } impl From for Value { - #[inline] - fn from(value: ValueDeserializer) -> Self { - value.0 - } + #[inline] + fn from(value: ValueDeserializer) -> Self { + value.0 + } } impl<'de> Deserialize<'de> for ValueDeserializer { - #[inline] - #[allow(clippy::too_many_lines)] - fn deserialize(de: D) -> Result - where - D: de::Deserializer<'de>, - { - struct ValueVisitor; - - impl<'de> serde::de::Visitor<'de> for ValueVisitor { - type Value = ValueDeserializer; - - #[cold] - fn expecting(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - "any valid MessagePack value".fmt(fmt) - } - - #[inline] - fn visit_some(self, de: D) -> Result - where - D: de::Deserializer<'de>, - { - Deserialize::deserialize(de) - } - - #[inline] - fn visit_none(self) -> Result { - Ok(Value::Null.into()) - } - - #[inline] - fn visit_unit(self) -> Result { - Ok(Value::Null.into()) - } - - #[inline] - fn visit_bool(self, value: bool) -> Result { - Ok(Value::Bool(value).into()) - } - - fn visit_u32(self, v: u32) -> Result { - Ok(Value::U32(v).into()) - } - - #[inline] - fn visit_u64(self, value: u64) -> Result { - Ok(Value::U64(value).into()) - } - - fn visit_i32(self, v: i32) -> Result { - Ok(Value::I32(v).into()) - } - - #[inline] - fn visit_i64(self, value: i64) -> Result { - Ok(Value::I64(value).into()) - } - - #[inline] - fn visit_f32(self, value: f32) -> Result { - Ok(Value::F32(value).into()) - } - - #[inline] - fn visit_f64(self, value: f64) -> Result { - Ok(Value::F64(value).into()) - } - - #[inline] - fn visit_string(self, value: String) -> Result { - Ok(Value::String(value).into()) - } - - #[inline] - fn visit_str(self, value: &str) -> Result { - self.visit_string(String::from(value)) - } - - #[inline] - fn visit_seq>( - self, - mut visitor: V, - ) -> Result { - let mut vec = { - visitor - .size_hint() - .map_or_else(Vec::new, Vec::with_capacity) - }; - while let Some(elem) = visitor.next_element::()? { - vec.push(elem.into()); - } - Ok(Value::Array(vec).into()) - } - - #[inline] - fn visit_bytes(self, v: &[u8]) -> Result { - Ok(Value::Binary(v.to_owned()).into()) - } - - #[inline] - fn visit_byte_buf(self, v: Vec) -> Result { - Ok(Value::Binary(v).into()) - } - - #[inline] - fn visit_map>( - self, - mut visitor: V, - ) -> Result { - let mut pairs = { - visitor - .size_hint() - .map_or_else(Vec::new, Vec::with_capacity) - }; - while let Some(key) = visitor.next_key::()? { - let val = visitor.next_value::()?; - pairs.push((key.into(), val.into())); - } - - Ok(Value::Map(ValueMap(pairs)).into()) - } - - fn visit_newtype_struct>( - self, - deserializer: D, - ) -> Result { - deserializer.deserialize_newtype_struct("", self) - } - } - - de.deserialize_any(ValueVisitor) - } + #[inline] + #[allow(clippy::too_many_lines)] + fn deserialize(de: D) -> Result + where + D: de::Deserializer<'de>, + { + struct ValueVisitor; + + impl<'de> serde::de::Visitor<'de> for ValueVisitor { + type Value = ValueDeserializer; + + #[cold] + fn expecting(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + "any valid MessagePack value".fmt(fmt) + } + + #[inline] + fn visit_some(self, de: D) -> Result + where + D: de::Deserializer<'de>, + { + Deserialize::deserialize(de) + } + + #[inline] + fn visit_none(self) -> Result { + Ok(Value::Null.into()) + } + + #[inline] + fn visit_unit(self) -> Result { + Ok(Value::Null.into()) + } + + #[inline] + fn visit_bool(self, value: bool) -> Result { + Ok(Value::Bool(value).into()) + } + + fn visit_u32(self, v: u32) -> Result { + Ok(Value::U32(v).into()) + } + + #[inline] + fn visit_u64(self, value: u64) -> Result { + Ok(Value::U64(value).into()) + } + + fn visit_i32(self, v: i32) -> Result { + Ok(Value::I32(v).into()) + } + + #[inline] + fn visit_i64(self, value: i64) -> Result { + Ok(Value::I64(value).into()) + } + + #[inline] + fn visit_f32(self, value: f32) -> Result { + Ok(Value::F32(value).into()) + } + + #[inline] + fn visit_f64(self, value: f64) -> Result { + Ok(Value::F64(value).into()) + } + + #[inline] + fn visit_string(self, value: String) -> Result { + Ok(Value::String(value).into()) + } + + #[inline] + fn visit_str(self, value: &str) -> Result { + self.visit_string(String::from(value)) + } + + #[inline] + fn visit_seq>( + self, + mut visitor: V, + ) -> Result { + let mut vec = { + visitor + .size_hint() + .map_or_else(Vec::new, Vec::with_capacity) + }; + while let Some(elem) = visitor.next_element::()? { + vec.push(elem.into()); + } + Ok(Value::Array(vec).into()) + } + + #[inline] + fn visit_bytes(self, v: &[u8]) -> Result { + Ok(Value::Binary(v.to_owned()).into()) + } + + #[inline] + fn visit_byte_buf(self, v: Vec) -> Result { + Ok(Value::Binary(v).into()) + } + + #[inline] + fn visit_map>( + self, + mut visitor: V, + ) -> Result { + let mut pairs = { + visitor + .size_hint() + .map_or_else(Vec::new, Vec::with_capacity) + }; + while let Some(key) = visitor.next_key::()? { + let val = visitor.next_value::()?; + pairs.push((key.into(), val.into())); + } + + Ok(Value::Map(ValueMap(pairs)).into()) + } + + fn visit_newtype_struct>( + self, + deserializer: D, + ) -> Result { + deserializer.deserialize_newtype_struct("", self) + } + } + + de.deserialize_any(ValueVisitor) + } } impl<'de> Deserializer<'de> for ValueDeserializer { - type Error = rbs::Error; - - fn deserialize_any(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - match self.into() { - Value::Null => visitor.visit_unit(), - Value::I32(v) => visitor.visit_i32(v), - Value::I64(v) => visitor.visit_i64(v), - Value::U32(v) => visitor.visit_u32(v), - Value::U64(v) => visitor.visit_u64(v), - Value::F32(v) => visitor.visit_f32(v), - Value::F64(v) => visitor.visit_f64(v), - Value::Bool(v) => visitor.visit_bool(v), - Value::String(v) => visitor.visit_string(v), - Value::Binary(v) => visitor.visit_byte_buf(v), - Value::Array(v) => { - let len = v.len(); - let mut de = SeqDeserializer { - iter: v.into_iter().map(ValueDeserializer), - }; - let seq = visitor.visit_seq(&mut de)?; - if de.iter.len() == 0 { - Ok(seq) - } else { - Err(de::Error::invalid_length(len, &"fewer elements in array")) - } - } - Value::Map(v) => { - let len = v.len(); - let mut de = MapDeserializer { - val: None, - iter: v.0.into_iter().map(|(k, v)| (Self(k), Self(v))), - }; - let map = visitor.visit_map(&mut de)?; - if de.iter.len() == 0 { - Ok(map) - } else { - Err(de::Error::invalid_length(len, &"fewer elements in map")) - } - } - Value::Ext(_tag, data) => Deserializer::deserialize_any(Self(*data), visitor), - } - } - - #[inline] - fn deserialize_option(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - if self.0.is_null() { - visitor.visit_none() - } else { - visitor.visit_some(self) - } - } - - #[inline] - fn deserialize_enum( - self, - _name: &str, - _variants: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::String(variant) => visitor.visit_enum(variant.into_deserializer()), - Value::Array(iter) => { - let mut iter = iter.into_iter(); - if !(iter.len() == 1 || iter.len() == 2) { - return Err(de::Error::invalid_length( - iter.len(), - &"array with one or two elements", - )); - } - - let id = match iter.next() { - Some(id) => deserialize_value(id)?, - None => { - return Err(de::Error::invalid_value( - Unexpected::Seq, - &"array with one or two elements", - )); - } - }; - - visitor.visit_enum(EnumDeserializer { - id, - value: iter.next(), - }) - } - other => Err(de::Error::invalid_type( - other.unexpected(), - &"string, array, map or int", - )), - } - } - - #[inline] - fn deserialize_newtype_struct( - self, - _name: &'static str, - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - #[inline] - fn deserialize_unit_struct( - self, - _name: &'static str, - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Value::Array(iter) => { - let iter = iter.into_iter(); - - if iter.len() == 0 { - visitor.visit_unit() - } else { - Err(de::Error::invalid_type(Unexpected::Seq, &"empty array")) - } - } - other => Err(de::Error::invalid_type(other.unexpected(), &"empty array")), - } - } - - forward_to_deserialize_any! { - bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq - bytes byte_buf map tuple_struct struct - identifier tuple ignored_any - } + type Error = rbs::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.into() { + Value::Null => visitor.visit_unit(), + Value::I32(v) => visitor.visit_i32(v), + Value::I64(v) => visitor.visit_i64(v), + Value::U32(v) => visitor.visit_u32(v), + Value::U64(v) => visitor.visit_u64(v), + Value::F32(v) => visitor.visit_f32(v), + Value::F64(v) => visitor.visit_f64(v), + Value::Bool(v) => visitor.visit_bool(v), + Value::String(v) => visitor.visit_string(v), + Value::Binary(v) => visitor.visit_byte_buf(v), + Value::Array(v) => { + let len = v.len(); + let mut de = SeqDeserializer { + iter: v.into_iter().map(ValueDeserializer), + }; + let seq = visitor.visit_seq(&mut de)?; + if de.iter.len() == 0 { + Ok(seq) + } else { + Err(de::Error::invalid_length(len, &"fewer elements in array")) + } + }, + Value::Map(v) => { + let len = v.len(); + let mut de = MapDeserializer { + val: None, + iter: v.0.into_iter().map(|(k, v)| (Self(k), Self(v))), + }; + let map = visitor.visit_map(&mut de)?; + if de.iter.len() == 0 { + Ok(map) + } else { + Err(de::Error::invalid_length(len, &"fewer elements in map")) + } + }, + Value::Ext(_tag, data) => Deserializer::deserialize_any(Self(*data), visitor), + } + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.0.is_null() { + visitor.visit_none() + } else { + visitor.visit_some(self) + } + } + + #[inline] + fn deserialize_enum( + self, + _name: &str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + match self.0 { + Value::String(variant) => visitor.visit_enum(variant.into_deserializer()), + Value::Array(iter) => { + let mut iter = iter.into_iter(); + if !(iter.len() == 1 || iter.len() == 2) { + return Err(de::Error::invalid_length( + iter.len(), + &"array with one or two elements", + )); + } + + let id = match iter.next() { + Some(id) => deserialize_value(id)?, + None => { + return Err(de::Error::invalid_value( + Unexpected::Seq, + &"array with one or two elements", + )); + }, + }; + + visitor.visit_enum(EnumDeserializer { + id, + value: iter.next(), + }) + }, + other => Err(de::Error::invalid_type( + other.unexpected(), + &"string, array, map or int", + )), + } + } + + #[inline] + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + #[inline] + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + match self.0 { + Value::Array(iter) => { + let iter = iter.into_iter(); + + if iter.len() == 0 { + visitor.visit_unit() + } else { + Err(de::Error::invalid_type(Unexpected::Seq, &"empty array")) + } + }, + other => Err(de::Error::invalid_type(other.unexpected(), &"empty array")), + } + } + + forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq + bytes byte_buf map tuple_struct struct + identifier tuple ignored_any + } } struct SeqDeserializer { - iter: I, + iter: I, } impl<'de, I, U> de::SeqAccess<'de> for SeqDeserializer where - I: Iterator, - U: Deserializer<'de, Error = rbs::Error>, + I: Iterator, + U: Deserializer<'de, Error = rbs::Error>, { - type Error = rbs::Error; - - fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> - where - T: de::DeserializeSeed<'de>, - { - self.iter - .next() - .map_or_else(|| Ok(None), |val| seed.deserialize(val).map(Some)) - } + type Error = rbs::Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + self.iter + .next() + .map_or_else(|| Ok(None), |val| seed.deserialize(val).map(Some)) + } } impl<'de, I, U> Deserializer<'de> for SeqDeserializer where - I: ExactSizeIterator, - U: Deserializer<'de, Error = rbs::Error>, + I: ExactSizeIterator, + U: Deserializer<'de, Error = rbs::Error>, { - type Error = rbs::Error; - - #[inline] - fn deserialize_any(mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - let len = self.iter.len(); - if len == 0 { - visitor.visit_unit() - } else { - let value = visitor.visit_seq(&mut self)?; - - if self.iter.len() == 0 { - Ok(value) - } else { - Err(de::Error::invalid_length(len, &"fewer elements in array")) - } - } - } - - forward_to_deserialize_any! { - bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option - seq bytes byte_buf map unit_struct newtype_struct - tuple_struct struct identifier tuple enum ignored_any - } + type Error = rbs::Error; + + #[inline] + fn deserialize_any(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let len = self.iter.len(); + if len == 0 { + visitor.visit_unit() + } else { + let value = visitor.visit_seq(&mut self)?; + + if self.iter.len() == 0 { + Ok(value) + } else { + Err(de::Error::invalid_length(len, &"fewer elements in array")) + } + } + } + + forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option + seq bytes byte_buf map unit_struct newtype_struct + tuple_struct struct identifier tuple enum ignored_any + } } struct MapDeserializer { - iter: I, - val: Option, + iter: I, + val: Option, } impl<'de, I, U> de::MapAccess<'de> for MapDeserializer where - I: Iterator, - U: ValueBase<'de>, + I: Iterator, + U: ValueBase<'de>, { - type Error = rbs::Error; - - fn next_key_seed(&mut self, seed: T) -> Result, Self::Error> - where - T: de::DeserializeSeed<'de>, - { - match self.iter.next() { - Some((key, val)) => { - self.val = Some(val); - seed.deserialize(key).map(Some) - } - None => Ok(None), - } - } - - fn next_value_seed(&mut self, seed: T) -> Result - where - T: de::DeserializeSeed<'de>, - { - Option::take(&mut self.val).map_or_else( - || Err(de::Error::custom("value is missing")), - |val| seed.deserialize(val), - ) - } + type Error = rbs::Error; + + fn next_key_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + match self.iter.next() { + Some((key, val)) => { + self.val = Some(val); + seed.deserialize(key).map(Some) + }, + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + Option::take(&mut self.val).map_or_else( + || Err(de::Error::custom("value is missing")), + |val| seed.deserialize(val), + ) + } } impl<'de, I, U> Deserializer<'de> for MapDeserializer where - U: ValueBase<'de>, - I: Iterator, + U: ValueBase<'de>, + I: Iterator, { - type Error = rbs::Error; - - #[inline] - fn deserialize_any(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_map(self) - } - - forward_to_deserialize_any! { - bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option - seq bytes byte_buf map unit_struct newtype_struct - tuple_struct struct identifier tuple enum ignored_any - } + type Error = rbs::Error; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(self) + } + + forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option + seq bytes byte_buf map unit_struct newtype_struct + tuple_struct struct identifier tuple enum ignored_any + } } struct EnumDeserializer { - id: u32, - value: Option, + id: u32, + value: Option, } impl<'de, U: ValueBase<'de>> de::EnumAccess<'de> for EnumDeserializer { - type Error = rbs::Error; - type Variant = VariantDeserializer; - - fn variant_seed>( - self, - seed: V, - ) -> Result<(V::Value, Self::Variant), Self::Error> { - let variant = self.id.into_deserializer(); - let visitor = VariantDeserializer { value: self.value }; - seed.deserialize(variant).map(|v| (v, visitor)) - } + type Error = rbs::Error; + type Variant = VariantDeserializer; + + fn variant_seed>( + self, + seed: V, + ) -> Result<(V::Value, Self::Variant), Self::Error> { + let variant = self.id.into_deserializer(); + let visitor = VariantDeserializer { value: self.value }; + seed.deserialize(variant).map(|v| (v, visitor)) + } } struct VariantDeserializer { - value: Option, + value: Option, } impl<'de, U: ValueBase<'de>> de::VariantAccess<'de> for VariantDeserializer { - type Error = rbs::Error; - - fn unit_variant(self) -> Result<(), Self::Error> { - // Can accept only [u32]. - self.value.map_or(Ok(()), |v| match v.into_iter() { - Ok(ref v) if v.len() == 0 => Ok(()), - Ok(..) => Err(de::Error::invalid_value(Unexpected::Seq, &"empty array")), - Err(v) => Err(de::Error::invalid_value(v.unexpected(), &"empty array")), - }) - } - - fn newtype_variant_seed(self, seed: T) -> Result - where - T: de::DeserializeSeed<'de>, - { - // Can accept both [u32, T...] and [u32, [T]] cases. - match self.value { - Some(v) => match v.into_iter() { - Ok(mut iter) => { - if iter.len() > 1 { - seed.deserialize(SeqDeserializer { iter }) - } else { - let val = match iter.next() { - Some(val) => seed.deserialize(val), - None => { - return Err(de::Error::invalid_value( - Unexpected::Seq, - &"array with one element", - )); - } - }; - - if iter.next().is_some() { - Err(de::Error::invalid_value( - Unexpected::Seq, - &"array with one element", - )) - } else { - val - } - } - } - Err(v) => seed.deserialize(v), - }, - None => Err(de::Error::invalid_type( - Unexpected::UnitVariant, - &"newtype variant", - )), - } - } - - fn tuple_variant(self, _len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - // Can accept [u32, [T...]]. - self.value.map_or_else( - || { - Err(de::Error::invalid_type( - Unexpected::UnitVariant, - &"tuple variant", - )) - }, - |v| match v.into_iter() { - Ok(v) => Deserializer::deserialize_any(SeqDeserializer { iter: v }, visitor), - Err(v) => Err(de::Error::invalid_type(v.unexpected(), &"tuple variant")), - }, - ) - } - - fn struct_variant( - self, - _fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.value.map_or_else( - || { - Err(de::Error::invalid_type( - Unexpected::UnitVariant, - &"struct variant", - )) - }, - |v| match v.into_iter() { - Ok(iter) => Deserializer::deserialize_any(SeqDeserializer { iter }, visitor), - Err(v) => match v.into_map_iter() { - Ok(iter) => { - Deserializer::deserialize_any(MapDeserializer { iter, val: None }, visitor) - } - Err(v) => Err(de::Error::invalid_type(v.unexpected(), &"struct variant")), - }, - }, - ) - } + type Error = rbs::Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + // Can accept only [u32]. + self.value.map_or(Ok(()), |v| match v.into_iter() { + Ok(ref v) if v.len() == 0 => Ok(()), + Ok(..) => Err(de::Error::invalid_value(Unexpected::Seq, &"empty array")), + Err(v) => Err(de::Error::invalid_value(v.unexpected(), &"empty array")), + }) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + // Can accept both [u32, T...] and [u32, [T]] cases. + match self.value { + Some(v) => match v.into_iter() { + Ok(mut iter) => { + if iter.len() > 1 { + seed.deserialize(SeqDeserializer { iter }) + } else { + let val = match iter.next() { + Some(val) => seed.deserialize(val), + None => { + return Err(de::Error::invalid_value( + Unexpected::Seq, + &"array with one element", + )); + }, + }; + + if iter.next().is_some() { + Err(de::Error::invalid_value( + Unexpected::Seq, + &"array with one element", + )) + } else { + val + } + } + }, + Err(v) => seed.deserialize(v), + }, + None => Err(de::Error::invalid_type( + Unexpected::UnitVariant, + &"newtype variant", + )), + } + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + // Can accept [u32, [T...]]. + self.value.map_or_else( + || { + Err(de::Error::invalid_type( + Unexpected::UnitVariant, + &"tuple variant", + )) + }, + |v| match v.into_iter() { + Ok(v) => Deserializer::deserialize_any(SeqDeserializer { iter: v }, visitor), + Err(v) => Err(de::Error::invalid_type(v.unexpected(), &"tuple variant")), + }, + ) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.value.map_or_else( + || { + Err(de::Error::invalid_type( + Unexpected::UnitVariant, + &"struct variant", + )) + }, + |v| match v.into_iter() { + Ok(iter) => Deserializer::deserialize_any(SeqDeserializer { iter }, visitor), + Err(v) => match v.into_map_iter() { + Ok(iter) => { + Deserializer::deserialize_any(MapDeserializer { iter, val: None }, visitor) + }, + Err(v) => Err(de::Error::invalid_type(v.unexpected(), &"struct variant")), + }, + }, + ) + } } diff --git a/ensemble/src/value/mod.rs b/ensemble/src/value/mod.rs index d70533f..112722b 100644 --- a/ensemble/src/value/mod.rs +++ b/ensemble/src/value/mod.rs @@ -12,7 +12,7 @@ mod ser; /// /// Returns an error if serialization fails. pub fn for_db(value: T) -> Result { - fast_serialize(value) + fast_serialize(value) } /// Deserialize a model from the database. @@ -21,9 +21,9 @@ pub fn for_db(value: T) -> Result { /// /// Returns an error if deserialization fails. pub(crate) fn from(value: rbs::Value) -> Result { - deserialize_value::(value) + deserialize_value::(value) } pub(crate) fn serializing_for_db() -> bool { - std::any::type_name::() == std::any::type_name::() + std::any::type_name::() == std::any::type_name::() } diff --git a/ensemble/src/value/ser.rs b/ensemble/src/value/ser.rs index 329deca..ff2ddaf 100644 --- a/ensemble/src/value/ser.rs +++ b/ensemble/src/value/ser.rs @@ -2,533 +2,533 @@ use rbs::{value::map::ValueMap, Value}; use serde::{ser, Serialize}; pub fn fast_serialize(mut value: T) -> Result { - let type_name = std::any::type_name::(); - if type_name == std::any::type_name::() { - let addr = std::ptr::addr_of_mut!(value); - let v = unsafe { &mut *addr.cast() }; - return Ok(std::mem::take(v)); - } - if type_name == std::any::type_name::<&Value>() { - let addr = std::ptr::addr_of!(value); - return Ok(unsafe { *addr.cast::<&Value>() }.clone()); - } - if type_name == std::any::type_name::<&&Value>() { - let addr = std::ptr::addr_of!(value); - return Ok(unsafe { **addr.cast::<&&Value>() }.clone()); - } - value.serialize(Serializer) + let type_name = std::any::type_name::(); + if type_name == std::any::type_name::() { + let addr = std::ptr::addr_of_mut!(value); + let v = unsafe { &mut *addr.cast() }; + return Ok(std::mem::take(v)); + } + if type_name == std::any::type_name::<&Value>() { + let addr = std::ptr::addr_of!(value); + return Ok(unsafe { *addr.cast::<&Value>() }.clone()); + } + if type_name == std::any::type_name::<&&Value>() { + let addr = std::ptr::addr_of!(value); + return Ok(unsafe { **addr.cast::<&&Value>() }.clone()); + } + value.serialize(Serializer) } struct Serializer; impl serde::Serializer for Serializer { - type Ok = rbs::Value; - type Error = rbs::Error; - - type SerializeSeq = SerializeVec; - type SerializeTuple = SerializeVec; - type SerializeMap = DefaultSerializeMap; - type SerializeTupleStruct = SerializeVec; - type SerializeStruct = DefaultSerializeMap; - type SerializeStructVariant = DefaultSerializeMap; - type SerializeTupleVariant = SerializeTupleVariant; - - #[inline] - fn serialize_bool(self, val: bool) -> Result { - Ok(Value::Bool(val)) - } - - #[inline] - fn serialize_i8(self, val: i8) -> Result { - Ok(Value::I32(i32::from(val))) - } - - #[inline] - fn serialize_i16(self, val: i16) -> Result { - Ok(Value::I32(i32::from(val))) - } - - #[inline] - fn serialize_i32(self, val: i32) -> Result { - Ok(Value::I32(val)) - } - - #[inline] - fn serialize_i64(self, val: i64) -> Result { - Ok(Value::I64(val)) - } - - #[inline] - fn serialize_u8(self, val: u8) -> Result { - Ok(Value::U32(u32::from(val))) - } - - #[inline] - fn serialize_u16(self, val: u16) -> Result { - Ok(Value::U32(u32::from(val))) - } - - #[inline] - fn serialize_u32(self, val: u32) -> Result { - Ok(Value::U32(val)) - } - - #[inline] - fn serialize_u64(self, val: u64) -> Result { - Ok(Value::U64(val)) - } - - #[inline] - fn serialize_f32(self, val: f32) -> Result { - Ok(Value::F32(val)) - } - - #[inline] - fn serialize_f64(self, val: f64) -> Result { - Ok(Value::F64(val)) - } - - #[inline] - fn serialize_char(self, val: char) -> Result { - let mut buf = String::new(); - buf.push(val); - self.serialize_str(&buf) - } - - #[inline] - fn serialize_str(self, val: &str) -> Result { - Ok(Value::String(val.into())) - } - - #[inline] - fn serialize_bytes(self, val: &[u8]) -> Result { - Ok(Value::Binary(val.into())) - } - - #[inline] - fn serialize_unit(self) -> Result { - Ok(Value::Null) - } - - #[inline] - fn serialize_unit_struct(self, _name: &'static str) -> Result { - self.serialize_unit() - } - - #[inline] - fn serialize_unit_variant( - self, - _name: &'static str, - _idx: u32, - variant: &'static str, - ) -> Result { - self.serialize_str(variant) - } - - #[inline] - fn serialize_newtype_struct( - self, - name: &'static str, - value: &T, - ) -> Result { - Ok(Value::Ext(name, Box::new(fast_serialize(value)?))) - } - - fn serialize_newtype_variant( - self, - name: &'static str, - _idx: u32, - variant: &'static str, - _value: &T, - ) -> Result { - Err(rbs::Error::Syntax(format!( - "Ensemble does not support enums with values: {name}::{variant}", - ))) - } - - #[inline] - fn serialize_none(self) -> Result { - self.serialize_unit() - } - - #[inline] - fn serialize_some(self, value: &T) -> Result { - fast_serialize(value) - } - - fn serialize_seq(self, len: Option) -> Result { - let se = SerializeVec { - vec: Vec::with_capacity(len.unwrap_or(0)), - }; - Ok(se) - } - - fn serialize_tuple(self, len: usize) -> Result { - self.serialize_seq(Some(len)) - } - - fn serialize_tuple_struct( - self, - _name: &'static str, - len: usize, - ) -> Result { - self.serialize_tuple(len) - } - - fn serialize_tuple_variant( - self, - _name: &'static str, - idx: u32, - _variant: &'static str, - len: usize, - ) -> Result { - let se = SerializeTupleVariant { - idx, - vec: Vec::with_capacity(len), - }; - Ok(se) - } - - fn serialize_map(self, len: Option) -> Result { - let se = DefaultSerializeMap { - next_key: None, - map: Vec::with_capacity(len.unwrap_or(0)), - }; - Ok(se) - } - - #[inline] - fn serialize_struct( - self, - _name: &'static str, - len: usize, - ) -> Result { - let se = DefaultSerializeMap { - next_key: None, - map: Vec::with_capacity(len), - }; - Ok(se) - } - - #[inline] - fn serialize_struct_variant( - self, - _name: &'static str, - _idx: u32, - _variant: &'static str, - len: usize, - ) -> Result { - let se = DefaultSerializeMap { - map: Vec::with_capacity(len), - next_key: None, - }; - Ok(se) - } + type Ok = rbs::Value; + type Error = rbs::Error; + + type SerializeSeq = SerializeVec; + type SerializeTuple = SerializeVec; + type SerializeMap = DefaultSerializeMap; + type SerializeTupleStruct = SerializeVec; + type SerializeStruct = DefaultSerializeMap; + type SerializeStructVariant = DefaultSerializeMap; + type SerializeTupleVariant = SerializeTupleVariant; + + #[inline] + fn serialize_bool(self, val: bool) -> Result { + Ok(Value::Bool(val)) + } + + #[inline] + fn serialize_i8(self, val: i8) -> Result { + Ok(Value::I32(i32::from(val))) + } + + #[inline] + fn serialize_i16(self, val: i16) -> Result { + Ok(Value::I32(i32::from(val))) + } + + #[inline] + fn serialize_i32(self, val: i32) -> Result { + Ok(Value::I32(val)) + } + + #[inline] + fn serialize_i64(self, val: i64) -> Result { + Ok(Value::I64(val)) + } + + #[inline] + fn serialize_u8(self, val: u8) -> Result { + Ok(Value::U32(u32::from(val))) + } + + #[inline] + fn serialize_u16(self, val: u16) -> Result { + Ok(Value::U32(u32::from(val))) + } + + #[inline] + fn serialize_u32(self, val: u32) -> Result { + Ok(Value::U32(val)) + } + + #[inline] + fn serialize_u64(self, val: u64) -> Result { + Ok(Value::U64(val)) + } + + #[inline] + fn serialize_f32(self, val: f32) -> Result { + Ok(Value::F32(val)) + } + + #[inline] + fn serialize_f64(self, val: f64) -> Result { + Ok(Value::F64(val)) + } + + #[inline] + fn serialize_char(self, val: char) -> Result { + let mut buf = String::new(); + buf.push(val); + self.serialize_str(&buf) + } + + #[inline] + fn serialize_str(self, val: &str) -> Result { + Ok(Value::String(val.into())) + } + + #[inline] + fn serialize_bytes(self, val: &[u8]) -> Result { + Ok(Value::Binary(val.into())) + } + + #[inline] + fn serialize_unit(self) -> Result { + Ok(Value::Null) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _idx: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct( + self, + name: &'static str, + value: &T, + ) -> Result { + Ok(Value::Ext(name, Box::new(fast_serialize(value)?))) + } + + fn serialize_newtype_variant( + self, + name: &'static str, + _idx: u32, + variant: &'static str, + _value: &T, + ) -> Result { + Err(rbs::Error::Syntax(format!( + "Ensemble does not support enums with values: {name}::{variant}", + ))) + } + + #[inline] + fn serialize_none(self) -> Result { + self.serialize_unit() + } + + #[inline] + fn serialize_some(self, value: &T) -> Result { + fast_serialize(value) + } + + fn serialize_seq(self, len: Option) -> Result { + let se = SerializeVec { + vec: Vec::with_capacity(len.unwrap_or(0)), + }; + Ok(se) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_tuple(len) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + idx: u32, + _variant: &'static str, + len: usize, + ) -> Result { + let se = SerializeTupleVariant { + idx, + vec: Vec::with_capacity(len), + }; + Ok(se) + } + + fn serialize_map(self, len: Option) -> Result { + let se = DefaultSerializeMap { + next_key: None, + map: Vec::with_capacity(len.unwrap_or(0)), + }; + Ok(se) + } + + #[inline] + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + let se = DefaultSerializeMap { + next_key: None, + map: Vec::with_capacity(len), + }; + Ok(se) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _idx: u32, + _variant: &'static str, + len: usize, + ) -> Result { + let se = DefaultSerializeMap { + map: Vec::with_capacity(len), + next_key: None, + }; + Ok(se) + } } pub struct SerializeVec { - vec: Vec, + vec: Vec, } pub struct SerializeTupleVariant { - idx: u32, - vec: Vec, + idx: u32, + vec: Vec, } pub struct DefaultSerializeMap { - map: Vec<(Value, Value)>, - next_key: Option, + map: Vec<(Value, Value)>, + next_key: Option, } pub struct SerializeStructVariant { - idx: u32, - vec: Vec, + idx: u32, + vec: Vec, } impl ser::SerializeSeq for SerializeVec { - type Ok = Value; - type Error = rbs::Error; - - #[inline] - fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { - self.vec.push(fast_serialize(value)?); - Ok(()) - } - - #[inline] - fn end(self) -> Result { - Ok(Value::Array(self.vec)) - } + type Ok = Value; + type Error = rbs::Error; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { + self.vec.push(fast_serialize(value)?); + Ok(()) + } + + #[inline] + fn end(self) -> Result { + Ok(Value::Array(self.vec)) + } } impl ser::SerializeTuple for SerializeVec { - type Ok = Value; - type Error = rbs::Error; - - #[inline] - fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { - ser::SerializeSeq::serialize_element(self, value) - } - - #[inline] - fn end(self) -> Result { - ser::SerializeSeq::end(self) - } + type Ok = Value; + type Error = rbs::Error; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { + ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result { + ser::SerializeSeq::end(self) + } } impl ser::SerializeTupleStruct for SerializeVec { - type Ok = Value; - type Error = rbs::Error; - - #[inline] - fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> { - ser::SerializeSeq::serialize_element(self, value) - } - - #[inline] - fn end(self) -> Result { - ser::SerializeSeq::end(self) - } + type Ok = Value; + type Error = rbs::Error; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> { + ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result { + ser::SerializeSeq::end(self) + } } impl ser::SerializeTupleVariant for SerializeTupleVariant { - type Ok = Value; - type Error = rbs::Error; - - #[inline] - fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> { - self.vec.push(fast_serialize(value)?); - Ok(()) - } - - #[inline] - fn end(self) -> Result { - Ok(Value::Array(vec![ - Value::from(self.idx), - Value::Array(self.vec), - ])) - } + type Ok = Value; + type Error = rbs::Error; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> { + self.vec.push(fast_serialize(value)?); + Ok(()) + } + + #[inline] + fn end(self) -> Result { + Ok(Value::Array(vec![ + Value::from(self.idx), + Value::Array(self.vec), + ])) + } } impl ser::SerializeMap for DefaultSerializeMap { - type Ok = Value; - type Error = rbs::Error; - - #[inline] - fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> { - self.next_key = Some(fast_serialize(key)?); - Ok(()) - } - - fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> { - let key = self - .next_key - .take() - .expect("`serialize_value` called before `serialize_key`"); - self.map.push((key, fast_serialize(value)?)); - Ok(()) - } - - #[inline] - fn end(self) -> Result { - Ok(Value::Map(ValueMap(self.map))) - } + type Ok = Value; + type Error = rbs::Error; + + #[inline] + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> { + self.next_key = Some(fast_serialize(key)?); + Ok(()) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> { + let key = self + .next_key + .take() + .expect("`serialize_value` called before `serialize_key`"); + self.map.push((key, fast_serialize(value)?)); + Ok(()) + } + + #[inline] + fn end(self) -> Result { + Ok(Value::Map(ValueMap(self.map))) + } } impl ser::SerializeStruct for DefaultSerializeMap { - type Ok = Value; - type Error = rbs::Error; - - fn serialize_field( - &mut self, - key: &'static str, - value: &T, - ) -> Result<(), Self::Error> { - self.map - .push((Value::String(key.to_string()), fast_serialize(value)?)); - Ok(()) - } - - fn end(self) -> Result { - Ok(Value::Map(ValueMap(self.map))) - } + type Ok = Value; + type Error = rbs::Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + self.map + .push((Value::String(key.to_string()), fast_serialize(value)?)); + Ok(()) + } + + fn end(self) -> Result { + Ok(Value::Map(ValueMap(self.map))) + } } impl ser::SerializeStructVariant for DefaultSerializeMap { - type Ok = Value; - type Error = rbs::Error; - - fn serialize_field( - &mut self, - key: &'static str, - value: &T, - ) -> Result<(), Self::Error> { - self.map - .push((Value::String(key.to_string()), fast_serialize(value)?)); - Ok(()) - } - - fn end(self) -> Result { - Ok(Value::Map(ValueMap(self.map))) - } + type Ok = Value; + type Error = rbs::Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + self.map + .push((Value::String(key.to_string()), fast_serialize(value)?)); + Ok(()) + } + + fn end(self) -> Result { + Ok(Value::Map(ValueMap(self.map))) + } } impl ser::SerializeStruct for SerializeVec { - type Ok = Value; - type Error = rbs::Error; - - #[inline] - fn serialize_field( - &mut self, - _key: &'static str, - value: &T, - ) -> Result<(), Self::Error> { - ser::SerializeSeq::serialize_element(self, value) - } - - #[inline] - fn end(self) -> Result { - ser::SerializeSeq::end(self) - } + type Ok = Value; + type Error = rbs::Error; + + #[inline] + fn serialize_field( + &mut self, + _key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result { + ser::SerializeSeq::end(self) + } } impl ser::SerializeStructVariant for SerializeStructVariant { - type Ok = Value; - type Error = rbs::Error; - - #[inline] - fn serialize_field( - &mut self, - _key: &'static str, - value: &T, - ) -> Result<(), Self::Error> { - self.vec.push(fast_serialize(value)?); - Ok(()) - } - - #[inline] - fn end(self) -> Result { - Ok(Value::Array(vec![ - Value::from(self.idx), - Value::Array(self.vec), - ])) - } + type Ok = Value; + type Error = rbs::Error; + + #[inline] + fn serialize_field( + &mut self, + _key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + self.vec.push(fast_serialize(value)?); + Ok(()) + } + + #[inline] + fn end(self) -> Result { + Ok(Value::Array(vec![ + Value::from(self.idx), + Value::Array(self.vec), + ])) + } } #[cfg(test)] mod tests { - use crate::types::{DateTime, Hashed, Json, Uuid}; - - use super::*; - use serde::{Deserialize, Serialize}; - use serde_json::json; - - #[derive(Debug, PartialEq, Serialize, Deserialize)] - struct Test { - a: i32, - b: String, - c: Vec, - } - - #[test] - fn test_serialize() { - let test = Test { - a: 1, - b: "test".to_string(), - c: vec![1, 2, 3], - }; - - assert_eq!( - fast_serialize(test).unwrap(), - rbs::to_value! { - "a" : 1, - "b" : "test", - "c" : [1u32, 2u32, 3u32], - } - ); - } - - #[derive(Debug, PartialEq, Serialize, Deserialize)] - enum Status { - Ok, - Error, - ThirdThing, - } - - #[test] - fn test_serialize_enum() { - assert_eq!(fast_serialize(Status::Ok).unwrap(), rbs::to_value!("Ok")); - assert_eq!( - fast_serialize(Status::Error).unwrap(), - rbs::to_value!("Error") - ); - assert_eq!( - fast_serialize(Status::ThirdThing).unwrap(), - rbs::to_value!("ThirdThing") - ); - } - - #[derive(Debug, PartialEq, Serialize, Deserialize)] - #[serde(rename_all = "snake_case")] - enum StatusV2 { - Ok, - Error, - ThirdThing, - } - - #[test] - fn test_serialize_enum_with_custom_config() { - assert_eq!(fast_serialize(StatusV2::Ok).unwrap(), rbs::to_value!("ok")); - assert_eq!( - fast_serialize(StatusV2::Error).unwrap(), - rbs::to_value!("error") - ); - assert_eq!( - fast_serialize(StatusV2::ThirdThing).unwrap(), - rbs::to_value!("third_thing") - ); - } - - #[test] - fn properly_serializes_datetime() { - let datetime = DateTime::now(); - - assert_eq!( - fast_serialize(&datetime).unwrap(), - Value::Ext("DateTime", Box::new(rbs::to_value!(datetime.0))) - ); - } - - #[test] - fn properly_serializes_uuid() { - let uuid = Uuid::new(); - - assert_eq!( - fast_serialize(&uuid).unwrap(), - Value::Ext("Uuid", Box::new(Value::String(uuid.to_string()))) - ); - } - - #[test] - fn properly_serializes_hashed() { - let hashed = Hashed::new("hello-world"); - - assert_eq!( - fast_serialize(&hashed).unwrap(), - Value::String(hashed.to_string()) - ); - } - - #[test] - fn properly_serializes_json() { - let json = Json(json!({ - "hello": "world", - "foo": "bar", - })); - - assert_eq!( - fast_serialize(&json).unwrap(), - Value::Ext("Json", Box::new(Value::String(json.to_string()))) - ); - } + use crate::types::{DateTime, Hashed, Json, Uuid}; + + use super::*; + use serde::{Deserialize, Serialize}; + use serde_json::json; + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct Test { + a: i32, + b: String, + c: Vec, + } + + #[test] + fn test_serialize() { + let test = Test { + a: 1, + b: "test".to_string(), + c: vec![1, 2, 3], + }; + + assert_eq!( + fast_serialize(test).unwrap(), + rbs::to_value! { + "a" : 1, + "b" : "test", + "c" : [1u32, 2u32, 3u32], + } + ); + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + enum Status { + Ok, + Error, + ThirdThing, + } + + #[test] + fn test_serialize_enum() { + assert_eq!(fast_serialize(Status::Ok).unwrap(), rbs::to_value!("Ok")); + assert_eq!( + fast_serialize(Status::Error).unwrap(), + rbs::to_value!("Error") + ); + assert_eq!( + fast_serialize(Status::ThirdThing).unwrap(), + rbs::to_value!("ThirdThing") + ); + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[serde(rename_all = "snake_case")] + enum StatusV2 { + Ok, + Error, + ThirdThing, + } + + #[test] + fn test_serialize_enum_with_custom_config() { + assert_eq!(fast_serialize(StatusV2::Ok).unwrap(), rbs::to_value!("ok")); + assert_eq!( + fast_serialize(StatusV2::Error).unwrap(), + rbs::to_value!("error") + ); + assert_eq!( + fast_serialize(StatusV2::ThirdThing).unwrap(), + rbs::to_value!("third_thing") + ); + } + + #[test] + fn properly_serializes_datetime() { + let datetime = DateTime::now(); + + assert_eq!( + fast_serialize(&datetime).unwrap(), + Value::Ext("DateTime", Box::new(rbs::to_value!(datetime.0))) + ); + } + + #[test] + fn properly_serializes_uuid() { + let uuid = Uuid::new(); + + assert_eq!( + fast_serialize(&uuid).unwrap(), + Value::Ext("Uuid", Box::new(Value::String(uuid.to_string()))) + ); + } + + #[test] + fn properly_serializes_hashed() { + let hashed = Hashed::new("hello-world"); + + assert_eq!( + fast_serialize(&hashed).unwrap(), + Value::String(hashed.to_string()) + ); + } + + #[test] + fn properly_serializes_json() { + let json = Json(json!({ + "hello": "world", + "foo": "bar", + })); + + assert_eq!( + fast_serialize(&json).unwrap(), + Value::Ext("Json", Box::new(Value::String(json.to_string()))) + ); + } } diff --git a/ensemble_derive/src/column/field.rs b/ensemble_derive/src/column/field.rs index 8ca6ff7..5aaf2a8 100644 --- a/ensemble_derive/src/column/field.rs +++ b/ensemble_derive/src/column/field.rs @@ -3,89 +3,89 @@ use quote::ToTokens; use syn::{spanned::Spanned, Attribute, Expr, FieldsNamed, Lit}; pub struct Fields { - ast: FieldsNamed, - pub fields: Vec, + ast: FieldsNamed, + pub fields: Vec, } impl Fields { - pub fn separate(&self) -> (Vec<&Field>, Vec<&Field>) { - self.fields.iter().partition(|f| f.attr.init) - } + pub fn separate(&self) -> (Vec<&Field>, Vec<&Field>) { + self.fields.iter().partition(|f| f.attr.init) + } } pub struct Field { - pub attr: Attr, - ast: syn::Field, - pub ty: syn::Type, - pub ident: syn::Ident, - pub doc: Option, + pub attr: Attr, + ast: syn::Field, + pub ty: syn::Type, + pub ident: syn::Ident, + pub doc: Option, } #[derive(ExtractAttributes, Default)] #[deluxe(attributes(builder), default)] pub struct Attr { - pub skip: bool, - pub init: bool, - pub into: bool, - pub needs: Option, - #[deluxe(rename = type, append)] - pub types: Vec, - pub rename: Option, + pub skip: bool, + pub init: bool, + pub into: bool, + pub needs: Option, + #[deluxe(rename = type, append)] + pub types: Vec, + pub rename: Option, } impl Field { - pub fn new(mut field: syn::Field) -> Self { - let ident = field.ident.clone().unwrap(); - let attr = Attr::extract_attributes(&mut field.attrs).unwrap(); + pub fn new(mut field: syn::Field) -> Self { + let ident = field.ident.clone().unwrap(); + let attr = Attr::extract_attributes(&mut field.attrs).unwrap(); - Self { - attr, - ident, - ty: field.ty.clone(), - doc: Self::get_doc(&field.attrs), - ast: field, - } - } + Self { + attr, + ident, + ty: field.ty.clone(), + doc: Self::get_doc(&field.attrs), + ast: field, + } + } - fn get_doc(attrs: &[Attribute]) -> Option { - attrs - .iter() - .find(|attr| attr.meta.path().is_ident("doc")) - .and_then(|attr| { - attr.meta.require_name_value().ok().and_then(|meta| { - let Expr::Lit(lit) = &meta.value else { - return None; - }; + fn get_doc(attrs: &[Attribute]) -> Option { + attrs + .iter() + .find(|attr| attr.meta.path().is_ident("doc")) + .and_then(|attr| { + attr.meta.require_name_value().ok().and_then(|meta| { + let Expr::Lit(lit) = &meta.value else { + return None; + }; - match &lit.lit { - Lit::Str(s) => Some(s.value()), - _ => None, - } - }) - }) - } + match &lit.lit { + Lit::Str(s) => Some(s.value()), + _ => None, + } + }) + }) + } - pub fn span(&self) -> proc_macro2::Span { - self.ast.span() - } + pub fn span(&self) -> proc_macro2::Span { + self.ast.span() + } } impl ToTokens for Field { - fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - self.ast.to_tokens(tokens); - } + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + self.ast.to_tokens(tokens); + } } impl ToTokens for Fields { - fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - self.ast.to_tokens(tokens); - } + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + self.ast.to_tokens(tokens); + } } impl From for Fields { - fn from(ast: FieldsNamed) -> Self { - let fields = ast.named.iter().map(|f| Field::new(f.clone())).collect(); + fn from(ast: FieldsNamed) -> Self { + let fields = ast.named.iter().map(|f| Field::new(f.clone())).collect(); - Self { ast, fields } - } + Self { ast, fields } + } } diff --git a/ensemble_derive/src/column/mod.rs b/ensemble_derive/src/column/mod.rs index 9fb4e7d..f923bab 100644 --- a/ensemble_derive/src/column/mod.rs +++ b/ensemble_derive/src/column/mod.rs @@ -7,198 +7,200 @@ use self::field::{Field, Fields}; mod field; pub fn r#impl(ast: &DeriveInput) -> syn::Result { - let syn::Data::Struct(r#struct) = &ast.data else { - return Err(syn::Error::new_spanned( - ast, - "Column derive only supports structs", - )); - }; - - let syn::Fields::Named(struct_fields) = &r#struct.fields else { - return Err(syn::Error::new_spanned( - ast, - "Column derive only supports named fields", - )); - }; - - let fields = Fields::from(struct_fields.clone()); - - let new_impl = impl_new(&fields); - let set_impls = impl_set(&fields)?; - - let name = &ast.ident; - let gen = quote! { - impl #name { - #new_impl - #set_impls - } - }; - - Ok(gen) + let syn::Data::Struct(r#struct) = &ast.data else { + return Err(syn::Error::new_spanned( + ast, + "Column derive only supports structs", + )); + }; + + let syn::Fields::Named(struct_fields) = &r#struct.fields else { + return Err(syn::Error::new_spanned( + ast, + "Column derive only supports named fields", + )); + }; + + let fields = Fields::from(struct_fields.clone()); + + let new_impl = impl_new(&fields); + let set_impls = impl_set(&fields)?; + + let name = &ast.ident; + let gen = quote! { + impl #name { + #new_impl + #set_impls + } + }; + + Ok(gen) } fn impl_new(fields: &Fields) -> TokenStream { - let (init, _) = fields.separate(); - - let init_types = init.iter().map(|f| { - let ty = &f.ty; - let iden = &f.ident; - - quote_spanned!(f.span()=> #iden: #ty) - }); - - let construct = fields.fields.iter().map(|f| { - let iden = &f.ident; - - if f.attr.init { - quote_spanned!(f.span()=> #iden) - } else { - quote_spanned!(f.span()=> #iden: Default::default()) - } - }); - - quote! { - pub fn new(#(#init_types),*) -> Self { - Self { - #(#construct),* - } - } - } + let (init, _) = fields.separate(); + + let init_types = init.iter().map(|f| { + let ty = &f.ty; + let iden = &f.ident; + + quote_spanned!(f.span()=> #iden: #ty) + }); + + let construct = fields.fields.iter().map(|f| { + let iden = &f.ident; + + if f.attr.init { + quote_spanned!(f.span()=> #iden) + } else { + quote_spanned!(f.span()=> #iden: Default::default()) + } + }); + + quote! { + pub fn new(#(#init_types),*) -> Self { + Self { + #(#construct),* + } + } + } } fn impl_set(fields: &Fields) -> syn::Result { - let (_, not_init) = fields.separate(); - - not_init - .iter() - .filter(|f| !f.attr.skip) - .map(|f| { - let ty = &f.ty; - let option = get_option_inner(ty); - let ty = option.unwrap_or(ty); - let is_string = ty.to_token_stream().to_string() == "String"; - let ty = if is_string { - quote_spanned! {f.span()=> &str } - } else { - quote_spanned! {f.span()=> #ty } - }; - let fn_constrain = if f.attr.into { - quote_spanned! {f.span()=> > } - } else { - TokenStream::new() - }; - let fn_ty = if f.attr.into { - quote_spanned! {f.span()=> T } - } else { - quote_spanned! {f.span()=> #ty } - }; - let iden = &f.ident; - let assign = build_assign(f, is_string, option.is_some()); - let doc = f.doc.as_ref().map_or_else(TokenStream::new, |doc| quote_spanned! {f.span()=> #[doc = #doc] }); - - - let only_types = &f.attr.types; - let alias = f - .attr - .rename - .as_ref() - .map_or(iden.clone(), |s| Ident::new(s, iden.span())); - - let types_constraint = if only_types.is_empty() { - TokenStream::new() - } else { - quote_spanned! { - f.span()=> if !matches!(self.r#type, #(#only_types)|*) { - panic!("{} is not a valid option for {} columns.", stringify!(#iden), self.r#type); - } - } - }; - - let needs = build_needs(f)?; - - Ok(quote_spanned! {f.span()=> - #[allow(clippy::return_self_not_must_use, clippy::must_use_candidate)] - #doc - pub fn #alias #fn_constrain (mut self, #iden: #fn_ty) -> Self { - #types_constraint - #needs - - self.#iden = #assign; - self - } - }) - }) - .collect() + let (_, not_init) = fields.separate(); + + not_init + .iter() + .filter(|f| !f.attr.skip) + .map(|f| { + let ty = &f.ty; + let option = get_option_inner(ty); + let ty = option.unwrap_or(ty); + let is_string = ty.to_token_stream().to_string() == "String"; + let ty = if is_string { + quote_spanned! {f.span()=> &str } + } else { + quote_spanned! {f.span()=> #ty } + }; + let fn_constrain = if f.attr.into { + quote_spanned! {f.span()=> > } + } else { + TokenStream::new() + }; + let fn_ty = if f.attr.into { + quote_spanned! {f.span()=> T } + } else { + quote_spanned! {f.span()=> #ty } + }; + let iden = &f.ident; + let assign = build_assign(f, is_string, option.is_some()); + let doc = f.doc.as_ref().map_or_else( + TokenStream::new, + |doc| quote_spanned! {f.span()=> #[doc = #doc] }, + ); + + let only_types = &f.attr.types; + let alias = f + .attr + .rename + .as_ref() + .map_or(iden.clone(), |s| Ident::new(s, iden.span())); + + let types_constraint = if only_types.is_empty() { + TokenStream::new() + } else { + quote_spanned! { + f.span()=> if !matches!(self.r#type, #(#only_types)|*) { + panic!("{} is not a valid option for {} columns.", stringify!(#iden), self.r#type); + } + } + }; + + let needs = build_needs(f)?; + + Ok(quote_spanned! {f.span()=> + #[allow(clippy::return_self_not_must_use, clippy::must_use_candidate)] + #doc + pub fn #alias #fn_constrain (mut self, #iden: #fn_ty) -> Self { + #types_constraint + #needs + + self.#iden = #assign; + self + } + }) + }) + .collect() } fn build_assign(field: &Field, is_string: bool, is_option: bool) -> TokenStream { - let iden = &field.ident; - - let assign = if field.attr.into { - quote_spanned! {field.span()=> #iden.into() } - } else if is_string { - quote_spanned! {field.span()=> #iden.to_string() } - } else { - quote_spanned! {field.span()=> #iden } - }; - - if is_option { - quote_spanned! {field.span()=> Some(#assign) } - } else { - quote_spanned! {field.span()=> #assign } - } + let iden = &field.ident; + + let assign = if field.attr.into { + quote_spanned! {field.span()=> #iden.into() } + } else if is_string { + quote_spanned! {field.span()=> #iden.to_string() } + } else { + quote_spanned! {field.span()=> #iden } + }; + + if is_option { + quote_spanned! {field.span()=> Some(#assign) } + } else { + quote_spanned! {field.span()=> #assign } + } } fn build_needs(field: &Field) -> syn::Result { - let Some(needs) = &field.attr.needs else { - return Ok(TokenStream::new()); - }; - - let Expr::Array(array) = needs else { - return Err(syn::Error::new_spanned( - needs, - "needs must be a path expression", - )); - }; - - let mut tokens = TokenStream::new(); - let segments = &array.elems; - - for (i, segment) in segments.iter().enumerate() { - let ident = &segment; - - if i == 0 { - tokens.extend(quote_spanned! {segment.span()=> !self.#ident }); - } else { - tokens.extend(quote_spanned! {segment.span()=> && !self.#ident }); - } - } - - let iden = &field.ident; - Ok(quote_spanned! {field.span()=> - if #tokens { - panic!("{} requires one of {} to be set.", stringify!(#iden), stringify!(#needs)); - } - }) + let Some(needs) = &field.attr.needs else { + return Ok(TokenStream::new()); + }; + + let Expr::Array(array) = needs else { + return Err(syn::Error::new_spanned( + needs, + "needs must be a path expression", + )); + }; + + let mut tokens = TokenStream::new(); + let segments = &array.elems; + + for (i, segment) in segments.iter().enumerate() { + let ident = &segment; + + if i == 0 { + tokens.extend(quote_spanned! {segment.span()=> !self.#ident }); + } else { + tokens.extend(quote_spanned! {segment.span()=> && !self.#ident }); + } + } + + let iden = &field.ident; + Ok(quote_spanned! {field.span()=> + if #tokens { + panic!("{} requires one of {} to be set.", stringify!(#iden), stringify!(#needs)); + } + }) } fn get_option_inner(r#type: &Type) -> Option<&Type> { - let Type::Path(path) = r#type else { - return None; - // path.path.segments.first().unwrap().ident == "Option" - }; + let Type::Path(path) = r#type else { + return None; + // path.path.segments.first().unwrap().ident == "Option" + }; - if path.path.segments.first().unwrap().ident != "Option" { - return None; - } + if path.path.segments.first().unwrap().ident != "Option" { + return None; + } - let PathArguments::AngleBracketed(args) = &path.path.segments.first().unwrap().arguments else { - return None; - }; + let PathArguments::AngleBracketed(args) = &path.path.segments.first().unwrap().arguments else { + return None; + }; - let GenericArgument::Type(ty) = args.args.first().unwrap() else { - return None; - }; + let GenericArgument::Type(ty) = args.args.first().unwrap() else { + return None; + }; - Some(ty) + Some(ty) } diff --git a/ensemble_derive/src/lib.rs b/ensemble_derive/src/lib.rs index f52ddc4..53db55e 100644 --- a/ensemble_derive/src/lib.rs +++ b/ensemble_derive/src/lib.rs @@ -8,58 +8,58 @@ mod model; #[proc_macro_derive(Model, attributes(ensemble, model, validate))] pub fn derive_model(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let mut ast = parse_macro_input!(input as DeriveInput); - let opts = match deluxe::extract_attributes(&mut ast) { - Ok(opts) => opts, - Err(e) => return e.into_compile_error().into(), - }; + let mut ast = parse_macro_input!(input as DeriveInput); + let opts = match deluxe::extract_attributes(&mut ast) { + Ok(opts) => opts, + Err(e) => return e.into_compile_error().into(), + }; - model::r#impl(&ast, opts) - .unwrap_or_else(syn::Error::into_compile_error) - .into() + model::r#impl(&ast, opts) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } #[proc_macro_derive(Column, attributes(builder))] pub fn derive_column(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let ast = parse_macro_input!(input as DeriveInput); + let ast = parse_macro_input!(input as DeriveInput); - column::r#impl(&ast) - .unwrap_or_else(syn::Error::into_compile_error) - .into() + column::r#impl(&ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } #[derive(Clone, Copy)] pub(crate) enum Relationship { - HasOne, - HasMany, - BelongsTo, - BelongsToMany, + HasOne, + HasMany, + BelongsTo, + BelongsToMany, } impl Display for Relationship { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::HasOne => "HasOne", - Self::HasMany => "HasMany", - Self::BelongsTo => "BelongsTo", - Self::BelongsToMany => "BelongsToMany", - } - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::HasOne => "HasOne", + Self::HasMany => "HasMany", + Self::BelongsTo => "BelongsTo", + Self::BelongsToMany => "BelongsToMany", + } + ) + } } #[allow(clippy::fallible_impl_from)] impl From for Relationship { - fn from(value: String) -> Self { - match value.as_str() { - "HasOne" => Self::HasOne, - "HasMany" => Self::HasMany, - "BelongsTo" => Self::BelongsTo, - "BelongsToMany" => Self::BelongsToMany, - _ => panic!("Unknown relationship found."), - } - } + fn from(value: String) -> Self { + match value.as_str() { + "HasOne" => Self::HasOne, + "HasMany" => Self::HasMany, + "BelongsTo" => Self::BelongsTo, + "BelongsToMany" => Self::BelongsToMany, + _ => panic!("Unknown relationship found."), + } + } } diff --git a/ensemble_derive/src/model/default/mod.rs b/ensemble_derive/src/model/default/mod.rs index 85a05dd..caf88dc 100644 --- a/ensemble_derive/src/model/default/mod.rs +++ b/ensemble_derive/src/model/default/mod.rs @@ -7,53 +7,53 @@ use super::field::Fields; #[derive(Debug, Default)] pub enum Value { - #[default] - Default, - Expr(Expr), + #[default] + Default, + Expr(Expr), } impl ParseMetaItem for Value { - fn parse_meta_item(input: ParseStream, _mode: ParseMode) -> syn::Result { - Ok(Self::Expr(input.parse::()?)) - } + fn parse_meta_item(input: ParseStream, _mode: ParseMode) -> syn::Result { + Ok(Self::Expr(input.parse::()?)) + } - fn parse_meta_item_flag(_: Span) -> syn::Result { - Ok(Self::Default) - } + fn parse_meta_item_flag(_: Span) -> syn::Result { + Ok(Self::Default) + } } #[derive(Debug, ParseMetaItem, Default)] #[deluxe(default)] pub struct Options { - pub uuid: bool, - pub created_at: bool, - pub updated_at: bool, - pub incrementing: Option, - #[deluxe(rename = default)] - pub value: Option, + pub uuid: bool, + pub created_at: bool, + pub updated_at: bool, + pub incrementing: Option, + #[deluxe(rename = default)] + pub value: Option, } pub fn r#impl(name: &Ident, fields: &Fields) -> syn::Result { - let mut defaults = vec![]; - let primary_key = fields.primary_key()?; - - for field in &fields.fields { - let ident = &field.ident; - let default = field - .default(name, primary_key)? - .unwrap_or_else(|| quote_spanned! { field.span() => Default::default() }); - - defaults.push(quote_spanned! { field.span() => #ident: #default }); - } - - Ok(quote! { - #[automatically_derived] - impl core::default::Default for #name { - fn default() -> Self { - Self { - #(#defaults,)* - } - } - } - }) + let mut defaults = vec![]; + let primary_key = fields.primary_key()?; + + for field in &fields.fields { + let ident = &field.ident; + let default = field + .default(name, primary_key)? + .unwrap_or_else(|| quote_spanned! { field.span() => Default::default() }); + + defaults.push(quote_spanned! { field.span() => #ident: #default }); + } + + Ok(quote! { + #[automatically_derived] + impl core::default::Default for #name { + fn default() -> Self { + Self { + #(#defaults,)* + } + } + } + }) } diff --git a/ensemble_derive/src/model/field.rs b/ensemble_derive/src/model/field.rs index 818e2d3..9b40113 100644 --- a/ensemble_derive/src/model/field.rs +++ b/ensemble_derive/src/model/field.rs @@ -10,191 +10,191 @@ use crate::Relationship; use super::default::{self, Value}; pub struct Fields { - ast: FieldsNamed, - pub fields: Vec, + ast: FieldsNamed, + pub fields: Vec, } pub struct Field { - pub attr: Attr, - ast: syn::Field, - pub ty: syn::Type, - pub ident: syn::Ident, - pub has_validation: bool, + pub attr: Attr, + ast: syn::Field, + pub ty: syn::Type, + pub ident: syn::Ident, + pub has_validation: bool, } #[derive(Debug, ExtractAttributes, Default)] #[deluxe(attributes(validate), default)] struct ValidationAttr { - #[deluxe(rest)] - rules: HashMap, + #[deluxe(rest)] + rules: HashMap, } #[allow(clippy::struct_excessive_bools)] #[derive(ExtractAttributes, Default)] #[deluxe(attributes(model), default)] pub struct Attr { - #[cfg(feature = "json")] - pub hide: bool, - #[cfg(feature = "json")] - pub show: bool, - pub primary: bool, - pub column: Option, - pub local_key: Option, - pub foreign_key: Option, - pub pivot_table: Option, - - #[deluxe(flatten)] - pub default: default::Options, - #[deluxe(skip)] - pub used_in_relationship: bool, + #[cfg(feature = "json")] + pub hide: bool, + #[cfg(feature = "json")] + pub show: bool, + pub primary: bool, + pub column: Option, + pub local_key: Option, + pub foreign_key: Option, + pub pivot_table: Option, + + #[deluxe(flatten)] + pub default: default::Options, + #[deluxe(skip)] + pub used_in_relationship: bool, } impl Field { - pub fn new(mut field: syn::Field) -> Self { - let ident = field.ident.clone().unwrap(); - let mut attr = Attr::extract_attributes(&mut field.attrs).unwrap(); - let validation = ValidationAttr::extract_attributes(&mut field.attrs).unwrap(); - - #[cfg(feature = "json")] - { - attr.hide |= ident == "password"; - } - attr.default.created_at |= ident == "created_at"; - attr.default.updated_at |= ident == "updated_at"; - - Self { - attr, - ident, - ty: field.ty.clone(), - ast: field, - has_validation: !validation.rules.is_empty(), - } - } - - pub fn span(&self) -> proc_macro2::Span { - self.ast.span() - } - - pub fn default(&self, name: &Ident, primary_key: &Self) -> syn::Result> { - let attrs = &self.attr.default; - let is_primary = primary_key.ident == self.ident; - let is_u64 = self.ty.to_token_stream().to_string() == "u64"; - - Ok(if let Some(default) = &attrs.value { - match default { - Value::Expr(expr) => Some(quote_spanned! { self.span() => #expr }), - Value::Default => Some(quote_spanned! { self.span() => Default::default() }), - } - } else if attrs.uuid { - let Type::Path(ty) = &self.ty else { - return Err(syn::Error::new_spanned( - self, - "Field must be of type ensemble::types::Uuid", - )); - }; - - if ty.path.segments.last().unwrap().ident != "Uuid" { - return Err(syn::Error::new_spanned( - ty, - "Field must be of type ensemble::types::Uuid", - )); - } - - Some(quote_spanned! { self.span() => <#ty>::new() }) - } else if attrs.incrementing.unwrap_or(is_primary && is_u64) { - Some(quote_spanned! { self.span() => 0 }) - } else if attrs.created_at || attrs.updated_at { - let Type::Path(ty) = &self.ty else { - return Err(syn::Error::new_spanned( - &self.ty, - "Field must be of type ensemble::types::DateTime", - )); - }; - - Some(quote_spanned! { self.span() => <#ty>::now() }) - } else if let Some((relationship_type, related, _)) = self.relationship(primary_key) { - let relationship_ident = Ident::new(&relationship_type.to_string(), self.span()); - let foreign_key = self.foreign_key(relationship_type, &related); - - if self.attr.column == Some(self.ident.to_string()) { - return Err(syn::Error::new_spanned( - self, - "You cannot name a relationship field the same as the column it references.", - )); - } - - Some( - quote_spanned! { self.span() => <#relationship_ident<#name, #related>>::build(Default::default(), #foreign_key) }, - ) - } else if self.ty.to_token_stream().to_string().starts_with("Option") { - Some(quote_spanned! { self.span() => None }) - } else { - None - }) - } - - pub(crate) fn foreign_key( - &self, - relationship_type: Relationship, - related: &Ident, - ) -> TokenStream { - match relationship_type { - Relationship::BelongsToMany => { - let local_key = wrap_option(self.attr.local_key.clone()); - let pivot_table = wrap_option(self.attr.pivot_table.clone()); - let foreign_key = wrap_option(self.attr.foreign_key.clone()); - - quote_spanned! {self.span()=> (#pivot_table, #foreign_key, #local_key) } - } - Relationship::BelongsTo => { - quote_spanned! {self.span()=> Some(#related::PRIMARY_KEY.to_string()) } - } - _ => wrap_option(self.attr.foreign_key.clone()), - } - } - - pub fn has_relationship(&self) -> bool { - let Type::Path(ty) = &self.ty else { - return false; - }; - - let Some(ty) = ty.path.segments.first() else { - return false; - }; - - ["HasOne", "HasMany", "BelongsTo", "BelongsToMany"].contains(&ty.ident.to_string().as_str()) - } - - pub(crate) fn relationship( - &self, - primary_key: &Self, - ) -> Option<(Relationship, Ident, (String, TokenStream))> { - let Type::Path(ty) = &self.ty else { - return None; - }; - - let Some(ty) = ty.path.segments.first() else { - return None; - }; - - let relationship_type = ty.ident.to_string(); - if !["HasOne", "HasMany", "BelongsTo", "BelongsToMany"] - .contains(&relationship_type.as_str()) - { - return None; - } - let relationship_type: Relationship = relationship_type.into(); - - let PathArguments::AngleBracketed(ty) = &ty.arguments else { - panic!("Expected generic argument"); - }; - let GenericArgument::Type(Type::Path(ty)) = ty.args.last().unwrap() else { - panic!("Expected generic argument"); - }; - - let related = &ty.path.segments.first().unwrap().ident; - - let value_key = match relationship_type { + pub fn new(mut field: syn::Field) -> Self { + let ident = field.ident.clone().unwrap(); + let mut attr = Attr::extract_attributes(&mut field.attrs).unwrap(); + let validation = ValidationAttr::extract_attributes(&mut field.attrs).unwrap(); + + #[cfg(feature = "json")] + { + attr.hide |= ident == "password"; + } + attr.default.created_at |= ident == "created_at"; + attr.default.updated_at |= ident == "updated_at"; + + Self { + attr, + ident, + ty: field.ty.clone(), + ast: field, + has_validation: !validation.rules.is_empty(), + } + } + + pub fn span(&self) -> proc_macro2::Span { + self.ast.span() + } + + pub fn default(&self, name: &Ident, primary_key: &Self) -> syn::Result> { + let attrs = &self.attr.default; + let is_primary = primary_key.ident == self.ident; + let is_u64 = self.ty.to_token_stream().to_string() == "u64"; + + Ok(if let Some(default) = &attrs.value { + match default { + Value::Expr(expr) => Some(quote_spanned! { self.span() => #expr }), + Value::Default => Some(quote_spanned! { self.span() => Default::default() }), + } + } else if attrs.uuid { + let Type::Path(ty) = &self.ty else { + return Err(syn::Error::new_spanned( + self, + "Field must be of type ensemble::types::Uuid", + )); + }; + + if ty.path.segments.last().unwrap().ident != "Uuid" { + return Err(syn::Error::new_spanned( + ty, + "Field must be of type ensemble::types::Uuid", + )); + } + + Some(quote_spanned! { self.span() => <#ty>::new() }) + } else if attrs.incrementing.unwrap_or(is_primary && is_u64) { + Some(quote_spanned! { self.span() => 0 }) + } else if attrs.created_at || attrs.updated_at { + let Type::Path(ty) = &self.ty else { + return Err(syn::Error::new_spanned( + &self.ty, + "Field must be of type ensemble::types::DateTime", + )); + }; + + Some(quote_spanned! { self.span() => <#ty>::now() }) + } else if let Some((relationship_type, related, _)) = self.relationship(primary_key) { + let relationship_ident = Ident::new(&relationship_type.to_string(), self.span()); + let foreign_key = self.foreign_key(relationship_type, &related); + + if self.attr.column == Some(self.ident.to_string()) { + return Err(syn::Error::new_spanned( + self, + "You cannot name a relationship field the same as the column it references.", + )); + } + + Some( + quote_spanned! { self.span() => <#relationship_ident<#name, #related>>::build(Default::default(), #foreign_key) }, + ) + } else if self.ty.to_token_stream().to_string().starts_with("Option") { + Some(quote_spanned! { self.span() => None }) + } else { + None + }) + } + + pub(crate) fn foreign_key( + &self, + relationship_type: Relationship, + related: &Ident, + ) -> TokenStream { + match relationship_type { + Relationship::BelongsToMany => { + let local_key = wrap_option(self.attr.local_key.clone()); + let pivot_table = wrap_option(self.attr.pivot_table.clone()); + let foreign_key = wrap_option(self.attr.foreign_key.clone()); + + quote_spanned! {self.span()=> (#pivot_table, #foreign_key, #local_key) } + }, + Relationship::BelongsTo => { + quote_spanned! {self.span()=> Some(#related::PRIMARY_KEY.to_string()) } + }, + _ => wrap_option(self.attr.foreign_key.clone()), + } + } + + pub fn has_relationship(&self) -> bool { + let Type::Path(ty) = &self.ty else { + return false; + }; + + let Some(ty) = ty.path.segments.first() else { + return false; + }; + + ["HasOne", "HasMany", "BelongsTo", "BelongsToMany"].contains(&ty.ident.to_string().as_str()) + } + + pub(crate) fn relationship( + &self, + primary_key: &Self, + ) -> Option<(Relationship, Ident, (String, TokenStream))> { + let Type::Path(ty) = &self.ty else { + return None; + }; + + let Some(ty) = ty.path.segments.first() else { + return None; + }; + + let relationship_type = ty.ident.to_string(); + if !["HasOne", "HasMany", "BelongsTo", "BelongsToMany"] + .contains(&relationship_type.as_str()) + { + return None; + } + let relationship_type: Relationship = relationship_type.into(); + + let PathArguments::AngleBracketed(ty) = &ty.arguments else { + panic!("Expected generic argument"); + }; + let GenericArgument::Type(Type::Path(ty)) = ty.args.last().unwrap() else { + panic!("Expected generic argument"); + }; + + let related = &ty.path.segments.first().unwrap().ident; + + let value_key = match relationship_type { Relationship::BelongsToMany | Relationship::HasOne | Relationship::HasMany => ( primary_key.ident.to_string(), primary_key.ident.to_token_stream(), @@ -208,97 +208,97 @@ impl Field { ), }; - Some((relationship_type, related.clone(), value_key)) - } + Some((relationship_type, related.clone(), value_key)) + } } impl ToTokens for Field { - fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - self.ast.to_tokens(tokens); - } + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + self.ast.to_tokens(tokens); + } } impl Fields { - pub fn should_validate(&self) -> bool { - self.fields.iter().any(|f| f.has_validation) - } - - pub fn primary_key(&self) -> syn::Result<&Field> { - let mut primary = None; - let mut id_field = None; - - for field in &self.fields { - if field.attr.primary { - if primary.is_some() { - return Err(syn::Error::new_spanned( - field, - "Only one field can be marked as primary", - )); - } - - primary = Some(field); - } else if field.ident == "id" { - id_field = Some(field); - } - } - - primary.or(id_field).ok_or_else(|| { - syn::Error::new_spanned( + pub fn should_validate(&self) -> bool { + self.fields.iter().any(|f| f.has_validation) + } + + pub fn primary_key(&self) -> syn::Result<&Field> { + let mut primary = None; + let mut id_field = None; + + for field in &self.fields { + if field.attr.primary { + if primary.is_some() { + return Err(syn::Error::new_spanned( + field, + "Only one field can be marked as primary", + )); + } + + primary = Some(field); + } else if field.ident == "id" { + id_field = Some(field); + } + } + + primary.or(id_field).ok_or_else(|| { + syn::Error::new_spanned( self, "No primary key found. Either mark a field with `#[model(primary)]` or name it `id`.", ) - }) - } - - pub fn relationships(&self) -> Vec<&Field> { - self.fields - .iter() - .filter(|f| f.has_relationship()) - .collect() - } - - pub fn mark_relationship_keys(&mut self) -> syn::Result<()> { - let primary_key = self.primary_key()?; - let relationship_keys = self - .relationships() - .iter() - .filter_map(|f| f.relationship(primary_key)) - .map(|(_, _, (key, _))| key) - .collect::>(); - - self.fields - .iter_mut() - .filter(|f| relationship_keys.contains(&f.ident.to_string())) - .for_each(|f| { - f.attr.used_in_relationship = true; - }); - - Ok(()) - } + }) + } + + pub fn relationships(&self) -> Vec<&Field> { + self.fields + .iter() + .filter(|f| f.has_relationship()) + .collect() + } + + pub fn mark_relationship_keys(&mut self) -> syn::Result<()> { + let primary_key = self.primary_key()?; + let relationship_keys = self + .relationships() + .iter() + .filter_map(|f| f.relationship(primary_key)) + .map(|(_, _, (key, _))| key) + .collect::>(); + + self.fields + .iter_mut() + .filter(|f| relationship_keys.contains(&f.ident.to_string())) + .for_each(|f| { + f.attr.used_in_relationship = true; + }); + + Ok(()) + } } impl ToTokens for Fields { - fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - self.ast.to_tokens(tokens); - } + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + self.ast.to_tokens(tokens); + } } impl TryFrom for Fields { - type Error = syn::Error; - fn try_from(ast: FieldsNamed) -> Result { - let fields = ast.named.iter().map(|f| Field::new(f.clone())).collect(); + type Error = syn::Error; + fn try_from(ast: FieldsNamed) -> Result { + let fields = ast.named.iter().map(|f| Field::new(f.clone())).collect(); - let mut fields = Self { ast, fields }; + let mut fields = Self { ast, fields }; - fields.mark_relationship_keys()?; + fields.mark_relationship_keys()?; - Ok(fields) - } + Ok(fields) + } } fn wrap_option(option: Option) -> TokenStream { - option.map_or_else( - || quote! { None }, - |value| quote! { Some(#value.to_string()) }, - ) + option.map_or_else( + || quote! { None }, + |value| quote! { Some(#value.to_string()) }, + ) } diff --git a/ensemble_derive/src/model/mod.rs b/ensemble_derive/src/model/mod.rs index 2938d4b..137973f 100644 --- a/ensemble_derive/src/model/mod.rs +++ b/ensemble_derive/src/model/mod.rs @@ -16,93 +16,92 @@ mod serde; #[derive(ExtractAttributes, Default)] #[deluxe(attributes(ensemble), default)] pub struct Opts { - #[deluxe(rename = table)] - table_name: Option, + #[deluxe(rename = table)] + table_name: Option, } pub fn r#impl(ast: &DeriveInput, opts: Opts) -> syn::Result { - let syn::Data::Struct(r#struct) = &ast.data else { - return Err(syn::Error::new_spanned( - ast, - "Model derive only supports structs", - )); - }; - - let syn::Fields::Named(struct_fields) = &r#struct.fields else { - return Err(syn::Error::new_spanned( - ast, - "Model derive only supports named fields", - )); - }; - - let fields = Fields::try_from(struct_fields.clone())?; - let primary_key = fields.primary_key()?; - - let find_impl = impl_find(primary_key); - let fresh_impl = impl_fresh(primary_key); - let eager_load_impl = impl_eager_load(&fields); - let save_impl = impl_save(&fields, primary_key); - let primary_key_impl = impl_primary_key(primary_key); - let fill_relation_impl = impl_fill_relation(&fields); - let serde_impl = serde::r#impl(&ast.ident, &fields)?; - let default_impl = default::r#impl(&ast.ident, &fields)?; - let create_impl = impl_create(&ast.ident, &fields, primary_key); - let relationships_impl = impl_relationships(&ast.ident, &fields)?; - let table_name_impl = impl_table_name(&ast.ident.to_string(), opts.table_name); - - let name = &ast.ident; - let primary_key_type = &primary_key.ty; - let gen = quote! { - const _: () = { - use ::ensemble::relationships::Relationship; - #[automatically_derived] - #[ensemble::async_trait] - impl Model for #name { - type PrimaryKey = #primary_key_type; - const NAME: &'static str = stringify!(#name); - - #save_impl - #find_impl - #fresh_impl - #create_impl - #table_name_impl - #eager_load_impl - #primary_key_impl - #fill_relation_impl - } - #serde_impl - #default_impl - #relationships_impl - }; - }; - - Ok(gen) + let syn::Data::Struct(r#struct) = &ast.data else { + return Err(syn::Error::new_spanned( + ast, + "Model derive only supports structs", + )); + }; + + let syn::Fields::Named(struct_fields) = &r#struct.fields else { + return Err(syn::Error::new_spanned( + ast, + "Model derive only supports named fields", + )); + }; + + let fields = Fields::try_from(struct_fields.clone())?; + let primary_key = fields.primary_key()?; + + let find_impl = impl_find(primary_key); + let fresh_impl = impl_fresh(primary_key); + let eager_load_impl = impl_eager_load(&fields); + let save_impl = impl_save(&fields, primary_key); + let primary_key_impl = impl_primary_key(primary_key); + let fill_relation_impl = impl_fill_relation(&fields); + let serde_impl = serde::r#impl(&ast.ident, &fields)?; + let default_impl = default::r#impl(&ast.ident, &fields)?; + let create_impl = impl_create(&ast.ident, &fields, primary_key); + let relationships_impl = impl_relationships(&ast.ident, &fields)?; + let table_name_impl = impl_table_name(&ast.ident.to_string(), opts.table_name); + + let name = &ast.ident; + let primary_key_type = &primary_key.ty; + let gen = quote! { + const _: () = { + use ::ensemble::relationships::Relationship; + #[automatically_derived] + impl Model for #name { + type PrimaryKey = #primary_key_type; + const NAME: &'static str = stringify!(#name); + + #save_impl + #find_impl + #fresh_impl + #create_impl + #table_name_impl + #eager_load_impl + #primary_key_impl + #fill_relation_impl + } + #serde_impl + #default_impl + #relationships_impl + }; + }; + + Ok(gen) } fn impl_fill_relation(fields: &Fields) -> TokenStream { - let relationships = fields.relationships(); - - let fill_relation = relationships.iter().map(|field| { - let ident = &field.ident; - - quote_spanned! {field.span() => - stringify!(#ident) => self.#ident.r#match(related), - } - }); - - quote! { - fn fill_relation(&mut self, relation: &str, related: &[::std::collections::HashMap<::std::string::String, ::ensemble::rbs::Value>]) -> Result<(), ::ensemble::Error> { - match relation { - #(#fill_relation)* - _ => panic!("Model does not have a {relation} relation"), - } - } - } + let relationships = fields.relationships(); + + let fill_relation = relationships.iter().map(|field| { + let ident = &field.ident; + + quote_spanned! {field.span() => + stringify!(#ident) => self.#ident.r#match(related), + } + }); + + quote! { + fn fill_relation(&mut self, relation: &str, related: &[::std::collections::HashMap<::std::string::String, ::ensemble::rbs::Value>]) -> Result<(), ::ensemble::Error> { + match relation { + #(#fill_relation)* + _ => panic!("Model does not have a {relation} relation"), + } + } + } } fn impl_eager_load(fields: &Fields) -> TokenStream { - let relationships = fields.relationships(); + let relationships = fields.relationships(); - let eager_loads = relationships.iter().map(|field| { + let eager_loads = relationships.iter().map(|field| { let ident = &field.ident; quote_spanned! {field.span() => @@ -110,206 +109,206 @@ fn impl_eager_load(fields: &Fields) -> TokenStream { } }); - quote! { - #[allow(clippy::cloned_instead_of_copied)] - fn eager_load(&self, relation: &str, related: &[&Self]) -> ::ensemble::query::Builder { - match relation { - #(#eager_loads)* - _ => panic!("Model does not have a {relation} relation"), - } - } - } + quote! { + #[allow(clippy::cloned_instead_of_copied)] + fn eager_load(&self, relation: &str, related: &[&Self]) -> ::ensemble::query::Builder { + match relation { + #(#eager_loads)* + _ => panic!("Model does not have a {relation} relation"), + } + } + } } fn impl_fresh(primary_key: &Field) -> TokenStream { - let ident = &primary_key.ident; + let ident = &primary_key.ident; - quote! { - async fn fresh(&self) -> Result { - Self::find(self.#ident.clone()).await - } - } + quote! { + async fn fresh(&self) -> Result { + Self::find(self.#ident.clone()).await + } + } } fn impl_relationships(name: &Ident, fields: &Fields) -> syn::Result { - let primary_key = fields.primary_key()?; - let relationships = fields.relationships(); - - if relationships.is_empty() { - return Ok(TokenStream::new()); - } - - let impls = relationships.iter().map(|f| { - let ident = &f.ident; - let (r#type, related, _) = f.relationship(primary_key).unwrap(); - let return_type = match r#type { - Relationship::HasMany | Relationship::BelongsToMany => { - quote! { ::std::vec::Vec<#related> } - } - Relationship::HasOne | Relationship::BelongsTo => { - quote! { #related } - } - }; - - quote_spanned! {f.span() => - #[allow(dead_code)] - pub async fn #ident(&mut self) -> Result<&mut #return_type, ::ensemble::Error> { - self.#ident.get().await - } - } - }); - - Ok(quote! { - impl #name { - #(#impls)* - } - }) + let primary_key = fields.primary_key()?; + let relationships = fields.relationships(); + + if relationships.is_empty() { + return Ok(TokenStream::new()); + } + + let impls = relationships.iter().map(|f| { + let ident = &f.ident; + let (r#type, related, _) = f.relationship(primary_key).unwrap(); + let return_type = match r#type { + Relationship::HasMany | Relationship::BelongsToMany => { + quote! { ::std::vec::Vec<#related> } + }, + Relationship::HasOne | Relationship::BelongsTo => { + quote! { #related } + }, + }; + + quote_spanned! {f.span() => + #[allow(dead_code)] + pub async fn #ident(&mut self) -> Result<&mut #return_type, ::ensemble::Error> { + self.#ident.get().await + } + } + }); + + Ok(quote! { + impl #name { + #(#impls)* + } + }) } fn impl_save(fields: &Fields, primary_key: &Field) -> TokenStream { - let ident = &primary_key.ident; - let run_validation = if fields.should_validate() { - quote! { - self.validate()?; - } - } else { - TokenStream::new() - }; - let update_timestamp = fields - .fields - .iter() - .filter(|f| f.attr.default.updated_at) - .map(|field| { - let ident = &field.ident; - - quote_spanned! {field.span() => - self.#ident = ::ensemble::types::DateTime::now(); - } - }) - .collect::(); - - quote! { - async fn save(&mut self) -> Result<(), ::ensemble::Error> { - #update_timestamp - #run_validation - - let rows_affected = Self::query() - .r#where(Self::PRIMARY_KEY, "=", &self.#ident) - .update(::ensemble::value::for_db(self)?) - .await?; - - if rows_affected != 1 { - return Err(::ensemble::Error::UniqueViolation); - } - - Ok(()) - } - } + let ident = &primary_key.ident; + let run_validation = if fields.should_validate() { + quote! { + self.validate()?; + } + } else { + TokenStream::new() + }; + let update_timestamp = fields + .fields + .iter() + .filter(|f| f.attr.default.updated_at) + .map(|field| { + let ident = &field.ident; + + quote_spanned! {field.span() => + self.#ident = ::ensemble::types::DateTime::now(); + } + }) + .collect::(); + + quote! { + async fn save(&mut self) -> Result<(), ::ensemble::Error> { + #update_timestamp + #run_validation + + let rows_affected = Self::query() + .r#where(Self::PRIMARY_KEY, "=", &self.#ident) + .update(::ensemble::value::for_db(self)?) + .await?; + + if rows_affected != 1 { + return Err(::ensemble::Error::UniqueViolation); + } + + Ok(()) + } + } } fn impl_find(primary_key: &Field) -> TokenStream { - let ident = &primary_key.ident; - - quote! { - async fn find(#ident: Self::PrimaryKey) -> Result { - Self::query() - .r#where(Self::PRIMARY_KEY, "=", ::ensemble::value::for_db(#ident)?) - .first() - .await? - .ok_or(::ensemble::Error::NotFound) - } - } + let ident = &primary_key.ident; + + quote! { + async fn find(#ident: Self::PrimaryKey) -> Result { + Self::query() + .r#where(Self::PRIMARY_KEY, "=", ::ensemble::value::for_db(#ident)?) + .first() + .await? + .ok_or(::ensemble::Error::NotFound) + } + } } fn impl_create(name: &Ident, fields: &Fields, primary_key: &Field) -> TokenStream { - let is_primary_u64 = (&primary_key.ty).into_token_stream().to_string() == "u64"; - - let required = fields - .fields - .iter() - .filter(|f| { - f.default(name, primary_key) - .map(|o| o.is_none()) - .unwrap_or(false) - }) - .map(|field| { - let ty = &field.ty; - let ident = &field.ident; - - quote_spanned! {field.span() => - if self.#ident == <#ty>::default() { - return Err(::ensemble::Error::Required(stringify!(#ident))); - } - } - }); - - let run_validation = if fields.should_validate() { - quote! { - self.validate()?; - } - } else { - TokenStream::new() - }; - - let update_timestamps = fields - .fields - .iter() - .filter(|f| f.attr.default.created_at || f.attr.default.updated_at) - .map(|field| { - let ident = &field.ident; - - quote_spanned! {field.span() => - self.#ident = ::ensemble::types::DateTime::now(); - } - }); - - let insert_and_return = if primary_key - .attr - .default - .incrementing - .unwrap_or(is_primary_u64) - { - let primary_key = &primary_key.ident; - quote! { - self.#primary_key = Self::query().insert(::ensemble::value::for_db(&self)?).await?; - - Ok(self) - } - } else { - quote! { - Self::query().insert(::ensemble::value::for_db(&self)?).await?; - - Ok(self) - } - }; - - quote! { - async fn create(mut self) -> Result { - #(#update_timestamps)* - #run_validation - #(#required)* - #insert_and_return - } - } + let is_primary_u64 = (&primary_key.ty).into_token_stream().to_string() == "u64"; + + let required = fields + .fields + .iter() + .filter(|f| { + f.default(name, primary_key) + .map(|o| o.is_none()) + .unwrap_or(false) + }) + .map(|field| { + let ty = &field.ty; + let ident = &field.ident; + + quote_spanned! {field.span() => + if self.#ident == <#ty>::default() { + return Err(::ensemble::Error::Required(stringify!(#ident))); + } + } + }); + + let run_validation = if fields.should_validate() { + quote! { + self.validate()?; + } + } else { + TokenStream::new() + }; + + let update_timestamps = fields + .fields + .iter() + .filter(|f| f.attr.default.created_at || f.attr.default.updated_at) + .map(|field| { + let ident = &field.ident; + + quote_spanned! {field.span() => + self.#ident = ::ensemble::types::DateTime::now(); + } + }); + + let insert_and_return = if primary_key + .attr + .default + .incrementing + .unwrap_or(is_primary_u64) + { + let primary_key = &primary_key.ident; + quote! { + self.#primary_key = Self::query().insert(::ensemble::value::for_db(&self)?).await?; + + Ok(self) + } + } else { + quote! { + Self::query().insert(::ensemble::value::for_db(&self)?).await?; + + Ok(self) + } + }; + + quote! { + async fn create(mut self) -> Result { + #(#update_timestamps)* + #run_validation + #(#required)* + #insert_and_return + } + } } fn impl_primary_key(primary_key: &Field) -> TokenStream { - let ident = &primary_key.ident; + let ident = &primary_key.ident; - quote! { - const PRIMARY_KEY: &'static str = stringify!(#ident); + quote! { + const PRIMARY_KEY: &'static str = stringify!(#ident); - fn primary_key(&self) -> &Self::PrimaryKey { - &self.#ident - } - } + fn primary_key(&self) -> &Self::PrimaryKey { + &self.#ident + } + } } fn impl_table_name(struct_name: &str, custom_name: Option) -> TokenStream { - let table_name = - custom_name.unwrap_or_else(|| pluralize(&struct_name.to_snake_case(), 2, false)); + let table_name = + custom_name.unwrap_or_else(|| pluralize(&struct_name.to_snake_case(), 2, false)); - quote! { - const TABLE_NAME: &'static str = #table_name; - } + quote! { + const TABLE_NAME: &'static str = #table_name; + } } diff --git a/ensemble_derive/src/model/serde.rs b/ensemble_derive/src/model/serde.rs index ac0c7b5..07e8148 100644 --- a/ensemble_derive/src/model/serde.rs +++ b/ensemble_derive/src/model/serde.rs @@ -7,214 +7,214 @@ use super::field::Fields; use crate::Relationship; pub fn r#impl(name: &Ident, fields: &Fields) -> syn::Result { - let mut serde = impl_serialize(name, fields)?; - serde.extend(impl_deserialize(name, fields)); + let mut serde = impl_serialize(name, fields)?; + serde.extend(impl_deserialize(name, fields)); - Ok(serde) + Ok(serde) } pub fn impl_serialize(name: &Ident, fields: &Fields) -> syn::Result { - let count = fields.fields.len(); - let primary_key = fields.primary_key()?; - - let serialize_for_db = fields.fields.iter().filter_map(|field| { - let ident = &field.ident; - let column = field - .attr - .column - .as_ref() - .map_or(field.ident.clone(), |v| Ident::new(v, field.span())); - - let Some((relationship_type, _, (_, key_expr))) = field.relationship(primary_key) else { - return Some(quote_spanned! {field.span()=> - state.serialize_field(stringify!(#column), &self.#ident)?; - }); - }; - - match relationship_type { - Relationship::BelongsTo => {} - _ => return None, - }; - - Some(quote_spanned! {field.span()=> { - let key: &'static str = #key_expr.leak(); - state.serialize_field(key, &self.#ident)?; - }}) - }); - - let general_serialize = fields.fields.iter().filter_map(|field| { - #[cfg(feature = "json")] - if field.attr.hide && !field.attr.show { - return None; - } - - let ident = &field.ident; - let column = field - .attr - .column - .as_ref() - .map_or(field.ident.clone(), |v| Ident::new(v, field.span())); - - Some(if field.has_relationship() { - quote_spanned! {field.span()=> - if self.#ident.is_loaded() { - state.serialize_field(stringify!(#column), &self.#ident)?; - } - } - } else { - quote_spanned! {field.span()=> - state.serialize_field(stringify!(#column), &self.#ident)?; - } - }) - }); - - let serialize_fields = quote! { - // ugly hack to figure out if we're serializing for rbs. might break in future (or previous) versions of rust. - if ::std::any::type_name::() == ::std::any::type_name::<::ensemble::rbs::Error>() { - #(#serialize_for_db)* - } else { - #(#general_serialize)* - } - }; - - Ok(quote! { - const _: () = { - use ::ensemble::Inflector; - use ::ensemble::serde::ser::SerializeStruct; - #[automatically_derived] - impl ::ensemble::serde::Serialize for #name { - fn serialize(&self, serializer: S) -> Result { - let mut state = serializer.serialize_struct(stringify!(#name), #count)?; - #serialize_fields - state.end() - } - } - }; - }) + let count = fields.fields.len(); + let primary_key = fields.primary_key()?; + + let serialize_for_db = fields.fields.iter().filter_map(|field| { + let ident = &field.ident; + let column = field + .attr + .column + .as_ref() + .map_or(field.ident.clone(), |v| Ident::new(v, field.span())); + + let Some((relationship_type, _, (_, key_expr))) = field.relationship(primary_key) else { + return Some(quote_spanned! {field.span()=> + state.serialize_field(stringify!(#column), &self.#ident)?; + }); + }; + + match relationship_type { + Relationship::BelongsTo => {}, + _ => return None, + }; + + Some(quote_spanned! {field.span()=> { + let key: &'static str = #key_expr.leak(); + state.serialize_field(key, &self.#ident)?; + }}) + }); + + let general_serialize = fields.fields.iter().filter_map(|field| { + #[cfg(feature = "json")] + if field.attr.hide && !field.attr.show { + return None; + } + + let ident = &field.ident; + let column = field + .attr + .column + .as_ref() + .map_or(field.ident.clone(), |v| Ident::new(v, field.span())); + + Some(if field.has_relationship() { + quote_spanned! {field.span()=> + if self.#ident.is_loaded() { + state.serialize_field(stringify!(#column), &self.#ident)?; + } + } + } else { + quote_spanned! {field.span()=> + state.serialize_field(stringify!(#column), &self.#ident)?; + } + }) + }); + + let serialize_fields = quote! { + // ugly hack to figure out if we're serializing for rbs. might break in future (or previous) versions of rust. + if ::std::any::type_name::() == ::std::any::type_name::<::ensemble::rbs::Error>() { + #(#serialize_for_db)* + } else { + #(#general_serialize)* + } + }; + + Ok(quote! { + const _: () = { + use ::ensemble::Inflector; + use ::ensemble::serde::ser::SerializeStruct; + #[automatically_derived] + impl ::ensemble::serde::Serialize for #name { + fn serialize(&self, serializer: S) -> Result { + let mut state = serializer.serialize_struct(stringify!(#name), #count)?; + #serialize_fields + state.end() + } + } + }; + }) } pub fn impl_deserialize(name: &Ident, fields: &Fields) -> syn::Result { - let visitor_name = Ident::new( - &format!("__{}", format!("{name} Visitor").to_class_case()), - name.span(), - ); - let enum_key = &fields - .fields - .iter() - .filter_map(|f| { - if f.has_relationship() { - return None; - } - - Some(Ident::new(&f.ident.to_string().to_class_case(), f.span())) - }) - .collect::>(); - - let column = &fields - .fields - .iter() - .filter_map(|f| { - if f.has_relationship() { - return None; - } - - Some( - f.attr - .column - .as_ref() - .map_or(f.ident.clone(), |v| Ident::new(v, f.span())), - ) - }) - .collect::>(); - - let field_deserialize = field_deserialize(column, enum_key); - let visitor_deserialize = visitor_deserialize(name, &visitor_name, fields, column, enum_key)?; - - Ok(quote! { - const _: () = { - use ensemble::Inflector; - use ::ensemble::serde as _serde; - use _serde::de::IntoDeserializer; - use ensemble::relationships::Relationship; - - #[automatically_derived] - impl<'de> _serde::Deserialize<'de> for #name { - fn deserialize>(deserializer: D) -> Result { - enum Field { #(#enum_key,)* Other(String) }; - #field_deserialize - - struct #visitor_name; - #visitor_deserialize - - const FIELDS: &'static [&'static str] = &[#(stringify!(#column)),*]; - - deserializer.deserialize_struct(stringify!(#name), FIELDS, #visitor_name {}) - } - } - }; - }) + let visitor_name = Ident::new( + &format!("__{}", format!("{name} Visitor").to_class_case()), + name.span(), + ); + let enum_key = &fields + .fields + .iter() + .filter_map(|f| { + if f.has_relationship() { + return None; + } + + Some(Ident::new(&f.ident.to_string().to_class_case(), f.span())) + }) + .collect::>(); + + let column = &fields + .fields + .iter() + .filter_map(|f| { + if f.has_relationship() { + return None; + } + + Some( + f.attr + .column + .as_ref() + .map_or(f.ident.clone(), |v| Ident::new(v, f.span())), + ) + }) + .collect::>(); + + let field_deserialize = field_deserialize(column, enum_key); + let visitor_deserialize = visitor_deserialize(name, &visitor_name, fields, column, enum_key)?; + + Ok(quote! { + const _: () = { + use ensemble::Inflector; + use ::ensemble::serde as _serde; + use _serde::de::IntoDeserializer; + use ensemble::relationships::Relationship; + + #[automatically_derived] + impl<'de> _serde::Deserialize<'de> for #name { + fn deserialize>(deserializer: D) -> Result { + enum Field { #(#enum_key,)* Other(String) }; + #field_deserialize + + struct #visitor_name; + #visitor_deserialize + + const FIELDS: &'static [&'static str] = &[#(stringify!(#column)),*]; + + deserializer.deserialize_struct(stringify!(#name), FIELDS, #visitor_name {}) + } + } + }; + }) } fn field_deserialize(column: &Rc<[Ident]>, enum_key: &Rc<[Ident]>) -> TokenStream { - let expecting_str = column - .iter() - .map(|f| format!("`{f}`")) - .collect::>() - .join(" or "); - - quote! { - impl<'de> _serde::Deserialize<'de> for Field { - fn deserialize>(deserializer: D) -> Result { - struct FieldVisitor; - - impl<'de> _serde::de::Visitor<'de> for FieldVisitor { - type Value = Field; - - fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - formatter.write_str(#expecting_str) - } - - fn visit_str(self, value: &str) -> Result { - match value { - #(stringify!(#column) => Ok(Field::#enum_key),)* - _ => { - Ok(Field::Other(::std::string::ToString::to_string(value))) - }, - } - } - } - - deserializer.deserialize_identifier(FieldVisitor) - } - } - } + let expecting_str = column + .iter() + .map(|f| format!("`{f}`")) + .collect::>() + .join(" or "); + + quote! { + impl<'de> _serde::Deserialize<'de> for Field { + fn deserialize>(deserializer: D) -> Result { + struct FieldVisitor; + + impl<'de> _serde::de::Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + formatter.write_str(#expecting_str) + } + + fn visit_str(self, value: &str) -> Result { + match value { + #(stringify!(#column) => Ok(Field::#enum_key),)* + _ => { + Ok(Field::Other(::std::string::ToString::to_string(value))) + }, + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + } } #[allow(clippy::too_many_lines)] fn visitor_deserialize( - name: &Ident, - visitor_name: &Ident, - fields: &Fields, - column: &Rc<[Ident]>, - enum_key: &Rc<[Ident]>, + name: &Ident, + visitor_name: &Ident, + fields: &Fields, + column: &Rc<[Ident]>, + enum_key: &Rc<[Ident]>, ) -> syn::Result { - let primary_key = fields.primary_key()?; - let key = &fields - .fields - .iter() - .filter(|f| !f.has_relationship()) - .map(|f| &f.ident) - .collect::>(); - - let needs_collect = fields.fields.iter().any(|f| { - let Some((relationship_type, _, _)) = f.relationship(primary_key) else { - return false; - }; - - matches!(relationship_type, Relationship::BelongsTo) - }); - - let required_checks = fields.fields.iter().filter_map(|f| { + let primary_key = fields.primary_key()?; + let key = &fields + .fields + .iter() + .filter(|f| !f.has_relationship()) + .map(|f| &f.ident) + .collect::>(); + + let needs_collect = fields.fields.iter().any(|f| { + let Some((relationship_type, ..)) = f.relationship(primary_key) else { + return false; + }; + + matches!(relationship_type, Relationship::BelongsTo) + }); + + let required_checks = fields.fields.iter().filter_map(|f| { let ident = &f.ident; let column = f .attr @@ -231,17 +231,17 @@ fn visitor_deserialize( Some(quote_spanned! {f.span()=> let #ident: #ty = #ident.ok_or_else(|| _serde::de::Error::missing_field(stringify!(#column)))?; }) }); - let ensure_no_leftovers = if needs_collect { - quote! { - if let Some(key) = __collect.keys().next() { - return Err(_serde::de::Error::unknown_field(&key, FIELDS)); - } - } - } else { - TokenStream::new() - }; - - let model_keys = fields.fields.iter().map(|f| { + let ensure_no_leftovers = if needs_collect { + quote! { + if let Some(key) = __collect.keys().next() { + return Err(_serde::de::Error::unknown_field(&key, FIELDS)); + } + } + } else { + TokenStream::new() + }; + + let model_keys = fields.fields.iter().map(|f| { let ident = &f.ident; let ty = &f.ty; @@ -276,63 +276,63 @@ fn visitor_deserialize( quote_spanned! {f.span()=> #ident: <#relationship_ident<#name, #related>>::build(#key_ident.clone(), #foreign_key) } }); - let build_model = quote! { - let __model = #name { #(#model_keys),* }; - #ensure_no_leftovers - Ok(__model) - }; - - let init_collect = if needs_collect { - quote! { - let mut __collect = ::std::collections::HashMap::::new(); - } - } else { - TokenStream::new() - }; - - let handle_unknown_field = if needs_collect { - quote! { - __collect.insert(name, map.next_value()?); - } - } else { - quote! { - return Err(_serde::de::Error::unknown_field(&name, FIELDS)); - } - }; - - Ok(quote! { - #[allow(clippy::clone_on_copy, clippy::redundant_clone)] - impl<'de> _serde::de::Visitor<'de> for #visitor_name { - type Value = #name; - - fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - formatter.write_str(&format!("struct {}", stringify!(#name))) - } - - fn visit_map>(self, mut map: V) -> Result<#name, V::Error> { - #(let mut #key = None;)* - #init_collect - - while let Some(key) = map.next_key()? { - match key { - #( - Field::#enum_key => { - if #key.is_some() { - return Err(_serde::de::Error::duplicate_field(stringify!(#column))); - } - #key = Some(map.next_value()?); - }, - )* - Field::Other(name) => { - #handle_unknown_field - } - } - } - - #(#required_checks)* - - #build_model - } - } - }) + let build_model = quote! { + let __model = #name { #(#model_keys),* }; + #ensure_no_leftovers + Ok(__model) + }; + + let init_collect = if needs_collect { + quote! { + let mut __collect = ::std::collections::HashMap::::new(); + } + } else { + TokenStream::new() + }; + + let handle_unknown_field = if needs_collect { + quote! { + __collect.insert(name, map.next_value()?); + } + } else { + quote! { + return Err(_serde::de::Error::unknown_field(&name, FIELDS)); + } + }; + + Ok(quote! { + #[allow(clippy::clone_on_copy, clippy::redundant_clone)] + impl<'de> _serde::de::Visitor<'de> for #visitor_name { + type Value = #name; + + fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + formatter.write_str(&format!("struct {}", stringify!(#name))) + } + + fn visit_map>(self, mut map: V) -> Result<#name, V::Error> { + #(let mut #key = None;)* + #init_collect + + while let Some(key) = map.next_key()? { + match key { + #( + Field::#enum_key => { + if #key.is_some() { + return Err(_serde::de::Error::duplicate_field(stringify!(#column))); + } + #key = Some(map.next_value()?); + }, + )* + Field::Other(name) => { + #handle_unknown_field + } + } + } + + #(#required_checks)* + + #build_model + } + } + }) } diff --git a/examples/user/src/main.rs b/examples/user/src/main.rs index 62b5737..26610fb 100644 --- a/examples/user/src/main.rs +++ b/examples/user/src/main.rs @@ -1,26 +1,26 @@ use ensemble::{ - types::{DateTime, Hashed}, - Model, + types::{DateTime, Hashed}, + Model, }; use std::env; use validator::Validate; #[derive(Debug, Model, Validate)] pub struct User { - pub id: u64, - pub name: String, - pub email: String, - pub password: Hashed, - pub created_at: DateTime, - pub updated_at: DateTime, + pub id: u64, + pub name: String, + pub email: String, + pub password: Hashed, + pub created_at: DateTime, + pub updated_at: DateTime, } #[tokio::main] async fn main() { - ensemble::setup(&env::var("DATABASE_URL").expect("DATABASE_URL must be set")) - .expect("Failed to set up database pool."); + ensemble::setup(&env::var("DATABASE_URL").expect("DATABASE_URL must be set")) + .expect("Failed to set up database pool."); - let users = User::all().await.unwrap(); + let users = User::all().await.unwrap(); - dbg!(users); + dbg!(users); } diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..5c94ef6 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,8 @@ +tab_spaces = 4 +hard_tabs = true +edition = "2021" +use_try_shorthand = true +imports_granularity = "Crate" +use_field_init_shorthand = true +condense_wildcard_suffixes = true +match_block_trailing_comma = true diff --git a/test_suite/tests/derive.rs b/test_suite/tests/derive.rs index cef32fd..24a29c8 100644 --- a/test_suite/tests/derive.rs +++ b/test_suite/tests/derive.rs @@ -1,3 +1,3 @@ mod derive { - automod::dir!("tests/derive"); + automod::dir!("tests/derive"); }