You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2018/04/04 16:36:21 UTC
spark git commit: [SPARK-23583][SQL] Invoke should support
interpreted execution
Repository: spark
Updated Branches:
refs/heads/master 5197562af -> a35523653
[SPARK-23583][SQL] Invoke should support interpreted execution
## What changes were proposed in this pull request?
This pr added interpreted execution for `Invoke`.
## How was this patch tested?
Added tests in `ObjectExpressionsSuite`.
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Closes #20797 from kiszk/SPARK-28583.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a3552365
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a3552365
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a3552365
Branch: refs/heads/master
Commit: a35523653cdac039ee2ddff316bc2c25d6514a91
Parents: 5197562
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Wed Apr 4 18:36:15 2018 +0200
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Wed Apr 4 18:36:15 2018 +0200
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 48 ++++++++++++++-
.../catalyst/expressions/objects/objects.scala | 56 +++++++++++++++--
.../expressions/ObjectExpressionsSuite.scala | 65 ++++++++++++++++++++
3 files changed, 163 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a3552365/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9a4bf00..1aae3ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection {
"interface", "long", "native", "new", "null", "package", "private", "protected", "public",
"return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
"throws", "transient", "true", "try", "void", "volatile", "while")
+
+ val typeJavaMapping = Map[DataType, Class[_]](
+ BooleanType -> classOf[Boolean],
+ ByteType -> classOf[Byte],
+ ShortType -> classOf[Short],
+ IntegerType -> classOf[Int],
+ LongType -> classOf[Long],
+ FloatType -> classOf[Float],
+ DoubleType -> classOf[Double],
+ StringType -> classOf[UTF8String],
+ DateType -> classOf[DateType.InternalType],
+ TimestampType -> classOf[TimestampType.InternalType],
+ BinaryType -> classOf[BinaryType.InternalType],
+ CalendarIntervalType -> classOf[CalendarInterval]
+ )
+
+ val typeBoxedJavaMapping = Map[DataType, Class[_]](
+ BooleanType -> classOf[java.lang.Boolean],
+ ByteType -> classOf[java.lang.Byte],
+ ShortType -> classOf[java.lang.Short],
+ IntegerType -> classOf[java.lang.Integer],
+ LongType -> classOf[java.lang.Long],
+ FloatType -> classOf[java.lang.Float],
+ DoubleType -> classOf[java.lang.Double],
+ DateType -> classOf[java.lang.Integer],
+ TimestampType -> classOf[java.lang.Long]
+ )
+
+ def dataTypeJavaClass(dt: DataType): Class[_] = {
+ dt match {
+ case _: DecimalType => classOf[Decimal]
+ case _: StructType => classOf[InternalRow]
+ case _: ArrayType => classOf[ArrayData]
+ case _: MapType => classOf[MapData]
+ case ObjectType(cls) => cls
+ case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
+ }
+ }
+
+ def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
+ if (arguments != Nil) {
+ arguments.map(e => dataTypeJavaClass(e.dataType))
+ } else {
+ Seq.empty
+ }
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/a3552365/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 0e9d357..a455c1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.objects
-import java.lang.reflect.Modifier
+import java.lang.reflect.{Method, Modifier}
import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
@@ -28,7 +28,7 @@ import scala.util.Try
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
@@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression {
(argCode, argValues.mkString(", "), resultIsNull)
}
+
+ /**
+ * Evaluate each argument with a given row, invoke a method with a given object and arguments,
+ * and cast a return value if the return type can be mapped to a Java Boxed type
+ *
+ * @param obj the object for the method to be called. If null, perform s static method call
+ * @param method the method object to be called
+ * @param arguments the arguments used for the method call
+ * @param input the row used for evaluating arguments
+ * @param dataType the data type of the return object
+ * @return the return object of a method call
+ */
+ def invoke(
+ obj: Any,
+ method: Method,
+ arguments: Seq[Expression],
+ input: InternalRow,
+ dataType: DataType): Any = {
+ val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
+ if (needNullCheck && args.exists(_ == null)) {
+ // return null if one of arguments is null
+ null
+ } else {
+ val ret = method.invoke(obj, args: _*)
+ val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
+ if (boxedClass.isDefined) {
+ boxedClass.get.cast(ret)
+ } else {
+ ret
+ }
+ }
+ }
}
/**
@@ -264,12 +296,11 @@ case class Invoke(
propagateNull: Boolean = true,
returnNullable : Boolean = true) extends InvokeLike {
+ lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
+
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
-
private lazy val encodedFunctionName = TermName(functionName).encodedName.toString
@transient lazy val method = targetObject.dataType match {
@@ -283,6 +314,21 @@ case class Invoke(
case _ => None
}
+ override def eval(input: InternalRow): Any = {
+ val obj = targetObject.eval(input)
+ if (obj == null) {
+ // return null if obj is null
+ null
+ } else {
+ val invokeMethod = if (method.isDefined) {
+ method.get
+ } else {
+ obj.getClass.getDeclaredMethod(functionName, argClasses: _*)
+ }
+ invoke(obj, invokeMethod, arguments, input, dataType)
+ }
+ }
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
val obj = targetObject.genCode(ctx)
http://git-wip-us.apache.org/repos/asf/spark/blob/a3552365/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 0edd27c..9bfe291 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+class InvokeTargetClass extends Serializable {
+ def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
+ def filterPrimitiveInt(e: Int): Boolean = e > 0
+ def binOp(e1: Int, e2: Double): Double = e1 + e2
+}
+
+class InvokeTargetSubClass extends InvokeTargetClass {
+ override def binOp(e1: Int, e2: Double): Double = e1 - e2
+}
class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
}
+ test("SPARK-23583: Invoke should support interpreted execution") {
+ val targetObject = new InvokeTargetClass
+ val funcClass = classOf[InvokeTargetClass]
+ val funcObj = Literal.create(targetObject, ObjectType(funcClass))
+ val targetSubObject = new InvokeTargetSubClass
+ val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass]))
+ val funcNullObj = Literal.create(null, ObjectType(funcClass))
+
+ val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true))
+ val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false))
+ val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+ java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt),
+ false, InternalRow.fromSeq(Seq(-1)))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+ null, InternalRow.fromSeq(Seq(null)))
+
+ checkObjectExprEvaluation(
+ Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+ null, InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25))
+
+ checkObjectExprEvaluation(
+ Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25))
+ }
+
test("SPARK-23585: UnwrapOption should support interpreted execution") {
val cls = classOf[Option[Int]]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
@@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq()))
}
+ // by scala values instead of catalyst values.
+ private def checkObjectExprEvaluation(
+ expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
+ val serializer = new JavaSerializer(new SparkConf()).newInstance
+ val resolver = ResolveTimeZone(new SQLConf)
+ val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
+ checkEvaluationWithoutCodegen(expr, expected, inputRow)
+ checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow)
+ if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
+ checkEvaluationWithUnsafeProjection(
+ expr,
+ expected,
+ inputRow,
+ UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
+ }
+ checkEvaluationWithOptimization(expr, expected, inputRow)
+ }
+
test("SPARK-23594 GetExternalRowField should support interpreted execution") {
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org