From 52c0ce278ca2735e60b4e636795f120a46f17ee4 Mon Sep 17 00:00:00 2001 From: James Stevenson Date: Tue, 19 Nov 2024 16:28:55 -0500 Subject: [PATCH] feat: don't write duplicate rows to DB (#11) --- rust/src/load.rs | 21 ++++++++++++++++++--- rust/src/sqlite.rs | 3 ++- tests/test_load.py | 10 ++++++++++ 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/rust/src/load.rs b/rust/src/load.rs index 21f7106..e765dba 100644 --- a/rust/src/load.rs +++ b/rust/src/load.rs @@ -9,7 +9,7 @@ use noodles_vcf::{ variant::record::info::{self, field::Value as InfoValue}, }; use pyo3::{exceptions, prelude::*}; -use sqlx::SqlitePool; +use sqlx::{error::DatabaseError, sqlite::SqliteError, SqlitePool}; use std::path::PathBuf; use std::time::Instant; use tokio::{ @@ -19,12 +19,27 @@ use tokio::{ async fn load_allele(db_row: DbRow, pool: &SqlitePool) -> Result<(), Box> { let mut conn = pool.acquire().await?; - sqlx::query("INSERT INTO vrs_locations (vrs_id, chr, pos) VALUES (?, ?, ?)") + let result = sqlx::query("INSERT INTO vrs_locations (vrs_id, chr, pos) VALUES (?, ?, ?)") .bind(db_row.vrs_id) .bind(db_row.chr) .bind(db_row.pos) .execute(&mut *conn) - .await?; + .await; + if let Err(err) = result { + if let Some(db_error) = err.as_database_error() { + if let Some(sqlite_error) = db_error.try_downcast_ref::() { + if sqlite_error + .code() + .map(|code| code == "2067") + .unwrap_or(false) + { + error!("duplicate"); + return Ok(()); + } + } + } + return Err(err.into()); + } Ok(()) } diff --git a/rust/src/sqlite.rs b/rust/src/sqlite.rs index 62224a1..76f0367 100644 --- a/rust/src/sqlite.rs +++ b/rust/src/sqlite.rs @@ -24,7 +24,8 @@ pub async fn setup_db(db_url: &str) -> Result<(), Error> { id INTEGER PRIMARY KEY AUTOINCREMENT, vrs_id TEXT NOT NULL, chr TEXT NOT NULL, - pos INTEGER NOT NULL + pos INTEGER NOT NULL, + UNIQUE(vrs_id,chr,pos) );", ) .execute(&db) diff --git a/tests/test_load.py b/tests/test_load.py index d2778a9..704c0e7 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -71,3 +71,13 @@ def test_non_block_gzip(fixture_dir: Path, temp_dir: Path): temp_db = temp_dir / "tmp.db" with pytest.raises(OSError, match="invalid BGZF header"): load.load_vcf(fixture_dir / "input_not_bgzip.vcf.gz", temp_db) + + +def test_load_redundant_rows(fixture_dir: Path, temp_dir: Path): + input_file = fixture_dir / "input.vcf" + temp_db = temp_dir / "tmp.db" + load.load_vcf(input_file, temp_db) + load.load_vcf(input_file, temp_db) + conn = sqlite3.connect(temp_db) + results = conn.execute("SELECT * FROM vrs_locations").fetchall() + assert len(results) == 10