Repository: spark
Updated Branches:
  refs/heads/branch-1.2 17b7cc733 -> 576fc54e5


[SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.2)

The eq of DataType is not correct, class cache is not use correctly (created 
class can not be find by dataType), then it will create lots of classes (saved 
in _cached_cls), never released.

Also, all same DataType have same hash code, there will be many object in a 
dict with the same hash code, end with hash attach, it's very slow to access 
this dict (depends on the implementation of CPython).

This PR also improve the performance of inferSchema (avoid the unnecessary 
converter of object).

Author: Davies Liu <[email protected]>

Closes #4809 from davies/leak2 and squashes the following commits:

65c222f [Davies Liu] Update sql.py
9b4dadc [Davies Liu] fix __eq__ of singleton
b576107 [Davies Liu] fix tests
6c2909a [Davies Liu] fix incorrect DataType.__eq__


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/576fc54e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/576fc54e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/576fc54e

Branch: refs/heads/branch-1.2
Commit: 576fc54e5c154fc28af1a732a6bea452d0a5cabb
Parents: 17b7cc7
Author: Davies Liu <[email protected]>
Authored: Fri Feb 27 20:04:16 2015 -0800
Committer: Josh Rosen <[email protected]>
Committed: Fri Feb 27 20:04:16 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py | 67 ++++++++++++++++++++++++++++++----------------
 1 file changed, 44 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/576fc54e/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index aa5af1b..4410925 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -36,6 +36,7 @@ import keyword
 import warnings
 import json
 import re
+import weakref
 from array import array
 from operator import itemgetter
 from itertools import imap
@@ -68,8 +69,7 @@ class DataType(object):
         return hash(str(self))
 
     def __eq__(self, other):
-        return (isinstance(other, self.__class__) and
-                self.__dict__ == other.__dict__)
+        return isinstance(other, self.__class__) and self.__dict__ == 
other.__dict__
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -105,10 +105,6 @@ class PrimitiveType(DataType):
 
     __metaclass__ = PrimitiveTypeSingleton
 
-    def __eq__(self, other):
-        # because they should be the same object
-        return self is other
-
 
 class NullType(PrimitiveType):
 
@@ -251,9 +247,9 @@ class ArrayType(DataType):
         :param elementType: the data type of elements.
         :param containsNull: indicates whether the list contains None values.
 
-        >>> ArrayType(StringType) == ArrayType(StringType, True)
+        >>> ArrayType(StringType()) == ArrayType(StringType(), True)
         True
-        >>> ArrayType(StringType, False) == ArrayType(StringType)
+        >>> ArrayType(StringType(), False) == ArrayType(StringType())
         False
         """
         self.elementType = elementType
@@ -298,11 +294,11 @@ class MapType(DataType):
         :param valueContainsNull: indicates whether values contains
         null values.
 
-        >>> (MapType(StringType, IntegerType)
-        ...        == MapType(StringType, IntegerType, True))
+        >>> (MapType(StringType(), IntegerType())
+        ...        == MapType(StringType(), IntegerType(), True))
         True
-        >>> (MapType(StringType, IntegerType, False)
-        ...        == MapType(StringType, FloatType))
+        >>> (MapType(StringType(), IntegerType(), False)
+        ...        == MapType(StringType(), FloatType()))
         False
         """
         self.keyType = keyType
@@ -351,11 +347,11 @@ class StructField(DataType):
                          to simple type that can be serialized to JSON
                          automatically
 
-        >>> (StructField("f1", StringType, True)
-        ...      == StructField("f1", StringType, True))
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f1", StringType(), True))
         True
-        >>> (StructField("f1", StringType, True)
-        ...      == StructField("f2", StringType, True))
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f2", StringType(), True))
         False
         """
         self.name = name
@@ -393,13 +389,13 @@ class StructType(DataType):
     def __init__(self, fields):
         """Creates a StructType
 
-        >>> struct1 = StructType([StructField("f1", StringType, True)])
-        >>> struct2 = StructType([StructField("f1", StringType, True)])
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True)])
         >>> struct1 == struct2
         True
-        >>> struct1 = StructType([StructField("f1", StringType, True)])
-        >>> struct2 = StructType([StructField("f1", StringType, True),
-        ...   [StructField("f2", IntegerType, False)]])
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True),
+        ...                       StructField("f2", IntegerType(), False)])
         >>> struct1 == struct2
         False
         """
@@ -499,6 +495,10 @@ _all_complex_types = dict((v.typeName(), v)
 
 def _parse_datatype_json_string(json_string):
     """Parses the given data type JSON string.
+
+    >>> import pickle
+    >>> LongType() == pickle.loads(pickle.dumps(LongType()))
+    True
     >>> def check_datatype(datatype):
     ...     scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
     ...     python_datatype = 
_parse_datatype_json_string(scala_datatype.json())
@@ -781,8 +781,25 @@ def _merge_type(a, b):
         return a
 
 
+def _need_converter(dataType):
+    if isinstance(dataType, StructType):
+        return True
+    elif isinstance(dataType, ArrayType):
+        return _need_converter(dataType.elementType)
+    elif isinstance(dataType, MapType):
+        return _need_converter(dataType.keyType) or 
_need_converter(dataType.valueType)
+    elif isinstance(dataType, NullType):
+        return True
+    else:
+        return False
+
+
 def _create_converter(dataType):
     """Create an converter to drop the names of fields in obj """
+
+    if not _need_converter(dataType):
+        return lambda x: x
+
     if isinstance(dataType, ArrayType):
         conv = _create_converter(dataType.elementType)
         return lambda row: map(conv, row)
@@ -800,6 +817,7 @@ def _create_converter(dataType):
     # dataType must be StructType
     names = [f.name for f in dataType.fields]
     converters = [_create_converter(f.dataType) for f in dataType.fields]
+    convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
 
     def convert_struct(obj):
         if obj is None:
@@ -822,7 +840,10 @@ def _create_converter(dataType):
         else:
             raise ValueError("Unexpected obj: %s" % obj)
 
-        return tuple([conv(d.get(name)) for name, conv in zip(names, 
converters)])
+        if convert_fields:
+            return tuple([conv(d.get(name)) for name, conv in zip(names, 
converters)])
+        else:
+            return tuple([d.get(name) for name in names])
 
     return convert_struct
 
@@ -1039,7 +1060,7 @@ def _verify_type(obj, dataType):
             _verify_type(v, f.dataType)
 
 
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
 
 
 def _restore_object(dataType, obj):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to