Skip to content

Commit

Permalink
lazy schema IO plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
tmct committed Oct 14, 2024
1 parent 597a10e commit d369704
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 8 deletions.
15 changes: 13 additions & 2 deletions crates/polars-lazy/src/frame/python.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use polars_core::prelude::*;
use pyo3::PyObject;

use crate::prelude::*;

impl LazyFrame {
Expand All @@ -9,7 +8,7 @@ impl LazyFrame {
options: PythonOptions {
// Should be a python function that returns a generator
scan_fn: Some(scan_fn.into()),
schema: Arc::new(schema),
schema: PySchemaSource::SchemaRef(Arc::new(schema)),
python_source: if pyarrow {
PythonScanSource::Pyarrow
} else {
Expand All @@ -20,4 +19,16 @@ impl LazyFrame {
}
.into()
}

pub fn scan_from_python_functions(schema_fn: PyObject, scan_fn: PyObject) -> Self {
DslPlan::PythonScan {
options: PythonOptions {
schema: PySchemaSource::PythonFunction(schema_fn.into()),
scan_fn: Some(scan_fn.into()),
python_source: PythonScanSource::IOPluginDeferredSchema,
..Default::default()
},
}
.into()
}
}
13 changes: 12 additions & 1 deletion crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ pub struct LogicalPlanUdfOptions {
pub fmt_str: &'static str,
}

#[derive(Clone, PartialEq, Eq, Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg(feature = "python")]
pub enum PySchemaSource {
#[default]
SchemaRef(SchemaRef),
PythonFunction(Py),
}


#[derive(Clone, PartialEq, Eq, Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg(feature = "python")]
Expand All @@ -254,7 +264,7 @@ pub struct PythonOptions {
/// The generator should produce Polars DataFrame's.
pub scan_fn: Option<PythonFunction>,
/// Schema of the file.
pub schema: SchemaRef,
pub schema: PySchemaSource,
/// Schema the reader will produce when the file is read.
pub output_schema: Option<SchemaRef>,
// Projected column names.
Expand All @@ -275,6 +285,7 @@ pub enum PythonScanSource {
Cuda,
#[default]
IOPlugin,
IOPluginDeferredSchema,
}

#[derive(Clone, PartialEq, Eq, Debug, Default)]
Expand Down
9 changes: 8 additions & 1 deletion crates/polars-python/src/lazyframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use polars_plan::plans::ScanSources;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyDict, PyList};

use super::PyLazyFrame;
use crate::error::PyPolarsErr;
use crate::expr::ToExprs;
Expand Down Expand Up @@ -403,6 +402,14 @@ impl PyLazyFrame {
Ok(LazyFrame::scan_from_python_function(schema, scan_fn, pyarrow).into())
}

#[staticmethod]
fn scan_from_python_function_deferred_schema(
schema_fn: PyObject,
scan_fn: PyObject,
) -> PyResult<Self> {
Ok(LazyFrame::scan_from_python_functions(schema_fn, scan_fn).into())
}

#[staticmethod]
fn scan_from_python_function_pl_schema(
schema: Vec<(PyBackedStr, Wrap<DataType>)>,
Expand Down
5 changes: 2 additions & 3 deletions crates/polars-python/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ use std::sync::{Arc, Mutex};
use polars::prelude::PolarsError;
use polars_plan::plans::{to_aexpr, Context, IR};
use polars_plan::prelude::expr_ir::ExprIR;
use polars_plan::prelude::{AExpr, PythonOptions, PythonScanSource};
use polars_plan::prelude::{AExpr, PySchemaSource, PythonOptions, PythonScanSource};
use polars_utils::arena::{Arena, Node};
use pyo3::prelude::*;
use pyo3::types::PyList;

use super::visitor::{expr_nodes, nodes};
use super::PyLazyFrame;
use crate::error::PyPolarsErr;
Expand Down Expand Up @@ -164,7 +163,7 @@ impl NodeTraverser {
let ir = IR::PythonScan {
options: PythonOptions {
scan_fn: Some(function.into()),
schema,
schema: PySchemaSource::SchemaRef(schema),
output_schema: None,
with_columns: None,
python_source: PythonScanSource::Cuda,
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def register_io_source(
callable: Callable[
[list[str] | None, Expr | None, int | None, int | None], Iterator[DataFrame]
],
schema: SchemaDict,
schema: SchemaDict | Callable,
) -> LazyFrame:
"""
Register your IO plugin and initialize a LazyFrame.
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/io/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def scan_my_source() -> pl.LazyFrame:
return register_io_source(my_source, schema=schema)


def scan_my_source_deferred_schema() -> pl.LazyFrame:
def collect_schema() -> pl.Schema:
return pl.Schema({"a": pl.Int64(), "b": pl.Int64()})

return register_io_source(my_source, schema=collect_schema)


def test_my_source() -> None:
assert_frame_equal(
scan_my_source().collect(), pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]})
Expand All @@ -52,3 +59,7 @@ def test_my_source() -> None:
assert_frame_equal(
scan_my_source().select("a").collect(), pl.DataFrame({"a": [1, 2, 3]})
)

assert_frame_equal(
scan_my_source_deferred_schema().collect(), pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]})
)

0 comments on commit d369704

Please sign in to comment.