Repository: spark
Updated Branches:
  refs/heads/master ba84bcb2c -> 10f2b6fa0


[SPARK-23555][PYTHON] Add BinaryType support for Arrow in Python

## What changes were proposed in this pull request?

Adding `BinaryType` support for Arrow in pyspark, conditional on using pyarrow 
>= 0.10.0. Earlier versions will continue to raise a TypeError.

## How was this patch tested?

Additional unit tests in pyspark for code paths that use Arrow for 
createDataFrame, toPandas, and scalar pandas_udfs.

Closes #20725 from BryanCutler/arrow-binary-type-support-SPARK-23555.

Authored-by: Bryan Cutler <cutl...@gmail.com>
Signed-off-by: Bryan Cutler <cutl...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/10f2b6fa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/10f2b6fa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/10f2b6fa

Branch: refs/heads/master
Commit: 10f2b6fa05f3d977f3b6099fcd94c5c0cd97a0cb
Parents: ba84bcb
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Fri Aug 17 22:14:42 2018 -0700
Committer: Bryan Cutler <cutl...@gmail.com>
Committed: Fri Aug 17 22:14:42 2018 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 66 +++++++++++++++++++++++++++++++++-------
 python/pyspark/sql/types.py | 15 +++++++++
 2 files changed, 70 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/10f2b6fa/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 91ed600..00d7e18 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -4050,6 +4050,8 @@ class ArrowTests(ReusedSQLTestCase):
     def setUpClass(cls):
         from datetime import date, datetime
         from decimal import Decimal
+        from distutils.version import LooseVersion
+        import pyarrow as pa
         ReusedSQLTestCase.setUpClass()
 
         # Synchronize default timezone between Python and Java
@@ -4078,6 +4080,13 @@ class ArrowTests(ReusedSQLTestCase):
                     (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
                      date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
 
+        # TODO: remove version check once minimum pyarrow version is 0.10.0
+        if LooseVersion("0.10.0") <= LooseVersion(pa.__version__):
+            cls.schema.add(StructField("9_binary_t", BinaryType(), True))
+            cls.data[0] = cls.data[0] + (bytearray(b"a"),)
+            cls.data[1] = cls.data[1] + (bytearray(b"bb"),)
+            cls.data[2] = cls.data[2] + (bytearray(b"ccc"),)
+
     @classmethod
     def tearDownClass(cls):
         del os.environ["TZ"]
@@ -4115,12 +4124,23 @@ class ArrowTests(ReusedSQLTestCase):
                     self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 
1}]}))
 
     def test_toPandas_fallback_disabled(self):
+        from distutils.version import LooseVersion
+        import pyarrow as pa
+
         schema = StructType([StructField("map", MapType(StringType(), 
IntegerType()), True)])
         df = self.spark.createDataFrame([(None,)], schema=schema)
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(Exception, 'Unsupported type'):
                 df.toPandas()
 
+        # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            schema = StructType([StructField("binary", BinaryType(), True)])
+            df = self.spark.createDataFrame([(None,)], schema=schema)
+            with QuietTest(self.sc):
+                with self.assertRaisesRegexp(Exception, 'Unsupported 
type.*BinaryType'):
+                    df.toPandas()
+
     def test_null_conversion(self):
         df_null = self.spark.createDataFrame([tuple([None for _ in 
range(len(self.data[0]))])] +
                                              self.data)
@@ -4232,19 +4252,22 @@ class ArrowTests(ReusedSQLTestCase):
 
     def test_createDataFrame_with_incorrect_schema(self):
         pdf = self.create_pandas_data_frame()
-        wrong_schema = StructType(list(reversed(self.schema)))
+        fields = list(self.schema)
+        fields[0], fields[7] = fields[7], fields[0]  # swap str with timestamp
+        wrong_schema = StructType(fields)
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(Exception, ".*No 
cast.*string.*timestamp.*"):
                 self.spark.createDataFrame(pdf, schema=wrong_schema)
 
     def test_createDataFrame_with_names(self):
         pdf = self.create_pandas_data_frame()
+        new_names = list(map(str, range(len(self.schema.fieldNames()))))
         # Test that schema as a list of column names gets applied
-        df = self.spark.createDataFrame(pdf, schema=list('abcdefgh'))
-        self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
+        df = self.spark.createDataFrame(pdf, schema=list(new_names))
+        self.assertEquals(df.schema.fieldNames(), new_names)
         # Test that schema as tuple of column names gets applied
-        df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh'))
-        self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
+        df = self.spark.createDataFrame(pdf, schema=tuple(new_names))
+        self.assertEquals(df.schema.fieldNames(), new_names)
 
     def test_createDataFrame_column_name_encoding(self):
         import pandas as pd
@@ -4331,13 +4354,22 @@ class ArrowTests(ReusedSQLTestCase):
                     self.assertEqual(df.collect(), [Row(a={u'a': 1})])
 
     def test_createDataFrame_fallback_disabled(self):
+        from distutils.version import LooseVersion
         import pandas as pd
+        import pyarrow as pa
 
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
                 self.spark.createDataFrame(
                     pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
 
+        # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            with QuietTest(self.sc):
+                with self.assertRaisesRegexp(TypeError, 'Unsupported 
type.*BinaryType'):
+                    self.spark.createDataFrame(
+                        pd.DataFrame([[{'a': b'aaa'}]]), "a: binary")
+
     # Regression test for SPARK-23314
     def test_timestamp_dst(self):
         import pandas as pd
@@ -4729,6 +4761,24 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                         bool_f(col('bool')))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_vectorized_udf_null_binary(self):
+        from distutils.version import LooseVersion
+        import pyarrow as pa
+        from pyspark.sql.functions import pandas_udf, col
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            with QuietTest(self.sc):
+                with self.assertRaisesRegexp(
+                        NotImplementedError,
+                        'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
+                    pandas_udf(lambda x: x, BinaryType())
+        else:
+            data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), 
(bytearray(b"ccc"),)]
+            schema = StructType().add("binary", BinaryType())
+            df = self.spark.createDataFrame(data, schema)
+            str_f = pandas_udf(lambda x: x, BinaryType())
+            res = df.select(str_f(col('binary')))
+            self.assertEquals(df.collect(), res.collect())
+
     def test_vectorized_udf_array_type(self):
         from pyspark.sql.functions import pandas_udf, col
         data = [([1, 2],), ([3, 4],)]
@@ -4835,12 +4885,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                     'Invalid returnType.*scalar Pandas UDF.*MapType'):
                 pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
 
-        with QuietTest(self.sc):
-            with self.assertRaisesRegexp(
-                    NotImplementedError,
-                    'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
-                pandas_udf(lambda x: x, BinaryType())
-
     def test_vectorized_udf_dates(self):
         from pyspark.sql.functions import pandas_udf, col
         from datetime import date

http://git-wip-us.apache.org/repos/asf/spark/blob/10f2b6fa/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 214d8fe..0b61707 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1578,6 +1578,7 @@ register_input_converter(DateConverter())
 def to_arrow_type(dt):
     """ Convert Spark data type to pyarrow type
     """
+    from distutils.version import LooseVersion
     import pyarrow as pa
     if type(dt) == BooleanType:
         arrow_type = pa.bool_()
@@ -1597,6 +1598,12 @@ def to_arrow_type(dt):
         arrow_type = pa.decimal128(dt.precision, dt.scale)
     elif type(dt) == StringType:
         arrow_type = pa.string()
+    elif type(dt) == BinaryType:
+        # TODO: remove version check once minimum pyarrow version is 0.10.0
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            raise TypeError("Unsupported type in conversion to Arrow: " + 
str(dt) +
+                            "\nPlease install pyarrow >= 0.10.0 for BinaryType 
support.")
+        arrow_type = pa.binary()
     elif type(dt) == DateType:
         arrow_type = pa.date32()
     elif type(dt) == TimestampType:
@@ -1623,6 +1630,8 @@ def to_arrow_schema(schema):
 def from_arrow_type(at):
     """ Convert pyarrow type to Spark data type.
     """
+    from distutils.version import LooseVersion
+    import pyarrow as pa
     import pyarrow.types as types
     if types.is_boolean(at):
         spark_type = BooleanType()
@@ -1642,6 +1651,12 @@ def from_arrow_type(at):
         spark_type = DecimalType(precision=at.precision, scale=at.scale)
     elif types.is_string(at):
         spark_type = StringType()
+    elif types.is_binary(at):
+        # TODO: remove version check once minimum pyarrow version is 0.10.0
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            raise TypeError("Unsupported type in conversion from Arrow: " + 
str(at) +
+                            "\nPlease install pyarrow >= 0.10.0 for BinaryType 
support.")
+        spark_type = BinaryType()
     elif types.is_date32(at):
         spark_type = DateType()
     elif types.is_timestamp(at):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to