Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19147#discussion_r137945084
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -3122,6 +3124,147 @@ def test_filtered_frame(self):
             self.assertTrue(pdf.empty)
     
     
    [email protected](not _have_arrow, "Arrow not installed")
    +class VectorizedUDFTests(ReusedPySparkTestCase):
    +
    +    @classmethod
    +    def setUpClass(cls):
    +        ReusedPySparkTestCase.setUpClass()
    +        cls.spark = SparkSession(cls.sc)
    +
    +    @classmethod
    +    def tearDownClass(cls):
    +        ReusedPySparkTestCase.tearDownClass()
    +        cls.spark.stop()
    +
    +    def test_vectorized_udf_basic(self):
    +        df = self.spark.range(10).select(
    +            col('id').cast('string').alias('str'),
    +            col('id').cast('int').alias('int'),
    +            col('id').alias('long'),
    +            col('id').cast('float').alias('float'),
    +            col('id').cast('double').alias('double'),
    +            col('id').cast('boolean').alias('bool'))
    +        f = lambda x: x
    +        str_f = pandas_udf(f, StringType())
    +        int_f = pandas_udf(f, IntegerType())
    +        long_f = pandas_udf(f, LongType())
    +        float_f = pandas_udf(f, FloatType())
    +        double_f = pandas_udf(f, DoubleType())
    +        bool_f = pandas_udf(f, BooleanType())
    +        res = df.select(str_f(col('str')), int_f(col('int')),
    +                        long_f(col('long')), float_f(col('float')),
    +                        double_f(col('double')), bool_f(col('bool')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_null_boolean(self):
    +        data = [(True,), (True,), (None,), (False,)]
    +        schema = StructType().add("bool", BooleanType())
    +        df = self.spark.createDataFrame(data, schema)
    +        bool_f = pandas_udf(lambda x: x, BooleanType())
    +        res = df.select(bool_f(col('bool')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_null_byte(self):
    +        data = [(None,), (2,), (3,), (4,)]
    +        schema = StructType().add("byte", ByteType())
    +        df = self.spark.createDataFrame(data, schema)
    +        byte_f = pandas_udf(lambda x: x, ByteType())
    +        res = df.select(byte_f(col('byte')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_null_short(self):
    +        data = [(None,), (2,), (3,), (4,)]
    +        schema = StructType().add("short", ShortType())
    +        df = self.spark.createDataFrame(data, schema)
    +        short_f = pandas_udf(lambda x: x, ShortType())
    +        res = df.select(short_f(col('short')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_null_int(self):
    +        data = [(None,), (2,), (3,), (4,)]
    +        schema = StructType().add("int", IntegerType())
    +        df = self.spark.createDataFrame(data, schema)
    +        int_f = pandas_udf(lambda x: x, IntegerType())
    +        res = df.select(int_f(col('int')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_null_long(self):
    +        data = [(None,), (2,), (3,), (4,)]
    +        schema = StructType().add("long", LongType())
    +        df = self.spark.createDataFrame(data, schema)
    +        long_f = pandas_udf(lambda x: x, LongType())
    +        res = df.select(long_f(col('long')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_null_float(self):
    +        data = [(3.0,), (5.0,), (-1.0,), (None,)]
    +        schema = StructType().add("float", FloatType())
    +        df = self.spark.createDataFrame(data, schema)
    +        float_f = pandas_udf(lambda x: x, FloatType())
    +        res = df.select(float_f(col('float')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_null_double(self):
    +        data = [(3.0,), (5.0,), (-1.0,), (None,)]
    +        schema = StructType().add("double", DoubleType())
    +        df = self.spark.createDataFrame(data, schema)
    +        double_f = pandas_udf(lambda x: x, DoubleType())
    +        res = df.select(double_f(col('double')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_null_string(self):
    +        data = [("foo",), (None,), ("bar",), ("bar",)]
    +        schema = StructType().add("str", StringType())
    +        df = self.spark.createDataFrame(data, schema)
    +        str_f = pandas_udf(lambda x: x, StringType())
    +        res = df.select(str_f(col('str')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +    def test_vectorized_udf_zero_parameter(self):
    +        import pandas as pd
    +        df = self.spark.range(100000)
    +        f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), LongType())
    +        res = df.select(f0())
    +        self.assertEquals(df.select(lit(1)).collect(), res.collect())
    +
    +    def test_vectorized_udf_datatype_string(self):
    +        import pandas as pd
    +        df = self.spark.range(100000)
    +        f0 = pandas_udf(lambda size: pd.Series(1).repeat(size), "long")
    +        res = df.select(f0())
    +        self.assertEquals(df.select(lit(1)).collect(), res.collect())
    +
    +    def test_vectorized_udf_complex(self):
    +        df = self.spark.range(10).select(
    +            col('id').cast('int').alias('a'),
    +            col('id').cast('int').alias('b'),
    +            col('id').cast('double').alias('c'))
    +        add = pandas_udf(lambda x, y: x + y, IntegerType())
    +        power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
    +        mul = pandas_udf(lambda x, y: x * y, DoubleType())
    +        res = df.select(add(col('a'), col('b')), power2(col('a')), 
mul(col('b'), col('c')))
    +        expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * 
c'))
    +        self.assertEquals(expected.collect(), res.collect())
    +
    +    def test_vectorized_udf_exception(self):
    +        df = self.spark.range(10)
    +        raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
    +        with QuietTest(self.sc):
    +            with self.assertRaisesRegexp(Exception, 'division( or modulo)? 
by zero'):
    +                df.select(raise_exception(col('id'))).collect()
    +
    +    def test_vectorized_udf_invalid_length(self):
    +        import pandas as pd
    +        df = self.spark.range(10)
    +        raise_exception = pandas_udf(lambda size: pd.Series(1), LongType())
    +        with QuietTest(self.sc):
    +            with self.assertRaisesRegexp(
    +                    Exception,
    +                    'The length of returned value should be the same as 
input value'):
    +                df.select(raise_exception()).collect()
    --- End diff --
    
    Also add a test for mixing udf and vectorized udf?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to