This is an automated email from the ASF dual-hosted git repository.
liuxun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 567376448 [#5729] feat(client-python): Add distribution expression in
Python client (#5833)
567376448 is described below
commit 5673764482d653774900da3f6326464b9b28aede
Author: SophieTech88 <[email protected]>
AuthorDate: Thu Dec 12 21:06:49 2024 -0600
[#5729] feat(client-python): Add distribution expression in Python client
(#5833)
### What changes were proposed in this pull request?
Implement distributions expression in python client, add unit test.
### Why are the changes needed?
We need to support the distributions expressions in python client
Fix: #5729
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Need to pass all unit tests.
---------
Co-authored-by: Xun <[email protected]>
---
.../api/expressions/distributions/distribution.py | 65 +++++++++++
.../api/expressions/distributions/distributions.py | 129 +++++++++++++++++++++
.../api/expressions/distributions/strategy.py | 52 +++++++++
.../gravitino/api/expressions/expression.py | 10 +-
.../client-python/tests/unittests/rel/__init__.py | 16 +++
.../tests/unittests/rel/test_distributions.py | 114 ++++++++++++++++++
.../tests/unittests/{ => rel}/test_expressions.py | 0
.../{ => rel}/test_function_expression.py | 0
.../tests/unittests/{ => rel}/test_literals.py | 0
.../tests/unittests/{ => rel}/test_types.py | 0
10 files changed, 381 insertions(+), 5 deletions(-)
diff --git
a/clients/client-python/gravitino/api/expressions/distributions/distribution.py
b/clients/client-python/gravitino/api/expressions/distributions/distribution.py
new file mode 100644
index 000000000..f0d26e39a
--- /dev/null
+++
b/clients/client-python/gravitino/api/expressions/distributions/distribution.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from abc import abstractmethod
+from typing import List
+
+from gravitino.api.expressions.distributions.strategy import Strategy
+from gravitino.api.expressions.expression import Expression
+
+
+class Distribution(Expression):
+ """
+ An interface that defines how data is distributed across partitions.
+ """
+
+ @abstractmethod
+ def strategy(self) -> Strategy:
+ """Return the distribution strategy name."""
+
+ @abstractmethod
+ def number(self) -> int:
+ """Return The number of buckets/distribution. For example, if the
distribution strategy is HASH
+ and the number is 10, then the data is distributed across 10
buckets."""
+
+ @abstractmethod
+ def expressions(self) -> List[Expression]:
+ """Return The expressions passed to the distribution function."""
+
+ def children(self) -> List[Expression]:
+ """
+ Returns the child expressions.
+ """
+ return self.expressions()
+
+ def equals(self, other: "Distribution") -> bool:
+ """
+ Indicates whether some other object is "equal to" this one.
+
+ Args:
+ other (Distribution): The reference distribution object with which
to compare.
+
+ Returns:
+ bool: True if this object is the same as the other; False
otherwise.
+ """
+ if other is None:
+ return False
+
+ return (
+ self.strategy() == other.strategy()
+ and self.number() == other.number()
+ and self.expressions() == other.expressions()
+ )
diff --git
a/clients/client-python/gravitino/api/expressions/distributions/distributions.py
b/clients/client-python/gravitino/api/expressions/distributions/distributions.py
new file mode 100644
index 000000000..a4d4bbd96
--- /dev/null
+++
b/clients/client-python/gravitino/api/expressions/distributions/distributions.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import List
+
+from gravitino.api.expressions.distributions.strategy import Strategy
+from gravitino.api.expressions.distributions.distribution import Distribution
+from gravitino.api.expressions.expression import Expression
+from gravitino.api.expressions.named_reference import NamedReference
+
+
+class DistributionImpl(Distribution):
+ _strategy: Strategy
+ _number: int
+ _expressions: List[Expression]
+
+ def __init__(self, strategy: Strategy, number: int, expressions:
List[Expression]):
+ self._strategy = strategy
+ self._number = number
+ self._expressions = expressions
+
+ def strategy(self) -> Strategy:
+ return self._strategy
+
+ def number(self) -> int:
+ return self._number
+
+ def expressions(self) -> List[Expression]:
+ return self._expressions
+
+ def __str__(self) -> str:
+ return f"DistributionImpl(strategy={self._strategy},
number={self._number}, expressions={self._expressions})"
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, DistributionImpl):
+ return False
+ return (
+ self._strategy == other.strategy()
+ and self._number == other.number()
+ and self._expressions == other.expressions()
+ )
+
+ def __hash__(self) -> int:
+ return hash((self._strategy, self._number, tuple(self._expressions)))
+
+
+class Distributions:
+ NONE: Distribution = DistributionImpl(Strategy.NONE, 0,
Expression.EMPTY_EXPRESSION)
+ """NONE is used to indicate that there is no distribution."""
+ HASH: Distribution = DistributionImpl(Strategy.HASH, 0,
Expression.EMPTY_EXPRESSION)
+ """List bucketing strategy hash, TODO: #1505 Separate the bucket number
from the Distribution."""
+ RANGE: Distribution = DistributionImpl(
+ Strategy.RANGE, 0, Expression.EMPTY_EXPRESSION
+ )
+ """List bucketing strategy range, TODO: #1505 Separate the bucket number
from the Distribution."""
+
+ @staticmethod
+ def even(number: int, *expressions: Expression) -> Distribution:
+ """
+ Create a distribution by evenly distributing the data across the
number of buckets.
+
+ :param number: The number of buckets.
+ :param expressions: The expressions to distribute by.
+ :return: The created even distribution.
+ """
+ return DistributionImpl(Strategy.EVEN, number, list(expressions))
+
+ @staticmethod
+ def hash(number: int, *expressions: Expression) -> Distribution:
+ """
+ Create a distribution by hashing the data across the number of buckets.
+
+ :param number: The number of buckets.
+ :param expressions: The expressions to distribute by.
+ :return: The created hash distribution.
+ """
+ return DistributionImpl(Strategy.HASH, number, list(expressions))
+
+ @staticmethod
+ def of(strategy: Strategy, number: int, *expressions: Expression) ->
Distribution:
+ """
+ Create a distribution by the given strategy.
+
+ :param strategy: The strategy to use.
+ :param number: The number of buckets.
+ :param expressions: The expressions to distribute by.
+ :return: The created distribution.
+ """
+ return DistributionImpl(strategy, number, list(expressions))
+
+ @staticmethod
+ def fields(
+ strategy: Strategy, number: int, *field_names: List[str]
+ ) -> Distribution:
+ """
+ Create a distribution on columns. Like distribute by (a) or (a, b),
for complex like
+ distributing by (func(a), b) or (func(a), func(b)), please use
DistributionImpl.Builder to create.
+
+ NOTE: a, b, c are column names.
+
+ SQL syntax: distribute by hash(a, b) buckets 5
+ fields(Strategy.HASH, 5, ["a"], ["b"])
+
+ SQL syntax: distribute by hash(a, b, c) buckets 10
+ fields(Strategy.HASH, 10, ["a"], ["b"], ["c"])
+
+ SQL syntax: distribute by EVEN(a) buckets 128
+ fields(Strategy.EVEN, 128, ["a"])
+
+ :param strategy: The strategy to use.
+ :param number: The number of buckets.
+ :param field_names: The field names to distribute by.
+ :return: The created distribution.
+ """
+ expressions = [NamedReference.field(name) for name in field_names]
+ return Distributions.of(strategy, number, *expressions)
diff --git
a/clients/client-python/gravitino/api/expressions/distributions/strategy.py
b/clients/client-python/gravitino/api/expressions/distributions/strategy.py
new file mode 100644
index 000000000..0ac03a1c2
--- /dev/null
+++ b/clients/client-python/gravitino/api/expressions/distributions/strategy.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from enum import Enum
+
+
+class Strategy(Enum):
+ """
+ An enum that defines the distribution strategy.
+
+ The following strategies are supported:
+
+ - NONE: No distribution strategy, depends on the underlying system's
allocation.
+ - HASH: Uses the hash value of the expression to distribute data.
+ - RANGE: Uses the specified range of the expression to distribute data.
+ - EVEN: Distributes data evenly across partitions.
+ """
+
+ NONE = "NONE"
+ HASH = "HASH"
+ RANGE = "RANGE"
+ EVEN = "EVEN"
+
+ @staticmethod
+ def get_by_name(name: str) -> "Strategy":
+ upper_name = name.upper()
+ if upper_name == "NONE":
+ return Strategy.NONE
+ elif upper_name == "HASH":
+ return Strategy.HASH
+ elif upper_name == "RANGE":
+ return Strategy.RANGE
+ elif upper_name in {"EVEN", "RANDOM"}:
+ return Strategy.EVEN
+ else:
+ raise ValueError(
+ f"Invalid distribution strategy: {name}. Valid values are:
{[s.value for s in Strategy]}"
+ )
diff --git a/clients/client-python/gravitino/api/expressions/expression.py
b/clients/client-python/gravitino/api/expressions/expression.py
index 41669042c..185db2ef4 100644
--- a/clients/client-python/gravitino/api/expressions/expression.py
+++ b/clients/client-python/gravitino/api/expressions/expression.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, List
if TYPE_CHECKING:
from gravitino.api.expressions.named_reference import NamedReference
@@ -26,23 +26,23 @@ if TYPE_CHECKING:
class Expression(ABC):
"""Base class of the public logical expression API."""
- EMPTY_EXPRESSION: list[Expression] = []
+ EMPTY_EXPRESSION: List[Expression] = []
"""
`EMPTY_EXPRESSION` is only used as an input when the default `children`
method builds the result.
"""
- EMPTY_NAMED_REFERENCE: list[NamedReference] = []
+ EMPTY_NAMED_REFERENCE: List[NamedReference] = []
"""
`EMPTY_NAMED_REFERENCE` is only used as an input when the default
`references` method builds
the result array to avoid repeatedly allocating an empty array.
"""
@abstractmethod
- def children(self) -> list[Expression]:
+ def children(self) -> List[Expression]:
"""Returns a list of the children of this node. Children should not
change."""
pass
- def references(self) -> list[NamedReference]:
+ def references(self) -> List[NamedReference]:
"""Returns a list of fields or columns that are referenced by this
expression."""
ref_set: set[NamedReference] = set()
diff --git a/clients/client-python/tests/unittests/rel/__init__.py
b/clients/client-python/tests/unittests/rel/__init__.py
new file mode 100644
index 000000000..13a83393a
--- /dev/null
+++ b/clients/client-python/tests/unittests/rel/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/clients/client-python/tests/unittests/rel/test_distributions.py
b/clients/client-python/tests/unittests/rel/test_distributions.py
new file mode 100644
index 000000000..a9e0637c5
--- /dev/null
+++ b/clients/client-python/tests/unittests/rel/test_distributions.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from typing import List
+
+from gravitino.api.expressions.distributions.distributions import (
+ DistributionImpl,
+ Distributions,
+)
+from gravitino.api.expressions.distributions.strategy import Strategy
+from gravitino.api.expressions.expression import Expression
+
+
+class MockExpression(Expression):
+ """Mock class to simulate an Expression"""
+
+ def children(self) -> List[Expression]:
+ return Expression.EMPTY_EXPRESSION
+
+
+class TestDistributions(unittest.TestCase):
+
+ def setUp(self):
+ # Create mock expressions for testing
+ self.expr1 = MockExpression() # Use the MockExpression class
+ self.expr2 = MockExpression() # Use the MockExpression class
+
+ def test_none_distribution(self):
+ # Test the NONE distribution
+ distribution = Distributions.NONE
+ self.assertEqual(distribution.strategy(), Strategy.NONE)
+ self.assertEqual(distribution.number(), 0)
+ self.assertEqual(distribution.expressions(),
Expression.EMPTY_EXPRESSION)
+
+ def test_hash_distribution(self):
+ # Test the HASH distribution
+ distribution = Distributions.HASH
+ self.assertEqual(distribution.strategy(), Strategy.HASH)
+ self.assertEqual(distribution.number(), 0)
+ self.assertEqual(distribution.expressions(),
Expression.EMPTY_EXPRESSION)
+
+ def test_range_distribution(self):
+ # Test the RANGE distribution
+ distribution = Distributions.RANGE
+ self.assertEqual(distribution.strategy(), Strategy.RANGE)
+ self.assertEqual(distribution.number(), 0)
+ self.assertEqual(distribution.expressions(),
Expression.EMPTY_EXPRESSION)
+
+ def test_even_distribution(self):
+ # Test the EVEN distribution with multiple expressions
+ distribution = Distributions.even(5, self.expr1, self.expr2)
+ self.assertEqual(distribution.strategy(), Strategy.EVEN)
+ self.assertEqual(distribution.number(), 5)
+ self.assertEqual(distribution.expressions(), [self.expr1, self.expr2])
+
+ def test_hash_distribution_with_multiple_expressions(self):
+ # Test HASH distribution with multiple expressions
+ distribution = Distributions.hash(10, self.expr1, self.expr2)
+ self.assertEqual(distribution.strategy(), Strategy.HASH)
+ self.assertEqual(distribution.number(), 10)
+ self.assertEqual(distribution.expressions(), [self.expr1, self.expr2])
+
+ def test_of_distribution(self):
+ # Test generic distribution creation using 'of'
+ distribution = Distributions.of(Strategy.RANGE, 20, self.expr1)
+ self.assertEqual(distribution.strategy(), Strategy.RANGE)
+ self.assertEqual(distribution.number(), 20)
+ self.assertEqual(distribution.expressions(), [self.expr1])
+
+ def test_fields_distribution(self):
+ # Test the 'fields' method with multiple field names
+ distribution = Distributions.fields(Strategy.HASH, 5, ["a", "b", "c"])
+ self.assertEqual(distribution.strategy(), Strategy.HASH)
+ self.assertEqual(distribution.number(), 5)
+ self.assertTrue(
+ len(distribution.expressions()) > 0
+ ) # Check that fields are converted to expressions
+
+ def test_distribution_equals(self):
+ # Test the equality of two DistributionImpl instances
+ distribution1 = DistributionImpl(Strategy.EVEN, 5, [self.expr1])
+ distribution2 = DistributionImpl(Strategy.EVEN, 5, [self.expr1])
+ distribution3 = DistributionImpl(Strategy.HASH, 10, [self.expr2])
+
+ self.assertTrue(distribution1 == distribution2)
+ self.assertFalse(distribution1 == distribution3)
+
+ def test_distribution_hash(self):
+ # Test the hash method of DistributionImpl
+ distribution1 = DistributionImpl(Strategy.HASH, 5, [self.expr1])
+ distribution2 = DistributionImpl(Strategy.HASH, 5, [self.expr1])
+ distribution3 = DistributionImpl(Strategy.RANGE, 5, [self.expr1])
+
+ self.assertEqual(
+ hash(distribution1), hash(distribution2)
+ ) # Should be equal for same values
+ self.assertNotEqual(
+ hash(distribution1), hash(distribution3)
+ ) # Should be different for different strategy
diff --git a/clients/client-python/tests/unittests/test_expressions.py
b/clients/client-python/tests/unittests/rel/test_expressions.py
similarity index 100%
rename from clients/client-python/tests/unittests/test_expressions.py
rename to clients/client-python/tests/unittests/rel/test_expressions.py
diff --git a/clients/client-python/tests/unittests/test_function_expression.py
b/clients/client-python/tests/unittests/rel/test_function_expression.py
similarity index 100%
rename from clients/client-python/tests/unittests/test_function_expression.py
rename to clients/client-python/tests/unittests/rel/test_function_expression.py
diff --git a/clients/client-python/tests/unittests/test_literals.py
b/clients/client-python/tests/unittests/rel/test_literals.py
similarity index 100%
rename from clients/client-python/tests/unittests/test_literals.py
rename to clients/client-python/tests/unittests/rel/test_literals.py
diff --git a/clients/client-python/tests/unittests/test_types.py
b/clients/client-python/tests/unittests/rel/test_types.py
similarity index 100%
rename from clients/client-python/tests/unittests/test_types.py
rename to clients/client-python/tests/unittests/rel/test_types.py