You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2018/04/20 13:02:33 UTC

spark git commit: [SPARK-23595][SQL] ValidateExternalType should support interpreted execution

Repository: spark
Updated Branches:
  refs/heads/master 074a7f905 -> 0dd97f6ea


[SPARK-23595][SQL] ValidateExternalType should support interpreted execution

## What changes were proposed in this pull request?
This pr supported interpreted mode for `ValidateExternalType`.

## How was this patch tested?
Added tests in `ObjectExpressionsSuite`.

Author: Takeshi Yamamuro <ya...@apache.org>

Closes #20757 from maropu/SPARK-23595.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0dd97f6e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0dd97f6e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0dd97f6e

Branch: refs/heads/master
Commit: 0dd97f6ea4affde1531dec1bec004b7ab18c6965
Parents: 074a7f9
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Fri Apr 20 15:02:27 2018 +0200
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Fri Apr 20 15:02:27 2018 +0200

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 13 ++++++++
 .../sql/catalyst/encoders/RowEncoder.scala      |  2 +-
 .../catalyst/expressions/objects/objects.scala  | 34 +++++++++++++++++---
 .../expressions/ObjectExpressionsSuite.scala    | 33 +++++++++++++++++--
 4 files changed, 74 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0dd97f6e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 818cc2f..f9acc20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -846,6 +846,19 @@ object ScalaReflection extends ScalaReflection {
     }
   }
 
+  def javaBoxedType(dt: DataType): Class[_] = dt match {
+    case _: DecimalType => classOf[Decimal]
+    case BinaryType => classOf[Array[Byte]]
+    case StringType => classOf[UTF8String]
+    case CalendarIntervalType => classOf[CalendarInterval]
+    case _: StructType => classOf[InternalRow]
+    case _: ArrayType => classOf[ArrayType]
+    case _: MapType => classOf[MapType]
+    case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
+    case ObjectType(cls) => cls
+    case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object])
+  }
+
   def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
     if (arguments != Nil) {
       arguments.map(e => dataTypeJavaClass(e.dataType))

http://git-wip-us.apache.org/repos/asf/spark/blob/0dd97f6e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 789750f..3340789 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0dd97f6e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 32c1f34..f1ffcae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 import org.apache.spark.util.Utils
 
 /**
@@ -1672,13 +1673,36 @@ case class ValidateExternalType(child: Expression, expected: DataType)
 
   override def nullable: Boolean = child.nullable
 
-  override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
-
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+  override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
 
   private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}"
 
+  private lazy val checkType: (Any) => Boolean = expected match {
+    case _: DecimalType =>
+      (value: Any) => {
+        value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] ||
+          value.isInstanceOf[Decimal]
+      }
+    case _: ArrayType =>
+      (value: Any) => {
+        value.getClass.isArray || value.isInstanceOf[Seq[_]]
+      }
+    case _ =>
+      val dataTypeClazz = ScalaReflection.javaBoxedType(dataType)
+      (value: Any) => {
+        dataTypeClazz.isInstance(value)
+      }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val result = child.eval(input)
+    if (checkType(result)) {
+      result
+    } else {
+      throw new RuntimeException(s"${result.getClass.getName}$errMsg")
+    }
+  }
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     // Use unnamed reference that doesn't create a local field here to reduce the number of fields
     // because errMsgField is used only when the type doesn't match.
@@ -1691,7 +1715,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
         Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
           .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
       case _: ArrayType =>
-        s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
+        s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}"
       case _ =>
         s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0dd97f6e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index bcd035c..7136af8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 class InvokeTargetClass extends Serializable {
   def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
@@ -296,7 +296,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
     val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")
     Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) =>
-      checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input)))
+      checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input)))
     }
 
     // If an input row or a field are null, a runtime exception will be thrown
@@ -472,6 +472,35 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val deserializer = toMapExpr.copy(inputData = Literal.create(data))
     checkObjectExprEvaluation(deserializer, expected = data)
   }
+
+  test("SPARK-23595 ValidateExternalType should support interpreted execution") {
+    val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
+    Seq(
+      (true, BooleanType),
+      (2.toByte, ByteType),
+      (5.toShort, ShortType),
+      (23, IntegerType),
+      (61L, LongType),
+      (1.0f, FloatType),
+      (10.0, DoubleType),
+      ("abcd".getBytes, BinaryType),
+      ("abcd", StringType),
+      (BigDecimal.valueOf(10), DecimalType.IntDecimal),
+      (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType),
+      (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
+      (Array(3, 2, 1), ArrayType(IntegerType))
+    ).foreach { case (input, dt) =>
+      val validateType = ValidateExternalType(
+        GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt)
+      checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input))))
+    }
+
+    checkExceptionInExpression[RuntimeException](
+      ValidateExternalType(
+        GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType),
+      InternalRow.fromSeq(Seq(Row(1))),
+      "java.lang.Integer is not a valid external type for schema of double")
+  }
 }
 
 class TestBean extends Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org