From 050c700d37859bde34505988da6840b457f1024d Mon Sep 17 00:00:00 2001 From: Miguel Piedrafita Date: Thu, 14 Dec 2023 21:01:23 +0100 Subject: [PATCH] Add increment method --- ensemble/src/lib.rs | 24 ++++++++++++++++++++++++ ensemble/src/query.rs | 25 +++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/ensemble/src/lib.rs b/ensemble/src/lib.rs index 1e47364..74630c7 100644 --- a/ensemble/src/lib.rs +++ b/ensemble/src/lib.rs @@ -158,6 +158,7 @@ pub trait Model: DeserializeOwned + Serialize + Sized + Send + Sync + Debug + De Self::query().with(eager_load) } + /// Load a relationship for the model. fn load + Send>( &mut self, relation: T, @@ -173,6 +174,29 @@ pub trait Model: DeserializeOwned + Serialize + Sized + Send + Sync + Debug + De } } + fn increment( + &mut self, + column: &str, + amount: u64, + ) -> impl Future> + Send { + async move { + let rows_affected = Self::query() + .r#where( + Self::PRIMARY_KEY, + "=", + value::for_db(self.primary_key()).unwrap(), + ) + .increment(column, amount) + .await?; + + if rows_affected != 1 { + return Err(Error::UniqueViolation); + } + + Ok(()) + } + } + /// Convert the model to a JSON value. /// /// # Panics diff --git a/ensemble/src/query.rs b/ensemble/src/query.rs index 4dbe667..631722d 100644 --- a/ensemble/src/query.rs +++ b/ensemble/src/query.rs @@ -421,6 +421,31 @@ impl Builder { Ok(rbs::from_value(result.last_insert_id)?) } + /// Increment a column's value by a given amount. Returns the number of affected rows. + /// + /// # Errors + /// + /// Returns an error if the query fails, or if a connection to the database cannot be established. + pub async fn increment(self, column: &str, amount: u64) -> Result { + let mut conn = connection::get().await?; + let (sql, mut bindings) = ( + format!( + "UPDATE {} SET {column} = {column} + ? {}", + self.table, + self.to_sql(Type::Update) + ), + self.get_bindings(), + ); + bindings.insert(0, amount.into()); + + tracing::debug!(sql = sql.as_str(), bindings = ?bindings, "Executing UPDATE SQL query for increment"); + + conn.exec(&sql, bindings) + .await + .map_err(|e| Error::Database(e.to_string())) + .map(|r| r.rows_affected) + } + /// Update records in the database. Returns the number of affected rows. /// /// # Errors