Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor code for re-use by streaming NDJSON source #21520

Merged
merged 2 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions crates/polars-io/src/ndjson/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,7 @@ impl<'a> CoreJsonReader<'a> {

fn count(mut self) -> PolarsResult<usize> {
let bytes = self.reader_bytes.take().unwrap();
let n_threads = self.n_threads.unwrap_or(POOL.current_num_threads());
let file_chunks = get_file_chunks_json(bytes.as_ref(), n_threads);

let iter = file_chunks.par_iter().map(|(start_pos, stop_at_nbytes)| {
let bytes = &bytes[*start_pos..*stop_at_nbytes];
let iter = json_lines(bytes);
iter.count()
});
Ok(POOL.install(|| iter.sum()))
Ok(super::count_rows_par(&bytes))
}

fn parse_json(&mut self, mut n_threads: usize, bytes: &[u8]) -> PolarsResult<DataFrame> {
Expand Down Expand Up @@ -304,13 +296,11 @@ impl<'a> CoreJsonReader<'a> {
file_chunks
.into_par_iter()
.map(|(start_pos, stop_at_nbytes)| {
let mut buffers = init_buffers(&self.schema, capacity, self.ignore_errors)?;
parse_lines(&bytes[start_pos..stop_at_nbytes], &mut buffers)?;
let mut local_df = DataFrame::new(
buffers
.into_values()
.map(|buf| buf.into_series().into_column())
.collect::<_>(),
let mut local_df = parse_ndjson(
&bytes[start_pos..stop_at_nbytes],
Some(capacity),
&self.schema,
self.ignore_errors,
)?;

let prepredicate_height = local_df.height() as IdxSize;
Expand Down Expand Up @@ -394,7 +384,7 @@ struct Scratch {
buffers: simd_json::Buffers,
}

fn json_lines(bytes: &[u8]) -> impl Iterator<Item = &[u8]> {
pub fn json_lines(bytes: &[u8]) -> impl Iterator<Item = &[u8]> {
// This previously used `serde_json`'s `RawValue` to deserialize chunks without really deserializing them.
// However, this convenience comes at a cost. serde_json allocates and parses and does UTF-8 validation, all
// things we don't need since we use simd_json for them. Also, `serde_json::StreamDeserializer` has a more
Expand All @@ -417,6 +407,41 @@ fn parse_lines(bytes: &[u8], buffers: &mut PlIndexMap<BufferKey, Buffer>) -> Pol
Ok(())
}

pub fn parse_ndjson(
bytes: &[u8],
n_rows_hint: Option<usize>,
schema: &Schema,
ignore_errors: bool,
) -> PolarsResult<DataFrame> {
let capacity = n_rows_hint.unwrap_or_else(|| {
// Default to total len divided by max len of first and last non-empty lines or 1.
bytes
.split(|&c| c == b'\n')
.find(|x| !x.is_empty())
.map_or(1, |x| {
bytes.len().div_ceil(
x.len().max(
bytes
.rsplit(|&c| c == b'\n')
.find(|x| !x.is_empty())
.unwrap()
.len(),
),
)
})
});

let mut buffers = init_buffers(schema, capacity, ignore_errors)?;
parse_lines(bytes, &mut buffers)?;

DataFrame::new(
buffers
.into_values()
.map(|buf| buf.into_series().into_column())
.collect::<_>(),
)
}

/// Find the nearest next line position.
/// Does not check for new line characters embedded in String fields.
/// This just looks for `}\n`
Expand Down
295 changes: 295 additions & 0 deletions crates/polars-io/src/ndjson/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use core::json_lines;
use std::num::NonZeroUsize;

use arrow::array::StructArray;
use polars_core::prelude::*;
use polars_core::POOL;
use rayon::iter::{IntoParallelIterator, ParallelIterator};

pub(crate) mod buffer;
pub mod core;
Expand All @@ -19,3 +22,295 @@ pub fn infer_schema<R: std::io::BufRead>(
.collect();
Ok(schema)
}

/// Statistics for a chunk of text used for NDJSON parsing.
#[derive(Debug, Clone, PartialEq)]
struct ChunkStats {
non_empty_rows: usize,
/// Set to None if the chunk was empty.
last_newline_offset: Option<usize>,
/// Used when counting rows.
has_leading_empty_line: bool,
has_non_empty_remainder: bool,
}

impl ChunkStats {
/// Assumes that:
/// * There is no quoting of newlines characters (unlike CSV)
/// * We do not count empty lines (successive newlines, or lines containing only whitespace / tab)
fn from_chunk(chunk: &[u8]) -> Self {
// Notes: Offsets are right-to-left in reverse mode.
let first_newline_offset = memchr::memchr(b'\n', chunk);
let last_newline_offset = memchr::memrchr(b'\n', chunk);

let has_leading_empty_line =
first_newline_offset.is_some_and(|i| json_lines(&chunk[..i]).next().is_none());
let has_non_empty_remainder =
json_lines(&chunk[last_newline_offset.map_or(0, |i| 1 + i)..chunk.len()])
.next()
.is_some();

let mut non_empty_rows = if first_newline_offset.is_some() && !has_leading_empty_line {
1
} else {
0
};

if first_newline_offset.is_some() {
let range = first_newline_offset.unwrap() + 1..last_newline_offset.unwrap() + 1;
non_empty_rows += json_lines(&chunk[range]).count()
}

Self {
non_empty_rows,
has_leading_empty_line,
last_newline_offset,
has_non_empty_remainder,
}
}

/// Reduction state for counting rows.
///
/// Note: `rhs` should be from the chunk immediately after `slf`, otherwise the results will be
/// incorrect.
pub fn reduce_count_rows(slf: &Self, rhs: &Self) -> Self {
let mut non_empty_rows = slf.non_empty_rows + rhs.non_empty_rows;

if slf.has_non_empty_remainder && rhs.has_leading_empty_line {
non_empty_rows += 1;
}

ChunkStats {
non_empty_rows,
last_newline_offset: rhs.last_newline_offset,
has_leading_empty_line: slf.has_leading_empty_line,
has_non_empty_remainder: rhs.has_non_empty_remainder
|| (rhs.last_newline_offset.is_none() && slf.has_non_empty_remainder),
}
}

/// The non-empty row count of this chunk assuming it is the last chunk (adds 1 if there is a
/// non-empty remainder).
pub fn non_empty_row_count_as_last_chunk(&self) -> usize {
self.non_empty_rows + self.has_non_empty_remainder as usize
}
}

/// Count the number of rows. The slice passed must represent the entire file. This will
/// potentially parallelize using rayon.
///
/// This does not check if the lines are valid NDJSON - it assumes that is the case.
pub fn count_rows_par(full_bytes: &[u8]) -> usize {
_count_rows_impl(
full_bytes,
std::env::var("POLARS_FORCE_NDJSON_CHUNK_SIZE")
.ok()
.and_then(|x| x.parse::<usize>().ok()),
)
}

/// Count the number of rows. The slice passed must represent the entire file.
/// This does not check if the lines are valid NDJSON - it assumes that is the case.
pub fn count_rows(full_bytes: &[u8]) -> usize {
json_lines(full_bytes).count()
}

/// This is separate for testing purposes.
fn _count_rows_impl(full_bytes: &[u8], force_chunk_size: Option<usize>) -> usize {
let min_chunk_size = if cfg!(debug_assertions) { 0 } else { 16 * 1024 };

// Row count does not have a parsing dependency between threads, so we can just split into
// the same number of chunks as threads.
let chunk_size = force_chunk_size.unwrap_or(
full_bytes
.len()
.div_ceil(POOL.current_num_threads())
.max(min_chunk_size),
);

if full_bytes.is_empty() {
return 0;
}

let n_chunks = full_bytes.len().div_ceil(chunk_size);

if n_chunks > 1 {
let identity = ChunkStats::from_chunk(&[]);
let acc_stats = POOL.install(|| {
(0..n_chunks)
.into_par_iter()
.map(|i| {
ChunkStats::from_chunk(
&full_bytes[i * chunk_size
..(1 + i).saturating_mul(chunk_size).min(full_bytes.len())],
)
})
.reduce(
|| identity.clone(),
|l, r| ChunkStats::reduce_count_rows(&l, &r),
)
});

acc_stats.non_empty_row_count_as_last_chunk()
} else {
count_rows(full_bytes)
}
}

#[cfg(test)]
mod tests {
use super::ChunkStats;

#[test]
fn test_chunk_stats() {
let bytes = r#"
{"a": 1}
{"a": 2}
"#
.as_bytes();

assert_eq!(
ChunkStats::from_chunk(bytes),
ChunkStats {
non_empty_rows: 2,
last_newline_offset: Some(18),
has_leading_empty_line: true,
has_non_empty_remainder: false,
}
);

assert_eq!(
ChunkStats::from_chunk(&bytes[..bytes.len() - 3]),
ChunkStats {
non_empty_rows: 1,
last_newline_offset: Some(9),
has_leading_empty_line: true,
has_non_empty_remainder: true,
}
);

assert_eq!(super::_count_rows_impl(&[], Some(1)), 0);
assert_eq!(super::_count_rows_impl(bytes, Some(1)), 2);
assert_eq!(super::_count_rows_impl(bytes, Some(3)), 2);
assert_eq!(super::_count_rows_impl(bytes, Some(5)), 2);
assert_eq!(super::_count_rows_impl(bytes, Some(7)), 2);
assert_eq!(super::_count_rows_impl(bytes, Some(bytes.len())), 2);

assert_eq!(super::count_rows_par(&[]), 0);

assert_eq!(
ChunkStats::from_chunk(&[]),
ChunkStats {
non_empty_rows: 0,
last_newline_offset: None,
has_leading_empty_line: false,
has_non_empty_remainder: false,
}
);

// Single-chars

assert_eq!(
ChunkStats::from_chunk(b"\n"),
ChunkStats {
non_empty_rows: 0,
last_newline_offset: Some(0),
has_leading_empty_line: true,
has_non_empty_remainder: false,
}
);

assert_eq!(
ChunkStats::from_chunk(b"a"),
ChunkStats {
non_empty_rows: 0,
last_newline_offset: None,
has_leading_empty_line: false,
has_non_empty_remainder: true,
}
);

assert_eq!(
ChunkStats::from_chunk(b" "),
ChunkStats {
non_empty_rows: 0,
last_newline_offset: None,
has_leading_empty_line: false,
has_non_empty_remainder: false,
}
);

// Double-char combinations

assert_eq!(
ChunkStats::from_chunk(b"a\n"),
ChunkStats {
non_empty_rows: 1,
last_newline_offset: Some(1),
has_leading_empty_line: false,
has_non_empty_remainder: false,
}
);

assert_eq!(
ChunkStats::from_chunk(b" \n"),
ChunkStats {
non_empty_rows: 0,
last_newline_offset: Some(1),
has_leading_empty_line: true,
has_non_empty_remainder: false,
}
);

assert_eq!(
ChunkStats::from_chunk(b"a "),
ChunkStats {
non_empty_rows: 0,
last_newline_offset: None,
has_leading_empty_line: false,
has_non_empty_remainder: true,
}
);
}

#[test]
fn test_chunk_stats_whitespace() {
let space_char = ' ';
let tab_char = '\t';
// This is not valid JSON, but we simply need to test that ChunkStats only counts lines
// containing at least 1 non-whitespace character.
let bytes = format!(
"
abc

abc

{tab_char}
{space_char}{space_char}{space_char}

abc{space_char}

"
);
let bytes = bytes.as_bytes();

assert_eq!(
ChunkStats::from_chunk(bytes),
ChunkStats {
non_empty_rows: 3,
last_newline_offset: Some(28),
has_leading_empty_line: true,
has_non_empty_remainder: false,
}
);
}

#[test]
fn test_count_rows() {
let bytes = r#"{"text": "\"hello", "id": 1}
{"text": "\"hello", "id": 1} "#
.as_bytes();

assert_eq!(super::count_rows_par(bytes), 2);
}
}
Loading
Loading