This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9054aa287601 [SPARK-50406][PYTHON][TESTS] Improve
pyspark.sql.tests.test_udtf
9054aa287601 is described below
commit 9054aa287601ab0596f6ea2a43c9176b848c06fc
Author: Xinrong Meng <[email protected]>
AuthorDate: Tue Nov 26 14:26:14 2024 +0800
[SPARK-50406][PYTHON][TESTS] Improve pyspark.sql.tests.test_udtf
### What changes were proposed in this pull request?
Improve pyspark.sql.tests.test_udtf by
- extract `udtf_for_table_argument` for code reuse
- use assertDataFrameEqual universally
### Why are the changes needed?
Code cleanup.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Test changes only.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48949 from xinrong-meng/impr_udtf_test.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Xinrong Meng <[email protected]>
---
python/pyspark/sql/tests/test_udtf.py | 212 +++++++++++++++-------------------
1 file changed, 91 insertions(+), 121 deletions(-)
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index f3f993fc6a78..206cfd7dc488 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -74,8 +74,7 @@ class BaseUDTFTestsMixin:
yield "hello", "world"
func = udtf(TestUDTF, returnType="c1: string, c2: string")
- rows = func().collect()
- self.assertEqual(rows, [Row(c1="hello", c2="world")])
+ assertDataFrameEqual(func(), [Row(c1="hello", c2="world")])
def test_udtf_yield_single_row_col(self):
class TestUDTF:
@@ -83,8 +82,7 @@ class BaseUDTFTestsMixin:
yield a,
func = udtf(TestUDTF, returnType="a: int")
- rows = func(lit(1)).collect()
- self.assertEqual(rows, [Row(a=1)])
+ assertDataFrameEqual(func(lit(1)), [Row(a=1)])
def test_udtf_yield_multi_cols(self):
class TestUDTF:
@@ -92,8 +90,7 @@ class BaseUDTFTestsMixin:
yield a, a + 1
func = udtf(TestUDTF, returnType="a: int, b: int")
- rows = func(lit(1)).collect()
- self.assertEqual(rows, [Row(a=1, b=2)])
+ assertDataFrameEqual(func(lit(1)), [Row(a=1, b=2)])
def test_udtf_yield_multi_rows(self):
class TestUDTF:
@@ -102,8 +99,7 @@ class BaseUDTFTestsMixin:
yield a + 1,
func = udtf(TestUDTF, returnType="a: int")
- rows = func(lit(1)).collect()
- self.assertEqual(rows, [Row(a=1), Row(a=2)])
+ assertDataFrameEqual(func(lit(1)), [Row(a=1), Row(a=2)])
def test_udtf_yield_multi_row_col(self):
class TestUDTF:
@@ -113,8 +109,8 @@ class BaseUDTFTestsMixin:
yield a, b, b - a
func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
- rows = func(lit(1), lit(2)).collect()
- self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1),
Row(a=1, b=2, c=1)])
+ res = func(lit(1), lit(2))
+ assertDataFrameEqual(res, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1),
Row(a=1, b=2, c=1)])
def test_udtf_decorator(self):
@udtf(returnType="a: int, b: int")
@@ -122,8 +118,7 @@ class BaseUDTFTestsMixin:
def eval(self, a: int):
yield a, a + 1
- rows = TestUDTF(lit(1)).collect()
- self.assertEqual(rows, [Row(a=1, b=2)])
+ assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1, b=2)])
def test_udtf_registration(self):
class TestUDTF:
@@ -135,9 +130,7 @@ class BaseUDTFTestsMixin:
func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
self.spark.udtf.register("testUDTF", func)
df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
- self.assertEqual(
- df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1,
b=2, c=1)]
- )
+ assertDataFrameEqual(df, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1),
Row(a=1, b=2, c=1)])
def test_udtf_with_lateral_join(self):
class TestUDTF:
@@ -150,10 +143,17 @@ class BaseUDTFTestsMixin:
df = self.spark.sql(
"SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL
testUDTF(a, b) f"
)
+ schema = StructType(
+ [
+ StructField("a", IntegerType(), True),
+ StructField("b", IntegerType(), True),
+ StructField("c", IntegerType(), True),
+ ]
+ )
expected = self.spark.createDataFrame(
- [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b",
"c"]
+ [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=schema
)
- self.assertEqual(df.collect(), expected.collect())
+ assertDataFrameEqual(df, expected)
def test_udtf_eval_with_return_stmt(self):
class TestUDTF:
@@ -161,8 +161,8 @@ class BaseUDTFTestsMixin:
return [(a, a + 1), (b, b + 1)]
func = udtf(TestUDTF, returnType="a: int, b: int")
- rows = func(lit(1), lit(2)).collect()
- self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+ res = func(lit(1), lit(2))
+ assertDataFrameEqual(res, [Row(a=1, b=2), Row(a=2, b=3)])
def test_udtf_eval_returning_non_tuple(self):
@udtf(returnType="a: int")
@@ -217,14 +217,14 @@ class BaseUDTFTestsMixin:
def eval(self, a: int):
...
- self.assertEqual(TestUDTF(lit(1)).collect(), [])
+ assertDataFrameEqual(TestUDTF(lit(1)), [])
@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a: int):
return
- self.assertEqual(TestUDTF(lit(1)).collect(), [])
+ assertDataFrameEqual(TestUDTF(lit(1)), [])
def test_udtf_with_conditional_return(self):
class TestUDTF:
@@ -234,8 +234,8 @@ class BaseUDTFTestsMixin:
func = udtf(TestUDTF, returnType="a: int")
self.spark.udtf.register("test_udtf", func)
- self.assertEqual(
- self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL
test_udtf(id)").collect(),
+ assertDataFrameEqual(
+ self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL
test_udtf(id)"),
[Row(id=6, a=6), Row(id=7, a=7)],
)
@@ -254,9 +254,9 @@ class BaseUDTFTestsMixin:
yield a,
yield None,
- self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)])
+ assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1), Row(a=None)])
df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"])
- self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(),
[Row(a=1, b=2)])
+ assertDataFrameEqual(TestUDTF(lit(1)).join(df, "a", "inner"),
[Row(a=1, b=2)])
assertDataFrameEqual(
TestUDTF(lit(1)).join(df, "a", "left"), [Row(a=None, b=None),
Row(a=1, b=2)]
)
@@ -267,10 +267,10 @@ class BaseUDTFTestsMixin:
def eval(self, a: int):
yield a,
- self.assertEqual(TestUDTF(lit(None)).collect(), [Row(a=None)])
+ assertDataFrameEqual(TestUDTF(lit(None)), [Row(a=None)])
self.spark.udtf.register("testUDTF", TestUDTF)
df = self.spark.sql("SELECT * FROM testUDTF(null)")
- self.assertEqual(df.collect(), [Row(a=None)])
+ assertDataFrameEqual(df, [Row(a=None)])
# These are expected error message substrings to be used in test cases
below.
tooManyPositionalArguments = "too many positional arguments"
@@ -366,8 +366,8 @@ class BaseUDTFTestsMixin:
def eval(self, a: int):
yield a, a + 1, self.key
- rows = TestUDTF(lit(1)).collect()
- self.assertEqual(rows, [Row(a=1, b=2, c="test")])
+ res = TestUDTF(lit(1))
+ assertDataFrameEqual(res, [Row(a=1, b=2, c="test")])
def test_udtf_terminate(self):
@udtf(returnType="key: string, value: float")
@@ -385,8 +385,8 @@ class BaseUDTFTestsMixin:
yield "count", float(self._count)
yield "avg", self._sum / self._count
- self.assertEqual(
- TestUDTF(lit(1)).collect(),
+ assertDataFrameEqual(
+ TestUDTF(lit(1)),
[Row(key="input", value=1), Row(key="count", value=1.0),
Row(key="avg", value=1.0)],
)
@@ -395,8 +395,8 @@ class BaseUDTFTestsMixin:
"SELECT id, key, value FROM range(0, 10, 1, 2), "
"LATERAL test_udtf(id) WHERE key != 'input'"
)
- self.assertEqual(
- df.collect(),
+ assertDataFrameEqual(
+ df,
[
Row(id=4, key="count", value=5.0),
Row(id=4, key="avg", value=2.0),
@@ -608,10 +608,8 @@ class BaseUDTFTestsMixin:
yield f"{person.name}: {person.age}",
self.spark.udtf.register("test_udtf", TestUDTF)
- self.assertEqual(
- self.spark.sql(
- "select * from test_udtf(named_struct('name', 'Alice', 'age',
1))"
- ).collect(),
+ assertDataFrameEqual(
+ self.spark.sql("select * from test_udtf(named_struct('name',
'Alice', 'age', 1))"),
[Row(x="Alice: 1")],
)
@@ -634,8 +632,8 @@ class BaseUDTFTestsMixin:
yield str(m),
self.spark.udtf.register("test_udtf", TestUDTF)
- self.assertEqual(
- self.spark.sql("select * from test_udtf(map('key',
'value'))").collect(),
+ assertDataFrameEqual(
+ self.spark.sql("select * from test_udtf(map('key', 'value'))"),
[Row(x="{'key': 'value'}")],
)
@@ -645,7 +643,7 @@ class BaseUDTFTestsMixin:
def eval(self, x: int):
yield {"a": x, "b": x + 1},
- self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x=Row(a=1, b=2))])
+ assertDataFrameEqual(TestUDTF(lit(1)), [Row(x=Row(a=1, b=2))])
def test_udtf_with_array_output_types(self):
@udtf(returnType="x: array<int>")
@@ -653,7 +651,7 @@ class BaseUDTFTestsMixin:
def eval(self, x: int):
yield [x, x + 1, x + 2],
- self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x=[1, 2, 3])])
+ assertDataFrameEqual(TestUDTF(lit(1)), [Row(x=[1, 2, 3])])
def test_udtf_with_map_output_types(self):
@udtf(returnType="x: map<int,string>")
@@ -661,7 +659,7 @@ class BaseUDTFTestsMixin:
def eval(self, x: int):
yield {x: str(x)},
- self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x={1: "1"})])
+ assertDataFrameEqual(TestUDTF(lit(1)), [Row(x={1: "1"})])
def test_udtf_with_empty_output_types(self):
@udtf(returnType=StructType())
@@ -1019,17 +1017,21 @@ class BaseUDTFTestsMixin:
)
def test_udtf_with_table_argument_query(self):
+ func = self.udtf_for_table_argument()
+ self.spark.udtf.register("test_udtf", func)
+ assertDataFrameEqual(
+ self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM
range(0, 8)))"),
+ [Row(a=6), Row(a=7)],
+ )
+
+ def udtf_for_table_argument(self):
class TestUDTF:
def eval(self, row: Row):
if row["id"] > 5:
yield row["id"],
func = udtf(TestUDTF, returnType="a: int")
- self.spark.udtf.register("test_udtf", func)
- self.assertEqual(
- self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM
range(0, 8)))").collect(),
- [Row(a=6), Row(a=7)],
- )
+ return func
def test_udtf_with_int_and_table_argument_query(self):
class TestUDTF:
@@ -1039,26 +1041,19 @@ class BaseUDTFTestsMixin:
func = udtf(TestUDTF, returnType="a: int")
self.spark.udtf.register("test_udtf", func)
- self.assertEqual(
- self.spark.sql(
- "SELECT * FROM test_udtf(5, TABLE (SELECT id FROM range(0,
8)))"
- ).collect(),
+ assertDataFrameEqual(
+ self.spark.sql("SELECT * FROM test_udtf(5, TABLE (SELECT id FROM
range(0, 8)))"),
[Row(a=6), Row(a=7)],
)
def test_udtf_with_table_argument_identifier(self):
- class TestUDTF:
- def eval(self, row: Row):
- if row["id"] > 5:
- yield row["id"],
-
- func = udtf(TestUDTF, returnType="a: int")
+ func = self.udtf_for_table_argument()
self.spark.udtf.register("test_udtf", func)
with self.tempView("v"):
self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id
FROM range(0, 8)")
- self.assertEqual(
- self.spark.sql("SELECT * FROM test_udtf(TABLE (v))").collect(),
+ assertDataFrameEqual(
+ self.spark.sql("SELECT * FROM test_udtf(TABLE (v))"),
[Row(a=6), Row(a=7)],
)
@@ -1073,44 +1068,29 @@ class BaseUDTFTestsMixin:
with self.tempView("v"):
self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id
FROM range(0, 8)")
- self.assertEqual(
- self.spark.sql("SELECT * FROM test_udtf(5, TABLE
(v))").collect(),
+ assertDataFrameEqual(
+ self.spark.sql("SELECT * FROM test_udtf(5, TABLE (v))"),
[Row(a=6), Row(a=7)],
)
def test_udtf_with_table_argument_unknown_identifier(self):
- class TestUDTF:
- def eval(self, row: Row):
- if row["id"] > 5:
- yield row["id"],
-
- func = udtf(TestUDTF, returnType="a: int")
+ func = self.udtf_for_table_argument()
self.spark.udtf.register("test_udtf", func)
with self.assertRaisesRegex(AnalysisException,
"TABLE_OR_VIEW_NOT_FOUND"):
self.spark.sql("SELECT * FROM test_udtf(TABLE (v))").collect()
def test_udtf_with_table_argument_malformed_query(self):
- class TestUDTF:
- def eval(self, row: Row):
- if row["id"] > 5:
- yield row["id"],
-
- func = udtf(TestUDTF, returnType="a: int")
+ func = self.udtf_for_table_argument()
self.spark.udtf.register("test_udtf", func)
with self.assertRaisesRegex(AnalysisException,
"TABLE_OR_VIEW_NOT_FOUND"):
self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT * FROM
v))").collect()
def test_udtf_with_table_argument_cte_inside(self):
- class TestUDTF:
- def eval(self, row: Row):
- if row["id"] > 5:
- yield row["id"],
-
- func = udtf(TestUDTF, returnType="a: int")
+ func = self.udtf_for_table_argument()
self.spark.udtf.register("test_udtf", func)
- self.assertEqual(
+ assertDataFrameEqual(
self.spark.sql(
"""
SELECT * FROM test_udtf(TABLE (
@@ -1120,19 +1100,14 @@ class BaseUDTFTestsMixin:
SELECT * FROM t
))
"""
- ).collect(),
+ ),
[Row(a=6), Row(a=7)],
)
def test_udtf_with_table_argument_cte_outside(self):
- class TestUDTF:
- def eval(self, row: Row):
- if row["id"] > 5:
- yield row["id"],
-
- func = udtf(TestUDTF, returnType="a: int")
+ func = self.udtf_for_table_argument()
self.spark.udtf.register("test_udtf", func)
- self.assertEqual(
+ assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
@@ -1140,11 +1115,11 @@ class BaseUDTFTestsMixin:
)
SELECT * FROM test_udtf(TABLE (SELECT id FROM t))
"""
- ).collect(),
+ ),
[Row(a=6), Row(a=7)],
)
- self.assertEqual(
+ assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
@@ -1152,28 +1127,23 @@ class BaseUDTFTestsMixin:
)
SELECT * FROM test_udtf(TABLE (t))
"""
- ).collect(),
+ ),
[Row(a=6), Row(a=7)],
)
# TODO(SPARK-44233): Fix the subquery resolution.
@unittest.skip("Fails to resolve the subquery.")
def test_udtf_with_table_argument_lateral_join(self):
- class TestUDTF:
- def eval(self, row: Row):
- if row["id"] > 5:
- yield row["id"],
-
- func = udtf(TestUDTF, returnType="a: int")
+ func = self.udtf_for_table_argument()
self.spark.udtf.register("test_udtf", func)
- self.assertEqual(
+ assertDataFrameEqual(
self.spark.sql(
"""
SELECT * FROM
range(0, 8) AS t,
LATERAL test_udtf(TABLE (t))
"""
- ).collect(),
+ ),
[Row(a=6), Row(a=7)],
)
@@ -1198,8 +1168,8 @@ class BaseUDTFTestsMixin:
self.spark.sql(query).collect()
with
self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": True}):
- self.assertEqual(
- self.spark.sql(query).collect(),
+ assertDataFrameEqual(
+ self.spark.sql(query),
[
Row(a=0, b=0),
Row(a=1, b=0),
@@ -2539,8 +2509,10 @@ class BaseUDTFTestsMixin:
yield i, v.toJson()
self.spark.udtf.register("test_udtf", TestUDTF)
- rows = self.spark.sql('select i, s from
test_udtf(parse_json(\'{"a":"b"}\'))').collect()
- self.assertEqual(rows, [Row(i=n, s='{"a":"b"}') for n in range(10)])
+ assertDataFrameEqual(
+ self.spark.sql('select i, s from
test_udtf(parse_json(\'{"a":"b"}\'))'),
+ [Row(i=n, s='{"a":"b"}') for n in range(10)],
+ )
def test_udtf_with_nested_variant_input(self):
# struct<variant>
@@ -2551,10 +2523,10 @@ class BaseUDTFTestsMixin:
yield i, v["v"].toJson()
self.spark.udtf.register("test_udtf_struct", TestUDTFStruct)
- rows = self.spark.sql(
+ res = self.spark.sql(
"select i, s from test_udtf_struct(named_struct('v',
parse_json('{\"a\":\"c\"}')))"
- ).collect()
- self.assertEqual(rows, [Row(i=n, s='{"a":"c"}') for n in range(10)])
+ )
+ assertDataFrameEqual(res, [Row(i=n, s='{"a":"c"}') for n in range(10)])
# array<variant>
@udtf(returnType="i int, s: string")
@@ -2564,10 +2536,8 @@ class BaseUDTFTestsMixin:
yield i, v[0].toJson()
self.spark.udtf.register("test_udtf_array", TestUDTFArray)
- rows = self.spark.sql(
- 'select i, s from
test_udtf_array(array(parse_json(\'{"a":"d"}\')))'
- ).collect()
- self.assertEqual(rows, [Row(i=n, s='{"a":"d"}') for n in range(10)])
+ res = self.spark.sql('select i, s from
test_udtf_array(array(parse_json(\'{"a":"d"}\')))')
+ assertDataFrameEqual(res, [Row(i=n, s='{"a":"d"}') for n in range(10)])
# map<string, variant>
@udtf(returnType="i int, s: string")
@@ -2577,10 +2547,10 @@ class BaseUDTFTestsMixin:
yield i, v["v"].toJson()
self.spark.udtf.register("test_udtf_map", TestUDTFMap)
- rows = self.spark.sql(
+ res = self.spark.sql(
"select i, s from test_udtf_map(map('v',
parse_json('{\"a\":\"e\"}')))"
- ).collect()
- self.assertEqual(rows, [Row(i=n, s='{"a":"e"}') for n in range(10)])
+ )
+ assertDataFrameEqual(res, [Row(i=n, s='{"a":"e"}') for n in range(10)])
def test_udtf_with_variant_output(self):
@udtf(returnType="i int, v: variant")
@@ -2591,8 +2561,8 @@ class BaseUDTFTestsMixin:
yield i, VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]),
bytes([1, 1, 0, 1, 97]))
self.spark.udtf.register("test_udtf", TestUDTF)
- rows = self.spark.sql("select i, to_json(v) from
test_udtf(8)").collect()
- self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for n
in range(8)])
+ res = self.spark.sql("select i, to_json(v) from test_udtf(8)")
+ assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for
n in range(8)])
def test_udtf_with_nested_variant_output(self):
# struct<variant>
@@ -2606,8 +2576,8 @@ class BaseUDTFTestsMixin:
}
self.spark.udtf.register("test_udtf_struct", TestUDTFStruct)
- rows = self.spark.sql("select i, to_json(v.v1) from
test_udtf_struct(8)").collect()
- self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for n
in range(8)])
+ res = self.spark.sql("select i, to_json(v.v1) from
test_udtf_struct(8)")
+ assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for
n in range(8)])
# array<variant>
@udtf(returnType="i int, v: array<variant>")
@@ -2620,8 +2590,8 @@ class BaseUDTFTestsMixin:
]
self.spark.udtf.register("test_udtf_array", TestUDTFArray)
- rows = self.spark.sql("select i, to_json(v[0]) from
test_udtf_array(8)").collect()
- self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(98 + n)}"}}') for n
in range(8)])
+ res = self.spark.sql("select i, to_json(v[0]) from test_udtf_array(8)")
+ assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(98 + n)}"}}') for
n in range(8)])
# map<string, variant>
@udtf(returnType="i int, v: map<string, variant>")
@@ -2634,8 +2604,8 @@ class BaseUDTFTestsMixin:
}
self.spark.udtf.register("test_udtf_struct", TestUDTFStruct)
- rows = self.spark.sql("select i, to_json(v['v1']) from
test_udtf_struct(8)").collect()
- self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for n
in range(8)])
+ res = self.spark.sql("select i, to_json(v['v1']) from
test_udtf_struct(8)")
+ assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for
n in range(8)])
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]