This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ef4be07fdad9 [SPARK-50220][PYTHON] Support listagg in PySpark
ef4be07fdad9 is described below
commit ef4be07fdad9c8078e22d4f3f068fee1b81cf967
Author: Mikhail Nikoliukin <[email protected]>
AuthorDate: Wed Dec 25 17:13:59 2024 +0800
[SPARK-50220][PYTHON] Support listagg in PySpark
### What changes were proposed in this pull request?
Added new function `listagg` to pyspark.
Follow-up of https://github.com/apache/spark/pull/48748.
### Why are the changes needed?
Allows to use native Python functions to write queries with `listagg`.
E.g., `df.select(F.listagg(df.value, ",").alias("r"))`.
### Does this PR introduce _any_ user-facing change?
Yes, new functions `listagg` and `listagg_distinct` (with aliases
`string_agg` and `string_agg_distinct`) in pyspark.
### How was this patch tested?
Unit tests
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: GitHub Copilot
Closes #49231 from mikhailnik-db/SPARK-50220-listagg-for-pyspark.
Authored-by: Mikhail Nikoliukin <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../source/reference/pyspark.sql/functions.rst | 4 +
python/pyspark/sql/connect/functions/builtin.py | 58 ++++
python/pyspark/sql/functions/__init__.py | 4 +
python/pyspark/sql/functions/builtin.py | 308 +++++++++++++++++++++
.../sql/tests/connect/test_connect_function.py | 4 +
python/pyspark/sql/tests/test_functions.py | 73 ++++-
6 files changed, 444 insertions(+), 7 deletions(-)
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst
b/python/docs/source/reference/pyspark.sql/functions.rst
index 430e353dd701..a1ba153110f1 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -451,6 +451,8 @@ Aggregate Functions
kurtosis
last
last_value
+ listagg
+ listagg_distinct
max
max_by
mean
@@ -476,6 +478,8 @@ Aggregate Functions
stddev
stddev_pop
stddev_samp
+ string_agg
+ string_agg_distinct
sum
sum_distinct
try_avg
diff --git a/python/pyspark/sql/connect/functions/builtin.py
b/python/pyspark/sql/connect/functions/builtin.py
index f52cdffb84b7..f13eeab12dd3 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -1064,6 +1064,64 @@ def collect_set(col: "ColumnOrName") -> Column:
collect_set.__doc__ = pysparkfuncs.collect_set.__doc__
+def listagg(col: "ColumnOrName", delimiter: Optional[Union[Column, str,
bytes]] = None) -> Column:
+ if delimiter is None:
+ return _invoke_function_over_columns("listagg", col)
+ else:
+ return _invoke_function_over_columns("listagg", col, lit(delimiter))
+
+
+listagg.__doc__ = pysparkfuncs.listagg.__doc__
+
+
+def listagg_distinct(
+ col: "ColumnOrName", delimiter: Optional[Union[Column, str, bytes]] = None
+) -> Column:
+ from pyspark.sql.connect.column import Column as ConnectColumn
+
+ args = [col]
+ if delimiter is not None:
+ args += [lit(delimiter)]
+
+ _exprs = [_to_col(c)._expr for c in args]
+ return ConnectColumn(
+ UnresolvedFunction("listagg", _exprs, is_distinct=True) # type:
ignore[arg-type]
+ )
+
+
+listagg_distinct.__doc__ = pysparkfuncs.listagg_distinct.__doc__
+
+
+def string_agg(
+ col: "ColumnOrName", delimiter: Optional[Union[Column, str, bytes]] = None
+) -> Column:
+ if delimiter is None:
+ return _invoke_function_over_columns("string_agg", col)
+ else:
+ return _invoke_function_over_columns("string_agg", col, lit(delimiter))
+
+
+string_agg.__doc__ = pysparkfuncs.string_agg.__doc__
+
+
+def string_agg_distinct(
+ col: "ColumnOrName", delimiter: Optional[Union[Column, str, bytes]] = None
+) -> Column:
+ from pyspark.sql.connect.column import Column as ConnectColumn
+
+ args = [col]
+ if delimiter is not None:
+ args += [lit(delimiter)]
+
+ _exprs = [_to_col(c)._expr for c in args]
+ return ConnectColumn(
+ UnresolvedFunction("string_agg", _exprs, is_distinct=True) # type:
ignore[arg-type]
+ )
+
+
+string_agg_distinct.__doc__ = pysparkfuncs.string_agg_distinct.__doc__
+
+
def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
return _invoke_function_over_columns("corr", col1, col2)
diff --git a/python/pyspark/sql/functions/__init__.py
b/python/pyspark/sql/functions/__init__.py
index 98db2a7b091d..fc0120bc681d 100644
--- a/python/pyspark/sql/functions/__init__.py
+++ b/python/pyspark/sql/functions/__init__.py
@@ -364,6 +364,8 @@ __all__ = [ # noqa: F405
"kurtosis",
"last",
"last_value",
+ "listagg",
+ "listagg_distinct",
"max",
"max_by",
"mean",
@@ -389,6 +391,8 @@ __all__ = [ # noqa: F405
"stddev",
"stddev_pop",
"stddev_samp",
+ "string_agg",
+ "string_agg_distinct",
"sum",
"sum_distinct",
"try_avg",
diff --git a/python/pyspark/sql/functions/builtin.py
b/python/pyspark/sql/functions/builtin.py
index 4b4c164055ea..7b14598a0ef4 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -1851,6 +1851,314 @@ def sum_distinct(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("sum_distinct", col)
+@_try_remote_functions
+def listagg(col: "ColumnOrName", delimiter: Optional[Union[Column, str,
bytes]] = None) -> Column:
+ """
+ Aggregate function: returns the concatenation of non-null input values,
+ separated by the delimiter.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ target column to compute on.
+ delimiter : :class:`~pyspark.sql.Column`, literal string or bytes, optional
+ the delimiter to separate the values. The default value is None.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ Example 1: Using listagg function
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([('a',), ('b',), (None,), ('c',)],
['strings'])
+ >>> df.select(sf.listagg('strings')).show()
+ +----------------------+
+ |listagg(strings, NULL)|
+ +----------------------+
+ | abc|
+ +----------------------+
+
+ Example 2: Using listagg function with a delimiter
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([('a',), ('b',), (None,), ('c',)],
['strings'])
+ >>> df.select(sf.listagg('strings', ', ')).show()
+ +--------------------+
+ |listagg(strings, , )|
+ +--------------------+
+ | a, b, c|
+ +--------------------+
+
+ Example 3: Using listagg function with a binary column and delimiter
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(b'\x01',), (b'\x02',), (None,),
(b'\x03',)], ['bytes'])
+ >>> df.select(sf.listagg('bytes', b'\x42')).show()
+ +---------------------+
+ |listagg(bytes, X'42')|
+ +---------------------+
+ | [01 42 02 42 03]|
+ +---------------------+
+
+ Example 4: Using listagg function on a column with all None values
+
+ >>> from pyspark.sql import functions as sf
+ >>> from pyspark.sql.types import StructType, StructField, StringType
+ >>> schema = StructType([StructField("strings", StringType(), True)])
+ >>> df = spark.createDataFrame([(None,), (None,), (None,), (None,)],
schema=schema)
+ >>> df.select(sf.listagg('strings')).show()
+ +----------------------+
+ |listagg(strings, NULL)|
+ +----------------------+
+ | NULL|
+ +----------------------+
+ """
+ if delimiter is None:
+ return _invoke_function_over_columns("listagg", col)
+ else:
+ return _invoke_function_over_columns("listagg", col, lit(delimiter))
+
+
+@_try_remote_functions
+def listagg_distinct(
+ col: "ColumnOrName", delimiter: Optional[Union[Column, str, bytes]] = None
+) -> Column:
+ """
+ Aggregate function: returns the concatenation of distinct non-null input
values,
+ separated by the delimiter.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ target column to compute on.
+ delimiter : :class:`~pyspark.sql.Column`, literal string or bytes, optional
+ the delimiter to separate the values. The default value is None.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ Example 1: Using listagg_distinct function
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([('a',), ('b',), (None,), ('c',), ('b',)],
['strings'])
+ >>> df.select(sf.listagg_distinct('strings')).show()
+ +-------------------------------+
+ |listagg(DISTINCT strings, NULL)|
+ +-------------------------------+
+ | abc|
+ +-------------------------------+
+
+ Example 2: Using listagg_distinct function with a delimiter
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([('a',), ('b',), (None,), ('c',), ('b',)],
['strings'])
+ >>> df.select(sf.listagg_distinct('strings', ', ')).show()
+ +-----------------------------+
+ |listagg(DISTINCT strings, , )|
+ +-----------------------------+
+ | a, b, c|
+ +-----------------------------+
+
+ Example 3: Using listagg_distinct function with a binary column and
delimiter
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(b'\x01',), (b'\x02',), (None,),
(b'\x03',), (b'\x02',)],
+ ... ['bytes'])
+ >>> df.select(sf.listagg_distinct('bytes', b'\x42')).show()
+ +------------------------------+
+ |listagg(DISTINCT bytes, X'42')|
+ +------------------------------+
+ | [01 42 02 42 03]|
+ +------------------------------+
+
+ Example 4: Using listagg_distinct function on a column with all None values
+
+ >>> from pyspark.sql import functions as sf
+ >>> from pyspark.sql.types import StructType, StructField, StringType
+ >>> schema = StructType([StructField("strings", StringType(), True)])
+ >>> df = spark.createDataFrame([(None,), (None,), (None,), (None,)],
schema=schema)
+ >>> df.select(sf.listagg_distinct('strings')).show()
+ +-------------------------------+
+ |listagg(DISTINCT strings, NULL)|
+ +-------------------------------+
+ | NULL|
+ +-------------------------------+
+ """
+ if delimiter is None:
+ return _invoke_function_over_columns("listagg_distinct", col)
+ else:
+ return _invoke_function_over_columns("listagg_distinct", col,
lit(delimiter))
+
+
+@_try_remote_functions
+def string_agg(
+ col: "ColumnOrName", delimiter: Optional[Union[Column, str, bytes]] = None
+) -> Column:
+ """
+ Aggregate function: returns the concatenation of non-null input values,
+ separated by the delimiter.
+
+ An alias of :func:`listagg`.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ target column to compute on.
+ delimiter : :class:`~pyspark.sql.Column`, literal string or bytes, optional
+ the delimiter to separate the values. The default value is None.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ Example 1: Using string_agg function
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([('a',), ('b',), (None,), ('c',)],
['strings'])
+ >>> df.select(sf.string_agg('strings')).show()
+ +-------------------------+
+ |string_agg(strings, NULL)|
+ +-------------------------+
+ | abc|
+ +-------------------------+
+
+ Example 2: Using string_agg function with a delimiter
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([('a',), ('b',), (None,), ('c',)],
['strings'])
+ >>> df.select(sf.string_agg('strings', ', ')).show()
+ +-----------------------+
+ |string_agg(strings, , )|
+ +-----------------------+
+ | a, b, c|
+ +-----------------------+
+
+ Example 3: Using string_agg function with a binary column and delimiter
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(b'\x01',), (b'\x02',), (None,),
(b'\x03',)], ['bytes'])
+ >>> df.select(sf.string_agg('bytes', b'\x42')).show()
+ +------------------------+
+ |string_agg(bytes, X'42')|
+ +------------------------+
+ | [01 42 02 42 03]|
+ +------------------------+
+
+ Example 4: Using string_agg function on a column with all None values
+
+ >>> from pyspark.sql import functions as sf
+ >>> from pyspark.sql.types import StructType, StructField, StringType
+ >>> schema = StructType([StructField("strings", StringType(), True)])
+ >>> df = spark.createDataFrame([(None,), (None,), (None,), (None,)],
schema=schema)
+ >>> df.select(sf.string_agg('strings')).show()
+ +-------------------------+
+ |string_agg(strings, NULL)|
+ +-------------------------+
+ | NULL|
+ +-------------------------+
+ """
+ if delimiter is None:
+ return _invoke_function_over_columns("string_agg", col)
+ else:
+ return _invoke_function_over_columns("string_agg", col, lit(delimiter))
+
+
+@_try_remote_functions
+def string_agg_distinct(
+ col: "ColumnOrName", delimiter: Optional[Union[Column, str, bytes]] = None
+) -> Column:
+ """
+ Aggregate function: returns the concatenation of distinct non-null input
values,
+ separated by the delimiter.
+
+ An alias of :func:`listagg_distinct`.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ target column to compute on.
+ delimiter : :class:`~pyspark.sql.Column`, literal string or bytes, optional
+ the delimiter to separate the values. The default value is None.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ Example 1: Using string_agg_distinct function
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([('a',), ('b',), (None,), ('c',), ('b',)],
['strings'])
+ >>> df.select(sf.string_agg_distinct('strings')).show()
+ +----------------------------------+
+ |string_agg(DISTINCT strings, NULL)|
+ +----------------------------------+
+ | abc|
+ +----------------------------------+
+
+ Example 2: Using string_agg_distinct function with a delimiter
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([('a',), ('b',), (None,), ('c',), ('b',)],
['strings'])
+ >>> df.select(sf.string_agg_distinct('strings', ', ')).show()
+ +--------------------------------+
+ |string_agg(DISTINCT strings, , )|
+ +--------------------------------+
+ | a, b, c|
+ +--------------------------------+
+
+ Example 3: Using string_agg_distinct function with a binary column and
delimiter
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(b'\x01',), (b'\x02',), (None,),
(b'\x03',), (b'\x02',)],
+ ... ['bytes'])
+ >>> df.select(sf.string_agg_distinct('bytes', b'\x42')).show()
+ +---------------------------------+
+ |string_agg(DISTINCT bytes, X'42')|
+ +---------------------------------+
+ | [01 42 02 42 03]|
+ +---------------------------------+
+
+ Example 4: Using string_agg_distinct function on a column with all None
values
+
+ >>> from pyspark.sql import functions as sf
+ >>> from pyspark.sql.types import StructType, StructField, StringType
+ >>> schema = StructType([StructField("strings", StringType(), True)])
+ >>> df = spark.createDataFrame([(None,), (None,), (None,), (None,)],
schema=schema)
+ >>> df.select(sf.string_agg_distinct('strings')).show()
+ +----------------------------------+
+ |string_agg(DISTINCT strings, NULL)|
+ +----------------------------------+
+ | NULL|
+ +----------------------------------+
+ """
+ if delimiter is None:
+ return _invoke_function_over_columns("string_agg_distinct", col)
+ else:
+ return _invoke_function_over_columns("string_agg_distinct", col,
lit(delimiter))
+
+
@_try_remote_functions
def product(col: "ColumnOrName") -> Column:
"""
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index b7a02efcd5e2..d1e255830529 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -590,6 +590,10 @@ class SparkConnectFunctionTests(ReusedConnectTestCase,
PandasOnSparkTestUtils, S
(CF.avg, SF.avg),
(CF.collect_list, SF.collect_list),
(CF.collect_set, SF.collect_set),
+ (CF.listagg, SF.listagg),
+ (CF.listagg_distinct, SF.listagg_distinct),
+ (CF.string_agg, SF.string_agg),
+ (CF.string_agg_distinct, SF.string_agg_distinct),
(CF.count, SF.count),
(CF.first, SF.first),
(CF.kurtosis, SF.kurtosis),
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index 4607d5d3411f..39db72b235bf 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -30,6 +30,7 @@ from pyspark.sql import Row, Window, functions as F, types
from pyspark.sql.avro.functions import from_avro, to_avro
from pyspark.sql.column import Column
from pyspark.sql.functions.builtin import nullifzero, randstr, uniform,
zeroifnull
+from pyspark.sql.types import StructType, StructField, StringType
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
from pyspark.testing.utils import have_numpy, assertDataFrameEqual
@@ -83,13 +84,7 @@ class FunctionsTestsMixin:
missing_in_py = jvm_fn_set.difference(py_fn_set)
# Functions that we expect to be missing in python until they are
added to pyspark
- expected_missing_in_py = {
- # TODO(SPARK-50220): listagg functions will soon be added and
removed from this list
- "listagg_distinct",
- "listagg",
- "string_agg",
- "string_agg_distinct",
- }
+ expected_missing_in_py = set()
self.assertEqual(
expected_missing_in_py, missing_in_py, "Missing functions in
pyspark not as expected"
@@ -1145,6 +1140,70 @@ class FunctionsTestsMixin:
["1", "2", "2", "2"],
)
+ def test_listagg_functions(self):
+ df = self.spark.createDataFrame(
+ [(1, "1"), (2, "2"), (None, None), (1, "2")], ["key", "value"]
+ )
+ df_with_bytes = self.spark.createDataFrame(
+ [(b"\x01",), (b"\x02",), (None,), (b"\x03",), (b"\x02",)],
["bytes"]
+ )
+ df_with_nulls = self.spark.createDataFrame(
+ [(None,), (None,), (None,), (None,), (None,)],
+ StructType([StructField("nulls", StringType(), True)]),
+ )
+ # listagg and string_agg are aliases
+ for listagg_ref in [F.listagg, F.string_agg]:
+
self.assertEqual(df.select(listagg_ref(df.key).alias("r")).collect()[0].r,
"121")
+
self.assertEqual(df.select(listagg_ref(df.value).alias("r")).collect()[0].r,
"122")
+ self.assertEqual(
+ df.select(listagg_ref(df.value,
",").alias("r")).collect()[0].r, "1,2,2"
+ )
+ self.assertEqual(
+ df_with_bytes.select(listagg_ref(df_with_bytes.bytes,
b"\x42").alias("r"))
+ .collect()[0]
+ .r,
+ b"\x01\x42\x02\x42\x03\x42\x02",
+ )
+ self.assertEqual(
+
df_with_nulls.select(listagg_ref(df_with_nulls.nulls).alias("r")).collect()[0].r,
+ None,
+ )
+
+ def test_listagg_distinct_functions(self):
+ df = self.spark.createDataFrame(
+ [(1, "1"), (2, "2"), (None, None), (1, "2")], ["key", "value"]
+ )
+ df_with_bytes = self.spark.createDataFrame(
+ [(b"\x01",), (b"\x02",), (None,), (b"\x03",), (b"\x02",)],
["bytes"]
+ )
+ df_with_nulls = self.spark.createDataFrame(
+ [(None,), (None,), (None,), (None,), (None,)],
+ StructType([StructField("nulls", StringType(), True)]),
+ )
+ # listagg_distinct and string_agg_distinct are aliases
+ for listagg_distinct_ref in [F.listagg_distinct,
F.string_agg_distinct]:
+ self.assertEqual(
+
df.select(listagg_distinct_ref(df.key).alias("r")).collect()[0].r, "12"
+ )
+ self.assertEqual(
+
df.select(listagg_distinct_ref(df.value).alias("r")).collect()[0].r, "12"
+ )
+ self.assertEqual(
+ df.select(listagg_distinct_ref(df.value,
",").alias("r")).collect()[0].r, "1,2"
+ )
+ self.assertEqual(
+ df_with_bytes.select(listagg_distinct_ref(df_with_bytes.bytes,
b"\x42").alias("r"))
+ .collect()[0]
+ .r,
+ b"\x01\x42\x02\x42\x03",
+ )
+ self.assertEqual(
+
df_with_nulls.select(listagg_distinct_ref(df_with_nulls.nulls).alias("r"))
+ .collect()[0]
+ .r,
+ None,
+ )
+
def test_datetime_functions(self):
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
parse_result = df.select(F.to_date(F.col("dateCol"))).first()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]