You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/02/28 05:07:48 UTC

spark git commit: [SPARK-6055] [PySpark] fix incorrect __eq__ of DataType

Repository: spark
Updated Branches:
  refs/heads/master 8c468a660 -> e0e64ba4b


[SPARK-6055] [PySpark] fix incorrect __eq__ of DataType

The _eq_ of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released.

Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython).

This PR also improve the performance of inferSchema (avoid the unnecessary converter of object).

cc pwendell  JoshRosen

Author: Davies Liu <da...@databricks.com>

Closes #4808 from davies/leak and squashes the following commits:

6a322a4 [Davies Liu] tests refactor
3da44fc [Davies Liu] fix __eq__ of Singleton
534ac90 [Davies Liu] add more checks
46999dc [Davies Liu] fix tests
d9ae973 [Davies Liu] fix memory leak in sql


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e0e64ba4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e0e64ba4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e0e64ba4

Branch: refs/heads/master
Commit: e0e64ba4b1b8eb72e856286f756c65fa22ab0a36
Parents: 8c468a6
Author: Davies Liu <da...@databricks.com>
Authored: Fri Feb 27 20:07:17 2015 -0800
Committer: Josh Rosen <jo...@databricks.com>
Committed: Fri Feb 27 20:07:17 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql/context.py   |  90 +-------------------------
 python/pyspark/sql/dataframe.py |   4 +-
 python/pyspark/sql/tests.py     |   9 +++
 python/pyspark/sql/types.py     | 120 +++++++++++++++++++++--------------
 4 files changed, 86 insertions(+), 137 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e0e64ba4/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 5d7aeb6..795ef0d 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -17,7 +17,6 @@
 
 import warnings
 import json
-from array import array
 from itertools import imap
 
 from py4j.protocol import Py4JError
@@ -25,7 +24,7 @@ from py4j.java_collections import MapConverter
 
 from pyspark.rdd import RDD, _prepare_for_python_RDD
 from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _verify_type, \
+from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
     _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
 from pyspark.sql.dataframe import DataFrame
 
@@ -620,93 +619,6 @@ class HiveContext(SQLContext):
         return self._jvm.HiveContext(self._jsc.sc())
 
 
-def _create_row(fields, values):
-    row = Row(*values)
-    row.__FIELDS__ = fields
-    return row
-
-
-class Row(tuple):
-
-    """
-    A row in L{DataFrame}. The fields in it can be accessed like attributes.
-
-    Row can be used to create a row object by using named arguments,
-    the fields will be sorted by names.
-
-    >>> row = Row(name="Alice", age=11)
-    >>> row
-    Row(age=11, name='Alice')
-    >>> row.name, row.age
-    ('Alice', 11)
-
-    Row also can be used to create another Row like class, then it
-    could be used to create Row objects, such as
-
-    >>> Person = Row("name", "age")
-    >>> Person
-    <Row(name, age)>
-    >>> Person("Alice", 11)
-    Row(name='Alice', age=11)
-    """
-
-    def __new__(self, *args, **kwargs):
-        if args and kwargs:
-            raise ValueError("Can not use both args "
-                             "and kwargs to create Row")
-        if args:
-            # create row class or objects
-            return tuple.__new__(self, args)
-
-        elif kwargs:
-            # create row objects
-            names = sorted(kwargs.keys())
-            values = tuple(kwargs[n] for n in names)
-            row = tuple.__new__(self, values)
-            row.__FIELDS__ = names
-            return row
-
-        else:
-            raise ValueError("No args or kwargs")
-
-    def asDict(self):
-        """
-        Return as an dict
-        """
-        if not hasattr(self, "__FIELDS__"):
-            raise TypeError("Cannot convert a Row class into dict")
-        return dict(zip(self.__FIELDS__, self))
-
-    # let obect acs like class
-    def __call__(self, *args):
-        """create new Row object"""
-        return _create_row(self, args)
-
-    def __getattr__(self, item):
-        if item.startswith("__"):
-            raise AttributeError(item)
-        try:
-            # it will be slow when it has many fields,
-            # but this will not be used in normal cases
-            idx = self.__FIELDS__.index(item)
-            return self[idx]
-        except IndexError:
-            raise AttributeError(item)
-
-    def __reduce__(self):
-        if hasattr(self, "__FIELDS__"):
-            return (_create_row, (self.__FIELDS__, tuple(self)))
-        else:
-            return tuple.__reduce__(self)
-
-    def __repr__(self):
-        if hasattr(self, "__FIELDS__"):
-            return "Row(%s)" % ", ".join("%s=%r" % (k, v)
-                                         for k, v in zip(self.__FIELDS__, self))
-        else:
-            return "<Row(%s)>" % ", ".join(self)
-
-
 def _test():
     import doctest
     from pyspark.context import SparkContext

http://git-wip-us.apache.org/repos/asf/spark/blob/e0e64ba4/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index aec9901..5c3b737 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1025,10 +1025,12 @@ class Column(object):
             ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
             jdt = ssql_ctx.parseDataType(dataType.json())
             jc = self._jc.cast(jdt)
+        else:
+            raise TypeError("unexpected type: %s" % type(dataType))
         return Column(jc)
 
     def __repr__(self):
-        return 'Column<%s>' % self._jdf.toString().encode('utf8')
+        return 'Column<%s>' % self._jc.toString().encode('utf8')
 
 
 def _test():

http://git-wip-us.apache.org/repos/asf/spark/blob/e0e64ba4/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 83899ad..2720439 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -24,6 +24,7 @@ import sys
 import pydoc
 import shutil
 import tempfile
+import pickle
 
 import py4j
 
@@ -88,6 +89,14 @@ class ExamplePoint:
             other.x == self.x and other.y == self.y
 
 
+class DataTypeTests(unittest.TestCase):
+    # regression test for SPARK-6055
+    def test_data_type_eq(self):
+        lt = LongType()
+        lt2 = pickle.loads(pickle.dumps(LongType()))
+        self.assertEquals(lt, lt2)
+
+
 class SQLTests(ReusedPySparkTestCase):
 
     @classmethod

http://git-wip-us.apache.org/repos/asf/spark/blob/e0e64ba4/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 0f5dc2b..31a861e 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -21,6 +21,7 @@ import keyword
 import warnings
 import json
 import re
+import weakref
 from array import array
 from operator import itemgetter
 
@@ -42,8 +43,7 @@ class DataType(object):
         return hash(str(self))
 
     def __eq__(self, other):
-        return (isinstance(other, self.__class__) and
-                self.__dict__ == other.__dict__)
+        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -64,6 +64,8 @@ class DataType(object):
                           sort_keys=True)
 
 
+# This singleton pattern does not work with pickle, you will get
+# another object after pickle and unpickle
 class PrimitiveTypeSingleton(type):
 
     """Metaclass for PrimitiveType"""
@@ -82,10 +84,6 @@ class PrimitiveType(DataType):
 
     __metaclass__ = PrimitiveTypeSingleton
 
-    def __eq__(self, other):
-        # because they should be the same object
-        return self is other
-
 
 class NullType(PrimitiveType):
 
@@ -242,11 +240,12 @@ class ArrayType(DataType):
         :param elementType: the data type of elements.
         :param containsNull: indicates whether the list contains None values.
 
-        >>> ArrayType(StringType) == ArrayType(StringType, True)
+        >>> ArrayType(StringType()) == ArrayType(StringType(), True)
         True
-        >>> ArrayType(StringType, False) == ArrayType(StringType)
+        >>> ArrayType(StringType(), False) == ArrayType(StringType())
         False
         """
+        assert isinstance(elementType, DataType), "elementType should be DataType"
         self.elementType = elementType
         self.containsNull = containsNull
 
@@ -292,13 +291,15 @@ class MapType(DataType):
         :param valueContainsNull: indicates whether values contains
         null values.
 
-        >>> (MapType(StringType, IntegerType)
-        ...        == MapType(StringType, IntegerType, True))
+        >>> (MapType(StringType(), IntegerType())
+        ...        == MapType(StringType(), IntegerType(), True))
         True
-        >>> (MapType(StringType, IntegerType, False)
-        ...        == MapType(StringType, FloatType))
+        >>> (MapType(StringType(), IntegerType(), False)
+        ...        == MapType(StringType(), FloatType()))
         False
         """
+        assert isinstance(keyType, DataType), "keyType should be DataType"
+        assert isinstance(valueType, DataType), "valueType should be DataType"
         self.keyType = keyType
         self.valueType = valueType
         self.valueContainsNull = valueContainsNull
@@ -348,13 +349,14 @@ class StructField(DataType):
                          to simple type that can be serialized to JSON
                          automatically
 
-        >>> (StructField("f1", StringType, True)
-        ...      == StructField("f1", StringType, True))
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f1", StringType(), True))
         True
-        >>> (StructField("f1", StringType, True)
-        ...      == StructField("f2", StringType, True))
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f2", StringType(), True))
         False
         """
+        assert isinstance(dataType, DataType), "dataType should be DataType"
         self.name = name
         self.dataType = dataType
         self.nullable = nullable
@@ -393,16 +395,17 @@ class StructType(DataType):
     def __init__(self, fields):
         """Creates a StructType
 
-        >>> struct1 = StructType([StructField("f1", StringType, True)])
-        >>> struct2 = StructType([StructField("f1", StringType, True)])
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True)])
         >>> struct1 == struct2
         True
-        >>> struct1 = StructType([StructField("f1", StringType, True)])
-        >>> struct2 = StructType([StructField("f1", StringType, True),
-        ...   [StructField("f2", IntegerType, False)]])
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True),
+        ...     StructField("f2", IntegerType(), False)])
         >>> struct1 == struct2
         False
         """
+        assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
         self.fields = fields
 
     def simpleString(self):
@@ -505,20 +508,24 @@ _all_complex_types = dict((v.typeName(), v)
 
 def _parse_datatype_json_string(json_string):
     """Parses the given data type JSON string.
+    >>> import pickle
     >>> def check_datatype(datatype):
+    ...     pickled = pickle.loads(pickle.dumps(datatype))
+    ...     assert datatype == pickled
     ...     scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
     ...     python_datatype = _parse_datatype_json_string(scala_datatype.json())
-    ...     return datatype == python_datatype
-    >>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
-    True
+    ...     assert datatype == python_datatype
+    >>> for cls in _all_primitive_types.values():
+    ...     check_datatype(cls())
+
     >>> # Simple ArrayType.
     >>> simple_arraytype = ArrayType(StringType(), True)
     >>> check_datatype(simple_arraytype)
-    True
+
     >>> # Simple MapType.
     >>> simple_maptype = MapType(StringType(), LongType())
     >>> check_datatype(simple_maptype)
-    True
+
     >>> # Simple StructType.
     >>> simple_structtype = StructType([
     ...     StructField("a", DecimalType(), False),
@@ -526,7 +533,7 @@ def _parse_datatype_json_string(json_string):
     ...     StructField("c", LongType(), True),
     ...     StructField("d", BinaryType(), False)])
     >>> check_datatype(simple_structtype)
-    True
+
     >>> # Complex StructType.
     >>> complex_structtype = StructType([
     ...     StructField("simpleArray", simple_arraytype, True),
@@ -535,22 +542,20 @@ def _parse_datatype_json_string(json_string):
     ...     StructField("boolean", BooleanType(), False),
     ...     StructField("withMeta", DoubleType(), False, {"name": "age"})])
     >>> check_datatype(complex_structtype)
-    True
+
     >>> # Complex ArrayType.
     >>> complex_arraytype = ArrayType(complex_structtype, True)
     >>> check_datatype(complex_arraytype)
-    True
+
     >>> # Complex MapType.
     >>> complex_maptype = MapType(complex_structtype,
     ...                           complex_arraytype, False)
     >>> check_datatype(complex_maptype)
-    True
+
     >>> check_datatype(ExamplePointUDT())
-    True
     >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
     ...                                   StructField("point", ExamplePointUDT(), False)])
     >>> check_datatype(structtype_with_udt)
-    True
     """
     return _parse_datatype_json_value(json.loads(json_string))
 
@@ -786,8 +791,24 @@ def _merge_type(a, b):
         return a
 
 
+def _need_converter(dataType):
+    if isinstance(dataType, StructType):
+        return True
+    elif isinstance(dataType, ArrayType):
+        return _need_converter(dataType.elementType)
+    elif isinstance(dataType, MapType):
+        return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
+    elif isinstance(dataType, NullType):
+        return True
+    else:
+        return False
+
+
 def _create_converter(dataType):
     """Create an converter to drop the names of fields in obj """
+    if not _need_converter(dataType):
+        return lambda x: x
+
     if isinstance(dataType, ArrayType):
         conv = _create_converter(dataType.elementType)
         return lambda row: map(conv, row)
@@ -806,13 +827,17 @@ def _create_converter(dataType):
     # dataType must be StructType
     names = [f.name for f in dataType.fields]
     converters = [_create_converter(f.dataType) for f in dataType.fields]
+    convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
 
     def convert_struct(obj):
         if obj is None:
             return
 
         if isinstance(obj, (tuple, list)):
-            return tuple(conv(v) for v, conv in zip(obj, converters))
+            if convert_fields:
+                return tuple(conv(v) for v, conv in zip(obj, converters))
+            else:
+                return tuple(obj)
 
         if isinstance(obj, dict):
             d = obj
@@ -821,7 +846,10 @@ def _create_converter(dataType):
         else:
             raise ValueError("Unexpected obj: %s" % obj)
 
-        return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+        if convert_fields:
+            return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+        else:
+            return tuple([d.get(name) for name in names])
 
     return convert_struct
 
@@ -871,20 +899,20 @@ def _parse_field_abstract(s):
     Parse a field in schema abstract
 
     >>> _parse_field_abstract("a")
-    StructField(a,None,true)
+    StructField(a,NullType,true)
     >>> _parse_field_abstract("b(c d)")
-    StructField(b,StructType(...c,None,true),StructField(d...
+    StructField(b,StructType(...c,NullType,true),StructField(d...
     >>> _parse_field_abstract("a[]")
-    StructField(a,ArrayType(None,true),true)
+    StructField(a,ArrayType(NullType,true),true)
     >>> _parse_field_abstract("a{[]}")
-    StructField(a,MapType(None,ArrayType(None,true),true),true)
+    StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
     """
     if set(_BRACKETS.keys()) & set(s):
         idx = min((s.index(c) for c in _BRACKETS if c in s))
         name = s[:idx]
         return StructField(name, _parse_schema_abstract(s[idx:]), True)
     else:
-        return StructField(s, None, True)
+        return StructField(s, NullType(), True)
 
 
 def _parse_schema_abstract(s):
@@ -898,11 +926,11 @@ def _parse_schema_abstract(s):
     >>> _parse_schema_abstract("c{} d{a b}")
     StructType...c,MapType...d,MapType...a...b...
     >>> _parse_schema_abstract("a b(t)").fields[1]
-    StructField(b,StructType(List(StructField(t,None,true))),true)
+    StructField(b,StructType(List(StructField(t,NullType,true))),true)
     """
     s = s.strip()
     if not s:
-        return
+        return NullType()
 
     elif s.startswith('('):
         return _parse_schema_abstract(s[1:-1])
@@ -911,7 +939,7 @@ def _parse_schema_abstract(s):
         return ArrayType(_parse_schema_abstract(s[1:-1]), True)
 
     elif s.startswith('{'):
-        return MapType(None, _parse_schema_abstract(s[1:-1]))
+        return MapType(NullType(), _parse_schema_abstract(s[1:-1]))
 
     parts = _split_schema_abstract(s)
     fields = [_parse_field_abstract(p) for p in parts]
@@ -931,7 +959,7 @@ def _infer_schema_type(obj, dataType):
     >>> _infer_schema_type(row, schema)
     StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
     """
-    if dataType is None:
+    if dataType is NullType():
         return _infer_type(obj)
 
     if not obj:
@@ -1037,8 +1065,7 @@ def _verify_type(obj, dataType):
         for v, f in zip(obj, dataType.fields):
             _verify_type(v, f.dataType)
 
-
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
 
 
 def _restore_object(dataType, obj):
@@ -1233,8 +1260,7 @@ class Row(tuple):
         elif kwargs:
             # create row objects
             names = sorted(kwargs.keys())
-            values = tuple(kwargs[n] for n in names)
-            row = tuple.__new__(self, values)
+            row = tuple.__new__(self, [kwargs[n] for n in names])
             row.__FIELDS__ = names
             return row
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org