This is an automated email from the ASF dual-hosted git repository.

jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f11b2f292 [#5732] feat(client-python): Support transforms expression 
(#7851)
4f11b2f292 is described below

commit 4f11b2f292dd6526db15f5668d5e15489b8266e8
Author: George T. C. Lai <[email protected]>
AuthorDate: Thu Aug 7 11:34:52 2025 +0800

    [#5732] feat(client-python): Support transforms expression (#7851)
    
    ### What changes were proposed in this pull request?
    
    Support transforms expression by implementing the following Java
    classes:
    
    - Transform.java
    - Transforms.java
    
    ### Why are the changes needed?
    
    To support transforms expression.
    
    Fix: #5732
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    ---------
    
    Signed-off-by: George T. C. Lai <[email protected]>
---
 .../api/expressions/transforms/__init__.py         |  16 +
 .../api/expressions/transforms/transform.py        |  93 ++++
 .../api/expressions/transforms/transforms.py       | 487 +++++++++++++++++++++
 .../tests/unittests/rel/test_transforms.py         | 217 +++++++++
 4 files changed, 813 insertions(+)

diff --git 
a/clients/client-python/gravitino/api/expressions/transforms/__init__.py 
b/clients/client-python/gravitino/api/expressions/transforms/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/clients/client-python/gravitino/api/expressions/transforms/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git 
a/clients/client-python/gravitino/api/expressions/transforms/transform.py 
b/clients/client-python/gravitino/api/expressions/transforms/transform.py
new file mode 100644
index 0000000000..ad47be5255
--- /dev/null
+++ b/clients/client-python/gravitino/api/expressions/transforms/transform.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from abc import ABC, abstractmethod
+from typing import List
+
+from gravitino.api.expressions.expression import Expression
+from gravitino.api.expressions.named_reference import NamedReference
+from gravitino.api.expressions.partitions.partition import Partition
+from gravitino.api.expressions.partitions.partitions import Partitions
+
+
+class Transform(Expression, ABC):
+    """Represents a transform function in the public logical expression API.
+
+    For example, the transform date(ts) is used to derive a date value from a 
timestamp column.
+    The transform name is "date" and its argument is a reference to the "ts" 
column.
+    """
+
+    @abstractmethod
+    def name(self) -> str:
+        """Gets the transform function name.
+
+        Returns:
+            str: The transform function name.
+        """
+        pass
+
+    @abstractmethod
+    def arguments(self) -> List[Expression]:
+        """Gets the arguments passed to the transform function.
+
+        Returns:
+            List[Expression]: The arguments passed to the transform function.
+        """
+        pass
+
+    def assignments(self) -> List[Partition]:
+        """Gets the preassigned partitions in the partitioning.
+
+        Currently, only `Transforms.ListTransform` and 
`Transforms.RangeTransform` need to deal with
+        assignments
+
+        Returns:
+            List[Partition]: The preassigned partitions in the partitioning.
+        """
+        return Partitions.EMPTY_PARTITIONS
+
+    def children(self) -> List[Expression]:
+        return self.arguments()
+
+
+class SingleFieldTransform(Transform):
+    """Base class for transforms on a single field."""
+
+    def __init__(self, ref: NamedReference):
+        self.ref = ref
+
+    def field_name(self) -> List[str]:
+        """Gets the referenced field name as a list of string parts.
+
+        Returns:
+            List[str]: The referenced field name as an array of String parts.
+        """
+        return self.ref.field_name()
+
+    def references(self) -> List[NamedReference]:
+        return [self.ref]
+
+    def arguments(self) -> List[Expression]:
+        return [self.ref]
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, SingleFieldTransform):
+            return False
+        return self.ref == other.ref
+
+    def __hash__(self) -> int:
+        return hash(self.ref)
diff --git 
a/clients/client-python/gravitino/api/expressions/transforms/transforms.py 
b/clients/client-python/gravitino/api/expressions/transforms/transforms.py
new file mode 100644
index 0000000000..c32935659f
--- /dev/null
+++ b/clients/client-python/gravitino/api/expressions/transforms/transforms.py
@@ -0,0 +1,487 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import ClassVar, List, Optional, Union, overload
+
+from gravitino.api.expressions.expression import Expression
+from gravitino.api.expressions.literals.literal import Literal
+from gravitino.api.expressions.literals.literals import Literals
+from gravitino.api.expressions.named_reference import NamedReference
+from gravitino.api.expressions.partitions.list_partition import ListPartition
+from gravitino.api.expressions.partitions.partition import Partition
+from gravitino.api.expressions.partitions.range_partition import RangePartition
+from gravitino.api.expressions.transforms.transform import (
+    SingleFieldTransform,
+    Transform,
+)
+
+
+class Transforms(Transform):
+    """Helper methods to create logical transforms to pass into Apache 
Gravitino."""
+
+    EMPTY_TRANSFORM: ClassVar[List[Transform]] = []
+    """An empty array of transforms."""
+    NAME_OF_IDENTITY: ClassVar[str] = "identity"
+    """The name of the identity transform."""
+    NAME_OF_YEAR: ClassVar[str] = "year"
+    """The name of the year transform. The year transform returns the year of 
the input value."""
+    NAME_OF_MONTH: ClassVar[str] = "month"
+    """The name of the month transform. The month transform returns the month 
of the input value."""
+    NAME_OF_DAY: ClassVar[str] = "day"
+    """The name of the day transform. The day transform returns the day of the 
input value."""
+    NAME_OF_HOUR: ClassVar[str] = "hour"
+    """The name of the hour transform. The hour transform returns the hour of 
the input value."""
+    NAME_OF_BUCKET: ClassVar[str] = "bucket"
+    """The name of the bucket transform. The bucket transform returns the 
bucket of the input value."""
+    NAME_OF_TRUNCATE: ClassVar[str] = "truncate"
+    """The name of the truncate transform. The truncate transform returns the 
truncated value of the"""
+    NAME_OF_LIST: ClassVar[str] = "list"
+    """The name of the list transform. The list transform includes multiple 
fields in a list."""
+    NAME_OF_RANGE: ClassVar[str] = "range"
+    """The name of the range transform. The range transform returns the range 
of the input value."""
+
+    @staticmethod
+    @overload
+    def identity(field_name: List[str]) -> "Transforms.IdentityTransform": ...
+
+    @staticmethod
+    @overload
+    def identity(field_name: str) -> "Transforms.IdentityTransform": ...
+
+    @staticmethod
+    def identity(field_name: Union[str, List[str]]) -> 
"Transforms.IdentityTransform":
+        """Create a transform that returns the input value.
+
+        Args:
+            field_name (List[str]):
+                The field name(s) to transform. Can be a list of field names 
or a single field name.
+        Returns:
+            Transforms.IdentityTransform: The created transform
+        """
+
+        return Transforms.IdentityTransform(
+            NamedReference.field(
+                [field_name] if isinstance(field_name, str) else field_name
+            )
+        )
+
+    @staticmethod
+    @overload
+    def year(field_name: List[str]) -> "Transforms.YearTransform": ...
+
+    @staticmethod
+    @overload
+    def year(field_name: str) -> "Transforms.YearTransform": ...
+
+    @staticmethod
+    def year(field_name: Union[str, List[str]]) -> "Transforms.YearTransform":
+        """Create a transform that returns the input value.
+
+        Args:
+            field_name (List[str]):
+                The field name(s) to transform. Can be a list of field names 
or a single field name.
+        Returns:
+            Transforms.YearTransform: The created transform
+        """
+
+        return Transforms.YearTransform(
+            NamedReference.field(
+                [field_name] if isinstance(field_name, str) else field_name
+            )
+        )
+
+    @staticmethod
+    @overload
+    def month(field_name: List[str]) -> "Transforms.MonthTransform": ...
+
+    @staticmethod
+    @overload
+    def month(field_name: str) -> "Transforms.MonthTransform": ...
+
+    @staticmethod
+    def month(field_name: Union[str, List[str]]) -> 
"Transforms.MonthTransform":
+        """Create a transform that returns the input value.
+
+        Args:
+            field_name (List[str]):
+                The field name(s) to transform. Can be a list of field names 
or a single field name.
+        Returns:
+            MonthTransform: The created transform
+        """
+
+        return Transforms.MonthTransform(
+            NamedReference.field(
+                [field_name] if isinstance(field_name, str) else field_name
+            )
+        )
+
+    @staticmethod
+    @overload
+    def day(field_name: List[str]) -> "Transforms.DayTransform": ...
+
+    @staticmethod
+    @overload
+    def day(field_name: str) -> "Transforms.DayTransform": ...
+
+    @staticmethod
+    def day(field_name: Union[str, List[str]]) -> "Transforms.DayTransform":
+        """Create a transform that returns the input value.
+
+        Args:
+            field_name (List[str]):
+                The field name(s) to transform. Can be a list of field names 
or a single field name.
+        Returns:
+            DayTransform: The created transform
+        """
+
+        return Transforms.DayTransform(
+            NamedReference.field(
+                [field_name] if isinstance(field_name, str) else field_name
+            )
+        )
+
+    @staticmethod
+    @overload
+    def hour(field_name: List[str]) -> "Transforms.HourTransform": ...
+
+    @staticmethod
+    @overload
+    def hour(field_name: str) -> "Transforms.HourTransform": ...
+
+    @staticmethod
+    def hour(field_name: Union[str, List[str]]) -> "Transforms.HourTransform":
+        """Create a transform that returns the input value.
+
+        Args:
+            field_name (List[str]):
+                The field name(s) to transform. Can be a list of field names 
or a single field name.
+        Returns:
+            Transforms.HourTransform: The created transform
+        """
+
+        return Transforms.HourTransform(
+            NamedReference.field(
+                [field_name] if isinstance(field_name, str) else field_name
+            )
+        )
+
+    @staticmethod
+    def bucket(
+        num_buckets: int, *field_names: List[str]
+    ) -> "Transforms.BucketTransform":
+        """Create a transform that returns the bucket of the input value.
+
+        Args:
+            num_buckets (int): The number of buckets to use
+            *field_names (List[str]): The field names to transform
+
+        Returns:
+            Transforms.BucketTransform: The created transform
+        """
+        return Transforms.BucketTransform(
+            num_buckets=Literals.integer_literal(value=num_buckets),
+            fields=[NamedReference.field(names) for names in field_names],
+        )
+
+    @staticmethod
+    @overload
+    def truncate(
+        width: int, field_name: List[str]
+    ) -> "Transforms.TruncateTransform": ...
+
+    @staticmethod
+    @overload
+    def truncate(width: int, field_name: str) -> 
"Transforms.TruncateTransform": ...
+
+    @staticmethod
+    def truncate(
+        width: int, field_name: Union[str, List[str]]
+    ) -> "Transforms.TruncateTransform":
+        """Create a transform that returns the truncated value of the input 
value with the given width.
+
+        Args:
+            width (int): The width to truncate to
+            field_name (Union[str, List[str]]): The column/field name to 
transform
+
+        Returns:
+            Transforms.TruncateTransform: The created transform
+        """
+        return Transforms.TruncateTransform(
+            width=Literals.integer_literal(value=width),
+            field=NamedReference.field(
+                [field_name] if isinstance(field_name, str) else field_name
+            ),
+        )
+
+    @staticmethod
+    def apply(name: str, arguments: List[Expression]) -> 
"Transforms.ApplyTransform":
+        """Create a transform that applies a function to the input value.
+
+        Args:
+            name (str): The name of the function to apply
+            arguments (List[Expression]): he arguments to the function
+
+        Returns:
+            Transforms.ApplyTransform: The created transform
+        """
+
+        return Transforms.ApplyTransform(name=name, arguments=arguments)
+
+    @staticmethod
+    def list(
+        *field_names: List[str], assignments: Optional[List[ListPartition]] = 
None
+    ) -> "Transforms.ListTransform":
+        """Create a transform that includes multiple fields in a list.
+
+        Args:
+            *fields (List[NamedReference]):
+                The fields to include in the list
+            assignments (Optional[List[ListPartition]]):
+                The preassigned list partitions
+
+        Returns:
+            Transforms.ListTransform: The created transform
+        """
+        return Transforms.ListTransform(
+            fields=[
+                NamedReference.field(field_name=field_name)
+                for field_name in field_names
+            ],
+            assignments=[] if assignments is None else assignments,
+        )
+
+    @staticmethod
+    def range(
+        field_name: List[str], assignments: Optional[List[RangePartition]] = 
None
+    ) -> "Transforms.RangeTransform":
+        """Create a transform that returns the range of the input value with 
preassigned range partitions.
+
+        Args:
+            field_name (List[str]):
+                The field name to transform
+            assignments (Optional[List[RangePartition]], optional):
+                The preassigned range partitions. Defaults to `None`.
+
+        Returns:
+            Transforms.RangeTransform: The created transform
+        """
+        return Transforms.RangeTransform(
+            field=NamedReference.field(field_name=field_name),
+            assignments=[] if assignments is None else assignments,
+        )
+
+    class IdentityTransform(SingleFieldTransform):
+        """A transform that returns the input value."""
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_IDENTITY
+
+    class YearTransform(SingleFieldTransform):
+        """A transform that returns the year of the input value."""
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_YEAR
+
+    class MonthTransform(SingleFieldTransform):
+        """A transform that returns the month of the input value."""
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_MONTH
+
+    class DayTransform(SingleFieldTransform):
+        """A transform that returns the day of the input value."""
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_DAY
+
+    class HourTransform(SingleFieldTransform):
+        """A transform that returns the hour of the input value."""
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_HOUR
+
+    class BucketTransform(Transform):
+        """A transform that returns the bucket of the input value."""
+
+        def __init__(self, num_buckets: Literal[int], fields: 
List[NamedReference]):
+            self.num_buckets_ = num_buckets
+            self.fields = fields
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_BUCKET
+
+        def num_buckets(self) -> int:
+            return self.num_buckets_.value()
+
+        def field_names(self) -> List[List[str]]:
+            return [field.field_name() for field in self.fields]
+
+        def arguments(self) -> List[Expression]:
+            return [self.num_buckets_, *self.fields]
+
+        def __eq__(self, value: object) -> bool:
+            if self is value:
+                return True
+            return (
+                isinstance(value, Transforms.BucketTransform)
+                and self.num_buckets_ == value.num_buckets_
+                and self.fields == value.fields
+            )
+
+        def __hash__(self) -> int:
+            return hash((self.num_buckets_, *self.fields))
+
+    class TruncateTransform(Transform):
+        """A transform that returns the truncated value of the input value 
with the given width."""
+
+        def __init__(self, width: Literal[int], field: NamedReference):
+            self.width_ = width
+            self.field = field
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_TRUNCATE
+
+        def width(self) -> int:
+            """Gets the width to truncate to.
+
+            Returns:
+                int: The width to truncate to.
+            """
+
+            return self.width_.value()
+
+        def field_name(self) -> List[str]:
+            """Gets the field name to transform.
+
+            Returns:
+                List[str]: The field name to transform.
+            """
+
+            return self.field.field_name()
+
+        def arguments(self) -> List[Expression]:
+            return [self.width_, self.field]
+
+        def __eq__(self, value: object) -> bool:
+            if self is value:
+                return True
+            return (
+                isinstance(value, Transforms.TruncateTransform)
+                and self.width_ == value.width_
+                and self.field == value.field
+            )
+
+        def __hash__(self) -> int:
+            return hash((self.width_, self.field))
+
+    class ApplyTransform(Transform):
+        """A transform that applies a function to the input value."""
+
+        def __init__(self, name: str, arguments: List[Expression]):
+            self.name_ = name
+            self.arguments_ = arguments
+
+        def name(self) -> str:
+            return self.name_
+
+        def arguments(self) -> List[Expression]:
+            return self.arguments_
+
+        def __eq__(self, value: object) -> bool:
+            if self is value:
+                return True
+            return (
+                isinstance(value, Transforms.ApplyTransform)
+                and self.name_ == value.name_
+                and self.arguments_ == value.arguments_
+            )
+
+        def __hash__(self) -> int:
+            return 31 * hash(self.name_) + hash(tuple(self.arguments_))
+
+    class ListTransform(Transform):
+        """A transform that includes multiple fields in a list."""
+
+        def __init__(
+            self,
+            fields: List[NamedReference],
+            assignments: Optional[List[ListPartition]] = None,
+        ):
+            self._fields = fields
+            self._assignments = [] if assignments is None else assignments
+
+        def field_names(self) -> List[List[str]]:
+            """Gets the field names to include in the list.
+
+            Returns:
+                List[List[str]]: The field names to include in the list.
+            """
+            return [field.field_name() for field in self._fields]
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_LIST
+
+        def arguments(self) -> List[Expression]:
+            return self._fields
+
+        def assignments(self) -> List[Partition]:
+            return self._assignments
+
+        def __eq__(self, value: object) -> bool:
+            if not isinstance(value, Transforms.ListTransform):
+                return False
+            return self is value or self._fields == value.arguments()
+
+        def __hash__(self) -> int:
+            return hash(tuple(self._fields))
+
+    class RangeTransform(Transform):
+        """A transform that returns the range of the input value."""
+
+        def __init__(
+            self,
+            field: NamedReference,
+            assignments: Optional[List[RangePartition]] = None,
+        ):
+            self._field = field
+            self._assignments = [] if assignments is None else assignments
+
+        def name(self) -> str:
+            return Transforms.NAME_OF_RANGE
+
+        def field_name(self) -> List[str]:
+            """Gets the field name to transform.
+
+            Returns:
+                List[str]: The field name to transform.
+            """
+
+            return self._field.field_name()
+
+        def arguments(self) -> List[Expression]:
+            return [self._field]
+
+        def assignments(self) -> List[Partition]:
+            return self._assignments
+
+        def __eq__(self, value: object) -> bool:
+            if not isinstance(value, Transforms.RangeTransform):
+                return False
+            return self is value or self.field_name() == value.field_name()
+
+        def __hash__(self) -> int:
+            return hash(self._field)
diff --git a/clients/client-python/tests/unittests/rel/test_transforms.py 
b/clients/client-python/tests/unittests/rel/test_transforms.py
new file mode 100644
index 0000000000..495c88775b
--- /dev/null
+++ b/clients/client-python/tests/unittests/rel/test_transforms.py
@@ -0,0 +1,217 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from datetime import date, time
+from itertools import combinations
+
+from gravitino.api.expressions.literals.literals import Literals
+from gravitino.api.expressions.named_reference import NamedReference
+from gravitino.api.expressions.partitions.partitions import Partitions
+from gravitino.api.expressions.transforms.transforms import Transforms
+
+
+class TestTransforms(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        cls._temporal_transforms = {
+            Transforms.IdentityTransform: Transforms.identity,
+            Transforms.YearTransform: Transforms.year,
+            Transforms.MonthTransform: Transforms.month,
+            Transforms.DayTransform: Transforms.day,
+            Transforms.HourTransform: Transforms.hour,
+        }
+
+        cls._transform_names = {
+            Transforms.IdentityTransform: Transforms.NAME_OF_IDENTITY,
+            Transforms.YearTransform: Transforms.NAME_OF_YEAR,
+            Transforms.MonthTransform: Transforms.NAME_OF_MONTH,
+            Transforms.DayTransform: Transforms.NAME_OF_DAY,
+            Transforms.HourTransform: Transforms.NAME_OF_HOUR,
+        }
+
+        cls._sample_literals = {
+            Literals.integer_literal(value=1),
+            Literals.float_literal(value=1.0),
+            Literals.string_literal(value="dummy_string"),
+            Literals.boolean_literal(value=True),
+            Literals.byte_literal(value="dummy_byte"),
+            Literals.date_literal(value=date(year=2025, month=7, day=29)),
+            Literals.time_literal(value=time(hour=10, minute=30, second=0)),
+            Literals.long_literal(value=1),
+        }
+
+    def test_temporal_transforms(self):
+        field_name = "dummy_field"
+        ref = NamedReference.field([field_name])
+        for trans_cls, trans_func in self._temporal_transforms.items():
+            trans_from_str = trans_func(field_name=field_name)
+            trans_from_list = trans_func(field_name=[field_name])
+            trans_dict = {trans_from_str: 1, trans_from_list: 2}
+
+            self.assertIsInstance(trans_from_str, trans_cls)
+            self.assertIsInstance(trans_from_list, trans_cls)
+
+            self.assertEqual(trans_from_str.name(), 
self._transform_names[trans_cls])
+            self.assertEqual(trans_from_str.arguments(), [ref])
+            self.assertEqual(trans_from_str.assignments(), 
Partitions.EMPTY_PARTITIONS)
+            self.assertEqual(trans_from_str.children(), 
trans_from_str.arguments())
+            self.assertEqual(trans_from_str.field_name(), [field_name])
+            self.assertEqual(trans_from_str.references(), [ref])
+
+            self.assertEqual(trans_from_str, trans_from_list)
+            self.assertFalse(trans_from_str == field_name)
+            self.assertEqual(len(trans_dict), 1)
+            self.assertEqual(trans_dict[trans_from_str], 2)
+
+    def test_bucket_transform(self):
+        field_names = [["dummy_field"], [f"dummy_field_{i}" for i in range(2)]]
+        num_buckets = 10
+        bucket_transform = Transforms.bucket(num_buckets, *field_names)
+        twin_bucket_transform = Transforms.bucket(num_buckets, *field_names)
+        bucket_trans_dict = {
+            bucket_transform: 1,
+            twin_bucket_transform: 2,
+        }
+
+        self.assertIsInstance(bucket_transform, Transforms.BucketTransform)
+        self.assertEqual(bucket_transform.name(), Transforms.NAME_OF_BUCKET)
+        self.assertEqual(bucket_transform.num_buckets(), num_buckets)
+        self.assertListEqual(bucket_transform.field_names(), field_names)
+        self.assertListEqual(
+            bucket_transform.arguments(),
+            [Literals.integer_literal(num_buckets), *bucket_transform.fields],
+        )
+        self.assertEqual(bucket_transform, bucket_transform)
+        self.assertIsNot(bucket_transform, twin_bucket_transform)
+        self.assertEqual(bucket_transform, twin_bucket_transform)
+        self.assertEqual(len(bucket_trans_dict), 1)
+        self.assertEqual(bucket_trans_dict[bucket_transform], 2)
+
+    def test_truncate_transform(self):
+        field_name = "dummy_field"
+        width = 10
+        truncate_transform_str = Transforms.truncate(width, field_name)
+        truncate_transform_list = Transforms.truncate(width, [field_name])
+        truncate_trans_dict = {
+            truncate_transform_str: 1,
+            truncate_transform_list: 2,
+        }
+
+        self.assertIsInstance(truncate_transform_str, 
Transforms.TruncateTransform)
+        self.assertIsInstance(truncate_transform_list, 
Transforms.TruncateTransform)
+        self.assertEqual(truncate_transform_str.name(), 
Transforms.NAME_OF_TRUNCATE)
+        self.assertEqual(truncate_transform_str.width(), width)
+        self.assertListEqual(truncate_transform_str.field_name(), [field_name])
+        self.assertListEqual(
+            truncate_transform_str.arguments(),
+            [Literals.integer_literal(width), truncate_transform_str.field],
+        )
+        self.assertEqual(truncate_transform_str, truncate_transform_str)
+        self.assertIsNot(truncate_transform_str, truncate_transform_list)
+        self.assertEqual(truncate_transform_str, truncate_transform_list)
+        self.assertEqual(len(truncate_trans_dict), 1)
+        self.assertEqual(truncate_trans_dict[truncate_transform_str], 2)
+
+    def test_apply_transform(self):
+        name = "dummy_function"
+        num_args = 2
+        for comb in combinations(self._sample_literals, num_args):
+            arguments = list(comb)
+            apply_transform = Transforms.apply(name=name, arguments=arguments)
+            twin_apply_transform = Transforms.apply(name=name, 
arguments=arguments)
+            apply_trans_dict = {
+                apply_transform: 1,
+                twin_apply_transform: 2,
+            }
+
+            self.assertIsInstance(apply_transform, Transforms.ApplyTransform)
+            self.assertIsInstance(twin_apply_transform, 
Transforms.ApplyTransform)
+            self.assertEqual(apply_transform.name(), name)
+            self.assertListEqual(apply_transform.arguments(), arguments)
+
+            self.assertEqual(apply_transform, apply_transform)
+            self.assertIsNot(apply_transform, twin_apply_transform)
+            self.assertEqual(apply_transform, twin_apply_transform)
+            self.assertEqual(len(apply_trans_dict), 1)
+            self.assertEqual(apply_trans_dict[apply_transform], 2)
+
+    def test_list_transform(self):
+        list_transform = Transforms.list(["createTime"], ["city"])
+        list_transform_with_assignments = Transforms.list(
+            ["createTime", "city"], assignments=[]
+        )
+        trans_dict = {
+            list_transform: 1,
+            list_transform_with_assignments: 2,
+        }
+        self.assertIsInstance(list_transform, Transforms.ListTransform)
+        self.assertIsInstance(list_transform_with_assignments, 
Transforms.ListTransform)
+        self.assertEqual(list_transform.name(), Transforms.NAME_OF_LIST)
+        self.assertListEqual(
+            list_transform.arguments(),
+            [
+                NamedReference.field(field_name=field_name)
+                for field_name in list_transform.field_names()
+            ],
+        )
+        self.assertListEqual(list_transform.field_names(), [["createTime"], 
["city"]])
+        self.assertListEqual(
+            list_transform_with_assignments.field_names(), [["createTime", 
"city"]]
+        )
+        self.assertListEqual(list_transform.assignments(), [])
+        self.assertListEqual(
+            list_transform.assignments(), 
list_transform_with_assignments.assignments()
+        )
+        self.assertNotEqual(list_transform, list_transform_with_assignments)
+        self.assertFalse(list_transform == "")
+        self.assertEqual(len(trans_dict), 2)
+        self.assertEqual(trans_dict[list_transform], 1)
+        self.assertEqual(trans_dict[list_transform_with_assignments], 2)
+
+    def test_range_transform(self):
+        range_transform = Transforms.range(["createTime"])
+        range_transform_with_assignments = Transforms.range(
+            ["createTime", "city"], assignments=[]
+        )
+        trans_dict = {
+            range_transform: 1,
+            range_transform_with_assignments: 2,
+        }
+        self.assertIsInstance(range_transform, Transforms.RangeTransform)
+        self.assertIsInstance(
+            range_transform_with_assignments, Transforms.RangeTransform
+        )
+        self.assertEqual(range_transform.name(), Transforms.NAME_OF_RANGE)
+        self.assertListEqual(
+            range_transform.arguments(),
+            [NamedReference.field(field_name=range_transform.field_name())],
+        )
+        self.assertListEqual(range_transform.field_name(), ["createTime"])
+        self.assertListEqual(
+            range_transform_with_assignments.field_name(), ["createTime", 
"city"]
+        )
+        self.assertListEqual(range_transform.assignments(), [])
+        self.assertListEqual(
+            range_transform.assignments(),
+            range_transform_with_assignments.assignments(),
+        )
+        self.assertNotEqual(range_transform, range_transform_with_assignments)
+        self.assertFalse(range_transform == "")
+        self.assertEqual(len(trans_dict), 2)
+        self.assertEqual(trans_dict[range_transform], 1)
+        self.assertEqual(trans_dict[range_transform_with_assignments], 2)

Reply via email to