This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 08c078319fa7 [SPARK-54617][PYTHON][SQL] Enable Arrow Grouped Iter
Aggregate UDF registration for SQL
08c078319fa7 is described below
commit 08c078319fa7f08b0f7d534f9760d802e4b0033b
Author: Yicong-Huang <[email protected]>
AuthorDate: Tue Dec 9 14:39:11 2025 +0800
[SPARK-54617][PYTHON][SQL] Enable Arrow Grouped Iter Aggregate UDF
registration for SQL
### What changes were proposed in this pull request?
This PR enables Arrow grouped iter aggregate UDFs to be registered and used
in SQL queries. Previously, Arrow iter aggregate UDFs could only be used via
DataFrame API, but not in SQL.
The main change is adding `SQL_GROUPED_AGG_ARROW_ITER_UDF` to the allowed
eval types in `UDFRegistration.register()` method, along with comprehensive
test cases.
### Why are the changes needed?
Arrow iter aggregate UDFs provide a memory-efficient way to perform grouped
aggregations by processing data in batches iteratively. However, they could
only be used via DataFrame API, not in SQL queries. This limitation prevented
users from using these UDFs in SQL-based workflows.
### Does this PR introduce _any_ user-facing change?
Yes. Users can now register Arrow grouped iter aggregate UDFs and use them
in SQL queries.
Example:
```python
from typing import Iterator
from pyspark.sql.functions import arrow_udf
import pyarrow as pa
arrow_udf("double")
def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
sum_val = 0.0
cnt = 0
for v in it:
sum_val += pa.compute.sum(v).as_py()
cnt += len(v)
return sum_val / cnt if cnt > 0 else 0.0
# Now this works:
spark.udf.register("arrow_mean_iter", arrow_mean_iter)
spark.sql("SELECT id, arrow_mean_iter(v) as mean FROM test_table GROUP BY
id").show()
```
### How was this patch tested?
Added comprehensive test cases covering:
- Single column Arrow iter aggregate UDF in SQL
- Multiple columns Arrow iter aggregate UDF in SQL
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53357 from Yicong-Huang/SPARK-54617/feat/arrow-iter-agg-udf-sql.
Authored-by: Yicong-Huang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/udf.py | 4 +-
.../sql/tests/arrow/test_arrow_udf_grouped_agg.py | 72 ++++++++++++++++++++++
.../sql/tests/pandas/test_pandas_grouped_map.py | 3 +-
python/pyspark/sql/udf.py | 4 +-
4 files changed, 80 insertions(+), 3 deletions(-)
diff --git a/python/pyspark/sql/connect/udf.py
b/python/pyspark/sql/connect/udf.py
index 64506731badb..a5257ac9d09e 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -295,6 +295,7 @@ class UDFRegistration:
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
]:
raise PySparkTypeError(
errorClass="INVALID_UDF_EVAL_TYPE",
@@ -302,7 +303,8 @@ class UDFRegistration:
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_ARROW_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF,
SQL_SCALAR_ARROW_ITER_UDF, "
- "SQL_GROUPED_AGG_PANDAS_UDF or
SQL_GROUPED_AGG_ARROW_UDF"
+ "SQL_GROUPED_AGG_PANDAS_UDF, SQL_GROUPED_AGG_ARROW_UDF
or "
+ "SQL_GROUPED_AGG_ARROW_ITER_UDF"
},
)
self.sparkSession._client.register_udf(
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index 844c7f111db4..dd953ad3b973 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -1212,6 +1212,78 @@ class GroupedAggArrowUDFTestsMixin:
group2_result["result"]["sum"], 2.0, places=5, msg="Group 2
should sum to 2.0"
)
+ def test_iterator_grouped_agg_sql_single_column(self):
+ """
+ Test iterator API for grouped aggregation with single column in SQL.
+ """
+ import pyarrow as pa
+
+ @arrow_udf("double")
+ def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for v in it:
+ assert isinstance(v, pa.Array)
+ sum_val += pa.compute.sum(v).as_py()
+ cnt += len(v)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ with self.tempView("test_table"), self.temp_func("arrow_mean_iter"):
+ df.createOrReplaceTempView("test_table")
+ self.spark.udf.register("arrow_mean_iter", arrow_mean_iter)
+
+ # Test SQL query with GROUP BY
+ result_sql = self.spark.sql(
+ "SELECT id, arrow_mean_iter(v) as mean FROM test_table GROUP
BY id ORDER BY id"
+ )
+ expected =
df.groupby("id").agg(sf.mean(df["v"]).alias("mean")).sort("id").collect()
+
+ self.assertEqual(expected, result_sql.collect())
+
+ def test_iterator_grouped_agg_sql_multiple_columns(self):
+ """
+ Test iterator API for grouped aggregation with multiple columns in SQL.
+ """
+ import pyarrow as pa
+
+ @arrow_udf("double")
+ def arrow_weighted_mean_iter(it: Iterator[Tuple[pa.Array, pa.Array]])
-> float:
+ weighted_sum = 0.0
+ weight = 0.0
+ for v, w in it:
+ assert isinstance(v, pa.Array)
+ assert isinstance(w, pa.Array)
+ weighted_sum += pa.compute.sum(pa.compute.multiply(v,
w)).as_py()
+ weight += pa.compute.sum(w).as_py()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ with self.tempView("test_table"),
self.temp_func("arrow_weighted_mean_iter"):
+ df.createOrReplaceTempView("test_table")
+ self.spark.udf.register("arrow_weighted_mean_iter",
arrow_weighted_mean_iter)
+
+ # Test SQL query with GROUP BY and multiple columns
+ result_sql = self.spark.sql(
+ "SELECT id, arrow_weighted_mean_iter(v, w) as wm "
+ "FROM test_table GROUP BY id ORDER BY id"
+ )
+
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) =
43.0 / 6.0
+ expected = [Row(id=1, wm=5.0 / 3.0), Row(id=2, wm=43.0 / 6.0)]
+
+ actual_results = result_sql.collect()
+ self.assertEqual(actual_results, expected)
+
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index fb18c5f062b8..2f9e4bbd8ddd 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -212,7 +212,8 @@ class ApplyInPandasTestsMixin:
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_ARROW_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF, SQL_SCALAR_ARROW_ITER_UDF, "
- "SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
+ "SQL_GROUPED_AGG_PANDAS_UDF, SQL_GROUPED_AGG_ARROW_UDF or "
+ "SQL_GROUPED_AGG_ARROW_ITER_UDF"
},
)
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 75bcb66efdb8..14d3f92f053d 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -681,6 +681,7 @@ class UDFRegistration:
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
]:
raise PySparkTypeError(
errorClass="INVALID_UDF_EVAL_TYPE",
@@ -688,7 +689,8 @@ class UDFRegistration:
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_ARROW_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF,
SQL_SCALAR_ARROW_ITER_UDF, "
- "SQL_GROUPED_AGG_PANDAS_UDF or
SQL_GROUPED_AGG_ARROW_UDF"
+ "SQL_GROUPED_AGG_PANDAS_UDF, SQL_GROUPED_AGG_ARROW_UDF
or "
+ "SQL_GROUPED_AGG_ARROW_ITER_UDF"
},
)
source_udf = _create_udf(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]