-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Check Python version when deserializing UDFs #19175
Changes from 6 commits
dbd410c
5110db6
1c03027
4adab97
ef0ceaa
c6895c1
513083b
6dd7865
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ pub static mut CALL_COLUMNS_UDF_PYTHON: Option< | |
pub static mut CALL_DF_UDF_PYTHON: Option< | ||
fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>, | ||
> = None; | ||
#[cfg(feature = "serde")] | ||
pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); | ||
|
||
#[derive(Clone, Debug)] | ||
|
@@ -60,7 +61,7 @@ impl Serialize for PythonFunction { | |
Python::with_gil(|py| { | ||
let pickle = PyModule::import_bound(py, "cloudpickle") | ||
.or_else(|_| PyModule::import_bound(py, "pickle")) | ||
.expect("Unable to import 'cloudpickle' or 'pickle'") | ||
.expect("unable to import 'cloudpickle' or 'pickle'") | ||
.getattr("dumps") | ||
.unwrap(); | ||
|
||
|
@@ -86,9 +87,8 @@ impl<'a> Deserialize<'a> for PythonFunction { | |
let bytes = Vec::<u8>::deserialize(deserializer)?; | ||
|
||
Python::with_gil(|py| { | ||
let pickle = PyModule::import_bound(py, "cloudpickle") | ||
.or_else(|_| PyModule::import_bound(py, "pickle")) | ||
.expect("Unable to import 'pickle'") | ||
let pickle = PyModule::import_bound(py, "pickle") | ||
.expect("unable to import 'pickle'") | ||
.getattr("loads") | ||
.unwrap(); | ||
let arg = (PyBytes::new_bound(py, &bytes),); | ||
|
@@ -125,19 +125,36 @@ impl PythonUdfExpression { | |
|
||
#[cfg(feature = "serde")] | ||
pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> { | ||
// Handle byte mark | ||
debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); | ||
// skip header | ||
let buf = &buf[MAGIC_BYTE_MARK.len()..]; | ||
|
||
// Handle pickle metadata | ||
let use_cloudpickle = buf[0]; | ||
if use_cloudpickle != 0 { | ||
let ser_py_version = buf[1]; | ||
let cur_py_version = get_python_minor_version(); | ||
polars_ensure!( | ||
ser_py_version == cur_py_version, | ||
InvalidOperation: | ||
"current Python version (3.{}) does not match the Python version used to serialize the UDF (3.{})", | ||
cur_py_version, | ||
ser_py_version | ||
); | ||
} | ||
let buf = &buf[2..]; | ||
|
||
// Load UDF metadata | ||
let mut reader = Cursor::new(buf); | ||
let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) = | ||
ciborium::de::from_reader(&mut reader).map_err(map_err)?; | ||
|
||
let remainder = &buf[reader.position() as usize..]; | ||
|
||
// Load UDF | ||
Python::with_gil(|py| { | ||
let pickle = PyModule::import_bound(py, "cloudpickle") | ||
.or_else(|_| PyModule::import_bound(py, "pickle")) | ||
.expect("Unable to import 'pickle'") | ||
let pickle = PyModule::import_bound(py, "pickle") | ||
.expect("unable to import 'pickle'") | ||
.getattr("loads") | ||
.unwrap(); | ||
let arg = (PyBytes::new_bound(py, remainder),); | ||
|
@@ -189,26 +206,45 @@ impl ColumnsUdf for PythonUdfExpression { | |
|
||
#[cfg(feature = "serde")] | ||
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> { | ||
// Write byte marks | ||
buf.extend_from_slice(MAGIC_BYTE_MARK); | ||
ciborium::ser::into_writer( | ||
&( | ||
self.output_type.clone(), | ||
self.is_elementwise, | ||
self.returns_scalar, | ||
), | ||
&mut *buf, | ||
) | ||
.unwrap(); | ||
|
||
Python::with_gil(|py| { | ||
let pickle = PyModule::import_bound(py, "cloudpickle") | ||
.or_else(|_| PyModule::import_bound(py, "pickle")) | ||
.expect("Unable to import 'pickle'") | ||
// Try pickle to serialize the UDF, otherwise fall back to cloudpickle. | ||
let pickle = PyModule::import_bound(py, "pickle") | ||
.expect("unable to import 'pickle'") | ||
.getattr("dumps") | ||
.unwrap(); | ||
let dumped = pickle | ||
.call1((self.python_function.clone(),)) | ||
.map_err(from_pyerr)?; | ||
let pickle_result = pickle.call1((self.python_function.clone(),)); | ||
let (dumped, use_cloudpickle, py_version) = match pickle_result { | ||
Ok(dumped) => (dumped, false, 0), | ||
Err(_) => { | ||
let cloudpickle = PyModule::import_bound(py, "cloudpickle") | ||
.map_err(from_pyerr)? | ||
.getattr("dumps") | ||
.unwrap(); | ||
let dumped = cloudpickle | ||
.call1((self.python_function.clone(),)) | ||
.map_err(from_pyerr)?; | ||
(dumped, true, get_python_minor_version()) | ||
}, | ||
}; | ||
|
||
// Write pickle metadata | ||
buf.extend_from_slice(&[use_cloudpickle as u8, py_version]); | ||
|
||
// Write UDF metadata | ||
ciborium::ser::into_writer( | ||
&( | ||
self.output_type.clone(), | ||
self.is_elementwise, | ||
self.returns_scalar, | ||
), | ||
&mut *buf, | ||
) | ||
.unwrap(); | ||
|
||
// Write UDF | ||
let dumped = dumped.extract::<PyBackedBytes>().unwrap(); | ||
buf.extend_from_slice(&dumped); | ||
Ok(()) | ||
|
@@ -298,3 +334,17 @@ impl Expr { | |
} | ||
} | ||
} | ||
|
||
/// Get the minor Python version from the `sys` module. | ||
fn get_python_minor_version() -> u8 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we cache this behind a lazylock? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, done! Unfortunately, now I can no longer patch the |
||
Python::with_gil(|py| { | ||
PyModule::import_bound(py, "sys") | ||
.unwrap() | ||
.getattr("version_info") | ||
.unwrap() | ||
.getattr("minor") | ||
.unwrap() | ||
.extract() | ||
.unwrap() | ||
}) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cloudpickle
simply re-exportspickle.loads
, so trying to import cloudpickle during deserialization is a waste.