This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 829045efbe82 [SPARK-46163][PS] DataFrame.update parameters filter_func 
and errors
829045efbe82 is described below

commit 829045efbe820fba017ca108b7dab153425156b0
Author: Devin Petersohn <[email protected]>
AuthorDate: Thu Feb 19 10:10:36 2026 +0900

    [SPARK-46163][PS] DataFrame.update parameters filter_func and errors
    
    ### What changes were proposed in this pull request?
    
    DataFrame.update parameters filter_func and errors
    
    ### Why are the changes needed?
    
    To add missing parameters to `update` function
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, new parameter implementation
    
    ### How was this patch tested?
    
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Co-authored-by: Claude Sonnet 4.5
    
    Closes #54287 from devin-petersohn/devin/update_params.
    
    Authored-by: Devin Petersohn <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/pandas/frame.py                     |  81 ++++++++++++++--
 .../pandas/tests/computation/test_combine.py       | 107 +++++++++++++++++++++
 2 files changed, 181 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 62156d068c5d..4cbf4a8a4530 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -9132,12 +9132,21 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         )
         return DataFrame(internal)
 
-    # TODO(SPARK-46163): add 'filter_func' and 'errors' parameter
-    def update(self, other: "DataFrame", join: str = "left", overwrite: bool = 
True) -> None:
+    def update(
+        self,
+        other: "DataFrame",
+        join: str = "left",
+        overwrite: bool = True,
+        filter_func: Optional[Callable[[Any], bool]] = None,
+        errors: str = "ignore",
+    ) -> None:
         """
         Modify in place using non-NA values from another DataFrame.
         Aligns on indices. There is no return value.
 
+        .. note:: When ``errors='raise'``, this method forces materialization 
to check
+            for overlapping non-NA data, which may impact performance on large 
datasets.
+
         Parameters
         ----------
         other : DataFrame, or Series
@@ -9149,10 +9158,23 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
             * True: overwrite original DataFrame's values with values from 
`other`.
             * False: only update values that are NA in the original DataFrame.
 
+        filter_func : callable(1d-array) -> bool 1d-array, optional
+            Can choose to replace values other than NA. Return True for values
+            which should be updated. Applied to original DataFrame's values.
+        errors : {'ignore', 'raise'}, default 'ignore'
+            If 'raise', will raise a ValueError if the DataFrame and other both
+            contain non-NA data in the same place.
+
         Returns
         -------
         None : method directly changes calling object
 
+        Raises
+        ------
+        ValueError
+            If errors='raise' and overlapping non-NA data is detected.
+            If errors is not 'ignore' or 'raise'.
+
         See Also
         --------
         DataFrame.merge : For column(s)-on-columns(s) operations.
@@ -9204,9 +9226,22 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         0  1    4.0
         1  2  500.0
         2  3    6.0
+
+        Using filter_func to selectively update values:
+
+        >>> df = ps.DataFrame({'A': [1, 2, 3], 'B': [400, 500, 600]})
+        >>> new_df = ps.DataFrame({'B': [4, 5, 6]})
+        >>> df.update(new_df, filter_func=lambda x: x > 450)
+        >>> df.sort_index()
+           A    B
+        0  1  400
+        1  2    5
+        2  3    6
         """
         if join != "left":
             raise NotImplementedError("Only left join is supported")
+        if errors not in ("ignore", "raise"):
+            raise ValueError("errors must be either 'ignore' or 'raise'")
 
         if isinstance(other, ps.Series):
             other = other.to_frame()
@@ -9218,6 +9253,28 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
             other[update_columns], rsuffix="_new"
         )._internal.resolved_copy.spark_frame
 
+        if errors == "raise" and update_columns:
+            from pyspark.sql.types import BooleanType
+
+            any_overlap = F.lit(False)
+            for column_labels in update_columns:
+                column_name = 
self._internal.spark_column_name_for(column_labels)
+                old_col = scol_for(update_sdf, column_name)
+                new_col = scol_for(
+                    update_sdf, 
other._internal.spark_column_name_for(column_labels) + "_new"
+                )
+
+                overlap = old_col.isNotNull() & new_col.isNotNull()
+                if filter_func is not None:
+                    overlap = overlap & pandas_udf(  # type: 
ignore[call-overload]
+                        filter_func, BooleanType()
+                    )(old_col)
+
+                any_overlap = any_overlap | overlap
+
+            if update_sdf.select(F.max(F.when(any_overlap, 
1).otherwise(0))).first()[0]:
+                raise ValueError("Data overlaps.")
+
         data_fields = self._internal.data_fields.copy()
         for column_labels in update_columns:
             column_name = self._internal.spark_column_name_for(column_labels)
@@ -9225,14 +9282,24 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
             new_col = scol_for(
                 update_sdf, 
other._internal.spark_column_name_for(column_labels) + "_new"
             )
-            if overwrite:
-                update_sdf = update_sdf.withColumn(
-                    column_name, F.when(new_col.isNull(), 
old_col).otherwise(new_col)
+
+            if filter_func is not None:
+                from pyspark.sql.types import BooleanType
+
+                mask = pandas_udf(filter_func, BooleanType())(old_col)  # 
type: ignore[call-overload]
+                updated_col = (
+                    F.when(new_col.isNull() | mask.isNull() | ~mask, 
old_col).otherwise(new_col)
+                    if overwrite
+                    else F.when(old_col.isNull() & mask, 
new_col).otherwise(old_col)
                 )
             else:
-                update_sdf = update_sdf.withColumn(
-                    column_name, F.when(old_col.isNull(), 
new_col).otherwise(old_col)
+                updated_col = (
+                    F.when(new_col.isNull(), old_col).otherwise(new_col)
+                    if overwrite
+                    else F.when(old_col.isNull(), new_col).otherwise(old_col)
                 )
+
+            update_sdf = update_sdf.withColumn(column_name, updated_col)
             data_fields[self._internal.column_labels.index(column_labels)] = 
None
         sdf = update_sdf.select(
             *[scol_for(update_sdf, col) for col in 
self._internal.spark_column_names],
diff --git a/python/pyspark/pandas/tests/computation/test_combine.py 
b/python/pyspark/pandas/tests/computation/test_combine.py
index f9a09c94b4fc..b94e49cfa1a6 100644
--- a/python/pyspark/pandas/tests/computation/test_combine.py
+++ b/python/pyspark/pandas/tests/computation/test_combine.py
@@ -658,6 +658,113 @@ class FrameCombineMixin:
             left_psdf.sort_values(by=[("X", "A"), ("X", "B")]),
         )
 
+    def test_update_with_filter_func(self):
+        # Test filter_func parameter
+        left_pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [10, 20, 30, 40]})
+        right_pdf = pd.DataFrame({"B": [100, 200, 300, 400]})
+
+        left_psdf = ps.from_pandas(left_pdf)
+        right_psdf = ps.from_pandas(right_pdf)
+
+        # Only update values > 25
+        left_pdf.update(right_pdf, filter_func=lambda x: x > 25)
+        left_psdf.update(right_psdf, filter_func=lambda x: x > 25)
+
+        self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
+
+    def test_update_filter_func_overwrite_false(self):
+        # Test filter_func with overwrite=False
+        left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, 20, None]})
+        right_pdf = pd.DataFrame({"B": [100, 200, 300]})
+
+        left_psdf = ps.from_pandas(left_pdf)
+        right_psdf = ps.from_pandas(right_pdf)
+
+        # Only update where new value > 150 (and old is null)
+        left_pdf.update(right_pdf, overwrite=False, filter_func=lambda x: x > 
150)
+        left_psdf.update(right_psdf, overwrite=False, filter_func=lambda x: x 
> 150)
+
+        self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
+
+    def test_update_errors_raise_with_overlap(self):
+        # Test that errors='raise' raises ValueError on overlap
+        left_psdf = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
+        right_psdf = ps.DataFrame({"B": [100, 200, 300]})
+
+        # Should raise because both have non-null values
+        with self.assertRaisesRegex(ValueError, "Data overlaps."):
+            left_psdf.update(right_psdf, errors="raise")
+
+    def test_update_errors_raise_no_overlap(self):
+        # Test that errors='raise' works when no overlap
+        left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, None, 30]})
+        right_pdf = pd.DataFrame({"B": [100, 200, None]})
+
+        left_psdf = ps.from_pandas(left_pdf)
+        right_psdf = ps.from_pandas(right_pdf)
+
+        left_pdf.update(right_pdf, errors="raise")
+        left_psdf.update(right_psdf, errors="raise")
+
+        self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
+
+    def test_update_errors_invalid_value(self):
+        # Test that invalid errors parameter raises ValueError
+        left_psdf = ps.DataFrame({"A": [1, 2, 3]})
+        right_psdf = ps.DataFrame({"A": [4, 5, 6]})
+
+        with self.assertRaisesRegex(ValueError, "errors must be either 
'ignore' or 'raise'"):
+            left_psdf.update(right_psdf, errors="invalid")
+
+    def test_update_filter_func_and_errors_raise(self):
+        # Test combination of filter_func and errors='raise'
+        left_psdf = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
+        right_psdf = ps.DataFrame({"B": [100, 200, 300]})
+
+        # Filter only values < 25 - should find overlaps at positions 0 and 1
+        with self.assertRaisesRegex(ValueError, "Data overlaps."):
+            left_psdf.update(right_psdf, filter_func=lambda x: x < 25, 
errors="raise")
+
+        # Filter only values > 100 - no overlaps since no original values > 100
+        left_psdf2 = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
+        right_psdf2 = ps.DataFrame({"B": [100, 200, 300]})
+
+        # Should not raise - no values in original DataFrame match filter
+        left_psdf2.update(right_psdf2, filter_func=lambda x: x > 100, 
errors="raise")
+
+    def test_update_filter_func_all_false(self):
+        # Test filter_func that returns all False
+        left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
+        right_pdf = pd.DataFrame({"B": [100, 200, 300]})
+
+        left_psdf = ps.from_pandas(left_pdf.copy())
+        right_psdf = ps.from_pandas(right_pdf)
+
+        # Filter that matches nothing
+        original_left_pdf = left_pdf.copy()
+        original_left_psdf = left_psdf.copy()
+
+        left_pdf.update(right_pdf, filter_func=lambda x: x > 1000)
+        left_psdf.update(right_psdf, filter_func=lambda x: x > 1000)
+
+        # DataFrame should be unchanged
+        self.assert_eq(left_pdf.sort_index(), original_left_pdf.sort_index())
+        self.assert_eq(left_psdf.sort_index(), original_left_psdf.sort_index())
+
+    def test_update_filter_func_with_nulls(self):
+        # Test filter_func handling of null values
+        left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, 20, None]})
+        right_pdf = pd.DataFrame({"B": [100, 200, 300]})
+
+        left_psdf = ps.from_pandas(left_pdf)
+        right_psdf = ps.from_pandas(right_pdf)
+
+        # Filter values > 10 (nulls will not match)
+        left_pdf.update(right_pdf, filter_func=lambda x: x > 10)
+        left_psdf.update(right_psdf, filter_func=lambda x: x > 10)
+
+        self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
+
 
 class FrameCombineTests(
     FrameCombineMixin,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to