You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/07/29 21:31:43 UTC

git commit: [SPARK-2674] [SQL] [PySpark] support datetime type for SchemaRDD

Repository: spark
Updated Branches:
  refs/heads/master e3643485d -> f0d880e28


[SPARK-2674] [SQL] [PySpark] support datetime type for SchemaRDD

Datetime and time in Python will be converted into java.util.Calendar after serialization, it will be converted into java.sql.Timestamp during inferSchema().

In javaToPython(), Timestamp will be converted into Calendar, then be converted into datetime in Python after pickling.

Author: Davies Liu <da...@gmail.com>

Closes #1601 from davies/date and squashes the following commits:

f0599b0 [Davies Liu] remove tests for sets and tuple in sql, fix list of list
c9d607a [Davies Liu] convert datetype for runtime
709d40d [Davies Liu] remove brackets
96db384 [Davies Liu] support datetime type for SchemaRDD


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f0d880e2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f0d880e2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f0d880e2

Branch: refs/heads/master
Commit: f0d880e288eba97c86dceb1b5edab4f3a935943b
Parents: e364348
Author: Davies Liu <da...@gmail.com>
Authored: Tue Jul 29 12:31:39 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Jul 29 12:31:39 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala |  4 +-
 python/pyspark/sql.py                           | 22 +++++-----
 .../scala/org/apache/spark/sql/SQLContext.scala | 40 +++++++++++++++--
 .../scala/org/apache/spark/sql/SchemaRDD.scala  | 46 ++++++++------------
 4 files changed, 68 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f0d880e2/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d87783e..0d8453f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging {
   def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
     pyRDD.rdd.mapPartitions { iter =>
       val unpickle = new Unpickler
-      // TODO: Figure out why flatMap is necessay for pyspark
       iter.flatMap { row =>
         unpickle.loads(row) match {
+          // in case of objects are pickled in batch mode
           case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
-          // Incase the partition doesn't have a collection
+          // not in batch mode
           case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/f0d880e2/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index cb83e89..a6b3277 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -47,12 +47,14 @@ class SQLContext:
             ...
         ValueError:...
 
-        >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
-        ... "boolean" : True}])
+        >>> from datetime import datetime
+        >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L,
+        ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1},
+        ... "list": [1, 2, 3]}])
         >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
-        ... x.boolean))
+        ... x.boolean, x.time, x.dict["a"], x.list))
         >>> srdd.collect()[0]
-        (1, u'string', 1.0, 1, True)
+        (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3])
         """
         self._sc = sparkContext
         self._jsc = self._sc._jsc
@@ -88,13 +90,13 @@ class SQLContext:
 
         >>> from array import array
         >>> srdd = sqlCtx.inferSchema(nestedRdd1)
-        >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
-        ...                    {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
+        >>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}},
+        ...                    {"f1" : [2, 3], "f2" : {"row2" : 2.0}}]
         True
 
         >>> srdd = sqlCtx.inferSchema(nestedRdd2)
-        >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
-        ...                    {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
+        >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]},
+        ...                    {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}]
         True
         """
         if (rdd.__class__ is SchemaRDD):
@@ -509,8 +511,8 @@ def _test():
         {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}},
         {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}])
     globs['nestedRdd2'] = sc.parallelize([
-        {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)},
-        {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}])
+        {"f1": [[1, 2], [2, 3]], "f2": [1, 2]},
+        {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}])
     (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
     globs['sc'].stop()
     if failure_count:

http://git-wip-us.apache.org/repos/asf/spark/blob/f0d880e2/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 4abd899..c178dad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -352,8 +352,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
       case c: java.lang.Long => LongType
       case c: java.lang.Double => DoubleType
       case c: java.lang.Boolean => BooleanType
+      case c: java.math.BigDecimal => DecimalType
+      case c: java.sql.Timestamp => TimestampType
+      case c: java.util.Calendar => TimestampType
       case c: java.util.List[_] => ArrayType(typeFor(c.head))
-      case c: java.util.Set[_] => ArrayType(typeFor(c.head))
       case c: java.util.Map[_, _] =>
         val (key, value) = c.head
         MapType(typeFor(key), typeFor(value))
@@ -362,11 +364,43 @@ class SQLContext(@transient val sparkContext: SparkContext)
         ArrayType(typeFor(elem))
       case c => throw new Exception(s"Object of type $c cannot be used")
     }
-    val schema = rdd.first().map { case (fieldName, obj) =>
+    val firstRow = rdd.first()
+    val schema = firstRow.map { case (fieldName, obj) =>
       AttributeReference(fieldName, typeFor(obj), true)()
     }.toSeq
 
-    val rowRdd = rdd.mapPartitions { iter =>
+    def needTransform(obj: Any): Boolean = obj match {
+      case c: java.util.List[_] => true
+      case c: java.util.Map[_, _] => true
+      case c if c.getClass.isArray => true
+      case c: java.util.Calendar => true
+      case c => false
+    }
+
+    // convert JList, JArray into Seq, convert JMap into Map
+    // convert Calendar into Timestamp
+    def transform(obj: Any): Any = obj match {
+      case c: java.util.List[_] => c.map(transform).toSeq
+      case c: java.util.Map[_, _] => c.map {
+        case (key, value) => (key, transform(value))
+      }.toMap
+      case c if c.getClass.isArray =>
+        c.asInstanceOf[Array[_]].map(transform).toSeq
+      case c: java.util.Calendar =>
+        new java.sql.Timestamp(c.getTime().getTime())
+      case c => c
+    }
+
+    val need = firstRow.exists {case (key, value) => needTransform(value)}
+    val transformed = if (need) {
+      rdd.mapPartitions { iter =>
+        iter.map {
+          m => m.map {case (key, value) => (key, transform(value))}
+        }
+      }
+    } else rdd
+
+    val rowRdd = transformed.mapPartitions { iter =>
       iter.map { map =>
         new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/f0d880e2/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 31d27bb..019ff9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType}
+import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType}
 import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
 import org.apache.spark.api.java.JavaRDD
 
@@ -376,39 +376,27 @@ class SchemaRDD(
    * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
    */
   private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+    def toJava(obj: Any, dataType: DataType): Any = dataType match {
+      case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
+      case array: ArrayType => obj match {
+        case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
+        case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
+        case arr if arr != null && arr.getClass.isArray =>
+          arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
+        case other => other
+      }
+      case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
+        case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
+      }.asJava
+      // Pyrolite can handle Timestamp
+      case other => obj
+    }
     def rowToMap(row: Row, structType: StructType): JMap[String, Any] = {
       val fields = structType.fields.map(field => (field.name, field.dataType))
       val map: JMap[String, Any] = new java.util.HashMap
       row.zip(fields).foreach {
-        case (obj, (attrName, dataType)) =>
-          dataType match {
-            case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct))
-            case array @ ArrayType(struct: StructType) =>
-              val arrayValues = obj match {
-                case seq: Seq[Any] =>
-                  seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
-                case list: JList[_] =>
-                  list.map(element => rowToMap(element.asInstanceOf[Row], struct))
-                case set: JSet[_] =>
-                  set.map(element => rowToMap(element.asInstanceOf[Row], struct))
-                case arr if arr != null && arr.getClass.isArray =>
-                  arr.asInstanceOf[Array[Any]].map {
-                    element => rowToMap(element.asInstanceOf[Row], struct)
-                  }
-                case other => other
-              }
-              map.put(attrName, arrayValues)
-            case array: ArrayType => {
-              val arrayValues = obj match {
-                case seq: Seq[Any] => seq.asJava
-                case other => other
-              }
-              map.put(attrName, arrayValues)
-            }
-            case other => map.put(attrName, obj)
-          }
+        case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType))
       }
-
       map
     }