Repository: spark
Updated Branches:
  refs/heads/branch-2.0 2b32a442d -> 356a359de


[SPARK-16700][PYSPARK][SQL] create DataFrame from dict/Row with schema

In 2.0, we verify the data type against schema for every row for safety, but 
with performance cost, this PR make it optional.

When we verify the data type for StructType, it does not support all the types 
we support in infer schema (for example, dict), this PR fix that to make them 
consistent.

For Row object which is created using named arguments, the order of fields are 
sorted by name, they may be not different than the order in provided schema, 
this PR fix that by ignore the order of fields in this case.

Created regression tests for them.

Author: Davies Liu <dav...@databricks.com>

Closes #14469 from davies/py_dict.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/356a359d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/356a359d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/356a359d

Branch: refs/heads/branch-2.0
Commit: 356a359de038e2e9d4d0cb7c0c5b493f7036d7c3
Parents: 2b32a44
Author: Davies Liu <dav...@databricks.com>
Authored: Mon Aug 15 12:41:27 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Thu Aug 25 09:42:43 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/context.py |  8 ++++++--
 python/pyspark/sql/session.py | 29 +++++++++++++----------------
 python/pyspark/sql/tests.py   | 16 ++++++++++++++++
 python/pyspark/sql/types.py   | 37 +++++++++++++++++++++++++++----------
 4 files changed, 62 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/356a359d/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index afb9b54..8cdf371 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -215,7 +215,7 @@ class SQLContext(object):
 
     @since(1.3)
     @ignore_unicode_prefix
-    def createDataFrame(self, data, schema=None, samplingRatio=None):
+    def createDataFrame(self, data, schema=None, samplingRatio=None, 
verifySchema=True):
         """
         Creates a :class:`DataFrame` from an :class:`RDD`, a list or a 
:class:`pandas.DataFrame`.
 
@@ -247,6 +247,7 @@ class SQLContext(object):
             ``byte`` instead of ``tinyint`` for 
:class:`pyspark.sql.types.ByteType`.
             We can also use ``int`` as a short name for 
:class:`pyspark.sql.types.IntegerType`.
         :param samplingRatio: the sample ratio of rows used for inferring
+        :param verifySchema: verify data types of every row against schema.
         :return: :class:`DataFrame`
 
         .. versionchanged:: 2.0
@@ -255,6 +256,9 @@ class SQLContext(object):
            If it's not a :class:`pyspark.sql.types.StructType`, it will be 
wrapped into a
            :class:`pyspark.sql.types.StructType` and each record will also be 
wrapped into a tuple.
 
+        .. versionchanged:: 2.0.1
+           Added verifySchema.
+
         >>> l = [('Alice', 1)]
         >>> sqlContext.createDataFrame(l).collect()
         [Row(_1=u'Alice', _2=1)]
@@ -302,7 +306,7 @@ class SQLContext(object):
             ...
         Py4JJavaError: ...
         """
-        return self.sparkSession.createDataFrame(data, schema, samplingRatio)
+        return self.sparkSession.createDataFrame(data, schema, samplingRatio, 
verifySchema)
 
     @since(1.3)
     def registerDataFrameAsTable(self, df, tableName):

http://git-wip-us.apache.org/repos/asf/spark/blob/356a359d/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 10bd89b..d8627ce 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -384,17 +384,15 @@ class SparkSession(object):
 
         if schema is None or isinstance(schema, (list, tuple)):
             struct = self._inferSchemaFromList(data)
+            converter = _create_converter(struct)
+            data = map(converter, data)
             if isinstance(schema, (list, tuple)):
                 for i, name in enumerate(schema):
                     struct.fields[i].name = name
                     struct.names[i] = name
             schema = struct
 
-        elif isinstance(schema, StructType):
-            for row in data:
-                _verify_type(row, schema)
-
-        else:
+        elif not isinstance(schema, StructType):
             raise TypeError("schema should be StructType or list or None, but 
got: %s" % schema)
 
         # convert python objects to sql data
@@ -403,7 +401,7 @@ class SparkSession(object):
 
     @since(2.0)
     @ignore_unicode_prefix
-    def createDataFrame(self, data, schema=None, samplingRatio=None):
+    def createDataFrame(self, data, schema=None, samplingRatio=None, 
verifySchema=True):
         """
         Creates a :class:`DataFrame` from an :class:`RDD`, a list or a 
:class:`pandas.DataFrame`.
 
@@ -434,13 +432,11 @@ class SparkSession(object):
             ``byte`` instead of ``tinyint`` for 
:class:`pyspark.sql.types.ByteType`. We can also use
             ``int`` as a short name for ``IntegerType``.
         :param samplingRatio: the sample ratio of rows used for inferring
+        :param verifySchema: verify data types of every row against schema.
         :return: :class:`DataFrame`
 
-        .. versionchanged:: 2.0
-           The ``schema`` parameter can be a 
:class:`pyspark.sql.types.DataType` or a
-           :class:`pyspark.sql.types.StringType` after 2.0. If it's not a
-           :class:`pyspark.sql.types.StructType`, it will be wrapped into a
-           :class:`pyspark.sql.types.StructType` and each record will also be 
wrapped into a tuple.
+        .. versionchanged:: 2.0.1
+           Added verifySchema.
 
         >>> l = [('Alice', 1)]
         >>> spark.createDataFrame(l).collect()
@@ -505,17 +501,18 @@ class SparkSession(object):
                 schema = [str(x) for x in data.columns]
             data = [r.tolist() for r in data.to_records(index=False)]
 
+        verify_func = _verify_type if verifySchema else lambda _, t: True
         if isinstance(schema, StructType):
             def prepare(obj):
-                _verify_type(obj, schema)
+                verify_func(obj, schema)
                 return obj
         elif isinstance(schema, DataType):
-            datatype = schema
+            dataType = schema
+            schema = StructType().add("value", schema)
 
             def prepare(obj):
-                _verify_type(obj, datatype)
-                return (obj, )
-            schema = StructType().add("value", datatype)
+                verify_func(obj, dataType)
+                return obj,
         else:
             if isinstance(schema, list):
                 schema = [x.encode('utf-8') if not isinstance(x, str) else x 
for x in schema]

http://git-wip-us.apache.org/repos/asf/spark/blob/356a359d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 87dbb50..520b09d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -411,6 +411,22 @@ class SQLTests(ReusedPySparkTestCase):
         df3 = self.spark.createDataFrame(rdd, df.schema)
         self.assertEqual(10, df3.count())
 
+    def test_apply_schema_to_dict_and_rows(self):
+        schema = StructType().add("b", StringType()).add("a", IntegerType())
+        input = [{"a": 1}, {"b": "coffee"}]
+        rdd = self.sc.parallelize(input)
+        for verify in [False, True]:
+            df = self.spark.createDataFrame(input, schema, verifySchema=verify)
+            df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+            self.assertEqual(df.schema, df2.schema)
+
+            rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, 
b=None))
+            df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+            self.assertEqual(10, df3.count())
+            input = [Row(a=x, b=str(x)) for x in range(10)]
+            df4 = self.spark.createDataFrame(input, schema, 
verifySchema=verify)
+            self.assertEqual(10, df4.count())
+
     def test_create_dataframe_schema_mismatch(self):
         input = [Row(a=1)]
         rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))

http://git-wip-us.apache.org/repos/asf/spark/blob/356a359d/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 1ca4bbc..b765472 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -582,6 +582,8 @@ class StructType(DataType):
         else:
             if isinstance(obj, dict):
                 return tuple(obj.get(n) for n in self.names)
+            elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
+                return tuple(obj[n] for n in self.names)
             elif isinstance(obj, (list, tuple)):
                 return tuple(obj)
             elif hasattr(obj, "__dict__"):
@@ -1243,7 +1245,7 @@ _acceptable_types = {
     TimestampType: (datetime.datetime,),
     ArrayType: (list, tuple, array),
     MapType: (dict,),
-    StructType: (tuple, list),
+    StructType: (tuple, list, dict),
 }
 
 
@@ -1314,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True):
     assert _type in _acceptable_types, "unknown datatype: %s for object %r" % 
(dataType, obj)
 
     if _type is StructType:
-        if not isinstance(obj, (tuple, list)):
-            raise TypeError("StructType can not accept object %r in type %s" % 
(obj, type(obj)))
+        # check the type and fields later
+        pass
     else:
-        # subclass of them can not be fromInternald in JVM
+        # subclass of them can not be fromInternal in JVM
         if type(obj) not in _acceptable_types[_type]:
             raise TypeError("%s can not accept object %r in type %s" % 
(dataType, obj, type(obj)))
 
@@ -1343,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True):
             _verify_type(v, dataType.valueType, dataType.valueContainsNull)
 
     elif isinstance(dataType, StructType):
-        if len(obj) != len(dataType.fields):
-            raise ValueError("Length of object (%d) does not match with "
-                             "length of fields (%d)" % (len(obj), 
len(dataType.fields)))
-        for v, f in zip(obj, dataType.fields):
-            _verify_type(v, f.dataType, f.nullable)
+        if isinstance(obj, dict):
+            for f in dataType.fields:
+                _verify_type(obj.get(f.name), f.dataType, f.nullable)
+        elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
+            # the order in obj could be different than dataType.fields
+            for f in dataType.fields:
+                _verify_type(obj[f.name], f.dataType, f.nullable)
+        elif isinstance(obj, (tuple, list)):
+            if len(obj) != len(dataType.fields):
+                raise ValueError("Length of object (%d) does not match with "
+                                 "length of fields (%d)" % (len(obj), 
len(dataType.fields)))
+            for v, f in zip(obj, dataType.fields):
+                _verify_type(v, f.dataType, f.nullable)
+        elif hasattr(obj, "__dict__"):
+            d = obj.__dict__
+            for f in dataType.fields:
+                _verify_type(d.get(f.name), f.dataType, f.nullable)
+        else:
+            raise TypeError("StructType can not accept object %r in type %s" % 
(obj, type(obj)))
 
 
 # This is used to unpickle a Row from JVM
@@ -1410,6 +1426,7 @@ class Row(tuple):
             names = sorted(kwargs.keys())
             row = tuple.__new__(self, [kwargs[n] for n in names])
             row.__fields__ = names
+            row.__from_dict__ = True
             return row
 
         else:
@@ -1485,7 +1502,7 @@ class Row(tuple):
             raise AttributeError(item)
 
     def __setattr__(self, key, value):
-        if key != '__fields__':
+        if key != '__fields__' and key != "__from_dict__":
             raise Exception("Row is read-only")
         self.__dict__[key] = value
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to