Repository: spark Updated Branches: refs/heads/master ba84bcb2c -> 10f2b6fa0
[SPARK-23555][PYTHON] Add BinaryType support for Arrow in Python ## What changes were proposed in this pull request? Adding `BinaryType` support for Arrow in pyspark, conditional on using pyarrow >= 0.10.0. Earlier versions will continue to raise a TypeError. ## How was this patch tested? Additional unit tests in pyspark for code paths that use Arrow for createDataFrame, toPandas, and scalar pandas_udfs. Closes #20725 from BryanCutler/arrow-binary-type-support-SPARK-23555. Authored-by: Bryan Cutler <cutl...@gmail.com> Signed-off-by: Bryan Cutler <cutl...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/10f2b6fa Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/10f2b6fa Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/10f2b6fa Branch: refs/heads/master Commit: 10f2b6fa05f3d977f3b6099fcd94c5c0cd97a0cb Parents: ba84bcb Author: Bryan Cutler <cutl...@gmail.com> Authored: Fri Aug 17 22:14:42 2018 -0700 Committer: Bryan Cutler <cutl...@gmail.com> Committed: Fri Aug 17 22:14:42 2018 -0700 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 66 +++++++++++++++++++++++++++++++++------- python/pyspark/sql/types.py | 15 +++++++++ 2 files changed, 70 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/10f2b6fa/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 91ed600..00d7e18 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4050,6 +4050,8 @@ class ArrowTests(ReusedSQLTestCase): def setUpClass(cls): from datetime import date, datetime from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -4078,6 +4080,13 @@ class ArrowTests(ReusedSQLTestCase): (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): + cls.schema.add(StructField("9_binary_t", BinaryType(), True)) + cls.data[0] = cls.data[0] + (bytearray(b"a"),) + cls.data[1] = cls.data[1] + (bytearray(b"bb"),) + cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + @classmethod def tearDownClass(cls): del os.environ["TZ"] @@ -4115,12 +4124,23 @@ class ArrowTests(ReusedSQLTestCase): self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): + from distutils.version import LooseVersion + import pyarrow as pa + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + schema = StructType([StructField("binary", BinaryType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'): + df.toPandas() + def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -4232,19 +4252,22 @@ class ArrowTests(ReusedSQLTestCase): def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() - wrong_schema = StructType(list(reversed(self.schema))) + fields = list(self.schema) + fields[0], fields[7] = fields[7], fields[0] # swap str with timestamp + wrong_schema = StructType(fields) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() + new_names = list(map(str, range(len(self.schema.fieldNames())))) # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) - self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) + df = self.spark.createDataFrame(pdf, schema=list(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) - self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) + df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -4331,13 +4354,22 @@ class ArrowTests(ReusedSQLTestCase): self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): + from distutils.version import LooseVersion import pandas as pd + import pyarrow as pa with QuietTest(self.sc): with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>") + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'): + self.spark.createDataFrame( + pd.DataFrame([[{'a': b'aaa'}]]), "a: binary") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd @@ -4729,6 +4761,24 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_binary(self): + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, col + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) + else: + data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)] + schema = StructType().add("binary", BinaryType()) + df = self.spark.createDataFrame(data, schema) + str_f = pandas_udf(lambda x: x, BinaryType()) + res = df.select(str_f(col('binary'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_array_type(self): from pyspark.sql.functions import pandas_udf, col data = [([1, 2],), ([3, 4],)] @@ -4835,12 +4885,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): 'Invalid returnType.*scalar Pandas UDF.*MapType'): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): - pandas_udf(lambda x: x, BinaryType()) - def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date http://git-wip-us.apache.org/repos/asf/spark/blob/10f2b6fa/python/pyspark/sql/types.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 214d8fe..0b61707 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1578,6 +1578,7 @@ register_input_converter(DateConverter()) def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ + from distutils.version import LooseVersion import pyarrow as pa if type(dt) == BooleanType: arrow_type = pa.bool_() @@ -1597,6 +1598,12 @@ def to_arrow_type(dt): arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == BinaryType: + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt) + + "\nPlease install pyarrow >= 0.10.0 for BinaryType support.") + arrow_type = pa.binary() elif type(dt) == DateType: arrow_type = pa.date32() elif type(dt) == TimestampType: @@ -1623,6 +1630,8 @@ def to_arrow_schema(schema): def from_arrow_type(at): """ Convert pyarrow type to Spark data type. """ + from distutils.version import LooseVersion + import pyarrow as pa import pyarrow.types as types if types.is_boolean(at): spark_type = BooleanType() @@ -1642,6 +1651,12 @@ def from_arrow_type(at): spark_type = DecimalType(precision=at.precision, scale=at.scale) elif types.is_string(at): spark_type = StringType() + elif types.is_binary(at): + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at) + + "\nPlease install pyarrow >= 0.10.0 for BinaryType support.") + spark_type = BinaryType() elif types.is_date32(at): spark_type = DateType() elif types.is_timestamp(at): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org