This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8f662fc2d0ea [SPARK-48555][SQL][PYTHON][CONNECT] Support using Columns
as parameters for several functions in pyspark/scala
8f662fc2d0ea is described below
commit 8f662fc2d0ea9472feb948ee6e2786eaf92339b7
Author: Ron Serruya <[email protected]>
AuthorDate: Mon Jun 17 09:31:08 2024 +0900
[SPARK-48555][SQL][PYTHON][CONNECT] Support using Columns as parameters for
several functions in pyspark/scala
### What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-48555
For pyspark, added the ability to use the "Column" type or names of column
for the parameters of the following functions:
- array_remove
- array_position
- map_contains_key
- substring
For scala, added the ability to use "Column" type as the parameters of the
`substring` function
This functionality already exists in the SQL syntax:
```
select array_remove(col1, col2) from values (array(1,2,3), 2)
```
however, it isn't possible to do the same in python
```python
df.select(F.array_remove(F.col("col1"), F.col("col2"))
```
Note that in scala the functions other than `substring` already accepted
Column params (or rather, they accept `Any` and pass whatever the param is to
`lit` so it ends up working), so I only needed to change substring in the
scala side.
### Why are the changes needed?
To align the scala/python API with the SQL one.
### Does this PR introduce _any_ user-facing change?
Yes, added new overloaded functions in scala and changed type
hints/docstrings in python.
### How was this patch tested?
Added doctests for the python changes, and tests in the scala test suites,
then tested both manually and using the CI.
### Was this patch authored or co-authored using generative AI tooling?
No.
### Notes:
- I opened the related JIRA ticket, but looks like I don't have the option
to assign it myself, so if it is required and any reviewer does have
permissions for it, I'd appreciate it
- The "Build" workflow passed successfully, but the "report tes results"
one didn't due to some authorization issue, I see that this is the same for
many other open PRs right now so I assume its ok.
- For the python changes, I tried to follow the convention used by other
functions (such as `array_contains` or `when`), of using `value._jc if
isinstance(value, Column) else value`
- Im not really familiar with the `connect` functions, but seems like on
the python side they already supported the use of columns so no extra changes
were needed there
- For the scala side, this is the first time I'm touching scala, I think I
covered it all as I tried to match similar changes done in [a similar
PR](https://github.com/apache/spark/pull/46045)
- The same issue also exists for `substring_index` however I wasn't able to
fix this one the same way I did for `substring`. Calling it with a `lit` for
the `count` arg worked, but using a `col` error with a `NumberFormatError` for
"project_value_3". I assume the error is related to trying to parse the Int
[here](https://github.com/apache/spark/blob/7cba1ab4d6acef4e9d73a8e6018b0902aac3a18d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala#L1
[...]
The contribution is my original work and I license the work to the project
under the project’s open source license.
Closes #46901 from Ronserruya/support_columns_in_pyspark_functions.
Authored-by: Ron Serruya <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../scala/org/apache/spark/sql/functions.scala | 13 ++++
.../apache/spark/sql/PlanGenerationTestSuite.scala | 4 ++
.../function_substring_using_columns.explain | 2 +
.../function_substring_with_columns.explain | 2 +
.../queries/function_substring_using_columns.json | 33 ++++++++++
.../function_substring_using_columns.proto.bin | Bin 0 -> 192 bytes
python/pyspark/sql/functions/builtin.py | 72 +++++++++++++++++++--
.../scala/org/apache/spark/sql/functions.scala | 13 ++++
.../apache/spark/sql/StringFunctionsSuite.scala | 5 ++
9 files changed, 138 insertions(+), 6 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index c287e3469108..eae239a25589 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4276,6 +4276,19 @@ object functions {
def substring(str: Column, pos: Int, len: Int): Column =
Column.fn("substring", str, lit(pos), lit(len))
+ /**
+ * Substring starts at `pos` and is of length `len` when str is String type
or returns the slice
+ * of byte array that starts at `pos` in byte and is of length `len` when
str is Binary type
+ *
+ * @note
+ * The position is not zero based, but 1 based index.
+ *
+ * @group string_funcs
+ * @since 4.0.0
+ */
+ def substring(str: Column, pos: Column, len: Column): Column =
+ Column.fn("substring", str, pos, len)
+
/**
* Returns the substring from string str before count occurrences of the
delimiter delim. If
* count is positive, everything the left of the final delimiter (counting
from left) is
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index aee287f4bbb3..77be7c5de04a 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -1780,6 +1780,10 @@ class PlanGenerationTestSuite
fn.substring(fn.col("g"), 4, 5)
}
+ functionTest("substring using columns") {
+ fn.substring(fn.col("g"), fn.col("a"), fn.col("b"))
+ }
+
functionTest("substring_index") {
fn.substring_index(fn.col("g"), ";", 5)
}
diff --git
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain
new file mode 100644
index 000000000000..3050d15d9754
--- /dev/null
+++
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain
@@ -0,0 +1,2 @@
+Project [substring(g#0, a#0, cast(b#0 as int)) AS substring(g, a, b)#0]
++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain
new file mode 100644
index 000000000000..fe07244fc9ce
--- /dev/null
+++
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain
@@ -0,0 +1,2 @@
+Project [substring(g#0, 4, 5) AS substring(g, 4, 5)#0]
++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json
b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json
new file mode 100644
index 000000000000..ba28b1c7f570
--- /dev/null
+++
b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json
@@ -0,0 +1,33 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "project": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema":
"struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
+ }
+ },
+ "expressions": [{
+ "unresolvedFunction": {
+ "functionName": "substring",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "g"
+ }
+ }, {
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "a"
+ }
+ }, {
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "b"
+ }
+ }]
+ }
+ }]
+ }
+}
\ No newline at end of file
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin
new file mode 100644
index 000000000000..f14b44ef5a50
Binary files /dev/null and
b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin
differ
diff --git a/python/pyspark/sql/functions/builtin.py
b/python/pyspark/sql/functions/builtin.py
index 07dfcaf2e2b7..2edbc9f5abe1 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -10915,7 +10915,9 @@ def sentences(
@_try_remote_functions
-def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
+def substring(
+ str: "ColumnOrName", pos: Union["ColumnOrName", int], len:
Union["ColumnOrName", int]
+) -> Column:
"""
Substring starts at `pos` and is of length `len` when str is String type or
returns the slice of byte array that starts at `pos` in byte and is of
length `len`
@@ -10934,11 +10936,14 @@ def substring(str: "ColumnOrName", pos: int, len:
int) -> Column:
----------
str : :class:`~pyspark.sql.Column` or str
target column to work on.
- pos : int
+ pos : :class:`~pyspark.sql.Column` or str or int
starting position in str.
- len : int
+ len : :class:`~pyspark.sql.Column` or str or int
length of chars.
+ .. versionchanged:: 4.0.0
+ `pos` and `len` now also accept Columns or names of Columns.
+
Returns
-------
:class:`~pyspark.sql.Column`
@@ -10949,9 +10954,18 @@ def substring(str: "ColumnOrName", pos: int, len: int)
-> Column:
>>> df = spark.createDataFrame([('abcd',)], ['s',])
>>> df.select(substring(df.s, 1, 2).alias('s')).collect()
[Row(s='ab')]
+ >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l'])
+ >>> df.select(substring(df.s, 2, df.l).alias('s')).collect()
+ [Row(s='par')]
+ >>> df.select(substring(df.s, df.p, 3).alias('s')).collect()
+ [Row(s='par')]
+ >>> df.select(substring(df.s, df.p, df.l).alias('s')).collect()
+ [Row(s='par')]
"""
from pyspark.sql.classic.column import _to_java_column
+ pos = _to_java_column(lit(pos) if isinstance(pos, int) else pos)
+ len = _to_java_column(lit(len) if isinstance(len, int) else len)
return _invoke_function("substring", _to_java_column(str), pos, len)
@@ -13969,7 +13983,10 @@ def array_position(col: "ColumnOrName", value: Any) ->
Column:
col : :class:`~pyspark.sql.Column` or str
target column to work on.
value : Any
- value to look for.
+ value or a :class:`~pyspark.sql.Column` expression to look for.
+
+ .. versionchanged:: 4.0.0
+ `value` now also accepts a Column type.
Returns
-------
@@ -14034,9 +14051,22 @@ def array_position(col: "ColumnOrName", value: Any) ->
Column:
+-----------------------+
| 3|
+-----------------------+
+
+ Example 6: Finding the position of a column's value in an array of integers
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([([10, 20, 30], 20)], ['data', 'col'])
+ >>> df.select(sf.array_position(df.data, df.col)).show()
+ +-------------------------+
+ |array_position(data, col)|
+ +-------------------------+
+ | 2|
+ +-------------------------+
+
"""
from pyspark.sql.classic.column import _to_java_column
+ value = _to_java_column(value) if isinstance(value, Column) else value
return _invoke_function("array_position", _to_java_column(col), value)
@@ -14402,7 +14432,10 @@ def array_remove(col: "ColumnOrName", element: Any) ->
Column:
col : :class:`~pyspark.sql.Column` or str
name of column containing array
element :
- element to be removed from the array
+ element or a :class:`~pyspark.sql.Column` expression to be removed
from the array
+
+ .. versionchanged:: 4.0.0
+ `element` now also accepts a Column type.
Returns
-------
@@ -14470,9 +14503,21 @@ def array_remove(col: "ColumnOrName", element: Any) ->
Column:
+---------------------+
| []|
+---------------------+
+
+ Example 6: Removing a column's value from a simple array
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([([1, 2, 3, 1, 1], 1)], ['data', 'col'])
+ >>> df.select(sf.array_remove(df.data, df.col)).show()
+ +-----------------------+
+ |array_remove(data, col)|
+ +-----------------------+
+ | [2, 3]|
+ +-----------------------+
"""
from pyspark.sql.classic.column import _to_java_column
+ element = _to_java_column(element) if isinstance(element, Column) else
element
return _invoke_function("array_remove", _to_java_column(col), element)
@@ -17237,7 +17282,10 @@ def map_contains_key(col: "ColumnOrName", value: Any)
-> Column:
col : :class:`~pyspark.sql.Column` or str
The name of the column or an expression that represents the map.
value :
- A literal value.
+ A literal value, or a :class:`~pyspark.sql.Column` expression.
+
+ .. versionchanged:: 4.0.0
+ `value` now also accepts a Column type.
Returns
-------
@@ -17267,9 +17315,21 @@ def map_contains_key(col: "ColumnOrName", value: Any)
-> Column:
+--------------------------+
| false|
+--------------------------+
+
+ Example 3: Check for key using a column
+
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data, 1 as key")
+ >>> df.select(sf.map_contains_key("data", sf.col("key"))).show()
+ +---------------------------+
+ |map_contains_key(data, key)|
+ +---------------------------+
+ | true|
+ +---------------------------+
"""
from pyspark.sql.classic.column import _to_java_column
+ value = _to_java_column(value) if isinstance(value, Column) else value
return _invoke_function("map_contains_key", _to_java_column(col), value)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 52733611e42a..882918eb78c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4234,6 +4234,19 @@ object functions {
def substring(str: Column, pos: Int, len: Int): Column =
Column.fn("substring", str, lit(pos), lit(len))
+ /**
+ * Substring starts at `pos` and is of length `len` when str is String type
or
+ * returns the slice of byte array that starts at `pos` in byte and is of
length `len`
+ * when str is Binary type
+ *
+ * @note The position is not zero based, but 1 based index.
+ *
+ * @group string_funcs
+ * @since 4.0.0
+ */
+ def substring(str: Column, pos: Column, len: Column): Column =
+ Column.fn("substring", str, pos, len)
+
/**
* Returns the substring from string str before count occurrences of the
delimiter delim.
* If count is positive, everything the left of the final delimiter
(counting from left) is
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 3fc0b572d80b..31c1cac9fb71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -332,6 +332,11 @@ class StringFunctionsSuite extends QueryTest with
SharedSparkSession {
// scalastyle:on
}
+ test("string substring function using columns") {
+ val df = Seq(("Spark", 2, 3)).toDF("a", "b", "c")
+ checkAnswer(df.select(substring($"a", $"b", $"c")), Row("par"))
+ }
+
test("string encode/decode function") {
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106,
-25, -107, -116)
// scalastyle:off
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]