Skip to content

Commit

Permalink
fix: Check enum categories when reading csv
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 6, 2025
1 parent 9c4dc9f commit 9037dca
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,48 @@ impl CategoricalChunkedBuilder {
}
}

fn try_get_cat_idx(&mut self, s: &str, h: u64) -> Option<u32> {
// SAFETY: index in hashmap are within bounds of categories
unsafe {
let r = self.local_mapping.entry(
h,
|k| self.categories.value_unchecked(*k as usize) == s,
|k| {
self.local_hasher
.hash_one(self.categories.value_unchecked(*k as usize))
},
);

match r {
HTEntry::Occupied(v) => Some(*v.get()),
HTEntry::Vacant(_) => None,
}
}
}

/// Append a new category, but fail if it didn't exist yet in the category state.
/// You can register categories up front with `register_value`, or via `append`.
#[inline]
pub fn try_append_value(&mut self, s: &str) -> PolarsResult<()> {
let h = self.local_hasher.hash_one(s);
let idx = self.try_get_cat_idx(s, h).ok_or_else(
|| polars_err!(ComputeError: "category {} doesn't exist in Enum dtype", s),
)?;
self.cat_builder.push(Some(idx));
Ok(())
}

/// Append a new category, but fail if it didn't exist yet in the category state.
/// You can register categories up front with `register_value`, or via `append`.
#[inline]
pub fn try_append(&mut self, opt_s: Option<&str>) -> PolarsResult<()> {
match opt_s {
None => self.append_null(),
Some(s) => self.try_append_value(s)?,
}
Ok(())
}

/// Registers a value to a categorical index without pushing it.
/// Returns the index and if the value was new.
#[inline]
Expand Down
15 changes: 10 additions & 5 deletions crates/polars-io/src/csv/read/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ impl CategoricalField {
self.builder.append_null();
return Ok(());
}

if validate_utf8(bytes) {
if needs_escaping {
polars_ensure!(bytes.len() > 1, ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?);
Expand All @@ -329,13 +328,19 @@ impl CategoricalField {
// SAFETY:
// just did utf8 check
let key = unsafe { std::str::from_utf8_unchecked(&self.escape_scratch) };
self.builder.append_value(key);
if self.is_enum {
self.builder.try_append_value(key)?;
} else {
self.builder.append_value(key);
}
} else {
// SAFETY:
// just did utf8 check
unsafe {
self.builder
.append_value(std::str::from_utf8_unchecked(bytes))
let key = unsafe { std::str::from_utf8_unchecked(bytes) };
if self.is_enum {
self.builder.try_append_value(key)?
} else {
self.builder.append_value(key)
}
}
} else if ignore_errors {
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2526,3 +2526,17 @@ def test_header_only_column_selection_17173() -> None:
result = pl.read_csv(io.StringIO(csv), columns=["B"])
expected = pl.Series("B", [], pl.String()).to_frame()
assert_frame_equal(result, expected)


def test_csv_enum_raise() -> None:
ENUM_DTYPE = pl.Enum(["foo", "bar"])
with (
io.StringIO("col\nfoo\nbaz\n") as csv,
pytest.raises(
pl.exceptions.ComputeError, match="category baz doesn't exist in Enum dtype"
),
):
pl.read_csv(
csv,
schema={"col": ENUM_DTYPE},
)

0 comments on commit 9037dca

Please sign in to comment.