This is an automated email from the ASF dual-hosted git repository.

jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git


The following commit(s) were added to refs/heads/main by this push:
     new 8763c22d7e [#5199] feat(client-python): add distribution DTO (#8185)
8763c22d7e is described below

commit 8763c22d7ec1ca1e4df801d3f75f6f7cddba0e4b
Author: George T. C. Lai <[email protected]>
AuthorDate: Thu Aug 21 14:56:59 2025 +0800

    [#5199] feat(client-python): add distribution DTO (#8185)
    
    ### What changes were proposed in this pull request?
    
    This PR is aimed at implementing the following classes corresponding to
    the Java client.
    
    - DistributionDTO
    
    **NOTE** that an refactor of `FieldReferenceDTO` in its Builder method
    `with_column_name()` is embraced in this PR as well. The refactor is
    aimed at conforming to [the one in Java
    
client.](https://github.com/apache/gravitino/blob/main/common/src/main/java/org/apache/gravitino/dto/rel/expressions/FieldReferenceDTO.java#L77)
    
    ### Why are the changes needed?
    
    We need to support table partitioning, bucketing and sort ordering and
    indexes
    
    #5199
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    ---------
    
    Signed-off-by: George T. C. Lai <[email protected]>
---
 .../gravitino/dto/rel/distribution_dto.py          | 112 +++++++++++++++++++++
 .../dto/rel/expressions/field_reference_dto.py     |   6 +-
 .../unittests/dto/rel/test_distribution_dto.py     | 102 +++++++++++++++++++
 .../unittests/dto/rel/test_field_reference_dto.py  |   2 +-
 .../tests/unittests/dto/rel/test_function_arg.py   |   2 +-
 5 files changed, 219 insertions(+), 5 deletions(-)

diff --git a/clients/client-python/gravitino/dto/rel/distribution_dto.py 
b/clients/client-python/gravitino/dto/rel/distribution_dto.py
new file mode 100644
index 0000000000..3bbb1339bc
--- /dev/null
+++ b/clients/client-python/gravitino/dto/rel/distribution_dto.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List
+
+from gravitino.api.expressions.distributions.distribution import Distribution
+from gravitino.api.expressions.distributions.strategy import Strategy
+from gravitino.dto.rel.column_dto import ColumnDTO
+from gravitino.dto.rel.expressions.function_arg import FunctionArg
+from gravitino.utils.precondition import Precondition
+
+
+class DistributionDTO(Distribution):
+    """Data transfer object representing distribution information.
+
+    Attributes:
+        NONE (DistributionDTO):
+            A DistributionDTO instance that represents no distribution.
+    """
+
+    NONE: "DistributionDTO"
+
+    def __init__(
+        self,
+        strategy: Strategy,
+        number: int,
+        args: List[FunctionArg],
+    ):
+        Precondition.check_argument(
+            number >= -1, "bucketNum must be greater than or equal -1"
+        )
+        Precondition.check_argument(args is not None, "expressions cannot be 
null")
+        self._strategy = strategy if isinstance(strategy, Strategy) else 
Strategy.HASH
+        self._number = number
+        self._args = args
+
+    def args(self) -> List[FunctionArg]:
+        """Returns the arguments of the function.
+
+        Returns:
+            List[FunctionArg]: The arguments of the function.
+        """
+        return self._args
+
+    def strategy(self) -> Strategy:
+        """Returns the strategy of the distribution.
+
+        Returns:
+            Strategy: The strategy of the distribution.
+        """
+        return self._strategy
+
+    def number(self) -> int:
+        """Returns the number of buckets.
+
+        Returns:
+            int: The number of buckets.
+        """
+        return self._number
+
+    def expressions(self) -> List[FunctionArg]:
+        """Returns the name of the distribution.
+
+        Returns:
+            List[FunctionArg]: The name of the distribution.
+        """
+        return self._args
+
+    def validate(self, columns: List[ColumnDTO]) -> None:
+        """Validates the distribution.
+
+        Args:
+            columns (List[ColumnDTO]): The columns to be validated.
+
+        Raises:
+            IllegalArgumentException: If the distribution is invalid.
+        """
+
+        for expression in self._args:
+            expression.validate(columns)
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, DistributionDTO):
+            return False
+        return self is other or (
+            self._number == other.number()
+            and self._args == other.args()
+            and self._strategy is other.strategy()
+        )
+
+    def __hash__(self) -> int:
+        result = hash(tuple(self._args))
+        result = 31 * result + self._number
+        result = 31 * result + hash(self._strategy) if self._strategy else 0
+        return result
+
+
+DistributionDTO.NONE = DistributionDTO(Strategy.NONE, 0, 
FunctionArg.EMPTY_ARGS)
diff --git 
a/clients/client-python/gravitino/dto/rel/expressions/field_reference_dto.py 
b/clients/client-python/gravitino/dto/rel/expressions/field_reference_dto.py
index 2e60df5f9e..633a216376 100644
--- a/clients/client-python/gravitino/dto/rel/expressions/field_reference_dto.py
+++ b/clients/client-python/gravitino/dto/rel/expressions/field_reference_dto.py
@@ -71,17 +71,17 @@ class FieldReferenceDTO(NamedReference, FunctionArg):
             self._field_name = field_name
             return self
 
-        def with_column_name(self, column_name: List[str]) -> 
FieldReferenceDTO.Builder:
+        def with_column_name(self, column_name: str) -> 
FieldReferenceDTO.Builder:
             """Set the column name for the field reference.
 
             Args:
-                column_name (List[str]): The column name.
+                column_name (str): The column name.
 
             Returns:
                 FieldReferenceDTO.Builder: The builder.
             """
 
-            self._field_name = column_name
+            self._field_name = [column_name]
             return self
 
         def build(self) -> FieldReferenceDTO:
diff --git 
a/clients/client-python/tests/unittests/dto/rel/test_distribution_dto.py 
b/clients/client-python/tests/unittests/dto/rel/test_distribution_dto.py
new file mode 100644
index 0000000000..fe88afa0e1
--- /dev/null
+++ b/clients/client-python/tests/unittests/dto/rel/test_distribution_dto.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from gravitino.api.expressions.distributions.strategy import Strategy
+from gravitino.api.types.types import Types
+from gravitino.dto.rel.column_dto import ColumnDTO
+from gravitino.dto.rel.distribution_dto import DistributionDTO
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.exceptions.base import IllegalArgumentException
+
+
+class TestDistributionDTO(unittest.TestCase):
+    def test_distribution_dto_no_distribution(self):
+        no_distribution = DistributionDTO.NONE
+
+        self.assertEqual(no_distribution.number(), 0)
+        self.assertIs(no_distribution.strategy(), Strategy.NONE)
+        self.assertListEqual(no_distribution.args(), [])
+
+    def test_distribution_dto_illegal_init(self):
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "bucketNum must be greater than or equal -1",
+            DistributionDTO,
+            Strategy.HASH,
+            -2,
+            [],
+        )
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "expressions cannot be null",
+            DistributionDTO,
+            Strategy.RANGE,
+            1,
+            None,
+        )
+
+    def test_distribution_dto_equal(self):
+        args = [FieldReferenceDTO.builder().with_column_name("score").build()]
+        dto = DistributionDTO(Strategy.RANGE, 4, [])
+        self.assertEqual(dto, dto)
+        self.assertFalse(dto == "invalid_dto")
+        self.assertFalse(dto == DistributionDTO(Strategy.HASH, 4, []))
+        self.assertFalse(dto == DistributionDTO(Strategy.RANGE, 5, []))
+        self.assertFalse(
+            dto
+            == DistributionDTO(
+                Strategy.RANGE,
+                4,
+                args,
+            )
+        )
+        self.assertTrue(dto == DistributionDTO(Strategy.RANGE, 4, []))
+
+    def test_distribution_dto_hash(self):
+        dto1 = DistributionDTO(Strategy.RANGE, 4, [])
+        dto2 = DistributionDTO(Strategy.RANGE, 5, [])
+        dto_dict = {
+            dto1: 1,
+            dto2: 2,
+        }
+        dto_dict[dto1] = 3
+        self.assertEqual(len(dto_dict), 2)
+        self.assertEqual(dto_dict[dto1], 3)
+
+    def test_distribution_dto_init(self):
+        column = (
+            ColumnDTO.builder()
+            .with_name("dummy_col")
+            .with_data_type(Types.StringType.get())
+            .build()
+        )
+        args = [FieldReferenceDTO.builder().with_column_name("score").build()]
+
+        dto = DistributionDTO(None, 0, [])
+        self.assertEqual(dto.number(), 0)
+        self.assertIs(dto.strategy(), Strategy.HASH)
+        dto.validate([])
+
+        dto = DistributionDTO(Strategy.RANGE, 4, args)
+        self.assertEqual(dto.number(), 4)
+        self.assertIs(dto.strategy(), Strategy.RANGE)
+        self.assertListEqual(dto.args(), args)
+        self.assertListEqual(dto.expressions(), args)
+        with self.assertRaises(IllegalArgumentException):
+            dto.validate(columns=[column])
diff --git 
a/clients/client-python/tests/unittests/dto/rel/test_field_reference_dto.py 
b/clients/client-python/tests/unittests/dto/rel/test_field_reference_dto.py
index 402698f306..230a88ad67 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_field_reference_dto.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_field_reference_dto.py
@@ -64,6 +64,6 @@ class TestFieldReferenceDTO(unittest.TestCase):
         self.assertIsInstance(dto, FieldReferenceDTO)
         self.assertEqual(dto.field_name(), ["field_name"])
 
-        dto = 
FieldReferenceDTO.builder().with_column_name(["field_name"]).build()
+        dto = 
FieldReferenceDTO.builder().with_column_name("field_name").build()
         self.assertIsInstance(dto, FieldReferenceDTO)
         self.assertEqual(dto.field_name(), ["field_name"])
diff --git a/clients/client-python/tests/unittests/dto/rel/test_function_arg.py 
b/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
index 03e4059f8d..83ebb413b5 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
@@ -56,7 +56,7 @@ class TestFunctionArg(unittest.TestCase):
         literal_dto.validate(columns=self._columns)
 
         field_ref_dto = (
-            
FieldReferenceDTO.builder().with_column_name(self._column_names).build()
+            
FieldReferenceDTO.builder().with_column_name(self._column_names[0]).build()
         )
         field_ref_dto.validate(columns=self._columns)
 

Reply via email to