Repository: spark Updated Branches: refs/heads/branch-2.4 82990e5ef -> 426c2bd35
[SPARK-23401][PYTHON][TESTS] Add more data types for PandasUDFTests ## What changes were proposed in this pull request? Add more data types for Pandas UDF Tests for PySpark SQL ## How was this patch tested? manual tests Closes #22568 from AlexanderKoryagin/new_types_for_pandas_udf_tests. Lead-authored-by: Aleksandr Koriagin <aleksandr_koria...@epam.com> Co-authored-by: hyukjinkwon <gurwls...@apache.org> Co-authored-by: Alexander Koryagin <alexanderkorya...@users.noreply.github.com> Signed-off-by: hyukjinkwon <gurwls...@apache.org> (cherry picked from commit 30f5d0f2ddfe56266ea81e4255f9b4f373dab237) Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/426c2bd3 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/426c2bd3 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/426c2bd3 Branch: refs/heads/branch-2.4 Commit: 426c2bd35937add1a26e77d2f2879f0e3f0c2f45 Parents: 82990e5 Author: Aleksandr Koriagin <aleksandr_koria...@epam.com> Authored: Mon Oct 1 17:18:45 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Mon Oct 1 17:19:00 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 107 +++++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/426c2bd3/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dece1da..690035a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5478,32 +5478,81 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): .withColumn("v", explode(col('vs'))).drop('vs') def test_supported_types(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col - df = self.data.withColumn("arr", array(col("id"))) + from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, PandasUDFType - # Different forms of group map pandas UDF, results of these are the same + values = [ + 1, 2, 3, + 4, 5, 1.1, + 2.2, Decimal(1.123), + [1, 2, 2], True, 'hello' + ] + output_fields = [ + ('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()), + ('int', IntegerType()), ('long', LongType()), ('float', FloatType()), + ('double', DoubleType()), ('decim', DecimalType(10, 3)), + ('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()) + ] - output_schema = StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType())), - StructField('v1', DoubleType()), - StructField('v2', LongType())]) + # TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"): + values.append(bytearray([0x01, 0x02])) + output_fields.append(('bin', BinaryType())) + output_schema = StructType([StructField(*x) for x in output_fields]) + df = self.spark.createDataFrame([values], schema=output_schema) + + # Different forms of group map pandas UDF, results of these are the same udf1 = pandas_udf( - lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + lambda pdf: pdf.assign( + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), output_schema, PandasUDFType.GROUPED_MAP ) udf2 = pandas_udf( - lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + lambda _, pdf: pdf.assign( + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), output_schema, PandasUDFType.GROUPED_MAP ) udf3 = pandas_udf( - lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + lambda key, pdf: pdf.assign( + id=key[0], + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), output_schema, PandasUDFType.GROUPED_MAP ) @@ -5667,24 +5716,26 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): + from distutils.version import LooseVersion + import pyarrow as pa from pyspark.sql.functions import pandas_udf, PandasUDFType - schema = StructType( - [StructField("id", LongType(), True), - StructField("map", MapType(StringType(), IntegerType()), True)]) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*grouped map Pandas UDF.*MapType'): - pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) - schema = StructType( - [StructField("id", LongType(), True), - StructField("arr_ts", ArrayType(TimestampType()), True)]) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'): - pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*' + unsupported_types = [ + StructField('map', MapType(StringType(), IntegerType())), + StructField('arr_ts', ArrayType(TimestampType())), + StructField('null', NullType()), + ] + + # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + unsupported_types.append(StructField('bin', BinaryType())) + + for unsupported_type in unsupported_types: + schema = StructType([StructField('id', LongType(), True), unsupported_type]) + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, common_err_msg): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) # Regression test for SPARK-23314 def test_timestamp_dst(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org