You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/12/16 22:20:27 UTC

spark git commit: [SPARK-12320][SQL] throw exception if the number of fields does not line up for Tuple encoder

Repository: spark
Updated Branches:
  refs/heads/master 1a8b2a17d -> a783a8ed4


[SPARK-12320][SQL] throw exception if the number of fields does not line up for Tuple encoder

Author: Wenchen Fan <we...@databricks.com>

Closes #10293 from cloud-fan/err-msg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a783a8ed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a783a8ed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a783a8ed

Branch: refs/heads/master
Commit: a783a8ed49814a09fde653433a3d6de398ddf888
Parents: 1a8b2a1
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Dec 16 13:18:56 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Dec 16 13:20:12 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/dsl/package.scala |  3 +-
 .../catalyst/encoders/ExpressionEncoder.scala   | 36 +++++++++++-
 .../expressions/complexTypeExtractors.scala     | 10 ++--
 .../encoders/EncoderResolutionSuite.scala       | 60 +++++++++++++++++---
 .../catalyst/expressions/ComplexTypeSuite.scala |  2 +-
 5 files changed, 93 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index e509711..8102c93 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -227,9 +227,10 @@ package object dsl {
         AttributeReference(s, mapType, nullable = true)()
 
       /** Creates a new AttributeReference of type struct */
-      def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
       def struct(structType: StructType): AttributeReference =
         AttributeReference(s, structType, nullable = true)()
+      def struct(attrs: AttributeReference*): AttributeReference =
+        struct(StructType.fromAttributes(attrs))
     }
 
     implicit class DslAttribute(a: AttributeReference) {

http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 363178b..7a4401c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -244,9 +244,41 @@ case class ExpressionEncoder[T](
   def resolve(
       schema: Seq[Attribute],
       outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
-    val positionToAttribute = AttributeMap.toIndex(schema)
+    def fail(st: StructType, maxOrdinal: Int): Unit = {
+      throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
+        "but failed as the number of fields does not line up.\n" +
+        " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" +
+        " - Target schema: " + this.schema.simpleString)
+    }
+
+    var maxOrdinal = -1
+    fromRowExpression.foreach {
+      case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
+      case _ =>
+    }
+    if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) {
+      fail(StructType.fromAttributes(schema), maxOrdinal)
+    }
+
     val unbound = fromRowExpression transform {
-      case b: BoundReference => positionToAttribute(b.ordinal)
+      case b: BoundReference => schema(b.ordinal)
+    }
+
+    val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int]
+    unbound.foreach {
+      case g: GetStructField =>
+        val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1)
+        if (maxOrdinal < g.ordinal) {
+          exprToMaxOrdinal.update(g.child, g.ordinal)
+        }
+      case _ =>
+    }
+    exprToMaxOrdinal.foreach {
+      case (expr, maxOrdinal) =>
+        val schema = expr.dataType.asInstanceOf[StructType]
+        if (maxOrdinal != schema.length - 1) {
+          fail(schema, maxOrdinal)
+        }
     }
 
     val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))

http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 10ce10a..58f6a7e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -104,14 +104,14 @@ object ExtractValue {
 case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
   extends UnaryExpression {
 
-  private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)
+  private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType]
 
-  override def dataType: DataType = field.dataType
-  override def nullable: Boolean = child.nullable || field.nullable
-  override def toString: String = s"$child.${name.getOrElse(field.name)}"
+  override def dataType: DataType = childSchema(ordinal).dataType
+  override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
+  override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}"
 
   protected override def nullSafeEval(input: Any): Any =
-    input.asInstanceOf[InternalRow].get(ordinal, field.dataType)
+    input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     nullSafeCodeGen(ctx, ev, eval => {

http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 0289988..815a03f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -64,22 +64,21 @@ class EncoderResolutionSuite extends PlanTest {
     val innerCls = classOf[StringLongClass]
     val cls = classOf[ComplexClass]
 
-    val structType = new StructType().add("a", IntegerType).add("b", LongType)
-    val attrs = Seq('a.int, 'b.struct(structType))
+    val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
     val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
     val expected: Expression = NewInstance(
       cls,
       Seq(
         'a.int.cast(LongType),
         If(
-          'b.struct(structType).isNull,
+          'b.struct('a.int, 'b.long).isNull,
           Literal.create(null, ObjectType(innerCls)),
           NewInstance(
             innerCls,
             Seq(
               toExternalString(
-                GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)),
-              GetStructField('b.struct(structType), 1, Some("b"))),
+                GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
+              GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))),
             false,
             ObjectType(innerCls))
         )),
@@ -94,8 +93,7 @@ class EncoderResolutionSuite extends PlanTest {
       ExpressionEncoder[Long])
     val cls = classOf[StringLongClass]
 
-    val structType = new StructType().add("a", StringType).add("b", ByteType)
-    val attrs = Seq('a.struct(structType), 'b.int)
+    val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
     val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
     val expected: Expression = NewInstance(
       classOf[Tuple2[_, _]],
@@ -103,8 +101,8 @@ class EncoderResolutionSuite extends PlanTest {
         NewInstance(
           cls,
           Seq(
-            toExternalString(GetStructField('a.struct(structType), 0, Some("a"))),
-            GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)),
+            toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
+            GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)),
           false,
           ObjectType(cls)),
         'b.int.cast(LongType)),
@@ -113,6 +111,50 @@ class EncoderResolutionSuite extends PlanTest {
     compareExpressions(fromRowExpr, expected)
   }
 
+  test("the real number of fields doesn't match encoder schema: tuple encoder") {
+    val encoder = ExpressionEncoder[(String, Long)]
+
+    {
+      val attrs = Seq('a.string, 'b.long, 'c.int)
+      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+        "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " +
+          "but failed as the number of fields does not line up.\n" +
+          " - Input schema: struct<a:string,b:bigint,c:int>\n" +
+          " - Target schema: struct<_1:string,_2:bigint>")
+    }
+
+    {
+      val attrs = Seq('a.string)
+      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+        "Try to map struct<a:string> to Tuple2, " +
+          "but failed as the number of fields does not line up.\n" +
+          " - Input schema: struct<a:string>\n" +
+          " - Target schema: struct<_1:string,_2:bigint>")
+    }
+  }
+
+  test("the real number of fields doesn't match encoder schema: nested tuple encoder") {
+    val encoder = ExpressionEncoder[(String, (Long, String))]
+
+    {
+      val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
+      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+        "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " +
+          "but failed as the number of fields does not line up.\n" +
+          " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" +
+          " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+    }
+
+    {
+      val attrs = Seq('a.string, 'b.struct('x.long))
+      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+        "Try to map struct<x:bigint> to Tuple2, " +
+          "but failed as the number of fields does not line up.\n" +
+          " - Input schema: struct<a:string,b:struct<x:bigint>>\n" +
+          " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+    }
+  }
+
   private def toExternalString(e: Expression): Expression = {
     Invoke(e, "toString", ObjectType(classOf[String]), Nil)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 62fd472..9f1b192 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -165,7 +165,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
       "b", create_row(Map("a" -> "b")))
     checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)),
       "b", create_row(Seq("a", "b")))
-    checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")),
+    checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")),
       1, create_row(create_row(1)))
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org