Repository: spark Updated Branches: refs/heads/branch-1.1 814934da6 -> 91d0effb3
[SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.1) The eq of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released. Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython). Author: Davies Liu <[email protected]> Closes #4810 from davies/leak3 and squashes the following commits: 48d643d [Davies Liu] Update sql.py 968a28c [Davies Liu] fix __eq__ of singleton ac9db57 [Davies Liu] fix tests f748114 [Davies Liu] fix incorrect DataType.__eq__ Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/91d0effb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/91d0effb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/91d0effb Branch: refs/heads/branch-1.1 Commit: 91d0effb32f741292b76608661ede302b72d8cc1 Parents: 814934d Author: Davies Liu <[email protected]> Authored: Fri Feb 27 20:06:03 2015 -0800 Committer: Josh Rosen <[email protected]> Committed: Fri Feb 27 20:06:03 2015 -0800 ---------------------------------------------------------------------- python/pyspark/sql.py | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/91d0effb/python/pyspark/sql.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 07b39c9..b6bb0a0 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -24,6 +24,7 @@ import decimal import datetime import keyword import warnings +import weakref from array import array from operator import itemgetter @@ -55,8 +56,7 @@ class DataType(object): return hash(str(self)) def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -80,10 +80,6 @@ class PrimitiveType(DataType): __metaclass__ = PrimitiveTypeSingleton - def __eq__(self, other): - # because they should be the same object - return self is other - class StringType(PrimitiveType): @@ -192,9 +188,9 @@ class ArrayType(DataType): :param elementType: the data type of elements. :param containsNull: indicates whether the list contains None values. - >>> ArrayType(StringType) == ArrayType(StringType, True) + >>> ArrayType(StringType()) == ArrayType(StringType(), True) True - >>> ArrayType(StringType, False) == ArrayType(StringType) + >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ self.elementType = elementType @@ -229,11 +225,11 @@ class MapType(DataType): :param valueContainsNull: indicates whether values contains null values. - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) + >>> (MapType(StringType(), IntegerType()) + ... == MapType(StringType(), IntegerType(), True)) True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) + >>> (MapType(StringType(), IntegerType(), False) + ... == MapType(StringType(), FloatType())) False """ self.keyType = keyType @@ -267,11 +263,11 @@ class StructField(DataType): :param nullable: indicates whether values of this field can be null. - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f1", StringType(), True)) True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f2", StringType(), True)) False """ self.name = name @@ -295,13 +291,13 @@ class StructType(DataType): def __init__(self, fields): """Creates a StructType - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), + ... StructField("f2", IntegerType(), False)]) >>> struct1 == struct2 False """ @@ -343,6 +339,9 @@ _all_primitive_types = dict((k, v) for k, v in globals().iteritems() def _parse_datatype_string(datatype_string): """Parses the given data type string. + >>> import pickle + >>> LongType() == pickle.loads(pickle.dumps(LongType())) + True >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) ... python_datatype = _parse_datatype_string( @@ -751,7 +750,7 @@ def _verify_type(obj, dataType): _verify_type(v, f.dataType) -_cached_cls = {} +_cached_cls = weakref.WeakValueDictionary() def _restore_object(dataType, obj): --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
