Skip to content

Commit

Permalink
update migrate Transaction and AsyncTransaction execute functions (#346)
Browse files Browse the repository at this point in the history
This allows avoiding double iteration when calling migrate with the grouped option active
  • Loading branch information
jxs authored Aug 3, 2024
1 parent b09e3cf commit 068858e
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 39 deletions.
2 changes: 1 addition & 1 deletion refinery_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "This crate should not be used directly, it is internally related
license = "MIT OR Apache-2.0"
documentation = "https://docs.rs/refinery/"
repository = "https://github.com/rust-db/refinery"
edition = "2018"
edition = "2021"

[features]
default = []
Expand Down
10 changes: 8 additions & 2 deletions refinery_core/src/drivers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use std::convert::Infallible;
impl Transaction for Config {
type Error = Infallible;

fn execute(&mut self, _queries: &[&str]) -> Result<usize, Self::Error> {
fn execute<'a, T: Iterator<Item = &'a str>>(
&mut self,
_queries: T,
) -> Result<usize, Self::Error> {
Ok(0)
}
}
Expand All @@ -33,7 +36,10 @@ impl Query<Vec<Migration>> for Config {
impl AsyncTransaction for Config {
type Error = Infallible;

async fn execute(&mut self, _queries: &[&str]) -> Result<usize, Self::Error> {
async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
&mut self,
_queries: T,
) -> Result<usize, Self::Error> {
Ok(0)
}
}
Expand Down
14 changes: 10 additions & 4 deletions refinery_core/src/drivers/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ fn query_applied_migrations(
impl Transaction for Conn {
type Error = MError;

fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
fn execute<'a, T: Iterator<Item = &'a str>>(
&mut self,
queries: T,
) -> Result<usize, Self::Error> {
let mut transaction = self.start_transaction(get_tx_opts())?;
let mut count = 0;
for query in queries.iter() {
for query in queries {
transaction.query_iter(query)?;
count += 1;
}
Expand All @@ -58,11 +61,14 @@ impl Transaction for Conn {
impl Transaction for PooledConn {
type Error = MError;

fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
fn execute<'a, T: Iterator<Item = &'a str>>(
&mut self,
queries: T,
) -> Result<usize, Self::Error> {
let mut transaction = self.start_transaction(get_tx_opts())?;
let mut count = 0;

for query in queries.iter() {
for query in queries {
transaction.query_iter(query)?;
count += 1;
}
Expand Down
7 changes: 5 additions & 2 deletions refinery_core/src/drivers/mysql_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ async fn query_applied_migrations<'a>(
impl AsyncTransaction for Pool {
type Error = MError;

async fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
&mut self,
queries: T,
) -> Result<usize, Self::Error> {
let mut conn = self.get_conn().await?;
let mut options = TxOpts::new();
options.with_isolation_level(Some(IsolationLevel::ReadCommitted));

let mut transaction = conn.start_transaction(options).await?;
let mut count = 0;
for query in queries {
transaction.query_drop(*query).await?;
transaction.query_drop(query).await?;
count += 1;
}
transaction.commit().await?;
Expand Down
7 changes: 5 additions & 2 deletions refinery_core/src/drivers/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ fn query_applied_migrations(
impl Transaction for PgClient {
type Error = PgError;

fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
fn execute<'a, T: Iterator<Item = &'a str>>(
&mut self,
queries: T,
) -> Result<usize, Self::Error> {
let mut transaction = PgClient::transaction(self)?;
let mut count = 0;
for query in queries.iter() {
for query in queries {
PgTransaction::batch_execute(&mut transaction, query)?;
count += 1;
}
Expand Down
7 changes: 5 additions & 2 deletions refinery_core/src/drivers/rusqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ fn query_applied_migrations(

impl Transaction for RqlConnection {
type Error = RqlError;
fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
fn execute<'a, T: Iterator<Item = &'a str>>(
&mut self,
queries: T,
) -> Result<usize, Self::Error> {
let transaction = self.transaction()?;
let mut count = 0;
for query in queries.iter() {
for query in queries {
transaction.execute_batch(query)?;
count += 1;
}
Expand Down
7 changes: 5 additions & 2 deletions refinery_core/src/drivers/tiberius.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,16 @@ where
{
type Error = Error;

async fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
&mut self,
queries: T,
) -> Result<usize, Self::Error> {
// Tiberius doesn't support transactions, see https://github.com/prisma/tiberius/issues/28
self.simple_query("BEGIN TRAN T1;").await?;
let mut count = 0;
for query in queries {
// Drop the returning `QueryStream<'a>` to avoid compiler complaning regarding lifetimes
if let Err(err) = self.simple_query(*query).await.map(drop) {
if let Err(err) = self.simple_query(query).await.map(drop) {
if let Err(err) = self.simple_query("ROLLBACK TRAN T1").await {
log::error!("could not ROLLBACK transaction, {}", err);
}
Expand Down
5 changes: 4 additions & 1 deletion refinery_core/src/drivers/tokio_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ async fn query_applied_migrations(
impl AsyncTransaction for Client {
type Error = PgError;

async fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error> {
async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
&mut self,
queries: T,
) -> Result<usize, Self::Error> {
let transaction = self.transaction().await?;
let mut count = 0;
for query in queries {
Expand Down
29 changes: 19 additions & 10 deletions refinery_core/src/traits/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ use crate::traits::{
use crate::{Error, Migration, Report, Target};

use async_trait::async_trait;
use std::ops::Deref;
use std::string::ToString;

#[async_trait]
pub trait AsyncTransaction {
type Error: std::error::Error + Send + Sync + 'static;

async fn execute(&mut self, query: &[&str]) -> Result<usize, Self::Error>;
async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
&mut self,
queries: T,
) -> Result<usize, Self::Error>;
}

#[async_trait]
Expand Down Expand Up @@ -43,10 +47,13 @@ async fn migrate<T: AsyncTransaction>(
migration.set_applied();
let update_query = insert_migration_query(&migration, migration_table_name);
transaction
.execute(&[
migration.sql().as_ref().expect("sql must be Some!"),
&update_query,
])
.execute(
[
migration.sql().as_ref().expect("sql must be Some!"),
update_query.as_str(),
]
.into_iter(),
)
.await
.migration_err(
&format!("error applying migration {}", migration),
Expand Down Expand Up @@ -105,10 +112,10 @@ async fn migrate_grouped<T: AsyncTransaction>(
);
}

let refs: Vec<&str> = grouped_migrations.iter().map(AsRef::as_ref).collect();
let refs = grouped_migrations.iter().map(AsRef::as_ref);

transaction
.execute(refs.as_ref())
.execute(refs)
.await
.migration_err("error applying migrations", None)?;

Expand Down Expand Up @@ -164,9 +171,11 @@ where
target: Target,
migration_table_name: &str,
) -> Result<Report, Error> {
self.execute(&[&Self::assert_migrations_table_query(migration_table_name)])
.await
.migration_err("error asserting migrations table", None)?;
self.execute(
[Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter(),
)
.await
.migration_err("error asserting migrations table", None)?;

let applied_migrations = self
.query(
Expand Down
32 changes: 19 additions & 13 deletions refinery_core/src/traits/sync.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Deref;

use crate::error::WrapMigrationError;
use crate::traits::{
insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY,
Expand All @@ -8,7 +10,10 @@ use crate::{Error, Migration, Report, Target};
pub trait Transaction {
type Error: std::error::Error + Send + Sync + 'static;

fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error>;
fn execute<'a, T: Iterator<Item = &'a str>>(
&mut self,
queries: T,
) -> Result<usize, Self::Error>;
}

pub trait Query<T>: Transaction {
Expand All @@ -20,7 +25,7 @@ pub fn migrate<T: Transaction>(
migrations: Vec<Migration>,
target: Target,
migration_table_name: &str,
batched: bool,
grouped: bool,
) -> Result<Report, Error> {
let mut migration_batch = Vec::new();
let mut applied_migrations = Vec::new();
Expand Down Expand Up @@ -49,7 +54,7 @@ pub fn migrate<T: Transaction>(
migration_batch.push(insert_migration);
}

match (target, batched) {
match (target, grouped) {
(Target::Fake | Target::FakeVersion(_), _) => {
log::info!("not going to apply any migration as fake flag is enabled");
}
Expand All @@ -68,16 +73,14 @@ pub fn migrate<T: Transaction>(
}
};

let refs: Vec<&str> = migration_batch.iter().map(AsRef::as_ref).collect();

if batched {
if grouped {
transaction
.execute(refs.as_ref())
.execute(migration_batch.iter().map(Deref::deref))
.migration_err("error applying migrations", None)?;
} else {
for (i, update) in refs.iter().enumerate() {
for (i, update) in migration_batch.into_iter().enumerate() {
transaction
.execute(&[update])
.execute([update.as_str()].into_iter())
.migration_err("error applying update", Some(&applied_migrations[0..i / 2]))?;
}
}
Expand All @@ -92,10 +95,13 @@ where
fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result<usize, Error> {
// Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table,
// thou on this case it's just to be consistent with the async trait `AsyncMigrate`
self.execute(&[ASSERT_MIGRATIONS_TABLE_QUERY
.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
.as_str()])
.migration_err("error asserting migrations table", None)
self.execute(
[ASSERT_MIGRATIONS_TABLE_QUERY
.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
.as_str()]
.into_iter(),
)
.migration_err("error asserting migrations table", None)
}

fn get_last_applied_migration(
Expand Down

0 comments on commit 068858e

Please sign in to comment.