You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/30 07:30:53 UTC

spark git commit: [SPARK-9116] [SQL] [PYSPARK] support Python only UDT in __main__

Repository: spark
Updated Branches:
  refs/heads/master f5dd11339 -> e044705b4


[SPARK-9116] [SQL] [PYSPARK] support Python only UDT in __main__

Also we could create a Python UDT without having a Scala one, it's important for Python users.

cc mengxr JoshRosen

Author: Davies Liu <da...@databricks.com>

Closes #7453 from davies/class_in_main and squashes the following commits:

4dfd5e1 [Davies Liu] add tests for Python and Scala UDT
793d9b2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main
dc65f19 [Davies Liu] address comment
a9a3c40 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main
a86e1fc [Davies Liu] fix serialization
ad528ba [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main
63f52ef [Davies Liu] fix pylint check
655b8a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main
316a394 [Davies Liu] support Python UDT with UTF
0bcb3ef [Davies Liu] fix bug in mllib
de986d6 [Davies Liu] fix test
83d65ac [Davies Liu] fix bug in StructType
55bb86e [Davies Liu] support Python UDT in __main__ (without Scala one)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e044705b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e044705b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e044705b

Branch: refs/heads/master
Commit: e044705b4402f86d0557ecd146f3565388c7eeb4
Parents: f5dd113
Author: Davies Liu <da...@databricks.com>
Authored: Wed Jul 29 22:30:49 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Jul 29 22:30:49 2015 -0700

----------------------------------------------------------------------
 pylintrc                                        |   2 +-
 python/pyspark/cloudpickle.py                   |  38 ++++++-
 python/pyspark/shuffle.py                       |   2 +-
 python/pyspark/sql/context.py                   | 108 +++++++++++-------
 python/pyspark/sql/tests.py                     | 112 +++++++++++++++++--
 python/pyspark/sql/types.py                     |  78 ++++++-------
 .../org/apache/spark/sql/types/DataType.scala   |   9 ++
 .../spark/sql/types/UserDefinedType.scala       |  29 +++++
 .../apache/spark/sql/execution/pythonUDFs.scala |   1 -
 9 files changed, 286 insertions(+), 93 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/pylintrc
----------------------------------------------------------------------
diff --git a/pylintrc b/pylintrc
index 0617759..6a67577 100644
--- a/pylintrc
+++ b/pylintrc
@@ -84,7 +84,7 @@ enable=
 # If you would like to improve the code quality of pyspark, remove any of these disabled errors
 # run ./dev/lint-python and see if the errors raised by pylint can be fixed.
 
-disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument
 ,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable
+disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument
 ,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable
 
 
 [REPORTS]

http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/python/pyspark/cloudpickle.py
----------------------------------------------------------------------
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 9ef9307..3b64798 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -350,7 +350,26 @@ class CloudPickler(Pickler):
             if new_override:
                 d['__new__'] = obj.__new__
 
-            self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
+            self.save(_load_class)
+            self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj)
+            d.pop('__doc__', None)
+            # handle property and staticmethod
+            dd = {}
+            for k, v in d.items():
+                if isinstance(v, property):
+                    k = ('property', k)
+                    v = (v.fget, v.fset, v.fdel, v.__doc__)
+                elif isinstance(v, staticmethod) and hasattr(v, '__func__'):
+                    k = ('staticmethod', k)
+                    v = v.__func__
+                elif isinstance(v, classmethod) and hasattr(v, '__func__'):
+                    k = ('classmethod', k)
+                    v = v.__func__
+                dd[k] = v
+            self.save(dd)
+            self.write(pickle.TUPLE2)
+            self.write(pickle.REDUCE)
+
         else:
             raise pickle.PicklingError("Can't pickle %r" % obj)
 
@@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None):
                               None, None, closure)
 
 
+def _load_class(cls, d):
+    """
+    Loads additional properties into class `cls`.
+    """
+    for k, v in d.items():
+        if isinstance(k, tuple):
+            typ, k = k
+            if typ == 'property':
+                v = property(*v)
+            elif typ == 'staticmethod':
+                v = staticmethod(v)
+            elif typ == 'classmethod':
+                v = classmethod(v)
+        setattr(cls, k, v)
+    return cls
+
+
 """Constructors for 3rd party libraries
 Note: These can never be renamed due to client compatibility issues"""
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/python/pyspark/shuffle.py
----------------------------------------------------------------------
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 8fb71ba..b8118bd 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -606,7 +606,7 @@ class ExternalList(object):
         if not os.path.exists(d):
             os.makedirs(d)
         p = os.path.join(d, str(id(self)))
-        self._file = open(p, "wb+", 65536)
+        self._file = open(p, "w+b", 65536)
         self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
         os.unlink(p)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index abb6522..917de24 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -277,6 +277,66 @@ class SQLContext(object):
 
         return self.createDataFrame(rdd, schema)
 
+    def _createFromRDD(self, rdd, schema, samplingRatio):
+        """
+        Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
+        """
+        if schema is None or isinstance(schema, (list, tuple)):
+            struct = self._inferSchema(rdd, samplingRatio)
+            converter = _create_converter(struct)
+            rdd = rdd.map(converter)
+            if isinstance(schema, (list, tuple)):
+                for i, name in enumerate(schema):
+                    struct.fields[i].name = name
+                    struct.names[i] = name
+            schema = struct
+
+        elif isinstance(schema, StructType):
+            # take the first few rows to verify schema
+            rows = rdd.take(10)
+            for row in rows:
+                _verify_type(row, schema)
+
+        else:
+            raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
+
+        # convert python objects to sql data
+        rdd = rdd.map(schema.toInternal)
+        return rdd, schema
+
+    def _createFromLocal(self, data, schema):
+        """
+        Create an RDD for DataFrame from an list or pandas.DataFrame, returns
+        the RDD and schema.
+        """
+        if has_pandas and isinstance(data, pandas.DataFrame):
+            if schema is None:
+                schema = [str(x) for x in data.columns]
+            data = [r.tolist() for r in data.to_records(index=False)]
+
+        # make sure data could consumed multiple times
+        if not isinstance(data, list):
+            data = list(data)
+
+        if schema is None or isinstance(schema, (list, tuple)):
+            struct = self._inferSchemaFromList(data)
+            if isinstance(schema, (list, tuple)):
+                for i, name in enumerate(schema):
+                    struct.fields[i].name = name
+                    struct.names[i] = name
+            schema = struct
+
+        elif isinstance(schema, StructType):
+            for row in data:
+                _verify_type(row, schema)
+
+        else:
+            raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
+
+        # convert python objects to sql data
+        data = [schema.toInternal(row) for row in data]
+        return self._sc.parallelize(data), schema
+
     @since(1.3)
     @ignore_unicode_prefix
     def createDataFrame(self, data, schema=None, samplingRatio=None):
@@ -340,49 +400,15 @@ class SQLContext(object):
         if isinstance(data, DataFrame):
             raise TypeError("data is already a DataFrame")
 
-        if has_pandas and isinstance(data, pandas.DataFrame):
-            if schema is None:
-                schema = [str(x) for x in data.columns]
-            data = [r.tolist() for r in data.to_records(index=False)]
-
-        if not isinstance(data, RDD):
-            if not isinstance(data, list):
-                data = list(data)
-            try:
-                # data could be list, tuple, generator ...
-                rdd = self._sc.parallelize(data)
-            except Exception:
-                raise TypeError("cannot create an RDD from type: %s" % type(data))
+        if isinstance(data, RDD):
+            rdd, schema = self._createFromRDD(data, schema, samplingRatio)
         else:
-            rdd = data
-
-        if schema is None or isinstance(schema, (list, tuple)):
-            if isinstance(data, RDD):
-                struct = self._inferSchema(rdd, samplingRatio)
-            else:
-                struct = self._inferSchemaFromList(data)
-            if isinstance(schema, (list, tuple)):
-                for i, name in enumerate(schema):
-                    struct.fields[i].name = name
-            schema = struct
-            converter = _create_converter(schema)
-            rdd = rdd.map(converter)
-
-        elif isinstance(schema, StructType):
-            # take the first few rows to verify schema
-            rows = rdd.take(10)
-            for row in rows:
-                _verify_type(row, schema)
-
-        else:
-            raise TypeError("schema should be StructType or list or None")
-
-        # convert python objects to sql data
-        rdd = rdd.map(schema.toInternal)
-
+            rdd, schema = self._createFromLocal(data, schema)
         jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
-        df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
-        return DataFrame(df, self)
+        jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+        df = DataFrame(jdf, self)
+        df._schema = schema
+        return df
 
     @since(1.3)
     def registerDataFrameAsTable(self, df, tableName):

http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 5aa6135..ebd3ea8 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -75,7 +75,7 @@ class ExamplePointUDT(UserDefinedType):
 
     @classmethod
     def module(cls):
-        return 'pyspark.tests'
+        return 'pyspark.sql.tests'
 
     @classmethod
     def scalaUDT(cls):
@@ -106,10 +106,45 @@ class ExamplePoint:
         return "(%s,%s)" % (self.x, self.y)
 
     def __eq__(self, other):
-        return isinstance(other, ExamplePoint) and \
+        return isinstance(other, self.__class__) and \
             other.x == self.x and other.y == self.y
 
 
+class PythonOnlyUDT(UserDefinedType):
+    """
+    User-defined type (UDT) for ExamplePoint.
+    """
+
+    @classmethod
+    def sqlType(self):
+        return ArrayType(DoubleType(), False)
+
+    @classmethod
+    def module(cls):
+        return '__main__'
+
+    def serialize(self, obj):
+        return [obj.x, obj.y]
+
+    def deserialize(self, datum):
+        return PythonOnlyPoint(datum[0], datum[1])
+
+    @staticmethod
+    def foo():
+        pass
+
+    @property
+    def props(self):
+        return {}
+
+
+class PythonOnlyPoint(ExamplePoint):
+    """
+    An example class to demonstrate UDT in only Python
+    """
+    __UDT__ = PythonOnlyUDT()
+
+
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
     def test_data_type_eq(self):
@@ -395,10 +430,39 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(1, row.asDict()["l"][0].a)
         self.assertEqual(1.0, row.asDict()['d']['key'].c)
 
+    def test_udt(self):
+        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
+        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
+
+        def check_datatype(datatype):
+            pickled = pickle.loads(pickle.dumps(datatype))
+            assert datatype == pickled
+            scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
+            python_datatype = _parse_datatype_json_string(scala_datatype.json())
+            assert datatype == python_datatype
+
+        check_datatype(ExamplePointUDT())
+        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+                                          StructField("point", ExamplePointUDT(), False)])
+        check_datatype(structtype_with_udt)
+        p = ExamplePoint(1.0, 2.0)
+        self.assertEqual(_infer_type(p), ExamplePointUDT())
+        _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
+
+        check_datatype(PythonOnlyUDT())
+        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+                                          StructField("point", PythonOnlyUDT(), False)])
+        check_datatype(structtype_with_udt)
+        p = PythonOnlyPoint(1.0, 2.0)
+        self.assertEqual(_infer_type(p), PythonOnlyUDT())
+        _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
+        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
+
     def test_infer_schema_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df = self.sc.parallelize([row]).toDF()
+        df = self.sqlCtx.createDataFrame([row])
         schema = df.schema
         field = [f for f in schema.fields if f.name == "point"][0]
         self.assertEqual(type(field.dataType), ExamplePointUDT)
@@ -406,36 +470,66 @@ class SQLTests(ReusedPySparkTestCase):
         point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
         self.assertEqual(point, ExamplePoint(1.0, 2.0))
 
+        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+        df = self.sqlCtx.createDataFrame([row])
+        schema = df.schema
+        field = [f for f in schema.fields if f.name == "point"][0]
+        self.assertEqual(type(field.dataType), PythonOnlyUDT)
+        df.registerTempTable("labeled_point")
+        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
     def test_apply_schema_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = (1.0, ExamplePoint(1.0, 2.0))
-        rdd = self.sc.parallelize([row])
         schema = StructType([StructField("label", DoubleType(), False),
                              StructField("point", ExamplePointUDT(), False)])
-        df = rdd.toDF(schema)
+        df = self.sqlCtx.createDataFrame([row], schema)
         point = df.head().point
         self.assertEquals(point, ExamplePoint(1.0, 2.0))
 
+        row = (1.0, PythonOnlyPoint(1.0, 2.0))
+        schema = StructType([StructField("label", DoubleType(), False),
+                             StructField("point", PythonOnlyUDT(), False)])
+        df = self.sqlCtx.createDataFrame([row], schema)
+        point = df.head().point
+        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
+
     def test_udf_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df = self.sc.parallelize([row]).toDF()
+        df = self.sqlCtx.createDataFrame([row])
         self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
         udf = UserDefinedFunction(lambda p: p.y, DoubleType())
         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
         udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
         self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
 
+        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+        df = self.sqlCtx.createDataFrame([row])
+        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
+        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
+        self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
+
     def test_parquet_with_udt(self):
-        from pyspark.sql.tests import ExamplePoint
+        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df0 = self.sc.parallelize([row]).toDF()
+        df0 = self.sqlCtx.createDataFrame([row])
         output_dir = os.path.join(self.tempdir.name, "labeled_point")
-        df0.saveAsParquetFile(output_dir)
+        df0.write.parquet(output_dir)
         df1 = self.sqlCtx.parquetFile(output_dir)
         point = df1.head().point
         self.assertEquals(point, ExamplePoint(1.0, 2.0))
 
+        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+        df0 = self.sqlCtx.createDataFrame([row])
+        df0.write.parquet(output_dir, mode='overwrite')
+        df1 = self.sqlCtx.parquetFile(output_dir)
+        point = df1.head().point
+        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
+
     def test_column_operators(self):
         ci = self.df.key
         cs = self.df.value

http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 8859308..0976aea 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -22,6 +22,7 @@ import datetime
 import calendar
 import json
 import re
+import base64
 from array import array
 
 if sys.version >= "3":
@@ -31,6 +32,8 @@ if sys.version >= "3":
 from py4j.protocol import register_input_converter
 from py4j.java_gateway import JavaClass
 
+from pyspark.serializers import CloudPickleSerializer
+
 __all__ = [
     "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
     "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
@@ -458,7 +461,7 @@ class StructType(DataType):
             self.names = [f.name for f in fields]
             assert all(isinstance(f, StructField) for f in fields),\
                 "fields should be a list of StructField"
-        self._needSerializeFields = None
+        self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
 
     def add(self, field, data_type=None, nullable=True, metadata=None):
         """
@@ -501,6 +504,7 @@ class StructType(DataType):
                 data_type_f = data_type
             self.fields.append(StructField(field, data_type_f, nullable, metadata))
             self.names.append(field)
+        self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
         return self
 
     def simpleString(self):
@@ -526,10 +530,7 @@ class StructType(DataType):
         if obj is None:
             return
 
-        if self._needSerializeFields is None:
-            self._needSerializeFields = any(f.needConversion() for f in self.fields)
-
-        if self._needSerializeFields:
+        if self._needSerializeAnyField:
             if isinstance(obj, dict):
                 return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
             elif isinstance(obj, (tuple, list)):
@@ -550,7 +551,10 @@ class StructType(DataType):
         if isinstance(obj, Row):
             # it's already converted by pickler
             return obj
-        values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)]
+        if self._needSerializeAnyField:
+            values = [f.fromInternal(v) for f, v in zip(self.fields, obj)]
+        else:
+            values = obj
         return _create_row(self.names, values)
 
 
@@ -581,9 +585,10 @@ class UserDefinedType(DataType):
     @classmethod
     def scalaUDT(cls):
         """
-        The class name of the paired Scala UDT.
+        The class name of the paired Scala UDT (could be '', if there
+        is no corresponding one).
         """
-        raise NotImplementedError("UDT must have a paired Scala UDT.")
+        return ''
 
     def needConversion(self):
         return True
@@ -622,22 +627,37 @@ class UserDefinedType(DataType):
         return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
 
     def jsonValue(self):
-        schema = {
-            "type": "udt",
-            "class": self.scalaUDT(),
-            "pyClass": "%s.%s" % (self.module(), type(self).__name__),
-            "sqlType": self.sqlType().jsonValue()
-        }
+        if self.scalaUDT():
+            assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT'
+            schema = {
+                "type": "udt",
+                "class": self.scalaUDT(),
+                "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+                "sqlType": self.sqlType().jsonValue()
+            }
+        else:
+            ser = CloudPickleSerializer()
+            b = ser.dumps(type(self))
+            schema = {
+                "type": "udt",
+                "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+                "serializedClass": base64.b64encode(b).decode('utf8'),
+                "sqlType": self.sqlType().jsonValue()
+            }
         return schema
 
     @classmethod
     def fromJson(cls, json):
-        pyUDT = json["pyClass"]
+        pyUDT = str(json["pyClass"])
         split = pyUDT.rfind(".")
         pyModule = pyUDT[:split]
         pyClass = pyUDT[split+1:]
         m = __import__(pyModule, globals(), locals(), [pyClass])
-        UDT = getattr(m, pyClass)
+        if not hasattr(m, pyClass):
+            s = base64.b64decode(json['serializedClass'].encode('utf-8'))
+            UDT = CloudPickleSerializer().loads(s)
+        else:
+            UDT = getattr(m, pyClass)
         return UDT()
 
     def __eq__(self, other):
@@ -696,11 +716,6 @@ def _parse_datatype_json_string(json_string):
     >>> complex_maptype = MapType(complex_structtype,
     ...                           complex_arraytype, False)
     >>> check_datatype(complex_maptype)
-
-    >>> check_datatype(ExamplePointUDT())
-    >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
-    ...                                   StructField("point", ExamplePointUDT(), False)])
-    >>> check_datatype(structtype_with_udt)
     """
     return _parse_datatype_json_value(json.loads(json_string))
 
@@ -752,10 +767,6 @@ if sys.version < "3":
 
 def _infer_type(obj):
     """Infer the DataType from obj
-
-    >>> p = ExamplePoint(1.0, 2.0)
-    >>> _infer_type(p)
-    ExamplePointUDT
     """
     if obj is None:
         return NullType()
@@ -1090,11 +1101,6 @@ def _verify_type(obj, dataType):
     Traceback (most recent call last):
         ...
     ValueError:...
-    >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
-    >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
-    Traceback (most recent call last):
-        ...
-    ValueError:...
     """
     # all objects are nullable
     if obj is None:
@@ -1259,18 +1265,12 @@ register_input_converter(DateConverter())
 def _test():
     import doctest
     from pyspark.context import SparkContext
-    # let doctest run in pyspark.sql.types, so DataTypes can be picklable
-    import pyspark.sql.types
-    from pyspark.sql import Row, SQLContext
-    from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
-    globs = pyspark.sql.types.__dict__.copy()
+    from pyspark.sql import SQLContext
+    globs = globals()
     sc = SparkContext('local[4]', 'PythonTest')
     globs['sc'] = sc
     globs['sqlContext'] = SQLContext(sc)
-    globs['ExamplePoint'] = ExamplePoint
-    globs['ExamplePointUDT'] = ExamplePointUDT
-    (failure_count, test_count) = doctest.testmod(
-        pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
+    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
     globs['sc'].stop()
     if failure_count:
         exit(-1)

http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 591fb26..f4428c2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -142,12 +142,21 @@ object DataType {
     ("type", JString("struct"))) =>
       StructType(fields.map(parseStructField))
 
+    // Scala/Java UDT
     case JSortedObject(
     ("class", JString(udtClass)),
     ("pyClass", _),
     ("sqlType", _),
     ("type", JString("udt"))) =>
       Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
+
+    // Python UDT
+    case JSortedObject(
+    ("pyClass", JString(pyClass)),
+    ("serializedClass", JString(serialized)),
+    ("sqlType", v: JValue),
+    ("type", JString("udt"))) =>
+        new PythonUserDefinedType(parseDataType(v), pyClass, serialized)
   }
 
   private def parseStructField(json: JValue): StructField = json match {

http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index e47cfb4..4305903 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -45,6 +45,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
   /** Paired Python UDT class, if exists. */
   def pyUDT: String = null
 
+  /** Serialized Python UDT class, if exists. */
+  def serializedPyClass: String = null
+
   /**
    * Convert the user type to a SQL datum
    *
@@ -82,3 +85,29 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
   override private[sql] def acceptsType(dataType: DataType) =
     this.getClass == dataType.getClass
 }
+
+/**
+ * ::DeveloperApi::
+ * The user defined type in Python.
+ *
+ * Note: This can only be accessed via Python UDF, or accessed as serialized object.
+ */
+private[sql] class PythonUserDefinedType(
+    val sqlType: DataType,
+    override val pyUDT: String,
+    override val serializedPyClass: String) extends UserDefinedType[Any] {
+
+  /* The serialization is handled by UDT class in Python */
+  override def serialize(obj: Any): Any = obj
+  override def deserialize(datam: Any): Any = datam
+
+  /* There is no Java class for Python UDT */
+  override def userClass: java.lang.Class[Any] = null
+
+  override private[sql] def jsonValue: JValue = {
+    ("type" -> "udt") ~
+      ("pyClass" -> pyUDT) ~
+      ("serializedClass" -> serializedPyClass) ~
+      ("sqlType" -> sqlType.jsonValue)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e044705b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index ec084a2..3c38916 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -267,7 +267,6 @@ object EvaluatePython {
           pickler.save(row.values(i))
           i += 1
         }
-        row.values.foreach(pickler.save)
         out.write(Opcodes.TUPLE)
         out.write(Opcodes.REDUCE)
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org