You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/01/09 00:43:20 UTC

[1/2] spark git commit: [SPARK-12696] Backport Dataset Bug fixes to 1.6

Repository: spark
Updated Branches:
  refs/heads/branch-1.6 faf094c7c -> a6190508b


http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0dbaeb8..9f8db39 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,8 @@ import java.sql.Date;
 import java.sql.Timestamp;
 import java.util.*;
 
+import com.google.common.base.Objects;
+import org.junit.rules.ExpectedException;
 import scala.Tuple2;
 import scala.Tuple3;
 import scala.Tuple4;
@@ -39,7 +41,6 @@ import org.apache.spark.sql.expressions.Aggregator;
 import org.apache.spark.sql.test.TestSQLContext;
 import org.apache.spark.sql.catalyst.encoders.OuterScopes;
 import org.apache.spark.sql.catalyst.expressions.GenericRow;
-import org.apache.spark.sql.types.DecimalType;
 import org.apache.spark.sql.types.StructType;
 
 import static org.apache.spark.sql.functions.*;
@@ -741,4 +742,127 @@ public class JavaDatasetSuite implements Serializable {
       context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
     ds.collect();
   }
+
+  public class SmallBean implements Serializable {
+    private String a;
+
+    private int b;
+
+    public int getB() {
+      return b;
+    }
+
+    public void setB(int b) {
+      this.b = b;
+    }
+
+    public String getA() {
+      return a;
+    }
+
+    public void setA(String a) {
+      this.a = a;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      SmallBean smallBean = (SmallBean) o;
+      return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(a, b);
+    }
+  }
+
+  public class NestedSmallBean implements Serializable {
+    private SmallBean f;
+
+    public SmallBean getF() {
+      return f;
+    }
+
+    public void setF(SmallBean f) {
+      this.f = f;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      NestedSmallBean that = (NestedSmallBean) o;
+      return Objects.equal(f, that.f);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(f);
+    }
+  }
+
+  @Rule
+  public transient ExpectedException nullabilityCheck = ExpectedException.none();
+
+  @Test
+  public void testRuntimeNullabilityCheck() {
+    OuterScopes.addOuterScope(this);
+
+    StructType schema = new StructType()
+      .add("f", new StructType()
+        .add("a", StringType, true)
+        .add("b", IntegerType, true), true);
+
+    // Shouldn't throw runtime exception since it passes nullability check.
+    {
+      Row row = new GenericRow(new Object[] {
+          new GenericRow(new Object[] {
+              "hello", 1
+          })
+      });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+      Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+      SmallBean smallBean = new SmallBean();
+      smallBean.setA("hello");
+      smallBean.setB(1);
+
+      NestedSmallBean nestedSmallBean = new NestedSmallBean();
+      nestedSmallBean.setF(smallBean);
+
+      Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean));
+    }
+
+    // Shouldn't throw runtime exception when parent object (`ClassData`) is null
+    {
+      Row row = new GenericRow(new Object[] { null });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+      Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+      NestedSmallBean nestedSmallBean = new NestedSmallBean();
+      Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean));
+    }
+
+    nullabilityCheck.expect(RuntimeException.class);
+    nullabilityCheck.expectMessage(
+      "Null value appeared in non-nullable field " +
+        "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+
+    {
+      Row row = new GenericRow(new Object[] {
+          new GenericRow(new Object[] {
+              "hello", null
+          })
+      });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+      Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+      ds.collect();
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 854dec0..0b7573c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -578,6 +578,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     assert(df.showString(10) === expectedAnswer)
   }
 
+  test("showString: binary") {
+    val df = Seq(
+      ("12".getBytes, "ABC.".getBytes),
+      ("34".getBytes, "12346".getBytes)
+    ).toDF()
+    val expectedAnswer = """+-------+----------------+
+                           ||     _1|              _2|
+                           |+-------+----------------+
+                           ||[31 32]|   [41 42 43 2E]|
+                           ||[33 34]|[31 32 33 34 36]|
+                           |+-------+----------------+
+                           |""".stripMargin
+    assert(df.showString(10) === expectedAnswer)
+  }
+
   test("showString: minimum column width") {
     val df = Seq(
       (1, 1),

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c6b3991..c19b5a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -24,6 +24,7 @@ import scala.language.postfixOps
 
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
 
 
 class DatasetSuite extends QueryTest with SharedSQLContext {
@@ -438,6 +439,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     assert(ds.toString == "[_1: int, _2: int]")
   }
 
+  test("showString: Kryo encoder") {
+    implicit val kryoEncoder = Encoders.kryo[KryoData]
+    val ds = Seq(KryoData(1), KryoData(2)).toDS()
+
+    val expectedAnswer = """+-----------+
+                           ||      value|
+                           |+-----------+
+                           ||KryoData(1)|
+                           ||KryoData(2)|
+                           |+-----------+
+                           |""".stripMargin
+    assert(ds.showString(10) === expectedAnswer)
+  }
+
   test("Kryo encoder") {
     implicit val kryoEncoder = Encoders.kryo[KryoData]
     val ds = Seq(KryoData(1), KryoData(2)).toDS()
@@ -493,12 +508,63 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData]
     assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3)))
   }
-}
 
+  test("verify mismatching field names fail with a good error") {
+    val ds = Seq(ClassData("a", 1)).toDS()
+    val e = intercept[AnalysisException] {
+      ds.as[ClassData2].collect()
+    }
+    assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
+  }
+
+  test("runtime nullability check") {
+    val schema = StructType(Seq(
+      StructField("f", StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", IntegerType, nullable = false)
+      )), nullable = true)
+    ))
+
+    def buildDataset(rows: Row*): Dataset[NestedStruct] = {
+      val rowRDD = sqlContext.sparkContext.parallelize(rows)
+      sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
+    }
+
+    checkAnswer(
+      buildDataset(Row(Row("hello", 1))),
+      NestedStruct(ClassData("hello", 1))
+    )
+
+    // Shouldn't throw runtime exception when parent object (`ClassData`) is null
+    assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null)))
+
+    val message = intercept[RuntimeException] {
+      buildDataset(Row(Row("hello", null))).collect()
+    }.getMessage
+
+    assert(message.contains(
+      "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
+    ))
+  }
+
+  test("SPARK-12478: top level null field") {
+    val ds0 = Seq(NestedStruct(null)).toDS()
+    checkAnswer(ds0, NestedStruct(null))
+    checkAnswer(ds0.toDF(), Row(null))
+
+    val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS()
+    checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
+    checkAnswer(ds1.toDF(), Row(Row(null)))
+  }
+}
 
 case class ClassData(a: String, b: Int)
+case class ClassData2(c: String, d: Int)
 case class ClassNullableData(a: String, b: Integer)
 
+case class NestedStruct(f: ClassData)
+case class DeepNestedStruct(f: NestedStruct)
+
 /**
  * A class used to test serialization using encoders. This class throws exceptions when using
  * Java serialization -- so the only way it can be "serialized" is through our encoders.

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index bc22fb8..9246f55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -21,10 +21,15 @@ import java.util.{Locale, TimeZone}
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
-import org.apache.spark.sql.execution.Queryable
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.{LogicalRDD, Queryable}
 
 abstract class QueryTest extends PlanTest {
 
@@ -123,6 +128,8 @@ abstract class QueryTest extends PlanTest {
              |""".stripMargin)
     }
 
+    checkJsonFormat(analyzedDF)
+
     QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
@@ -177,6 +184,97 @@ abstract class QueryTest extends PlanTest {
       s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
         planWithCaching)
   }
+
+  private def checkJsonFormat(df: DataFrame): Unit = {
+    val logicalPlan = df.queryExecution.analyzed
+    // bypass some cases that we can't handle currently.
+    logicalPlan.transform {
+      case _: MapPartitions[_, _] => return
+      case _: MapGroups[_, _, _] => return
+      case _: AppendColumns[_, _] => return
+      case _: CoGroup[_, _, _, _] => return
+      case _: LogicalRelation => return
+    }.transformAllExpressions {
+      case a: ImperativeAggregate => return
+    }
+
+    val jsonString = try {
+      logicalPlan.toJSON
+    } catch {
+      case e =>
+        fail(
+          s"""
+             |Failed to parse logical plan to JSON:
+             |${logicalPlan.treeString}
+           """.stripMargin, e)
+    }
+
+    // bypass hive tests before we fix all corner cases in hive module.
+    if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return
+
+    // scala function is not serializable to JSON, use null to replace them so that we can compare
+    // the plans later.
+    val normalized1 = logicalPlan.transformAllExpressions {
+      case udf: ScalaUDF => udf.copy(function = null)
+      case gen: UserDefinedGenerator => gen.copy(function = null)
+    }
+
+    // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
+    // these non-serializable stuff, and use these original ones to replace the null-placeholders
+    // in the logical plans parsed from JSON.
+    var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l }
+    var localRelations = logicalPlan.collect { case l: LocalRelation => l }
+    var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i }
+
+    val jsonBackPlan = try {
+      TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
+    } catch {
+      case e =>
+        fail(
+          s"""
+             |Failed to rebuild the logical plan from JSON:
+             |${logicalPlan.treeString}
+             |
+             |${logicalPlan.prettyJson}
+           """.stripMargin, e)
+    }
+
+    val normalized2 = jsonBackPlan transformDown {
+      case l: LogicalRDD =>
+        val origin = logicalRDDs.head
+        logicalRDDs = logicalRDDs.drop(1)
+        LogicalRDD(l.output, origin.rdd)(sqlContext)
+      case l: LocalRelation =>
+        val origin = localRelations.head
+        localRelations = localRelations.drop(1)
+        l.copy(data = origin.data)
+      case l: InMemoryRelation =>
+        val origin = inMemoryRelations.head
+        inMemoryRelations = inMemoryRelations.drop(1)
+        InMemoryRelation(
+          l.output,
+          l.useCompression,
+          l.batchSize,
+          l.storageLevel,
+          origin.child,
+          l.tableName)(
+          origin.cachedColumnBuffers,
+          l._statistics,
+          origin._batchStats)
+    }
+
+    assert(logicalRDDs.isEmpty)
+    assert(localRelations.isEmpty)
+    assert(inMemoryRelations.isEmpty)
+
+    if (normalized1 != normalized2) {
+      fail(
+        s"""
+           |== FAIL: the logical plan parsed from json does not match the original one ===
+           |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")}
+          """.stripMargin)
+    }
+  }
 }
 
 object QueryTest {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index f602f2f..2a11173 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -65,6 +65,11 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
   override def userClass: Class[MyDenseVector] = classOf[MyDenseVector]
 
   private[spark] override def asNullable: MyDenseVectorUDT = this
+
+  override def equals(other: Any): Boolean = other match {
+    case _: MyDenseVectorUDT => true
+    case _ => false
+  }
 }
 
 class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 08b291e..f099e14 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -728,6 +728,8 @@ private[hive] case class MetastoreRelation
     Objects.hashCode(databaseName, tableName, alias, output)
   }
 
+  override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: Nil
+
   @transient val hiveQlTable: Table = {
     // We start by constructing an API table as Hive performs several important transformations
     // internally when converting an API table to a QL table.

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index b30117f..d9b9ba4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -58,7 +58,7 @@ case class ScriptTransformation(
     ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext)
   extends UnaryNode {
 
-  override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
+  override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil
 
   private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-12696] Backport Dataset Bug fixes to 1.6

Posted by ma...@apache.org.
[SPARK-12696] Backport Dataset Bug fixes to 1.6

We've fixed a lot of bugs in master, and since this is experimental in 1.6 we should consider back porting the fixes.  The only thing that is obviously risky to me is 0e07ed3, we might try to remove that.

Author: Wenchen Fan <we...@databricks.com>
Author: gatorsmile <ga...@gmail.com>
Author: Liang-Chi Hsieh <vi...@gmail.com>
Author: Cheng Lian <li...@databricks.com>
Author: Nong Li <no...@databricks.com>

Closes #10650 from marmbrus/dataset-backports.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6190508
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6190508
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6190508

Branch: refs/heads/branch-1.6
Commit: a6190508b20673952303eff32b3a559f0a264d03
Parents: faf094c
Author: Michael Armbrust <mi...@databricks.com>
Authored: Fri Jan 8 15:43:11 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Jan 8 15:43:11 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  |  25 +-
 .../spark/sql/catalyst/ScalaReflection.scala    | 176 +++++++------
 .../sql/catalyst/analysis/CheckAnalysis.scala   |   2 +-
 .../apache/spark/sql/catalyst/dsl/package.scala |   3 +-
 .../catalyst/encoders/ExpressionEncoder.scala   |  39 ++-
 .../sql/catalyst/encoders/RowEncoder.scala      |  15 +-
 .../spark/sql/catalyst/expressions/Cast.scala   |  14 +
 .../expressions/aggregate/interfaces.scala      |   1 -
 .../expressions/complexTypeExtractors.scala     |  10 +-
 .../sql/catalyst/expressions/literals.scala     |  41 +++
 .../catalyst/expressions/namedExpressions.scala |   4 +
 .../sql/catalyst/expressions/objects.scala      | 140 ++++++----
 .../spark/sql/catalyst/plans/QueryPlan.scala    |   2 +
 .../spark/sql/catalyst/trees/TreeNode.scala     | 258 ++++++++++++++++++-
 .../org/apache/spark/sql/types/DataType.scala   |   6 +-
 .../encoders/EncoderResolutionSuite.scala       | 101 ++++++--
 .../encoders/ExpressionEncoderSuite.scala       |   9 +
 .../sql/catalyst/encoders/RowEncoderSuite.scala |  15 +-
 .../catalyst/expressions/ComplexTypeSuite.scala |   2 +-
 .../scala/org/apache/spark/sql/DataFrame.scala  |  50 +---
 .../scala/org/apache/spark/sql/Dataset.scala    |  42 ++-
 .../spark/sql/execution/ExistingRDD.scala       |   4 +-
 .../apache/spark/sql/execution/Queryable.scala  |  65 +++++
 .../columnar/InMemoryColumnarTableScan.scala    |   6 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 126 ++++++++-
 .../org/apache/spark/sql/DataFrameSuite.scala   |  15 ++
 .../org/apache/spark/sql/DatasetSuite.scala     |  68 ++++-
 .../scala/org/apache/spark/sql/QueryTest.scala  | 102 +++++++-
 .../apache/spark/sql/UserDefinedTypeSuite.scala |   5 +
 .../spark/sql/hive/HiveMetastoreCatalog.scala   |   2 +
 .../hive/execution/ScriptTransformation.scala   |   2 +-
 31 files changed, 1086 insertions(+), 264 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index f566d1b..ed153d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -178,19 +178,19 @@ object JavaTypeInference {
       case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
 
       case c if c == classOf[java.lang.Short] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Integer] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Long] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Double] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Byte] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Float] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
       case c if c == classOf[java.lang.Boolean] =>
-        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+        NewInstance(c, getPath :: Nil, ObjectType(c))
 
       case c if c == classOf[java.sql.Date] =>
         StaticInvoke(
@@ -288,10 +288,17 @@ object JavaTypeInference {
         val setters = properties.map { p =>
           val fieldName = p.getName
           val fieldType = typeToken.method(p.getReadMethod).getReturnType
-          p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
+          val (_, nullable) = inferDataType(fieldType)
+          val constructor = constructorFor(fieldType, Some(addToPath(fieldName)))
+          val setter = if (nullable) {
+            constructor
+          } else {
+            AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
+          }
+          p.getWriteMethod.getName -> setter
         }.toMap
 
-        val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
+        val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
         val result = InitializeJavaBean(newInstance, setters)
 
         if (path.nonEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index ea98956..b0efdf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -68,7 +68,7 @@ object ScalaReflection extends ScalaReflection {
             val TypeRef(_, _, Seq(elementType)) = tpe
             arrayClassFor(elementType)
           case other =>
-            val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
+            val clazz = getClassFromType(tpe)
             ObjectType(clazz)
         }
     }
@@ -177,6 +177,7 @@ object ScalaReflection extends ScalaReflection {
       case _ => UpCast(expr, expected, walkedTypePath)
     }
 
+    val className = getClassNameFromType(tpe)
     tpe match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
 
@@ -184,42 +185,42 @@ object ScalaReflection extends ScalaReflection {
         val TypeRef(_, _, Seq(optType)) = t
         val className = getClassNameFromType(optType)
         val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
-        WrapOption(constructorFor(optType, path, newTypePath))
+        WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType))
 
       case t if t <:< localTypeOf[java.lang.Integer] =>
         val boxedType = classOf[java.lang.Integer]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Long] =>
         val boxedType = classOf[java.lang.Long]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Double] =>
         val boxedType = classOf[java.lang.Double]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Float] =>
         val boxedType = classOf[java.lang.Float]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Short] =>
         val boxedType = classOf[java.lang.Short]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Byte] =>
         val boxedType = classOf[java.lang.Byte]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.lang.Boolean] =>
         val boxedType = classOf[java.lang.Boolean]
         val objectType = ObjectType(boxedType)
-        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+        NewInstance(boxedType, getPath :: Nil, objectType)
 
       case t if t <:< localTypeOf[java.sql.Date] =>
         StaticInvoke(
@@ -321,30 +322,12 @@ object ScalaReflection extends ScalaReflection {
           keyData :: valueData :: Nil)
 
       case t if t <:< localTypeOf[Product] =>
-        val formalTypeArgs = t.typeSymbol.asClass.typeParams
-        val TypeRef(_, _, actualTypeArgs) = t
-        val constructorSymbol = t.member(nme.CONSTRUCTOR)
-        val params = if (constructorSymbol.isMethod) {
-          constructorSymbol.asMethod.paramss
-        } else {
-          // Find the primary constructor, and use its parameter ordering.
-          val primaryConstructorSymbol: Option[Symbol] =
-            constructorSymbol.asTerm.alternatives.find(s =>
-              s.isMethod && s.asMethod.isPrimaryConstructor)
-
-          if (primaryConstructorSymbol.isEmpty) {
-            sys.error("Internal SQL error: Product object did not have a primary constructor.")
-          } else {
-            primaryConstructorSymbol.get.asMethod.paramss
-          }
-        }
+        val params = getConstructorParameters(t)
 
-        val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
+        val cls = getClassFromType(tpe)
 
-        val arguments = params.head.zipWithIndex.map { case (p, i) =>
-          val fieldName = p.name.toString
-          val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
-          val dataType = schemaFor(fieldType).dataType
+        val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
+          val Schema(dataType, nullable) = schemaFor(fieldType)
           val clsName = getClassNameFromType(fieldType)
           val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
           // For tuples, we based grab the inner fields by ordinal instead of name.
@@ -354,14 +337,20 @@ object ScalaReflection extends ScalaReflection {
               Some(addToPathOrdinal(i, dataType, newTypePath)),
               newTypePath)
           } else {
-            constructorFor(
+            val constructor = constructorFor(
               fieldType,
               Some(addToPath(fieldName, dataType, newTypePath)),
               newTypePath)
+
+            if (!nullable) {
+              AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
+            } else {
+              constructor
+            }
           }
         }
 
-        val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
+        val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
 
         if (path.nonEmpty) {
           expressions.If(
@@ -372,6 +361,16 @@ object ScalaReflection extends ScalaReflection {
         } else {
           newInstance
         }
+
+      case t if Utils.classIsLoadable(className) &&
+        Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+        val udt = Utils.classForName(className)
+          .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+        val obj = NewInstance(
+          udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+          Nil,
+          dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
     }
   }
 
@@ -392,7 +391,7 @@ object ScalaReflection extends ScalaReflection {
     val clsName = getClassNameFromType(tpe)
     val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
     extractorFor(inputObject, tpe, walkedTypePath) match {
-      case s: CreateNamedStruct => s
+      case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
       case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
     }
   }
@@ -406,7 +405,7 @@ object ScalaReflection extends ScalaReflection {
     def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
       val externalDataType = dataTypeFor(elementType)
       val Schema(catalystType, nullable) = silentSchemaFor(elementType)
-      if (isNativeType(catalystType)) {
+      if (isNativeType(externalDataType)) {
         NewInstance(
           classOf[GenericArrayData],
           input :: Nil,
@@ -414,10 +413,6 @@ object ScalaReflection extends ScalaReflection {
       } else {
         val clsName = getClassNameFromType(elementType)
         val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
-        // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here
-        // to trigger the type check.
-        extractorFor(inputObject, elementType, newPath)
-
         MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
       }
     }
@@ -425,6 +420,7 @@ object ScalaReflection extends ScalaReflection {
     if (!inputObject.dataType.isInstanceOf[ObjectType]) {
       inputObject
     } else {
+      val className = getClassNameFromType(tpe)
       tpe match {
         case t if t <:< localTypeOf[Option[_]] =>
           val TypeRef(_, _, Seq(optType)) = t
@@ -481,33 +477,15 @@ object ScalaReflection extends ScalaReflection {
           }
 
         case t if t <:< localTypeOf[Product] =>
-          val formalTypeArgs = t.typeSymbol.asClass.typeParams
-          val TypeRef(_, _, actualTypeArgs) = t
-          val constructorSymbol = t.member(nme.CONSTRUCTOR)
-          val params = if (constructorSymbol.isMethod) {
-            constructorSymbol.asMethod.paramss
-          } else {
-            // Find the primary constructor, and use its parameter ordering.
-            val primaryConstructorSymbol: Option[Symbol] =
-              constructorSymbol.asTerm.alternatives.find(s =>
-                s.isMethod && s.asMethod.isPrimaryConstructor)
-
-            if (primaryConstructorSymbol.isEmpty) {
-              sys.error("Internal SQL error: Product object did not have a primary constructor.")
-            } else {
-              primaryConstructorSymbol.get.asMethod.paramss
-            }
-          }
-
-          CreateNamedStruct(params.head.flatMap { p =>
-            val fieldName = p.name.toString
-            val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+          val params = getConstructorParameters(t)
+          val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
             val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
             val clsName = getClassNameFromType(fieldType)
             val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-
             expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
           })
+          val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+          expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
 
         case t if t <:< localTypeOf[Array[_]] =>
           val TypeRef(_, _, Seq(elementType)) = t
@@ -593,12 +571,37 @@ object ScalaReflection extends ScalaReflection {
         case t if t <:< localTypeOf[java.lang.Boolean] =>
           Invoke(inputObject, "booleanValue", BooleanType)
 
+        case t if Utils.classIsLoadable(className) &&
+          Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+          val udt = Utils.classForName(className)
+            .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+          val obj = NewInstance(
+            udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+            Nil,
+            dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+          Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+
         case other =>
           throw new UnsupportedOperationException(
             s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
       }
     }
   }
+
+  /**
+   * Returns the parameter names and types for the primary constructor of this class.
+   *
+   * Note that it only works for scala classes with primary constructor, and currently doesn't
+   * support inner class.
+   */
+  def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = {
+    val m = runtimeMirror(cls.getClassLoader)
+    val classSymbol = m.staticClass(cls.getName)
+    val t = classSymbol.selfType
+    getConstructorParameters(t)
+  }
+
+  def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
 }
 
 /**
@@ -672,26 +675,11 @@ trait ScalaReflection {
         Schema(MapType(schemaFor(keyType).dataType,
           valueDataType, valueContainsNull = valueNullable), nullable = true)
       case t if t <:< localTypeOf[Product] =>
-        val formalTypeArgs = t.typeSymbol.asClass.typeParams
-        val TypeRef(_, _, actualTypeArgs) = t
-        val constructorSymbol = t.member(nme.CONSTRUCTOR)
-        val params = if (constructorSymbol.isMethod) {
-          constructorSymbol.asMethod.paramss
-        } else {
-          // Find the primary constructor, and use its parameter ordering.
-          val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
-            s => s.isMethod && s.asMethod.isPrimaryConstructor)
-          if (primaryConstructorSymbol.isEmpty) {
-            sys.error("Internal SQL error: Product object did not have a primary constructor.")
-          } else {
-            primaryConstructorSymbol.get.asMethod.paramss
-          }
-        }
+        val params = getConstructorParameters(t)
         Schema(StructType(
-          params.head.map { p =>
-            val Schema(dataType, nullable) =
-              schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
-            StructField(p.name.toString, dataType, nullable)
+          params.map { case (fieldName, fieldType) =>
+            val Schema(dataType, nullable) = schemaFor(fieldType)
+            StructField(fieldName, dataType, nullable)
           }), nullable = true)
       case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
       case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
@@ -744,4 +732,32 @@ trait ScalaReflection {
     assert(methods.length == 1)
     methods.head.getParameterTypes
   }
+
+  /**
+   * Returns the parameter names and types for the primary constructor of this type.
+   *
+   * Note that it only works for scala classes with primary constructor, and currently doesn't
+   * support inner class.
+   */
+  def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
+    val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
+    val TypeRef(_, _, actualTypeArgs) = tpe
+    val constructorSymbol = tpe.member(nme.CONSTRUCTOR)
+    val params = if (constructorSymbol.isMethod) {
+      constructorSymbol.asMethod.paramss
+    } else {
+      // Find the primary constructor, and use its parameter ordering.
+      val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
+        s => s.isMethod && s.asMethod.isPrimaryConstructor)
+      if (primaryConstructorSymbol.isEmpty) {
+        sys.error("Internal SQL error: Product object did not have a primary constructor.")
+      } else {
+        primaryConstructorSymbol.get.asMethod.paramss
+      }
+    }
+
+    params.flatten.map { p =>
+      p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 7b2c93d..8c19a85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -57,7 +57,7 @@ trait CheckAnalysis {
         operator transformExpressionsUp {
           case a: Attribute if !a.resolved =>
             val from = operator.inputSet.map(_.name).mkString(", ")
-            a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
+            a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]")
 
           case e: Expression if e.checkInputDataTypes().isFailure =>
             e.checkInputDataTypes() match {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index e509711..8102c93 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -227,9 +227,10 @@ package object dsl {
         AttributeReference(s, mapType, nullable = true)()
 
       /** Creates a new AttributeReference of type struct */
-      def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
       def struct(structType: StructType): AttributeReference =
         AttributeReference(s, structType, nullable = true)()
+      def struct(attrs: AttributeReference*): AttributeReference =
+        struct(StructType.fromAttributes(attrs))
     }
 
     implicit class DslAttribute(a: AttributeReference) {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0f6dc2c..6c05846 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -133,7 +133,7 @@ object ExpressionEncoder {
     }
 
     val fromRowExpression =
-      NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
+      NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)
 
     new ExpressionEncoder[Any](
       schema,
@@ -253,13 +253,46 @@ case class ExpressionEncoder[T](
   def resolve(
       schema: Seq[Attribute],
       outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
-    val positionToAttribute = AttributeMap.toIndex(schema)
+    def fail(st: StructType, maxOrdinal: Int): Unit = {
+      throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
+        "but failed as the number of fields does not line up.\n" +
+        " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" +
+        " - Target schema: " + this.schema.simpleString)
+    }
+
+    var maxOrdinal = -1
+    fromRowExpression.foreach {
+      case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
+      case _ =>
+    }
+    if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) {
+      fail(StructType.fromAttributes(schema), maxOrdinal)
+    }
+
     val unbound = fromRowExpression transform {
-      case b: BoundReference => positionToAttribute(b.ordinal)
+      case b: BoundReference => schema(b.ordinal)
+    }
+
+    val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int]
+    unbound.foreach {
+      case g: GetStructField =>
+        val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1)
+        if (maxOrdinal < g.ordinal) {
+          exprToMaxOrdinal.update(g.child, g.ordinal)
+        }
+      case _ =>
+    }
+    exprToMaxOrdinal.foreach {
+      case (expr, maxOrdinal) =>
+        val schema = expr.dataType.asInstanceOf[StructType]
+        if (maxOrdinal != schema.length - 1) {
+          fail(schema, maxOrdinal)
+        }
     }
 
     val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
     val analyzedPlan = SimpleAnalyzer.execute(plan)
+    SimpleAnalyzer.checkAnalysis(analyzedPlan)
     val optimizedPlan = SimplifyCasts(analyzedPlan)
 
     // In order to construct instances of inner classes (for example those declared in a REPL cell),

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 6c1220a..3903086 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -35,7 +35,8 @@ object RowEncoder {
   def apply(schema: StructType): ExpressionEncoder[Row] = {
     val cls = classOf[Row]
     val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
-    val extractExpressions = extractorsFor(inputObject, schema)
+    // We use an If expression to wrap extractorsFor result of StructType
+    val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue
     val constructExpression = constructorFor(schema)
     new ExpressionEncoder[Row](
       schema,
@@ -55,7 +56,6 @@ object RowEncoder {
       val obj = NewInstance(
         udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
         Nil,
-        false,
         dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
       Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
 
@@ -130,7 +130,9 @@ object RowEncoder {
             Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
             f.dataType))
       }
-      CreateStruct(convertedFields)
+      If(IsNull(inputObject),
+        Literal.create(null, inputType),
+        CreateStruct(convertedFields))
   }
 
   private def externalDataTypeFor(dt: DataType): DataType = dt match {
@@ -166,7 +168,6 @@ object RowEncoder {
       val obj = NewInstance(
         udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
         Nil,
-        false,
         dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
       Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
 
@@ -193,7 +194,7 @@ object RowEncoder {
     case ArrayType(et, nullable) =>
       val arrayData =
         Invoke(
-          MapObjects(constructorFor, input, et),
+          MapObjects(constructorFor(_), input, et),
           "array",
           ObjectType(classOf[Array[_]]))
       StaticInvoke(
@@ -222,6 +223,8 @@ object RowEncoder {
           Literal.create(null, externalDataTypeFor(f.dataType)),
           constructorFor(GetStructField(input, i)))
       }
-      CreateExternalRow(convertedFields)
+      If(IsNull(input),
+        Literal.create(null, externalDataTypeFor(input.dataType)),
+        CreateExternalRow(convertedFields))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 79e0438..c53e84d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.math.{BigDecimal => JavaBigDecimal}
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -81,6 +82,9 @@ object Cast {
                 toField.nullable)
         }
 
+    case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass =>
+      true
+
     case _ => false
   }
 
@@ -428,6 +432,11 @@ case class Cast(child: Expression, dataType: DataType)
     case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
     case map: MapType => castMap(from.asInstanceOf[MapType], map)
     case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
+    case udt: UserDefinedType[_]
+      if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+      identity[Any]
+    case _: UserDefinedType[_] =>
+      throw new SparkException(s"Cannot cast $from to $to.")
   }
 
   private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
@@ -470,6 +479,11 @@ case class Cast(child: Expression, dataType: DataType)
       castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
     case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
     case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
+    case udt: UserDefinedType[_]
+      if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+      (c, evPrim, evNull) => s"$evPrim = $c;"
+    case _: UserDefinedType[_] =>
+      throw new SparkException(s"Cannot cast $from to $to.")
   }
 
   // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 3b441de..e6fd726 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.catalyst.InternalRow

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 31520f5..1f79683 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -104,14 +104,14 @@ object ExtractValue {
 case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
   extends UnaryExpression {
 
-  private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)
+  private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType]
 
-  override def dataType: DataType = field.dataType
-  override def nullable: Boolean = child.nullable || field.nullable
-  override def toString: String = s"$child.${name.getOrElse(field.name)}"
+  override def dataType: DataType = childSchema(ordinal).dataType
+  override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
+  override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}"
 
   protected override def nullSafeEval(input: Any): Any =
-    input.asInstanceOf[InternalRow].get(ordinal, field.dataType)
+    input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     nullSafeCodeGen(ctx, ev, eval => {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 68ec688..e3573b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.json4s.JsonAST._
 import java.sql.{Date, Timestamp}
 
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -55,6 +56,34 @@ object Literal {
    */
   def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass))
 
+  def fromJSON(json: JValue): Literal = {
+    val dataType = DataType.parseDataType(json \ "dataType")
+    json \ "value" match {
+      case JNull => Literal.create(null, dataType)
+      case JString(str) =>
+        val value = dataType match {
+          case BooleanType => str.toBoolean
+          case ByteType => str.toByte
+          case ShortType => str.toShort
+          case IntegerType => str.toInt
+          case LongType => str.toLong
+          case FloatType => str.toFloat
+          case DoubleType => str.toDouble
+          case StringType => UTF8String.fromString(str)
+          case DateType => java.sql.Date.valueOf(str)
+          case TimestampType => java.sql.Timestamp.valueOf(str)
+          case CalendarIntervalType => CalendarInterval.fromString(str)
+          case t: DecimalType =>
+            val d = Decimal(str)
+            assert(d.changePrecision(t.precision, t.scale))
+            d
+          case _ => null
+        }
+        Literal.create(value, dataType)
+      case other => sys.error(s"$other is not a valid Literal json value")
+    }
+  }
+
   def create(v: Any, dataType: DataType): Literal = {
     Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
   }
@@ -123,6 +152,18 @@ case class Literal protected (value: Any, dataType: DataType)
     case _ => false
   }
 
+  override protected def jsonFields: List[JField] = {
+    // Turns all kinds of literal values to string in json field, as the type info is hard to
+    // retain in json format, e.g. {"a": 123} can be a int, or double, or decimal, etc.
+    val jsonValue = (value, dataType) match {
+      case (null, _) => JNull
+      case (i: Int, DateType) => JString(DateTimeUtils.toJavaDate(i).toString)
+      case (l: Long, TimestampType) => JString(DateTimeUtils.toJavaTimestamp(l).toString)
+      case (other, _) => JString(other.toString)
+    }
+    ("value" -> jsonValue) :: ("dataType" -> dataType.jsonValue) :: Nil
+  }
+
   override def eval(input: InternalRow): Any = value
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 26b6aca..eefd9c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -262,6 +262,10 @@ case class AttributeReference(
     }
   }
 
+  override protected final def otherCopyArgs: Seq[AnyRef] = {
+    exprId :: qualifiers :: Nil
+  }
+
   override def toString: String = s"$name#${exprId.id}$typeSuffix"
 
   // Since the expression id is not in the first constructor it is missing from the default

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 749936c..c0c3e6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -23,11 +23,9 @@ import scala.reflect.ClassTag
 import org.apache.spark.SparkConf
 import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
-import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
-import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.types._
 
 /**
@@ -167,7 +165,7 @@ case class Invoke(
       ${obj.code}
       ${argGen.map(_.code).mkString("\n")}
 
-      boolean ${ev.isNull} = ${obj.value} == null;
+      boolean ${ev.isNull} = ${obj.isNull};
       $javaType ${ev.value} =
         ${ev.isNull} ?
         ${ctx.defaultValue(dataType)} : ($javaType) $value;
@@ -180,8 +178,8 @@ object NewInstance {
   def apply(
       cls: Class[_],
       arguments: Seq[Expression],
-      propagateNull: Boolean = false,
-      dataType: DataType): NewInstance =
+      dataType: DataType,
+      propagateNull: Boolean = true): NewInstance =
     new NewInstance(cls, arguments, propagateNull, dataType, None)
 }
 
@@ -233,7 +231,7 @@ case class NewInstance(
       s"new $className($argString)"
     }
 
-    if (propagateNull) {
+    if (propagateNull && argGen.nonEmpty) {
       val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
 
       s"""
@@ -250,8 +248,8 @@ case class NewInstance(
       s"""
         $setup
 
-        $javaType ${ev.value} = $constructorCall;
-        final boolean ${ev.isNull} = ${ev.value} == null;
+        final $javaType ${ev.value} = $constructorCall;
+        final boolean ${ev.isNull} = false;
       """
     }
   }
@@ -293,15 +291,16 @@ case class UnwrapOption(
  * Converts the result of evaluating `child` into an option, checking both the isNull bit and
  * (in the case of reference types) equality with null.
  * @param child The expression to evaluate and wrap.
+ * @param optType The type of this option.
  */
-case class WrapOption(child: Expression)
+case class WrapOption(child: Expression, optType: DataType)
   extends UnaryExpression with ExpectsInputTypes {
 
   override def dataType: DataType = ObjectType(classOf[Option[_]])
 
   override def nullable: Boolean = true
 
-  override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil
+  override def inputTypes: Seq[AbstractDataType] = optType :: Nil
 
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is supported")
@@ -324,19 +323,28 @@ case class WrapOption(child: Expression)
  * A place holder for the loop variable used in [[MapObjects]].  This should never be constructed
  * manually, but will instead be passed into the provided lambda function.
  */
-case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression {
+case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
+  with Unevaluable {
 
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
-    throw new UnsupportedOperationException("Only calling gen() is supported.")
+  override def nullable: Boolean = true
 
-  override def children: Seq[Expression] = Nil
-  override def gen(ctx: CodeGenContext): GeneratedExpressionCode =
+  override def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
     GeneratedExpressionCode(code = "", value = value, isNull = isNull)
+  }
+}
 
-  override def nullable: Boolean = false
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+object MapObjects {
+  private val curId = new java.util.concurrent.atomic.AtomicInteger()
 
+  def apply(
+      function: Expression => Expression,
+      inputData: Expression,
+      elementType: DataType): MapObjects = {
+    val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
+    val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
+    val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
+    MapObjects(loopVar, function(loopVar), inputData)
+  }
 }
 
 /**
@@ -347,20 +355,16 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
  * The following collection ObjectTypes are currently supported:
  *   Seq, Array, ArrayData, java.util.List
  *
- * @param function A function that returns an expression, given an attribute that can be used
- *                 to access the current value.  This is does as a lambda function so that
- *                 a unique attribute reference can be provided for each expression (thus allowing
- *                 us to nest multiple MapObject calls).
+ * @param loopVar A place holder that used as the loop variable when iterate the collection, and
+ *                used as input for the `lambdaFunction`. It also carries the element type info.
+ * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
+ *                       to handle collection elements.
  * @param inputData An expression that when evaluted returns a collection object.
- * @param elementType The type of element in the collection, expressed as a DataType.
  */
 case class MapObjects(
-    function: AttributeReference => Expression,
-    inputData: Expression,
-    elementType: DataType) extends Expression {
-
-  private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
-  private lazy val completeFunction = function(loopAttribute)
+    loopVar: LambdaVariable,
+    lambdaFunction: Expression,
+    inputData: Expression) extends Expression {
 
   private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
     case NullType =>
@@ -400,37 +404,23 @@ case class MapObjects(
 
   override def nullable: Boolean = true
 
-  override def children: Seq[Expression] = completeFunction :: inputData :: Nil
+  override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
 
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is supported")
 
-  override def dataType: DataType = ArrayType(completeFunction.dataType)
+  override def dataType: DataType = ArrayType(lambdaFunction.dataType)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val javaType = ctx.javaType(dataType)
-    val elementJavaType = ctx.javaType(elementType)
+    val elementJavaType = ctx.javaType(loopVar.dataType)
     val genInputData = inputData.gen(ctx)
-
-    // Variables to hold the element that is currently being processed.
-    val loopValue = ctx.freshName("loopValue")
-    val loopIsNull = ctx.freshName("loopIsNull")
-
-    val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType)
-    val substitutedFunction = completeFunction transform {
-      case a: AttributeReference if a == loopAttribute => loopVariable
-    }
-    // A hack to run this through the analyzer (to bind extractions).
-    val boundFunction =
-      SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil)))
-        .expressions.head.children.head
-
-    val genFunction = boundFunction.gen(ctx)
+    val genFunction = lambdaFunction.gen(ctx)
     val dataLength = ctx.freshName("dataLength")
     val convertedArray = ctx.freshName("convertedArray")
     val loopIndex = ctx.freshName("loopIndex")
 
-    val convertedType = ctx.boxedType(boundFunction.dataType)
+    val convertedType = ctx.boxedType(lambdaFunction.dataType)
 
     // Because of the way Java defines nested arrays, we have to handle the syntax specially.
     // Specifically, we have to insert the [$dataLength] in between the type and any extra nested
@@ -444,9 +434,9 @@ case class MapObjects(
     }
 
     val loopNullCheck = if (primitiveElement) {
-      s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
+      s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
     } else {
-      s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;"
+      s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
     }
 
     s"""
@@ -462,14 +452,14 @@ case class MapObjects(
 
         int $loopIndex = 0;
         while ($loopIndex < $dataLength) {
-          $elementJavaType $loopValue =
+          $elementJavaType ${loopVar.value} =
             ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
           $loopNullCheck
 
-          if ($loopIsNull) {
+          ${genFunction.code}
+          if (${genFunction.isNull}) {
             $convertedArray[$loopIndex] = null;
           } else {
-            ${genFunction.code}
             $convertedArray[$loopIndex] = ${genFunction.value};
           }
 
@@ -634,3 +624,43 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
      """
   }
 }
+
+/**
+ * Asserts that input values of a non-nullable child expression are not null.
+ *
+ * Note that there are cases where `child.nullable == true`, while we still needs to add this
+ * assertion.  Consider a nullable column `s` whose data type is a struct containing a non-nullable
+ * `Int` field named `i`.  Expression `s.i` is nullable because `s` can be null.  However, for all
+ * non-null `s`, `s.i` can't be null.
+ */
+case class AssertNotNull(
+    child: Expression, parentType: String, fieldName: String, fieldType: String)
+  extends UnaryExpression {
+
+  override def dataType: DataType = child.dataType
+
+  override def nullable: Boolean = false
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val childGen = child.gen(ctx)
+
+    ev.isNull = "false"
+    ev.value = childGen.value
+
+    s"""
+      ${childGen.code}
+
+      if (${childGen.isNull}) {
+        throw new RuntimeException(
+          "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
+          "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
+          "please try to use scala.Option[_] or other nullable types " +
+          "(e.g. java.lang.Integer instead of int/scala.Int)."
+        );
+      }
+     """
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index b9db783..d262644 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -88,6 +88,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
       case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
+      case null => null
     }
 
     val newArgs = productIterator.map(recursiveTransform).toArray
@@ -120,6 +121,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
       case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
+      case null => null
     }
 
     val newArgs = productIterator.map(recursiveTransform).toArray

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index d838d84..c97dc2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -17,9 +17,25 @@
 
 package org.apache.spark.sql.catalyst.trees
 
+import java.util.UUID
 import scala.collection.Map
-
+import scala.collection.mutable.Stack
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.ScalaReflection._
+import org.apache.spark.sql.catalyst.{TableIdentifier, ScalaReflectionLock}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types.{StructType, DataType}
 
 /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
@@ -463,4 +479,244 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
     }
     s"$nodeName(${args.mkString(",")})"
   }
+
+  def toJSON: String = compact(render(jsonValue))
+
+  def prettyJson: String = pretty(render(jsonValue))
+
+  private def jsonValue: JValue = {
+    val jsonValues = scala.collection.mutable.ArrayBuffer.empty[JValue]
+
+    def collectJsonValue(tn: BaseType): Unit = {
+      val jsonFields = ("class" -> JString(tn.getClass.getName)) ::
+        ("num-children" -> JInt(tn.children.length)) :: tn.jsonFields
+      jsonValues += JObject(jsonFields)
+      tn.children.foreach(collectJsonValue)
+    }
+
+    collectJsonValue(this)
+    jsonValues
+  }
+
+  protected def jsonFields: List[JField] = {
+    val fieldNames = getConstructorParameters(getClass).map(_._1)
+    val fieldValues = productIterator.toSeq ++ otherCopyArgs
+    assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " +
+      fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", "))
+
+    fieldNames.zip(fieldValues).map {
+      // If the field value is a child, then use an int to encode it, represents the index of
+      // this child in all children.
+      case (name, value: TreeNode[_]) if containsChild(value) =>
+        name -> JInt(children.indexOf(value))
+      case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) =>
+        name -> JArray(
+          value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList
+        )
+      case (name, value) => name -> parseToJson(value)
+    }.toList
+  }
+
+  private def parseToJson(obj: Any): JValue = obj match {
+    case b: Boolean => JBool(b)
+    case b: Byte => JInt(b.toInt)
+    case s: Short => JInt(s.toInt)
+    case i: Int => JInt(i)
+    case l: Long => JInt(l)
+    case f: Float => JDouble(f)
+    case d: Double => JDouble(d)
+    case b: BigInt => JInt(b)
+    case null => JNull
+    case s: String => JString(s)
+    case u: UUID => JString(u.toString)
+    case dt: DataType => dt.jsonValue
+    case m: Metadata => m.jsonValue
+    case s: StorageLevel =>
+      ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~
+        ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication)
+    case n: TreeNode[_] => n.jsonValue
+    case o: Option[_] => o.map(parseToJson)
+    case t: Seq[_] => JArray(t.map(parseToJson).toList)
+    case m: Map[_, _] =>
+      val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) }
+      JObject(fields)
+    case r: RDD[_] => JNothing
+    // if it's a scala object, we can simply keep the full class path.
+    // TODO: currently if the class name ends with "$", we think it's a scala object, there is
+    // probably a better way to check it.
+    case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName
+    // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper
+    case p: Product => try {
+      val fieldNames = getConstructorParameters(p.getClass).map(_._1)
+      val fieldValues = p.productIterator.toSeq
+      assert(fieldNames.length == fieldValues.length)
+      ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
+        case (name, value) => name -> parseToJson(value)
+      }.toList
+    } catch {
+      case _: RuntimeException => null
+    }
+    case _ => JNull
+  }
+}
+
+object TreeNode {
+  def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = {
+    val jsonAST = parse(json)
+    assert(jsonAST.isInstanceOf[JArray])
+    reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType]
+  }
+
+  private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = {
+    assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject]))
+    val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*)
+
+    def parseNextNode(): TreeNode[_] = {
+      val nextNode = jsonNodes.pop()
+
+      val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s)
+      if (cls == classOf[Literal]) {
+        Literal.fromJSON(nextNode)
+      } else if (cls.getName.endsWith("$")) {
+        cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]]
+      } else {
+        val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt
+
+        val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode())
+        val fields = getConstructorParameters(cls)
+
+        val parameters: Array[AnyRef] = fields.map {
+          case (fieldName, fieldType) =>
+            parseFromJson(nextNode \ fieldName, fieldType, children, sc)
+        }.toArray
+
+        val maybeCtor = cls.getConstructors.find { p =>
+          val expectedTypes = p.getParameterTypes
+          expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall {
+            case (cls, tpe) => cls == getClassFromType(tpe)
+          }
+        }
+        if (maybeCtor.isEmpty) {
+          sys.error(s"No valid constructor for ${cls.getName}")
+        } else {
+          try {
+            maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]]
+          } catch {
+            case e: java.lang.IllegalArgumentException =>
+              throw new RuntimeException(
+                s"""
+                  |Failed to construct tree node: ${cls.getName}
+                  |ctor: ${maybeCtor.get}
+                  |types: ${parameters.map(_.getClass).mkString(", ")}
+                  |args: ${parameters.mkString(", ")}
+                """.stripMargin, e)
+          }
+        }
+      }
+    }
+
+    parseNextNode()
+  }
+
+  import universe._
+
+  private def parseFromJson(
+      value: JValue,
+      expectedType: Type,
+      children: Seq[TreeNode[_]],
+      sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized {
+    if (value == JNull) return null
+
+    expectedType match {
+      case t if t <:< definitions.BooleanTpe =>
+        value.asInstanceOf[JBool].value: java.lang.Boolean
+      case t if t <:< definitions.ByteTpe =>
+        value.asInstanceOf[JInt].num.toByte: java.lang.Byte
+      case t if t <:< definitions.ShortTpe =>
+        value.asInstanceOf[JInt].num.toShort: java.lang.Short
+      case t if t <:< definitions.IntTpe =>
+        value.asInstanceOf[JInt].num.toInt: java.lang.Integer
+      case t if t <:< definitions.LongTpe =>
+        value.asInstanceOf[JInt].num.toLong: java.lang.Long
+      case t if t <:< definitions.FloatTpe =>
+        value.asInstanceOf[JDouble].num.toFloat: java.lang.Float
+      case t if t <:< definitions.DoubleTpe =>
+        value.asInstanceOf[JDouble].num: java.lang.Double
+
+      case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num
+      case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s
+      case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s)
+      case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value)
+      case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject])
+      case t if t <:< localTypeOf[StorageLevel] =>
+        val JBool(useDisk) = value \ "useDisk"
+        val JBool(useMemory) = value \ "useMemory"
+        val JBool(useOffHeap) = value \ "useOffHeap"
+        val JBool(deserialized) = value \ "deserialized"
+        val JInt(replication) = value \ "replication"
+        StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt)
+      case t if t <:< localTypeOf[TreeNode[_]] => value match {
+        case JInt(i) => children(i.toInt)
+        case arr: JArray => reconstruct(arr, sc)
+        case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.")
+      }
+      case t if t <:< localTypeOf[Option[_]] =>
+        if (value == JNothing) {
+          None
+        } else {
+          val TypeRef(_, _, Seq(optType)) = t
+          Option(parseFromJson(value, optType, children, sc))
+        }
+      case t if t <:< localTypeOf[Seq[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        val JArray(elements) = value
+        elements.map(parseFromJson(_, elementType, children, sc)).toSeq
+      case t if t <:< localTypeOf[Map[_, _]] =>
+        val TypeRef(_, _, Seq(keyType, valueType)) = t
+        val JObject(fields) = value
+        fields.map {
+          case (name, value) => name -> parseFromJson(value, valueType, children, sc)
+        }.toMap
+      case t if t <:< localTypeOf[RDD[_]] =>
+        new EmptyRDD[Any](sc)
+      case _ if isScalaObject(value) =>
+        val JString(clsName) = value \ "object"
+        val cls = Utils.classForName(clsName)
+        cls.getField("MODULE$").get(cls)
+      case t if t <:< localTypeOf[Product] =>
+        val fields = getConstructorParameters(t)
+        val clsName = getClassNameFromType(t)
+        parseToProduct(clsName, fields, value, children, sc)
+      // There maybe some cases that the parameter type signature is not Product but the value is,
+      // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here.
+      case _ if isScalaProduct(value) =>
+        val JString(clsName) = value \ "product-class"
+        val fields = getConstructorParameters(Utils.classForName(clsName))
+        parseToProduct(clsName, fields, value, children, sc)
+      case _ => sys.error(s"Do not support type $expectedType with json $value.")
+    }
+  }
+
+  private def parseToProduct(
+      clsName: String,
+      fields: Seq[(String, Type)],
+      value: JValue,
+      children: Seq[TreeNode[_]],
+      sc: SparkContext): AnyRef = {
+    val parameters: Array[AnyRef] = fields.map {
+      case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc)
+    }.toArray
+    val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size)
+    ctor.newInstance(parameters: _*).asInstanceOf[AnyRef]
+  }
+
+  private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match {
+    case JString(str) if str.endsWith("$") => true
+    case _ => false
+  }
+
+  private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match {
+    case _: JString => true
+    case _ => false
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 4b54c31..34ce6be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -104,8 +104,8 @@ object DataType {
   def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
 
   private val nonDecimalNameToType = {
-    Seq(NullType, DateType, TimestampType, BinaryType,
-      IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+    Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
+      DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType)
       .map(t => t.typeName -> t).toMap
   }
 
@@ -127,7 +127,7 @@ object DataType {
   }
 
   // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
-  private def parseDataType(json: JValue): DataType = json match {
+  private[sql] def parseDataType(json: JValue): DataType = json match {
     case JString(name) =>
       nameToType(name)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 0289988..bc36a55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -36,14 +36,18 @@ class EncoderResolutionSuite extends PlanTest {
     val encoder = ExpressionEncoder[StringLongClass]
     val cls = classOf[StringLongClass]
 
+
     {
       val attrs = Seq('a.string, 'b.int)
       val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
       val expected: Expression = NewInstance(
         cls,
-        toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
-        false,
-        ObjectType(cls))
+        Seq(
+          toExternalString('a.string),
+          AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
+        ),
+        ObjectType(cls),
+        propagateNull = false)
       compareExpressions(fromRowExpr, expected)
     }
 
@@ -52,9 +56,12 @@ class EncoderResolutionSuite extends PlanTest {
       val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
       val expected = NewInstance(
         cls,
-        toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
-        false,
-        ObjectType(cls))
+        Seq(
+          toExternalString('a.int.cast(StringType)),
+          AssertNotNull('b.long, cls.getName, "b", "Long")
+        ),
+        ObjectType(cls),
+        propagateNull = false)
       compareExpressions(fromRowExpr, expected)
     }
   }
@@ -64,27 +71,28 @@ class EncoderResolutionSuite extends PlanTest {
     val innerCls = classOf[StringLongClass]
     val cls = classOf[ComplexClass]
 
-    val structType = new StructType().add("a", IntegerType).add("b", LongType)
-    val attrs = Seq('a.int, 'b.struct(structType))
+    val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
     val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
     val expected: Expression = NewInstance(
       cls,
       Seq(
-        'a.int.cast(LongType),
+        AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
         If(
-          'b.struct(structType).isNull,
+          'b.struct('a.int, 'b.long).isNull,
           Literal.create(null, ObjectType(innerCls)),
           NewInstance(
             innerCls,
             Seq(
               toExternalString(
-                GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)),
-              GetStructField('b.struct(structType), 1, Some("b"))),
-            false,
-            ObjectType(innerCls))
+                GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
+              AssertNotNull(
+                GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
+                innerCls.getName, "b", "Long")),
+            ObjectType(innerCls),
+            propagateNull = false)
         )),
-      false,
-      ObjectType(cls))
+      ObjectType(cls),
+      propagateNull = false)
     compareExpressions(fromRowExpr, expected)
   }
 
@@ -94,8 +102,7 @@ class EncoderResolutionSuite extends PlanTest {
       ExpressionEncoder[Long])
     val cls = classOf[StringLongClass]
 
-    val structType = new StructType().add("a", StringType).add("b", ByteType)
-    val attrs = Seq('a.struct(structType), 'b.int)
+    val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
     val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
     val expected: Expression = NewInstance(
       classOf[Tuple2[_, _]],
@@ -103,16 +110,62 @@ class EncoderResolutionSuite extends PlanTest {
         NewInstance(
           cls,
           Seq(
-            toExternalString(GetStructField('a.struct(structType), 0, Some("a"))),
-            GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)),
-          false,
-          ObjectType(cls)),
+            toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
+            AssertNotNull(
+              GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
+              cls.getName, "b", "Long")),
+          ObjectType(cls),
+          propagateNull = false),
         'b.int.cast(LongType)),
-      false,
-      ObjectType(classOf[Tuple2[_, _]]))
+      ObjectType(classOf[Tuple2[_, _]]),
+      propagateNull = false)
     compareExpressions(fromRowExpr, expected)
   }
 
+  test("the real number of fields doesn't match encoder schema: tuple encoder") {
+    val encoder = ExpressionEncoder[(String, Long)]
+
+    {
+      val attrs = Seq('a.string, 'b.long, 'c.int)
+      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+        "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " +
+          "but failed as the number of fields does not line up.\n" +
+          " - Input schema: struct<a:string,b:bigint,c:int>\n" +
+          " - Target schema: struct<_1:string,_2:bigint>")
+    }
+
+    {
+      val attrs = Seq('a.string)
+      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+        "Try to map struct<a:string> to Tuple2, " +
+          "but failed as the number of fields does not line up.\n" +
+          " - Input schema: struct<a:string>\n" +
+          " - Target schema: struct<_1:string,_2:bigint>")
+    }
+  }
+
+  test("the real number of fields doesn't match encoder schema: nested tuple encoder") {
+    val encoder = ExpressionEncoder[(String, (Long, String))]
+
+    {
+      val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
+      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+        "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " +
+          "but failed as the number of fields does not line up.\n" +
+          " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" +
+          " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+    }
+
+    {
+      val attrs = Seq('a.string, 'b.struct('x.long))
+      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+        "Try to map struct<x:bigint> to Tuple2, " +
+          "but failed as the number of fields does not line up.\n" +
+          " - Input schema: struct<a:string,b:struct<x:bigint>>\n" +
+          " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+    }
+  }
+
   private def toExternalString(e: Expression): Expression = {
     Invoke(e, "toString", ObjectType(classOf[String]), Nil)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 47b07df..98f29e5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -130,6 +130,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
   encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null")
   encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map")
 
+  encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
+  encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
+
   // Kryo encoders
   encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
   encodeDecodeTest(new KryoSerializable(15), "kryo object")(
@@ -147,6 +150,7 @@ class ExpressionEncoderSuite extends SparkFunSuite {
 
   case class InnerClass(i: Int)
   productTest(InnerClass(1))
+  encodeDecodeTest(Array(InnerClass(1)), "array of inner class")
 
   productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
 
@@ -156,6 +160,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
 
   productTest(OptionalData(None, None, None, None, None, None, None, None))
 
+  encodeDecodeTest(Seq(Some(1), None), "Option in array")
+  encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map")
+
   productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
 
   productTest(BoxedData(null, null, null, null, null, null, null))
@@ -240,6 +247,8 @@ class ExpressionEncoderSuite extends SparkFunSuite {
     ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
   }
 
+  productTest(("UDT", new ExamplePoint(0.1, 0.2)))
+
   test("nullable of encoder schema") {
     def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = {
       assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 0ea51ec..8f4faab 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -108,7 +108,8 @@ class RowEncoderSuite extends SparkFunSuite {
       .add("arrayOfArrayOfString", ArrayType(arrayOfString))
       .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
       .add("arrayOfMap", ArrayType(mapOfString))
-      .add("arrayOfStruct", ArrayType(structOfString)))
+      .add("arrayOfStruct", ArrayType(structOfString))
+      .add("arrayOfUDT", arrayOfUDT))
 
   encodeDecodeTest(
     new StructType()
@@ -130,18 +131,6 @@ class RowEncoderSuite extends SparkFunSuite {
         new StructType().add("array", arrayOfString).add("map", mapOfString))
       .add("structOfUDT", structOfUDT))
 
-  test(s"encode/decode: arrayOfUDT") {
-    val schema = new StructType()
-      .add("arrayOfUDT", arrayOfUDT)
-
-    val encoder = RowEncoder(schema)
-
-    val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4)))
-    val row = encoder.toRow(input)
-    val convertedBack = encoder.fromRow(row)
-    assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0))
-  }
-
   test(s"encode/decode: Product") {
     val schema = new StructType()
       .add("structAsProduct",

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 62fd472..9f1b192 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -165,7 +165,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
       "b", create_row(Map("a" -> "b")))
     checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)),
       "b", create_row(Seq("a", "b")))
-    checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")),
+    checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")),
       1, create_row(create_row(1)))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 3180049..1beb080 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -165,13 +165,11 @@ class DataFrame private[sql](
    * @param _numRows Number of rows to show
    * @param truncate Whether truncate long strings and align cells right
    */
-  private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
+  override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
     val numRows = _numRows.max(0)
-    val sb = new StringBuilder
     val takeResult = take(numRows + 1)
     val hasMoreData = takeResult.length > numRows
     val data = takeResult.take(numRows)
-    val numCols = schema.fieldNames.length
 
     // For array values, replace Seq and Array with square brackets
     // For cells that are beyond 20 characters, replace it with the first 17 and "..."
@@ -179,6 +177,7 @@ class DataFrame private[sql](
       row.toSeq.map { cell =>
         val str = cell match {
           case null => "null"
+          case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]")
           case array: Array[_] => array.mkString("[", ", ", "]")
           case seq: Seq[_] => seq.mkString("[", ", ", "]")
           case _ => cell.toString
@@ -187,50 +186,7 @@ class DataFrame private[sql](
       }: Seq[String]
     }
 
-    // Initialise the width of each column to a minimum value of '3'
-    val colWidths = Array.fill(numCols)(3)
-
-    // Compute the width of each column
-    for (row <- rows) {
-      for ((cell, i) <- row.zipWithIndex) {
-        colWidths(i) = math.max(colWidths(i), cell.length)
-      }
-    }
-
-    // Create SeparateLine
-    val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
-
-    // column names
-    rows.head.zipWithIndex.map { case (cell, i) =>
-      if (truncate) {
-        StringUtils.leftPad(cell, colWidths(i))
-      } else {
-        StringUtils.rightPad(cell, colWidths(i))
-      }
-    }.addString(sb, "|", "|", "|\n")
-
-    sb.append(sep)
-
-    // data
-    rows.tail.map {
-      _.zipWithIndex.map { case (cell, i) =>
-        if (truncate) {
-          StringUtils.leftPad(cell.toString, colWidths(i))
-        } else {
-          StringUtils.rightPad(cell.toString, colWidths(i))
-        }
-      }.addString(sb, "|", "|", "|\n")
-    }
-
-    sb.append(sep)
-
-    // For Data that has more than "numRows" records
-    if (hasMoreData) {
-      val rowsString = if (numRows == 1) "row" else "rows"
-      sb.append(s"only showing top $numRows $rowsString\n")
-    }
-
-    sb.toString()
+    formatString ( rows, numRows, hasMoreData, truncate )
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3bd18a1..cbf15a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import scala.collection.JavaConverters._
 
+import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.function._
 import org.apache.spark.rdd.RDD
@@ -64,7 +65,7 @@ import org.apache.spark.util.Utils
 class Dataset[T] private[sql](
     @transient override val sqlContext: SQLContext,
     @transient override val queryExecution: QueryExecution,
-    tEncoder: Encoder[T]) extends Queryable with Serializable {
+    tEncoder: Encoder[T]) extends Queryable with Serializable with Logging {
 
   /**
    * An unresolved version of the internal encoder for the type of this [[Dataset]].  This one is
@@ -79,7 +80,7 @@ class Dataset[T] private[sql](
 
   /**
    * The encoder where the expressions used to construct an object from an input row have been
-   * bound to the ordinals of the given schema.
+   * bound to the ordinals of this [[Dataset]]'s output schema.
    */
   private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
 
@@ -225,7 +226,42 @@ class Dataset[T] private[sql](
    *
    * @since 1.6.0
    */
-  def show(numRows: Int, truncate: Boolean): Unit = toDF().show(numRows, truncate)
+  // scalastyle:off println
+  def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate))
+  // scalastyle:on println
+
+  /**
+   * Compose the string representing rows for output
+   * @param _numRows Number of rows to show
+   * @param truncate Whether truncate long strings and align cells right
+   */
+  override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
+    val numRows = _numRows.max(0)
+    val takeResult = take(numRows + 1)
+    val hasMoreData = takeResult.length > numRows
+    val data = takeResult.take(numRows)
+
+    // For array values, replace Seq and Array with square brackets
+    // For cells that are beyond 20 characters, replace it with the first 17 and "..."
+    val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: (data.map {
+      case r: Row => r
+      case tuple: Product => Row.fromTuple(tuple)
+      case o => Row(o)
+    } map { row =>
+      row.toSeq.map { cell =>
+        val str = cell match {
+          case null => "null"
+          case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]")
+          case array: Array[_] => array.mkString("[", ", ", "]")
+          case seq: Seq[_] => seq.mkString("[", ", ", "]")
+          case _ => cell.toString
+        }
+        if (truncate && str.length > 20) str.substring(0, 17) + "..." else str
+      }: Seq[String]
+    })
+
+    formatString ( rows, numRows, hasMoreData, truncate )
+  }
 
   /**
     * Returns a new [[Dataset]] that has exactly `numPartitions` partitions.

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index b8a4302..ea5a9af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -74,9 +74,7 @@ private[sql] case class LogicalRDD(
 
   override def children: Seq[LogicalPlan] = Nil
 
-  override protected final def otherCopyArgs: Seq[AnyRef] = {
-    sqlContext :: Nil
-  }
+  override protected final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil
 
   override def newInstance(): LogicalRDD.this.type =
     LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala
index f2f5997..b397d42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
 
 import scala.util.control.NonFatal
 
+import org.apache.commons.lang3.StringUtils
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.types.StructType
 
@@ -42,4 +43,68 @@ private[sql] trait Queryable {
   def explain(extended: Boolean): Unit
 
   def explain(): Unit
+
+  private[sql] def showString(_numRows: Int, truncate: Boolean = true): String
+
+  /**
+   * Format the string representing rows for output
+   * @param rows The rows to show
+   * @param numRows Number of rows to show
+   * @param hasMoreData Whether some rows are not shown due to the limit
+   * @param truncate Whether truncate long strings and align cells right
+   *
+   */
+  private[sql] def formatString (
+      rows: Seq[Seq[String]],
+      numRows: Int,
+      hasMoreData : Boolean,
+      truncate: Boolean = true): String = {
+    val sb = new StringBuilder
+    val numCols = schema.fieldNames.length
+
+    // Initialise the width of each column to a minimum value of '3'
+    val colWidths = Array.fill(numCols)(3)
+
+    // Compute the width of each column
+    for (row <- rows) {
+      for ((cell, i) <- row.zipWithIndex) {
+        colWidths(i) = math.max(colWidths(i), cell.length)
+      }
+    }
+
+    // Create SeparateLine
+    val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
+
+    // column names
+    rows.head.zipWithIndex.map { case (cell, i) =>
+      if (truncate) {
+        StringUtils.leftPad(cell, colWidths(i))
+      } else {
+        StringUtils.rightPad(cell, colWidths(i))
+      }
+    }.addString(sb, "|", "|", "|\n")
+
+    sb.append(sep)
+
+    // data
+    rows.tail.map {
+      _.zipWithIndex.map { case (cell, i) =>
+        if (truncate) {
+          StringUtils.leftPad(cell.toString, colWidths(i))
+        } else {
+          StringUtils.rightPad(cell.toString, colWidths(i))
+        }
+      }.addString(sb, "|", "|", "|\n")
+    }
+
+    sb.append(sep)
+
+    // For Data that has more than "numRows" records
+    if (hasMoreData) {
+      val rowsString = if (numRows == 1) "row" else "rows"
+      sb.append(s"only showing top $numRows $rowsString\n")
+    }
+
+    sb.toString()
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index ce701fb..f817d5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -60,9 +60,9 @@ private[sql] case class InMemoryRelation(
     storageLevel: StorageLevel,
     @transient child: SparkPlan,
     tableName: Option[String])(
-    @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null,
-    @transient private var _statistics: Statistics = null,
-    private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
+    @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
+    @transient private[sql] var _statistics: Statistics = null,
+    private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
   extends LogicalPlan with MultiInstanceRelation {
 
   private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org