Repository: spark
Updated Branches:
  refs/heads/master 2b6e1ce6e -> 24544fbce


[SPARK-3594] [PySpark] [SQL] take more rows to infer schema or sampling

This patch will try to infer schema for RDD which has empty value (None, [], 
{}) in the first row. It will try first 100 rows and merge the types into 
schema, also merge fields of StructType together. If there is still NullType in 
schema, then it will show an warning, tell user to try with sampling.

If sampling is presented, it will infer schema from all the rows after sampling.

Also, add samplingRatio for jsonFile() and jsonRDD()

Author: Davies Liu <davies....@gmail.com>
Author: Davies Liu <dav...@databricks.com>

Closes #2716 from davies/infer and squashes the following commits:

e678f6d [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer
34b5c63 [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer
567dc60 [Davies Liu] update docs
9767b27 [Davies Liu] Merge branch 'master' into infer
e48d7fb [Davies Liu] fix tests
29e94d5 [Davies Liu] let NullType inherit from PrimitiveType
ee5d524 [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer
540d1d5 [Davies Liu] merge fields for StructType
f93fd84 [Davies Liu] add more tests
3603e00 [Davies Liu] take more rows to infer schema, or infer the schema by 
sampling the RDD


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/24544fbc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/24544fbc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/24544fbc

Branch: refs/heads/master
Commit: 24544fbce05665ab4999a1fe5aac434d29cd912c
Parents: 2b6e1ce
Author: Davies Liu <davies....@gmail.com>
Authored: Mon Nov 3 13:17:09 2014 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Nov 3 13:17:09 2014 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py                           | 196 ++++++++++++-------
 python/pyspark/tests.py                         |  19 ++
 .../spark/sql/catalyst/types/dataTypes.scala    |   2 +-
 3 files changed, 148 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/24544fbc/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 98e41f8..675df08 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -109,6 +109,15 @@ class PrimitiveType(DataType):
         return self is other
 
 
+class NullType(PrimitiveType):
+
+    """Spark SQL NullType
+
+    The data type representing None, used for the types which has not
+    been inferred.
+    """
+
+
 class StringType(PrimitiveType):
 
     """Spark SQL StringType
@@ -331,7 +340,7 @@ class StructField(DataType):
 
     """
 
-    def __init__(self, name, dataType, nullable, metadata=None):
+    def __init__(self, name, dataType, nullable=True, metadata=None):
         """Creates a StructField
         :param name: the name of this field.
         :param dataType: the data type of this field.
@@ -484,6 +493,7 @@ def _parse_datatype_json_value(json_value):
 
 # Mapping Python types to Spark SQL DataType
 _type_mappings = {
+    type(None): NullType,
     bool: BooleanType,
     int: IntegerType,
     long: LongType,
@@ -500,22 +510,22 @@ _type_mappings = {
 
 def _infer_type(obj):
     """Infer the DataType from obj"""
-    if obj is None:
-        raise ValueError("Can not infer type for None")
-
     dataType = _type_mappings.get(type(obj))
     if dataType is not None:
         return dataType()
 
     if isinstance(obj, dict):
-        if not obj:
-            raise ValueError("Can not infer type for empty dict")
-        key, value = obj.iteritems().next()
-        return MapType(_infer_type(key), _infer_type(value), True)
+        for key, value in obj.iteritems():
+            if key is not None and value is not None:
+                return MapType(_infer_type(key), _infer_type(value), True)
+        else:
+            return MapType(NullType(), NullType(), True)
     elif isinstance(obj, (list, array)):
-        if not obj:
-            raise ValueError("Can not infer type for empty list/array")
-        return ArrayType(_infer_type(obj[0]), True)
+        for v in obj:
+            if v is not None:
+                return ArrayType(_infer_type(obj[0]), True)
+        else:
+            return ArrayType(NullType(), True)
     else:
         try:
             return _infer_schema(obj)
@@ -548,60 +558,93 @@ def _infer_schema(row):
     return StructType(fields)
 
 
-def _create_converter(obj, dataType):
+def _has_nulltype(dt):
+    """ Return whether there is NullType in `dt` or not """
+    if isinstance(dt, StructType):
+        return any(_has_nulltype(f.dataType) for f in dt.fields)
+    elif isinstance(dt, ArrayType):
+        return _has_nulltype((dt.elementType))
+    elif isinstance(dt, MapType):
+        return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
+    else:
+        return isinstance(dt, NullType)
+
+
+def _merge_type(a, b):
+    if isinstance(a, NullType):
+        return b
+    elif isinstance(b, NullType):
+        return a
+    elif type(a) is not type(b):
+        # TODO: type cast (such as int -> long)
+        raise TypeError("Can not merge type %s and %s" % (a, b))
+
+    # same type
+    if isinstance(a, StructType):
+        nfs = dict((f.name, f.dataType) for f in b.fields)
+        fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, 
NullType())))
+                  for f in a.fields]
+        names = set([f.name for f in fields])
+        for n in nfs:
+            if n not in names:
+                fields.append(StructField(n, nfs[n]))
+        return StructType(fields)
+
+    elif isinstance(a, ArrayType):
+        return ArrayType(_merge_type(a.elementType, b.elementType), True)
+
+    elif isinstance(a, MapType):
+        return MapType(_merge_type(a.keyType, b.keyType),
+                       _merge_type(a.valueType, b.valueType),
+                       True)
+    else:
+        return a
+
+
+def _create_converter(dataType):
     """Create an converter to drop the names of fields in obj """
     if isinstance(dataType, ArrayType):
-        conv = _create_converter(obj[0], dataType.elementType)
+        conv = _create_converter(dataType.elementType)
         return lambda row: map(conv, row)
 
     elif isinstance(dataType, MapType):
-        value = obj.values()[0]
-        conv = _create_converter(value, dataType.valueType)
+        conv = _create_converter(dataType.valueType)
         return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
 
+    elif isinstance(dataType, NullType):
+        return lambda x: None
+
     elif not isinstance(dataType, StructType):
         return lambda x: x
 
     # dataType must be StructType
     names = [f.name for f in dataType.fields]
+    converters = [_create_converter(f.dataType) for f in dataType.fields]
+
+    def convert_struct(obj):
+        if obj is None:
+            return
+
+        if isinstance(obj, tuple):
+            if hasattr(obj, "fields"):
+                d = dict(zip(obj.fields, obj))
+            if hasattr(obj, "__FIELDS__"):
+                d = dict(zip(obj.__FIELDS__, obj))
+            elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
+                d = dict(obj)
+            else:
+                raise ValueError("unexpected tuple: %s" % obj)
 
-    if isinstance(obj, dict):
-        conv = lambda o: tuple(o.get(n) for n in names)
-
-    elif isinstance(obj, tuple):
-        if hasattr(obj, "_fields"):  # namedtuple
-            conv = tuple
-        elif hasattr(obj, "__FIELDS__"):
-            conv = tuple
-        elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
-            conv = lambda o: tuple(v for k, v in o)
+        elif isinstance(obj, dict):
+            d = obj
+        elif hasattr(obj, "__dict__"):  # object
+            d = obj.__dict__
         else:
-            raise ValueError("unexpected tuple")
+            raise ValueError("Unexpected obj: %s" % obj)
 
-    elif hasattr(obj, "__dict__"):  # object
-        conv = lambda o: [o.__dict__.get(n, None) for n in names]
+        return tuple([conv(d.get(name)) for name, conv in zip(names, 
converters)])
 
-    if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields):
-        return conv
-
-    row = conv(obj)
-    convs = [_create_converter(v, f.dataType)
-             for v, f in zip(row, dataType.fields)]
-
-    def nested_conv(row):
-        return tuple(f(v) for f, v in zip(convs, conv(row)))
-
-    return nested_conv
-
-
-def _drop_schema(rows, schema):
-    """ all the names of fields, becoming tuples"""
-    iterator = iter(rows)
-    row = iterator.next()
-    converter = _create_converter(row, schema)
-    yield converter(row)
-    for i in iterator:
-        yield converter(i)
+    return convert_struct
 
 
 _BRACKETS = {'(': ')', '[': ']', '{': '}'}
@@ -713,7 +756,7 @@ def _infer_schema_type(obj, dataType):
         return _infer_type(obj)
 
     if not obj:
-        raise ValueError("Can not infer type from empty value")
+        return NullType()
 
     if isinstance(dataType, ArrayType):
         eType = _infer_schema_type(obj[0], dataType.elementType)
@@ -1049,18 +1092,20 @@ class SQLContext(object):
                                       self._sc._javaAccumulator,
                                       returnType.json())
 
-    def inferSchema(self, rdd):
+    def inferSchema(self, rdd, samplingRatio=None):
         """Infer and apply a schema to an RDD of L{Row}.
 
-        We peek at the first row of the RDD to determine the fields' names
-        and types. Nested collections are supported, which include array,
-        dict, list, Row, tuple, namedtuple, or object.
+        When samplingRatio is specified, the schema is inferred by looking
+        at the types of each row in the sampled dataset. Otherwise, the
+        first 100 rows of the RDD are inspected. Nested collections are
+        supported, which can include array, dict, list, Row, tuple,
+        namedtuple, or object.
 
-        All the rows in `rdd` should have the same type with the first one,
-        or it will cause runtime exceptions.
+        Each row could be L{pyspark.sql.Row} object or namedtuple or objects.
+        Using top level dicts is deprecated, as dict is used to represent Maps.
 
-        Each row could be L{pyspark.sql.Row} object or namedtuple or objects,
-        using dict is deprecated.
+        If a single column has multiple distinct inferred types, it may cause
+        runtime exceptions.
 
         >>> rdd = sc.parallelize(
         ...     [Row(field1=1, field2="row1"),
@@ -1097,8 +1142,23 @@ class SQLContext(object):
             warnings.warn("Using RDD of dict to inferSchema is deprecated,"
                           "please use pyspark.sql.Row instead")
 
-        schema = _infer_schema(first)
-        rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
+        if samplingRatio is None:
+            schema = _infer_schema(first)
+            if _has_nulltype(schema):
+                for row in rdd.take(100)[1:]:
+                    schema = _merge_type(schema, _infer_schema(row))
+                    if not _has_nulltype(schema):
+                        break
+                else:
+                    warnings.warn("Some of types cannot be determined by the "
+                                  "first 100 rows, please try again with 
sampling")
+        else:
+            if samplingRatio > 0.99:
+                rdd = rdd.sample(False, float(samplingRatio))
+            schema = rdd.map(_infer_schema).reduce(_merge_type)
+
+        converter = _create_converter(schema)
+        rdd = rdd.map(converter)
         return self.applySchema(rdd, schema)
 
     def applySchema(self, rdd, schema):
@@ -1219,7 +1279,7 @@ class SQLContext(object):
         jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
         return SchemaRDD(jschema_rdd, self)
 
-    def jsonFile(self, path, schema=None):
+    def jsonFile(self, path, schema=None, samplingRatio=1.0):
         """
         Loads a text file storing one JSON object per line as a
         L{SchemaRDD}.
@@ -1227,8 +1287,8 @@ class SQLContext(object):
         If the schema is provided, applies the given schema to this
         JSON dataset.
 
-        Otherwise, it goes through the entire dataset once to determine
-        the schema.
+        Otherwise, it samples the dataset with ratio `samplingRatio` to
+        determine the schema.
 
         >>> import tempfile, shutil
         >>> jsonFile = tempfile.mkdtemp()
@@ -1274,20 +1334,20 @@ class SQLContext(object):
         [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
         """
         if schema is None:
-            srdd = self._ssql_ctx.jsonFile(path)
+            srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
         else:
             scala_datatype = self._ssql_ctx.parseDataType(schema.json())
             srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
         return SchemaRDD(srdd.toJavaSchemaRDD(), self)
 
-    def jsonRDD(self, rdd, schema=None):
+    def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
         """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
 
         If the schema is provided, applies the given schema to this
         JSON dataset.
 
-        Otherwise, it goes through the entire dataset once to determine
-        the schema.
+        Otherwise, it samples the dataset with ratio `samplingRatio` to
+        determine the schema.
 
         >>> srdd1 = sqlCtx.jsonRDD(json)
         >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
@@ -1344,7 +1404,7 @@ class SQLContext(object):
         keyed._bypass_serializer = True
         jrdd = keyed._jrdd.map(self._jvm.BytesToString())
         if schema is None:
-            srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+            srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
         else:
             scala_datatype = self._ssql_ctx.parseDataType(schema.json())
             srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)

http://git-wip-us.apache.org/repos/asf/spark/blob/24544fbc/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 253a471..68fd756 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -796,6 +796,25 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(1.0, row.c)
         self.assertEqual("2", row.d)
 
+    def test_infer_schema(self):
+        d = [Row(l=[], d={}),
+             Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+        rdd = self.sc.parallelize(d)
+        srdd = self.sqlCtx.inferSchema(rdd)
+        self.assertEqual([], srdd.map(lambda r: r.l).first())
+        self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
+        srdd.registerTempTable("test")
+        result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = 
'2'")
+        self.assertEqual(1, result.first()[0])
+
+        srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
+        self.assertEqual(srdd.schema(), srdd2.schema())
+        self.assertEqual({}, srdd2.map(lambda r: r.d).first())
+        self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
+        srdd2.registerTempTable("test2")
+        result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = 
'2'")
+        self.assertEqual(1, result.first()[0])
+
     def test_convert_row_to_dict(self):
         row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
         self.assertEqual(1, row.asDict()['l'][0].a)

http://git-wip-us.apache.org/repos/asf/spark/blob/24544fbc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index cc5015a..e1b5992 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -213,7 +213,7 @@ trait PrimitiveType extends DataType {
 }
 
 object PrimitiveType {
-  private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ 
NativeType.all
+  private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) 
++ NativeType.all
   private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> 
t).toMap
 
   /** Given the string representation of a type, return its DataType */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to