Github user BryanCutler commented on a diff in the pull request:
https://github.com/apache/spark/pull/21650#discussion_r202867741
--- Diff: python/pyspark/sql/tests.py ---
@@ -5060,6 +5049,144 @@ def test_type_annotation(self):
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'],
returnType='bigint')('id'))
self.assertEqual(df.first()[0], 0)
+ def test_mixed_udf(self):
+ import pandas as pd
+ from pyspark.sql.functions import udf, pandas_udf
+
+ df = self.spark.range(0, 1).toDF('v')
+
+ # Test mixture of multiple UDFs and Pandas UDFs
+
+ @udf('int')
+ def f1(x):
+ assert type(x) == int
+ return x + 1
+
+ @pandas_udf('int')
+ def f2(x):
+ assert type(x) == pd.Series
+ return x + 10
+
+ @udf('int')
+ def f3(x):
+ assert type(x) == int
+ return x + 100
+
+ @pandas_udf('int')
+ def f4(x):
+ assert type(x) == pd.Series
+ return x + 1000
+
+ # Test mixed udfs in a single projection
+ df1 = df.withColumn('f1', f1(df['v']))
+ df1 = df1.withColumn('f2', f2(df1['v']))
+ df1 = df1.withColumn('f3', f3(df1['v']))
+ df1 = df1.withColumn('f4', f4(df1['v']))
+ df1 = df1.withColumn('f2_f1', f2(df1['f1']))
+ df1 = df1.withColumn('f3_f1', f3(df1['f1']))
+ df1 = df1.withColumn('f4_f1', f4(df1['f1']))
+ df1 = df1.withColumn('f3_f2', f3(df1['f2']))
+ df1 = df1.withColumn('f4_f2', f4(df1['f2']))
+ df1 = df1.withColumn('f4_f3', f4(df1['f3']))
+ df1 = df1.withColumn('f3_f2_f1', f3(df1['f2_f1']))
+ df1 = df1.withColumn('f4_f2_f1', f4(df1['f2_f1']))
+ df1 = df1.withColumn('f4_f3_f1', f4(df1['f3_f1']))
+ df1 = df1.withColumn('f4_f3_f2', f4(df1['f3_f2']))
+ df1 = df1.withColumn('f4_f3_f2_f1', f4(df1['f3_f2_f1']))
+
+ # Test mixed udfs in a single expression
+ df2 = df.withColumn('f1', f1(df['v']))
+ df2 = df2.withColumn('f2', f2(df['v']))
+ df2 = df2.withColumn('f3', f3(df['v']))
+ df2 = df2.withColumn('f4', f4(df['v']))
+ df2 = df2.withColumn('f2_f1', f2(f1(df['v'])))
+ df2 = df2.withColumn('f3_f1', f3(f1(df['v'])))
+ df2 = df2.withColumn('f4_f1', f4(f1(df['v'])))
+ df2 = df2.withColumn('f3_f2', f3(f2(df['v'])))
+ df2 = df2.withColumn('f4_f2', f4(f2(df['v'])))
+ df2 = df2.withColumn('f4_f3', f4(f3(df['v'])))
+ df2 = df2.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
+ df2 = df2.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
+ df2 = df2.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
+ df2 = df2.withColumn('f4_f3_f2', f4(f3(f2(df['v']))))
+ df2 = df2.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
+
+ # expected result
+ df3 = df.withColumn('f1', df['v'] + 1)
+ df3 = df3.withColumn('f2', df['v'] + 10)
+ df3 = df3.withColumn('f3', df['v'] + 100)
+ df3 = df3.withColumn('f4', df['v'] + 1000)
+ df3 = df3.withColumn('f2_f1', df['v'] + 11)
+ df3 = df3.withColumn('f3_f1', df['v'] + 101)
+ df3 = df3.withColumn('f4_f1', df['v'] + 1001)
+ df3 = df3.withColumn('f3_f2', df['v'] + 110)
+ df3 = df3.withColumn('f4_f2', df['v'] + 1010)
+ df3 = df3.withColumn('f4_f3', df['v'] + 1100)
+ df3 = df3.withColumn('f3_f2_f1', df['v'] + 111)
+ df3 = df3.withColumn('f4_f2_f1', df['v'] + 1011)
+ df3 = df3.withColumn('f4_f3_f1', df['v'] + 1101)
+ df3 = df3.withColumn('f4_f3_f2', df['v'] + 1110)
+ df3 = df3.withColumn('f4_f3_f2_f1', df['v'] + 1111)
+
+ self.assertEquals(df3.collect(), df1.collect())
+ self.assertEquals(df3.collect(), df2.collect())
+
+ def test_mixed_udf_and_sql(self):
+ import pandas as pd
+ from pyspark.sql.functions import udf, pandas_udf
+
+ df = self.spark.range(0, 1).toDF('v')
+
+ # Test mixture of UDFs, Pandas UDFs and SQL expression.
+
+ @udf('int')
+ def f1(x):
+ assert type(x) == int
+ return x + 1
+
+ def f2(x):
+ return x + 10
+
+ @pandas_udf('int')
+ def f3(x):
+ assert type(x) == pd.Series
+ return x + 100
+
+ df1 = df.withColumn('f1', f1(df['v']))
+ df1 = df1.withColumn('f2', f2(df['v']))
+ df1 = df1.withColumn('f3', f3(df['v']))
+ df1 = df1.withColumn('f1_f2', f1(f2(df['v'])))
+ df1 = df1.withColumn('f1_f3', f1(f3(df['v'])))
+ df1 = df1.withColumn('f2_f1', f2(f1(df['v'])))
+ df1 = df1.withColumn('f2_f3', f2(f3(df['v'])))
+ df1 = df1.withColumn('f3_f1', f3(f1(df['v'])))
+ df1 = df1.withColumn('f3_f2', f3(f2(df['v'])))
+ df1 = df1.withColumn('f1_f2_f3', f1(f2(f3(df['v']))))
+ df1 = df1.withColumn('f1_f3_f2', f1(f3(f2(df['v']))))
+ df1 = df1.withColumn('f2_f1_f3', f2(f1(f3(df['v']))))
+ df1 = df1.withColumn('f2_f3_f1', f2(f3(f1(df['v']))))
+ df1 = df1.withColumn('f3_f1_f2', f3(f1(f2(df['v']))))
+ df1 = df1.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
+
+ # expected result
+ df2 = df.withColumn('f1', df['v'] + 1)
+ df2 = df2.withColumn('f2', df['v'] + 10)
+ df2 = df2.withColumn('f3', df['v'] + 100)
+ df2 = df2.withColumn('f1_f2', df['v'] + 11)
+ df2 = df2.withColumn('f1_f3', df['v'] + 101)
+ df2 = df2.withColumn('f2_f1', df['v'] + 11)
+ df2 = df2.withColumn('f2_f3', df['v'] + 110)
+ df2 = df2.withColumn('f3_f1', df['v'] + 101)
+ df2 = df2.withColumn('f3_f2', df['v'] + 110)
+ df2 = df2.withColumn('f1_f2_f3', df['v'] + 111)
+ df2 = df2.withColumn('f1_f3_f2', df['v'] + 111)
+ df2 = df2.withColumn('f2_f1_f3', df['v'] + 111)
+ df2 = df2.withColumn('f2_f3_f1', df['v'] + 111)
+ df2 = df2.withColumn('f3_f1_f2', df['v'] + 111)
+ df2 = df2.withColumn('f3_f2_f1', df['v'] + 111)
+
+ self.assertEquals(df2.collect(), df1.collect())
--- End diff --
I think it would be better to combine this test with the one above and
construct it as a list of cases that you could loop over instead of so many
blocks of `withColumn`s. Something like
```
class TestCase():
def __init__(self, col_name, col_expected, col_projection,
col_udf_expression, col_sql_expression):
...
cases = [
TestCase('f4_f3_f2_f1', df['v'] + 1111, f4(df1['f3_f2_f1']),
f4(f3(f2(f1(df['v']))), f4(f3(f1(df['v']) + 10)))
...]
expected_df = df
for case in cases:
expected_df = expected_df.with_column(case.col_name, case.col_expected)
....
self.assertEquals(expected_df.collect(), projection_df.collect())
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]