You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/02/28 05:06:28 UTC

spark git commit: [SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.1)

Repository: spark
Updated Branches:
  refs/heads/branch-1.1 814934da6 -> 91d0effb3


[SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.1)

The eq of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released.

Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython).

Author: Davies Liu <da...@databricks.com>

Closes #4810 from davies/leak3 and squashes the following commits:

48d643d [Davies Liu] Update sql.py
968a28c [Davies Liu] fix __eq__ of singleton
ac9db57 [Davies Liu] fix tests
f748114 [Davies Liu] fix incorrect DataType.__eq__


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/91d0effb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/91d0effb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/91d0effb

Branch: refs/heads/branch-1.1
Commit: 91d0effb32f741292b76608661ede302b72d8cc1
Parents: 814934d
Author: Davies Liu <da...@databricks.com>
Authored: Fri Feb 27 20:06:03 2015 -0800
Committer: Josh Rosen <jo...@databricks.com>
Committed: Fri Feb 27 20:06:03 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py | 43 +++++++++++++++++++++----------------------
 1 file changed, 21 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/91d0effb/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 07b39c9..b6bb0a0 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -24,6 +24,7 @@ import decimal
 import datetime
 import keyword
 import warnings
+import weakref
 from array import array
 from operator import itemgetter
 
@@ -55,8 +56,7 @@ class DataType(object):
         return hash(str(self))
 
     def __eq__(self, other):
-        return (isinstance(other, self.__class__) and
-                self.__dict__ == other.__dict__)
+        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -80,10 +80,6 @@ class PrimitiveType(DataType):
 
     __metaclass__ = PrimitiveTypeSingleton
 
-    def __eq__(self, other):
-        # because they should be the same object
-        return self is other
-
 
 class StringType(PrimitiveType):
 
@@ -192,9 +188,9 @@ class ArrayType(DataType):
         :param elementType: the data type of elements.
         :param containsNull: indicates whether the list contains None values.
 
-        >>> ArrayType(StringType) == ArrayType(StringType, True)
+        >>> ArrayType(StringType()) == ArrayType(StringType(), True)
         True
-        >>> ArrayType(StringType, False) == ArrayType(StringType)
+        >>> ArrayType(StringType(), False) == ArrayType(StringType())
         False
         """
         self.elementType = elementType
@@ -229,11 +225,11 @@ class MapType(DataType):
         :param valueContainsNull: indicates whether values contains
         null values.
 
-        >>> (MapType(StringType, IntegerType)
-        ...        == MapType(StringType, IntegerType, True))
+        >>> (MapType(StringType(), IntegerType())
+        ...        == MapType(StringType(), IntegerType(), True))
         True
-        >>> (MapType(StringType, IntegerType, False)
-        ...        == MapType(StringType, FloatType))
+        >>> (MapType(StringType(), IntegerType(), False)
+        ...        == MapType(StringType(), FloatType()))
         False
         """
         self.keyType = keyType
@@ -267,11 +263,11 @@ class StructField(DataType):
         :param nullable: indicates whether values of this field
                          can be null.
 
-        >>> (StructField("f1", StringType, True)
-        ...      == StructField("f1", StringType, True))
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f1", StringType(), True))
         True
-        >>> (StructField("f1", StringType, True)
-        ...      == StructField("f2", StringType, True))
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f2", StringType(), True))
         False
         """
         self.name = name
@@ -295,13 +291,13 @@ class StructType(DataType):
     def __init__(self, fields):
         """Creates a StructType
 
-        >>> struct1 = StructType([StructField("f1", StringType, True)])
-        >>> struct2 = StructType([StructField("f1", StringType, True)])
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True)])
         >>> struct1 == struct2
         True
-        >>> struct1 = StructType([StructField("f1", StringType, True)])
-        >>> struct2 = StructType([StructField("f1", StringType, True),
-        ...   [StructField("f2", IntegerType, False)]])
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True),
+        ...                       StructField("f2", IntegerType(), False)])
         >>> struct1 == struct2
         False
         """
@@ -343,6 +339,9 @@ _all_primitive_types = dict((k, v) for k, v in globals().iteritems()
 def _parse_datatype_string(datatype_string):
     """Parses the given data type string.
 
+    >>> import pickle
+    >>> LongType() == pickle.loads(pickle.dumps(LongType()))
+    True
     >>> def check_datatype(datatype):
     ...     scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
     ...     python_datatype = _parse_datatype_string(
@@ -751,7 +750,7 @@ def _verify_type(obj, dataType):
             _verify_type(v, f.dataType)
 
 
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
 
 
 def _restore_object(dataType, obj):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org