This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5633e9312137 [SPARK-52288][PS] Avoid INVALID_ARRAY_INDEX in 
`split`/`rsplit` when ANSI mode is on
5633e9312137 is described below

commit 5633e9312137ed648609023e6af5ccae56b88986
Author: Xinrong Meng <[email protected]>
AuthorDate: Fri Jun 6 11:22:39 2025 -0700

    [SPARK-52288][PS] Avoid INVALID_ARRAY_INDEX in `split`/`rsplit` when ANSI 
mode is on
    
    ### What changes were proposed in this pull request?
    Avoid INVALID_ARRAY_INDEX in `split`/`rsplit` when ANSI mode is on
    
    ### Why are the changes needed?
    Ensure pandas on Spark works well with ANSI mode on.
    Part of https://issues.apache.org/jira/browse/SPARK-52169.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. INVALID_ARRAY_INDEX no longer fails `split`/`rsplit` when ANSI mode is 
on
    
    ```py
    >>> spark.conf.get("spark.sql.ansi.enabled")
    'true'
    >>> import pandas as pd
    >>> pser = pd.Series(["hello-world", "short"])
    >>> psser = ps.from_pandas(pser)
    ```
    
    FROM
    ```py
    >>> psser.str.split("-", n=1, expand=True)
    25/05/28 14:52:10 ERROR Executor: Exception in task 10.0 in stage 2.0 (TID 
15)
    org.apache.spark.SparkArrayIndexOutOfBoundsException: [INVALID_ARRAY_INDEX] 
The index 1 is out of bounds. The array has 1 elements. Use the SQL function 
`get()` to tolerate accessing element at invalid index and return NULL instead. 
SQLSTATE: 22003
    == DataFrame ==
    "__getitem__" was called from
    <stdin>:1
    ...
    ```
    TO
    ```py
    >>> psser.str.split("-", n=1, expand=True)
           0      1
    0  hello  world
    1  short   None
    ```
    
    ### How was this patch tested?
    Unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #51006 from xinrong-meng/arr_idx_enable.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Takuya Ueshin <[email protected]>
---
 python/pyspark/pandas/strings.py                        | 17 +++++++++++++++--
 .../pyspark/pandas/tests/series/test_string_ops_adv.py  |  7 ++++---
 2 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/pandas/strings.py b/python/pyspark/pandas/strings.py
index 7e572bd1fae3..dc1544d8be39 100644
--- a/python/pyspark/pandas/strings.py
+++ b/python/pyspark/pandas/strings.py
@@ -32,6 +32,7 @@ from typing import (
 import numpy as np
 import pandas as pd
 
+from pyspark.pandas.utils import is_ansi_mode_enabled
 from pyspark.sql.types import StringType, BinaryType, ArrayType, LongType, 
MapType
 from pyspark.sql import functions as F
 from pyspark.sql.functions import pandas_udf
@@ -2031,7 +2032,13 @@ class StringMethods:
         if expand:
             psdf = psser.to_frame()
             scol = psdf._internal.data_spark_columns[0]
-            spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
+            spark_session = self._data._internal.spark_frame.sparkSession
+            if is_ansi_mode_enabled(spark_session):
+                spark_columns = [
+                    F.try_element_at(scol, F.lit(i + 1)).alias(str(i)) for i 
in range(n + 1)
+                ]
+            else:
+                spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
             column_labels = [(i,) for i in range(n + 1)]
             internal = psdf._internal.with_new_columns(
                 spark_columns,
@@ -2178,7 +2185,13 @@ class StringMethods:
         if expand:
             psdf = psser.to_frame()
             scol = psdf._internal.data_spark_columns[0]
-            spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
+            spark_session = self._data._internal.spark_frame.sparkSession
+            if is_ansi_mode_enabled(spark_session):
+                spark_columns = [
+                    F.try_element_at(scol, F.lit(i + 1)).alias(str(i)) for i 
in range(n + 1)
+                ]
+            else:
+                spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
             column_labels = [(i,) for i in range(n + 1)]
             internal = psdf._internal.with_new_columns(
                 spark_columns,
diff --git a/python/pyspark/pandas/tests/series/test_string_ops_adv.py 
b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
index e00252110dae..b0e4c69a35ea 100644
--- a/python/pyspark/pandas/tests/series/test_string_ops_adv.py
+++ b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
@@ -22,7 +22,6 @@ import re
 from pyspark import pandas as ps
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
-from pyspark.testing.utils import is_ansi_mode_test, 
ansi_mode_not_supported_message
 
 
 class SeriesStringOpsAdvMixin:
@@ -174,7 +173,6 @@ class SeriesStringOpsAdvMixin:
         self.check_func(lambda x: x.str.slice_replace(stop=2, repl="X"))
         self.check_func(lambda x: x.str.slice_replace(start=1, stop=3, 
repl="X"))
 
-    @unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
     def test_string_split(self):
         self.check_func_on_series(lambda x: repr(x.str.split()), 
self.pser[:-1])
         self.check_func_on_series(lambda x: repr(x.str.split(r"p*")), 
self.pser[:-1])
@@ -185,7 +183,8 @@ class SeriesStringOpsAdvMixin:
         with self.assertRaises(NotImplementedError):
             self.check_func(lambda x: x.str.split(expand=True))
 
-    @unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
+        self.check_func_on_series(lambda x: repr(x.str.split("-", n=1, 
expand=True)), pser)
+
     def test_string_rsplit(self):
         self.check_func_on_series(lambda x: repr(x.str.rsplit()), 
self.pser[:-1])
         self.check_func_on_series(lambda x: repr(x.str.rsplit(r"p*")), 
self.pser[:-1])
@@ -196,6 +195,8 @@ class SeriesStringOpsAdvMixin:
         with self.assertRaises(NotImplementedError):
             self.check_func(lambda x: x.str.rsplit(expand=True))
 
+        self.check_func_on_series(lambda x: repr(x.str.rsplit("-", n=1, 
expand=True)), pser)
+
     def test_string_translate(self):
         m = str.maketrans({"a": "X", "e": "Y", "i": None})
         self.check_func(lambda x: x.str.translate(m))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to