You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/05/29 23:13:39 UTC

[1/2] spark git commit: [SPARK-7899] [PYSPARK] Fix Python 3 pyspark/sql/types module conflict

Repository: spark
Updated Branches:
  refs/heads/master 5f48e5c33 -> 1c5b19827


http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
new file mode 100644
index 0000000..9e7e9f0
--- /dev/null
+++ b/python/pyspark/sql/types.py
@@ -0,0 +1,1306 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import decimal
+import time
+import datetime
+import keyword
+import warnings
+import json
+import re
+import weakref
+from array import array
+from operator import itemgetter
+
+if sys.version >= "3":
+    long = int
+    unicode = str
+
+from py4j.protocol import register_input_converter
+from py4j.java_gateway import JavaClass
+
+__all__ = [
+    "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
+    "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
+    "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
+
+
+class DataType(object):
+    """Base class for data types."""
+
+    def __repr__(self):
+        return self.__class__.__name__
+
+    def __hash__(self):
+        return hash(str(self))
+
+    def __eq__(self, other):
+        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    @classmethod
+    def typeName(cls):
+        return cls.__name__[:-4].lower()
+
+    def simpleString(self):
+        return self.typeName()
+
+    def jsonValue(self):
+        return self.typeName()
+
+    def json(self):
+        return json.dumps(self.jsonValue(),
+                          separators=(',', ':'),
+                          sort_keys=True)
+
+
+# This singleton pattern does not work with pickle, you will get
+# another object after pickle and unpickle
+class DataTypeSingleton(type):
+    """Metaclass for DataType"""
+
+    _instances = {}
+
+    def __call__(cls):
+        if cls not in cls._instances:
+            cls._instances[cls] = super(DataTypeSingleton, cls).__call__()
+        return cls._instances[cls]
+
+
+class NullType(DataType):
+    """Null type.
+
+    The data type representing None, used for the types that cannot be inferred.
+    """
+
+    __metaclass__ = DataTypeSingleton
+
+
+class AtomicType(DataType):
+    """An internal type used to represent everything that is not
+    null, UDTs, arrays, structs, and maps."""
+
+    __metaclass__ = DataTypeSingleton
+
+
+class NumericType(AtomicType):
+    """Numeric data types.
+    """
+
+
+class IntegralType(NumericType):
+    """Integral data types.
+    """
+
+
+class FractionalType(NumericType):
+    """Fractional data types.
+    """
+
+
+class StringType(AtomicType):
+    """String data type.
+    """
+
+
+class BinaryType(AtomicType):
+    """Binary (byte array) data type.
+    """
+
+
+class BooleanType(AtomicType):
+    """Boolean data type.
+    """
+
+
+class DateType(AtomicType):
+    """Date (datetime.date) data type.
+    """
+
+
+class TimestampType(AtomicType):
+    """Timestamp (datetime.datetime) data type.
+    """
+
+
+class DecimalType(FractionalType):
+    """Decimal (decimal.Decimal) data type.
+    """
+
+    def __init__(self, precision=None, scale=None):
+        self.precision = precision
+        self.scale = scale
+        self.hasPrecisionInfo = precision is not None
+
+    def simpleString(self):
+        if self.hasPrecisionInfo:
+            return "decimal(%d,%d)" % (self.precision, self.scale)
+        else:
+            return "decimal(10,0)"
+
+    def jsonValue(self):
+        if self.hasPrecisionInfo:
+            return "decimal(%d,%d)" % (self.precision, self.scale)
+        else:
+            return "decimal"
+
+    def __repr__(self):
+        if self.hasPrecisionInfo:
+            return "DecimalType(%d,%d)" % (self.precision, self.scale)
+        else:
+            return "DecimalType()"
+
+
+class DoubleType(FractionalType):
+    """Double data type, representing double precision floats.
+    """
+
+
+class FloatType(FractionalType):
+    """Float data type, representing single precision floats.
+    """
+
+
+class ByteType(IntegralType):
+    """Byte data type, i.e. a signed integer in a single byte.
+    """
+    def simpleString(self):
+        return 'tinyint'
+
+
+class IntegerType(IntegralType):
+    """Int data type, i.e. a signed 32-bit integer.
+    """
+    def simpleString(self):
+        return 'int'
+
+
+class LongType(IntegralType):
+    """Long data type, i.e. a signed 64-bit integer.
+
+    If the values are beyond the range of [-9223372036854775808, 9223372036854775807],
+    please use :class:`DecimalType`.
+    """
+    def simpleString(self):
+        return 'bigint'
+
+
+class ShortType(IntegralType):
+    """Short data type, i.e. a signed 16-bit integer.
+    """
+    def simpleString(self):
+        return 'smallint'
+
+
+class ArrayType(DataType):
+    """Array data type.
+
+    :param elementType: :class:`DataType` of each element in the array.
+    :param containsNull: boolean, whether the array can contain null (None) values.
+    """
+
+    def __init__(self, elementType, containsNull=True):
+        """
+        >>> ArrayType(StringType()) == ArrayType(StringType(), True)
+        True
+        >>> ArrayType(StringType(), False) == ArrayType(StringType())
+        False
+        """
+        assert isinstance(elementType, DataType), "elementType should be DataType"
+        self.elementType = elementType
+        self.containsNull = containsNull
+
+    def simpleString(self):
+        return 'array<%s>' % self.elementType.simpleString()
+
+    def __repr__(self):
+        return "ArrayType(%s,%s)" % (self.elementType,
+                                     str(self.containsNull).lower())
+
+    def jsonValue(self):
+        return {"type": self.typeName(),
+                "elementType": self.elementType.jsonValue(),
+                "containsNull": self.containsNull}
+
+    @classmethod
+    def fromJson(cls, json):
+        return ArrayType(_parse_datatype_json_value(json["elementType"]),
+                         json["containsNull"])
+
+
+class MapType(DataType):
+    """Map data type.
+
+    :param keyType: :class:`DataType` of the keys in the map.
+    :param valueType: :class:`DataType` of the values in the map.
+    :param valueContainsNull: indicates whether values can contain null (None) values.
+
+    Keys in a map data type are not allowed to be null (None).
+    """
+
+    def __init__(self, keyType, valueType, valueContainsNull=True):
+        """
+        >>> (MapType(StringType(), IntegerType())
+        ...        == MapType(StringType(), IntegerType(), True))
+        True
+        >>> (MapType(StringType(), IntegerType(), False)
+        ...        == MapType(StringType(), FloatType()))
+        False
+        """
+        assert isinstance(keyType, DataType), "keyType should be DataType"
+        assert isinstance(valueType, DataType), "valueType should be DataType"
+        self.keyType = keyType
+        self.valueType = valueType
+        self.valueContainsNull = valueContainsNull
+
+    def simpleString(self):
+        return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
+
+    def __repr__(self):
+        return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
+                                      str(self.valueContainsNull).lower())
+
+    def jsonValue(self):
+        return {"type": self.typeName(),
+                "keyType": self.keyType.jsonValue(),
+                "valueType": self.valueType.jsonValue(),
+                "valueContainsNull": self.valueContainsNull}
+
+    @classmethod
+    def fromJson(cls, json):
+        return MapType(_parse_datatype_json_value(json["keyType"]),
+                       _parse_datatype_json_value(json["valueType"]),
+                       json["valueContainsNull"])
+
+
+class StructField(DataType):
+    """A field in :class:`StructType`.
+
+    :param name: string, name of the field.
+    :param dataType: :class:`DataType` of the field.
+    :param nullable: boolean, whether the field can be null (None) or not.
+    :param metadata: a dict from string to simple type that can be serialized to JSON automatically
+    """
+
+    def __init__(self, name, dataType, nullable=True, metadata=None):
+        """
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f1", StringType(), True))
+        True
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f2", StringType(), True))
+        False
+        """
+        assert isinstance(dataType, DataType), "dataType should be DataType"
+        self.name = name
+        self.dataType = dataType
+        self.nullable = nullable
+        self.metadata = metadata or {}
+
+    def simpleString(self):
+        return '%s:%s' % (self.name, self.dataType.simpleString())
+
+    def __repr__(self):
+        return "StructField(%s,%s,%s)" % (self.name, self.dataType,
+                                          str(self.nullable).lower())
+
+    def jsonValue(self):
+        return {"name": self.name,
+                "type": self.dataType.jsonValue(),
+                "nullable": self.nullable,
+                "metadata": self.metadata}
+
+    @classmethod
+    def fromJson(cls, json):
+        return StructField(json["name"],
+                           _parse_datatype_json_value(json["type"]),
+                           json["nullable"],
+                           json["metadata"])
+
+
+class StructType(DataType):
+    """Struct type, consisting of a list of :class:`StructField`.
+
+    This is the data type representing a :class:`Row`.
+    """
+
+    def __init__(self, fields):
+        """
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True)])
+        >>> struct1 == struct2
+        True
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True),
+        ...     StructField("f2", IntegerType(), False)])
+        >>> struct1 == struct2
+        False
+        """
+        assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
+        self.fields = fields
+
+    def simpleString(self):
+        return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
+
+    def __repr__(self):
+        return ("StructType(List(%s))" %
+                ",".join(str(field) for field in self.fields))
+
+    def jsonValue(self):
+        return {"type": self.typeName(),
+                "fields": [f.jsonValue() for f in self.fields]}
+
+    @classmethod
+    def fromJson(cls, json):
+        return StructType([StructField.fromJson(f) for f in json["fields"]])
+
+
+class UserDefinedType(DataType):
+    """User-defined type (UDT).
+
+    .. note:: WARN: Spark Internal Use Only
+    """
+
+    @classmethod
+    def typeName(cls):
+        return cls.__name__.lower()
+
+    @classmethod
+    def sqlType(cls):
+        """
+        Underlying SQL storage type for this UDT.
+        """
+        raise NotImplementedError("UDT must implement sqlType().")
+
+    @classmethod
+    def module(cls):
+        """
+        The Python module of the UDT.
+        """
+        raise NotImplementedError("UDT must implement module().")
+
+    @classmethod
+    def scalaUDT(cls):
+        """
+        The class name of the paired Scala UDT.
+        """
+        raise NotImplementedError("UDT must have a paired Scala UDT.")
+
+    def serialize(self, obj):
+        """
+        Converts the a user-type object into a SQL datum.
+        """
+        raise NotImplementedError("UDT must implement serialize().")
+
+    def deserialize(self, datum):
+        """
+        Converts a SQL datum into a user-type object.
+        """
+        raise NotImplementedError("UDT must implement deserialize().")
+
+    def simpleString(self):
+        return 'udt'
+
+    def json(self):
+        return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
+
+    def jsonValue(self):
+        schema = {
+            "type": "udt",
+            "class": self.scalaUDT(),
+            "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+            "sqlType": self.sqlType().jsonValue()
+        }
+        return schema
+
+    @classmethod
+    def fromJson(cls, json):
+        pyUDT = json["pyClass"]
+        split = pyUDT.rfind(".")
+        pyModule = pyUDT[:split]
+        pyClass = pyUDT[split+1:]
+        m = __import__(pyModule, globals(), locals(), [pyClass])
+        UDT = getattr(m, pyClass)
+        return UDT()
+
+    def __eq__(self, other):
+        return type(self) == type(other)
+
+
+_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
+                 ByteType, ShortType, IntegerType, LongType, DateType, TimestampType]
+_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
+_all_complex_types = dict((v.typeName(), v)
+                          for v in [ArrayType, MapType, StructType])
+
+
+def _parse_datatype_json_string(json_string):
+    """Parses the given data type JSON string.
+    >>> import pickle
+    >>> def check_datatype(datatype):
+    ...     pickled = pickle.loads(pickle.dumps(datatype))
+    ...     assert datatype == pickled
+    ...     scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json())
+    ...     python_datatype = _parse_datatype_json_string(scala_datatype.json())
+    ...     assert datatype == python_datatype
+    >>> for cls in _all_atomic_types.values():
+    ...     check_datatype(cls())
+
+    >>> # Simple ArrayType.
+    >>> simple_arraytype = ArrayType(StringType(), True)
+    >>> check_datatype(simple_arraytype)
+
+    >>> # Simple MapType.
+    >>> simple_maptype = MapType(StringType(), LongType())
+    >>> check_datatype(simple_maptype)
+
+    >>> # Simple StructType.
+    >>> simple_structtype = StructType([
+    ...     StructField("a", DecimalType(), False),
+    ...     StructField("b", BooleanType(), True),
+    ...     StructField("c", LongType(), True),
+    ...     StructField("d", BinaryType(), False)])
+    >>> check_datatype(simple_structtype)
+
+    >>> # Complex StructType.
+    >>> complex_structtype = StructType([
+    ...     StructField("simpleArray", simple_arraytype, True),
+    ...     StructField("simpleMap", simple_maptype, True),
+    ...     StructField("simpleStruct", simple_structtype, True),
+    ...     StructField("boolean", BooleanType(), False),
+    ...     StructField("withMeta", DoubleType(), False, {"name": "age"})])
+    >>> check_datatype(complex_structtype)
+
+    >>> # Complex ArrayType.
+    >>> complex_arraytype = ArrayType(complex_structtype, True)
+    >>> check_datatype(complex_arraytype)
+
+    >>> # Complex MapType.
+    >>> complex_maptype = MapType(complex_structtype,
+    ...                           complex_arraytype, False)
+    >>> check_datatype(complex_maptype)
+
+    >>> check_datatype(ExamplePointUDT())
+    >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+    ...                                   StructField("point", ExamplePointUDT(), False)])
+    >>> check_datatype(structtype_with_udt)
+    """
+    return _parse_datatype_json_value(json.loads(json_string))
+
+
+_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
+
+
+def _parse_datatype_json_value(json_value):
+    if not isinstance(json_value, dict):
+        if json_value in _all_atomic_types.keys():
+            return _all_atomic_types[json_value]()
+        elif json_value == 'decimal':
+            return DecimalType()
+        elif _FIXED_DECIMAL.match(json_value):
+            m = _FIXED_DECIMAL.match(json_value)
+            return DecimalType(int(m.group(1)), int(m.group(2)))
+        else:
+            raise ValueError("Could not parse datatype: %s" % json_value)
+    else:
+        tpe = json_value["type"]
+        if tpe in _all_complex_types:
+            return _all_complex_types[tpe].fromJson(json_value)
+        elif tpe == 'udt':
+            return UserDefinedType.fromJson(json_value)
+        else:
+            raise ValueError("not supported type: %s" % tpe)
+
+
+# Mapping Python types to Spark SQL DataType
+_type_mappings = {
+    type(None): NullType,
+    bool: BooleanType,
+    int: LongType,
+    float: DoubleType,
+    str: StringType,
+    bytearray: BinaryType,
+    decimal.Decimal: DecimalType,
+    datetime.date: DateType,
+    datetime.datetime: TimestampType,
+    datetime.time: TimestampType,
+}
+
+if sys.version < "3":
+    _type_mappings.update({
+        unicode: StringType,
+        long: LongType,
+    })
+
+
+def _infer_type(obj):
+    """Infer the DataType from obj
+
+    >>> p = ExamplePoint(1.0, 2.0)
+    >>> _infer_type(p)
+    ExamplePointUDT
+    """
+    if obj is None:
+        return NullType()
+
+    if hasattr(obj, '__UDT__'):
+        return obj.__UDT__
+
+    dataType = _type_mappings.get(type(obj))
+    if dataType is not None:
+        return dataType()
+
+    if isinstance(obj, dict):
+        for key, value in obj.items():
+            if key is not None and value is not None:
+                return MapType(_infer_type(key), _infer_type(value), True)
+        else:
+            return MapType(NullType(), NullType(), True)
+    elif isinstance(obj, (list, array)):
+        for v in obj:
+            if v is not None:
+                return ArrayType(_infer_type(obj[0]), True)
+        else:
+            return ArrayType(NullType(), True)
+    else:
+        try:
+            return _infer_schema(obj)
+        except TypeError:
+            raise TypeError("not supported type: %s" % type(obj))
+
+
+def _infer_schema(row):
+    """Infer the schema from dict/namedtuple/object"""
+    if isinstance(row, dict):
+        items = sorted(row.items())
+
+    elif isinstance(row, (tuple, list)):
+        if hasattr(row, "__fields__"):  # Row
+            items = zip(row.__fields__, tuple(row))
+        elif hasattr(row, "_fields"):  # namedtuple
+            items = zip(row._fields, tuple(row))
+        else:
+            names = ['_%d' % i for i in range(1, len(row) + 1)]
+            items = zip(names, row)
+
+    elif hasattr(row, "__dict__"):  # object
+        items = sorted(row.__dict__.items())
+
+    else:
+        raise TypeError("Can not infer schema for type: %s" % type(row))
+
+    fields = [StructField(k, _infer_type(v), True) for k, v in items]
+    return StructType(fields)
+
+
+def _need_python_to_sql_conversion(dataType):
+    """
+    Checks whether we need python to sql conversion for the given type.
+    For now, only UDTs need this conversion.
+
+    >>> _need_python_to_sql_conversion(DoubleType())
+    False
+    >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
+    ...                       StructField("values", ArrayType(DoubleType(), False), False)])
+    >>> _need_python_to_sql_conversion(schema0)
+    False
+    >>> _need_python_to_sql_conversion(ExamplePointUDT())
+    True
+    >>> schema1 = ArrayType(ExamplePointUDT(), False)
+    >>> _need_python_to_sql_conversion(schema1)
+    True
+    >>> schema2 = StructType([StructField("label", DoubleType(), False),
+    ...                       StructField("point", ExamplePointUDT(), False)])
+    >>> _need_python_to_sql_conversion(schema2)
+    True
+    """
+    if isinstance(dataType, StructType):
+        return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+    elif isinstance(dataType, ArrayType):
+        return _need_python_to_sql_conversion(dataType.elementType)
+    elif isinstance(dataType, MapType):
+        return _need_python_to_sql_conversion(dataType.keyType) or \
+            _need_python_to_sql_conversion(dataType.valueType)
+    elif isinstance(dataType, UserDefinedType):
+        return True
+    else:
+        return False
+
+
+def _python_to_sql_converter(dataType):
+    """
+    Returns a converter that converts a Python object into a SQL datum for the given type.
+
+    >>> conv = _python_to_sql_converter(DoubleType())
+    >>> conv(1.0)
+    1.0
+    >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
+    >>> conv([1.0, 2.0])
+    [1.0, 2.0]
+    >>> conv = _python_to_sql_converter(ExamplePointUDT())
+    >>> conv(ExamplePoint(1.0, 2.0))
+    [1.0, 2.0]
+    >>> schema = StructType([StructField("label", DoubleType(), False),
+    ...                      StructField("point", ExamplePointUDT(), False)])
+    >>> conv = _python_to_sql_converter(schema)
+    >>> conv((1.0, ExamplePoint(1.0, 2.0)))
+    (1.0, [1.0, 2.0])
+    """
+    if not _need_python_to_sql_conversion(dataType):
+        return lambda x: x
+
+    if isinstance(dataType, StructType):
+        names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
+        converters = [_python_to_sql_converter(t) for t in types]
+
+        def converter(obj):
+            if isinstance(obj, dict):
+                return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+            elif isinstance(obj, tuple):
+                if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
+                    return tuple(c(v) for c, v in zip(converters, obj))
+                elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):  # k-v pairs
+                    d = dict(obj)
+                    return tuple(c(d.get(n)) for n, c in zip(names, converters))
+                else:
+                    return tuple(c(v) for c, v in zip(converters, obj))
+            else:
+                raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+        return converter
+    elif isinstance(dataType, ArrayType):
+        element_converter = _python_to_sql_converter(dataType.elementType)
+        return lambda a: [element_converter(v) for v in a]
+    elif isinstance(dataType, MapType):
+        key_converter = _python_to_sql_converter(dataType.keyType)
+        value_converter = _python_to_sql_converter(dataType.valueType)
+        return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+    elif isinstance(dataType, UserDefinedType):
+        return lambda obj: dataType.serialize(obj)
+    else:
+        raise ValueError("Unexpected type %r" % dataType)
+
+
+def _has_nulltype(dt):
+    """ Return whether there is NullType in `dt` or not """
+    if isinstance(dt, StructType):
+        return any(_has_nulltype(f.dataType) for f in dt.fields)
+    elif isinstance(dt, ArrayType):
+        return _has_nulltype((dt.elementType))
+    elif isinstance(dt, MapType):
+        return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
+    else:
+        return isinstance(dt, NullType)
+
+
+def _merge_type(a, b):
+    if isinstance(a, NullType):
+        return b
+    elif isinstance(b, NullType):
+        return a
+    elif type(a) is not type(b):
+        # TODO: type cast (such as int -> long)
+        raise TypeError("Can not merge type %s and %s" % (type(a), type(b)))
+
+    # same type
+    if isinstance(a, StructType):
+        nfs = dict((f.name, f.dataType) for f in b.fields)
+        fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
+                  for f in a.fields]
+        names = set([f.name for f in fields])
+        for n in nfs:
+            if n not in names:
+                fields.append(StructField(n, nfs[n]))
+        return StructType(fields)
+
+    elif isinstance(a, ArrayType):
+        return ArrayType(_merge_type(a.elementType, b.elementType), True)
+
+    elif isinstance(a, MapType):
+        return MapType(_merge_type(a.keyType, b.keyType),
+                       _merge_type(a.valueType, b.valueType),
+                       True)
+    else:
+        return a
+
+
+def _need_converter(dataType):
+    if isinstance(dataType, StructType):
+        return True
+    elif isinstance(dataType, ArrayType):
+        return _need_converter(dataType.elementType)
+    elif isinstance(dataType, MapType):
+        return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
+    elif isinstance(dataType, NullType):
+        return True
+    else:
+        return False
+
+
+def _create_converter(dataType):
+    """Create an converter to drop the names of fields in obj """
+    if not _need_converter(dataType):
+        return lambda x: x
+
+    if isinstance(dataType, ArrayType):
+        conv = _create_converter(dataType.elementType)
+        return lambda row: [conv(v) for v in row]
+
+    elif isinstance(dataType, MapType):
+        kconv = _create_converter(dataType.keyType)
+        vconv = _create_converter(dataType.valueType)
+        return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
+
+    elif isinstance(dataType, NullType):
+        return lambda x: None
+
+    elif not isinstance(dataType, StructType):
+        return lambda x: x
+
+    # dataType must be StructType
+    names = [f.name for f in dataType.fields]
+    converters = [_create_converter(f.dataType) for f in dataType.fields]
+    convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
+
+    def convert_struct(obj):
+        if obj is None:
+            return
+
+        if isinstance(obj, (tuple, list)):
+            if convert_fields:
+                return tuple(conv(v) for v, conv in zip(obj, converters))
+            else:
+                return tuple(obj)
+
+        if isinstance(obj, dict):
+            d = obj
+        elif hasattr(obj, "__dict__"):  # object
+            d = obj.__dict__
+        else:
+            raise TypeError("Unexpected obj type: %s" % type(obj))
+
+        if convert_fields:
+            return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+        else:
+            return tuple([d.get(name) for name in names])
+
+    return convert_struct
+
+
+_BRACKETS = {'(': ')', '[': ']', '{': '}'}
+
+
+def _split_schema_abstract(s):
+    """
+    split the schema abstract into fields
+
+    >>> _split_schema_abstract("a b  c")
+    ['a', 'b', 'c']
+    >>> _split_schema_abstract("a(a b)")
+    ['a(a b)']
+    >>> _split_schema_abstract("a b[] c{a b}")
+    ['a', 'b[]', 'c{a b}']
+    >>> _split_schema_abstract(" ")
+    []
+    """
+
+    r = []
+    w = ''
+    brackets = []
+    for c in s:
+        if c == ' ' and not brackets:
+            if w:
+                r.append(w)
+            w = ''
+        else:
+            w += c
+            if c in _BRACKETS:
+                brackets.append(c)
+            elif c in _BRACKETS.values():
+                if not brackets or c != _BRACKETS[brackets.pop()]:
+                    raise ValueError("unexpected " + c)
+
+    if brackets:
+        raise ValueError("brackets not closed: %s" % brackets)
+    if w:
+        r.append(w)
+    return r
+
+
+def _parse_field_abstract(s):
+    """
+    Parse a field in schema abstract
+
+    >>> _parse_field_abstract("a")
+    StructField(a,NullType,true)
+    >>> _parse_field_abstract("b(c d)")
+    StructField(b,StructType(...c,NullType,true),StructField(d...
+    >>> _parse_field_abstract("a[]")
+    StructField(a,ArrayType(NullType,true),true)
+    >>> _parse_field_abstract("a{[]}")
+    StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
+    """
+    if set(_BRACKETS.keys()) & set(s):
+        idx = min((s.index(c) for c in _BRACKETS if c in s))
+        name = s[:idx]
+        return StructField(name, _parse_schema_abstract(s[idx:]), True)
+    else:
+        return StructField(s, NullType(), True)
+
+
+def _parse_schema_abstract(s):
+    """
+    parse abstract into schema
+
+    >>> _parse_schema_abstract("a b  c")
+    StructType...a...b...c...
+    >>> _parse_schema_abstract("a[b c] b{}")
+    StructType...a,ArrayType...b...c...b,MapType...
+    >>> _parse_schema_abstract("c{} d{a b}")
+    StructType...c,MapType...d,MapType...a...b...
+    >>> _parse_schema_abstract("a b(t)").fields[1]
+    StructField(b,StructType(List(StructField(t,NullType,true))),true)
+    """
+    s = s.strip()
+    if not s:
+        return NullType()
+
+    elif s.startswith('('):
+        return _parse_schema_abstract(s[1:-1])
+
+    elif s.startswith('['):
+        return ArrayType(_parse_schema_abstract(s[1:-1]), True)
+
+    elif s.startswith('{'):
+        return MapType(NullType(), _parse_schema_abstract(s[1:-1]))
+
+    parts = _split_schema_abstract(s)
+    fields = [_parse_field_abstract(p) for p in parts]
+    return StructType(fields)
+
+
+def _infer_schema_type(obj, dataType):
+    """
+    Fill the dataType with types inferred from obj
+
+    >>> schema = _parse_schema_abstract("a b c d")
+    >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
+    >>> _infer_schema_type(row, schema)
+    StructType...LongType...DoubleType...StringType...DateType...
+    >>> row = [[1], {"key": (1, 2.0)}]
+    >>> schema = _parse_schema_abstract("a[] b{c d}")
+    >>> _infer_schema_type(row, schema)
+    StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
+    """
+    if isinstance(dataType, NullType):
+        return _infer_type(obj)
+
+    if not obj:
+        return NullType()
+
+    if isinstance(dataType, ArrayType):
+        eType = _infer_schema_type(obj[0], dataType.elementType)
+        return ArrayType(eType, True)
+
+    elif isinstance(dataType, MapType):
+        k, v = next(iter(obj.items()))
+        return MapType(_infer_schema_type(k, dataType.keyType),
+                       _infer_schema_type(v, dataType.valueType))
+
+    elif isinstance(dataType, StructType):
+        fs = dataType.fields
+        assert len(fs) == len(obj), \
+            "Obj(%s) have different length with fields(%s)" % (obj, fs)
+        fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True)
+                  for o, f in zip(obj, fs)]
+        return StructType(fields)
+
+    else:
+        raise TypeError("Unexpected dataType: %s" % type(dataType))
+
+
+_acceptable_types = {
+    BooleanType: (bool,),
+    ByteType: (int, long),
+    ShortType: (int, long),
+    IntegerType: (int, long),
+    LongType: (int, long),
+    FloatType: (float,),
+    DoubleType: (float,),
+    DecimalType: (decimal.Decimal,),
+    StringType: (str, unicode),
+    BinaryType: (bytearray,),
+    DateType: (datetime.date, datetime.datetime),
+    TimestampType: (datetime.datetime,),
+    ArrayType: (list, tuple, array),
+    MapType: (dict,),
+    StructType: (tuple, list),
+}
+
+
+def _verify_type(obj, dataType):
+    """
+    Verify the type of obj against dataType, raise an exception if
+    they do not match.
+
+    >>> _verify_type(None, StructType([]))
+    >>> _verify_type("", StringType())
+    >>> _verify_type(0, LongType())
+    >>> _verify_type(list(range(3)), ArrayType(ShortType()))
+    >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    TypeError:...
+    >>> _verify_type({}, MapType(StringType(), IntegerType()))
+    >>> _verify_type((), StructType([]))
+    >>> _verify_type([], StructType([]))
+    >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ValueError:...
+    >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+    >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ValueError:...
+    """
+    # all objects are nullable
+    if obj is None:
+        return
+
+    if isinstance(dataType, UserDefinedType):
+        if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
+            raise ValueError("%r is not an instance of type %r" % (obj, dataType))
+        _verify_type(dataType.serialize(obj), dataType.sqlType())
+        return
+
+    _type = type(dataType)
+    assert _type in _acceptable_types, "unknown datatype: %s" % dataType
+
+    # subclass of them can not be deserialized in JVM
+    if type(obj) not in _acceptable_types[_type]:
+        raise TypeError("%s can not accept object in type %s"
+                        % (dataType, type(obj)))
+
+    if isinstance(dataType, ArrayType):
+        for i in obj:
+            _verify_type(i, dataType.elementType)
+
+    elif isinstance(dataType, MapType):
+        for k, v in obj.items():
+            _verify_type(k, dataType.keyType)
+            _verify_type(v, dataType.valueType)
+
+    elif isinstance(dataType, StructType):
+        if len(obj) != len(dataType.fields):
+            raise ValueError("Length of object (%d) does not match with "
+                             "length of fields (%d)" % (len(obj), len(dataType.fields)))
+        for v, f in zip(obj, dataType.fields):
+            _verify_type(v, f.dataType)
+
+_cached_cls = weakref.WeakValueDictionary()
+
+
+def _restore_object(dataType, obj):
+    """ Restore object during unpickling. """
+    # use id(dataType) as key to speed up lookup in dict
+    # Because of batched pickling, dataType will be the
+    # same object in most cases.
+    k = id(dataType)
+    cls = _cached_cls.get(k)
+    if cls is None or cls.__datatype is not dataType:
+        # use dataType as key to avoid create multiple class
+        cls = _cached_cls.get(dataType)
+        if cls is None:
+            cls = _create_cls(dataType)
+            _cached_cls[dataType] = cls
+        cls.__datatype = dataType
+        _cached_cls[k] = cls
+    return cls(obj)
+
+
+def _create_object(cls, v):
+    """ Create an customized object with class `cls`. """
+    # datetime.date would be deserialized as datetime.datetime
+    # from java type, so we need to set it back.
+    if cls is datetime.date and isinstance(v, datetime.datetime):
+        return v.date()
+    return cls(v) if v is not None else v
+
+
+def _create_getter(dt, i):
+    """ Create a getter for item `i` with schema """
+    cls = _create_cls(dt)
+
+    def getter(self):
+        return _create_object(cls, self[i])
+
+    return getter
+
+
+def _has_struct_or_date(dt):
+    """Return whether `dt` is or has StructType/DateType in it"""
+    if isinstance(dt, StructType):
+        return True
+    elif isinstance(dt, ArrayType):
+        return _has_struct_or_date(dt.elementType)
+    elif isinstance(dt, MapType):
+        return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
+    elif isinstance(dt, DateType):
+        return True
+    elif isinstance(dt, UserDefinedType):
+        return True
+    return False
+
+
+def _create_properties(fields):
+    """Create properties according to fields"""
+    ps = {}
+    for i, f in enumerate(fields):
+        name = f.name
+        if (name.startswith("__") and name.endswith("__")
+                or keyword.iskeyword(name)):
+            warnings.warn("field name %s can not be accessed in Python,"
+                          "use position to access it instead" % name)
+        if _has_struct_or_date(f.dataType):
+            # delay creating object until accessing it
+            getter = _create_getter(f.dataType, i)
+        else:
+            getter = itemgetter(i)
+        ps[name] = property(getter)
+    return ps
+
+
+def _create_cls(dataType):
+    """
+    Create an class by dataType
+
+    The created class is similar to namedtuple, but can have nested schema.
+
+    >>> schema = _parse_schema_abstract("a b c")
+    >>> row = (1, 1.0, "str")
+    >>> schema = _infer_schema_type(row, schema)
+    >>> obj = _create_cls(schema)(row)
+    >>> import pickle
+    >>> pickle.loads(pickle.dumps(obj))
+    Row(a=1, b=1.0, c='str')
+
+    >>> row = [[1], {"key": (1, 2.0)}]
+    >>> schema = _parse_schema_abstract("a[] b{c d}")
+    >>> schema = _infer_schema_type(row, schema)
+    >>> obj = _create_cls(schema)(row)
+    >>> pickle.loads(pickle.dumps(obj))
+    Row(a=[1], b={'key': Row(c=1, d=2.0)})
+    >>> pickle.loads(pickle.dumps(obj.a))
+    [1]
+    >>> pickle.loads(pickle.dumps(obj.b))
+    {'key': Row(c=1, d=2.0)}
+    """
+
+    if isinstance(dataType, ArrayType):
+        cls = _create_cls(dataType.elementType)
+
+        def List(l):
+            if l is None:
+                return
+            return [_create_object(cls, v) for v in l]
+
+        return List
+
+    elif isinstance(dataType, MapType):
+        kcls = _create_cls(dataType.keyType)
+        vcls = _create_cls(dataType.valueType)
+
+        def Dict(d):
+            if d is None:
+                return
+            return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
+
+        return Dict
+
+    elif isinstance(dataType, DateType):
+        return datetime.date
+
+    elif isinstance(dataType, UserDefinedType):
+        return lambda datum: dataType.deserialize(datum)
+
+    elif not isinstance(dataType, StructType):
+        # no wrapper for atomic types
+        return lambda x: x
+
+    class Row(tuple):
+
+        """ Row in DataFrame """
+        __datatype = dataType
+        __fields__ = tuple(f.name for f in dataType.fields)
+        __slots__ = ()
+
+        # create property for fast access
+        locals().update(_create_properties(dataType.fields))
+
+        def asDict(self):
+            """ Return as a dict """
+            return dict((n, getattr(self, n)) for n in self.__fields__)
+
+        def __repr__(self):
+            # call collect __repr__ for nested objects
+            return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
+                                          for n in self.__fields__))
+
+        def __reduce__(self):
+            return (_restore_object, (self.__datatype, tuple(self)))
+
+    return Row
+
+
+def _create_row(fields, values):
+    row = Row(*values)
+    row.__fields__ = fields
+    return row
+
+
+class Row(tuple):
+
+    """
+    A row in L{DataFrame}. The fields in it can be accessed like attributes.
+
+    Row can be used to create a row object by using named arguments,
+    the fields will be sorted by names.
+
+    >>> row = Row(name="Alice", age=11)
+    >>> row
+    Row(age=11, name='Alice')
+    >>> row.name, row.age
+    ('Alice', 11)
+
+    Row also can be used to create another Row like class, then it
+    could be used to create Row objects, such as
+
+    >>> Person = Row("name", "age")
+    >>> Person
+    <Row(name, age)>
+    >>> Person("Alice", 11)
+    Row(name='Alice', age=11)
+    """
+
+    def __new__(self, *args, **kwargs):
+        if args and kwargs:
+            raise ValueError("Can not use both args "
+                             "and kwargs to create Row")
+        if args:
+            # create row class or objects
+            return tuple.__new__(self, args)
+
+        elif kwargs:
+            # create row objects
+            names = sorted(kwargs.keys())
+            row = tuple.__new__(self, [kwargs[n] for n in names])
+            row.__fields__ = names
+            return row
+
+        else:
+            raise ValueError("No args or kwargs")
+
+    def asDict(self):
+        """
+        Return as an dict
+        """
+        if not hasattr(self, "__fields__"):
+            raise TypeError("Cannot convert a Row class into dict")
+        return dict(zip(self.__fields__, self))
+
+    # let object acts like class
+    def __call__(self, *args):
+        """create new Row object"""
+        return _create_row(self, args)
+
+    def __getattr__(self, item):
+        if item.startswith("__"):
+            raise AttributeError(item)
+        try:
+            # it will be slow when it has many fields,
+            # but this will not be used in normal cases
+            idx = self.__fields__.index(item)
+            return self[idx]
+        except IndexError:
+            raise AttributeError(item)
+        except ValueError:
+            raise AttributeError(item)
+
+    def __reduce__(self):
+        """Returns a tuple so Python knows how to pickle Row."""
+        if hasattr(self, "__fields__"):
+            return (_create_row, (self.__fields__, tuple(self)))
+        else:
+            return tuple.__reduce__(self)
+
+    def __repr__(self):
+        """Printable representation of Row used in Python REPL."""
+        if hasattr(self, "__fields__"):
+            return "Row(%s)" % ", ".join("%s=%r" % (k, v)
+                                         for k, v in zip(self.__fields__, tuple(self)))
+        else:
+            return "<Row(%s)>" % ", ".join(self)
+
+
+class DateConverter(object):
+    def can_convert(self, obj):
+        return isinstance(obj, datetime.date)
+
+    def convert(self, obj, gateway_client):
+        Date = JavaClass("java.sql.Date", gateway_client)
+        return Date.valueOf(obj.strftime("%Y-%m-%d"))
+
+
+class DatetimeConverter(object):
+    def can_convert(self, obj):
+        return isinstance(obj, datetime.datetime)
+
+    def convert(self, obj, gateway_client):
+        Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
+        return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
+
+
+# datetime is a subclass of date, we should register DatetimeConverter first
+register_input_converter(DatetimeConverter())
+register_input_converter(DateConverter())
+
+
+def _test():
+    import doctest
+    from pyspark.context import SparkContext
+    # let doctest run in pyspark.sql.types, so DataTypes can be picklable
+    import pyspark.sql.types
+    from pyspark.sql import Row, SQLContext
+    from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
+    globs = pyspark.sql.types.__dict__.copy()
+    sc = SparkContext('local[4]', 'PythonTest')
+    globs['sc'] = sc
+    globs['sqlContext'] = SQLContext(sc)
+    globs['ExamplePoint'] = ExamplePoint
+    globs['ExamplePointUDT'] = ExamplePointUDT
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+
+if __name__ == "__main__":
+    _test()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index ffde2fb..fcfb495 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -57,54 +57,54 @@ function run_test() {
 
 function run_core_tests() {
     echo "Run core tests ..."
-    run_test "pyspark/rdd.py"
-    run_test "pyspark/context.py"
-    run_test "pyspark/conf.py"
-    PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
-    PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
-    run_test "pyspark/serializers.py"
-    run_test "pyspark/profiler.py"
-    run_test "pyspark/shuffle.py"
-    run_test "pyspark/tests.py"
+    run_test "pyspark.rdd"
+    run_test "pyspark.context"
+    run_test "pyspark.conf"
+    run_test "pyspark.broadcast"
+    run_test "pyspark.accumulators"
+    run_test "pyspark.serializers"
+    run_test "pyspark.profiler"
+    run_test "pyspark.shuffle"
+    run_test "pyspark.tests"
 }
 
 function run_sql_tests() {
     echo "Run sql tests ..."
-    run_test "pyspark/sql/_types.py"
-    run_test "pyspark/sql/context.py"
-    run_test "pyspark/sql/column.py"
-    run_test "pyspark/sql/dataframe.py"
-    run_test "pyspark/sql/group.py"
-    run_test "pyspark/sql/functions.py"
-    run_test "pyspark/sql/tests.py"
+    run_test "pyspark.sql.types"
+    run_test "pyspark.sql.context"
+    run_test "pyspark.sql.column"
+    run_test "pyspark.sql.dataframe"
+    run_test "pyspark.sql.group"
+    run_test "pyspark.sql.functions"
+    run_test "pyspark.sql.tests"
 }
 
 function run_mllib_tests() {
     echo "Run mllib tests ..."
-    run_test "pyspark/mllib/classification.py"
-    run_test "pyspark/mllib/clustering.py"
-    run_test "pyspark/mllib/evaluation.py"
-    run_test "pyspark/mllib/feature.py"
-    run_test "pyspark/mllib/fpm.py"
-    run_test "pyspark/mllib/linalg.py"
-    run_test "pyspark/mllib/rand.py"
-    run_test "pyspark/mllib/recommendation.py"
-    run_test "pyspark/mllib/regression.py"
-    run_test "pyspark/mllib/stat/_statistics.py"
-    run_test "pyspark/mllib/tree.py"
-    run_test "pyspark/mllib/util.py"
-    run_test "pyspark/mllib/tests.py"
+    run_test "pyspark.mllib.classification"
+    run_test "pyspark.mllib.clustering"
+    run_test "pyspark.mllib.evaluation"
+    run_test "pyspark.mllib.feature"
+    run_test "pyspark.mllib.fpm"
+    run_test "pyspark.mllib.linalg"
+    run_test "pyspark.mllib.random"
+    run_test "pyspark.mllib.recommendation"
+    run_test "pyspark.mllib.regression"
+    run_test "pyspark.mllib.stat._statistics"
+    run_test "pyspark.mllib.tree"
+    run_test "pyspark.mllib.util"
+    run_test "pyspark.mllib.tests"
 }
 
 function run_ml_tests() {
     echo "Run ml tests ..."
-    run_test "pyspark/ml/feature.py"
-    run_test "pyspark/ml/classification.py"
-    run_test "pyspark/ml/recommendation.py"
-    run_test "pyspark/ml/regression.py"
-    run_test "pyspark/ml/tuning.py"
-    run_test "pyspark/ml/tests.py"
-    run_test "pyspark/ml/evaluation.py"
+    run_test "pyspark.ml.feature"
+    run_test "pyspark.ml.classification"
+    run_test "pyspark.ml.recommendation"
+    run_test "pyspark.ml.regression"
+    run_test "pyspark.ml.tuning"
+    run_test "pyspark.ml.tests"
+    run_test "pyspark.ml.evaluation"
 }
 
 function run_streaming_tests() {
@@ -124,8 +124,8 @@ function run_streaming_tests() {
     done
 
     export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell"
-    run_test "pyspark/streaming/util.py"
-    run_test "pyspark/streaming/tests.py"
+    run_test "pyspark.streaming.util"
+    run_test "pyspark.streaming.tests"
 }
 
 echo "Running PySpark tests. Output is in python/$LOG_FILE."


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-7899] [PYSPARK] Fix Python 3 pyspark/sql/types module conflict

Posted by da...@apache.org.
[SPARK-7899] [PYSPARK] Fix Python 3 pyspark/sql/types module conflict

This PR makes the types module in `pyspark/sql/types` work with pylint static analysis by removing the dynamic naming of the `pyspark/sql/_types` module to `pyspark/sql/types`.

Tests are now loaded using `$PYSPARK_DRIVER_PYTHON -m module` rather than `$PYSPARK_DRIVER_PYTHON module.py`. The old method adds the location of `module.py` to `sys.path`, so this change prevents accidental use of relative paths in Python.

Author: Michael Nazario <mn...@palantir.com>

Closes #6439 from mnazario/feature/SPARK-7899 and squashes the following commits:

366ef30 [Michael Nazario] Remove hack on random.py
bb8b04d [Michael Nazario] Make doctests consistent with other tests
6ee4f75 [Michael Nazario] Change test scripts to use "-m"
673528f [Michael Nazario] Move _types back to types


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c5b1982
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c5b1982
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c5b1982

Branch: refs/heads/master
Commit: 1c5b19827a091b5aba69a967600e7ca35ed3bcfd
Parents: 5f48e5c
Author: Michael Nazario <mn...@palantir.com>
Authored: Fri May 29 14:13:44 2015 -0700
Committer: Davies Liu <da...@databricks.com>
Committed: Fri May 29 14:13:44 2015 -0700

----------------------------------------------------------------------
 bin/pyspark                      |    6 +-
 python/pyspark/accumulators.py   |    4 +
 python/pyspark/mllib/__init__.py |    8 -
 python/pyspark/mllib/rand.py     |  409 -----------
 python/pyspark/mllib/random.py   |  409 +++++++++++
 python/pyspark/sql/__init__.py   |   12 -
 python/pyspark/sql/_types.py     | 1306 ---------------------------------
 python/pyspark/sql/types.py      | 1306 +++++++++++++++++++++++++++++++++
 python/run-tests                 |   76 +-
 9 files changed, 1758 insertions(+), 1778 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/bin/pyspark
----------------------------------------------------------------------
diff --git a/bin/pyspark b/bin/pyspark
index 8acad61..7cb19c5 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -90,11 +90,7 @@ if [[ -n "$SPARK_TESTING" ]]; then
   unset YARN_CONF_DIR
   unset HADOOP_CONF_DIR
   export PYTHONHASHSEED=0
-  if [[ -n "$PYSPARK_DOC_TEST" ]]; then
-    exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
-  else
-    exec "$PYSPARK_DRIVER_PYTHON" $1
-  fi
+  exec "$PYSPARK_DRIVER_PYTHON" -m $1
   exit
 fi
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/accumulators.py
----------------------------------------------------------------------
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 0d21a13..adca90d 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -261,3 +261,7 @@ def _start_update_server():
     thread.daemon = True
     thread.start()
     return server
+
+if __name__ == "__main__":
+    import doctest
+    doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/mllib/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index 07507b2..b11aed2 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -28,11 +28,3 @@ if numpy.version.version < '1.4':
 
 __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
            'recommendation', 'regression', 'stat', 'tree', 'util']
-
-import sys
-from . import rand as random
-modname = __name__ + '.random'
-random.__name__ = modname
-random.RandomRDDs.__module__ = modname
-sys.modules[modname] = random
-del modname, sys

http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/mllib/rand.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/rand.py
deleted file mode 100644
index 06fbc0e..0000000
--- a/python/pyspark/mllib/rand.py
+++ /dev/null
@@ -1,409 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Python package for random data generation.
-"""
-
-from functools import wraps
-
-from pyspark.mllib.common import callMLlibFunc
-
-
-__all__ = ['RandomRDDs', ]
-
-
-def toArray(f):
-    @wraps(f)
-    def func(sc, *a, **kw):
-        rdd = f(sc, *a, **kw)
-        return rdd.map(lambda vec: vec.toArray())
-    return func
-
-
-class RandomRDDs(object):
-    """
-    Generator methods for creating RDDs comprised of i.i.d samples from
-    some distribution.
-    """
-
-    @staticmethod
-    def uniformRDD(sc, size, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of i.i.d. samples from the
-        uniform distribution U(0.0, 1.0).
-
-        To transform the distribution in the generated RDD from U(0.0, 1.0)
-        to U(a, b), use
-        C{RandomRDDs.uniformRDD(sc, n, p, seed)\
-          .map(lambda v: a + (b - a) * v)}
-
-        :param sc: SparkContext used to create the RDD.
-        :param size: Size of the RDD.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`.
-
-        >>> x = RandomRDDs.uniformRDD(sc, 100).collect()
-        >>> len(x)
-        100
-        >>> max(x) <= 1.0 and min(x) >= 0.0
-        True
-        >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions()
-        4
-        >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions()
-        >>> parts == sc.defaultParallelism
-        True
-        """
-        return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed)
-
-    @staticmethod
-    def normalRDD(sc, size, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of i.i.d. samples from the standard normal
-        distribution.
-
-        To transform the distribution in the generated RDD from standard normal
-        to some other normal N(mean, sigma^2), use
-        C{RandomRDDs.normal(sc, n, p, seed)\
-          .map(lambda v: mean + sigma * v)}
-
-        :param sc: SparkContext used to create the RDD.
-        :param size: Size of the RDD.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0).
-
-        >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1)
-        >>> stats = x.stats()
-        >>> stats.count()
-        1000
-        >>> abs(stats.mean() - 0.0) < 0.1
-        True
-        >>> abs(stats.stdev() - 1.0) < 0.1
-        True
-        """
-        return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed)
-
-    @staticmethod
-    def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of i.i.d. samples from the log normal
-        distribution with the input mean and standard distribution.
-
-        :param sc: SparkContext used to create the RDD.
-        :param mean: mean for the log Normal distribution
-        :param std: std for the log Normal distribution
-        :param size: Size of the RDD.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of float comprised of i.i.d. samples ~ log N(mean, std).
-
-        >>> from math import sqrt, exp
-        >>> mean = 0.0
-        >>> std = 1.0
-        >>> expMean = exp(mean + 0.5 * std * std)
-        >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
-        >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2)
-        >>> stats = x.stats()
-        >>> stats.count()
-        1000
-        >>> abs(stats.mean() - expMean) < 0.5
-        True
-        >>> from math import sqrt
-        >>> abs(stats.stdev() - expStd) < 0.5
-        True
-        """
-        return callMLlibFunc("logNormalRDD", sc._jsc, float(mean), float(std),
-                             size, numPartitions, seed)
-
-    @staticmethod
-    def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of i.i.d. samples from the Poisson
-        distribution with the input mean.
-
-        :param sc: SparkContext used to create the RDD.
-        :param mean: Mean, or lambda, for the Poisson distribution.
-        :param size: Size of the RDD.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of float comprised of i.i.d. samples ~ Pois(mean).
-
-        >>> mean = 100.0
-        >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2)
-        >>> stats = x.stats()
-        >>> stats.count()
-        1000
-        >>> abs(stats.mean() - mean) < 0.5
-        True
-        >>> from math import sqrt
-        >>> abs(stats.stdev() - sqrt(mean)) < 0.5
-        True
-        """
-        return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed)
-
-    @staticmethod
-    def exponentialRDD(sc, mean, size, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of i.i.d. samples from the Exponential
-        distribution with the input mean.
-
-        :param sc: SparkContext used to create the RDD.
-        :param mean: Mean, or 1 / lambda, for the Exponential distribution.
-        :param size: Size of the RDD.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of float comprised of i.i.d. samples ~ Exp(mean).
-
-        >>> mean = 2.0
-        >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2)
-        >>> stats = x.stats()
-        >>> stats.count()
-        1000
-        >>> abs(stats.mean() - mean) < 0.5
-        True
-        >>> from math import sqrt
-        >>> abs(stats.stdev() - sqrt(mean)) < 0.5
-        True
-        """
-        return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed)
-
-    @staticmethod
-    def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of i.i.d. samples from the Gamma
-        distribution with the input shape and scale.
-
-        :param sc: SparkContext used to create the RDD.
-        :param shape: shape (> 0) parameter for the Gamma distribution
-        :param scale: scale (> 0) parameter for the Gamma distribution
-        :param size: Size of the RDD.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale).
-
-        >>> from math import sqrt
-        >>> shape = 1.0
-        >>> scale = 2.0
-        >>> expMean = shape * scale
-        >>> expStd = sqrt(shape * scale * scale)
-        >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2)
-        >>> stats = x.stats()
-        >>> stats.count()
-        1000
-        >>> abs(stats.mean() - expMean) < 0.5
-        True
-        >>> abs(stats.stdev() - expStd) < 0.5
-        True
-        """
-        return callMLlibFunc("gammaRDD", sc._jsc, float(shape),
-                             float(scale), size, numPartitions, seed)
-
-    @staticmethod
-    @toArray
-    def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of vectors containing i.i.d. samples drawn
-        from the uniform distribution U(0.0, 1.0).
-
-        :param sc: SparkContext used to create the RDD.
-        :param numRows: Number of Vectors in the RDD.
-        :param numCols: Number of elements in each Vector.
-        :param numPartitions: Number of partitions in the RDD.
-        :param seed: Seed for the RNG that generates the seed for the generator in each partition.
-        :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`.
-
-        >>> import numpy as np
-        >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect())
-        >>> mat.shape
-        (10, 10)
-        >>> mat.max() <= 1.0 and mat.min() >= 0.0
-        True
-        >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions()
-        4
-        """
-        return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
-
-    @staticmethod
-    @toArray
-    def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of vectors containing i.i.d. samples drawn
-        from the standard normal distribution.
-
-        :param sc: SparkContext used to create the RDD.
-        :param numRows: Number of Vectors in the RDD.
-        :param numCols: Number of elements in each Vector.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`.
-
-        >>> import numpy as np
-        >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect())
-        >>> mat.shape
-        (100, 100)
-        >>> abs(mat.mean() - 0.0) < 0.1
-        True
-        >>> abs(mat.std() - 1.0) < 0.1
-        True
-        """
-        return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
-
-    @staticmethod
-    @toArray
-    def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of vectors containing i.i.d. samples drawn
-        from the log normal distribution.
-
-        :param sc: SparkContext used to create the RDD.
-        :param mean: Mean of the log normal distribution
-        :param std: Standard Deviation of the log normal distribution
-        :param numRows: Number of Vectors in the RDD.
-        :param numCols: Number of elements in each Vector.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`.
-
-        >>> import numpy as np
-        >>> from math import sqrt, exp
-        >>> mean = 0.0
-        >>> std = 1.0
-        >>> expMean = exp(mean + 0.5 * std * std)
-        >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
-        >>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect()
-        >>> mat = np.matrix(m)
-        >>> mat.shape
-        (100, 100)
-        >>> abs(mat.mean() - expMean) < 0.1
-        True
-        >>> abs(mat.std() - expStd) < 0.1
-        True
-        """
-        return callMLlibFunc("logNormalVectorRDD", sc._jsc, float(mean), float(std),
-                             numRows, numCols, numPartitions, seed)
-
-    @staticmethod
-    @toArray
-    def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of vectors containing i.i.d. samples drawn
-        from the Poisson distribution with the input mean.
-
-        :param sc: SparkContext used to create the RDD.
-        :param mean: Mean, or lambda, for the Poisson distribution.
-        :param numRows: Number of Vectors in the RDD.
-        :param numCols: Number of elements in each Vector.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`)
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean).
-
-        >>> import numpy as np
-        >>> mean = 100.0
-        >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1)
-        >>> mat = np.mat(rdd.collect())
-        >>> mat.shape
-        (100, 100)
-        >>> abs(mat.mean() - mean) < 0.5
-        True
-        >>> from math import sqrt
-        >>> abs(mat.std() - sqrt(mean)) < 0.5
-        True
-        """
-        return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols,
-                             numPartitions, seed)
-
-    @staticmethod
-    @toArray
-    def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of vectors containing i.i.d. samples drawn
-        from the Exponential distribution with the input mean.
-
-        :param sc: SparkContext used to create the RDD.
-        :param mean: Mean, or 1 / lambda, for the Exponential distribution.
-        :param numRows: Number of Vectors in the RDD.
-        :param numCols: Number of elements in each Vector.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`)
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean).
-
-        >>> import numpy as np
-        >>> mean = 0.5
-        >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1)
-        >>> mat = np.mat(rdd.collect())
-        >>> mat.shape
-        (100, 100)
-        >>> abs(mat.mean() - mean) < 0.5
-        True
-        >>> from math import sqrt
-        >>> abs(mat.std() - sqrt(mean)) < 0.5
-        True
-        """
-        return callMLlibFunc("exponentialVectorRDD", sc._jsc, float(mean), numRows, numCols,
-                             numPartitions, seed)
-
-    @staticmethod
-    @toArray
-    def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None):
-        """
-        Generates an RDD comprised of vectors containing i.i.d. samples drawn
-        from the Gamma distribution.
-
-        :param sc: SparkContext used to create the RDD.
-        :param shape: Shape (> 0) of the Gamma distribution
-        :param scale: Scale (> 0) of the Gamma distribution
-        :param numRows: Number of Vectors in the RDD.
-        :param numCols: Number of elements in each Vector.
-        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
-        :param seed: Random seed (default: a random long integer).
-        :return: RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale).
-
-        >>> import numpy as np
-        >>> from math import sqrt
-        >>> shape = 1.0
-        >>> scale = 2.0
-        >>> expMean = shape * scale
-        >>> expStd = sqrt(shape * scale * scale)
-        >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect())
-        >>> mat.shape
-        (100, 100)
-        >>> abs(mat.mean() - expMean) < 0.1
-        True
-        >>> abs(mat.std() - expStd) < 0.1
-        True
-        """
-        return callMLlibFunc("gammaVectorRDD", sc._jsc, float(shape), float(scale),
-                             numRows, numCols, numPartitions, seed)
-
-
-def _test():
-    import doctest
-    from pyspark.context import SparkContext
-    globs = globals().copy()
-    # The small batch size here ensures that we see multiple batches,
-    # even in these small test examples:
-    globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
-    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
-    if failure_count:
-        exit(-1)
-
-
-if __name__ == "__main__":
-    _test()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/mllib/random.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
new file mode 100644
index 0000000..06fbc0e
--- /dev/null
+++ b/python/pyspark/mllib/random.py
@@ -0,0 +1,409 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python package for random data generation.
+"""
+
+from functools import wraps
+
+from pyspark.mllib.common import callMLlibFunc
+
+
+__all__ = ['RandomRDDs', ]
+
+
+def toArray(f):
+    @wraps(f)
+    def func(sc, *a, **kw):
+        rdd = f(sc, *a, **kw)
+        return rdd.map(lambda vec: vec.toArray())
+    return func
+
+
+class RandomRDDs(object):
+    """
+    Generator methods for creating RDDs comprised of i.i.d samples from
+    some distribution.
+    """
+
+    @staticmethod
+    def uniformRDD(sc, size, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of i.i.d. samples from the
+        uniform distribution U(0.0, 1.0).
+
+        To transform the distribution in the generated RDD from U(0.0, 1.0)
+        to U(a, b), use
+        C{RandomRDDs.uniformRDD(sc, n, p, seed)\
+          .map(lambda v: a + (b - a) * v)}
+
+        :param sc: SparkContext used to create the RDD.
+        :param size: Size of the RDD.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`.
+
+        >>> x = RandomRDDs.uniformRDD(sc, 100).collect()
+        >>> len(x)
+        100
+        >>> max(x) <= 1.0 and min(x) >= 0.0
+        True
+        >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions()
+        4
+        >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions()
+        >>> parts == sc.defaultParallelism
+        True
+        """
+        return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed)
+
+    @staticmethod
+    def normalRDD(sc, size, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of i.i.d. samples from the standard normal
+        distribution.
+
+        To transform the distribution in the generated RDD from standard normal
+        to some other normal N(mean, sigma^2), use
+        C{RandomRDDs.normal(sc, n, p, seed)\
+          .map(lambda v: mean + sigma * v)}
+
+        :param sc: SparkContext used to create the RDD.
+        :param size: Size of the RDD.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0).
+
+        >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1)
+        >>> stats = x.stats()
+        >>> stats.count()
+        1000
+        >>> abs(stats.mean() - 0.0) < 0.1
+        True
+        >>> abs(stats.stdev() - 1.0) < 0.1
+        True
+        """
+        return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed)
+
+    @staticmethod
+    def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of i.i.d. samples from the log normal
+        distribution with the input mean and standard distribution.
+
+        :param sc: SparkContext used to create the RDD.
+        :param mean: mean for the log Normal distribution
+        :param std: std for the log Normal distribution
+        :param size: Size of the RDD.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of float comprised of i.i.d. samples ~ log N(mean, std).
+
+        >>> from math import sqrt, exp
+        >>> mean = 0.0
+        >>> std = 1.0
+        >>> expMean = exp(mean + 0.5 * std * std)
+        >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
+        >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2)
+        >>> stats = x.stats()
+        >>> stats.count()
+        1000
+        >>> abs(stats.mean() - expMean) < 0.5
+        True
+        >>> from math import sqrt
+        >>> abs(stats.stdev() - expStd) < 0.5
+        True
+        """
+        return callMLlibFunc("logNormalRDD", sc._jsc, float(mean), float(std),
+                             size, numPartitions, seed)
+
+    @staticmethod
+    def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of i.i.d. samples from the Poisson
+        distribution with the input mean.
+
+        :param sc: SparkContext used to create the RDD.
+        :param mean: Mean, or lambda, for the Poisson distribution.
+        :param size: Size of the RDD.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of float comprised of i.i.d. samples ~ Pois(mean).
+
+        >>> mean = 100.0
+        >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2)
+        >>> stats = x.stats()
+        >>> stats.count()
+        1000
+        >>> abs(stats.mean() - mean) < 0.5
+        True
+        >>> from math import sqrt
+        >>> abs(stats.stdev() - sqrt(mean)) < 0.5
+        True
+        """
+        return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed)
+
+    @staticmethod
+    def exponentialRDD(sc, mean, size, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of i.i.d. samples from the Exponential
+        distribution with the input mean.
+
+        :param sc: SparkContext used to create the RDD.
+        :param mean: Mean, or 1 / lambda, for the Exponential distribution.
+        :param size: Size of the RDD.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of float comprised of i.i.d. samples ~ Exp(mean).
+
+        >>> mean = 2.0
+        >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2)
+        >>> stats = x.stats()
+        >>> stats.count()
+        1000
+        >>> abs(stats.mean() - mean) < 0.5
+        True
+        >>> from math import sqrt
+        >>> abs(stats.stdev() - sqrt(mean)) < 0.5
+        True
+        """
+        return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed)
+
+    @staticmethod
+    def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of i.i.d. samples from the Gamma
+        distribution with the input shape and scale.
+
+        :param sc: SparkContext used to create the RDD.
+        :param shape: shape (> 0) parameter for the Gamma distribution
+        :param scale: scale (> 0) parameter for the Gamma distribution
+        :param size: Size of the RDD.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale).
+
+        >>> from math import sqrt
+        >>> shape = 1.0
+        >>> scale = 2.0
+        >>> expMean = shape * scale
+        >>> expStd = sqrt(shape * scale * scale)
+        >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2)
+        >>> stats = x.stats()
+        >>> stats.count()
+        1000
+        >>> abs(stats.mean() - expMean) < 0.5
+        True
+        >>> abs(stats.stdev() - expStd) < 0.5
+        True
+        """
+        return callMLlibFunc("gammaRDD", sc._jsc, float(shape),
+                             float(scale), size, numPartitions, seed)
+
+    @staticmethod
+    @toArray
+    def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of vectors containing i.i.d. samples drawn
+        from the uniform distribution U(0.0, 1.0).
+
+        :param sc: SparkContext used to create the RDD.
+        :param numRows: Number of Vectors in the RDD.
+        :param numCols: Number of elements in each Vector.
+        :param numPartitions: Number of partitions in the RDD.
+        :param seed: Seed for the RNG that generates the seed for the generator in each partition.
+        :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`.
+
+        >>> import numpy as np
+        >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect())
+        >>> mat.shape
+        (10, 10)
+        >>> mat.max() <= 1.0 and mat.min() >= 0.0
+        True
+        >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions()
+        4
+        """
+        return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
+
+    @staticmethod
+    @toArray
+    def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of vectors containing i.i.d. samples drawn
+        from the standard normal distribution.
+
+        :param sc: SparkContext used to create the RDD.
+        :param numRows: Number of Vectors in the RDD.
+        :param numCols: Number of elements in each Vector.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`.
+
+        >>> import numpy as np
+        >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect())
+        >>> mat.shape
+        (100, 100)
+        >>> abs(mat.mean() - 0.0) < 0.1
+        True
+        >>> abs(mat.std() - 1.0) < 0.1
+        True
+        """
+        return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
+
+    @staticmethod
+    @toArray
+    def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of vectors containing i.i.d. samples drawn
+        from the log normal distribution.
+
+        :param sc: SparkContext used to create the RDD.
+        :param mean: Mean of the log normal distribution
+        :param std: Standard Deviation of the log normal distribution
+        :param numRows: Number of Vectors in the RDD.
+        :param numCols: Number of elements in each Vector.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`.
+
+        >>> import numpy as np
+        >>> from math import sqrt, exp
+        >>> mean = 0.0
+        >>> std = 1.0
+        >>> expMean = exp(mean + 0.5 * std * std)
+        >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
+        >>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect()
+        >>> mat = np.matrix(m)
+        >>> mat.shape
+        (100, 100)
+        >>> abs(mat.mean() - expMean) < 0.1
+        True
+        >>> abs(mat.std() - expStd) < 0.1
+        True
+        """
+        return callMLlibFunc("logNormalVectorRDD", sc._jsc, float(mean), float(std),
+                             numRows, numCols, numPartitions, seed)
+
+    @staticmethod
+    @toArray
+    def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of vectors containing i.i.d. samples drawn
+        from the Poisson distribution with the input mean.
+
+        :param sc: SparkContext used to create the RDD.
+        :param mean: Mean, or lambda, for the Poisson distribution.
+        :param numRows: Number of Vectors in the RDD.
+        :param numCols: Number of elements in each Vector.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`)
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean).
+
+        >>> import numpy as np
+        >>> mean = 100.0
+        >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1)
+        >>> mat = np.mat(rdd.collect())
+        >>> mat.shape
+        (100, 100)
+        >>> abs(mat.mean() - mean) < 0.5
+        True
+        >>> from math import sqrt
+        >>> abs(mat.std() - sqrt(mean)) < 0.5
+        True
+        """
+        return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols,
+                             numPartitions, seed)
+
+    @staticmethod
+    @toArray
+    def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of vectors containing i.i.d. samples drawn
+        from the Exponential distribution with the input mean.
+
+        :param sc: SparkContext used to create the RDD.
+        :param mean: Mean, or 1 / lambda, for the Exponential distribution.
+        :param numRows: Number of Vectors in the RDD.
+        :param numCols: Number of elements in each Vector.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`)
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean).
+
+        >>> import numpy as np
+        >>> mean = 0.5
+        >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1)
+        >>> mat = np.mat(rdd.collect())
+        >>> mat.shape
+        (100, 100)
+        >>> abs(mat.mean() - mean) < 0.5
+        True
+        >>> from math import sqrt
+        >>> abs(mat.std() - sqrt(mean)) < 0.5
+        True
+        """
+        return callMLlibFunc("exponentialVectorRDD", sc._jsc, float(mean), numRows, numCols,
+                             numPartitions, seed)
+
+    @staticmethod
+    @toArray
+    def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None):
+        """
+        Generates an RDD comprised of vectors containing i.i.d. samples drawn
+        from the Gamma distribution.
+
+        :param sc: SparkContext used to create the RDD.
+        :param shape: Shape (> 0) of the Gamma distribution
+        :param scale: Scale (> 0) of the Gamma distribution
+        :param numRows: Number of Vectors in the RDD.
+        :param numCols: Number of elements in each Vector.
+        :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+        :param seed: Random seed (default: a random long integer).
+        :return: RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale).
+
+        >>> import numpy as np
+        >>> from math import sqrt
+        >>> shape = 1.0
+        >>> scale = 2.0
+        >>> expMean = shape * scale
+        >>> expStd = sqrt(shape * scale * scale)
+        >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect())
+        >>> mat.shape
+        (100, 100)
+        >>> abs(mat.mean() - expMean) < 0.1
+        True
+        >>> abs(mat.std() - expStd) < 0.1
+        True
+        """
+        return callMLlibFunc("gammaVectorRDD", sc._jsc, float(shape), float(scale),
+                             numRows, numCols, numPartitions, seed)
+
+
+def _test():
+    import doctest
+    from pyspark.context import SparkContext
+    globs = globals().copy()
+    # The small batch size here ensures that we see multiple batches,
+    # even in these small test examples:
+    globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
+    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+
+if __name__ == "__main__":
+    _test()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/sql/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 8fee92a..726d288 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -50,18 +50,6 @@ def since(version):
         return f
     return deco
 
-# fix the module name conflict for Python 3+
-import sys
-from . import _types as types
-modname = __name__ + '.types'
-types.__name__ = modname
-# update the __module__ for all objects, make them picklable
-for v in types.__dict__.values():
-    if hasattr(v, "__module__") and v.__module__.endswith('._types'):
-        v.__module__ = modname
-sys.modules[modname] = types
-del modname, sys
-
 from pyspark.sql.types import Row
 from pyspark.sql.context import SQLContext, HiveContext
 from pyspark.sql.column import Column

http://git-wip-us.apache.org/repos/asf/spark/blob/1c5b1982/python/pyspark/sql/_types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py
deleted file mode 100644
index 9e7e9f0..0000000
--- a/python/pyspark/sql/_types.py
+++ /dev/null
@@ -1,1306 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import sys
-import decimal
-import time
-import datetime
-import keyword
-import warnings
-import json
-import re
-import weakref
-from array import array
-from operator import itemgetter
-
-if sys.version >= "3":
-    long = int
-    unicode = str
-
-from py4j.protocol import register_input_converter
-from py4j.java_gateway import JavaClass
-
-__all__ = [
-    "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
-    "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
-    "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
-
-
-class DataType(object):
-    """Base class for data types."""
-
-    def __repr__(self):
-        return self.__class__.__name__
-
-    def __hash__(self):
-        return hash(str(self))
-
-    def __eq__(self, other):
-        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
-
-    def __ne__(self, other):
-        return not self.__eq__(other)
-
-    @classmethod
-    def typeName(cls):
-        return cls.__name__[:-4].lower()
-
-    def simpleString(self):
-        return self.typeName()
-
-    def jsonValue(self):
-        return self.typeName()
-
-    def json(self):
-        return json.dumps(self.jsonValue(),
-                          separators=(',', ':'),
-                          sort_keys=True)
-
-
-# This singleton pattern does not work with pickle, you will get
-# another object after pickle and unpickle
-class DataTypeSingleton(type):
-    """Metaclass for DataType"""
-
-    _instances = {}
-
-    def __call__(cls):
-        if cls not in cls._instances:
-            cls._instances[cls] = super(DataTypeSingleton, cls).__call__()
-        return cls._instances[cls]
-
-
-class NullType(DataType):
-    """Null type.
-
-    The data type representing None, used for the types that cannot be inferred.
-    """
-
-    __metaclass__ = DataTypeSingleton
-
-
-class AtomicType(DataType):
-    """An internal type used to represent everything that is not
-    null, UDTs, arrays, structs, and maps."""
-
-    __metaclass__ = DataTypeSingleton
-
-
-class NumericType(AtomicType):
-    """Numeric data types.
-    """
-
-
-class IntegralType(NumericType):
-    """Integral data types.
-    """
-
-
-class FractionalType(NumericType):
-    """Fractional data types.
-    """
-
-
-class StringType(AtomicType):
-    """String data type.
-    """
-
-
-class BinaryType(AtomicType):
-    """Binary (byte array) data type.
-    """
-
-
-class BooleanType(AtomicType):
-    """Boolean data type.
-    """
-
-
-class DateType(AtomicType):
-    """Date (datetime.date) data type.
-    """
-
-
-class TimestampType(AtomicType):
-    """Timestamp (datetime.datetime) data type.
-    """
-
-
-class DecimalType(FractionalType):
-    """Decimal (decimal.Decimal) data type.
-    """
-
-    def __init__(self, precision=None, scale=None):
-        self.precision = precision
-        self.scale = scale
-        self.hasPrecisionInfo = precision is not None
-
-    def simpleString(self):
-        if self.hasPrecisionInfo:
-            return "decimal(%d,%d)" % (self.precision, self.scale)
-        else:
-            return "decimal(10,0)"
-
-    def jsonValue(self):
-        if self.hasPrecisionInfo:
-            return "decimal(%d,%d)" % (self.precision, self.scale)
-        else:
-            return "decimal"
-
-    def __repr__(self):
-        if self.hasPrecisionInfo:
-            return "DecimalType(%d,%d)" % (self.precision, self.scale)
-        else:
-            return "DecimalType()"
-
-
-class DoubleType(FractionalType):
-    """Double data type, representing double precision floats.
-    """
-
-
-class FloatType(FractionalType):
-    """Float data type, representing single precision floats.
-    """
-
-
-class ByteType(IntegralType):
-    """Byte data type, i.e. a signed integer in a single byte.
-    """
-    def simpleString(self):
-        return 'tinyint'
-
-
-class IntegerType(IntegralType):
-    """Int data type, i.e. a signed 32-bit integer.
-    """
-    def simpleString(self):
-        return 'int'
-
-
-class LongType(IntegralType):
-    """Long data type, i.e. a signed 64-bit integer.
-
-    If the values are beyond the range of [-9223372036854775808, 9223372036854775807],
-    please use :class:`DecimalType`.
-    """
-    def simpleString(self):
-        return 'bigint'
-
-
-class ShortType(IntegralType):
-    """Short data type, i.e. a signed 16-bit integer.
-    """
-    def simpleString(self):
-        return 'smallint'
-
-
-class ArrayType(DataType):
-    """Array data type.
-
-    :param elementType: :class:`DataType` of each element in the array.
-    :param containsNull: boolean, whether the array can contain null (None) values.
-    """
-
-    def __init__(self, elementType, containsNull=True):
-        """
-        >>> ArrayType(StringType()) == ArrayType(StringType(), True)
-        True
-        >>> ArrayType(StringType(), False) == ArrayType(StringType())
-        False
-        """
-        assert isinstance(elementType, DataType), "elementType should be DataType"
-        self.elementType = elementType
-        self.containsNull = containsNull
-
-    def simpleString(self):
-        return 'array<%s>' % self.elementType.simpleString()
-
-    def __repr__(self):
-        return "ArrayType(%s,%s)" % (self.elementType,
-                                     str(self.containsNull).lower())
-
-    def jsonValue(self):
-        return {"type": self.typeName(),
-                "elementType": self.elementType.jsonValue(),
-                "containsNull": self.containsNull}
-
-    @classmethod
-    def fromJson(cls, json):
-        return ArrayType(_parse_datatype_json_value(json["elementType"]),
-                         json["containsNull"])
-
-
-class MapType(DataType):
-    """Map data type.
-
-    :param keyType: :class:`DataType` of the keys in the map.
-    :param valueType: :class:`DataType` of the values in the map.
-    :param valueContainsNull: indicates whether values can contain null (None) values.
-
-    Keys in a map data type are not allowed to be null (None).
-    """
-
-    def __init__(self, keyType, valueType, valueContainsNull=True):
-        """
-        >>> (MapType(StringType(), IntegerType())
-        ...        == MapType(StringType(), IntegerType(), True))
-        True
-        >>> (MapType(StringType(), IntegerType(), False)
-        ...        == MapType(StringType(), FloatType()))
-        False
-        """
-        assert isinstance(keyType, DataType), "keyType should be DataType"
-        assert isinstance(valueType, DataType), "valueType should be DataType"
-        self.keyType = keyType
-        self.valueType = valueType
-        self.valueContainsNull = valueContainsNull
-
-    def simpleString(self):
-        return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
-
-    def __repr__(self):
-        return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
-                                      str(self.valueContainsNull).lower())
-
-    def jsonValue(self):
-        return {"type": self.typeName(),
-                "keyType": self.keyType.jsonValue(),
-                "valueType": self.valueType.jsonValue(),
-                "valueContainsNull": self.valueContainsNull}
-
-    @classmethod
-    def fromJson(cls, json):
-        return MapType(_parse_datatype_json_value(json["keyType"]),
-                       _parse_datatype_json_value(json["valueType"]),
-                       json["valueContainsNull"])
-
-
-class StructField(DataType):
-    """A field in :class:`StructType`.
-
-    :param name: string, name of the field.
-    :param dataType: :class:`DataType` of the field.
-    :param nullable: boolean, whether the field can be null (None) or not.
-    :param metadata: a dict from string to simple type that can be serialized to JSON automatically
-    """
-
-    def __init__(self, name, dataType, nullable=True, metadata=None):
-        """
-        >>> (StructField("f1", StringType(), True)
-        ...      == StructField("f1", StringType(), True))
-        True
-        >>> (StructField("f1", StringType(), True)
-        ...      == StructField("f2", StringType(), True))
-        False
-        """
-        assert isinstance(dataType, DataType), "dataType should be DataType"
-        self.name = name
-        self.dataType = dataType
-        self.nullable = nullable
-        self.metadata = metadata or {}
-
-    def simpleString(self):
-        return '%s:%s' % (self.name, self.dataType.simpleString())
-
-    def __repr__(self):
-        return "StructField(%s,%s,%s)" % (self.name, self.dataType,
-                                          str(self.nullable).lower())
-
-    def jsonValue(self):
-        return {"name": self.name,
-                "type": self.dataType.jsonValue(),
-                "nullable": self.nullable,
-                "metadata": self.metadata}
-
-    @classmethod
-    def fromJson(cls, json):
-        return StructField(json["name"],
-                           _parse_datatype_json_value(json["type"]),
-                           json["nullable"],
-                           json["metadata"])
-
-
-class StructType(DataType):
-    """Struct type, consisting of a list of :class:`StructField`.
-
-    This is the data type representing a :class:`Row`.
-    """
-
-    def __init__(self, fields):
-        """
-        >>> struct1 = StructType([StructField("f1", StringType(), True)])
-        >>> struct2 = StructType([StructField("f1", StringType(), True)])
-        >>> struct1 == struct2
-        True
-        >>> struct1 = StructType([StructField("f1", StringType(), True)])
-        >>> struct2 = StructType([StructField("f1", StringType(), True),
-        ...     StructField("f2", IntegerType(), False)])
-        >>> struct1 == struct2
-        False
-        """
-        assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
-        self.fields = fields
-
-    def simpleString(self):
-        return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
-
-    def __repr__(self):
-        return ("StructType(List(%s))" %
-                ",".join(str(field) for field in self.fields))
-
-    def jsonValue(self):
-        return {"type": self.typeName(),
-                "fields": [f.jsonValue() for f in self.fields]}
-
-    @classmethod
-    def fromJson(cls, json):
-        return StructType([StructField.fromJson(f) for f in json["fields"]])
-
-
-class UserDefinedType(DataType):
-    """User-defined type (UDT).
-
-    .. note:: WARN: Spark Internal Use Only
-    """
-
-    @classmethod
-    def typeName(cls):
-        return cls.__name__.lower()
-
-    @classmethod
-    def sqlType(cls):
-        """
-        Underlying SQL storage type for this UDT.
-        """
-        raise NotImplementedError("UDT must implement sqlType().")
-
-    @classmethod
-    def module(cls):
-        """
-        The Python module of the UDT.
-        """
-        raise NotImplementedError("UDT must implement module().")
-
-    @classmethod
-    def scalaUDT(cls):
-        """
-        The class name of the paired Scala UDT.
-        """
-        raise NotImplementedError("UDT must have a paired Scala UDT.")
-
-    def serialize(self, obj):
-        """
-        Converts the a user-type object into a SQL datum.
-        """
-        raise NotImplementedError("UDT must implement serialize().")
-
-    def deserialize(self, datum):
-        """
-        Converts a SQL datum into a user-type object.
-        """
-        raise NotImplementedError("UDT must implement deserialize().")
-
-    def simpleString(self):
-        return 'udt'
-
-    def json(self):
-        return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
-
-    def jsonValue(self):
-        schema = {
-            "type": "udt",
-            "class": self.scalaUDT(),
-            "pyClass": "%s.%s" % (self.module(), type(self).__name__),
-            "sqlType": self.sqlType().jsonValue()
-        }
-        return schema
-
-    @classmethod
-    def fromJson(cls, json):
-        pyUDT = json["pyClass"]
-        split = pyUDT.rfind(".")
-        pyModule = pyUDT[:split]
-        pyClass = pyUDT[split+1:]
-        m = __import__(pyModule, globals(), locals(), [pyClass])
-        UDT = getattr(m, pyClass)
-        return UDT()
-
-    def __eq__(self, other):
-        return type(self) == type(other)
-
-
-_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
-                 ByteType, ShortType, IntegerType, LongType, DateType, TimestampType]
-_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
-_all_complex_types = dict((v.typeName(), v)
-                          for v in [ArrayType, MapType, StructType])
-
-
-def _parse_datatype_json_string(json_string):
-    """Parses the given data type JSON string.
-    >>> import pickle
-    >>> def check_datatype(datatype):
-    ...     pickled = pickle.loads(pickle.dumps(datatype))
-    ...     assert datatype == pickled
-    ...     scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json())
-    ...     python_datatype = _parse_datatype_json_string(scala_datatype.json())
-    ...     assert datatype == python_datatype
-    >>> for cls in _all_atomic_types.values():
-    ...     check_datatype(cls())
-
-    >>> # Simple ArrayType.
-    >>> simple_arraytype = ArrayType(StringType(), True)
-    >>> check_datatype(simple_arraytype)
-
-    >>> # Simple MapType.
-    >>> simple_maptype = MapType(StringType(), LongType())
-    >>> check_datatype(simple_maptype)
-
-    >>> # Simple StructType.
-    >>> simple_structtype = StructType([
-    ...     StructField("a", DecimalType(), False),
-    ...     StructField("b", BooleanType(), True),
-    ...     StructField("c", LongType(), True),
-    ...     StructField("d", BinaryType(), False)])
-    >>> check_datatype(simple_structtype)
-
-    >>> # Complex StructType.
-    >>> complex_structtype = StructType([
-    ...     StructField("simpleArray", simple_arraytype, True),
-    ...     StructField("simpleMap", simple_maptype, True),
-    ...     StructField("simpleStruct", simple_structtype, True),
-    ...     StructField("boolean", BooleanType(), False),
-    ...     StructField("withMeta", DoubleType(), False, {"name": "age"})])
-    >>> check_datatype(complex_structtype)
-
-    >>> # Complex ArrayType.
-    >>> complex_arraytype = ArrayType(complex_structtype, True)
-    >>> check_datatype(complex_arraytype)
-
-    >>> # Complex MapType.
-    >>> complex_maptype = MapType(complex_structtype,
-    ...                           complex_arraytype, False)
-    >>> check_datatype(complex_maptype)
-
-    >>> check_datatype(ExamplePointUDT())
-    >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
-    ...                                   StructField("point", ExamplePointUDT(), False)])
-    >>> check_datatype(structtype_with_udt)
-    """
-    return _parse_datatype_json_value(json.loads(json_string))
-
-
-_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
-
-
-def _parse_datatype_json_value(json_value):
-    if not isinstance(json_value, dict):
-        if json_value in _all_atomic_types.keys():
-            return _all_atomic_types[json_value]()
-        elif json_value == 'decimal':
-            return DecimalType()
-        elif _FIXED_DECIMAL.match(json_value):
-            m = _FIXED_DECIMAL.match(json_value)
-            return DecimalType(int(m.group(1)), int(m.group(2)))
-        else:
-            raise ValueError("Could not parse datatype: %s" % json_value)
-    else:
-        tpe = json_value["type"]
-        if tpe in _all_complex_types:
-            return _all_complex_types[tpe].fromJson(json_value)
-        elif tpe == 'udt':
-            return UserDefinedType.fromJson(json_value)
-        else:
-            raise ValueError("not supported type: %s" % tpe)
-
-
-# Mapping Python types to Spark SQL DataType
-_type_mappings = {
-    type(None): NullType,
-    bool: BooleanType,
-    int: LongType,
-    float: DoubleType,
-    str: StringType,
-    bytearray: BinaryType,
-    decimal.Decimal: DecimalType,
-    datetime.date: DateType,
-    datetime.datetime: TimestampType,
-    datetime.time: TimestampType,
-}
-
-if sys.version < "3":
-    _type_mappings.update({
-        unicode: StringType,
-        long: LongType,
-    })
-
-
-def _infer_type(obj):
-    """Infer the DataType from obj
-
-    >>> p = ExamplePoint(1.0, 2.0)
-    >>> _infer_type(p)
-    ExamplePointUDT
-    """
-    if obj is None:
-        return NullType()
-
-    if hasattr(obj, '__UDT__'):
-        return obj.__UDT__
-
-    dataType = _type_mappings.get(type(obj))
-    if dataType is not None:
-        return dataType()
-
-    if isinstance(obj, dict):
-        for key, value in obj.items():
-            if key is not None and value is not None:
-                return MapType(_infer_type(key), _infer_type(value), True)
-        else:
-            return MapType(NullType(), NullType(), True)
-    elif isinstance(obj, (list, array)):
-        for v in obj:
-            if v is not None:
-                return ArrayType(_infer_type(obj[0]), True)
-        else:
-            return ArrayType(NullType(), True)
-    else:
-        try:
-            return _infer_schema(obj)
-        except TypeError:
-            raise TypeError("not supported type: %s" % type(obj))
-
-
-def _infer_schema(row):
-    """Infer the schema from dict/namedtuple/object"""
-    if isinstance(row, dict):
-        items = sorted(row.items())
-
-    elif isinstance(row, (tuple, list)):
-        if hasattr(row, "__fields__"):  # Row
-            items = zip(row.__fields__, tuple(row))
-        elif hasattr(row, "_fields"):  # namedtuple
-            items = zip(row._fields, tuple(row))
-        else:
-            names = ['_%d' % i for i in range(1, len(row) + 1)]
-            items = zip(names, row)
-
-    elif hasattr(row, "__dict__"):  # object
-        items = sorted(row.__dict__.items())
-
-    else:
-        raise TypeError("Can not infer schema for type: %s" % type(row))
-
-    fields = [StructField(k, _infer_type(v), True) for k, v in items]
-    return StructType(fields)
-
-
-def _need_python_to_sql_conversion(dataType):
-    """
-    Checks whether we need python to sql conversion for the given type.
-    For now, only UDTs need this conversion.
-
-    >>> _need_python_to_sql_conversion(DoubleType())
-    False
-    >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
-    ...                       StructField("values", ArrayType(DoubleType(), False), False)])
-    >>> _need_python_to_sql_conversion(schema0)
-    False
-    >>> _need_python_to_sql_conversion(ExamplePointUDT())
-    True
-    >>> schema1 = ArrayType(ExamplePointUDT(), False)
-    >>> _need_python_to_sql_conversion(schema1)
-    True
-    >>> schema2 = StructType([StructField("label", DoubleType(), False),
-    ...                       StructField("point", ExamplePointUDT(), False)])
-    >>> _need_python_to_sql_conversion(schema2)
-    True
-    """
-    if isinstance(dataType, StructType):
-        return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
-    elif isinstance(dataType, ArrayType):
-        return _need_python_to_sql_conversion(dataType.elementType)
-    elif isinstance(dataType, MapType):
-        return _need_python_to_sql_conversion(dataType.keyType) or \
-            _need_python_to_sql_conversion(dataType.valueType)
-    elif isinstance(dataType, UserDefinedType):
-        return True
-    else:
-        return False
-
-
-def _python_to_sql_converter(dataType):
-    """
-    Returns a converter that converts a Python object into a SQL datum for the given type.
-
-    >>> conv = _python_to_sql_converter(DoubleType())
-    >>> conv(1.0)
-    1.0
-    >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
-    >>> conv([1.0, 2.0])
-    [1.0, 2.0]
-    >>> conv = _python_to_sql_converter(ExamplePointUDT())
-    >>> conv(ExamplePoint(1.0, 2.0))
-    [1.0, 2.0]
-    >>> schema = StructType([StructField("label", DoubleType(), False),
-    ...                      StructField("point", ExamplePointUDT(), False)])
-    >>> conv = _python_to_sql_converter(schema)
-    >>> conv((1.0, ExamplePoint(1.0, 2.0)))
-    (1.0, [1.0, 2.0])
-    """
-    if not _need_python_to_sql_conversion(dataType):
-        return lambda x: x
-
-    if isinstance(dataType, StructType):
-        names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
-        converters = [_python_to_sql_converter(t) for t in types]
-
-        def converter(obj):
-            if isinstance(obj, dict):
-                return tuple(c(obj.get(n)) for n, c in zip(names, converters))
-            elif isinstance(obj, tuple):
-                if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
-                    return tuple(c(v) for c, v in zip(converters, obj))
-                elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):  # k-v pairs
-                    d = dict(obj)
-                    return tuple(c(d.get(n)) for n, c in zip(names, converters))
-                else:
-                    return tuple(c(v) for c, v in zip(converters, obj))
-            else:
-                raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
-        return converter
-    elif isinstance(dataType, ArrayType):
-        element_converter = _python_to_sql_converter(dataType.elementType)
-        return lambda a: [element_converter(v) for v in a]
-    elif isinstance(dataType, MapType):
-        key_converter = _python_to_sql_converter(dataType.keyType)
-        value_converter = _python_to_sql_converter(dataType.valueType)
-        return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
-    elif isinstance(dataType, UserDefinedType):
-        return lambda obj: dataType.serialize(obj)
-    else:
-        raise ValueError("Unexpected type %r" % dataType)
-
-
-def _has_nulltype(dt):
-    """ Return whether there is NullType in `dt` or not """
-    if isinstance(dt, StructType):
-        return any(_has_nulltype(f.dataType) for f in dt.fields)
-    elif isinstance(dt, ArrayType):
-        return _has_nulltype((dt.elementType))
-    elif isinstance(dt, MapType):
-        return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
-    else:
-        return isinstance(dt, NullType)
-
-
-def _merge_type(a, b):
-    if isinstance(a, NullType):
-        return b
-    elif isinstance(b, NullType):
-        return a
-    elif type(a) is not type(b):
-        # TODO: type cast (such as int -> long)
-        raise TypeError("Can not merge type %s and %s" % (type(a), type(b)))
-
-    # same type
-    if isinstance(a, StructType):
-        nfs = dict((f.name, f.dataType) for f in b.fields)
-        fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
-                  for f in a.fields]
-        names = set([f.name for f in fields])
-        for n in nfs:
-            if n not in names:
-                fields.append(StructField(n, nfs[n]))
-        return StructType(fields)
-
-    elif isinstance(a, ArrayType):
-        return ArrayType(_merge_type(a.elementType, b.elementType), True)
-
-    elif isinstance(a, MapType):
-        return MapType(_merge_type(a.keyType, b.keyType),
-                       _merge_type(a.valueType, b.valueType),
-                       True)
-    else:
-        return a
-
-
-def _need_converter(dataType):
-    if isinstance(dataType, StructType):
-        return True
-    elif isinstance(dataType, ArrayType):
-        return _need_converter(dataType.elementType)
-    elif isinstance(dataType, MapType):
-        return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
-    elif isinstance(dataType, NullType):
-        return True
-    else:
-        return False
-
-
-def _create_converter(dataType):
-    """Create an converter to drop the names of fields in obj """
-    if not _need_converter(dataType):
-        return lambda x: x
-
-    if isinstance(dataType, ArrayType):
-        conv = _create_converter(dataType.elementType)
-        return lambda row: [conv(v) for v in row]
-
-    elif isinstance(dataType, MapType):
-        kconv = _create_converter(dataType.keyType)
-        vconv = _create_converter(dataType.valueType)
-        return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
-
-    elif isinstance(dataType, NullType):
-        return lambda x: None
-
-    elif not isinstance(dataType, StructType):
-        return lambda x: x
-
-    # dataType must be StructType
-    names = [f.name for f in dataType.fields]
-    converters = [_create_converter(f.dataType) for f in dataType.fields]
-    convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
-
-    def convert_struct(obj):
-        if obj is None:
-            return
-
-        if isinstance(obj, (tuple, list)):
-            if convert_fields:
-                return tuple(conv(v) for v, conv in zip(obj, converters))
-            else:
-                return tuple(obj)
-
-        if isinstance(obj, dict):
-            d = obj
-        elif hasattr(obj, "__dict__"):  # object
-            d = obj.__dict__
-        else:
-            raise TypeError("Unexpected obj type: %s" % type(obj))
-
-        if convert_fields:
-            return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
-        else:
-            return tuple([d.get(name) for name in names])
-
-    return convert_struct
-
-
-_BRACKETS = {'(': ')', '[': ']', '{': '}'}
-
-
-def _split_schema_abstract(s):
-    """
-    split the schema abstract into fields
-
-    >>> _split_schema_abstract("a b  c")
-    ['a', 'b', 'c']
-    >>> _split_schema_abstract("a(a b)")
-    ['a(a b)']
-    >>> _split_schema_abstract("a b[] c{a b}")
-    ['a', 'b[]', 'c{a b}']
-    >>> _split_schema_abstract(" ")
-    []
-    """
-
-    r = []
-    w = ''
-    brackets = []
-    for c in s:
-        if c == ' ' and not brackets:
-            if w:
-                r.append(w)
-            w = ''
-        else:
-            w += c
-            if c in _BRACKETS:
-                brackets.append(c)
-            elif c in _BRACKETS.values():
-                if not brackets or c != _BRACKETS[brackets.pop()]:
-                    raise ValueError("unexpected " + c)
-
-    if brackets:
-        raise ValueError("brackets not closed: %s" % brackets)
-    if w:
-        r.append(w)
-    return r
-
-
-def _parse_field_abstract(s):
-    """
-    Parse a field in schema abstract
-
-    >>> _parse_field_abstract("a")
-    StructField(a,NullType,true)
-    >>> _parse_field_abstract("b(c d)")
-    StructField(b,StructType(...c,NullType,true),StructField(d...
-    >>> _parse_field_abstract("a[]")
-    StructField(a,ArrayType(NullType,true),true)
-    >>> _parse_field_abstract("a{[]}")
-    StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
-    """
-    if set(_BRACKETS.keys()) & set(s):
-        idx = min((s.index(c) for c in _BRACKETS if c in s))
-        name = s[:idx]
-        return StructField(name, _parse_schema_abstract(s[idx:]), True)
-    else:
-        return StructField(s, NullType(), True)
-
-
-def _parse_schema_abstract(s):
-    """
-    parse abstract into schema
-
-    >>> _parse_schema_abstract("a b  c")
-    StructType...a...b...c...
-    >>> _parse_schema_abstract("a[b c] b{}")
-    StructType...a,ArrayType...b...c...b,MapType...
-    >>> _parse_schema_abstract("c{} d{a b}")
-    StructType...c,MapType...d,MapType...a...b...
-    >>> _parse_schema_abstract("a b(t)").fields[1]
-    StructField(b,StructType(List(StructField(t,NullType,true))),true)
-    """
-    s = s.strip()
-    if not s:
-        return NullType()
-
-    elif s.startswith('('):
-        return _parse_schema_abstract(s[1:-1])
-
-    elif s.startswith('['):
-        return ArrayType(_parse_schema_abstract(s[1:-1]), True)
-
-    elif s.startswith('{'):
-        return MapType(NullType(), _parse_schema_abstract(s[1:-1]))
-
-    parts = _split_schema_abstract(s)
-    fields = [_parse_field_abstract(p) for p in parts]
-    return StructType(fields)
-
-
-def _infer_schema_type(obj, dataType):
-    """
-    Fill the dataType with types inferred from obj
-
-    >>> schema = _parse_schema_abstract("a b c d")
-    >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
-    >>> _infer_schema_type(row, schema)
-    StructType...LongType...DoubleType...StringType...DateType...
-    >>> row = [[1], {"key": (1, 2.0)}]
-    >>> schema = _parse_schema_abstract("a[] b{c d}")
-    >>> _infer_schema_type(row, schema)
-    StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
-    """
-    if isinstance(dataType, NullType):
-        return _infer_type(obj)
-
-    if not obj:
-        return NullType()
-
-    if isinstance(dataType, ArrayType):
-        eType = _infer_schema_type(obj[0], dataType.elementType)
-        return ArrayType(eType, True)
-
-    elif isinstance(dataType, MapType):
-        k, v = next(iter(obj.items()))
-        return MapType(_infer_schema_type(k, dataType.keyType),
-                       _infer_schema_type(v, dataType.valueType))
-
-    elif isinstance(dataType, StructType):
-        fs = dataType.fields
-        assert len(fs) == len(obj), \
-            "Obj(%s) have different length with fields(%s)" % (obj, fs)
-        fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True)
-                  for o, f in zip(obj, fs)]
-        return StructType(fields)
-
-    else:
-        raise TypeError("Unexpected dataType: %s" % type(dataType))
-
-
-_acceptable_types = {
-    BooleanType: (bool,),
-    ByteType: (int, long),
-    ShortType: (int, long),
-    IntegerType: (int, long),
-    LongType: (int, long),
-    FloatType: (float,),
-    DoubleType: (float,),
-    DecimalType: (decimal.Decimal,),
-    StringType: (str, unicode),
-    BinaryType: (bytearray,),
-    DateType: (datetime.date, datetime.datetime),
-    TimestampType: (datetime.datetime,),
-    ArrayType: (list, tuple, array),
-    MapType: (dict,),
-    StructType: (tuple, list),
-}
-
-
-def _verify_type(obj, dataType):
-    """
-    Verify the type of obj against dataType, raise an exception if
-    they do not match.
-
-    >>> _verify_type(None, StructType([]))
-    >>> _verify_type("", StringType())
-    >>> _verify_type(0, LongType())
-    >>> _verify_type(list(range(3)), ArrayType(ShortType()))
-    >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
-    Traceback (most recent call last):
-        ...
-    TypeError:...
-    >>> _verify_type({}, MapType(StringType(), IntegerType()))
-    >>> _verify_type((), StructType([]))
-    >>> _verify_type([], StructType([]))
-    >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL
-    Traceback (most recent call last):
-        ...
-    ValueError:...
-    >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
-    >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
-    Traceback (most recent call last):
-        ...
-    ValueError:...
-    """
-    # all objects are nullable
-    if obj is None:
-        return
-
-    if isinstance(dataType, UserDefinedType):
-        if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
-            raise ValueError("%r is not an instance of type %r" % (obj, dataType))
-        _verify_type(dataType.serialize(obj), dataType.sqlType())
-        return
-
-    _type = type(dataType)
-    assert _type in _acceptable_types, "unknown datatype: %s" % dataType
-
-    # subclass of them can not be deserialized in JVM
-    if type(obj) not in _acceptable_types[_type]:
-        raise TypeError("%s can not accept object in type %s"
-                        % (dataType, type(obj)))
-
-    if isinstance(dataType, ArrayType):
-        for i in obj:
-            _verify_type(i, dataType.elementType)
-
-    elif isinstance(dataType, MapType):
-        for k, v in obj.items():
-            _verify_type(k, dataType.keyType)
-            _verify_type(v, dataType.valueType)
-
-    elif isinstance(dataType, StructType):
-        if len(obj) != len(dataType.fields):
-            raise ValueError("Length of object (%d) does not match with "
-                             "length of fields (%d)" % (len(obj), len(dataType.fields)))
-        for v, f in zip(obj, dataType.fields):
-            _verify_type(v, f.dataType)
-
-_cached_cls = weakref.WeakValueDictionary()
-
-
-def _restore_object(dataType, obj):
-    """ Restore object during unpickling. """
-    # use id(dataType) as key to speed up lookup in dict
-    # Because of batched pickling, dataType will be the
-    # same object in most cases.
-    k = id(dataType)
-    cls = _cached_cls.get(k)
-    if cls is None or cls.__datatype is not dataType:
-        # use dataType as key to avoid create multiple class
-        cls = _cached_cls.get(dataType)
-        if cls is None:
-            cls = _create_cls(dataType)
-            _cached_cls[dataType] = cls
-        cls.__datatype = dataType
-        _cached_cls[k] = cls
-    return cls(obj)
-
-
-def _create_object(cls, v):
-    """ Create an customized object with class `cls`. """
-    # datetime.date would be deserialized as datetime.datetime
-    # from java type, so we need to set it back.
-    if cls is datetime.date and isinstance(v, datetime.datetime):
-        return v.date()
-    return cls(v) if v is not None else v
-
-
-def _create_getter(dt, i):
-    """ Create a getter for item `i` with schema """
-    cls = _create_cls(dt)
-
-    def getter(self):
-        return _create_object(cls, self[i])
-
-    return getter
-
-
-def _has_struct_or_date(dt):
-    """Return whether `dt` is or has StructType/DateType in it"""
-    if isinstance(dt, StructType):
-        return True
-    elif isinstance(dt, ArrayType):
-        return _has_struct_or_date(dt.elementType)
-    elif isinstance(dt, MapType):
-        return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
-    elif isinstance(dt, DateType):
-        return True
-    elif isinstance(dt, UserDefinedType):
-        return True
-    return False
-
-
-def _create_properties(fields):
-    """Create properties according to fields"""
-    ps = {}
-    for i, f in enumerate(fields):
-        name = f.name
-        if (name.startswith("__") and name.endswith("__")
-                or keyword.iskeyword(name)):
-            warnings.warn("field name %s can not be accessed in Python,"
-                          "use position to access it instead" % name)
-        if _has_struct_or_date(f.dataType):
-            # delay creating object until accessing it
-            getter = _create_getter(f.dataType, i)
-        else:
-            getter = itemgetter(i)
-        ps[name] = property(getter)
-    return ps
-
-
-def _create_cls(dataType):
-    """
-    Create an class by dataType
-
-    The created class is similar to namedtuple, but can have nested schema.
-
-    >>> schema = _parse_schema_abstract("a b c")
-    >>> row = (1, 1.0, "str")
-    >>> schema = _infer_schema_type(row, schema)
-    >>> obj = _create_cls(schema)(row)
-    >>> import pickle
-    >>> pickle.loads(pickle.dumps(obj))
-    Row(a=1, b=1.0, c='str')
-
-    >>> row = [[1], {"key": (1, 2.0)}]
-    >>> schema = _parse_schema_abstract("a[] b{c d}")
-    >>> schema = _infer_schema_type(row, schema)
-    >>> obj = _create_cls(schema)(row)
-    >>> pickle.loads(pickle.dumps(obj))
-    Row(a=[1], b={'key': Row(c=1, d=2.0)})
-    >>> pickle.loads(pickle.dumps(obj.a))
-    [1]
-    >>> pickle.loads(pickle.dumps(obj.b))
-    {'key': Row(c=1, d=2.0)}
-    """
-
-    if isinstance(dataType, ArrayType):
-        cls = _create_cls(dataType.elementType)
-
-        def List(l):
-            if l is None:
-                return
-            return [_create_object(cls, v) for v in l]
-
-        return List
-
-    elif isinstance(dataType, MapType):
-        kcls = _create_cls(dataType.keyType)
-        vcls = _create_cls(dataType.valueType)
-
-        def Dict(d):
-            if d is None:
-                return
-            return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
-
-        return Dict
-
-    elif isinstance(dataType, DateType):
-        return datetime.date
-
-    elif isinstance(dataType, UserDefinedType):
-        return lambda datum: dataType.deserialize(datum)
-
-    elif not isinstance(dataType, StructType):
-        # no wrapper for atomic types
-        return lambda x: x
-
-    class Row(tuple):
-
-        """ Row in DataFrame """
-        __datatype = dataType
-        __fields__ = tuple(f.name for f in dataType.fields)
-        __slots__ = ()
-
-        # create property for fast access
-        locals().update(_create_properties(dataType.fields))
-
-        def asDict(self):
-            """ Return as a dict """
-            return dict((n, getattr(self, n)) for n in self.__fields__)
-
-        def __repr__(self):
-            # call collect __repr__ for nested objects
-            return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
-                                          for n in self.__fields__))
-
-        def __reduce__(self):
-            return (_restore_object, (self.__datatype, tuple(self)))
-
-    return Row
-
-
-def _create_row(fields, values):
-    row = Row(*values)
-    row.__fields__ = fields
-    return row
-
-
-class Row(tuple):
-
-    """
-    A row in L{DataFrame}. The fields in it can be accessed like attributes.
-
-    Row can be used to create a row object by using named arguments,
-    the fields will be sorted by names.
-
-    >>> row = Row(name="Alice", age=11)
-    >>> row
-    Row(age=11, name='Alice')
-    >>> row.name, row.age
-    ('Alice', 11)
-
-    Row also can be used to create another Row like class, then it
-    could be used to create Row objects, such as
-
-    >>> Person = Row("name", "age")
-    >>> Person
-    <Row(name, age)>
-    >>> Person("Alice", 11)
-    Row(name='Alice', age=11)
-    """
-
-    def __new__(self, *args, **kwargs):
-        if args and kwargs:
-            raise ValueError("Can not use both args "
-                             "and kwargs to create Row")
-        if args:
-            # create row class or objects
-            return tuple.__new__(self, args)
-
-        elif kwargs:
-            # create row objects
-            names = sorted(kwargs.keys())
-            row = tuple.__new__(self, [kwargs[n] for n in names])
-            row.__fields__ = names
-            return row
-
-        else:
-            raise ValueError("No args or kwargs")
-
-    def asDict(self):
-        """
-        Return as an dict
-        """
-        if not hasattr(self, "__fields__"):
-            raise TypeError("Cannot convert a Row class into dict")
-        return dict(zip(self.__fields__, self))
-
-    # let object acts like class
-    def __call__(self, *args):
-        """create new Row object"""
-        return _create_row(self, args)
-
-    def __getattr__(self, item):
-        if item.startswith("__"):
-            raise AttributeError(item)
-        try:
-            # it will be slow when it has many fields,
-            # but this will not be used in normal cases
-            idx = self.__fields__.index(item)
-            return self[idx]
-        except IndexError:
-            raise AttributeError(item)
-        except ValueError:
-            raise AttributeError(item)
-
-    def __reduce__(self):
-        """Returns a tuple so Python knows how to pickle Row."""
-        if hasattr(self, "__fields__"):
-            return (_create_row, (self.__fields__, tuple(self)))
-        else:
-            return tuple.__reduce__(self)
-
-    def __repr__(self):
-        """Printable representation of Row used in Python REPL."""
-        if hasattr(self, "__fields__"):
-            return "Row(%s)" % ", ".join("%s=%r" % (k, v)
-                                         for k, v in zip(self.__fields__, tuple(self)))
-        else:
-            return "<Row(%s)>" % ", ".join(self)
-
-
-class DateConverter(object):
-    def can_convert(self, obj):
-        return isinstance(obj, datetime.date)
-
-    def convert(self, obj, gateway_client):
-        Date = JavaClass("java.sql.Date", gateway_client)
-        return Date.valueOf(obj.strftime("%Y-%m-%d"))
-
-
-class DatetimeConverter(object):
-    def can_convert(self, obj):
-        return isinstance(obj, datetime.datetime)
-
-    def convert(self, obj, gateway_client):
-        Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
-        return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
-
-
-# datetime is a subclass of date, we should register DatetimeConverter first
-register_input_converter(DatetimeConverter())
-register_input_converter(DateConverter())
-
-
-def _test():
-    import doctest
-    from pyspark.context import SparkContext
-    # let doctest run in pyspark.sql.types, so DataTypes can be picklable
-    import pyspark.sql.types
-    from pyspark.sql import Row, SQLContext
-    from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
-    globs = pyspark.sql.types.__dict__.copy()
-    sc = SparkContext('local[4]', 'PythonTest')
-    globs['sc'] = sc
-    globs['sqlContext'] = SQLContext(sc)
-    globs['ExamplePoint'] = ExamplePoint
-    globs['ExamplePointUDT'] = ExamplePointUDT
-    (failure_count, test_count) = doctest.testmod(
-        pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
-    if failure_count:
-        exit(-1)
-
-
-if __name__ == "__main__":
-    _test()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org