Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19147#discussion_r137945084
--- Diff: python/pyspark/sql/tests.py ---
@@ -3122,6 +3124,147 @@ def test_filtered_frame(self):
self.assertTrue(pdf.empty)
[email protected](not _have_arrow, "Arrow not installed")
+class VectorizedUDFTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.spark = SparkSession(cls.sc)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ cls.spark.stop()
+
+ def test_vectorized_udf_basic(self):
+ df = self.spark.range(10).select(
+ col('id').cast('string').alias('str'),
+ col('id').cast('int').alias('int'),
+ col('id').alias('long'),
+ col('id').cast('float').alias('float'),
+ col('id').cast('double').alias('double'),
+ col('id').cast('boolean').alias('bool'))
+ f = lambda x: x
+ str_f = pandas_udf(f, StringType())
+ int_f = pandas_udf(f, IntegerType())
+ long_f = pandas_udf(f, LongType())
+ float_f = pandas_udf(f, FloatType())
+ double_f = pandas_udf(f, DoubleType())
+ bool_f = pandas_udf(f, BooleanType())
+ res = df.select(str_f(col('str')), int_f(col('int')),
+ long_f(col('long')), float_f(col('float')),
+ double_f(col('double')), bool_f(col('bool')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_boolean(self):
+ data = [(True,), (True,), (None,), (False,)]
+ schema = StructType().add("bool", BooleanType())
+ df = self.spark.createDataFrame(data, schema)
+ bool_f = pandas_udf(lambda x: x, BooleanType())
+ res = df.select(bool_f(col('bool')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_byte(self):
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("byte", ByteType())
+ df = self.spark.createDataFrame(data, schema)
+ byte_f = pandas_udf(lambda x: x, ByteType())
+ res = df.select(byte_f(col('byte')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_short(self):
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("short", ShortType())
+ df = self.spark.createDataFrame(data, schema)
+ short_f = pandas_udf(lambda x: x, ShortType())
+ res = df.select(short_f(col('short')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_int(self):
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("int", IntegerType())
+ df = self.spark.createDataFrame(data, schema)
+ int_f = pandas_udf(lambda x: x, IntegerType())
+ res = df.select(int_f(col('int')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_long(self):
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("long", LongType())
+ df = self.spark.createDataFrame(data, schema)
+ long_f = pandas_udf(lambda x: x, LongType())
+ res = df.select(long_f(col('long')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_float(self):
+ data = [(3.0,), (5.0,), (-1.0,), (None,)]
+ schema = StructType().add("float", FloatType())
+ df = self.spark.createDataFrame(data, schema)
+ float_f = pandas_udf(lambda x: x, FloatType())
+ res = df.select(float_f(col('float')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_double(self):
+ data = [(3.0,), (5.0,), (-1.0,), (None,)]
+ schema = StructType().add("double", DoubleType())
+ df = self.spark.createDataFrame(data, schema)
+ double_f = pandas_udf(lambda x: x, DoubleType())
+ res = df.select(double_f(col('double')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_string(self):
+ data = [("foo",), (None,), ("bar",), ("bar",)]
+ schema = StructType().add("str", StringType())
+ df = self.spark.createDataFrame(data, schema)
+ str_f = pandas_udf(lambda x: x, StringType())
+ res = df.select(str_f(col('str')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_zero_parameter(self):
+ import pandas as pd
+ df = self.spark.range(100000)
+ f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType())
+ res = df.select(f0())
+ self.assertEquals(df.select(lit(1)).collect(), res.collect())
+
+ def test_vectorized_udf_datatype_string(self):
+ import pandas as pd
+ df = self.spark.range(100000)
+ f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), "long")
+ res = df.select(f0())
+ self.assertEquals(df.select(lit(1)).collect(), res.collect())
+
+ def test_vectorized_udf_complex(self):
+ df = self.spark.range(10).select(
+ col('id').cast('int').alias('a'),
+ col('id').cast('int').alias('b'),
+ col('id').cast('double').alias('c'))
+ add = pandas_udf(lambda x, y: x + y, IntegerType())
+ power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
+ mul = pandas_udf(lambda x, y: x * y, DoubleType())
+ res = df.select(add(col('a'), col('b')), power2(col('a')),
mul(col('b'), col('c')))
+ expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b *
c'))
+ self.assertEquals(expected.collect(), res.collect())
+
+ def test_vectorized_udf_exception(self):
+ df = self.spark.range(10)
+ raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, 'division( or modulo)?
by zero'):
+ df.select(raise_exception(col('id'))).collect()
+
+ def test_vectorized_udf_invalid_length(self):
+ import pandas as pd
+ df = self.spark.range(10)
+ raise_exception = pandas_udf(lambda size: pd.Series(1), LongType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ Exception,
+ 'The length of returned value should be the same as
input value'):
+ df.select(raise_exception()).collect()
--- End diff --
Also add a test for mixing udf and vectorized udf?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]