Skip to content

Commit 08573d1

Browse files
SophieTech88xunliu
andauthored
[#5730] feat(client-python): Add sorts expression (#5879)
### What changes were proposed in this pull request? Implement sorts expression in python client, add unit test. ### Why are the changes needed? We need to support the sorts expressions in python client Fix: #5730 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Need to pass all unit tests. --------- Co-authored-by: Xun <liuxun@apache.org> Co-authored-by: Xun <xun@datastrato.com>
1 parent cb9cdd8 commit 08573d1

File tree

6 files changed

+381
-0
lines changed

6 files changed

+381
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from enum import Enum
18+
19+
20+
class NullOrdering(Enum):
21+
"""A null order used in sorting expressions."""
22+
23+
NULLS_FIRST: str = "nulls_first"
24+
"""Nulls appear before non-nulls. For ascending order, this means nulls appear at the beginning."""
25+
26+
NULLS_LAST: str = "nulls_last"
27+
"""Nulls appear after non-nulls. For ascending order, this means nulls appear at the end."""
28+
29+
def __str__(self) -> str:
30+
if self == NullOrdering.NULLS_FIRST:
31+
return "nulls_first"
32+
if self == NullOrdering.NULLS_LAST:
33+
return "nulls_last"
34+
35+
raise ValueError(f"Unexpected null order: {self}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from enum import Enum
18+
from gravitino.api.expressions.sorts.null_ordering import NullOrdering
19+
20+
21+
class SortDirection(Enum):
22+
"""A sort direction used in sorting expressions.
23+
Each direction has a default null ordering that is implied if no null ordering is specified explicitly.
24+
"""
25+
26+
ASCENDING = ("asc", NullOrdering.NULLS_FIRST)
27+
"""Ascending sort direction. Nulls appear first. For ascending order, this means nulls appear at the beginning."""
28+
29+
DESCENDING = ("desc", NullOrdering.NULLS_LAST)
30+
"""Descending sort direction. Nulls appear last. For ascending order, this means nulls appear at the end."""
31+
32+
def __init__(self, direction: str, default_null_ordering: NullOrdering):
33+
self._direction = direction
34+
self._default_null_ordering = default_null_ordering
35+
36+
def direction(self) -> str:
37+
return self._direction
38+
39+
def default_null_ordering(self) -> NullOrdering:
40+
"""
41+
Returns the default null ordering to use if no null ordering is specified explicitly.
42+
43+
Returns:
44+
NullOrdering: The default null ordering.
45+
"""
46+
return self._default_null_ordering
47+
48+
def __str__(self) -> str:
49+
if self == SortDirection.ASCENDING:
50+
return SortDirection.ASCENDING.direction()
51+
if self == SortDirection.DESCENDING:
52+
return SortDirection.DESCENDING.direction()
53+
54+
raise ValueError(f"Unexpected sort direction: {self}")
55+
56+
@staticmethod
57+
def from_string(direction: str):
58+
"""
59+
Returns the SortDirection from the string representation.
60+
61+
Args:
62+
direction: The string representation of the sort direction.
63+
64+
Returns:
65+
SortDirection: The corresponding SortDirection.
66+
"""
67+
direction = direction.lower()
68+
if direction == SortDirection.ASCENDING.direction():
69+
return SortDirection.ASCENDING
70+
if direction == SortDirection.DESCENDING.direction():
71+
return SortDirection.DESCENDING
72+
73+
raise ValueError(f"Unexpected sort direction: {direction}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from abc import ABC, abstractmethod
18+
from typing import List
19+
20+
from gravitino.api.expressions.expression import Expression
21+
from gravitino.api.expressions.sorts.null_ordering import NullOrdering
22+
from gravitino.api.expressions.sorts.sort_direction import SortDirection
23+
24+
25+
class SortOrder(Expression, ABC):
26+
"""Represents a sort order in the public expression API."""
27+
28+
@abstractmethod
29+
def expression(self) -> Expression:
30+
"""Returns the sort expression."""
31+
pass
32+
33+
@abstractmethod
34+
def direction(self) -> SortDirection:
35+
"""Returns the sort direction."""
36+
pass
37+
38+
@abstractmethod
39+
def null_ordering(self) -> NullOrdering:
40+
"""Returns the null ordering."""
41+
pass
42+
43+
def children(self) -> List[Expression]:
44+
"""Returns the children expressions of this sort order."""
45+
return [self.expression()]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from typing import List
18+
19+
from gravitino.api.expressions.expression import Expression
20+
from gravitino.api.expressions.sorts.null_ordering import NullOrdering
21+
from gravitino.api.expressions.sorts.sort_direction import SortDirection
22+
from gravitino.api.expressions.sorts.sort_order import SortOrder
23+
24+
25+
class SortImpl(SortOrder):
26+
27+
def __init__(
28+
self,
29+
expression: Expression,
30+
direction: SortDirection,
31+
null_ordering: NullOrdering,
32+
):
33+
"""Initialize the SortImpl object."""
34+
self._expression = expression
35+
self._direction = direction
36+
self._null_ordering = null_ordering
37+
38+
def expression(self) -> Expression:
39+
return self._expression
40+
41+
def direction(self) -> SortDirection:
42+
return self._direction
43+
44+
def null_ordering(self) -> NullOrdering:
45+
return self._null_ordering
46+
47+
def __eq__(self, other: object) -> bool:
48+
"""Check if two SortImpl instances are equal."""
49+
if not isinstance(other, SortImpl):
50+
return False
51+
return (
52+
self.expression() == other.expression()
53+
and self.direction() == other.direction()
54+
and self.null_ordering() == other.null_ordering()
55+
)
56+
57+
def __hash__(self) -> int:
58+
"""Generate a hash for a SortImpl instance."""
59+
return hash((self.expression(), self.direction(), self.null_ordering()))
60+
61+
def __str__(self) -> str:
62+
"""Provide a string representation of the SortImpl object."""
63+
return (
64+
f"SortImpl(expression={self._expression}, "
65+
f"direction={self._direction}, null_ordering={self._null_ordering})"
66+
)
67+
68+
69+
class SortOrders:
70+
"""Helper methods to create SortOrders to pass into Apache Gravitino."""
71+
72+
# NONE is used to indicate that there is no sort order
73+
NONE: List[SortOrder] = []
74+
75+
@staticmethod
76+
def ascending(expression: Expression) -> SortImpl:
77+
"""Creates a sort order with ascending direction and nulls first."""
78+
return SortOrders.of(expression, SortDirection.ASCENDING)
79+
80+
@staticmethod
81+
def descending(expression: Expression) -> SortImpl:
82+
"""Creates a sort order with descending direction and nulls last."""
83+
return SortOrders.of(expression, SortDirection.DESCENDING)
84+
85+
@staticmethod
86+
def of(
87+
expression: Expression,
88+
direction: SortDirection,
89+
null_ordering: NullOrdering = None,
90+
) -> SortImpl:
91+
"""Creates a sort order with the given direction and optionally specified null ordering."""
92+
if null_ordering is None:
93+
null_ordering = direction.default_null_ordering()
94+
return SortImpl(expression, direction, null_ordering)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import unittest
18+
from unittest.mock import MagicMock
19+
20+
from gravitino.api.expressions.function_expression import FunctionExpression
21+
from gravitino.api.expressions.named_reference import NamedReference
22+
from gravitino.api.expressions.sorts.sort_direction import SortDirection
23+
from gravitino.api.expressions.sorts.null_ordering import NullOrdering
24+
from gravitino.api.expressions.sorts.sort_orders import SortImpl, SortOrders
25+
from gravitino.api.expressions.expression import Expression
26+
27+
28+
class TestSortOrder(unittest.TestCase):
29+
def test_sort_direction_from_string(self):
30+
self.assertEqual(SortDirection.from_string("asc"), SortDirection.ASCENDING)
31+
self.assertEqual(SortDirection.from_string("desc"), SortDirection.DESCENDING)
32+
with self.assertRaises(ValueError):
33+
SortDirection.from_string("invalid")
34+
35+
def test_null_ordering(self):
36+
self.assertEqual(str(NullOrdering.NULLS_FIRST), "nulls_first")
37+
self.assertEqual(str(NullOrdering.NULLS_LAST), "nulls_last")
38+
39+
def test_sort_impl_initialization(self):
40+
mock_expression = MagicMock(spec=Expression)
41+
sort_impl = SortImpl(
42+
expression=mock_expression,
43+
direction=SortDirection.ASCENDING,
44+
null_ordering=NullOrdering.NULLS_FIRST,
45+
)
46+
self.assertEqual(sort_impl.expression(), mock_expression)
47+
self.assertEqual(sort_impl.direction(), SortDirection.ASCENDING)
48+
self.assertEqual(sort_impl.null_ordering(), NullOrdering.NULLS_FIRST)
49+
50+
def test_sort_impl_equality(self):
51+
mock_expression1 = MagicMock(spec=Expression)
52+
mock_expression2 = MagicMock(spec=Expression)
53+
54+
sort_impl1 = SortImpl(
55+
expression=mock_expression1,
56+
direction=SortDirection.ASCENDING,
57+
null_ordering=NullOrdering.NULLS_FIRST,
58+
)
59+
sort_impl2 = SortImpl(
60+
expression=mock_expression1,
61+
direction=SortDirection.ASCENDING,
62+
null_ordering=NullOrdering.NULLS_FIRST,
63+
)
64+
sort_impl3 = SortImpl(
65+
expression=mock_expression2,
66+
direction=SortDirection.ASCENDING,
67+
null_ordering=NullOrdering.NULLS_FIRST,
68+
)
69+
70+
self.assertEqual(sort_impl1, sort_impl2)
71+
self.assertNotEqual(sort_impl1, sort_impl3)
72+
73+
def test_sort_orders(self):
74+
mock_expression = MagicMock(spec=Expression)
75+
ascending_order = SortOrders.ascending(mock_expression)
76+
self.assertEqual(ascending_order.direction(), SortDirection.ASCENDING)
77+
self.assertEqual(ascending_order.null_ordering(), NullOrdering.NULLS_FIRST)
78+
79+
descending_order = SortOrders.descending(mock_expression)
80+
self.assertEqual(descending_order.direction(), SortDirection.DESCENDING)
81+
self.assertEqual(descending_order.null_ordering(), NullOrdering.NULLS_LAST)
82+
83+
def test_sort_impl_string_representation(self):
84+
mock_expression = MagicMock(spec=Expression)
85+
sort_impl = SortImpl(
86+
expression=mock_expression,
87+
direction=SortDirection.ASCENDING,
88+
null_ordering=NullOrdering.NULLS_FIRST,
89+
)
90+
expected_str = (
91+
f"SortImpl(expression={mock_expression}, "
92+
f"direction=asc, null_ordering=nulls_first)"
93+
)
94+
self.assertEqual(str(sort_impl), expected_str)
95+
96+
def test_sort_order(self):
97+
field_reference = NamedReference.field(["field1"])
98+
sort_order = SortOrders.of(
99+
field_reference, SortDirection.ASCENDING, NullOrdering.NULLS_FIRST
100+
)
101+
102+
self.assertEqual(NullOrdering.NULLS_FIRST, sort_order.null_ordering())
103+
self.assertEqual(SortDirection.ASCENDING, sort_order.direction())
104+
self.assertIsInstance(sort_order.expression(), NamedReference)
105+
self.assertEqual(["field1"], sort_order.expression().field_name())
106+
107+
date = FunctionExpression.of("date", NamedReference.field(["b"]))
108+
sort_order = SortOrders.of(
109+
date, SortDirection.DESCENDING, NullOrdering.NULLS_LAST
110+
)
111+
self.assertEqual(NullOrdering.NULLS_LAST, sort_order.null_ordering())
112+
self.assertEqual(SortDirection.DESCENDING, sort_order.direction())
113+
114+
self.assertIsInstance(sort_order.expression(), FunctionExpression)
115+
self.assertEqual("date", sort_order.expression().function_name())
116+
self.assertEqual(
117+
["b"], sort_order.expression().arguments()[0].references()[0].field_name()
118+
)

0 commit comments

Comments
 (0)