Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/18732#discussion_r143740157
--- Diff: python/pyspark/sql/tests.py ---
@@ -3376,6 +3376,151 @@ def test_vectorized_udf_empty_partition(self):
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())
+ def test_vectorized_udf_varargs(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)],
2))
+ f = pandas_udf(lambda *v: v[0], LongType())
+ res = df.select(f(col('id')))
+ self.assertEquals(df.collect(), res.collect())
+
+
[email protected](not _have_pandas or not _have_arrow, "Pandas or Arrow not
installed")
+class GroupbyApplyTests(ReusedPySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.spark = SparkSession(cls.sc)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ cls.spark.stop()
+
+ def assertFramesEqual(self, expected, result):
+ msg = ("DataFrames are not equal: " +
+ ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) +
+ ("\n\nResult:\n%s\n%s" % (result, result.dtypes)))
+ self.assertTrue(expected.equals(result), msg=msg)
+
+ @property
+ def data(self):
+ from pyspark.sql.functions import array, explode, col, lit
+ return self.spark.range(10).toDF('id') \
+ .withColumn("vs", array([lit(i) for i in range(20, 30)])) \
+ .withColumn("v", explode(col('vs'))).drop('vs')
+
+ def test_simple(self):
+ from pyspark.sql.functions import pandas_udf
+ df = self.data
+
+ foo_udf = pandas_udf(
+ lambda df: df.assign(v1=df.v * df.id * 1.0, v2=df.v + df.id),
+ StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('v1', DoubleType()),
+ StructField('v2', LongType())]))
+
+ result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
+ expected =
df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
+ self.assertFramesEqual(expected, result)
+
+ def test_decorator(self):
+ from pyspark.sql.functions import pandas_udf
+ df = self.data
+
+ @pandas_udf(StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('v1', DoubleType()),
+ StructField('v2', LongType())]))
+ def foo(df):
+ return df.assign(v1=df.v * df.id * 1.0, v2=df.v + df.id)
+
+ result = df.groupby('id').apply(foo).sort('id').toPandas()
+ expected =
df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
+ self.assertFramesEqual(expected, result)
+
+ def test_coerce(self):
+ from pyspark.sql.functions import pandas_udf
+ df = self.data
+
+ foo = pandas_udf(
+ lambda df: df,
+ StructType([StructField('id', LongType()), StructField('v',
DoubleType())]))
+
+ result = df.groupby('id').apply(foo).sort('id').toPandas()
+ expected =
df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
+ expected = expected.assign(v=expected.v.astype('float64'))
+ self.assertFramesEqual(expected, result)
+
+ def test_complex_groupby(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.data
+
+ @pandas_udf(StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('norm', DoubleType())]))
+ def normalize(pdf):
+ v = pdf.v
+ return pdf.assign(norm=(v - v.mean()) / v.std())
+
+ result = df.groupby(col('id') % 2 ==
0).apply(normalize).sort('id', 'v').toPandas()
+ pdf = df.toPandas()
+ expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func)
+ expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
+ expected = expected.assign(norm=expected.norm.astype('float64'))
+ self.assertFramesEqual(expected, result)
+
+ def test_empty_groupby(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.data
+
+ @pandas_udf(StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('norm', DoubleType())]))
+ def normalize(pdf):
+ v = pdf.v
+ return pdf.assign(norm=(v - v.mean()) / v.std())
+
+ result = df.groupby().apply(normalize).sort('id', 'v').toPandas()
+ pdf = df.toPandas()
+ expected = normalize.func(pdf)
+ expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
+ expected = expected.assign(norm=expected.norm.astype('float64'))
+ self.assertFramesEqual(expected, result)
+
+ def test_wrong_return_type(self):
+ from pyspark.sql.functions import pandas_udf
+ df = self.data
+
+ foo = pandas_udf(
+ lambda df: df,
--- End diff --
Fixed.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]