diff --git a/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py b/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py index 8073b2b12..f689dcf05 100644 --- a/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py @@ -8,6 +8,9 @@ import pytest from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor +from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import ( + GlobalSubstreamCursor, +) from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import ( PerPartitionCursor, PerPartitionKeySerializer, @@ -715,3 +718,63 @@ def test_per_partition_state_when_set_initial_global_state( }, ] assert cursor.get_stream_state()["states"] == expected_state + + +def test_per_partition_cursor_partition_router_extra_fields( + mocked_cursor_factory, mocked_partition_router +): + first_partition = {"first_partition_key": "first_partition_value"} + mocked_partition_router.stream_slices.return_value = [ + StreamSlice( + partition=first_partition, cursor_slice={}, extra_fields={"extra_field": "extra_value"} + ), + ] + cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) + + mocked_cursor_factory.create.return_value = cursor + cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) + + cursor.set_initial_state({"states": [{"partition": first_partition, "cursor": CURSOR_STATE}]}) + slices = list(cursor.stream_slices()) + + assert slices[0].extra_fields == {"extra_field": "extra_value"} + assert slices == [ + StreamSlice( + partition={"first_partition_key": "first_partition_value"}, + cursor_slice={CURSOR_SLICE_FIELD: "first slice cursor value"}, + extra_fields={"extra_field": "extra_value"}, + ) + ] + + +def test_global_cursor_partition_router_extra_fields( + mocked_cursor_factory, mocked_partition_router +): + first_partition = {"first_partition_key": "first_partition_value"} + mocked_partition_router.stream_slices.return_value = [ + StreamSlice( + partition=first_partition, cursor_slice={}, extra_fields={"extra_field": "extra_value"} + ), + ] + cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) + + global_cursor = GlobalSubstreamCursor(cursor, mocked_partition_router) + + slices = list(global_cursor.stream_slices()) + + assert slices[0].extra_fields == {"extra_field": "extra_value"} + assert slices == [ + StreamSlice( + partition=first_partition, + cursor_slice={CURSOR_SLICE_FIELD: "first slice cursor value"}, + extra_fields={"extra_field": "extra_value"}, + ) + ]