Skip to content

Commit

Permalink
feat: don't write duplicate rows to DB (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsstevenson authored Nov 19, 2024
1 parent 8645846 commit 52c0ce2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
21 changes: 18 additions & 3 deletions rust/src/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use noodles_vcf::{
variant::record::info::{self, field::Value as InfoValue},
};
use pyo3::{exceptions, prelude::*};
use sqlx::SqlitePool;
use sqlx::{error::DatabaseError, sqlite::SqliteError, SqlitePool};
use std::path::PathBuf;
use std::time::Instant;
use tokio::{
Expand All @@ -19,12 +19,27 @@ use tokio::{

async fn load_allele(db_row: DbRow, pool: &SqlitePool) -> Result<(), Box<dyn std::error::Error>> {
let mut conn = pool.acquire().await?;
sqlx::query("INSERT INTO vrs_locations (vrs_id, chr, pos) VALUES (?, ?, ?)")
let result = sqlx::query("INSERT INTO vrs_locations (vrs_id, chr, pos) VALUES (?, ?, ?)")
.bind(db_row.vrs_id)
.bind(db_row.chr)
.bind(db_row.pos)
.execute(&mut *conn)
.await?;
.await;
if let Err(err) = result {
if let Some(db_error) = err.as_database_error() {
if let Some(sqlite_error) = db_error.try_downcast_ref::<SqliteError>() {
if sqlite_error
.code()
.map(|code| code == "2067")
.unwrap_or(false)
{
error!("duplicate");
return Ok(());
}
}
}
return Err(err.into());
}
Ok(())
}

Expand Down
3 changes: 2 additions & 1 deletion rust/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ pub async fn setup_db(db_url: &str) -> Result<(), Error> {
id INTEGER PRIMARY KEY AUTOINCREMENT,
vrs_id TEXT NOT NULL,
chr TEXT NOT NULL,
pos INTEGER NOT NULL
pos INTEGER NOT NULL,
UNIQUE(vrs_id,chr,pos)
);",
)
.execute(&db)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,13 @@ def test_non_block_gzip(fixture_dir: Path, temp_dir: Path):
temp_db = temp_dir / "tmp.db"
with pytest.raises(OSError, match="invalid BGZF header"):
load.load_vcf(fixture_dir / "input_not_bgzip.vcf.gz", temp_db)


def test_load_redundant_rows(fixture_dir: Path, temp_dir: Path):
input_file = fixture_dir / "input.vcf"
temp_db = temp_dir / "tmp.db"
load.load_vcf(input_file, temp_db)
load.load_vcf(input_file, temp_db)
conn = sqlite3.connect(temp_db)
results = conn.execute("SELECT * FROM vrs_locations").fetchall()
assert len(results) == 10

0 comments on commit 52c0ce2

Please sign in to comment.