This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b3d5bc0c109 [SPARK-45362][PYTHON] Project out PARTITION BY expressions before Python UDTF 'eval' method consumes them b3d5bc0c109 is described below commit b3d5bc0c10908aa66510844eaabc43b6764dd7c0 Author: Daniel Tenedorio <daniel.tenedo...@databricks.com> AuthorDate: Thu Sep 28 14:02:46 2023 -0700 [SPARK-45362][PYTHON] Project out PARTITION BY expressions before Python UDTF 'eval' method consumes them ### What changes were proposed in this pull request? This PR projects out PARTITION BY expressions before Python UDTF 'eval' method consumes them. Before this PR, if a query included this `PARTITION BY` clause: ``` SELECT * FROM udtf((SELECT a, b FROM TABLE t) PARTITION BY (c, d)) ``` Then the `eval` method received four columns in each row: `a, b, c, d`. After this PR, the `eval` method only receives two columns: `a, b`, as expected. ### Why are the changes needed? This makes the Python UDTF `TABLE` columns consistently match what the `eval` method receives, as expected. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds new unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43156 from dtenedor/project-out-partition-exprs. Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/tests/test_udtf.py | 12 ++++++++++++ python/pyspark/worker.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 97d5190a506..a1d82056c50 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -2009,6 +2009,10 @@ class BaseUDTFTestsMixin: self._partition_col = None def eval(self, row: Row): + # Make sure that the PARTITION BY expressions were projected out. + assert len(row.asDict().items()) == 2 + assert "partition_col" in row + assert "input" in row self._sum += row["input"] if self._partition_col is not None and self._partition_col != row["partition_col"]: # Make sure that all values of the partitioning column are the same @@ -2092,6 +2096,10 @@ class BaseUDTFTestsMixin: self._partition_col = None def eval(self, row: Row, partition_col: str): + # Make sure that the PARTITION BY and ORDER BY expressions were projected out. + assert len(row.asDict().items()) == 2 + assert "partition_col" in row + assert "input" in row # Make sure that all values of the partitioning column are the same # for each row consumed by this method for this instance of the class. if self._partition_col is not None and self._partition_col != row[partition_col]: @@ -2247,6 +2255,10 @@ class BaseUDTFTestsMixin: ) def eval(self, row: Row): + # Make sure that the PARTITION BY and ORDER BY expressions were projected out. + assert len(row.asDict().items()) == 2 + assert "partition_col" in row + assert "input" in row # Make sure that all values of the partitioning column are the same # for each row consumed by this method for this instance of the class. if self._partition_col is not None and self._partition_col != row["partition_col"]: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 77481704979..4cffb02a64a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -51,7 +51,14 @@ from pyspark.sql.pandas.serializers import ( ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type -from pyspark.sql.types import BinaryType, Row, StringType, StructType, _parse_datatype_json_string +from pyspark.sql.types import ( + BinaryType, + Row, + StringType, + StructType, + _create_row, + _parse_datatype_json_string, +) from pyspark.util import fail_on_stopiteration, handle_worker_exception from pyspark import shuffle from pyspark.errors import PySparkRuntimeError, PySparkTypeError @@ -735,7 +742,12 @@ def read_udtf(pickleSer, infile, eval_type): yield row self._udtf = self._create_udtf() if self._udtf.eval is not None: - result = self._udtf.eval(*args, **kwargs) + # Filter the arguments to exclude projected PARTITION BY values added by Catalyst. + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + result = self._udtf.eval(*filtered_args, **filtered_kwargs) if result is not None: for row in result: yield row @@ -752,10 +764,9 @@ def read_udtf(pickleSer, infile, eval_type): prev_table_arg = self._get_table_arg(self._prev_arguments) cur_partitions_args = [] prev_partitions_args = [] - for i in partition_child_indexes: + for i in self._partition_child_indexes: cur_partitions_args.append(cur_table_arg[i]) prev_partitions_args.append(prev_table_arg[i]) - self._prev_arguments = arguments result = any(k != v for k, v in zip(cur_partitions_args, prev_partitions_args)) self._prev_arguments = arguments return result @@ -763,6 +774,18 @@ def read_udtf(pickleSer, infile, eval_type): def _get_table_arg(self, inputs: list) -> Row: return [x for x in inputs if type(x) is Row][0] + def _remove_partition_by_exprs(self, arg: Any) -> Any: + if isinstance(arg, Row): + new_row_keys = [] + new_row_values = [] + for i, (key, value) in enumerate(zip(arg.__fields__, arg)): + if i not in self._partition_child_indexes: + new_row_keys.append(key) + new_row_values.append(value) + return _create_row(new_row_keys, new_row_values) + else: + return arg + # Instantiate the UDTF class. try: if len(partition_child_indexes) > 0: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org