This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 829045efbe82 [SPARK-46163][PS] DataFrame.update parameters filter_func
and errors
829045efbe82 is described below
commit 829045efbe820fba017ca108b7dab153425156b0
Author: Devin Petersohn <[email protected]>
AuthorDate: Thu Feb 19 10:10:36 2026 +0900
[SPARK-46163][PS] DataFrame.update parameters filter_func and errors
### What changes were proposed in this pull request?
DataFrame.update parameters filter_func and errors
### Why are the changes needed?
To add missing parameters to `update` function
### Does this PR introduce _any_ user-facing change?
Yes, new parameter implementation
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
Co-authored-by: Claude Sonnet 4.5
Closes #54287 from devin-petersohn/devin/update_params.
Authored-by: Devin Petersohn <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/frame.py | 81 ++++++++++++++--
.../pandas/tests/computation/test_combine.py | 107 +++++++++++++++++++++
2 files changed, 181 insertions(+), 7 deletions(-)
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 62156d068c5d..4cbf4a8a4530 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -9132,12 +9132,21 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
)
return DataFrame(internal)
- # TODO(SPARK-46163): add 'filter_func' and 'errors' parameter
- def update(self, other: "DataFrame", join: str = "left", overwrite: bool =
True) -> None:
+ def update(
+ self,
+ other: "DataFrame",
+ join: str = "left",
+ overwrite: bool = True,
+ filter_func: Optional[Callable[[Any], bool]] = None,
+ errors: str = "ignore",
+ ) -> None:
"""
Modify in place using non-NA values from another DataFrame.
Aligns on indices. There is no return value.
+ .. note:: When ``errors='raise'``, this method forces materialization
to check
+ for overlapping non-NA data, which may impact performance on large
datasets.
+
Parameters
----------
other : DataFrame, or Series
@@ -9149,10 +9158,23 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
* True: overwrite original DataFrame's values with values from
`other`.
* False: only update values that are NA in the original DataFrame.
+ filter_func : callable(1d-array) -> bool 1d-array, optional
+ Can choose to replace values other than NA. Return True for values
+ which should be updated. Applied to original DataFrame's values.
+ errors : {'ignore', 'raise'}, default 'ignore'
+ If 'raise', will raise a ValueError if the DataFrame and other both
+ contain non-NA data in the same place.
+
Returns
-------
None : method directly changes calling object
+ Raises
+ ------
+ ValueError
+ If errors='raise' and overlapping non-NA data is detected.
+ If errors is not 'ignore' or 'raise'.
+
See Also
--------
DataFrame.merge : For column(s)-on-columns(s) operations.
@@ -9204,9 +9226,22 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
0 1 4.0
1 2 500.0
2 3 6.0
+
+ Using filter_func to selectively update values:
+
+ >>> df = ps.DataFrame({'A': [1, 2, 3], 'B': [400, 500, 600]})
+ >>> new_df = ps.DataFrame({'B': [4, 5, 6]})
+ >>> df.update(new_df, filter_func=lambda x: x > 450)
+ >>> df.sort_index()
+ A B
+ 0 1 400
+ 1 2 5
+ 2 3 6
"""
if join != "left":
raise NotImplementedError("Only left join is supported")
+ if errors not in ("ignore", "raise"):
+ raise ValueError("errors must be either 'ignore' or 'raise'")
if isinstance(other, ps.Series):
other = other.to_frame()
@@ -9218,6 +9253,28 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
other[update_columns], rsuffix="_new"
)._internal.resolved_copy.spark_frame
+ if errors == "raise" and update_columns:
+ from pyspark.sql.types import BooleanType
+
+ any_overlap = F.lit(False)
+ for column_labels in update_columns:
+ column_name =
self._internal.spark_column_name_for(column_labels)
+ old_col = scol_for(update_sdf, column_name)
+ new_col = scol_for(
+ update_sdf,
other._internal.spark_column_name_for(column_labels) + "_new"
+ )
+
+ overlap = old_col.isNotNull() & new_col.isNotNull()
+ if filter_func is not None:
+ overlap = overlap & pandas_udf( # type:
ignore[call-overload]
+ filter_func, BooleanType()
+ )(old_col)
+
+ any_overlap = any_overlap | overlap
+
+ if update_sdf.select(F.max(F.when(any_overlap,
1).otherwise(0))).first()[0]:
+ raise ValueError("Data overlaps.")
+
data_fields = self._internal.data_fields.copy()
for column_labels in update_columns:
column_name = self._internal.spark_column_name_for(column_labels)
@@ -9225,14 +9282,24 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
new_col = scol_for(
update_sdf,
other._internal.spark_column_name_for(column_labels) + "_new"
)
- if overwrite:
- update_sdf = update_sdf.withColumn(
- column_name, F.when(new_col.isNull(),
old_col).otherwise(new_col)
+
+ if filter_func is not None:
+ from pyspark.sql.types import BooleanType
+
+ mask = pandas_udf(filter_func, BooleanType())(old_col) #
type: ignore[call-overload]
+ updated_col = (
+ F.when(new_col.isNull() | mask.isNull() | ~mask,
old_col).otherwise(new_col)
+ if overwrite
+ else F.when(old_col.isNull() & mask,
new_col).otherwise(old_col)
)
else:
- update_sdf = update_sdf.withColumn(
- column_name, F.when(old_col.isNull(),
new_col).otherwise(old_col)
+ updated_col = (
+ F.when(new_col.isNull(), old_col).otherwise(new_col)
+ if overwrite
+ else F.when(old_col.isNull(), new_col).otherwise(old_col)
)
+
+ update_sdf = update_sdf.withColumn(column_name, updated_col)
data_fields[self._internal.column_labels.index(column_labels)] =
None
sdf = update_sdf.select(
*[scol_for(update_sdf, col) for col in
self._internal.spark_column_names],
diff --git a/python/pyspark/pandas/tests/computation/test_combine.py
b/python/pyspark/pandas/tests/computation/test_combine.py
index f9a09c94b4fc..b94e49cfa1a6 100644
--- a/python/pyspark/pandas/tests/computation/test_combine.py
+++ b/python/pyspark/pandas/tests/computation/test_combine.py
@@ -658,6 +658,113 @@ class FrameCombineMixin:
left_psdf.sort_values(by=[("X", "A"), ("X", "B")]),
)
+ def test_update_with_filter_func(self):
+ # Test filter_func parameter
+ left_pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [10, 20, 30, 40]})
+ right_pdf = pd.DataFrame({"B": [100, 200, 300, 400]})
+
+ left_psdf = ps.from_pandas(left_pdf)
+ right_psdf = ps.from_pandas(right_pdf)
+
+ # Only update values > 25
+ left_pdf.update(right_pdf, filter_func=lambda x: x > 25)
+ left_psdf.update(right_psdf, filter_func=lambda x: x > 25)
+
+ self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
+
+ def test_update_filter_func_overwrite_false(self):
+ # Test filter_func with overwrite=False
+ left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, 20, None]})
+ right_pdf = pd.DataFrame({"B": [100, 200, 300]})
+
+ left_psdf = ps.from_pandas(left_pdf)
+ right_psdf = ps.from_pandas(right_pdf)
+
+ # Only update where new value > 150 (and old is null)
+ left_pdf.update(right_pdf, overwrite=False, filter_func=lambda x: x >
150)
+ left_psdf.update(right_psdf, overwrite=False, filter_func=lambda x: x
> 150)
+
+ self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
+
+ def test_update_errors_raise_with_overlap(self):
+ # Test that errors='raise' raises ValueError on overlap
+ left_psdf = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
+ right_psdf = ps.DataFrame({"B": [100, 200, 300]})
+
+ # Should raise because both have non-null values
+ with self.assertRaisesRegex(ValueError, "Data overlaps."):
+ left_psdf.update(right_psdf, errors="raise")
+
+ def test_update_errors_raise_no_overlap(self):
+ # Test that errors='raise' works when no overlap
+ left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, None, 30]})
+ right_pdf = pd.DataFrame({"B": [100, 200, None]})
+
+ left_psdf = ps.from_pandas(left_pdf)
+ right_psdf = ps.from_pandas(right_pdf)
+
+ left_pdf.update(right_pdf, errors="raise")
+ left_psdf.update(right_psdf, errors="raise")
+
+ self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
+
+ def test_update_errors_invalid_value(self):
+ # Test that invalid errors parameter raises ValueError
+ left_psdf = ps.DataFrame({"A": [1, 2, 3]})
+ right_psdf = ps.DataFrame({"A": [4, 5, 6]})
+
+ with self.assertRaisesRegex(ValueError, "errors must be either
'ignore' or 'raise'"):
+ left_psdf.update(right_psdf, errors="invalid")
+
+ def test_update_filter_func_and_errors_raise(self):
+ # Test combination of filter_func and errors='raise'
+ left_psdf = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
+ right_psdf = ps.DataFrame({"B": [100, 200, 300]})
+
+ # Filter only values < 25 - should find overlaps at positions 0 and 1
+ with self.assertRaisesRegex(ValueError, "Data overlaps."):
+ left_psdf.update(right_psdf, filter_func=lambda x: x < 25,
errors="raise")
+
+ # Filter only values > 100 - no overlaps since no original values > 100
+ left_psdf2 = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
+ right_psdf2 = ps.DataFrame({"B": [100, 200, 300]})
+
+ # Should not raise - no values in original DataFrame match filter
+ left_psdf2.update(right_psdf2, filter_func=lambda x: x > 100,
errors="raise")
+
+ def test_update_filter_func_all_false(self):
+ # Test filter_func that returns all False
+ left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
+ right_pdf = pd.DataFrame({"B": [100, 200, 300]})
+
+ left_psdf = ps.from_pandas(left_pdf.copy())
+ right_psdf = ps.from_pandas(right_pdf)
+
+ # Filter that matches nothing
+ original_left_pdf = left_pdf.copy()
+ original_left_psdf = left_psdf.copy()
+
+ left_pdf.update(right_pdf, filter_func=lambda x: x > 1000)
+ left_psdf.update(right_psdf, filter_func=lambda x: x > 1000)
+
+ # DataFrame should be unchanged
+ self.assert_eq(left_pdf.sort_index(), original_left_pdf.sort_index())
+ self.assert_eq(left_psdf.sort_index(), original_left_psdf.sort_index())
+
+ def test_update_filter_func_with_nulls(self):
+ # Test filter_func handling of null values
+ left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, 20, None]})
+ right_pdf = pd.DataFrame({"B": [100, 200, 300]})
+
+ left_psdf = ps.from_pandas(left_pdf)
+ right_psdf = ps.from_pandas(right_pdf)
+
+ # Filter values > 10 (nulls will not match)
+ left_pdf.update(right_pdf, filter_func=lambda x: x > 10)
+ left_psdf.update(right_psdf, filter_func=lambda x: x > 10)
+
+ self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
+
class FrameCombineTests(
FrameCombineMixin,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]