This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6ef15f9d075 [SPARK-38822][PYSPARK] Raise indexError when insert loc is
out of bounds
6ef15f9d075 is described below
commit 6ef15f9d075bf1735131f22b890043829c623f7f
Author: Yikun Jiang <[email protected]>
AuthorDate: Wed Apr 13 10:42:46 2022 +0900
[SPARK-38822][PYSPARK] Raise indexError when insert loc is out of bounds
### What changes were proposed in this pull request?
Since Pandas 1.4.0, pandas are using `numpy.insert`, `numpy.insert` raises
indexError when insert loc is out of bounds.
Related changes:
- panda 1.4.0+ is using numpy insert:
https://github.com/pandas-dev/pandas/commit/c021d33ecf0e096a186edb731964767e9288a875
- Since numpy 1.8 (10 years ago
https://github.com/numpy/numpy/commit/908e06c3c465434023649b0ca522836580c5cfdc)
: [`out-of-bound indices will generate an
error.`](https://numpy.org/devdocs/release/1.8.0-notes.html#changes)
### Why are the changes needed?
Follow pandas behaviors.
### Does this PR introduce _any_ user-facing change?
Yes, raise the index error in index out of bounds case, follow pandas
behaviors.
### How was this patch tested?
UT
Closes #36115 from Yikun/SPARK-38822.
Authored-by: Yikun Jiang <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/indexes/base.py | 7 +++---
python/pyspark/pandas/indexes/multi.py | 17 +++----------
python/pyspark/pandas/tests/indexes/test_base.py | 32 ++++++++++++++++++------
python/pyspark/pandas/tests/test_utils.py | 13 ++++++++++
python/pyspark/pandas/utils.py | 19 ++++++++++++++
5 files changed, 62 insertions(+), 26 deletions(-)
diff --git a/python/pyspark/pandas/indexes/base.py
b/python/pyspark/pandas/indexes/base.py
index fd1c2dff032..c8be0b436fa 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -73,6 +73,7 @@ from pyspark.pandas.utils import (
scol_for,
verify_temp_column_name,
validate_bool_kwarg,
+ validate_index_loc,
ERROR_MESSAGE_CANNOT_COMBINE,
log_advice,
)
@@ -2544,10 +2545,8 @@ class Index(IndexOpsMixin):
>>> psidx.insert(-3, 100)
Int64Index([1, 2, 100, 3, 4, 5], dtype='int64')
"""
- if loc < 0:
- length = len(self)
- loc = loc + length
- loc = 0 if loc < 0 else loc
+ validate_index_loc(self, loc)
+ loc = loc + len(self) if loc < 0 else loc
index_name = self._internal.index_spark_column_names[0]
sdf_before = self.to_frame(name=index_name)[:loc]._to_spark()
diff --git a/python/pyspark/pandas/indexes/multi.py
b/python/pyspark/pandas/indexes/multi.py
index e3aea0d075d..43ced2f4541 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -38,6 +38,7 @@ from pyspark.pandas.utils import (
name_like_string,
scol_for,
verify_temp_column_name,
+ validate_index_loc,
)
from pyspark.pandas.internal import (
InternalField,
@@ -1108,20 +1109,8 @@ class MultiIndex(Index):
('c', 'z')],
)
"""
- length = len(self)
- if loc < 0:
- loc = loc + length
- if loc < 0:
- raise IndexError(
- "index {} is out of bounds for axis 0 with size {}".format(
- (loc - length), length
- )
- )
- else:
- if loc > length:
- raise IndexError(
- "index {} is out of bounds for axis 0 with size
{}".format(loc, length)
- )
+ validate_index_loc(self, loc)
+ loc = loc + len(self) if loc < 0 else loc
index_name: List[Label] = [(name,) for name in
self._internal.index_spark_column_names]
sdf_before = self.to_frame(name=index_name)[:loc]._to_spark()
diff --git a/python/pyspark/pandas/tests/indexes/test_base.py
b/python/pyspark/pandas/tests/indexes/test_base.py
index 3e03bbc028c..de138b58c68 100644
--- a/python/pyspark/pandas/tests/indexes/test_base.py
+++ b/python/pyspark/pandas/tests/indexes/test_base.py
@@ -2191,32 +2191,48 @@ class IndexesTest(ComparisonTestBase, TestUtils):
psidx = ps.from_pandas(pidx)
self.assert_eq(pidx.insert(1, 100), psidx.insert(1, 100))
self.assert_eq(pidx.insert(-1, 100), psidx.insert(-1, 100))
- self.assert_eq(pidx.insert(100, 100), psidx.insert(100, 100))
- self.assert_eq(pidx.insert(-100, 100), psidx.insert(-100, 100))
+ err_msg = "index 100 is out of bounds for axis 0 with size 3"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ psidx.insert(100, 100)
+ err_msg = "index -100 is out of bounds for axis 0 with size 3"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ psidx.insert(-100, 100)
# Floating
pidx = pd.Index([1.0, 2.0, 3.0], name="Koalas")
psidx = ps.from_pandas(pidx)
self.assert_eq(pidx.insert(1, 100.0), psidx.insert(1, 100.0))
self.assert_eq(pidx.insert(-1, 100.0), psidx.insert(-1, 100.0))
- self.assert_eq(pidx.insert(100, 100.0), psidx.insert(100, 100.0))
- self.assert_eq(pidx.insert(-100, 100.0), psidx.insert(-100, 100.0))
+ err_msg = "index 100 is out of bounds for axis 0 with size 3"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ psidx.insert(100, 100)
+ err_msg = "index -100 is out of bounds for axis 0 with size 3"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ psidx.insert(-100, 100)
# String
pidx = pd.Index(["a", "b", "c"], name="Koalas")
psidx = ps.from_pandas(pidx)
self.assert_eq(pidx.insert(1, "x"), psidx.insert(1, "x"))
self.assert_eq(pidx.insert(-1, "x"), psidx.insert(-1, "x"))
- self.assert_eq(pidx.insert(100, "x"), psidx.insert(100, "x"))
- self.assert_eq(pidx.insert(-100, "x"), psidx.insert(-100, "x"))
+ err_msg = "index 100 is out of bounds for axis 0 with size 3"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ psidx.insert(100, "x")
+ err_msg = "index -100 is out of bounds for axis 0 with size 3"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ psidx.insert(-100, "x")
# Boolean
pidx = pd.Index([True, False, True, False], name="Koalas")
psidx = ps.from_pandas(pidx)
self.assert_eq(pidx.insert(1, True), psidx.insert(1, True))
self.assert_eq(pidx.insert(-1, True), psidx.insert(-1, True))
- self.assert_eq(pidx.insert(100, True), psidx.insert(100, True))
- self.assert_eq(pidx.insert(-100, True), psidx.insert(-100, True))
+ err_msg = "index 100 is out of bounds for axis 0 with size 4"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ psidx.insert(100, True)
+ err_msg = "index -100 is out of bounds for axis 0 with size 4"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ psidx.insert(-100, True)
# MultiIndex
pmidx = pd.MultiIndex.from_tuples(
diff --git a/python/pyspark/pandas/tests/test_utils.py
b/python/pyspark/pandas/tests/test_utils.py
index b601c695476..11f560c6f55 100644
--- a/python/pyspark/pandas/tests/test_utils.py
+++ b/python/pyspark/pandas/tests/test_utils.py
@@ -17,10 +17,12 @@
import pandas as pd
+from pyspark.pandas.indexes.base import Index
from pyspark.pandas.utils import (
lazy_property,
validate_arguments_and_invoke_function,
validate_bool_kwarg,
+ validate_index_loc,
validate_mode,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase
@@ -92,6 +94,17 @@ class UtilsTest(PandasOnSparkTestCase, SQLTestUtils):
with self.assertRaises(ValueError):
validate_mode("r")
+ def test_validate_index_loc(self):
+ psidx = Index([1, 2, 3])
+ validate_index_loc(psidx, -1)
+ validate_index_loc(psidx, -3)
+ err_msg = "index 4 is out of bounds for axis 0 with size 3"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ validate_index_loc(psidx, 4)
+ err_msg = "index -4 is out of bounds for axis 0 with size 3"
+ with self.assertRaisesRegex(IndexError, err_msg):
+ validate_index_loc(psidx, -4)
+
class TestClassForLazyProp:
def __init__(self):
diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py
index a61ea7d19b3..0a7831b94ff 100644
--- a/python/pyspark/pandas/utils.py
+++ b/python/pyspark/pandas/utils.py
@@ -49,6 +49,7 @@ from pyspark.pandas.spark import functions as SF
from pyspark.pandas.typedef.typehints import as_spark_type
if TYPE_CHECKING:
+ from pyspark.pandas.indexes.base import Index
from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.internal import InternalFrame
@@ -975,6 +976,24 @@ def log_advice(message: str) -> None:
warnings.warn(message, PandasAPIOnSparkAdviceWarning)
+def validate_index_loc(index: "Index", loc: int) -> None:
+ """
+ Raises IndexError if index is out of bounds
+ """
+ length = len(index)
+ if loc < 0:
+ loc = loc + length
+ if loc < 0:
+ raise IndexError(
+ "index {} is out of bounds for axis 0 with size
{}".format((loc - length), length)
+ )
+ else:
+ if loc > length:
+ raise IndexError(
+ "index {} is out of bounds for axis 0 with size
{}".format(loc, length)
+ )
+
+
def _test() -> None:
import os
import doctest
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]