You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/04/12 07:33:41 UTC

spark git commit: [SPARK-6677] [SQL] [PySpark] fix cached classes

Repository: spark
Updated Branches:
  refs/heads/master 0cc8fcb4c -> 5d8f7b9e8


[SPARK-6677] [SQL] [PySpark] fix cached classes

It's possible to have two DataType object with same id (memory address) at different time, we should check the cached classes to verify that it's generated by given datatype.

This PR also change `__FIELDS__` and `__DATATYPE__` to lower case to match Python code style.

Author: Davies Liu <da...@databricks.com>

Closes #5445 from davies/fix_type_cache and squashes the following commits:

63b3238 [Davies Liu] typo
47bdede [Davies Liu] fix cached classes


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5d8f7b9e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5d8f7b9e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5d8f7b9e

Branch: refs/heads/master
Commit: 5d8f7b9e87e8066d54717a1a78b06e8531d8b0d4
Parents: 0cc8fcb
Author: Davies Liu <da...@databricks.com>
Authored: Sat Apr 11 22:33:23 2015 -0700
Committer: Josh Rosen <jo...@databricks.com>
Committed: Sat Apr 11 22:33:23 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/types.py | 39 ++++++++++++++++++++-------------------
 1 file changed, 20 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5d8f7b9e/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 7e0124b..ef76d84 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -567,8 +567,8 @@ def _infer_schema(row):
     elif isinstance(row, (tuple, list)):
         if hasattr(row, "_fields"):  # namedtuple
             items = zip(row._fields, tuple(row))
-        elif hasattr(row, "__FIELDS__"):  # Row
-            items = zip(row.__FIELDS__, tuple(row))
+        elif hasattr(row, "__fields__"):  # Row
+            items = zip(row.__fields__, tuple(row))
         else:
             names = ['_%d' % i for i in range(1, len(row) + 1)]
             items = zip(names, row)
@@ -647,7 +647,7 @@ def _python_to_sql_converter(dataType):
             if isinstance(obj, dict):
                 return tuple(c(obj.get(n)) for n, c in zip(names, converters))
             elif isinstance(obj, tuple):
-                if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
+                if hasattr(obj, "_fields") or hasattr(obj, "__fields__"):
                     return tuple(c(v) for c, v in zip(converters, obj))
                 elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):  # k-v pairs
                     d = dict(obj)
@@ -997,12 +997,13 @@ def _restore_object(dataType, obj):
     # same object in most cases.
     k = id(dataType)
     cls = _cached_cls.get(k)
-    if cls is None:
+    if cls is None or cls.__datatype is not dataType:
         # use dataType as key to avoid create multiple class
         cls = _cached_cls.get(dataType)
         if cls is None:
             cls = _create_cls(dataType)
             _cached_cls[dataType] = cls
+        cls.__datatype = dataType
         _cached_cls[k] = cls
     return cls(obj)
 
@@ -1119,8 +1120,8 @@ def _create_cls(dataType):
     class Row(tuple):
 
         """ Row in DataFrame """
-        __DATATYPE__ = dataType
-        __FIELDS__ = tuple(f.name for f in dataType.fields)
+        __datatype = dataType
+        __fields__ = tuple(f.name for f in dataType.fields)
         __slots__ = ()
 
         # create property for fast access
@@ -1128,22 +1129,22 @@ def _create_cls(dataType):
 
         def asDict(self):
             """ Return as a dict """
-            return dict((n, getattr(self, n)) for n in self.__FIELDS__)
+            return dict((n, getattr(self, n)) for n in self.__fields__)
 
         def __repr__(self):
             # call collect __repr__ for nested objects
             return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
-                                          for n in self.__FIELDS__))
+                                          for n in self.__fields__))
 
         def __reduce__(self):
-            return (_restore_object, (self.__DATATYPE__, tuple(self)))
+            return (_restore_object, (self.__datatype, tuple(self)))
 
     return Row
 
 
 def _create_row(fields, values):
     row = Row(*values)
-    row.__FIELDS__ = fields
+    row.__fields__ = fields
     return row
 
 
@@ -1183,7 +1184,7 @@ class Row(tuple):
             # create row objects
             names = sorted(kwargs.keys())
             row = tuple.__new__(self, [kwargs[n] for n in names])
-            row.__FIELDS__ = names
+            row.__fields__ = names
             return row
 
         else:
@@ -1193,11 +1194,11 @@ class Row(tuple):
         """
         Return as an dict
         """
-        if not hasattr(self, "__FIELDS__"):
+        if not hasattr(self, "__fields__"):
             raise TypeError("Cannot convert a Row class into dict")
-        return dict(zip(self.__FIELDS__, self))
+        return dict(zip(self.__fields__, self))
 
-    # let obect acs like class
+    # let object acts like class
     def __call__(self, *args):
         """create new Row object"""
         return _create_row(self, args)
@@ -1208,21 +1209,21 @@ class Row(tuple):
         try:
             # it will be slow when it has many fields,
             # but this will not be used in normal cases
-            idx = self.__FIELDS__.index(item)
+            idx = self.__fields__.index(item)
             return self[idx]
         except IndexError:
             raise AttributeError(item)
 
     def __reduce__(self):
-        if hasattr(self, "__FIELDS__"):
-            return (_create_row, (self.__FIELDS__, tuple(self)))
+        if hasattr(self, "__fields__"):
+            return (_create_row, (self.__fields__, tuple(self)))
         else:
             return tuple.__reduce__(self)
 
     def __repr__(self):
-        if hasattr(self, "__FIELDS__"):
+        if hasattr(self, "__fields__"):
             return "Row(%s)" % ", ".join("%s=%r" % (k, v)
-                                         for k, v in zip(self.__FIELDS__, self))
+                                         for k, v in zip(self.__fields__, tuple(self)))
         else:
             return "<Row(%s)>" % ", ".join(self)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org