This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9054aa287601 [SPARK-50406][PYTHON][TESTS] Improve 
pyspark.sql.tests.test_udtf
9054aa287601 is described below

commit 9054aa287601ab0596f6ea2a43c9176b848c06fc
Author: Xinrong Meng <[email protected]>
AuthorDate: Tue Nov 26 14:26:14 2024 +0800

    [SPARK-50406][PYTHON][TESTS] Improve pyspark.sql.tests.test_udtf
    
    ### What changes were proposed in this pull request?
    Improve pyspark.sql.tests.test_udtf by
    - extract `udtf_for_table_argument` for code reuse
    - use assertDataFrameEqual universally
    
    ### Why are the changes needed?
    Code cleanup.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Test changes only.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48949 from xinrong-meng/impr_udtf_test.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Xinrong Meng <[email protected]>
---
 python/pyspark/sql/tests/test_udtf.py | 212 +++++++++++++++-------------------
 1 file changed, 91 insertions(+), 121 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index f3f993fc6a78..206cfd7dc488 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -74,8 +74,7 @@ class BaseUDTFTestsMixin:
                 yield "hello", "world"
 
         func = udtf(TestUDTF, returnType="c1: string, c2: string")
-        rows = func().collect()
-        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+        assertDataFrameEqual(func(), [Row(c1="hello", c2="world")])
 
     def test_udtf_yield_single_row_col(self):
         class TestUDTF:
@@ -83,8 +82,7 @@ class BaseUDTFTestsMixin:
                 yield a,
 
         func = udtf(TestUDTF, returnType="a: int")
-        rows = func(lit(1)).collect()
-        self.assertEqual(rows, [Row(a=1)])
+        assertDataFrameEqual(func(lit(1)), [Row(a=1)])
 
     def test_udtf_yield_multi_cols(self):
         class TestUDTF:
@@ -92,8 +90,7 @@ class BaseUDTFTestsMixin:
                 yield a, a + 1
 
         func = udtf(TestUDTF, returnType="a: int, b: int")
-        rows = func(lit(1)).collect()
-        self.assertEqual(rows, [Row(a=1, b=2)])
+        assertDataFrameEqual(func(lit(1)), [Row(a=1, b=2)])
 
     def test_udtf_yield_multi_rows(self):
         class TestUDTF:
@@ -102,8 +99,7 @@ class BaseUDTFTestsMixin:
                 yield a + 1,
 
         func = udtf(TestUDTF, returnType="a: int")
-        rows = func(lit(1)).collect()
-        self.assertEqual(rows, [Row(a=1), Row(a=2)])
+        assertDataFrameEqual(func(lit(1)), [Row(a=1), Row(a=2)])
 
     def test_udtf_yield_multi_row_col(self):
         class TestUDTF:
@@ -113,8 +109,8 @@ class BaseUDTFTestsMixin:
                 yield a, b, b - a
 
         func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
-        rows = func(lit(1), lit(2)).collect()
-        self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), 
Row(a=1, b=2, c=1)])
+        res = func(lit(1), lit(2))
+        assertDataFrameEqual(res, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), 
Row(a=1, b=2, c=1)])
 
     def test_udtf_decorator(self):
         @udtf(returnType="a: int, b: int")
@@ -122,8 +118,7 @@ class BaseUDTFTestsMixin:
             def eval(self, a: int):
                 yield a, a + 1
 
-        rows = TestUDTF(lit(1)).collect()
-        self.assertEqual(rows, [Row(a=1, b=2)])
+        assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1, b=2)])
 
     def test_udtf_registration(self):
         class TestUDTF:
@@ -135,9 +130,7 @@ class BaseUDTFTestsMixin:
         func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
         self.spark.udtf.register("testUDTF", func)
         df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
-        self.assertEqual(
-            df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, 
b=2, c=1)]
-        )
+        assertDataFrameEqual(df, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), 
Row(a=1, b=2, c=1)])
 
     def test_udtf_with_lateral_join(self):
         class TestUDTF:
@@ -150,10 +143,17 @@ class BaseUDTFTestsMixin:
         df = self.spark.sql(
             "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL 
testUDTF(a, b) f"
         )
+        schema = StructType(
+            [
+                StructField("a", IntegerType(), True),
+                StructField("b", IntegerType(), True),
+                StructField("c", IntegerType(), True),
+            ]
+        )
         expected = self.spark.createDataFrame(
-            [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", 
"c"]
+            [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=schema
         )
-        self.assertEqual(df.collect(), expected.collect())
+        assertDataFrameEqual(df, expected)
 
     def test_udtf_eval_with_return_stmt(self):
         class TestUDTF:
@@ -161,8 +161,8 @@ class BaseUDTFTestsMixin:
                 return [(a, a + 1), (b, b + 1)]
 
         func = udtf(TestUDTF, returnType="a: int, b: int")
-        rows = func(lit(1), lit(2)).collect()
-        self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+        res = func(lit(1), lit(2))
+        assertDataFrameEqual(res, [Row(a=1, b=2), Row(a=2, b=3)])
 
     def test_udtf_eval_returning_non_tuple(self):
         @udtf(returnType="a: int")
@@ -217,14 +217,14 @@ class BaseUDTFTestsMixin:
             def eval(self, a: int):
                 ...
 
-        self.assertEqual(TestUDTF(lit(1)).collect(), [])
+        assertDataFrameEqual(TestUDTF(lit(1)), [])
 
         @udtf(returnType="a: int")
         class TestUDTF:
             def eval(self, a: int):
                 return
 
-        self.assertEqual(TestUDTF(lit(1)).collect(), [])
+        assertDataFrameEqual(TestUDTF(lit(1)), [])
 
     def test_udtf_with_conditional_return(self):
         class TestUDTF:
@@ -234,8 +234,8 @@ class BaseUDTFTestsMixin:
 
         func = udtf(TestUDTF, returnType="a: int")
         self.spark.udtf.register("test_udtf", func)
-        self.assertEqual(
-            self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL 
test_udtf(id)").collect(),
+        assertDataFrameEqual(
+            self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL 
test_udtf(id)"),
             [Row(id=6, a=6), Row(id=7, a=7)],
         )
 
@@ -254,9 +254,9 @@ class BaseUDTFTestsMixin:
                 yield a,
                 yield None,
 
-        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)])
+        assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1), Row(a=None)])
         df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"])
-        self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), 
[Row(a=1, b=2)])
+        assertDataFrameEqual(TestUDTF(lit(1)).join(df, "a", "inner"), 
[Row(a=1, b=2)])
         assertDataFrameEqual(
             TestUDTF(lit(1)).join(df, "a", "left"), [Row(a=None, b=None), 
Row(a=1, b=2)]
         )
@@ -267,10 +267,10 @@ class BaseUDTFTestsMixin:
             def eval(self, a: int):
                 yield a,
 
-        self.assertEqual(TestUDTF(lit(None)).collect(), [Row(a=None)])
+        assertDataFrameEqual(TestUDTF(lit(None)), [Row(a=None)])
         self.spark.udtf.register("testUDTF", TestUDTF)
         df = self.spark.sql("SELECT * FROM testUDTF(null)")
-        self.assertEqual(df.collect(), [Row(a=None)])
+        assertDataFrameEqual(df, [Row(a=None)])
 
     # These are expected error message substrings to be used in test cases 
below.
     tooManyPositionalArguments = "too many positional arguments"
@@ -366,8 +366,8 @@ class BaseUDTFTestsMixin:
             def eval(self, a: int):
                 yield a, a + 1, self.key
 
-        rows = TestUDTF(lit(1)).collect()
-        self.assertEqual(rows, [Row(a=1, b=2, c="test")])
+        res = TestUDTF(lit(1))
+        assertDataFrameEqual(res, [Row(a=1, b=2, c="test")])
 
     def test_udtf_terminate(self):
         @udtf(returnType="key: string, value: float")
@@ -385,8 +385,8 @@ class BaseUDTFTestsMixin:
                 yield "count", float(self._count)
                 yield "avg", self._sum / self._count
 
-        self.assertEqual(
-            TestUDTF(lit(1)).collect(),
+        assertDataFrameEqual(
+            TestUDTF(lit(1)),
             [Row(key="input", value=1), Row(key="count", value=1.0), 
Row(key="avg", value=1.0)],
         )
 
@@ -395,8 +395,8 @@ class BaseUDTFTestsMixin:
             "SELECT id, key, value FROM range(0, 10, 1, 2), "
             "LATERAL test_udtf(id) WHERE key != 'input'"
         )
-        self.assertEqual(
-            df.collect(),
+        assertDataFrameEqual(
+            df,
             [
                 Row(id=4, key="count", value=5.0),
                 Row(id=4, key="avg", value=2.0),
@@ -608,10 +608,8 @@ class BaseUDTFTestsMixin:
                 yield f"{person.name}: {person.age}",
 
         self.spark.udtf.register("test_udtf", TestUDTF)
-        self.assertEqual(
-            self.spark.sql(
-                "select * from test_udtf(named_struct('name', 'Alice', 'age', 
1))"
-            ).collect(),
+        assertDataFrameEqual(
+            self.spark.sql("select * from test_udtf(named_struct('name', 
'Alice', 'age', 1))"),
             [Row(x="Alice: 1")],
         )
 
@@ -634,8 +632,8 @@ class BaseUDTFTestsMixin:
                 yield str(m),
 
         self.spark.udtf.register("test_udtf", TestUDTF)
-        self.assertEqual(
-            self.spark.sql("select * from test_udtf(map('key', 
'value'))").collect(),
+        assertDataFrameEqual(
+            self.spark.sql("select * from test_udtf(map('key', 'value'))"),
             [Row(x="{'key': 'value'}")],
         )
 
@@ -645,7 +643,7 @@ class BaseUDTFTestsMixin:
             def eval(self, x: int):
                 yield {"a": x, "b": x + 1},
 
-        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x=Row(a=1, b=2))])
+        assertDataFrameEqual(TestUDTF(lit(1)), [Row(x=Row(a=1, b=2))])
 
     def test_udtf_with_array_output_types(self):
         @udtf(returnType="x: array<int>")
@@ -653,7 +651,7 @@ class BaseUDTFTestsMixin:
             def eval(self, x: int):
                 yield [x, x + 1, x + 2],
 
-        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x=[1, 2, 3])])
+        assertDataFrameEqual(TestUDTF(lit(1)), [Row(x=[1, 2, 3])])
 
     def test_udtf_with_map_output_types(self):
         @udtf(returnType="x: map<int,string>")
@@ -661,7 +659,7 @@ class BaseUDTFTestsMixin:
             def eval(self, x: int):
                 yield {x: str(x)},
 
-        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x={1: "1"})])
+        assertDataFrameEqual(TestUDTF(lit(1)), [Row(x={1: "1"})])
 
     def test_udtf_with_empty_output_types(self):
         @udtf(returnType=StructType())
@@ -1019,17 +1017,21 @@ class BaseUDTFTestsMixin:
         )
 
     def test_udtf_with_table_argument_query(self):
+        func = self.udtf_for_table_argument()
+        self.spark.udtf.register("test_udtf", func)
+        assertDataFrameEqual(
+            self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM 
range(0, 8)))"),
+            [Row(a=6), Row(a=7)],
+        )
+
+    def udtf_for_table_argument(self):
         class TestUDTF:
             def eval(self, row: Row):
                 if row["id"] > 5:
                     yield row["id"],
 
         func = udtf(TestUDTF, returnType="a: int")
-        self.spark.udtf.register("test_udtf", func)
-        self.assertEqual(
-            self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM 
range(0, 8)))").collect(),
-            [Row(a=6), Row(a=7)],
-        )
+        return func
 
     def test_udtf_with_int_and_table_argument_query(self):
         class TestUDTF:
@@ -1039,26 +1041,19 @@ class BaseUDTFTestsMixin:
 
         func = udtf(TestUDTF, returnType="a: int")
         self.spark.udtf.register("test_udtf", func)
-        self.assertEqual(
-            self.spark.sql(
-                "SELECT * FROM test_udtf(5, TABLE (SELECT id FROM range(0, 
8)))"
-            ).collect(),
+        assertDataFrameEqual(
+            self.spark.sql("SELECT * FROM test_udtf(5, TABLE (SELECT id FROM 
range(0, 8)))"),
             [Row(a=6), Row(a=7)],
         )
 
     def test_udtf_with_table_argument_identifier(self):
-        class TestUDTF:
-            def eval(self, row: Row):
-                if row["id"] > 5:
-                    yield row["id"],
-
-        func = udtf(TestUDTF, returnType="a: int")
+        func = self.udtf_for_table_argument()
         self.spark.udtf.register("test_udtf", func)
 
         with self.tempView("v"):
             self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id 
FROM range(0, 8)")
-            self.assertEqual(
-                self.spark.sql("SELECT * FROM test_udtf(TABLE (v))").collect(),
+            assertDataFrameEqual(
+                self.spark.sql("SELECT * FROM test_udtf(TABLE (v))"),
                 [Row(a=6), Row(a=7)],
             )
 
@@ -1073,44 +1068,29 @@ class BaseUDTFTestsMixin:
 
         with self.tempView("v"):
             self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id 
FROM range(0, 8)")
-            self.assertEqual(
-                self.spark.sql("SELECT * FROM test_udtf(5, TABLE 
(v))").collect(),
+            assertDataFrameEqual(
+                self.spark.sql("SELECT * FROM test_udtf(5, TABLE (v))"),
                 [Row(a=6), Row(a=7)],
             )
 
     def test_udtf_with_table_argument_unknown_identifier(self):
-        class TestUDTF:
-            def eval(self, row: Row):
-                if row["id"] > 5:
-                    yield row["id"],
-
-        func = udtf(TestUDTF, returnType="a: int")
+        func = self.udtf_for_table_argument()
         self.spark.udtf.register("test_udtf", func)
 
         with self.assertRaisesRegex(AnalysisException, 
"TABLE_OR_VIEW_NOT_FOUND"):
             self.spark.sql("SELECT * FROM test_udtf(TABLE (v))").collect()
 
     def test_udtf_with_table_argument_malformed_query(self):
-        class TestUDTF:
-            def eval(self, row: Row):
-                if row["id"] > 5:
-                    yield row["id"],
-
-        func = udtf(TestUDTF, returnType="a: int")
+        func = self.udtf_for_table_argument()
         self.spark.udtf.register("test_udtf", func)
 
         with self.assertRaisesRegex(AnalysisException, 
"TABLE_OR_VIEW_NOT_FOUND"):
             self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT * FROM 
v))").collect()
 
     def test_udtf_with_table_argument_cte_inside(self):
-        class TestUDTF:
-            def eval(self, row: Row):
-                if row["id"] > 5:
-                    yield row["id"],
-
-        func = udtf(TestUDTF, returnType="a: int")
+        func = self.udtf_for_table_argument()
         self.spark.udtf.register("test_udtf", func)
-        self.assertEqual(
+        assertDataFrameEqual(
             self.spark.sql(
                 """
                 SELECT * FROM test_udtf(TABLE (
@@ -1120,19 +1100,14 @@ class BaseUDTFTestsMixin:
                   SELECT * FROM t
                 ))
                 """
-            ).collect(),
+            ),
             [Row(a=6), Row(a=7)],
         )
 
     def test_udtf_with_table_argument_cte_outside(self):
-        class TestUDTF:
-            def eval(self, row: Row):
-                if row["id"] > 5:
-                    yield row["id"],
-
-        func = udtf(TestUDTF, returnType="a: int")
+        func = self.udtf_for_table_argument()
         self.spark.udtf.register("test_udtf", func)
-        self.assertEqual(
+        assertDataFrameEqual(
             self.spark.sql(
                 """
                 WITH t AS (
@@ -1140,11 +1115,11 @@ class BaseUDTFTestsMixin:
                 )
                 SELECT * FROM test_udtf(TABLE (SELECT id FROM t))
                 """
-            ).collect(),
+            ),
             [Row(a=6), Row(a=7)],
         )
 
-        self.assertEqual(
+        assertDataFrameEqual(
             self.spark.sql(
                 """
                 WITH t AS (
@@ -1152,28 +1127,23 @@ class BaseUDTFTestsMixin:
                 )
                 SELECT * FROM test_udtf(TABLE (t))
                 """
-            ).collect(),
+            ),
             [Row(a=6), Row(a=7)],
         )
 
     # TODO(SPARK-44233): Fix the subquery resolution.
     @unittest.skip("Fails to resolve the subquery.")
     def test_udtf_with_table_argument_lateral_join(self):
-        class TestUDTF:
-            def eval(self, row: Row):
-                if row["id"] > 5:
-                    yield row["id"],
-
-        func = udtf(TestUDTF, returnType="a: int")
+        func = self.udtf_for_table_argument()
         self.spark.udtf.register("test_udtf", func)
-        self.assertEqual(
+        assertDataFrameEqual(
             self.spark.sql(
                 """
                 SELECT * FROM
                   range(0, 8) AS t,
                   LATERAL test_udtf(TABLE (t))
                 """
-            ).collect(),
+            ),
             [Row(a=6), Row(a=7)],
         )
 
@@ -1198,8 +1168,8 @@ class BaseUDTFTestsMixin:
                 self.spark.sql(query).collect()
 
         with 
self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": True}):
-            self.assertEqual(
-                self.spark.sql(query).collect(),
+            assertDataFrameEqual(
+                self.spark.sql(query),
                 [
                     Row(a=0, b=0),
                     Row(a=1, b=0),
@@ -2539,8 +2509,10 @@ class BaseUDTFTestsMixin:
                     yield i, v.toJson()
 
         self.spark.udtf.register("test_udtf", TestUDTF)
-        rows = self.spark.sql('select i, s from 
test_udtf(parse_json(\'{"a":"b"}\'))').collect()
-        self.assertEqual(rows, [Row(i=n, s='{"a":"b"}') for n in range(10)])
+        assertDataFrameEqual(
+            self.spark.sql('select i, s from 
test_udtf(parse_json(\'{"a":"b"}\'))'),
+            [Row(i=n, s='{"a":"b"}') for n in range(10)],
+        )
 
     def test_udtf_with_nested_variant_input(self):
         # struct<variant>
@@ -2551,10 +2523,10 @@ class BaseUDTFTestsMixin:
                     yield i, v["v"].toJson()
 
         self.spark.udtf.register("test_udtf_struct", TestUDTFStruct)
-        rows = self.spark.sql(
+        res = self.spark.sql(
             "select i, s from test_udtf_struct(named_struct('v', 
parse_json('{\"a\":\"c\"}')))"
-        ).collect()
-        self.assertEqual(rows, [Row(i=n, s='{"a":"c"}') for n in range(10)])
+        )
+        assertDataFrameEqual(res, [Row(i=n, s='{"a":"c"}') for n in range(10)])
 
         # array<variant>
         @udtf(returnType="i int, s: string")
@@ -2564,10 +2536,8 @@ class BaseUDTFTestsMixin:
                     yield i, v[0].toJson()
 
         self.spark.udtf.register("test_udtf_array", TestUDTFArray)
-        rows = self.spark.sql(
-            'select i, s from 
test_udtf_array(array(parse_json(\'{"a":"d"}\')))'
-        ).collect()
-        self.assertEqual(rows, [Row(i=n, s='{"a":"d"}') for n in range(10)])
+        res = self.spark.sql('select i, s from 
test_udtf_array(array(parse_json(\'{"a":"d"}\')))')
+        assertDataFrameEqual(res, [Row(i=n, s='{"a":"d"}') for n in range(10)])
 
         # map<string, variant>
         @udtf(returnType="i int, s: string")
@@ -2577,10 +2547,10 @@ class BaseUDTFTestsMixin:
                     yield i, v["v"].toJson()
 
         self.spark.udtf.register("test_udtf_map", TestUDTFMap)
-        rows = self.spark.sql(
+        res = self.spark.sql(
             "select i, s from test_udtf_map(map('v', 
parse_json('{\"a\":\"e\"}')))"
-        ).collect()
-        self.assertEqual(rows, [Row(i=n, s='{"a":"e"}') for n in range(10)])
+        )
+        assertDataFrameEqual(res, [Row(i=n, s='{"a":"e"}') for n in range(10)])
 
     def test_udtf_with_variant_output(self):
         @udtf(returnType="i int, v: variant")
@@ -2591,8 +2561,8 @@ class BaseUDTFTestsMixin:
                     yield i, VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), 
bytes([1, 1, 0, 1, 97]))
 
         self.spark.udtf.register("test_udtf", TestUDTF)
-        rows = self.spark.sql("select i, to_json(v) from 
test_udtf(8)").collect()
-        self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for n 
in range(8)])
+        res = self.spark.sql("select i, to_json(v) from test_udtf(8)")
+        assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for 
n in range(8)])
 
     def test_udtf_with_nested_variant_output(self):
         # struct<variant>
@@ -2606,8 +2576,8 @@ class BaseUDTFTestsMixin:
                     }
 
         self.spark.udtf.register("test_udtf_struct", TestUDTFStruct)
-        rows = self.spark.sql("select i, to_json(v.v1) from 
test_udtf_struct(8)").collect()
-        self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for n 
in range(8)])
+        res = self.spark.sql("select i, to_json(v.v1) from 
test_udtf_struct(8)")
+        assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for 
n in range(8)])
 
         # array<variant>
         @udtf(returnType="i int, v: array<variant>")
@@ -2620,8 +2590,8 @@ class BaseUDTFTestsMixin:
                     ]
 
         self.spark.udtf.register("test_udtf_array", TestUDTFArray)
-        rows = self.spark.sql("select i, to_json(v[0]) from 
test_udtf_array(8)").collect()
-        self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(98 + n)}"}}') for n 
in range(8)])
+        res = self.spark.sql("select i, to_json(v[0]) from test_udtf_array(8)")
+        assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(98 + n)}"}}') for 
n in range(8)])
 
         # map<string, variant>
         @udtf(returnType="i int, v: map<string, variant>")
@@ -2634,8 +2604,8 @@ class BaseUDTFTestsMixin:
                     }
 
         self.spark.udtf.register("test_udtf_struct", TestUDTFStruct)
-        rows = self.spark.sql("select i, to_json(v['v1']) from 
test_udtf_struct(8)").collect()
-        self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for n 
in range(8)])
+        res = self.spark.sql("select i, to_json(v['v1']) from 
test_udtf_struct(8)")
+        assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for 
n in range(8)])
 
 
 class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to