You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/08/29 16:09:02 UTC

spark git commit: [SPARK-21848][SQL] Add trait UserDefinedExpression to identify user-defined functions

Repository: spark
Updated Branches:
  refs/heads/master 32fa0b814 -> 8fcbda9c9


[SPARK-21848][SQL] Add trait UserDefinedExpression to identify user-defined functions

## What changes were proposed in this pull request?

Add trait UserDefinedExpression to identify user-defined functions.
UDF can be expensive. In optimizer we may need to avoid executing UDF multiple times.
E.g.
```scala
table.select(UDF as 'a).select('a, ('a + 1) as 'b)
```
If UDF is expensive in this case, optimizer should not collapse the project to
```scala
table.select(UDF as 'a, (UDF+1) as 'b)
```

Currently UDF classes like PythonUDF, HiveGenericUDF are not defined in catalyst.
This PR is to add a new trait to make it easier to identify user-defined functions.

## How was this patch tested?

Unit test

Author: Wang Gengliang <lt...@gmail.com>

Closes #19064 from gengliangwang/UDFType.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8fcbda9c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8fcbda9c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8fcbda9c

Branch: refs/heads/master
Commit: 8fcbda9c93175c0d44b0e4deaf10df1a427e03ea
Parents: 32fa0b8
Author: Wang Gengliang <lt...@gmail.com>
Authored: Tue Aug 29 09:08:59 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Tue Aug 29 09:08:59 2017 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Expression.scala     |  6 ++++++
 .../spark/sql/catalyst/expressions/ScalaUDF.scala |  2 +-
 .../spark/sql/execution/aggregate/udaf.scala      |  6 +++++-
 .../spark/sql/execution/python/PythonUDF.scala    |  4 ++--
 .../org/apache/spark/sql/hive/hiveUDFs.scala      | 18 ++++++++++++++----
 5 files changed, 28 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8fcbda9c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 74c4cdd..c058425 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -635,3 +635,9 @@ abstract class TernaryExpression extends Expression {
     }
   }
 }
+
+/**
+ * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
+ * and Hive function wrappers.
+ */
+trait UserDefinedExpression

http://git-wip-us.apache.org/repos/asf/spark/blob/8fcbda9c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 9df0e2e..527f167 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -47,7 +47,7 @@ case class ScalaUDF(
     udfName: Option[String] = None,
     nullable: Boolean = true,
     udfDeterministic: Boolean = true)
-  extends Expression with ImplicitCastInputTypes with NonSQLExpression {
+  extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
 
   override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8fcbda9c/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index ae5e2c6..fec1add 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -324,7 +324,11 @@ case class ScalaUDAF(
     udaf: UserDefinedAggregateFunction,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes {
+  extends ImperativeAggregate
+  with NonSQLExpression
+  with Logging
+  with ImplicitCastInputTypes
+  with UserDefinedExpression {
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)

http://git-wip-us.apache.org/repos/asf/spark/blob/8fcbda9c/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 59d7e8d..7ebbdb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.python
 
 import org.apache.spark.api.python.PythonFunction
-import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
+import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression}
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -29,7 +29,7 @@ case class PythonUDF(
     func: PythonFunction,
     dataType: DataType,
     children: Seq[Expression])
-  extends Expression with Unevaluable with NonSQLExpression {
+  extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
 
   override def toString: String = s"$name(${children.mkString(", ")})"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8fcbda9c/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index a83ad61..e9bdcf0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -42,7 +42,11 @@ import org.apache.spark.sql.types._
 
 private[hive] case class HiveSimpleUDF(
     name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
-  extends Expression with HiveInspectors with CodegenFallback with Logging {
+  extends Expression
+  with HiveInspectors
+  with CodegenFallback
+  with Logging
+  with UserDefinedExpression {
 
   override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic)
 
@@ -119,7 +123,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
 
 private[hive] case class HiveGenericUDF(
     name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
-  extends Expression with HiveInspectors with CodegenFallback with Logging {
+  extends Expression
+  with HiveInspectors
+  with CodegenFallback
+  with Logging
+  with UserDefinedExpression {
 
   override def nullable: Boolean = true
 
@@ -191,7 +199,7 @@ private[hive] case class HiveGenericUDTF(
     name: String,
     funcWrapper: HiveFunctionWrapper,
     children: Seq[Expression])
-  extends Generator with HiveInspectors with CodegenFallback {
+  extends Generator with HiveInspectors with CodegenFallback with UserDefinedExpression {
 
   @transient
   protected lazy val function: GenericUDTF = {
@@ -303,7 +311,9 @@ private[hive] case class HiveUDAFFunction(
     isUDAFBridgeRequired: Boolean = false,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {
+  extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
+  with HiveInspectors
+  with UserDefinedExpression {
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org