You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/22 04:15:26 UTC
[spark] branch master updated: [SPARK-42052][SQL] Codegen Support for HiveSimpleUDF
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5825db81e00 [SPARK-42052][SQL] Codegen Support for HiveSimpleUDF
5825db81e00 is described below
commit 5825db81e0059a4895b4f59d57dec67b0bc618b4
Author: panbingkun <pb...@gmail.com>
AuthorDate: Wed Mar 22 12:14:58 2023 +0800
[SPARK-42052][SQL] Codegen Support for HiveSimpleUDF
### What changes were proposed in this pull request?
- As a subtask of [SPARK-42050](https://issues.apache.org/jira/browse/SPARK-42050), this PR adds Codegen Support for HiveSimpleUDF
- Extract a`HiveUDFEvaluatorBase` class for the common behaviors of HiveSimpleUDFEvaluator & HiveGenericUDFEvaluator.
### Why are the changes needed?
- Improve codegen coverage and performance.
- Following https://github.com/apache/spark/pull/39949. Make the code more concise.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Add new UT.
Pass GA.
Closes #40397 from panbingkun/refactor_HiveSimpleUDF.
Authored-by: panbingkun <pb...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../apache/spark/sql/hive/hiveUDFEvaluators.scala | 148 +++++++++++++++++++++
.../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 147 ++++++--------------
.../spark/sql/hive/execution/HiveUDFSuite.scala | 42 ++++++
3 files changed, 232 insertions(+), 105 deletions(-)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
new file mode 100644
index 00000000000..094f8ba7a0f
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, UDF}
+import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
+import org.apache.spark.sql.types.DataType
+
+abstract class HiveUDFEvaluatorBase[UDFType <: AnyRef](
+ funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+ extends HiveInspectors with Serializable {
+
+ @transient
+ lazy val function = funcWrapper.createFunction[UDFType]()
+
+ @transient
+ lazy val isUDFDeterministic = {
+ val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
+ udfType != null && udfType.deterministic() && !udfType.stateful()
+ }
+
+ def returnType: DataType
+
+ def setArg(index: Int, arg: Any): Unit
+
+ def doEvaluate(): Any
+
+ final def evaluate(): Any = {
+ try {
+ doEvaluate()
+ } catch {
+ case e: Throwable =>
+ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
+ s"${funcWrapper.functionClassName}",
+ s"${children.map(_.dataType.catalogString).mkString(", ")}",
+ s"${returnType.catalogString}",
+ e)
+ }
+ }
+}
+
+class HiveSimpleUDFEvaluator(
+ funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+ extends HiveUDFEvaluatorBase[UDF](funcWrapper, children) {
+
+ @transient
+ lazy val method = function.getResolver.
+ getEvalMethod(children.map(_.dataType.toTypeInfo).asJava)
+
+ @transient
+ private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+ @transient
+ private lazy val arguments = children.map(toInspector).toArray
+
+ // Create parameter converters
+ @transient
+ private lazy val conversionHelper = new ConversionHelper(method, arguments)
+
+ @transient
+ private lazy val inputs: Array[AnyRef] = new Array[AnyRef](children.length)
+
+ override def returnType: DataType = javaTypeToDataType(method.getGenericReturnType)
+
+ override def setArg(index: Int, arg: Any): Unit = {
+ inputs(index) = wrappers(index)(arg).asInstanceOf[AnyRef]
+ }
+
+ @transient
+ private lazy val unwrapper: Any => Any =
+ unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
+ method.getGenericReturnType, ObjectInspectorOptions.JAVA))
+
+ override def doEvaluate(): Any = {
+ val ret = FunctionRegistry.invoke(
+ method,
+ function,
+ conversionHelper.convertIfNecessary(inputs: _*): _*)
+ unwrapper(ret)
+ }
+}
+
+class HiveGenericUDFEvaluator(
+ funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+ extends HiveUDFEvaluatorBase[GenericUDF](funcWrapper, children) {
+
+ @transient
+ private lazy val argumentInspectors = children.map(toInspector)
+
+ @transient
+ lazy val returnInspector = {
+ function.initializeAndFoldConstants(argumentInspectors.toArray)
+ }
+
+ @transient
+ private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map {
+ case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
+ }.toArray[DeferredObject]
+
+ @transient
+ private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)
+
+ override def returnType: DataType = inspectorToDataType(returnInspector)
+
+ def setArg(index: Int, arg: Any): Unit =
+ deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
+
+ override def doEvaluate(): Any = unwrapper(function.evaluate(deferredObjects))
+}
+
+// Adapter from Catalyst ExpressionResult to Hive DeferredObject
+private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType)
+ extends DeferredObject with HiveInspectors {
+
+ private val wrapper = wrapperFor(oi, dataType)
+ private var func: Any = _
+ def set(func: Any): Unit = {
+ this.func = func
+ }
+ override def prepare(i: Int): Unit = {}
+ override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 98b2258ea13..b07a1b717e7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -23,15 +23,10 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.ql.exec._
-import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -49,56 +44,26 @@ private[hive] case class HiveSimpleUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression
with HiveInspectors
- with CodegenFallback
- with Logging
with UserDefinedExpression {
- override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic)
-
- override def nullable: Boolean = true
-
- @transient
- lazy val function = funcWrapper.createFunction[UDF]()
-
- @transient
- private lazy val method =
- function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava)
-
- @transient
- private lazy val arguments = children.map(toInspector).toArray
-
@transient
- private lazy val isUDFDeterministic = {
- val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
- udfType != null && udfType.deterministic() && !udfType.stateful()
- }
+ private lazy val evaluator = new HiveSimpleUDFEvaluator(funcWrapper, children)
- override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable)
+ override lazy val deterministic: Boolean =
+ evaluator.isUDFDeterministic && children.forall(_.deterministic)
- // Create parameter converters
- @transient
- private lazy val conversionHelper = new ConversionHelper(method, arguments)
+ override def nullable: Boolean = true
- override lazy val dataType = javaTypeToDataType(method.getGenericReturnType)
+ override def foldable: Boolean = evaluator.isUDFDeterministic && children.forall(_.foldable)
- @transient
- private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
-
- @transient
- lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
- method.getGenericReturnType, ObjectInspectorOptions.JAVA))
-
- @transient
- private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
+ override lazy val dataType: DataType = javaTypeToDataType(evaluator.method.getGenericReturnType)
// TODO: Finish input output types.
override def eval(input: InternalRow): Any = {
- val inputs = wrap(children.map(_.eval(input)), wrappers, cached)
- val ret = FunctionRegistry.invoke(
- method,
- function,
- conversionHelper.convertIfNecessary(inputs : _*): _*)
- unwrapper(ret)
+ children.zipWithIndex.map {
+ case (child, idx) => evaluator.setArg(idx, child.eval(input))
+ }
+ evaluator.evaluate()
}
override def toString: String = {
@@ -111,19 +76,37 @@ private[hive] case class HiveSimpleUDF(
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)
-}
-// Adapter from Catalyst ExpressionResult to Hive DeferredObject
-private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType)
- extends DeferredObject with HiveInspectors {
+ protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
+ val evals = children.map(_.genCode(ctx))
+
+ val setValues = evals.zipWithIndex.map {
+ case (eval, i) =>
+ s"""
+ |if (${eval.isNull}) {
+ | $refEvaluator.setArg($i, null);
+ |} else {
+ | $refEvaluator.setArg($i, ${eval.value});
+ |}
+ |""".stripMargin
+ }
- private val wrapper = wrapperFor(oi, dataType)
- private var func: Any = _
- def set(func: Any): Unit = {
- this.func = func
+ val resultType = CodeGenerator.boxedType(dataType)
+ val resultTerm = ctx.freshName("result")
+ ev.copy(code =
+ code"""
+ |${evals.map(_.code).mkString("\n")}
+ |${setValues.mkString("\n")}
+ |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
+ |boolean ${ev.isNull} = $resultTerm == null;
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ |if (!${ev.isNull}) {
+ | ${ev.value} = $resultTerm;
+ |}
+ |""".stripMargin
+ )
}
- override def prepare(i: Int): Unit = {}
- override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
}
private[hive] case class HiveGenericUDF(
@@ -135,9 +118,9 @@ private[hive] case class HiveGenericUDF(
override def nullable: Boolean = true
override lazy val deterministic: Boolean =
- isUDFDeterministic && children.forall(_.deterministic)
+ evaluator.isUDFDeterministic && children.forall(_.deterministic)
- override def foldable: Boolean = isUDFDeterministic &&
+ override def foldable: Boolean = evaluator.isUDFDeterministic &&
evaluator.returnInspector.isInstanceOf[ConstantObjectInspector]
override lazy val dataType: DataType = inspectorToDataType(evaluator.returnInspector)
@@ -145,12 +128,6 @@ private[hive] case class HiveGenericUDF(
@transient
private lazy val evaluator = new HiveGenericUDFEvaluator(funcWrapper, children)
- @transient
- private val isUDFDeterministic = {
- val udfType = evaluator.function.getClass.getAnnotation(classOf[HiveUDFType])
- udfType != null && udfType.deterministic() && !udfType.stateful()
- }
-
override def eval(input: InternalRow): Any = {
children.zipWithIndex.map {
case (child, idx) => evaluator.setArg(idx, child.eval(input))
@@ -188,18 +165,8 @@ private[hive] case class HiveGenericUDF(
code"""
|${evals.map(_.code).mkString("\n")}
|${setValues.mkString("\n")}
- |$resultType $resultTerm = null;
- |boolean ${ev.isNull} = false;
- |try {
- | $resultTerm = ($resultType) $refEvaluator.evaluate();
- | ${ev.isNull} = $resultTerm == null;
- |} catch (Throwable e) {
- | throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
- | "${funcWrapper.functionClassName}",
- | "${children.map(_.dataType.catalogString).mkString(", ")}",
- | "${dataType.catalogString}",
- | e);
- |}
+ |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
+ |boolean ${ev.isNull} = $resultTerm == null;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
@@ -209,36 +176,6 @@ private[hive] case class HiveGenericUDF(
}
}
-class HiveGenericUDFEvaluator(
- funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
- extends HiveInspectors
- with Serializable {
-
- @transient
- lazy val function = funcWrapper.createFunction[GenericUDF]()
-
- @transient
- private lazy val argumentInspectors = children.map(toInspector)
-
- @transient
- lazy val returnInspector = {
- function.initializeAndFoldConstants(argumentInspectors.toArray)
- }
-
- @transient
- private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map {
- case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
- }.toArray[DeferredObject]
-
- @transient
- private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)
-
- def setArg(index: Int, arg: Any): Unit =
- deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
-
- def evaluate(): Any = unwrapper(function.evaluate(deferredObjects))
-}
-
/**
* Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
* `Generator`. Note that the semantics of Generators do not allow
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index baa25843d48..8fb9209f9cb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.hive.ql.metadata.HiveException
import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
@@ -743,6 +744,38 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
}
}
+
+ test("SPARK-42052: HiveSimpleUDF Codegen Support") {
+ withUserDefinedFunction("CodeGenHiveSimpleUDF" -> false) {
+ sql(s"CREATE FUNCTION CodeGenHiveSimpleUDF AS '${classOf[UDFStringString].getName}'")
+ withTable("HiveSimpleUDFTable") {
+ sql(s"create table HiveSimpleUDFTable as select 'Spark SQL' as v")
+ val df = sql("SELECT CodeGenHiveSimpleUDF('Hello', v) from HiveSimpleUDFTable")
+ val plan = df.queryExecution.executedPlan
+ assert(plan.isInstanceOf[WholeStageCodegenExec])
+ checkAnswer(df, Seq(Row("Hello Spark SQL")))
+ }
+ }
+ }
+
+ test("SPARK-42052: HiveSimpleUDF Codegen Support w/ execution failure") {
+ withUserDefinedFunction("CodeGenHiveSimpleUDF" -> false) {
+ sql(s"CREATE FUNCTION CodeGenHiveSimpleUDF AS '${classOf[SimpleUDFAssertTrue].getName}'")
+ withTable("HiveSimpleUDFTable") {
+ sql(s"create table HiveSimpleUDFTable as select false as v")
+ val df = sql("SELECT CodeGenHiveSimpleUDF(v) from HiveSimpleUDFTable")
+ checkError(
+ exception = intercept[SparkException](df.collect()).getCause.asInstanceOf[SparkException],
+ errorClass = "FAILED_EXECUTE_UDF",
+ parameters = Map(
+ "functionName" -> s"${classOf[SimpleUDFAssertTrue].getName}",
+ "signature" -> "boolean",
+ "result" -> "boolean"
+ )
+ )
+ }
+ }
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {
@@ -844,3 +877,12 @@ class ListFiles extends UDF {
if (fileArray != null) Arrays.asList(fileArray: _*) else new ArrayList[String]()
}
}
+
+class SimpleUDFAssertTrue extends UDF {
+ def evaluate(condition: Boolean): Boolean = {
+ if (!condition) {
+ throw new HiveException("ASSERT_TRUE(): assertion failed.");
+ }
+ condition
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org