This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8f662fc2d0ea [SPARK-48555][SQL][PYTHON][CONNECT] Support using Columns 
as parameters for several functions in pyspark/scala
8f662fc2d0ea is described below

commit 8f662fc2d0ea9472feb948ee6e2786eaf92339b7
Author: Ron Serruya <[email protected]>
AuthorDate: Mon Jun 17 09:31:08 2024 +0900

    [SPARK-48555][SQL][PYTHON][CONNECT] Support using Columns as parameters for 
several functions in pyspark/scala
    
    ### What changes were proposed in this pull request?
    
    https://issues.apache.org/jira/browse/SPARK-48555
    For pyspark, added the ability to use the "Column" type or names of column 
for the parameters of the following functions:
    
    - array_remove
    - array_position
    - map_contains_key
    - substring
    
    For scala, added the ability to use "Column" type as the parameters of the 
`substring` function
    
    This functionality already exists in the SQL syntax:
    ```
    select array_remove(col1, col2) from values (array(1,2,3), 2)
    ```
    
    however, it isn't possible to do the same in python
    ```python
    df.select(F.array_remove(F.col("col1"), F.col("col2"))
    ```
    
    Note that in scala the functions other than `substring` already accepted 
Column params (or rather, they accept `Any` and pass whatever the param is to 
`lit` so it ends up working), so I  only needed to change substring in the 
scala side.
    
    ### Why are the changes needed?
    To align the scala/python API with the SQL one.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, added new overloaded functions in scala and changed type 
hints/docstrings in python.
    
    ### How was this patch tested?
    Added doctests for the python changes, and tests in the scala test suites, 
then tested both manually and using the CI.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    ### Notes:
    
    - I opened the related JIRA ticket, but looks like I don't have the option 
to assign it myself, so if it is required and any reviewer does have 
permissions for it, I'd appreciate it
    - The "Build" workflow passed successfully, but the "report tes results" 
one didn't due to some authorization issue, I see that this is the same for 
many other open PRs right now so I assume its ok.
    - For the python changes, I tried to follow the convention used by other 
functions (such as `array_contains` or `when`), of using `value._jc if 
isinstance(value, Column) else value`
    - Im not really familiar with the `connect` functions, but seems like on 
the python side they already supported the use of columns so no extra changes 
were needed there
    - For the scala side, this is the first time I'm touching scala, I think I 
covered it all as I tried to match similar changes done in [a similar 
PR](https://github.com/apache/spark/pull/46045)
    - The same issue also exists for `substring_index` however I wasn't able to 
fix this one the same way I did for `substring`. Calling it with a `lit` for 
the `count` arg worked, but using a `col` error with a `NumberFormatError` for 
"project_value_3". I assume the error is related to trying to parse the Int 
[here](https://github.com/apache/spark/blob/7cba1ab4d6acef4e9d73a8e6018b0902aac3a18d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala#L1
 [...]
    
    The contribution is my original work and I license the work to the project 
under the project’s open source license.
    
    Closes #46901 from Ronserruya/support_columns_in_pyspark_functions.
    
    Authored-by: Ron Serruya <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../scala/org/apache/spark/sql/functions.scala     |  13 ++++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 ++
 .../function_substring_using_columns.explain       |   2 +
 .../function_substring_with_columns.explain        |   2 +
 .../queries/function_substring_using_columns.json  |  33 ++++++++++
 .../function_substring_using_columns.proto.bin     | Bin 0 -> 192 bytes
 python/pyspark/sql/functions/builtin.py            |  72 +++++++++++++++++++--
 .../scala/org/apache/spark/sql/functions.scala     |  13 ++++
 .../apache/spark/sql/StringFunctionsSuite.scala    |   5 ++
 9 files changed, 138 insertions(+), 6 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index c287e3469108..eae239a25589 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4276,6 +4276,19 @@ object functions {
   def substring(str: Column, pos: Int, len: Int): Column =
     Column.fn("substring", str, lit(pos), lit(len))
 
+  /**
+   * Substring starts at `pos` and is of length `len` when str is String type 
or returns the slice
+   * of byte array that starts at `pos` in byte and is of length `len` when 
str is Binary type
+   *
+   * @note
+   *   The position is not zero based, but 1 based index.
+   *
+   * @group string_funcs
+   * @since 4.0.0
+   */
+  def substring(str: Column, pos: Column, len: Column): Column =
+    Column.fn("substring", str, pos, len)
+
   /**
    * Returns the substring from string str before count occurrences of the 
delimiter delim. If
    * count is positive, everything the left of the final delimiter (counting 
from left) is
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index aee287f4bbb3..77be7c5de04a 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -1780,6 +1780,10 @@ class PlanGenerationTestSuite
     fn.substring(fn.col("g"), 4, 5)
   }
 
+  functionTest("substring using columns") {
+    fn.substring(fn.col("g"), fn.col("a"), fn.col("b"))
+  }
+
   functionTest("substring_index") {
     fn.substring_index(fn.col("g"), ";", 5)
   }
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain
new file mode 100644
index 000000000000..3050d15d9754
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain
@@ -0,0 +1,2 @@
+Project [substring(g#0, a#0, cast(b#0 as int)) AS substring(g, a, b)#0]
++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain
new file mode 100644
index 000000000000..fe07244fc9ce
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain
@@ -0,0 +1,2 @@
+Project [substring(g#0, 4, 5) AS substring(g, 4, 5)#0]
++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json
new file mode 100644
index 000000000000..ba28b1c7f570
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json
@@ -0,0 +1,33 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "project": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": 
"struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
+      }
+    },
+    "expressions": [{
+      "unresolvedFunction": {
+        "functionName": "substring",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "g"
+          }
+        }, {
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "a"
+          }
+        }, {
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "b"
+          }
+        }]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin
new file mode 100644
index 000000000000..f14b44ef5a50
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin
 differ
diff --git a/python/pyspark/sql/functions/builtin.py 
b/python/pyspark/sql/functions/builtin.py
index 07dfcaf2e2b7..2edbc9f5abe1 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -10915,7 +10915,9 @@ def sentences(
 
 
 @_try_remote_functions
-def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
+def substring(
+    str: "ColumnOrName", pos: Union["ColumnOrName", int], len: 
Union["ColumnOrName", int]
+) -> Column:
     """
     Substring starts at `pos` and is of length `len` when str is String type or
     returns the slice of byte array that starts at `pos` in byte and is of 
length `len`
@@ -10934,11 +10936,14 @@ def substring(str: "ColumnOrName", pos: int, len: 
int) -> Column:
     ----------
     str : :class:`~pyspark.sql.Column` or str
         target column to work on.
-    pos : int
+    pos : :class:`~pyspark.sql.Column` or str or int
         starting position in str.
-    len : int
+    len : :class:`~pyspark.sql.Column` or str or int
         length of chars.
 
+        .. versionchanged:: 4.0.0
+            `pos` and `len` now also accept Columns or names of Columns.
+
     Returns
     -------
     :class:`~pyspark.sql.Column`
@@ -10949,9 +10954,18 @@ def substring(str: "ColumnOrName", pos: int, len: int) 
-> Column:
     >>> df = spark.createDataFrame([('abcd',)], ['s',])
     >>> df.select(substring(df.s, 1, 2).alias('s')).collect()
     [Row(s='ab')]
+    >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l'])
+    >>> df.select(substring(df.s, 2, df.l).alias('s')).collect()
+    [Row(s='par')]
+    >>> df.select(substring(df.s, df.p, 3).alias('s')).collect()
+    [Row(s='par')]
+    >>> df.select(substring(df.s, df.p, df.l).alias('s')).collect()
+    [Row(s='par')]
     """
     from pyspark.sql.classic.column import _to_java_column
 
+    pos = _to_java_column(lit(pos) if isinstance(pos, int) else pos)
+    len = _to_java_column(lit(len) if isinstance(len, int) else len)
     return _invoke_function("substring", _to_java_column(str), pos, len)
 
 
@@ -13969,7 +13983,10 @@ def array_position(col: "ColumnOrName", value: Any) -> 
Column:
     col : :class:`~pyspark.sql.Column` or str
         target column to work on.
     value : Any
-        value to look for.
+        value or a :class:`~pyspark.sql.Column` expression to look for.
+
+        .. versionchanged:: 4.0.0
+            `value` now also accepts a Column type.
 
     Returns
     -------
@@ -14034,9 +14051,22 @@ def array_position(col: "ColumnOrName", value: Any) -> 
Column:
     +-----------------------+
     |                      3|
     +-----------------------+
+
+    Example 6: Finding the position of a column's value in an array of integers
+
+    >>> from pyspark.sql import functions as sf
+    >>> df = spark.createDataFrame([([10, 20, 30], 20)], ['data', 'col'])
+    >>> df.select(sf.array_position(df.data, df.col)).show()
+    +-------------------------+
+    |array_position(data, col)|
+    +-------------------------+
+    |                        2|
+    +-------------------------+
+
     """
     from pyspark.sql.classic.column import _to_java_column
 
+    value = _to_java_column(value) if isinstance(value, Column) else value
     return _invoke_function("array_position", _to_java_column(col), value)
 
 
@@ -14402,7 +14432,10 @@ def array_remove(col: "ColumnOrName", element: Any) -> 
Column:
     col : :class:`~pyspark.sql.Column` or str
         name of column containing array
     element :
-        element to be removed from the array
+        element or a :class:`~pyspark.sql.Column` expression to be removed 
from the array
+
+        .. versionchanged:: 4.0.0
+            `element` now also accepts a Column type.
 
     Returns
     -------
@@ -14470,9 +14503,21 @@ def array_remove(col: "ColumnOrName", element: Any) -> 
Column:
     +---------------------+
     |                   []|
     +---------------------+
+
+    Example 6: Removing a column's value from a simple array
+
+    >>> from pyspark.sql import functions as sf
+    >>> df = spark.createDataFrame([([1, 2, 3, 1, 1], 1)], ['data', 'col'])
+    >>> df.select(sf.array_remove(df.data, df.col)).show()
+    +-----------------------+
+    |array_remove(data, col)|
+    +-----------------------+
+    |                 [2, 3]|
+    +-----------------------+
     """
     from pyspark.sql.classic.column import _to_java_column
 
+    element = _to_java_column(element) if isinstance(element, Column) else 
element
     return _invoke_function("array_remove", _to_java_column(col), element)
 
 
@@ -17237,7 +17282,10 @@ def map_contains_key(col: "ColumnOrName", value: Any) 
-> Column:
     col : :class:`~pyspark.sql.Column` or str
         The name of the column or an expression that represents the map.
     value :
-        A literal value.
+        A literal value, or a :class:`~pyspark.sql.Column` expression.
+
+        .. versionchanged:: 4.0.0
+            `value` now also accepts a Column type.
 
     Returns
     -------
@@ -17267,9 +17315,21 @@ def map_contains_key(col: "ColumnOrName", value: Any) 
-> Column:
     +--------------------------+
     |                     false|
     +--------------------------+
+
+    Example 3: Check for key using a column
+
+    >>> from pyspark.sql import functions as sf
+    >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data, 1 as key")
+    >>> df.select(sf.map_contains_key("data", sf.col("key"))).show()
+    +---------------------------+
+    |map_contains_key(data, key)|
+    +---------------------------+
+    |                       true|
+    +---------------------------+
     """
     from pyspark.sql.classic.column import _to_java_column
 
+    value = _to_java_column(value) if isinstance(value, Column) else value
     return _invoke_function("map_contains_key", _to_java_column(col), value)
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 52733611e42a..882918eb78c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4234,6 +4234,19 @@ object functions {
   def substring(str: Column, pos: Int, len: Int): Column =
     Column.fn("substring", str, lit(pos), lit(len))
 
+  /**
+   * Substring starts at `pos` and is of length `len` when str is String type 
or
+   * returns the slice of byte array that starts at `pos` in byte and is of 
length `len`
+   * when str is Binary type
+   *
+   * @note The position is not zero based, but 1 based index.
+   *
+   * @group string_funcs
+   * @since 4.0.0
+   */
+  def substring(str: Column, pos: Column, len: Column): Column =
+    Column.fn("substring", str, pos, len)
+
   /**
    * Returns the substring from string str before count occurrences of the 
delimiter delim.
    * If count is positive, everything the left of the final delimiter 
(counting from left) is
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 3fc0b572d80b..31c1cac9fb71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -332,6 +332,11 @@ class StringFunctionsSuite extends QueryTest with 
SharedSparkSession {
     // scalastyle:on
   }
 
+  test("string substring function using columns") {
+    val df = Seq(("Spark", 2, 3)).toDF("a", "b", "c")
+    checkAnswer(df.select(substring($"a", $"b", $"c")), Row("par"))
+  }
+
   test("string encode/decode function") {
     val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, 
-25, -107, -116)
     // scalastyle:off


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to