Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Mar 4, 2025
1 parent 9150d41 commit c2d3273
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
8 changes: 4 additions & 4 deletions crates/polars-python/src/functions/utils.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use polars::prelude::_set_check_length;
use polars_core::config::get_engine_affinity as _get_engine_affinity;
use polars_core::config::get_engine_affinity;
use pyo3::prelude::*;

#[pyfunction]
pub fn check_length(check: bool) {
unsafe { _set_check_length(check) }
}

#[pyfunction]
pub fn get_engine_affinity() -> PyResult<Option<String>> {
Ok(Some(_get_engine_affinity()))
#[pyfunction(name="get_engine_affinity")]
pub fn py_get_engine_affinity() -> PyResult<String> {
Ok(get_engine_affinity())
}
14 changes: 13 additions & 1 deletion py-polars/polars/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,19 @@ def set_engine_affinity(
│ 2 ┆ 5 │
│ 3 ┆ 6 │
└─────┴─────┘
"""
Raises
------
ValueError: if engine is not recognised.
"""
if engine not in {
"cpu",
"gpu",
"streaming",
"none",
}:
msg = 'engine must be one of "cpu", "gpu", "streaming", or None.'
raise ValueError(msg)
if engine is None:
os.environ.pop("POLARS_ENGINE_AFFINITY", None)
else:
Expand Down
26 changes: 13 additions & 13 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,7 @@ def _gpu_engine_callback(
new_streaming: bool,
_eager: bool,
) -> Callable[[Any, int | None], None] | None:
is_gpu = (
(is_config_obj := isinstance(engine, GPUEngine))
or engine == "gpu"
or get_engine_affinity() == "gpu"
)
is_gpu = (is_config_obj := isinstance(engine, GPUEngine)) or engine == "gpu"
if not (is_config_obj or engine in ("cpu", "gpu")):
msg = f"Invalid engine argument {engine=}"
raise ValueError(msg)
Expand Down Expand Up @@ -1678,7 +1674,7 @@ def profile(
truncate_nodes: int = 0,
figsize: tuple[int, int] = (18, 8),
streaming: bool = False,
engine: EngineType = "cpu",
engine: EngineType = None,
_check_order: bool = True,
**_kwargs: Any,
) -> tuple[DataFrame, DataFrame]:
Expand Down Expand Up @@ -1724,7 +1720,9 @@ def profile(
Run parts of the query in a streaming fashion (this is in an alpha state)
engine
Select the engine used to process the query, optional.
If set to `"cpu"` (default), the query is run using the
If set to `None` (default), polars will attempt to run the query
using the engine set from the `POLARS_ENGINE_AFFINITY` environment
variable. If set to `"cpu"`, the query is run using the
polars CPU engine. If set to `"gpu"`, the GPU engine is
used. Fine-grained control over the GPU engine, for
example which device to use on a system with multiple
Expand Down Expand Up @@ -1783,6 +1781,8 @@ def profile(
):
error_msg = f"profile() got an unexpected keyword argument '{k}'"
raise TypeError(error_msg)
if engine is None:
engine = get_engine_affinity()
if no_optimization:
predicate_pushdown = False
projection_pushdown = False
Expand Down Expand Up @@ -1876,7 +1876,7 @@ def collect(
collapse_joins: bool = True,
no_optimization: bool = False,
streaming: bool = False,
engine: None | EngineType = "cpu",
engine: None | EngineType = None,
background: Literal[True],
_eager: bool = False,
_check_order: bool = True,
Expand All @@ -1898,7 +1898,7 @@ def collect(
collapse_joins: bool = True,
no_optimization: bool = False,
streaming: bool = False,
engine: None | EngineType = "cpu",
engine: None | EngineType = None,
background: Literal[False] = False,
_check_order: bool = True,
_eager: bool = False,
Expand All @@ -1919,7 +1919,7 @@ def collect(
collapse_joins: bool = True,
no_optimization: bool = False,
streaming: bool = False,
engine: None | EngineType = "cpu",
engine: None | EngineType = None,
background: bool = False,
_check_order: bool = True,
_eager: bool = False,
Expand Down Expand Up @@ -1967,7 +1967,9 @@ def collect(
mode.
engine
Select the engine used to process the query, optional.
If set to `"cpu"` (default), the query is run using the
If set to `None` (default), polars will attempt to run the query
using the engine set from the `POLARS_ENGINE_AFFINITY` environment
variable. If set to `"cpu"`, the query is run using the
polars CPU engine. If set to `"gpu"`, the GPU engine is
used. Fine-grained control over the GPU engine, for
example which device to use on a system with multiple
Expand Down Expand Up @@ -2082,8 +2084,6 @@ def collect(

if engine is None:
engine = get_engine_affinity()
if engine is None:
engine = "cpu"

new_streaming = (
_kwargs.get("new_streaming", False) or get_engine_affinity() == "streaming"
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ fn polars(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
// Functions: other
m.add_wrapped(wrap_pyfunction!(functions::check_length))
.unwrap();
m.add_wrapped(wrap_pyfunction!(functions::get_engine_affinity))
m.add_wrapped(wrap_pyfunction!(functions::py_get_engine_affinity))
.unwrap();

#[cfg(feature = "sql")]
Expand Down

0 comments on commit c2d3273

Please sign in to comment.