Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Check Python version when deserializing UDFs #19175

Merged
merged 8 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 73 additions & 23 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
pub static mut CALL_DF_UDF_PYTHON: Option<
fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
> = None;
#[cfg(feature = "serde")]
pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes();

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -60,7 +61,7 @@ impl Serialize for PythonFunction {
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("Unable to import 'cloudpickle' or 'pickle'")
.expect("unable to import 'cloudpickle' or 'pickle'")
.getattr("dumps")
.unwrap();

Expand All @@ -86,9 +87,8 @@ impl<'a> Deserialize<'a> for PythonFunction {
let bytes = Vec::<u8>::deserialize(deserializer)?;

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("Unable to import 'pickle'")
Comment on lines -89 to -91
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cloudpickle simply re-exports pickle.loads, so trying to import cloudpickle during deserialization is a waste.

let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
.unwrap();
let arg = (PyBytes::new_bound(py, &bytes),);
Expand Down Expand Up @@ -125,19 +125,36 @@ impl PythonUdfExpression {

#[cfg(feature = "serde")]
pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
// Handle byte mark
debug_assert!(buf.starts_with(MAGIC_BYTE_MARK));
// skip header
let buf = &buf[MAGIC_BYTE_MARK.len()..];

// Handle pickle metadata
let use_cloudpickle = buf[0];
if use_cloudpickle != 0 {
let ser_py_version = buf[1];
let cur_py_version = get_python_minor_version();
polars_ensure!(
ser_py_version == cur_py_version,
InvalidOperation:
"current Python version (3.{}) does not match the Python version used to serialize the UDF (3.{})",
cur_py_version,
ser_py_version
);
}
let buf = &buf[2..];

// Load UDF metadata
let mut reader = Cursor::new(buf);
let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) =
ciborium::de::from_reader(&mut reader).map_err(map_err)?;

let remainder = &buf[reader.position() as usize..];

// Load UDF
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("Unable to import 'pickle'")
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
.unwrap();
let arg = (PyBytes::new_bound(py, remainder),);
Expand Down Expand Up @@ -189,26 +206,45 @@ impl ColumnsUdf for PythonUdfExpression {

#[cfg(feature = "serde")]
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
// Write byte marks
buf.extend_from_slice(MAGIC_BYTE_MARK);
ciborium::ser::into_writer(
&(
self.output_type.clone(),
self.is_elementwise,
self.returns_scalar,
),
&mut *buf,
)
.unwrap();

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("Unable to import 'pickle'")
// Try pickle to serialize the UDF, otherwise fall back to cloudpickle.
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("dumps")
.unwrap();
let dumped = pickle
.call1((self.python_function.clone(),))
.map_err(from_pyerr)?;
let pickle_result = pickle.call1((self.python_function.clone(),));
let (dumped, use_cloudpickle, py_version) = match pickle_result {
Ok(dumped) => (dumped, false, 0),
Err(_) => {
let cloudpickle = PyModule::import_bound(py, "cloudpickle")
.map_err(from_pyerr)?
.getattr("dumps")
.unwrap();
let dumped = cloudpickle
.call1((self.python_function.clone(),))
.map_err(from_pyerr)?;
(dumped, true, get_python_minor_version())
},
};

// Write pickle metadata
buf.extend_from_slice(&[use_cloudpickle as u8, py_version]);

// Write UDF metadata
ciborium::ser::into_writer(
&(
self.output_type.clone(),
self.is_elementwise,
self.returns_scalar,
),
&mut *buf,
)
.unwrap();

// Write UDF
let dumped = dumped.extract::<PyBackedBytes>().unwrap();
buf.extend_from_slice(&dumped);
Ok(())
Expand Down Expand Up @@ -298,3 +334,17 @@ impl Expr {
}
}
}

/// Get the minor Python version from the `sys` module.
fn get_python_minor_version() -> u8 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cache this behind a lazylock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done!

Unfortunately, now I can no longer patch the sys.version_info from the Python side, so we can no longer really test this. If you have an idea how to test it, let me know. Otherwise, we'll just have to trust the code 😬

Python::with_gil(|py| {
PyModule::import_bound(py, "sys")
.unwrap()
.getattr("version_info")
.unwrap()
.getattr("minor")
.unwrap()
.extract()
.unwrap()
})
}
62 changes: 60 additions & 2 deletions py-polars/tests/unit/lazyframe/test_serde.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import io
from typing import TYPE_CHECKING
import sys
from typing import TYPE_CHECKING, NamedTuple

import pytest
from hypothesis import example, given

import polars as pl
from polars.exceptions import ComputeError
from polars.exceptions import ComputeError, InvalidOperationError
from polars.testing import assert_frame_equal
from polars.testing.parametric import dataframes

Expand Down Expand Up @@ -116,3 +117,60 @@ def test_lf_serde_scan(tmp_path: Path) -> None:
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
assert_frame_equal(result, lf)
assert_frame_equal(result.collect(), df)


class MockVersionInfo(NamedTuple):
"""Version info with different minor version."""

major: int = sys.version_info.major
minor: int = sys.version_info.minor + 1
micro: int = sys.version_info.micro


@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
def test_lf_serde_lambda_version_specific(
monkeypatch: pytest.MonkeyPatch,
) -> None:
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)
)
ser = lf.serialize()

# Same version
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
expected = pl.LazyFrame({"a": [2, 3, 4]})
assert_frame_equal(result, expected)

# Different version
monkeypatch.setattr(sys, "version_info", MockVersionInfo())

with pytest.raises(
InvalidOperationError,
match="does not match the Python version used to serialize the UDF",
):
pl.LazyFrame.deserialize(io.BytesIO(ser)).collect()


def custom_function(x: pl.Series) -> pl.Series:
return x + 1


@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
def test_lf_serde_function_version_specific(
monkeypatch: pytest.MonkeyPatch,
) -> None:
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
pl.col("a").map_batches(custom_function, return_dtype=pl.Int64)
)
ser = lf.serialize()

# Same version
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
expected = pl.LazyFrame({"a": [2, 3, 4]})
assert_frame_equal(result, expected)

# Different version
monkeypatch.setattr(sys, "version_info", MockVersionInfo())

result = pl.LazyFrame.deserialize(io.BytesIO(ser))
assert_frame_equal(result, expected)
Loading