allisonwang-db commented on code in PR #53400:
URL: https://github.com/apache/spark/pull/53400#discussion_r2603732635


##########
python/pyspark/sql/tests/test_udaf.py:
##########
@@ -0,0 +1,472 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql import Row
+from pyspark.sql.functions import udaf
+from pyspark.sql.udaf import Aggregator, UserDefinedAggregateFunction
+from pyspark.errors import PySparkTypeError, PySparkNotImplementedError
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual
+
+
+class UDAFTestsMixin:
+    def test_udaf_basic_sum(self):
+        """Test basic sum UDAF with df.agg()."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        df = self.spark.createDataFrame([(1,), (2,), (3,), (4,), (5,)], 
["value"])
+        sum_udaf = udaf(MySum(), "bigint", "MySum")
+
+        result = df.agg(sum_udaf(df.value))
+        assertDataFrameEqual(result, [Row(**{"MySum(value)": 15})])
+
+    def test_udaf_with_groupby(self):
+        """Test UDAF with groupBy().agg()."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        df = self.spark.createDataFrame(
+            [(1, "a"), (2, "a"), (3, "b"), (4, "b"), (5, "b")], ["value", 
"group"]
+        )
+        sum_udaf = udaf(MySum(), "bigint", "MySum")
+
+        # Test using groupBy().agg()
+        result = df.groupBy("group").agg(sum_udaf(df.value))
+        assertDataFrameEqual(
+            result.sort("group"),
+            [
+                Row(group="a", **{"MySum(value)": 3}),
+                Row(group="b", **{"MySum(value)": 12}),
+            ],
+        )
+
+    def test_udaf_average(self):
+        """Test average UDAF with df.agg()."""
+
+        class MyAverage(Aggregator):
+            def zero(self):
+                return (0.0, 0)  # (sum, count)
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return (buffer[0] + value, buffer[1] + 1)
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return (buffer1[0] + buffer2[0], buffer1[1] + buffer2[1])
+
+            def finish(self, reduction):
+                if reduction[1] == 0:
+                    return None
+                return reduction[0] / reduction[1]
+
+        df = self.spark.createDataFrame([(10,), (20,), (30,)], ["value"])
+        avg_udaf = udaf(MyAverage(), "double", "MyAverage")
+
+        result = df.agg(avg_udaf(df.value))
+        assertDataFrameEqual(result, [Row(**{"MyAverage(value)": 20.0})], 
checkRowOrder=False)
+
+    def test_udaf_max(self):
+        """Test max UDAF with df.agg()."""
+
+        class MyMax(Aggregator):
+            def zero(self):
+                return None
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    if buffer is None:
+                        return value
+                    return max(buffer, value)
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                if buffer1 is None:
+                    return buffer2
+                if buffer2 is None:
+                    return buffer1
+                return max(buffer1, buffer2)
+
+            def finish(self, reduction):
+                return reduction
+
+        df = self.spark.createDataFrame([(5,), (10,), (3,), (8,)], ["value"])
+        max_udaf = udaf(MyMax(), "bigint", "MyMax")
+
+        result = df.agg(max_udaf(df.value))
+        assertDataFrameEqual(result, [Row(**{"MyMax(value)": 10})])
+
+    def test_udaf_with_nulls(self):
+        """Test UDAF handling null values with df.agg()."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        df = self.spark.createDataFrame([(1,), (None,), (3,), (None,), (5,)], 
["value"])
+        sum_udaf = udaf(MySum(), "bigint", "MySum")
+
+        result = df.agg(sum_udaf(df.value))
+        assertDataFrameEqual(result, [Row(**{"MySum(value)": 9})])  # 1+3+5 = 
9 (nulls ignored)
+
+    def test_udaf_empty_dataframe(self):
+        """Test UDAF with empty DataFrame using df.agg()."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        from pyspark.sql.types import StructType, StructField, LongType
+
+        schema = StructType([StructField("value", LongType(), True)])
+        df = self.spark.createDataFrame([], schema)
+        sum_udaf = udaf(MySum(), "bigint", "MySum")
+
+        result = df.agg(sum_udaf(df.value))
+        assertDataFrameEqual(result, [Row(**{"MySum(value)": 0})])
+
+    def test_udaf_aggregator_interface(self):
+        """Test that Aggregator interface is properly defined."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        agg = MySum()
+
+        # Test zero
+        zero_val = agg.zero()
+        self.assertEqual(zero_val, 0)
+
+        # Test reduce
+        buffer = agg.reduce(agg.zero(), 5)
+        self.assertEqual(buffer, 5)
+        buffer = agg.reduce(buffer, 10)
+        self.assertEqual(buffer, 15)
+
+        # Test merge
+        merged = agg.merge(5, 10)
+        self.assertEqual(merged, 15)
+
+        # Test finish
+        result = agg.finish(merged)
+        self.assertEqual(result, 15)
+
+    def test_udaf_creation(self):
+        """Test UDAF creation with udaf() function."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        sum_udaf = udaf(MySum(), "bigint")
+        self.assertIsInstance(sum_udaf, UserDefinedAggregateFunction)
+        self.assertEqual(sum_udaf._name, "MySum")
+
+        # Test with custom name
+        custom_udaf = udaf(MySum(), "bigint", "CustomSum")
+        self.assertEqual(custom_udaf._name, "CustomSum")
+
+    def test_udaf_invalid_aggregator(self):
+        """Test that invalid aggregator raises error."""
+        with self.assertRaises(PySparkTypeError):
+            udaf("not an aggregator", "bigint")  # type: ignore
+
+    def test_udaf_invalid_return_type(self):
+        """Test that invalid return type raises error."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        with self.assertRaises(PySparkTypeError):
+            udaf(MySum(), 123)  # type: ignore  # Not a DataType or string
+
+    def test_udaf_multi_column_not_supported(self):
+        """Test that multi-column UDAF is not yet supported."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        df = self.spark.createDataFrame([(1, 2), (3, 4)], ["a", "b"])
+        sum_udaf = udaf(MySum(), "bigint")
+
+        with self.assertRaises(PySparkNotImplementedError):
+            df.agg(sum_udaf(df.a), sum_udaf(df.b))
+
+    def test_udaf_large_dataset(self):
+        """Test UDAF with large dataset (20000 rows) distributed across 
partitions."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return buffer + value
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        # Create a large dataset with 20000 rows
+        # Use repartition to ensure data is distributed across multiple 
partitions
+        data = [(i,) for i in range(1, 20001)]  # 1 to 20000
+        df = self.spark.createDataFrame(data, ["value"])
+
+        # Repartition to ensure data is distributed across partitions
+        # This ensures each partition has data
+        num_partitions = max(4, self.spark.sparkContext.defaultParallelism)
+        df = df.repartition(num_partitions)
+
+        # Verify data is distributed
+        actual_partitions = df.rdd.getNumPartitions()
+        self.assertGreater(
+            actual_partitions, 1, "Data should be distributed across multiple 
partitions"
+        )
+
+        # Calculate expected sum: 1 + 2 + ... + 20000 = 20000 * 20001 / 2 = 
200010000
+        expected_sum = 20000 * 20001 // 2
+
+        sum_udaf = udaf(MySum(), "bigint", "MySum")
+        result = df.agg(sum_udaf(df.value))
+        assertDataFrameEqual(result, [Row(**{"MySum(value)": expected_sum})])
+
+    def test_udaf_large_dataset_average(self):
+        """Test average UDAF with large dataset using df.agg()."""
+
+        class MyAverage(Aggregator):
+            def zero(self):
+                return (0.0, 0)  # (sum, count)
+
+            def reduce(self, buffer, value):
+                if value is not None:
+                    return (buffer[0] + value, buffer[1] + 1)
+                return buffer
+
+            def merge(self, buffer1, buffer2):
+                return (buffer1[0] + buffer2[0], buffer1[1] + buffer2[1])
+
+            def finish(self, reduction):
+                if reduction[1] == 0:
+                    return None
+                return reduction[0] / reduction[1]
+
+        # Create a large dataset with 20000 rows
+        data = [(float(i),) for i in range(1, 20001)]  # 1.0 to 20000.0
+        df = self.spark.createDataFrame(data, ["value"])
+
+        # Repartition to ensure data is distributed
+        num_partitions = max(4, self.spark.sparkContext.defaultParallelism)
+        df = df.repartition(num_partitions)
+
+        # Expected average: (1 + 2 + ... + 20000) / 20000 = 20001 / 2 = 10000.5
+        expected_avg = 20001.0 / 2.0
+
+        avg_udaf = udaf(MyAverage(), "double", "MyAverage")
+        result = df.agg(avg_udaf(df.value))
+        assertDataFrameEqual(
+            result, [Row(**{"MyAverage(value)": expected_avg})], 
checkRowOrder=False
+        )
+
+    def test_udaf_column_attributes(self):
+        """Test that UDAF Column has correct attributes."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                return buffer + (value or 0)
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        df = self.spark.createDataFrame([(1,), (2,)], ["value"])
+        sum_udaf = udaf(MySum(), "bigint", "MySum")
+
+        # Test that __call__ returns a Column with UDAF attributes
+        udaf_col = sum_udaf(df.value)
+        self.assertTrue(hasattr(udaf_col, "_udaf_func"))
+        self.assertTrue(hasattr(udaf_col, "_udaf_col"))
+        self.assertEqual(udaf_col._udaf_func, sum_udaf)  # type: 
ignore[attr-defined]
+
+    def test_udaf_multiple_udaf_not_supported(self):
+        """Test that multiple UDAFs in agg() raises error."""
+
+        class MySum(Aggregator):
+            def zero(self):
+                return 0
+
+            def reduce(self, buffer, value):
+                return buffer + (value or 0)
+
+            def merge(self, buffer1, buffer2):
+                return buffer1 + buffer2
+
+            def finish(self, reduction):
+                return reduction
+
+        df = self.spark.createDataFrame([(1,), (2,)], ["value"])
+        sum_udaf1 = udaf(MySum(), "bigint", "MySum1")
+        sum_udaf2 = udaf(MySum(), "bigint", "MySum2")
+
+        with self.assertRaises(PySparkNotImplementedError):
+            df.agg(sum_udaf1(df.value), sum_udaf2(df.value))
+
+    def test_udaf_mixed_with_other_agg_not_supported(self):
+        """Test that mixing UDAF with other aggregate functions raises 
error."""
+
+        class MySum(Aggregator):

Review Comment:
   Can we add some tests for more complicated data structures? like dictionary?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to