This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 8763c22d7e [#5199] feat(client-python): add distribution DTO (#8185)
8763c22d7e is described below
commit 8763c22d7ec1ca1e4df801d3f75f6f7cddba0e4b
Author: George T. C. Lai <[email protected]>
AuthorDate: Thu Aug 21 14:56:59 2025 +0800
[#5199] feat(client-python): add distribution DTO (#8185)
### What changes were proposed in this pull request?
This PR is aimed at implementing the following classes corresponding to
the Java client.
- DistributionDTO
**NOTE** that an refactor of `FieldReferenceDTO` in its Builder method
`with_column_name()` is embraced in this PR as well. The refactor is
aimed at conforming to [the one in Java
client.](https://github.com/apache/gravitino/blob/main/common/src/main/java/org/apache/gravitino/dto/rel/expressions/FieldReferenceDTO.java#L77)
### Why are the changes needed?
We need to support table partitioning, bucketing and sort ordering and
indexes
#5199
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests
---------
Signed-off-by: George T. C. Lai <[email protected]>
---
.../gravitino/dto/rel/distribution_dto.py | 112 +++++++++++++++++++++
.../dto/rel/expressions/field_reference_dto.py | 6 +-
.../unittests/dto/rel/test_distribution_dto.py | 102 +++++++++++++++++++
.../unittests/dto/rel/test_field_reference_dto.py | 2 +-
.../tests/unittests/dto/rel/test_function_arg.py | 2 +-
5 files changed, 219 insertions(+), 5 deletions(-)
diff --git a/clients/client-python/gravitino/dto/rel/distribution_dto.py
b/clients/client-python/gravitino/dto/rel/distribution_dto.py
new file mode 100644
index 0000000000..3bbb1339bc
--- /dev/null
+++ b/clients/client-python/gravitino/dto/rel/distribution_dto.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List
+
+from gravitino.api.expressions.distributions.distribution import Distribution
+from gravitino.api.expressions.distributions.strategy import Strategy
+from gravitino.dto.rel.column_dto import ColumnDTO
+from gravitino.dto.rel.expressions.function_arg import FunctionArg
+from gravitino.utils.precondition import Precondition
+
+
+class DistributionDTO(Distribution):
+ """Data transfer object representing distribution information.
+
+ Attributes:
+ NONE (DistributionDTO):
+ A DistributionDTO instance that represents no distribution.
+ """
+
+ NONE: "DistributionDTO"
+
+ def __init__(
+ self,
+ strategy: Strategy,
+ number: int,
+ args: List[FunctionArg],
+ ):
+ Precondition.check_argument(
+ number >= -1, "bucketNum must be greater than or equal -1"
+ )
+ Precondition.check_argument(args is not None, "expressions cannot be
null")
+ self._strategy = strategy if isinstance(strategy, Strategy) else
Strategy.HASH
+ self._number = number
+ self._args = args
+
+ def args(self) -> List[FunctionArg]:
+ """Returns the arguments of the function.
+
+ Returns:
+ List[FunctionArg]: The arguments of the function.
+ """
+ return self._args
+
+ def strategy(self) -> Strategy:
+ """Returns the strategy of the distribution.
+
+ Returns:
+ Strategy: The strategy of the distribution.
+ """
+ return self._strategy
+
+ def number(self) -> int:
+ """Returns the number of buckets.
+
+ Returns:
+ int: The number of buckets.
+ """
+ return self._number
+
+ def expressions(self) -> List[FunctionArg]:
+ """Returns the name of the distribution.
+
+ Returns:
+ List[FunctionArg]: The name of the distribution.
+ """
+ return self._args
+
+ def validate(self, columns: List[ColumnDTO]) -> None:
+ """Validates the distribution.
+
+ Args:
+ columns (List[ColumnDTO]): The columns to be validated.
+
+ Raises:
+ IllegalArgumentException: If the distribution is invalid.
+ """
+
+ for expression in self._args:
+ expression.validate(columns)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, DistributionDTO):
+ return False
+ return self is other or (
+ self._number == other.number()
+ and self._args == other.args()
+ and self._strategy is other.strategy()
+ )
+
+ def __hash__(self) -> int:
+ result = hash(tuple(self._args))
+ result = 31 * result + self._number
+ result = 31 * result + hash(self._strategy) if self._strategy else 0
+ return result
+
+
+DistributionDTO.NONE = DistributionDTO(Strategy.NONE, 0,
FunctionArg.EMPTY_ARGS)
diff --git
a/clients/client-python/gravitino/dto/rel/expressions/field_reference_dto.py
b/clients/client-python/gravitino/dto/rel/expressions/field_reference_dto.py
index 2e60df5f9e..633a216376 100644
--- a/clients/client-python/gravitino/dto/rel/expressions/field_reference_dto.py
+++ b/clients/client-python/gravitino/dto/rel/expressions/field_reference_dto.py
@@ -71,17 +71,17 @@ class FieldReferenceDTO(NamedReference, FunctionArg):
self._field_name = field_name
return self
- def with_column_name(self, column_name: List[str]) ->
FieldReferenceDTO.Builder:
+ def with_column_name(self, column_name: str) ->
FieldReferenceDTO.Builder:
"""Set the column name for the field reference.
Args:
- column_name (List[str]): The column name.
+ column_name (str): The column name.
Returns:
FieldReferenceDTO.Builder: The builder.
"""
- self._field_name = column_name
+ self._field_name = [column_name]
return self
def build(self) -> FieldReferenceDTO:
diff --git
a/clients/client-python/tests/unittests/dto/rel/test_distribution_dto.py
b/clients/client-python/tests/unittests/dto/rel/test_distribution_dto.py
new file mode 100644
index 0000000000..fe88afa0e1
--- /dev/null
+++ b/clients/client-python/tests/unittests/dto/rel/test_distribution_dto.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from gravitino.api.expressions.distributions.strategy import Strategy
+from gravitino.api.types.types import Types
+from gravitino.dto.rel.column_dto import ColumnDTO
+from gravitino.dto.rel.distribution_dto import DistributionDTO
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.exceptions.base import IllegalArgumentException
+
+
+class TestDistributionDTO(unittest.TestCase):
+ def test_distribution_dto_no_distribution(self):
+ no_distribution = DistributionDTO.NONE
+
+ self.assertEqual(no_distribution.number(), 0)
+ self.assertIs(no_distribution.strategy(), Strategy.NONE)
+ self.assertListEqual(no_distribution.args(), [])
+
+ def test_distribution_dto_illegal_init(self):
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "bucketNum must be greater than or equal -1",
+ DistributionDTO,
+ Strategy.HASH,
+ -2,
+ [],
+ )
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "expressions cannot be null",
+ DistributionDTO,
+ Strategy.RANGE,
+ 1,
+ None,
+ )
+
+ def test_distribution_dto_equal(self):
+ args = [FieldReferenceDTO.builder().with_column_name("score").build()]
+ dto = DistributionDTO(Strategy.RANGE, 4, [])
+ self.assertEqual(dto, dto)
+ self.assertFalse(dto == "invalid_dto")
+ self.assertFalse(dto == DistributionDTO(Strategy.HASH, 4, []))
+ self.assertFalse(dto == DistributionDTO(Strategy.RANGE, 5, []))
+ self.assertFalse(
+ dto
+ == DistributionDTO(
+ Strategy.RANGE,
+ 4,
+ args,
+ )
+ )
+ self.assertTrue(dto == DistributionDTO(Strategy.RANGE, 4, []))
+
+ def test_distribution_dto_hash(self):
+ dto1 = DistributionDTO(Strategy.RANGE, 4, [])
+ dto2 = DistributionDTO(Strategy.RANGE, 5, [])
+ dto_dict = {
+ dto1: 1,
+ dto2: 2,
+ }
+ dto_dict[dto1] = 3
+ self.assertEqual(len(dto_dict), 2)
+ self.assertEqual(dto_dict[dto1], 3)
+
+ def test_distribution_dto_init(self):
+ column = (
+ ColumnDTO.builder()
+ .with_name("dummy_col")
+ .with_data_type(Types.StringType.get())
+ .build()
+ )
+ args = [FieldReferenceDTO.builder().with_column_name("score").build()]
+
+ dto = DistributionDTO(None, 0, [])
+ self.assertEqual(dto.number(), 0)
+ self.assertIs(dto.strategy(), Strategy.HASH)
+ dto.validate([])
+
+ dto = DistributionDTO(Strategy.RANGE, 4, args)
+ self.assertEqual(dto.number(), 4)
+ self.assertIs(dto.strategy(), Strategy.RANGE)
+ self.assertListEqual(dto.args(), args)
+ self.assertListEqual(dto.expressions(), args)
+ with self.assertRaises(IllegalArgumentException):
+ dto.validate(columns=[column])
diff --git
a/clients/client-python/tests/unittests/dto/rel/test_field_reference_dto.py
b/clients/client-python/tests/unittests/dto/rel/test_field_reference_dto.py
index 402698f306..230a88ad67 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_field_reference_dto.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_field_reference_dto.py
@@ -64,6 +64,6 @@ class TestFieldReferenceDTO(unittest.TestCase):
self.assertIsInstance(dto, FieldReferenceDTO)
self.assertEqual(dto.field_name(), ["field_name"])
- dto =
FieldReferenceDTO.builder().with_column_name(["field_name"]).build()
+ dto =
FieldReferenceDTO.builder().with_column_name("field_name").build()
self.assertIsInstance(dto, FieldReferenceDTO)
self.assertEqual(dto.field_name(), ["field_name"])
diff --git a/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
b/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
index 03e4059f8d..83ebb413b5 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
@@ -56,7 +56,7 @@ class TestFunctionArg(unittest.TestCase):
literal_dto.validate(columns=self._columns)
field_ref_dto = (
-
FieldReferenceDTO.builder().with_column_name(self._column_names).build()
+
FieldReferenceDTO.builder().with_column_name(self._column_names[0]).build()
)
field_ref_dto.validate(columns=self._columns)