This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 4f11b2f292 [#5732] feat(client-python): Support transforms expression
(#7851)
4f11b2f292 is described below
commit 4f11b2f292dd6526db15f5668d5e15489b8266e8
Author: George T. C. Lai <[email protected]>
AuthorDate: Thu Aug 7 11:34:52 2025 +0800
[#5732] feat(client-python): Support transforms expression (#7851)
### What changes were proposed in this pull request?
Support transforms expression by implementing the following Java
classes:
- Transform.java
- Transforms.java
### Why are the changes needed?
To support transforms expression.
Fix: #5732
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests
---------
Signed-off-by: George T. C. Lai <[email protected]>
---
.../api/expressions/transforms/__init__.py | 16 +
.../api/expressions/transforms/transform.py | 93 ++++
.../api/expressions/transforms/transforms.py | 487 +++++++++++++++++++++
.../tests/unittests/rel/test_transforms.py | 217 +++++++++
4 files changed, 813 insertions(+)
diff --git
a/clients/client-python/gravitino/api/expressions/transforms/__init__.py
b/clients/client-python/gravitino/api/expressions/transforms/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/clients/client-python/gravitino/api/expressions/transforms/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/clients/client-python/gravitino/api/expressions/transforms/transform.py
b/clients/client-python/gravitino/api/expressions/transforms/transform.py
new file mode 100644
index 0000000000..ad47be5255
--- /dev/null
+++ b/clients/client-python/gravitino/api/expressions/transforms/transform.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from abc import ABC, abstractmethod
+from typing import List
+
+from gravitino.api.expressions.expression import Expression
+from gravitino.api.expressions.named_reference import NamedReference
+from gravitino.api.expressions.partitions.partition import Partition
+from gravitino.api.expressions.partitions.partitions import Partitions
+
+
+class Transform(Expression, ABC):
+ """Represents a transform function in the public logical expression API.
+
+ For example, the transform date(ts) is used to derive a date value from a
timestamp column.
+ The transform name is "date" and its argument is a reference to the "ts"
column.
+ """
+
+ @abstractmethod
+ def name(self) -> str:
+ """Gets the transform function name.
+
+ Returns:
+ str: The transform function name.
+ """
+ pass
+
+ @abstractmethod
+ def arguments(self) -> List[Expression]:
+ """Gets the arguments passed to the transform function.
+
+ Returns:
+ List[Expression]: The arguments passed to the transform function.
+ """
+ pass
+
+ def assignments(self) -> List[Partition]:
+ """Gets the preassigned partitions in the partitioning.
+
+ Currently, only `Transforms.ListTransform` and
`Transforms.RangeTransform` need to deal with
+ assignments
+
+ Returns:
+ List[Partition]: The preassigned partitions in the partitioning.
+ """
+ return Partitions.EMPTY_PARTITIONS
+
+ def children(self) -> List[Expression]:
+ return self.arguments()
+
+
+class SingleFieldTransform(Transform):
+ """Base class for transforms on a single field."""
+
+ def __init__(self, ref: NamedReference):
+ self.ref = ref
+
+ def field_name(self) -> List[str]:
+ """Gets the referenced field name as a list of string parts.
+
+ Returns:
+ List[str]: The referenced field name as an array of String parts.
+ """
+ return self.ref.field_name()
+
+ def references(self) -> List[NamedReference]:
+ return [self.ref]
+
+ def arguments(self) -> List[Expression]:
+ return [self.ref]
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, SingleFieldTransform):
+ return False
+ return self.ref == other.ref
+
+ def __hash__(self) -> int:
+ return hash(self.ref)
diff --git
a/clients/client-python/gravitino/api/expressions/transforms/transforms.py
b/clients/client-python/gravitino/api/expressions/transforms/transforms.py
new file mode 100644
index 0000000000..c32935659f
--- /dev/null
+++ b/clients/client-python/gravitino/api/expressions/transforms/transforms.py
@@ -0,0 +1,487 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import ClassVar, List, Optional, Union, overload
+
+from gravitino.api.expressions.expression import Expression
+from gravitino.api.expressions.literals.literal import Literal
+from gravitino.api.expressions.literals.literals import Literals
+from gravitino.api.expressions.named_reference import NamedReference
+from gravitino.api.expressions.partitions.list_partition import ListPartition
+from gravitino.api.expressions.partitions.partition import Partition
+from gravitino.api.expressions.partitions.range_partition import RangePartition
+from gravitino.api.expressions.transforms.transform import (
+ SingleFieldTransform,
+ Transform,
+)
+
+
+class Transforms(Transform):
+ """Helper methods to create logical transforms to pass into Apache
Gravitino."""
+
+ EMPTY_TRANSFORM: ClassVar[List[Transform]] = []
+ """An empty array of transforms."""
+ NAME_OF_IDENTITY: ClassVar[str] = "identity"
+ """The name of the identity transform."""
+ NAME_OF_YEAR: ClassVar[str] = "year"
+ """The name of the year transform. The year transform returns the year of
the input value."""
+ NAME_OF_MONTH: ClassVar[str] = "month"
+ """The name of the month transform. The month transform returns the month
of the input value."""
+ NAME_OF_DAY: ClassVar[str] = "day"
+ """The name of the day transform. The day transform returns the day of the
input value."""
+ NAME_OF_HOUR: ClassVar[str] = "hour"
+ """The name of the hour transform. The hour transform returns the hour of
the input value."""
+ NAME_OF_BUCKET: ClassVar[str] = "bucket"
+ """The name of the bucket transform. The bucket transform returns the
bucket of the input value."""
+ NAME_OF_TRUNCATE: ClassVar[str] = "truncate"
+ """The name of the truncate transform. The truncate transform returns the
truncated value of the"""
+ NAME_OF_LIST: ClassVar[str] = "list"
+ """The name of the list transform. The list transform includes multiple
fields in a list."""
+ NAME_OF_RANGE: ClassVar[str] = "range"
+ """The name of the range transform. The range transform returns the range
of the input value."""
+
+ @staticmethod
+ @overload
+ def identity(field_name: List[str]) -> "Transforms.IdentityTransform": ...
+
+ @staticmethod
+ @overload
+ def identity(field_name: str) -> "Transforms.IdentityTransform": ...
+
+ @staticmethod
+ def identity(field_name: Union[str, List[str]]) ->
"Transforms.IdentityTransform":
+ """Create a transform that returns the input value.
+
+ Args:
+ field_name (List[str]):
+ The field name(s) to transform. Can be a list of field names
or a single field name.
+ Returns:
+ Transforms.IdentityTransform: The created transform
+ """
+
+ return Transforms.IdentityTransform(
+ NamedReference.field(
+ [field_name] if isinstance(field_name, str) else field_name
+ )
+ )
+
+ @staticmethod
+ @overload
+ def year(field_name: List[str]) -> "Transforms.YearTransform": ...
+
+ @staticmethod
+ @overload
+ def year(field_name: str) -> "Transforms.YearTransform": ...
+
+ @staticmethod
+ def year(field_name: Union[str, List[str]]) -> "Transforms.YearTransform":
+ """Create a transform that returns the input value.
+
+ Args:
+ field_name (List[str]):
+ The field name(s) to transform. Can be a list of field names
or a single field name.
+ Returns:
+ Transforms.YearTransform: The created transform
+ """
+
+ return Transforms.YearTransform(
+ NamedReference.field(
+ [field_name] if isinstance(field_name, str) else field_name
+ )
+ )
+
+ @staticmethod
+ @overload
+ def month(field_name: List[str]) -> "Transforms.MonthTransform": ...
+
+ @staticmethod
+ @overload
+ def month(field_name: str) -> "Transforms.MonthTransform": ...
+
+ @staticmethod
+ def month(field_name: Union[str, List[str]]) ->
"Transforms.MonthTransform":
+ """Create a transform that returns the input value.
+
+ Args:
+ field_name (List[str]):
+ The field name(s) to transform. Can be a list of field names
or a single field name.
+ Returns:
+ MonthTransform: The created transform
+ """
+
+ return Transforms.MonthTransform(
+ NamedReference.field(
+ [field_name] if isinstance(field_name, str) else field_name
+ )
+ )
+
+ @staticmethod
+ @overload
+ def day(field_name: List[str]) -> "Transforms.DayTransform": ...
+
+ @staticmethod
+ @overload
+ def day(field_name: str) -> "Transforms.DayTransform": ...
+
+ @staticmethod
+ def day(field_name: Union[str, List[str]]) -> "Transforms.DayTransform":
+ """Create a transform that returns the input value.
+
+ Args:
+ field_name (List[str]):
+ The field name(s) to transform. Can be a list of field names
or a single field name.
+ Returns:
+ DayTransform: The created transform
+ """
+
+ return Transforms.DayTransform(
+ NamedReference.field(
+ [field_name] if isinstance(field_name, str) else field_name
+ )
+ )
+
+ @staticmethod
+ @overload
+ def hour(field_name: List[str]) -> "Transforms.HourTransform": ...
+
+ @staticmethod
+ @overload
+ def hour(field_name: str) -> "Transforms.HourTransform": ...
+
+ @staticmethod
+ def hour(field_name: Union[str, List[str]]) -> "Transforms.HourTransform":
+ """Create a transform that returns the input value.
+
+ Args:
+ field_name (List[str]):
+ The field name(s) to transform. Can be a list of field names
or a single field name.
+ Returns:
+ Transforms.HourTransform: The created transform
+ """
+
+ return Transforms.HourTransform(
+ NamedReference.field(
+ [field_name] if isinstance(field_name, str) else field_name
+ )
+ )
+
+ @staticmethod
+ def bucket(
+ num_buckets: int, *field_names: List[str]
+ ) -> "Transforms.BucketTransform":
+ """Create a transform that returns the bucket of the input value.
+
+ Args:
+ num_buckets (int): The number of buckets to use
+ *field_names (List[str]): The field names to transform
+
+ Returns:
+ Transforms.BucketTransform: The created transform
+ """
+ return Transforms.BucketTransform(
+ num_buckets=Literals.integer_literal(value=num_buckets),
+ fields=[NamedReference.field(names) for names in field_names],
+ )
+
+ @staticmethod
+ @overload
+ def truncate(
+ width: int, field_name: List[str]
+ ) -> "Transforms.TruncateTransform": ...
+
+ @staticmethod
+ @overload
+ def truncate(width: int, field_name: str) ->
"Transforms.TruncateTransform": ...
+
+ @staticmethod
+ def truncate(
+ width: int, field_name: Union[str, List[str]]
+ ) -> "Transforms.TruncateTransform":
+ """Create a transform that returns the truncated value of the input
value with the given width.
+
+ Args:
+ width (int): The width to truncate to
+ field_name (Union[str, List[str]]): The column/field name to
transform
+
+ Returns:
+ Transforms.TruncateTransform: The created transform
+ """
+ return Transforms.TruncateTransform(
+ width=Literals.integer_literal(value=width),
+ field=NamedReference.field(
+ [field_name] if isinstance(field_name, str) else field_name
+ ),
+ )
+
+ @staticmethod
+ def apply(name: str, arguments: List[Expression]) ->
"Transforms.ApplyTransform":
+ """Create a transform that applies a function to the input value.
+
+ Args:
+ name (str): The name of the function to apply
+ arguments (List[Expression]): he arguments to the function
+
+ Returns:
+ Transforms.ApplyTransform: The created transform
+ """
+
+ return Transforms.ApplyTransform(name=name, arguments=arguments)
+
+ @staticmethod
+ def list(
+ *field_names: List[str], assignments: Optional[List[ListPartition]] =
None
+ ) -> "Transforms.ListTransform":
+ """Create a transform that includes multiple fields in a list.
+
+ Args:
+ *fields (List[NamedReference]):
+ The fields to include in the list
+ assignments (Optional[List[ListPartition]]):
+ The preassigned list partitions
+
+ Returns:
+ Transforms.ListTransform: The created transform
+ """
+ return Transforms.ListTransform(
+ fields=[
+ NamedReference.field(field_name=field_name)
+ for field_name in field_names
+ ],
+ assignments=[] if assignments is None else assignments,
+ )
+
+ @staticmethod
+ def range(
+ field_name: List[str], assignments: Optional[List[RangePartition]] =
None
+ ) -> "Transforms.RangeTransform":
+ """Create a transform that returns the range of the input value with
preassigned range partitions.
+
+ Args:
+ field_name (List[str]):
+ The field name to transform
+ assignments (Optional[List[RangePartition]], optional):
+ The preassigned range partitions. Defaults to `None`.
+
+ Returns:
+ Transforms.RangeTransform: The created transform
+ """
+ return Transforms.RangeTransform(
+ field=NamedReference.field(field_name=field_name),
+ assignments=[] if assignments is None else assignments,
+ )
+
+ class IdentityTransform(SingleFieldTransform):
+ """A transform that returns the input value."""
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_IDENTITY
+
+ class YearTransform(SingleFieldTransform):
+ """A transform that returns the year of the input value."""
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_YEAR
+
+ class MonthTransform(SingleFieldTransform):
+ """A transform that returns the month of the input value."""
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_MONTH
+
+ class DayTransform(SingleFieldTransform):
+ """A transform that returns the day of the input value."""
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_DAY
+
+ class HourTransform(SingleFieldTransform):
+ """A transform that returns the hour of the input value."""
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_HOUR
+
+ class BucketTransform(Transform):
+ """A transform that returns the bucket of the input value."""
+
+ def __init__(self, num_buckets: Literal[int], fields:
List[NamedReference]):
+ self.num_buckets_ = num_buckets
+ self.fields = fields
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_BUCKET
+
+ def num_buckets(self) -> int:
+ return self.num_buckets_.value()
+
+ def field_names(self) -> List[List[str]]:
+ return [field.field_name() for field in self.fields]
+
+ def arguments(self) -> List[Expression]:
+ return [self.num_buckets_, *self.fields]
+
+ def __eq__(self, value: object) -> bool:
+ if self is value:
+ return True
+ return (
+ isinstance(value, Transforms.BucketTransform)
+ and self.num_buckets_ == value.num_buckets_
+ and self.fields == value.fields
+ )
+
+ def __hash__(self) -> int:
+ return hash((self.num_buckets_, *self.fields))
+
+ class TruncateTransform(Transform):
+ """A transform that returns the truncated value of the input value
with the given width."""
+
+ def __init__(self, width: Literal[int], field: NamedReference):
+ self.width_ = width
+ self.field = field
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_TRUNCATE
+
+ def width(self) -> int:
+ """Gets the width to truncate to.
+
+ Returns:
+ int: The width to truncate to.
+ """
+
+ return self.width_.value()
+
+ def field_name(self) -> List[str]:
+ """Gets the field name to transform.
+
+ Returns:
+ List[str]: The field name to transform.
+ """
+
+ return self.field.field_name()
+
+ def arguments(self) -> List[Expression]:
+ return [self.width_, self.field]
+
+ def __eq__(self, value: object) -> bool:
+ if self is value:
+ return True
+ return (
+ isinstance(value, Transforms.TruncateTransform)
+ and self.width_ == value.width_
+ and self.field == value.field
+ )
+
+ def __hash__(self) -> int:
+ return hash((self.width_, self.field))
+
+ class ApplyTransform(Transform):
+ """A transform that applies a function to the input value."""
+
+ def __init__(self, name: str, arguments: List[Expression]):
+ self.name_ = name
+ self.arguments_ = arguments
+
+ def name(self) -> str:
+ return self.name_
+
+ def arguments(self) -> List[Expression]:
+ return self.arguments_
+
+ def __eq__(self, value: object) -> bool:
+ if self is value:
+ return True
+ return (
+ isinstance(value, Transforms.ApplyTransform)
+ and self.name_ == value.name_
+ and self.arguments_ == value.arguments_
+ )
+
+ def __hash__(self) -> int:
+ return 31 * hash(self.name_) + hash(tuple(self.arguments_))
+
+ class ListTransform(Transform):
+ """A transform that includes multiple fields in a list."""
+
+ def __init__(
+ self,
+ fields: List[NamedReference],
+ assignments: Optional[List[ListPartition]] = None,
+ ):
+ self._fields = fields
+ self._assignments = [] if assignments is None else assignments
+
+ def field_names(self) -> List[List[str]]:
+ """Gets the field names to include in the list.
+
+ Returns:
+ List[List[str]]: The field names to include in the list.
+ """
+ return [field.field_name() for field in self._fields]
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_LIST
+
+ def arguments(self) -> List[Expression]:
+ return self._fields
+
+ def assignments(self) -> List[Partition]:
+ return self._assignments
+
+ def __eq__(self, value: object) -> bool:
+ if not isinstance(value, Transforms.ListTransform):
+ return False
+ return self is value or self._fields == value.arguments()
+
+ def __hash__(self) -> int:
+ return hash(tuple(self._fields))
+
+ class RangeTransform(Transform):
+ """A transform that returns the range of the input value."""
+
+ def __init__(
+ self,
+ field: NamedReference,
+ assignments: Optional[List[RangePartition]] = None,
+ ):
+ self._field = field
+ self._assignments = [] if assignments is None else assignments
+
+ def name(self) -> str:
+ return Transforms.NAME_OF_RANGE
+
+ def field_name(self) -> List[str]:
+ """Gets the field name to transform.
+
+ Returns:
+ List[str]: The field name to transform.
+ """
+
+ return self._field.field_name()
+
+ def arguments(self) -> List[Expression]:
+ return [self._field]
+
+ def assignments(self) -> List[Partition]:
+ return self._assignments
+
+ def __eq__(self, value: object) -> bool:
+ if not isinstance(value, Transforms.RangeTransform):
+ return False
+ return self is value or self.field_name() == value.field_name()
+
+ def __hash__(self) -> int:
+ return hash(self._field)
diff --git a/clients/client-python/tests/unittests/rel/test_transforms.py
b/clients/client-python/tests/unittests/rel/test_transforms.py
new file mode 100644
index 0000000000..495c88775b
--- /dev/null
+++ b/clients/client-python/tests/unittests/rel/test_transforms.py
@@ -0,0 +1,217 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from datetime import date, time
+from itertools import combinations
+
+from gravitino.api.expressions.literals.literals import Literals
+from gravitino.api.expressions.named_reference import NamedReference
+from gravitino.api.expressions.partitions.partitions import Partitions
+from gravitino.api.expressions.transforms.transforms import Transforms
+
+
+class TestTransforms(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls._temporal_transforms = {
+ Transforms.IdentityTransform: Transforms.identity,
+ Transforms.YearTransform: Transforms.year,
+ Transforms.MonthTransform: Transforms.month,
+ Transforms.DayTransform: Transforms.day,
+ Transforms.HourTransform: Transforms.hour,
+ }
+
+ cls._transform_names = {
+ Transforms.IdentityTransform: Transforms.NAME_OF_IDENTITY,
+ Transforms.YearTransform: Transforms.NAME_OF_YEAR,
+ Transforms.MonthTransform: Transforms.NAME_OF_MONTH,
+ Transforms.DayTransform: Transforms.NAME_OF_DAY,
+ Transforms.HourTransform: Transforms.NAME_OF_HOUR,
+ }
+
+ cls._sample_literals = {
+ Literals.integer_literal(value=1),
+ Literals.float_literal(value=1.0),
+ Literals.string_literal(value="dummy_string"),
+ Literals.boolean_literal(value=True),
+ Literals.byte_literal(value="dummy_byte"),
+ Literals.date_literal(value=date(year=2025, month=7, day=29)),
+ Literals.time_literal(value=time(hour=10, minute=30, second=0)),
+ Literals.long_literal(value=1),
+ }
+
+ def test_temporal_transforms(self):
+ field_name = "dummy_field"
+ ref = NamedReference.field([field_name])
+ for trans_cls, trans_func in self._temporal_transforms.items():
+ trans_from_str = trans_func(field_name=field_name)
+ trans_from_list = trans_func(field_name=[field_name])
+ trans_dict = {trans_from_str: 1, trans_from_list: 2}
+
+ self.assertIsInstance(trans_from_str, trans_cls)
+ self.assertIsInstance(trans_from_list, trans_cls)
+
+ self.assertEqual(trans_from_str.name(),
self._transform_names[trans_cls])
+ self.assertEqual(trans_from_str.arguments(), [ref])
+ self.assertEqual(trans_from_str.assignments(),
Partitions.EMPTY_PARTITIONS)
+ self.assertEqual(trans_from_str.children(),
trans_from_str.arguments())
+ self.assertEqual(trans_from_str.field_name(), [field_name])
+ self.assertEqual(trans_from_str.references(), [ref])
+
+ self.assertEqual(trans_from_str, trans_from_list)
+ self.assertFalse(trans_from_str == field_name)
+ self.assertEqual(len(trans_dict), 1)
+ self.assertEqual(trans_dict[trans_from_str], 2)
+
+ def test_bucket_transform(self):
+ field_names = [["dummy_field"], [f"dummy_field_{i}" for i in range(2)]]
+ num_buckets = 10
+ bucket_transform = Transforms.bucket(num_buckets, *field_names)
+ twin_bucket_transform = Transforms.bucket(num_buckets, *field_names)
+ bucket_trans_dict = {
+ bucket_transform: 1,
+ twin_bucket_transform: 2,
+ }
+
+ self.assertIsInstance(bucket_transform, Transforms.BucketTransform)
+ self.assertEqual(bucket_transform.name(), Transforms.NAME_OF_BUCKET)
+ self.assertEqual(bucket_transform.num_buckets(), num_buckets)
+ self.assertListEqual(bucket_transform.field_names(), field_names)
+ self.assertListEqual(
+ bucket_transform.arguments(),
+ [Literals.integer_literal(num_buckets), *bucket_transform.fields],
+ )
+ self.assertEqual(bucket_transform, bucket_transform)
+ self.assertIsNot(bucket_transform, twin_bucket_transform)
+ self.assertEqual(bucket_transform, twin_bucket_transform)
+ self.assertEqual(len(bucket_trans_dict), 1)
+ self.assertEqual(bucket_trans_dict[bucket_transform], 2)
+
+ def test_truncate_transform(self):
+ field_name = "dummy_field"
+ width = 10
+ truncate_transform_str = Transforms.truncate(width, field_name)
+ truncate_transform_list = Transforms.truncate(width, [field_name])
+ truncate_trans_dict = {
+ truncate_transform_str: 1,
+ truncate_transform_list: 2,
+ }
+
+ self.assertIsInstance(truncate_transform_str,
Transforms.TruncateTransform)
+ self.assertIsInstance(truncate_transform_list,
Transforms.TruncateTransform)
+ self.assertEqual(truncate_transform_str.name(),
Transforms.NAME_OF_TRUNCATE)
+ self.assertEqual(truncate_transform_str.width(), width)
+ self.assertListEqual(truncate_transform_str.field_name(), [field_name])
+ self.assertListEqual(
+ truncate_transform_str.arguments(),
+ [Literals.integer_literal(width), truncate_transform_str.field],
+ )
+ self.assertEqual(truncate_transform_str, truncate_transform_str)
+ self.assertIsNot(truncate_transform_str, truncate_transform_list)
+ self.assertEqual(truncate_transform_str, truncate_transform_list)
+ self.assertEqual(len(truncate_trans_dict), 1)
+ self.assertEqual(truncate_trans_dict[truncate_transform_str], 2)
+
+ def test_apply_transform(self):
+ name = "dummy_function"
+ num_args = 2
+ for comb in combinations(self._sample_literals, num_args):
+ arguments = list(comb)
+ apply_transform = Transforms.apply(name=name, arguments=arguments)
+ twin_apply_transform = Transforms.apply(name=name,
arguments=arguments)
+ apply_trans_dict = {
+ apply_transform: 1,
+ twin_apply_transform: 2,
+ }
+
+ self.assertIsInstance(apply_transform, Transforms.ApplyTransform)
+ self.assertIsInstance(twin_apply_transform,
Transforms.ApplyTransform)
+ self.assertEqual(apply_transform.name(), name)
+ self.assertListEqual(apply_transform.arguments(), arguments)
+
+ self.assertEqual(apply_transform, apply_transform)
+ self.assertIsNot(apply_transform, twin_apply_transform)
+ self.assertEqual(apply_transform, twin_apply_transform)
+ self.assertEqual(len(apply_trans_dict), 1)
+ self.assertEqual(apply_trans_dict[apply_transform], 2)
+
+ def test_list_transform(self):
+ list_transform = Transforms.list(["createTime"], ["city"])
+ list_transform_with_assignments = Transforms.list(
+ ["createTime", "city"], assignments=[]
+ )
+ trans_dict = {
+ list_transform: 1,
+ list_transform_with_assignments: 2,
+ }
+ self.assertIsInstance(list_transform, Transforms.ListTransform)
+ self.assertIsInstance(list_transform_with_assignments,
Transforms.ListTransform)
+ self.assertEqual(list_transform.name(), Transforms.NAME_OF_LIST)
+ self.assertListEqual(
+ list_transform.arguments(),
+ [
+ NamedReference.field(field_name=field_name)
+ for field_name in list_transform.field_names()
+ ],
+ )
+ self.assertListEqual(list_transform.field_names(), [["createTime"],
["city"]])
+ self.assertListEqual(
+ list_transform_with_assignments.field_names(), [["createTime",
"city"]]
+ )
+ self.assertListEqual(list_transform.assignments(), [])
+ self.assertListEqual(
+ list_transform.assignments(),
list_transform_with_assignments.assignments()
+ )
+ self.assertNotEqual(list_transform, list_transform_with_assignments)
+ self.assertFalse(list_transform == "")
+ self.assertEqual(len(trans_dict), 2)
+ self.assertEqual(trans_dict[list_transform], 1)
+ self.assertEqual(trans_dict[list_transform_with_assignments], 2)
+
+ def test_range_transform(self):
+ range_transform = Transforms.range(["createTime"])
+ range_transform_with_assignments = Transforms.range(
+ ["createTime", "city"], assignments=[]
+ )
+ trans_dict = {
+ range_transform: 1,
+ range_transform_with_assignments: 2,
+ }
+ self.assertIsInstance(range_transform, Transforms.RangeTransform)
+ self.assertIsInstance(
+ range_transform_with_assignments, Transforms.RangeTransform
+ )
+ self.assertEqual(range_transform.name(), Transforms.NAME_OF_RANGE)
+ self.assertListEqual(
+ range_transform.arguments(),
+ [NamedReference.field(field_name=range_transform.field_name())],
+ )
+ self.assertListEqual(range_transform.field_name(), ["createTime"])
+ self.assertListEqual(
+ range_transform_with_assignments.field_name(), ["createTime",
"city"]
+ )
+ self.assertListEqual(range_transform.assignments(), [])
+ self.assertListEqual(
+ range_transform.assignments(),
+ range_transform_with_assignments.assignments(),
+ )
+ self.assertNotEqual(range_transform, range_transform_with_assignments)
+ self.assertFalse(range_transform == "")
+ self.assertEqual(len(trans_dict), 2)
+ self.assertEqual(trans_dict[range_transform], 1)
+ self.assertEqual(trans_dict[range_transform_with_assignments], 2)