Skip to content

Commit

Permalink
fix many errors
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Mar 6, 2025
1 parent 9aacc3b commit e3d21ad
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 30 deletions.
7 changes: 4 additions & 3 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -744,9 +744,10 @@ impl LazyFrame {
}

match engine {
Engine::Auto | Engine::InMemory | Engine::Gpu => (),
#[cfg(feature = "new_streaming")]
Engine::Streaming => self = self.with_new_streaming(true),
Engine::OldStreaming => self = self.with_streaming(true),
_ => {},
}

match engine {
Expand Down Expand Up @@ -1208,7 +1209,7 @@ impl LazyFrame {
}

match engine {
Engine::Auto | Engine::Streaming => {
Engine::Auto | Engine::Streaming => feature_gated!("new_streaming", {
let mut new_stream_lazy = self.clone();
new_stream_lazy.opt_state |= OptFlags::NEW_STREAMING;
new_stream_lazy.opt_state &= !OptFlags::STREAMING;
Expand All @@ -1227,7 +1228,7 @@ impl LazyFrame {
.map(|_| ());
drop(string_cache_hold);
result
},
}),
_ if matches!(payload, SinkType::Partition { .. }) => Err(polars_err!(
InvalidOperation: "partition sinks are not supported on for the '{}' engine",
engine.into_static_str()
Expand Down
6 changes: 3 additions & 3 deletions docs/source/src/python/user-guide/concepts/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
.group_by("species")
.agg(pl.col("sepal_width").mean())
)
df = q1.collect(engine='streaming')
df = q1.collect(engine="streaming")
# --8<-- [end:streaming]

# --8<-- [start:example]
print(q1.explain(engine='streaming'))
print(q1.explain(engine="streaming"))

# --8<-- [end:example]

Expand All @@ -23,5 +23,5 @@
pl.col("sepal_length").mean().over("species")
)

print(q2.explain(engine='streaming'))
print(q2.explain(engine="streaming"))
# --8<-- [end:example2]
2 changes: 1 addition & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3248,7 +3248,7 @@ def fetch(
comm_subexpr_elim=comm_subexpr_elim,
cluster_with_columns=cluster_with_columns,
collapse_joins=collapse_joins,
engine="old-streaming" if streaming else "in-memory",
streaming=streaming,
)

def _fetch(
Expand Down
36 changes: 17 additions & 19 deletions py-polars/tests/unit/io/test_multiscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,8 @@ def test_multiscan_projection(
base_projection[::-1],
]:
assert_frame_equal(
scan(multiscan_path, **args)
.collect(engine="streaming") # type: ignore[call-overload]
.select(projection),
scan(multiscan_path, **args).select(projection).collect(engine="streaming"), # type: ignore[call-overload]
scan(multiscan_path, **args).collect(engine="streaming").select(projection),
scan(multiscan_path, **args).select(projection).collect(engine="streaming"),
)

for remove in range(len(base_projection)):
Expand All @@ -149,11 +147,11 @@ def test_multiscan_projection(
print(projection)
assert_frame_equal(
scan(multiscan_path, **args)
.collect(engine="streaming") # type: ignore[call-overload]
.collect(engine="streaming")
.select(projection),
scan(multiscan_path, **args)
.select(projection)
.collect(engine="streaming"), # type: ignore[call-overload]
.collect(engine="streaming"),
)


Expand Down Expand Up @@ -188,7 +186,7 @@ def test_multiscan_hive_predicate(
write(b, b_path)
write(c, c_path)

full = scan(multiscan_path).collect(engine="streaming") # type: ignore[call-overload]
full = scan(multiscan_path).collect(engine="streaming")
full_ri = full.with_row_index("ri", 42)

last_pred = None
Expand All @@ -209,15 +207,15 @@ def test_multiscan_hive_predicate(
last_pred = pred
assert_frame_equal(
full.filter(pred),
scan(multiscan_path).filter(pred).collect(engine="streaming"), # type: ignore[call-overload]
scan(multiscan_path).filter(pred).collect(engine="streaming"),
)

assert_frame_equal(
full_ri.filter(pred),
scan(multiscan_path)
.with_row_index("ri", 42)
.filter(pred)
.collect(engine="streaming"), # type: ignore[call-overload]
.collect(engine="streaming"),
)
except Exception as _:
print(last_pred)
Expand Down Expand Up @@ -337,7 +335,7 @@ def test_schema_mismatch_type_mismatch(
pl.exceptions.SchemaError,
match="data type mismatch for column xyz_col: expected: i64, found: str",
):
q.collect(engine="streaming") # type: ignore[call-overload]
q.collect(engine="streaming")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -377,7 +375,7 @@ def test_schema_mismatch_order_mismatch(
q = scan(multiscan_path)

with pytest.raises(pl.exceptions.SchemaError):
q.collect(engine="streaming") # type: ignore[call-overload]
q.collect(engine="streaming")


@pytest.mark.parametrize(
Expand All @@ -403,7 +401,7 @@ def test_multiscan_head(
f.seek(0)

assert_frame_equal(
scan([a, b]).head(5).collect(engine="streaming"), # type: ignore[call-overload]
scan([a, b]).head(5).collect(engine="streaming"),
pl.Series("c1", range(5)).to_frame(),
)

Expand Down Expand Up @@ -431,7 +429,7 @@ def test_multiscan_tail(
f.seek(0)

assert_frame_equal(
scan([a, b]).tail(5).collect(engine="streaming"), # type: ignore[call-overload]
scan([a, b]).tail(5).collect(engine="streaming"),
pl.Series("c1", range(5, 10)).to_frame(),
)

Expand Down Expand Up @@ -469,22 +467,22 @@ def test_multiscan_slice_middle(
] + expected_series

assert_frame_equal(
scan(fs).slice(offset, 17).collect(engine="streaming"), # type: ignore[call-overload]
scan(fs).slice(offset, 17).collect(engine="streaming"),
pl.DataFrame(expected_series),
)
assert_frame_equal(
scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"), # type: ignore[call-overload]
scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"),
pl.DataFrame(ri_expected_series),
)

# Negative slices
offset = -(13 * 7 - offset)
assert_frame_equal(
scan(fs).slice(offset, 17).collect(engine="streaming"), # type: ignore[call-overload]
scan(fs).slice(offset, 17).collect(engine="streaming"),
pl.DataFrame(expected_series),
)
assert_frame_equal(
scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"), # type: ignore[call-overload]
scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"),
pl.DataFrame(ri_expected_series),
)

Expand Down Expand Up @@ -517,7 +515,7 @@ def test_multiscan_slice_parametric(

assert_frame_equal(
scan(ref).slice(offset, length).collect(),
scan(fs).slice(offset, length).collect(engine="streaming"), # type: ignore[call-overload]
scan(fs).slice(offset, length).collect(engine="streaming"),
)

ref.seek(0)
Expand All @@ -530,7 +528,7 @@ def test_multiscan_slice_parametric(
.collect(),
scan(fs, row_index_name="ri", row_index_offset=42)
.slice(offset, length)
.collect(engine="streaming"), # type: ignore[call-overload]
.collect(engine="streaming"),
)


Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ def test_scan_csv_bytesio_memory_usage(
assert (
pl.scan_csv(f)
.filter(pl.col("mydata") == 999_999)
.collect(engine="streaming" if streaming else "in-memory") # type: ignore[call-overload]
.collect(engine="streaming" if streaming else "in-memory")
.item()
== 999_999
)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/operations/test_merge_sorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
@pytest.mark.parametrize("streaming", [False, True])
def test_merge_sorted(streaming: bool) -> None:
assert_frame_equal(
lf.collect(engine="streaming" if streaming else "in-memory"), # type: ignore[call-overload]
lf.collect(engine="streaming" if streaming else "in-memory"),
expected,
)

Expand Down Expand Up @@ -114,7 +114,7 @@ def test_merge_sorted_unbalanced(size: int, ra: list[int]) -> None:
)

lf = lhs.lazy().merge_sorted(rhs.lazy(), "a")
df = lf.collect(engine="streaming") # type: ignore[call-overload]
df = lf.collect(engine="streaming")

nulls_last = ra[0] is not None

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/streaming/test_streaming_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,5 +529,5 @@ def test_streaming_group_by_all_null_21593() -> None:
}
)

out = df.lazy().group_by(pl.all()).min().collect(engine="streaming") # type: ignore[call-overload]
out = df.lazy().group_by(pl.all()).min().collect(engine="streaming")
assert_frame_equal(df, out, check_row_order=False)

0 comments on commit e3d21ad

Please sign in to comment.