Repository: spark
Updated Branches:
  refs/heads/branch-2.4 82990e5ef -> 426c2bd35


[SPARK-23401][PYTHON][TESTS] Add more data types for PandasUDFTests

## What changes were proposed in this pull request?
Add more data types for Pandas UDF Tests for PySpark SQL

## How was this patch tested?
manual tests

Closes #22568 from AlexanderKoryagin/new_types_for_pandas_udf_tests.

Lead-authored-by: Aleksandr Koriagin <aleksandr_koria...@epam.com>
Co-authored-by: hyukjinkwon <gurwls...@apache.org>
Co-authored-by: Alexander Koryagin <alexanderkorya...@users.noreply.github.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>
(cherry picked from commit 30f5d0f2ddfe56266ea81e4255f9b4f373dab237)
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/426c2bd3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/426c2bd3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/426c2bd3

Branch: refs/heads/branch-2.4
Commit: 426c2bd35937add1a26e77d2f2879f0e3f0c2f45
Parents: 82990e5
Author: Aleksandr Koriagin <aleksandr_koria...@epam.com>
Authored: Mon Oct 1 17:18:45 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Mon Oct 1 17:19:00 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 107 +++++++++++++++++++++++++++++----------
 1 file changed, 79 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/426c2bd3/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index dece1da..690035a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -5478,32 +5478,81 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
             .withColumn("v", explode(col('vs'))).drop('vs')
 
     def test_supported_types(self):
-        from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
-        df = self.data.withColumn("arr", array(col("id")))
+        from decimal import Decimal
+        from distutils.version import LooseVersion
+        import pyarrow as pa
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
 
-        # Different forms of group map pandas UDF, results of these are the 
same
+        values = [
+            1, 2, 3,
+            4, 5, 1.1,
+            2.2, Decimal(1.123),
+            [1, 2, 2], True, 'hello'
+        ]
+        output_fields = [
+            ('id', IntegerType()), ('byte', ByteType()), ('short', 
ShortType()),
+            ('int', IntegerType()), ('long', LongType()), ('float', 
FloatType()),
+            ('double', DoubleType()), ('decim', DecimalType(10, 3)),
+            ('array', ArrayType(IntegerType())), ('bool', BooleanType()), 
('str', StringType())
+        ]
 
-        output_schema = StructType(
-            [StructField('id', LongType()),
-             StructField('v', IntegerType()),
-             StructField('arr', ArrayType(LongType())),
-             StructField('v1', DoubleType()),
-             StructField('v2', LongType())])
+        # TODO: Add BinaryType to variables above once minimum pyarrow version 
is 0.10.0
+        if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"):
+            values.append(bytearray([0x01, 0x02]))
+            output_fields.append(('bin', BinaryType()))
 
+        output_schema = StructType([StructField(*x) for x in output_fields])
+        df = self.spark.createDataFrame([values], schema=output_schema)
+
+        # Different forms of group map pandas UDF, results of these are the 
same
         udf1 = pandas_udf(
-            lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+            lambda pdf: pdf.assign(
+                byte=pdf.byte * 2,
+                short=pdf.short * 2,
+                int=pdf.int * 2,
+                long=pdf.long * 2,
+                float=pdf.float * 2,
+                double=pdf.double * 2,
+                decim=pdf.decim * 2,
+                bool=False if pdf.bool else True,
+                str=pdf.str + 'there',
+                array=pdf.array,
+            ),
             output_schema,
             PandasUDFType.GROUPED_MAP
         )
 
         udf2 = pandas_udf(
-            lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + 
pdf.id),
+            lambda _, pdf: pdf.assign(
+                byte=pdf.byte * 2,
+                short=pdf.short * 2,
+                int=pdf.int * 2,
+                long=pdf.long * 2,
+                float=pdf.float * 2,
+                double=pdf.double * 2,
+                decim=pdf.decim * 2,
+                bool=False if pdf.bool else True,
+                str=pdf.str + 'there',
+                array=pdf.array,
+            ),
             output_schema,
             PandasUDFType.GROUPED_MAP
         )
 
         udf3 = pandas_udf(
-            lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, 
v2=pdf.v + pdf.id),
+            lambda key, pdf: pdf.assign(
+                id=key[0],
+                byte=pdf.byte * 2,
+                short=pdf.short * 2,
+                int=pdf.int * 2,
+                long=pdf.long * 2,
+                float=pdf.float * 2,
+                double=pdf.double * 2,
+                decim=pdf.decim * 2,
+                bool=False if pdf.bool else True,
+                str=pdf.str + 'there',
+                array=pdf.array,
+            ),
             output_schema,
             PandasUDFType.GROUPED_MAP
         )
@@ -5667,24 +5716,26 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
                     pandas_udf(lambda x, y: x, DoubleType(), 
PandasUDFType.SCALAR))
 
     def test_unsupported_types(self):
+        from distutils.version import LooseVersion
+        import pyarrow as pa
         from pyspark.sql.functions import pandas_udf, PandasUDFType
-        schema = StructType(
-            [StructField("id", LongType(), True),
-             StructField("map", MapType(StringType(), IntegerType()), True)])
-        with QuietTest(self.sc):
-            with self.assertRaisesRegexp(
-                    NotImplementedError,
-                    'Invalid returnType.*grouped map Pandas UDF.*MapType'):
-                pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
 
-        schema = StructType(
-            [StructField("id", LongType(), True),
-             StructField("arr_ts", ArrayType(TimestampType()), True)])
-        with QuietTest(self.sc):
-            with self.assertRaisesRegexp(
-                    NotImplementedError,
-                    'Invalid returnType.*grouped map Pandas 
UDF.*ArrayType.*TimestampType'):
-                pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
+        common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*'
+        unsupported_types = [
+            StructField('map', MapType(StringType(), IntegerType())),
+            StructField('arr_ts', ArrayType(TimestampType())),
+            StructField('null', NullType()),
+        ]
+
+        # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            unsupported_types.append(StructField('bin', BinaryType()))
+
+        for unsupported_type in unsupported_types:
+            schema = StructType([StructField('id', LongType(), True), 
unsupported_type])
+            with QuietTest(self.sc):
+                with self.assertRaisesRegexp(NotImplementedError, 
common_err_msg):
+                    pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
 
     # Regression test for SPARK-23314
     def test_timestamp_dst(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to