You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/12/21 21:47:12 UTC

spark git commit: [SPARK-12321][SQL] JSON format for TreeNode (use reflection)

Repository: spark
Updated Branches:
  refs/heads/master 474eb21a3 -> 7634fe951


[SPARK-12321][SQL] JSON format for TreeNode (use reflection)

An alternative solution for https://github.com/apache/spark/pull/10295 , instead of implementing json format for all logical/physical plans and expressions, use reflection to implement it in `TreeNode`.

Here I use pre-order traversal to flattern a plan tree to a plan list, and add an extra field `num-children` to each plan node, so that we can reconstruct the tree from the list.

example json:

logical plan tree:
```
[ {
  "class" : "org.apache.spark.sql.catalyst.plans.logical.Sort",
  "num-children" : 1,
  "order" : [ [ {
    "class" : "org.apache.spark.sql.catalyst.expressions.SortOrder",
    "num-children" : 1,
    "child" : 0,
    "direction" : "Ascending"
  }, {
    "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference",
    "num-children" : 0,
    "name" : "i",
    "dataType" : "integer",
    "nullable" : true,
    "metadata" : { },
    "exprId" : {
      "id" : 10,
      "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6"
    },
    "qualifiers" : [ ]
  } ] ],
  "global" : false,
  "child" : 0
}, {
  "class" : "org.apache.spark.sql.catalyst.plans.logical.Project",
  "num-children" : 1,
  "projectList" : [ [ {
    "class" : "org.apache.spark.sql.catalyst.expressions.Alias",
    "num-children" : 1,
    "child" : 0,
    "name" : "i",
    "exprId" : {
      "id" : 10,
      "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6"
    },
    "qualifiers" : [ ]
  }, {
    "class" : "org.apache.spark.sql.catalyst.expressions.Add",
    "num-children" : 2,
    "left" : 0,
    "right" : 1
  }, {
    "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference",
    "num-children" : 0,
    "name" : "a",
    "dataType" : "integer",
    "nullable" : true,
    "metadata" : { },
    "exprId" : {
      "id" : 0,
      "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6"
    },
    "qualifiers" : [ ]
  }, {
    "class" : "org.apache.spark.sql.catalyst.expressions.Literal",
    "num-children" : 0,
    "value" : "1",
    "dataType" : "integer"
  } ], [ {
    "class" : "org.apache.spark.sql.catalyst.expressions.Alias",
    "num-children" : 1,
    "child" : 0,
    "name" : "j",
    "exprId" : {
      "id" : 11,
      "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6"
    },
    "qualifiers" : [ ]
  }, {
    "class" : "org.apache.spark.sql.catalyst.expressions.Multiply",
    "num-children" : 2,
    "left" : 0,
    "right" : 1
  }, {
    "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference",
    "num-children" : 0,
    "name" : "a",
    "dataType" : "integer",
    "nullable" : true,
    "metadata" : { },
    "exprId" : {
      "id" : 0,
      "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6"
    },
    "qualifiers" : [ ]
  }, {
    "class" : "org.apache.spark.sql.catalyst.expressions.Literal",
    "num-children" : 0,
    "value" : "2",
    "dataType" : "integer"
  } ] ],
  "child" : 0
}, {
  "class" : "org.apache.spark.sql.catalyst.plans.logical.LocalRelation",
  "num-children" : 0,
  "output" : [ [ {
    "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference",
    "num-children" : 0,
    "name" : "a",
    "dataType" : "integer",
    "nullable" : true,
    "metadata" : { },
    "exprId" : {
      "id" : 0,
      "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6"
    },
    "qualifiers" : [ ]
  } ] ],
  "data" : [ ]
} ]
```

Author: Wenchen Fan <we...@databricks.com>

Closes #10311 from cloud-fan/toJson-reflection.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7634fe95
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7634fe95
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7634fe95

Branch: refs/heads/master
Commit: 7634fe9511e1a8fb94979624b1b617b495b48ad3
Parents: 474eb21
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Dec 21 12:47:07 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon Dec 21 12:47:07 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 114 ++++----
 .../expressions/aggregate/interfaces.scala      |   1 -
 .../sql/catalyst/expressions/literals.scala     |  41 +++
 .../catalyst/expressions/namedExpressions.scala |   4 +
 .../spark/sql/catalyst/plans/QueryPlan.scala    |   2 +
 .../spark/sql/catalyst/trees/TreeNode.scala     | 258 ++++++++++++++++++-
 .../org/apache/spark/sql/types/DataType.scala   |   6 +-
 .../spark/sql/execution/ExistingRDD.scala       |   4 +-
 .../columnar/InMemoryColumnarTableScan.scala    |   6 +-
 .../scala/org/apache/spark/sql/QueryTest.scala  | 102 +++++++-
 .../apache/spark/sql/UserDefinedTypeSuite.scala |   5 +
 .../spark/sql/hive/HiveMetastoreCatalog.scala   |   2 +
 .../hive/execution/ScriptTransformation.scala   |   2 +-
 13 files changed, 472 insertions(+), 75 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c1b1d5c..cc9e6af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -68,7 +68,7 @@ object ScalaReflection extends ScalaReflection {
             val TypeRef(_, _, Seq(elementType)) = tpe
             arrayClassFor(elementType)
           case other =>
-            val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
+            val clazz = getClassFromType(tpe)
             ObjectType(clazz)
         }
     }
@@ -321,29 +321,11 @@ object ScalaReflection extends ScalaReflection {
           keyData :: valueData :: Nil)
 
       case t if t <:< localTypeOf[Product] =>
-        val formalTypeArgs = t.typeSymbol.asClass.typeParams
-        val TypeRef(_, _, actualTypeArgs) = t
-        val constructorSymbol = t.member(nme.CONSTRUCTOR)
-        val params = if (constructorSymbol.isMethod) {
-          constructorSymbol.asMethod.paramss
-        } else {
-          // Find the primary constructor, and use its parameter ordering.
-          val primaryConstructorSymbol: Option[Symbol] =
-            constructorSymbol.asTerm.alternatives.find(s =>
-              s.isMethod && s.asMethod.isPrimaryConstructor)
+        val params = getConstructorParameters(t)
 
-          if (primaryConstructorSymbol.isEmpty) {
-            sys.error("Internal SQL error: Product object did not have a primary constructor.")
-          } else {
-            primaryConstructorSymbol.get.asMethod.paramss
-          }
-        }
+        val cls = getClassFromType(tpe)
 
-        val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
-
-        val arguments = params.head.zipWithIndex.map { case (p, i) =>
-          val fieldName = p.name.toString
-          val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+        val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
           val dataType = schemaFor(fieldType).dataType
           val clsName = getClassNameFromType(fieldType)
           val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
@@ -477,27 +459,9 @@ object ScalaReflection extends ScalaReflection {
           }
 
         case t if t <:< localTypeOf[Product] =>
-          val formalTypeArgs = t.typeSymbol.asClass.typeParams
-          val TypeRef(_, _, actualTypeArgs) = t
-          val constructorSymbol = t.member(nme.CONSTRUCTOR)
-          val params = if (constructorSymbol.isMethod) {
-            constructorSymbol.asMethod.paramss
-          } else {
-            // Find the primary constructor, and use its parameter ordering.
-            val primaryConstructorSymbol: Option[Symbol] =
-              constructorSymbol.asTerm.alternatives.find(s =>
-                s.isMethod && s.asMethod.isPrimaryConstructor)
-
-            if (primaryConstructorSymbol.isEmpty) {
-              sys.error("Internal SQL error: Product object did not have a primary constructor.")
-            } else {
-              primaryConstructorSymbol.get.asMethod.paramss
-            }
-          }
+          val params = getConstructorParameters(t)
 
-          CreateNamedStruct(params.head.flatMap { p =>
-            val fieldName = p.name.toString
-            val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+          CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
             val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
             val clsName = getClassNameFromType(fieldType)
             val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
@@ -595,6 +559,21 @@ object ScalaReflection extends ScalaReflection {
       }
     }
   }
+
+  /**
+   * Returns the parameter names and types for the primary constructor of this class.
+   *
+   * Note that it only works for scala classes with primary constructor, and currently doesn't
+   * support inner class.
+   */
+  def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = {
+    val m = runtimeMirror(cls.getClassLoader)
+    val classSymbol = m.staticClass(cls.getName)
+    val t = classSymbol.selfType
+    getConstructorParameters(t)
+  }
+
+  def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
 }
 
 /**
@@ -668,26 +647,11 @@ trait ScalaReflection {
         Schema(MapType(schemaFor(keyType).dataType,
           valueDataType, valueContainsNull = valueNullable), nullable = true)
       case t if t <:< localTypeOf[Product] =>
-        val formalTypeArgs = t.typeSymbol.asClass.typeParams
-        val TypeRef(_, _, actualTypeArgs) = t
-        val constructorSymbol = t.member(nme.CONSTRUCTOR)
-        val params = if (constructorSymbol.isMethod) {
-          constructorSymbol.asMethod.paramss
-        } else {
-          // Find the primary constructor, and use its parameter ordering.
-          val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
-            s => s.isMethod && s.asMethod.isPrimaryConstructor)
-          if (primaryConstructorSymbol.isEmpty) {
-            sys.error("Internal SQL error: Product object did not have a primary constructor.")
-          } else {
-            primaryConstructorSymbol.get.asMethod.paramss
-          }
-        }
+        val params = getConstructorParameters(t)
         Schema(StructType(
-          params.head.map { p =>
-            val Schema(dataType, nullable) =
-              schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
-            StructField(p.name.toString, dataType, nullable)
+          params.map { case (fieldName, fieldType) =>
+            val Schema(dataType, nullable) = schemaFor(fieldType)
+            StructField(fieldName, dataType, nullable)
           }), nullable = true)
       case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
       case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
@@ -740,4 +704,32 @@ trait ScalaReflection {
     assert(methods.length == 1)
     methods.head.getParameterTypes
   }
+
+  /**
+   * Returns the parameter names and types for the primary constructor of this type.
+   *
+   * Note that it only works for scala classes with primary constructor, and currently doesn't
+   * support inner class.
+   */
+  def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
+    val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
+    val TypeRef(_, _, actualTypeArgs) = tpe
+    val constructorSymbol = tpe.member(nme.CONSTRUCTOR)
+    val params = if (constructorSymbol.isMethod) {
+      constructorSymbol.asMethod.paramss
+    } else {
+      // Find the primary constructor, and use its parameter ordering.
+      val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
+        s => s.isMethod && s.asMethod.isPrimaryConstructor)
+      if (primaryConstructorSymbol.isEmpty) {
+        sys.error("Internal SQL error: Product object did not have a primary constructor.")
+      } else {
+        primaryConstructorSymbol.get.asMethod.paramss
+      }
+    }
+
+    params.flatten.map { p =>
+      p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index b6d2ddc..b616d69 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.catalyst.InternalRow

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 68ec688..e3573b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.json4s.JsonAST._
 import java.sql.{Date, Timestamp}
 
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -55,6 +56,34 @@ object Literal {
    */
   def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass))
 
+  def fromJSON(json: JValue): Literal = {
+    val dataType = DataType.parseDataType(json \ "dataType")
+    json \ "value" match {
+      case JNull => Literal.create(null, dataType)
+      case JString(str) =>
+        val value = dataType match {
+          case BooleanType => str.toBoolean
+          case ByteType => str.toByte
+          case ShortType => str.toShort
+          case IntegerType => str.toInt
+          case LongType => str.toLong
+          case FloatType => str.toFloat
+          case DoubleType => str.toDouble
+          case StringType => UTF8String.fromString(str)
+          case DateType => java.sql.Date.valueOf(str)
+          case TimestampType => java.sql.Timestamp.valueOf(str)
+          case CalendarIntervalType => CalendarInterval.fromString(str)
+          case t: DecimalType =>
+            val d = Decimal(str)
+            assert(d.changePrecision(t.precision, t.scale))
+            d
+          case _ => null
+        }
+        Literal.create(value, dataType)
+      case other => sys.error(s"$other is not a valid Literal json value")
+    }
+  }
+
   def create(v: Any, dataType: DataType): Literal = {
     Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
   }
@@ -123,6 +152,18 @@ case class Literal protected (value: Any, dataType: DataType)
     case _ => false
   }
 
+  override protected def jsonFields: List[JField] = {
+    // Turns all kinds of literal values to string in json field, as the type info is hard to
+    // retain in json format, e.g. {"a": 123} can be a int, or double, or decimal, etc.
+    val jsonValue = (value, dataType) match {
+      case (null, _) => JNull
+      case (i: Int, DateType) => JString(DateTimeUtils.toJavaDate(i).toString)
+      case (l: Long, TimestampType) => JString(DateTimeUtils.toJavaTimestamp(l).toString)
+      case (other, _) => JString(other.toString)
+    }
+    ("value" -> jsonValue) :: ("dataType" -> dataType.jsonValue) :: Nil
+  }
+
   override def eval(input: InternalRow): Any = value
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 26b6aca..eefd9c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -262,6 +262,10 @@ case class AttributeReference(
     }
   }
 
+  override protected final def otherCopyArgs: Seq[AnyRef] = {
+    exprId :: qualifiers :: Nil
+  }
+
   override def toString: String = s"$name#${exprId.id}$typeSuffix"
 
   // Since the expression id is not in the first constructor it is missing from the default

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index b9db783..d262644 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -88,6 +88,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
       case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
+      case null => null
     }
 
     val newArgs = productIterator.map(recursiveTransform).toArray
@@ -120,6 +121,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
       case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
+      case null => null
     }
 
     val newArgs = productIterator.map(recursiveTransform).toArray

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index d838d84..c97dc2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -17,9 +17,25 @@
 
 package org.apache.spark.sql.catalyst.trees
 
+import java.util.UUID
 import scala.collection.Map
-
+import scala.collection.mutable.Stack
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.ScalaReflection._
+import org.apache.spark.sql.catalyst.{TableIdentifier, ScalaReflectionLock}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types.{StructType, DataType}
 
 /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
@@ -463,4 +479,244 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
     }
     s"$nodeName(${args.mkString(",")})"
   }
+
+  def toJSON: String = compact(render(jsonValue))
+
+  def prettyJson: String = pretty(render(jsonValue))
+
+  private def jsonValue: JValue = {
+    val jsonValues = scala.collection.mutable.ArrayBuffer.empty[JValue]
+
+    def collectJsonValue(tn: BaseType): Unit = {
+      val jsonFields = ("class" -> JString(tn.getClass.getName)) ::
+        ("num-children" -> JInt(tn.children.length)) :: tn.jsonFields
+      jsonValues += JObject(jsonFields)
+      tn.children.foreach(collectJsonValue)
+    }
+
+    collectJsonValue(this)
+    jsonValues
+  }
+
+  protected def jsonFields: List[JField] = {
+    val fieldNames = getConstructorParameters(getClass).map(_._1)
+    val fieldValues = productIterator.toSeq ++ otherCopyArgs
+    assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " +
+      fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", "))
+
+    fieldNames.zip(fieldValues).map {
+      // If the field value is a child, then use an int to encode it, represents the index of
+      // this child in all children.
+      case (name, value: TreeNode[_]) if containsChild(value) =>
+        name -> JInt(children.indexOf(value))
+      case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) =>
+        name -> JArray(
+          value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList
+        )
+      case (name, value) => name -> parseToJson(value)
+    }.toList
+  }
+
+  private def parseToJson(obj: Any): JValue = obj match {
+    case b: Boolean => JBool(b)
+    case b: Byte => JInt(b.toInt)
+    case s: Short => JInt(s.toInt)
+    case i: Int => JInt(i)
+    case l: Long => JInt(l)
+    case f: Float => JDouble(f)
+    case d: Double => JDouble(d)
+    case b: BigInt => JInt(b)
+    case null => JNull
+    case s: String => JString(s)
+    case u: UUID => JString(u.toString)
+    case dt: DataType => dt.jsonValue
+    case m: Metadata => m.jsonValue
+    case s: StorageLevel =>
+      ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~
+        ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication)
+    case n: TreeNode[_] => n.jsonValue
+    case o: Option[_] => o.map(parseToJson)
+    case t: Seq[_] => JArray(t.map(parseToJson).toList)
+    case m: Map[_, _] =>
+      val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) }
+      JObject(fields)
+    case r: RDD[_] => JNothing
+    // if it's a scala object, we can simply keep the full class path.
+    // TODO: currently if the class name ends with "$", we think it's a scala object, there is
+    // probably a better way to check it.
+    case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName
+    // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper
+    case p: Product => try {
+      val fieldNames = getConstructorParameters(p.getClass).map(_._1)
+      val fieldValues = p.productIterator.toSeq
+      assert(fieldNames.length == fieldValues.length)
+      ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
+        case (name, value) => name -> parseToJson(value)
+      }.toList
+    } catch {
+      case _: RuntimeException => null
+    }
+    case _ => JNull
+  }
+}
+
+object TreeNode {
+  def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = {
+    val jsonAST = parse(json)
+    assert(jsonAST.isInstanceOf[JArray])
+    reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType]
+  }
+
+  private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = {
+    assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject]))
+    val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*)
+
+    def parseNextNode(): TreeNode[_] = {
+      val nextNode = jsonNodes.pop()
+
+      val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s)
+      if (cls == classOf[Literal]) {
+        Literal.fromJSON(nextNode)
+      } else if (cls.getName.endsWith("$")) {
+        cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]]
+      } else {
+        val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt
+
+        val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode())
+        val fields = getConstructorParameters(cls)
+
+        val parameters: Array[AnyRef] = fields.map {
+          case (fieldName, fieldType) =>
+            parseFromJson(nextNode \ fieldName, fieldType, children, sc)
+        }.toArray
+
+        val maybeCtor = cls.getConstructors.find { p =>
+          val expectedTypes = p.getParameterTypes
+          expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall {
+            case (cls, tpe) => cls == getClassFromType(tpe)
+          }
+        }
+        if (maybeCtor.isEmpty) {
+          sys.error(s"No valid constructor for ${cls.getName}")
+        } else {
+          try {
+            maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]]
+          } catch {
+            case e: java.lang.IllegalArgumentException =>
+              throw new RuntimeException(
+                s"""
+                  |Failed to construct tree node: ${cls.getName}
+                  |ctor: ${maybeCtor.get}
+                  |types: ${parameters.map(_.getClass).mkString(", ")}
+                  |args: ${parameters.mkString(", ")}
+                """.stripMargin, e)
+          }
+        }
+      }
+    }
+
+    parseNextNode()
+  }
+
+  import universe._
+
+  private def parseFromJson(
+      value: JValue,
+      expectedType: Type,
+      children: Seq[TreeNode[_]],
+      sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized {
+    if (value == JNull) return null
+
+    expectedType match {
+      case t if t <:< definitions.BooleanTpe =>
+        value.asInstanceOf[JBool].value: java.lang.Boolean
+      case t if t <:< definitions.ByteTpe =>
+        value.asInstanceOf[JInt].num.toByte: java.lang.Byte
+      case t if t <:< definitions.ShortTpe =>
+        value.asInstanceOf[JInt].num.toShort: java.lang.Short
+      case t if t <:< definitions.IntTpe =>
+        value.asInstanceOf[JInt].num.toInt: java.lang.Integer
+      case t if t <:< definitions.LongTpe =>
+        value.asInstanceOf[JInt].num.toLong: java.lang.Long
+      case t if t <:< definitions.FloatTpe =>
+        value.asInstanceOf[JDouble].num.toFloat: java.lang.Float
+      case t if t <:< definitions.DoubleTpe =>
+        value.asInstanceOf[JDouble].num: java.lang.Double
+
+      case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num
+      case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s
+      case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s)
+      case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value)
+      case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject])
+      case t if t <:< localTypeOf[StorageLevel] =>
+        val JBool(useDisk) = value \ "useDisk"
+        val JBool(useMemory) = value \ "useMemory"
+        val JBool(useOffHeap) = value \ "useOffHeap"
+        val JBool(deserialized) = value \ "deserialized"
+        val JInt(replication) = value \ "replication"
+        StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt)
+      case t if t <:< localTypeOf[TreeNode[_]] => value match {
+        case JInt(i) => children(i.toInt)
+        case arr: JArray => reconstruct(arr, sc)
+        case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.")
+      }
+      case t if t <:< localTypeOf[Option[_]] =>
+        if (value == JNothing) {
+          None
+        } else {
+          val TypeRef(_, _, Seq(optType)) = t
+          Option(parseFromJson(value, optType, children, sc))
+        }
+      case t if t <:< localTypeOf[Seq[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        val JArray(elements) = value
+        elements.map(parseFromJson(_, elementType, children, sc)).toSeq
+      case t if t <:< localTypeOf[Map[_, _]] =>
+        val TypeRef(_, _, Seq(keyType, valueType)) = t
+        val JObject(fields) = value
+        fields.map {
+          case (name, value) => name -> parseFromJson(value, valueType, children, sc)
+        }.toMap
+      case t if t <:< localTypeOf[RDD[_]] =>
+        new EmptyRDD[Any](sc)
+      case _ if isScalaObject(value) =>
+        val JString(clsName) = value \ "object"
+        val cls = Utils.classForName(clsName)
+        cls.getField("MODULE$").get(cls)
+      case t if t <:< localTypeOf[Product] =>
+        val fields = getConstructorParameters(t)
+        val clsName = getClassNameFromType(t)
+        parseToProduct(clsName, fields, value, children, sc)
+      // There maybe some cases that the parameter type signature is not Product but the value is,
+      // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here.
+      case _ if isScalaProduct(value) =>
+        val JString(clsName) = value \ "product-class"
+        val fields = getConstructorParameters(Utils.classForName(clsName))
+        parseToProduct(clsName, fields, value, children, sc)
+      case _ => sys.error(s"Do not support type $expectedType with json $value.")
+    }
+  }
+
+  private def parseToProduct(
+      clsName: String,
+      fields: Seq[(String, Type)],
+      value: JValue,
+      children: Seq[TreeNode[_]],
+      sc: SparkContext): AnyRef = {
+    val parameters: Array[AnyRef] = fields.map {
+      case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc)
+    }.toArray
+    val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size)
+    ctor.newInstance(parameters: _*).asInstanceOf[AnyRef]
+  }
+
+  private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match {
+    case JString(str) if str.endsWith("$") => true
+    case _ => false
+  }
+
+  private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match {
+    case _: JString => true
+    case _ => false
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index b0c43c4..f8d71c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -107,8 +107,8 @@ object DataType {
   def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
 
   private val nonDecimalNameToType = {
-    Seq(NullType, DateType, TimestampType, BinaryType,
-      IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+    Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
+      DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType)
       .map(t => t.typeName -> t).toMap
   }
 
@@ -130,7 +130,7 @@ object DataType {
   }
 
   // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
-  private def parseDataType(json: JValue): DataType = json match {
+  private[sql] def parseDataType(json: JValue): DataType = json match {
     case JString(name) =>
       nameToType(name)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index b8a4302..ea5a9af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -74,9 +74,7 @@ private[sql] case class LogicalRDD(
 
   override def children: Seq[LogicalPlan] = Nil
 
-  override protected final def otherCopyArgs: Seq[AnyRef] = {
-    sqlContext :: Nil
-  }
+  override protected final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil
 
   override def newInstance(): LogicalRDD.this.type =
     LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index 3c5a8cb..4afa5f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -61,9 +61,9 @@ private[sql] case class InMemoryRelation(
     storageLevel: StorageLevel,
     @transient child: SparkPlan,
     tableName: Option[String])(
-    @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null,
-    @transient private var _statistics: Statistics = null,
-    private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
+    @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
+    @transient private[sql] var _statistics: Statistics = null,
+    private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
   extends LogicalPlan with MultiInstanceRelation {
 
   private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index bc22fb8..9246f55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -21,10 +21,15 @@ import java.util.{Locale, TimeZone}
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
-import org.apache.spark.sql.execution.Queryable
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.{LogicalRDD, Queryable}
 
 abstract class QueryTest extends PlanTest {
 
@@ -123,6 +128,8 @@ abstract class QueryTest extends PlanTest {
              |""".stripMargin)
     }
 
+    checkJsonFormat(analyzedDF)
+
     QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
@@ -177,6 +184,97 @@ abstract class QueryTest extends PlanTest {
       s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
         planWithCaching)
   }
+
+  private def checkJsonFormat(df: DataFrame): Unit = {
+    val logicalPlan = df.queryExecution.analyzed
+    // bypass some cases that we can't handle currently.
+    logicalPlan.transform {
+      case _: MapPartitions[_, _] => return
+      case _: MapGroups[_, _, _] => return
+      case _: AppendColumns[_, _] => return
+      case _: CoGroup[_, _, _, _] => return
+      case _: LogicalRelation => return
+    }.transformAllExpressions {
+      case a: ImperativeAggregate => return
+    }
+
+    val jsonString = try {
+      logicalPlan.toJSON
+    } catch {
+      case e =>
+        fail(
+          s"""
+             |Failed to parse logical plan to JSON:
+             |${logicalPlan.treeString}
+           """.stripMargin, e)
+    }
+
+    // bypass hive tests before we fix all corner cases in hive module.
+    if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return
+
+    // scala function is not serializable to JSON, use null to replace them so that we can compare
+    // the plans later.
+    val normalized1 = logicalPlan.transformAllExpressions {
+      case udf: ScalaUDF => udf.copy(function = null)
+      case gen: UserDefinedGenerator => gen.copy(function = null)
+    }
+
+    // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
+    // these non-serializable stuff, and use these original ones to replace the null-placeholders
+    // in the logical plans parsed from JSON.
+    var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l }
+    var localRelations = logicalPlan.collect { case l: LocalRelation => l }
+    var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i }
+
+    val jsonBackPlan = try {
+      TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
+    } catch {
+      case e =>
+        fail(
+          s"""
+             |Failed to rebuild the logical plan from JSON:
+             |${logicalPlan.treeString}
+             |
+             |${logicalPlan.prettyJson}
+           """.stripMargin, e)
+    }
+
+    val normalized2 = jsonBackPlan transformDown {
+      case l: LogicalRDD =>
+        val origin = logicalRDDs.head
+        logicalRDDs = logicalRDDs.drop(1)
+        LogicalRDD(l.output, origin.rdd)(sqlContext)
+      case l: LocalRelation =>
+        val origin = localRelations.head
+        localRelations = localRelations.drop(1)
+        l.copy(data = origin.data)
+      case l: InMemoryRelation =>
+        val origin = inMemoryRelations.head
+        inMemoryRelations = inMemoryRelations.drop(1)
+        InMemoryRelation(
+          l.output,
+          l.useCompression,
+          l.batchSize,
+          l.storageLevel,
+          origin.child,
+          l.tableName)(
+          origin.cachedColumnBuffers,
+          l._statistics,
+          origin._batchStats)
+    }
+
+    assert(logicalRDDs.isEmpty)
+    assert(localRelations.isEmpty)
+    assert(inMemoryRelations.isEmpty)
+
+    if (normalized1 != normalized2) {
+      fail(
+        s"""
+           |== FAIL: the logical plan parsed from json does not match the original one ===
+           |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")}
+          """.stripMargin)
+    }
+  }
 }
 
 object QueryTest {

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index f602f2f..2a11173 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -65,6 +65,11 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
   override def userClass: Class[MyDenseVector] = classOf[MyDenseVector]
 
   private[spark] override def asNullable: MyDenseVectorUDT = this
+
+  override def equals(other: Any): Boolean = other match {
+    case _: MyDenseVectorUDT => true
+    case _ => false
+  }
 }
 
 class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 08b291e..f099e14 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -728,6 +728,8 @@ private[hive] case class MetastoreRelation
     Objects.hashCode(databaseName, tableName, alias, output)
   }
 
+  override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: Nil
+
   @transient val hiveQlTable: Table = {
     // We start by constructing an API table as Hive performs several important transformations
     // internally when converting an API table to a QL table.

http://git-wip-us.apache.org/repos/asf/spark/blob/7634fe95/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index b30117f..d9b9ba4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -58,7 +58,7 @@ case class ScriptTransformation(
     ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext)
   extends UnaryNode {
 
-  override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
+  override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil
 
   private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org