You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/06/30 00:27:24 UTC
[2/4] spark git commit: [SPARK-8478] [SQL] Harmonize UDF-related code
to use uniformly UDF instead of Udf
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
new file mode 100644
index 0000000..9e1cff0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -0,0 +1,292 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution
+
+import java.util.{List => JList, Map => JMap}
+
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+
+import net.razorvine.pickle.{Pickler, Unpickler}
+
+import org.apache.spark.{Accumulator, Logging => SparkLogging}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]].
+ */
+private[spark] case class PythonUDF(
+ name: String,
+ command: Array[Byte],
+ envVars: JMap[String, String],
+ pythonIncludes: JList[String],
+ pythonExec: String,
+ pythonVer: String,
+ broadcastVars: JList[Broadcast[PythonBroadcast]],
+ accumulator: Accumulator[JList[Array[Byte]]],
+ dataType: DataType,
+ children: Seq[Expression]) extends Expression with SparkLogging {
+
+ override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
+
+ override def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any = {
+ throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")
+ }
+}
+
+/**
+ * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
+ * alone in a batch.
+ *
+ * This has the limitation that the input to the Python UDF is not allowed include attributes from
+ * multiple child operators.
+ */
+private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // Skip EvaluatePython nodes.
+ case plan: EvaluatePython => plan
+
+ case plan: LogicalPlan if plan.resolved =>
+ // Extract any PythonUDFs from the current operator.
+ val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
+ if (udfs.isEmpty) {
+ // If there aren't any, we are done.
+ plan
+ } else {
+ // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
+ // If there is more than one, we will add another evaluation operator in a subsequent pass.
+ udfs.find(_.resolved) match {
+ case Some(udf) =>
+ var evaluation: EvaluatePython = null
+
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ // Check to make sure that the UDF can be evaluated with only the input of this child.
+ // Other cases are disallowed as they are ambiguous or would require a cartesian
+ // product.
+ if (udf.references.subsetOf(child.outputSet)) {
+ evaluation = EvaluatePython(udf, child)
+ evaluation
+ } else if (udf.references.intersect(child.outputSet).nonEmpty) {
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+ } else {
+ child
+ }
+ }
+
+ assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")
+
+ // Trim away the new UDF value if it was only used for filtering or something.
+ logical.Project(
+ plan.output,
+ plan.transformExpressions {
+ case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
+ }.withNewChildren(newChildren))
+
+ case None =>
+ // If there is no Python UDF that is resolved, skip this round.
+ plan
+ }
+ }
+ }
+}
+
+object EvaluatePython {
+ def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
+ new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
+
+ /**
+ * Helper for converting a Scala object to a java suitable for pyspark serialization.
+ */
+ def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+ case (null, _) => null
+
+ case (row: Row, struct: StructType) =>
+ val fields = struct.fields.map(field => field.dataType)
+ row.toSeq.zip(fields).map {
+ case (obj, dataType) => toJava(obj, dataType)
+ }.toArray
+
+ case (seq: Seq[Any], array: ArrayType) =>
+ seq.map(x => toJava(x, array.elementType)).asJava
+ case (list: JList[_], array: ArrayType) =>
+ list.map(x => toJava(x, array.elementType)).asJava
+ case (arr, array: ArrayType) if arr.getClass.isArray =>
+ arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
+
+ case (obj: Map[_, _], mt: MapType) => obj.map {
+ case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
+ }.asJava
+
+ case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
+
+ case (date: Int, DateType) => DateTimeUtils.toJavaDate(date)
+ case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t)
+ case (s: UTF8String, StringType) => s.toString
+
+ // Pyrolite can handle Timestamp and Decimal
+ case (other, _) => other
+ }
+
+ /**
+ * Convert Row into Java Array (for pickled into Python)
+ */
+ def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
+ // TODO: this is slow!
+ row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
+ }
+
+ // Converts value to the type specified by the data type.
+ // Because Python does not have data types for TimestampType, FloatType, ShortType, and
+ // ByteType, we need to explicitly convert values in columns of these data types to the desired
+ // JVM data types.
+ def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+ // TODO: We should check nullable
+ case (null, _) => null
+
+ case (c: java.util.List[_], ArrayType(elementType, _)) =>
+ c.map { e => fromJava(e, elementType)}: Seq[Any]
+
+ case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
+ c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any]
+
+ case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
+ case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
+ }.toMap
+
+ case (c, StructType(fields)) if c.getClass.isArray =>
+ new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
+ case (e, f) => fromJava(e, f.dataType)
+ })
+
+ case (c: java.util.Calendar, DateType) =>
+ DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis))
+
+ case (c: java.util.Calendar, TimestampType) =>
+ c.getTimeInMillis * 10000L
+ case (t: java.sql.Timestamp, TimestampType) =>
+ DateTimeUtils.fromJavaTimestamp(t)
+
+ case (_, udt: UserDefinedType[_]) =>
+ fromJava(obj, udt.sqlType)
+
+ case (c: Int, ByteType) => c.toByte
+ case (c: Long, ByteType) => c.toByte
+ case (c: Int, ShortType) => c.toShort
+ case (c: Long, ShortType) => c.toShort
+ case (c: Long, IntegerType) => c.toInt
+ case (c: Int, LongType) => c.toLong
+ case (c: Double, FloatType) => c.toFloat
+ case (c: String, StringType) => UTF8String.fromString(c)
+ case (c, StringType) =>
+ // If we get here, c is not a string. Call toString on it.
+ UTF8String.fromString(c.toString)
+
+ case (c, _) => c
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
+ */
+@DeveloperApi
+case class EvaluatePython(
+ udf: PythonUDF,
+ child: LogicalPlan,
+ resultAttribute: AttributeReference)
+ extends logical.UnaryNode {
+
+ def output: Seq[Attribute] = child.output :+ resultAttribute
+
+ // References should not include the produced attribute.
+ override def references: AttributeSet = udf.references
+}
+
+/**
+ * :: DeveloperApi ::
+ * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time.
+ * The input data is zipped with the result of the udf evaluation.
+ */
+@DeveloperApi
+case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
+ extends SparkPlan {
+
+ def children: Seq[SparkPlan] = child :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val childResults = child.execute().map(_.copy())
+
+ val parent = childResults.mapPartitions { iter =>
+ val pickle = new Pickler
+ val currentRow = newMutableProjection(udf.children, child.output)()
+ val fields = udf.children.map(_.dataType)
+ iter.grouped(1000).map { inputRows =>
+ val toBePickled = inputRows.map { row =>
+ EvaluatePython.rowToArray(currentRow(row), fields)
+ }.toArray
+ pickle.dumps(toBePickled)
+ }
+ }
+
+ val pyRDD = new PythonRDD(
+ parent,
+ udf.command,
+ udf.envVars,
+ udf.pythonIncludes,
+ false,
+ udf.pythonExec,
+ udf.pythonVer,
+ udf.broadcastVars,
+ udf.accumulator
+ ).mapPartitions { iter =>
+ val pickle = new Unpickler
+ iter.flatMap { pickedResult =>
+ val unpickledBatch = pickle.loads(pickedResult)
+ unpickledBatch.asInstanceOf[java.util.ArrayList[Any]]
+ }
+ }.mapPartitions { iter =>
+ val row = new GenericMutableRow(1)
+ iter.map { result =>
+ row(0) = EvaluatePython.fromJava(result, udf.dataType)
+ row: InternalRow
+ }
+ }
+
+ childResults.zip(pyRDD).mapPartitions { iter =>
+ val joinedRow = new JoinedRow()
+ iter.map {
+ case (row, udfResult) =>
+ joinedRow(row, udfResult)
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
deleted file mode 100644
index 036f5d2..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ /dev/null
@@ -1,292 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql.execution
-
-import java.util.{List => JList, Map => JMap}
-
-import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
-
-import net.razorvine.pickle.{Pickler, Unpickler}
-
-import org.apache.spark.{Accumulator, Logging => SparkLogging}
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-/**
- * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]].
- */
-private[spark] case class PythonUDF(
- name: String,
- command: Array[Byte],
- envVars: JMap[String, String],
- pythonIncludes: JList[String],
- pythonExec: String,
- pythonVer: String,
- broadcastVars: JList[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[JList[Array[Byte]]],
- dataType: DataType,
- children: Seq[Expression]) extends Expression with SparkLogging {
-
- override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
-
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")
- }
-}
-
-/**
- * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
- * alone in a batch.
- *
- * This has the limitation that the input to the Python UDF is not allowed include attributes from
- * multiple child operators.
- */
-private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // Skip EvaluatePython nodes.
- case plan: EvaluatePython => plan
-
- case plan: LogicalPlan if plan.resolved =>
- // Extract any PythonUDFs from the current operator.
- val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
- if (udfs.isEmpty) {
- // If there aren't any, we are done.
- plan
- } else {
- // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
- // If there is more than one, we will add another evaluation operator in a subsequent pass.
- udfs.find(_.resolved) match {
- case Some(udf) =>
- var evaluation: EvaluatePython = null
-
- // Rewrite the child that has the input required for the UDF
- val newChildren = plan.children.map { child =>
- // Check to make sure that the UDF can be evaluated with only the input of this child.
- // Other cases are disallowed as they are ambiguous or would require a cartesian
- // product.
- if (udf.references.subsetOf(child.outputSet)) {
- evaluation = EvaluatePython(udf, child)
- evaluation
- } else if (udf.references.intersect(child.outputSet).nonEmpty) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- child
- }
- }
-
- assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")
-
- // Trim away the new UDF value if it was only used for filtering or something.
- logical.Project(
- plan.output,
- plan.transformExpressions {
- case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
- }.withNewChildren(newChildren))
-
- case None =>
- // If there is no Python UDF that is resolved, skip this round.
- plan
- }
- }
- }
-}
-
-object EvaluatePython {
- def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
- new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
-
- /**
- * Helper for converting a Scala object to a java suitable for pyspark serialization.
- */
- def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
- case (null, _) => null
-
- case (row: Row, struct: StructType) =>
- val fields = struct.fields.map(field => field.dataType)
- row.toSeq.zip(fields).map {
- case (obj, dataType) => toJava(obj, dataType)
- }.toArray
-
- case (seq: Seq[Any], array: ArrayType) =>
- seq.map(x => toJava(x, array.elementType)).asJava
- case (list: JList[_], array: ArrayType) =>
- list.map(x => toJava(x, array.elementType)).asJava
- case (arr, array: ArrayType) if arr.getClass.isArray =>
- arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
-
- case (obj: Map[_, _], mt: MapType) => obj.map {
- case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
- }.asJava
-
- case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
-
- case (date: Int, DateType) => DateTimeUtils.toJavaDate(date)
- case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t)
- case (s: UTF8String, StringType) => s.toString
-
- // Pyrolite can handle Timestamp and Decimal
- case (other, _) => other
- }
-
- /**
- * Convert Row into Java Array (for pickled into Python)
- */
- def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
- // TODO: this is slow!
- row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
- }
-
- // Converts value to the type specified by the data type.
- // Because Python does not have data types for TimestampType, FloatType, ShortType, and
- // ByteType, we need to explicitly convert values in columns of these data types to the desired
- // JVM data types.
- def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
- // TODO: We should check nullable
- case (null, _) => null
-
- case (c: java.util.List[_], ArrayType(elementType, _)) =>
- c.map { e => fromJava(e, elementType)}: Seq[Any]
-
- case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
- c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any]
-
- case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
- case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
- }.toMap
-
- case (c, StructType(fields)) if c.getClass.isArray =>
- new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
- case (e, f) => fromJava(e, f.dataType)
- })
-
- case (c: java.util.Calendar, DateType) =>
- DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis))
-
- case (c: java.util.Calendar, TimestampType) =>
- c.getTimeInMillis * 10000L
- case (t: java.sql.Timestamp, TimestampType) =>
- DateTimeUtils.fromJavaTimestamp(t)
-
- case (_, udt: UserDefinedType[_]) =>
- fromJava(obj, udt.sqlType)
-
- case (c: Int, ByteType) => c.toByte
- case (c: Long, ByteType) => c.toByte
- case (c: Int, ShortType) => c.toShort
- case (c: Long, ShortType) => c.toShort
- case (c: Long, IntegerType) => c.toInt
- case (c: Int, LongType) => c.toLong
- case (c: Double, FloatType) => c.toFloat
- case (c: String, StringType) => UTF8String.fromString(c)
- case (c, StringType) =>
- // If we get here, c is not a string. Call toString on it.
- UTF8String.fromString(c.toString)
-
- case (c, _) => c
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
- */
-@DeveloperApi
-case class EvaluatePython(
- udf: PythonUDF,
- child: LogicalPlan,
- resultAttribute: AttributeReference)
- extends logical.UnaryNode {
-
- def output: Seq[Attribute] = child.output :+ resultAttribute
-
- // References should not include the produced attribute.
- override def references: AttributeSet = udf.references
-}
-
-/**
- * :: DeveloperApi ::
- * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time.
- * The input data is zipped with the result of the udf evaluation.
- */
-@DeveloperApi
-case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
- extends SparkPlan {
-
- def children: Seq[SparkPlan] = child :: Nil
-
- protected override def doExecute(): RDD[InternalRow] = {
- val childResults = child.execute().map(_.copy())
-
- val parent = childResults.mapPartitions { iter =>
- val pickle = new Pickler
- val currentRow = newMutableProjection(udf.children, child.output)()
- val fields = udf.children.map(_.dataType)
- iter.grouped(1000).map { inputRows =>
- val toBePickled = inputRows.map { row =>
- EvaluatePython.rowToArray(currentRow(row), fields)
- }.toArray
- pickle.dumps(toBePickled)
- }
- }
-
- val pyRDD = new PythonRDD(
- parent,
- udf.command,
- udf.envVars,
- udf.pythonIncludes,
- false,
- udf.pythonExec,
- udf.pythonVer,
- udf.broadcastVars,
- udf.accumulator
- ).mapPartitions { iter =>
- val pickle = new Unpickler
- iter.flatMap { pickedResult =>
- val unpickledBatch = pickle.loads(pickedResult)
- unpickledBatch.asInstanceOf[java.util.ArrayList[Any]]
- }
- }.mapPartitions { iter =>
- val row = new GenericMutableRow(1)
- iter.map { result =>
- row(0) = EvaluatePython.fromJava(result, udf.dataType)
- row: InternalRow
- }
- }
-
- childResults.zip(pyRDD).mapPartitions { iter =>
- val joinedRow = new JoinedRow()
- iter.map {
- case (row, udfResult) =>
- joinedRow(row, udfResult)
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5422e06..4d9a019 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1509,7 +1509,7 @@ object functions {
(0 to 10).map { x =>
val args = (1 to x).map(i => s"arg$i: Column").mkString(", ")
val fTypes = Seq.fill(x + 1)("_").mkString(", ")
- val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ")
+ val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ")
println(s"""
/**
* Call a Scala function of ${x} arguments as user-defined function (UDF). This requires
@@ -1521,7 +1521,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = {
- ScalaUdf(f, returnType, Seq($argsInUdf))
+ ScalaUDF(f, returnType, Seq($argsInUDF))
}""")
}
}
@@ -1659,7 +1659,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function0[_], returnType: DataType): Column = {
- ScalaUdf(f, returnType, Seq())
+ ScalaUDF(f, returnType, Seq())
}
/**
@@ -1672,7 +1672,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr))
}
/**
@@ -1685,7 +1685,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr))
}
/**
@@ -1698,7 +1698,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
}
/**
@@ -1711,7 +1711,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
}
/**
@@ -1724,7 +1724,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
}
/**
@@ -1737,7 +1737,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
}
/**
@@ -1750,7 +1750,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
}
/**
@@ -1763,7 +1763,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
}
/**
@@ -1776,7 +1776,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
}
/**
@@ -1789,7 +1789,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
- ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
+ ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
}
// scalastyle:on
@@ -1802,8 +1802,8 @@ object functions {
*
* val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
* val sqlContext = df.sqlContext
- * sqlContext.udf.register("simpleUdf", (v: Int) => v * v)
- * df.select($"id", callUDF("simpleUdf", $"value"))
+ * sqlContext.udf.register("simpleUDF", (v: Int) => v * v)
+ * df.select($"id", callUDF("simpleUDF", $"value"))
* }}}
*
* @group udf_funcs
@@ -1821,8 +1821,8 @@ object functions {
*
* val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
* val sqlContext = df.sqlContext
- * sqlContext.udf.register("simpleUdf", (v: Int) => v * v)
- * df.select($"id", callUdf("simpleUdf", $"value"))
+ * sqlContext.udf.register("simpleUDF", (v: Int) => v * v)
+ * df.select($"id", callUdf("simpleUDF", $"value"))
* }}}
*
* @group udf_funcs
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 22c54e4..82dc0e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -140,9 +140,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
// we except the id is materialized once
- val idUdf = udf(() => UUID.randomUUID().toString)
+ val idUDF = udf(() => UUID.randomUUID().toString)
- val dfWithId = df.withColumn("id", idUdf())
+ val dfWithId = df.withColumn("id", idUDF())
// Make a new DataFrame (actually the same reference to the old one)
val cached = dfWithId.cache()
// Trigger the cache
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 8021f91..b91242a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.SQLConf.SQLConfEntry._
import org.apache.spark.sql.catalyst.ParserDialect
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand}
+import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand}
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
import org.apache.spark.sql.sources.DataSourceStrategy
@@ -381,7 +381,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.ParquetConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
- ExtractPythonUdfs ::
+ ExtractPythonUDFs ::
ResolveHiveWindowFunction ::
sources.PreInsertCastAndRename ::
Nil
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 7c46209..2de7a99 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1638,7 +1638,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
sys.error(s"Couldn't find function $functionName"))
val functionClassName = functionInfo.getFunctionClass.getName
- (HiveGenericUdtf(
+ (HiveGenericUDTF(
new HiveFunctionWrapper(functionClassName),
children.map(nodeToExpr)), attributes)
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
new file mode 100644
index 0000000..d7827d5
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -0,0 +1,598 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.JavaConversions._
+import scala.util.Try
+
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector}
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
+import org.apache.hadoop.hive.ql.exec._
+import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
+import org.apache.hadoop.hive.ql.udf.generic._
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.hive.HiveShim._
+import org.apache.spark.sql.types._
+
+
+private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
+ extends analysis.FunctionRegistry with HiveInspectors {
+
+ def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
+
+ override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ Try(underlying.lookupFunction(name, children)).getOrElse {
+ // We only look it up to see if it exists, but do not include it in the HiveUDF since it is
+ // not always serializable.
+ val functionInfo: FunctionInfo =
+ Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
+ throw new AnalysisException(s"undefined function $name"))
+
+ val functionClassName = functionInfo.getFunctionClass.getName
+
+ if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children)
+ } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children)
+ } else if (
+ classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children)
+ } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveUDAF(new HiveFunctionWrapper(functionClassName), children)
+ } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children)
+ } else {
+ sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
+ }
+ }
+ }
+
+ override def registerFunction(name: String, builder: FunctionBuilder): Unit =
+ throw new UnsupportedOperationException
+}
+
+private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+ extends Expression with HiveInspectors with Logging {
+
+ type UDFType = UDF
+
+ override def deterministic: Boolean = isUDFDeterministic
+
+ override def nullable: Boolean = true
+
+ @transient
+ lazy val function = funcWrapper.createFunction[UDFType]()
+
+ @transient
+ protected lazy val method =
+ function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
+
+ @transient
+ protected lazy val arguments = children.map(toInspector).toArray
+
+ @transient
+ protected lazy val isUDFDeterministic = {
+ val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
+ udfType != null && udfType.deterministic()
+ }
+
+ override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable)
+
+ // Create parameter converters
+ @transient
+ protected lazy val conversionHelper = new ConversionHelper(method, arguments)
+
+ @transient
+ lazy val dataType = javaClassToDataType(method.getReturnType)
+
+ @transient
+ lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
+ method.getGenericReturnType(), ObjectInspectorOptions.JAVA)
+
+ @transient
+ protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
+
+ override def isThreadSafe: Boolean = false
+
+ // TODO: Finish input output types.
+ override def eval(input: InternalRow): Any = {
+ unwrap(
+ FunctionRegistry.invoke(method, function, conversionHelper
+ .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
+ returnInspector)
+ }
+
+ override def toString: String = {
+ s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
+ }
+}
+
+// Adapter from Catalyst ExpressionResult to Hive DeferredObject
+private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
+ extends DeferredObject with HiveInspectors {
+ private var func: () => Any = _
+ def set(func: () => Any): Unit = {
+ this.func = func
+ }
+ override def prepare(i: Int): Unit = {}
+ override def get(): AnyRef = wrap(func(), oi)
+}
+
+private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+ extends Expression with HiveInspectors with Logging {
+ type UDFType = GenericUDF
+
+ override def deterministic: Boolean = isUDFDeterministic
+
+ override def nullable: Boolean = true
+
+ @transient
+ lazy val function = funcWrapper.createFunction[UDFType]()
+
+ @transient
+ protected lazy val argumentInspectors = children.map(toInspector)
+
+ @transient
+ protected lazy val returnInspector = {
+ function.initializeAndFoldConstants(argumentInspectors.toArray)
+ }
+
+ @transient
+ protected lazy val isUDFDeterministic = {
+ val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
+ (udfType != null && udfType.deterministic())
+ }
+
+ override def foldable: Boolean =
+ isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
+
+ @transient
+ protected lazy val deferedObjects =
+ argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject]
+
+ lazy val dataType: DataType = inspectorToDataType(returnInspector)
+
+ override def isThreadSafe: Boolean = false
+
+ override def eval(input: InternalRow): Any = {
+ returnInspector // Make sure initialized.
+
+ var i = 0
+ while (i < children.length) {
+ val idx = i
+ deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(
+ () => {
+ children(idx).eval(input)
+ })
+ i += 1
+ }
+ unwrap(function.evaluate(deferedObjects), returnInspector)
+ }
+
+ override def toString: String = {
+ s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
+ }
+}
+
+/**
+ * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]].
+ */
+private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case p: LogicalPlan if !p.childrenResolved => p
+
+ // We are resolving WindowExpressions at here. When we get here, we have already
+ // replaced those WindowSpecReferences.
+ case p: LogicalPlan =>
+ p transformExpressions {
+ case WindowExpression(
+ UnresolvedWindowFunction(name, children),
+ windowSpec: WindowSpecDefinition) =>
+ // First, let's find the window function info.
+ val windowFunctionInfo: WindowFunctionInfo =
+ Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse(
+ throw new AnalysisException(s"Couldn't find window function $name"))
+
+ // Get the class of this function.
+ // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use
+ // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1.
+ val functionClass = windowFunctionInfo.getfInfo().getFunctionClass
+ val newChildren =
+ // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit
+ // input parameters and requires implicit parameters, which
+ // are expressions in Order By clause.
+ if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) {
+ if (children.nonEmpty) {
+ throw new AnalysisException(s"$name does not take input parameters.")
+ }
+ windowSpec.orderSpec.map(_.child)
+ } else {
+ children
+ }
+
+ // If the class is UDAF, we need to use UDAFBridge.
+ val isUDAFBridgeRequired =
+ if (classOf[UDAF].isAssignableFrom(functionClass)) {
+ true
+ } else {
+ false
+ }
+
+ // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of
+ // HiveWindowFunction.
+ val windowFunction =
+ HiveWindowFunction(
+ new HiveFunctionWrapper(functionClass.getName),
+ windowFunctionInfo.isPivotResult,
+ isUDAFBridgeRequired,
+ newChildren)
+
+ // Second, check if the specified window function can accept window definition.
+ windowSpec.frameSpecification match {
+ case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow =>
+ // This Hive window function does not support user-speficied window frame.
+ throw new AnalysisException(
+ s"Window function $name does not take a frame specification.")
+ case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow &&
+ windowFunctionInfo.isPivotResult =>
+ // These two should not be true at the same time when a window frame is defined.
+ // If so, throw an exception.
+ throw new AnalysisException(s"Could not handle Hive window function $name because " +
+ s"it supports both a user specified window frame and pivot result.")
+ case _ => // OK
+ }
+ // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs
+ // a window frame specification to work.
+ val newWindowSpec = windowSpec.frameSpecification match {
+ case UnspecifiedFrame =>
+ val newWindowFrame =
+ SpecifiedWindowFrame.defaultWindowFrame(
+ windowSpec.orderSpec.nonEmpty,
+ windowFunctionInfo.isSupportsWindow)
+ WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame)
+ case _ => windowSpec
+ }
+
+ // Finally, we create a WindowExpression with the resolved window function and
+ // specified window spec.
+ WindowExpression(windowFunction, newWindowSpec)
+ }
+ }
+}
+
+/**
+ * A [[WindowFunction]] implementation wrapping Hive's window function.
+ * @param funcWrapper The wrapper for the Hive Window Function.
+ * @param pivotResult If it is true, the Hive function will return a list of values representing
+ * the values of the added columns. Otherwise, a single value is returned for
+ * current row.
+ * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's
+ * createFunction is UDAF, we need to use GenericUDAFBridge to wrap
+ * it as a GenericUDAFResolver2.
+ * @param children Input parameters.
+ */
+private[hive] case class HiveWindowFunction(
+ funcWrapper: HiveFunctionWrapper,
+ pivotResult: Boolean,
+ isUDAFBridgeRequired: Boolean,
+ children: Seq[Expression]) extends WindowFunction
+ with HiveInspectors {
+
+ // Hive window functions are based on GenericUDAFResolver2.
+ type UDFType = GenericUDAFResolver2
+
+ @transient
+ protected lazy val resolver: GenericUDAFResolver2 =
+ if (isUDAFBridgeRequired) {
+ new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
+ } else {
+ funcWrapper.createFunction[GenericUDAFResolver2]()
+ }
+
+ @transient
+ protected lazy val inputInspectors = children.map(toInspector).toArray
+
+ // The GenericUDAFEvaluator used to evaluate the window function.
+ @transient
+ protected lazy val evaluator: GenericUDAFEvaluator = {
+ val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false)
+ resolver.getEvaluator(parameterInfo)
+ }
+
+ // The object inspector of values returned from the Hive window function.
+ @transient
+ protected lazy val returnInspector = {
+ evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors)
+ }
+
+ def dataType: DataType =
+ if (!pivotResult) {
+ inspectorToDataType(returnInspector)
+ } else {
+ // If pivotResult is true, we should take the element type out as the data type of this
+ // function.
+ inspectorToDataType(returnInspector) match {
+ case ArrayType(dt, _) => dt
+ case _ =>
+ sys.error(
+ s"error resolve the data type of window function ${funcWrapper.functionClassName}")
+ }
+ }
+
+ def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
+ @transient
+ lazy val inputProjection = new InterpretedProjection(children)
+
+ @transient
+ private var hiveEvaluatorBuffer: AggregationBuffer = _
+ // Output buffer.
+ private var outputBuffer: Any = _
+
+ override def init(): Unit = {
+ evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors)
+ }
+
+ // Reset the hiveEvaluatorBuffer and outputPosition
+ override def reset(): Unit = {
+ // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber.
+ // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init.
+ // However, RowNumberBuffer.init does not really reset this buffer.
+ hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer
+ evaluator.reset(hiveEvaluatorBuffer)
+ }
+
+ override def prepareInputParameters(input: InternalRow): AnyRef = {
+ wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length))
+ }
+ // Add input parameters for a single row.
+ override def update(input: AnyRef): Unit = {
+ evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]])
+ }
+
+ override def batchUpdate(inputs: Array[AnyRef]): Unit = {
+ var i = 0
+ while (i < inputs.length) {
+ evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]])
+ i += 1
+ }
+ }
+
+ override def evaluate(): Unit = {
+ outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector)
+ }
+
+ override def get(index: Int): Any = {
+ if (!pivotResult) {
+ // if pivotResult is false, we will get a single value for all rows in the frame.
+ outputBuffer
+ } else {
+ // if pivotResult is true, we will get a Seq having the same size with the size
+ // of the window frame. At here, we will return the result at the position of
+ // index in the output buffer.
+ outputBuffer.asInstanceOf[Seq[Any]].get(index)
+ }
+ }
+
+ override def toString: String = {
+ s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
+ }
+
+ override def newInstance: WindowFunction =
+ new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children)
+}
+
+private[hive] case class HiveGenericUDAF(
+ funcWrapper: HiveFunctionWrapper,
+ children: Seq[Expression]) extends AggregateExpression
+ with HiveInspectors {
+
+ type UDFType = AbstractGenericUDAFResolver
+
+ @transient
+ protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()
+
+ @transient
+ protected lazy val objectInspector = {
+ val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
+ resolver.getEvaluator(parameterInfo)
+ .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
+ }
+
+ @transient
+ protected lazy val inspectors = children.map(toInspector)
+
+ def dataType: DataType = inspectorToDataType(objectInspector)
+
+ def nullable: Boolean = true
+
+ override def toString: String = {
+ s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
+ }
+
+ def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this)
+}
+
+/** It is used as a wrapper for the hive functions which uses UDAF interface */
+private[hive] case class HiveUDAF(
+ funcWrapper: HiveFunctionWrapper,
+ children: Seq[Expression]) extends AggregateExpression
+ with HiveInspectors {
+
+ type UDFType = UDAF
+
+ @transient
+ protected lazy val resolver: AbstractGenericUDAFResolver =
+ new GenericUDAFBridge(funcWrapper.createFunction())
+
+ @transient
+ protected lazy val objectInspector = {
+ val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
+ resolver.getEvaluator(parameterInfo)
+ .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
+ }
+
+ @transient
+ protected lazy val inspectors = children.map(toInspector)
+
+ def dataType: DataType = inspectorToDataType(objectInspector)
+
+ def nullable: Boolean = true
+
+ override def toString: String = {
+ s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
+ }
+
+ def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true)
+}
+
+/**
+ * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
+ * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow
+ * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning
+ * dependent operations like calls to `close()` before producing output will not operate the same as
+ * in Hive. However, in practice this should not affect compatibility for most sane UDTFs
+ * (e.g. explode or GenericUDTFParseUrlTuple).
+ *
+ * Operators that require maintaining state in between input rows should instead be implemented as
+ * user defined aggregations, which have clean semantics even in a partitioned execution.
+ */
+private[hive] case class HiveGenericUDTF(
+ funcWrapper: HiveFunctionWrapper,
+ children: Seq[Expression])
+ extends Generator with HiveInspectors {
+
+ @transient
+ protected lazy val function: GenericUDTF = {
+ val fun: GenericUDTF = funcWrapper.createFunction()
+ fun.setCollector(collector)
+ fun
+ }
+
+ @transient
+ protected lazy val inputInspectors = children.map(toInspector)
+
+ @transient
+ protected lazy val outputInspector = function.initialize(inputInspectors.toArray)
+
+ @transient
+ protected lazy val udtInput = new Array[AnyRef](children.length)
+
+ @transient
+ protected lazy val collector = new UDTFCollector
+
+ lazy val elementTypes = outputInspector.getAllStructFieldRefs.map {
+ field => (inspectorToDataType(field.getFieldObjectInspector), true)
+ }
+
+ override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
+ outputInspector // Make sure initialized.
+
+ val inputProjection = new InterpretedProjection(children)
+
+ function.process(wrap(inputProjection(input), inputInspectors, udtInput))
+ collector.collectRows()
+ }
+
+ protected class UDTFCollector extends Collector {
+ var collected = new ArrayBuffer[InternalRow]
+
+ override def collect(input: java.lang.Object) {
+ // We need to clone the input here because implementations of
+ // GenericUDTF reuse the same object. Luckily they are always an array, so
+ // it is easy to clone.
+ collected += unwrap(input, outputInspector).asInstanceOf[InternalRow]
+ }
+
+ def collectRows(): Seq[InternalRow] = {
+ val toCollect = collected
+ collected = new ArrayBuffer[InternalRow]
+ toCollect
+ }
+ }
+
+ override def terminate(): TraversableOnce[InternalRow] = {
+ outputInspector // Make sure initialized.
+ function.close()
+ collector.collectRows()
+ }
+
+ override def toString: String = {
+ s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
+ }
+}
+
+private[hive] case class HiveUDAFFunction(
+ funcWrapper: HiveFunctionWrapper,
+ exprs: Seq[Expression],
+ base: AggregateExpression,
+ isUDAFBridgeRequired: Boolean = false)
+ extends AggregateFunction
+ with HiveInspectors {
+
+ def this() = this(null, null, null)
+
+ private val resolver =
+ if (isUDAFBridgeRequired) {
+ new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
+ } else {
+ funcWrapper.createFunction[AbstractGenericUDAFResolver]()
+ }
+
+ private val inspectors = exprs.map(toInspector).toArray
+
+ private val function = {
+ val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
+ resolver.getEvaluator(parameterInfo)
+ }
+
+ private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
+
+ private val buffer =
+ function.getNewAggregationBuffer
+
+ override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector)
+
+ @transient
+ val inputProjection = new InterpretedProjection(exprs)
+
+ @transient
+ protected lazy val cached = new Array[AnyRef](exprs.length)
+
+ def update(input: InternalRow): Unit = {
+ val inputs = inputProjection(input)
+ function.iterate(buffer, wrap(inputs, inspectors, cached))
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
deleted file mode 100644
index 4986b1e..0000000
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ /dev/null
@@ -1,598 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.JavaConversions._
-import scala.util.Try
-
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector}
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
-import org.apache.hadoop.hive.ql.exec._
-import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
-import org.apache.hadoop.hive.ql.udf.generic._
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.hive.HiveShim._
-import org.apache.spark.sql.types._
-
-
-private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
- extends analysis.FunctionRegistry with HiveInspectors {
-
- def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
-
- override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
- Try(underlying.lookupFunction(name, children)).getOrElse {
- // We only look it up to see if it exists, but do not include it in the HiveUDF since it is
- // not always serializable.
- val functionInfo: FunctionInfo =
- Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
- throw new AnalysisException(s"undefined function $name"))
-
- val functionClassName = functionInfo.getFunctionClass.getName
-
- if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
- } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
- } else if (
- classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
- } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
- } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
- } else {
- sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
- }
- }
- }
-
- override def registerFunction(name: String, builder: FunctionBuilder): Unit =
- throw new UnsupportedOperationException
-}
-
-private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
- extends Expression with HiveInspectors with Logging {
-
- type UDFType = UDF
-
- override def deterministic: Boolean = isUDFDeterministic
-
- override def nullable: Boolean = true
-
- @transient
- lazy val function = funcWrapper.createFunction[UDFType]()
-
- @transient
- protected lazy val method =
- function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
-
- @transient
- protected lazy val arguments = children.map(toInspector).toArray
-
- @transient
- protected lazy val isUDFDeterministic = {
- val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
- udfType != null && udfType.deterministic()
- }
-
- override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable)
-
- // Create parameter converters
- @transient
- protected lazy val conversionHelper = new ConversionHelper(method, arguments)
-
- @transient
- lazy val dataType = javaClassToDataType(method.getReturnType)
-
- @transient
- lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
- method.getGenericReturnType(), ObjectInspectorOptions.JAVA)
-
- @transient
- protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
-
- override def isThreadSafe: Boolean = false
-
- // TODO: Finish input output types.
- override def eval(input: InternalRow): Any = {
- unwrap(
- FunctionRegistry.invoke(method, function, conversionHelper
- .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
- returnInspector)
- }
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-}
-
-// Adapter from Catalyst ExpressionResult to Hive DeferredObject
-private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
- extends DeferredObject with HiveInspectors {
- private var func: () => Any = _
- def set(func: () => Any): Unit = {
- this.func = func
- }
- override def prepare(i: Int): Unit = {}
- override def get(): AnyRef = wrap(func(), oi)
-}
-
-private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
- extends Expression with HiveInspectors with Logging {
- type UDFType = GenericUDF
-
- override def deterministic: Boolean = isUDFDeterministic
-
- override def nullable: Boolean = true
-
- @transient
- lazy val function = funcWrapper.createFunction[UDFType]()
-
- @transient
- protected lazy val argumentInspectors = children.map(toInspector)
-
- @transient
- protected lazy val returnInspector = {
- function.initializeAndFoldConstants(argumentInspectors.toArray)
- }
-
- @transient
- protected lazy val isUDFDeterministic = {
- val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
- (udfType != null && udfType.deterministic())
- }
-
- override def foldable: Boolean =
- isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
-
- @transient
- protected lazy val deferedObjects =
- argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject]
-
- lazy val dataType: DataType = inspectorToDataType(returnInspector)
-
- override def isThreadSafe: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- returnInspector // Make sure initialized.
-
- var i = 0
- while (i < children.length) {
- val idx = i
- deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(
- () => {
- children(idx).eval(input)
- })
- i += 1
- }
- unwrap(function.evaluate(deferedObjects), returnInspector)
- }
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-}
-
-/**
- * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]].
- */
-private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case p: LogicalPlan if !p.childrenResolved => p
-
- // We are resolving WindowExpressions at here. When we get here, we have already
- // replaced those WindowSpecReferences.
- case p: LogicalPlan =>
- p transformExpressions {
- case WindowExpression(
- UnresolvedWindowFunction(name, children),
- windowSpec: WindowSpecDefinition) =>
- // First, let's find the window function info.
- val windowFunctionInfo: WindowFunctionInfo =
- Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse(
- throw new AnalysisException(s"Couldn't find window function $name"))
-
- // Get the class of this function.
- // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use
- // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1.
- val functionClass = windowFunctionInfo.getfInfo().getFunctionClass
- val newChildren =
- // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit
- // input parameters and requires implicit parameters, which
- // are expressions in Order By clause.
- if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) {
- if (children.nonEmpty) {
- throw new AnalysisException(s"$name does not take input parameters.")
- }
- windowSpec.orderSpec.map(_.child)
- } else {
- children
- }
-
- // If the class is UDAF, we need to use UDAFBridge.
- val isUDAFBridgeRequired =
- if (classOf[UDAF].isAssignableFrom(functionClass)) {
- true
- } else {
- false
- }
-
- // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of
- // HiveWindowFunction.
- val windowFunction =
- HiveWindowFunction(
- new HiveFunctionWrapper(functionClass.getName),
- windowFunctionInfo.isPivotResult,
- isUDAFBridgeRequired,
- newChildren)
-
- // Second, check if the specified window function can accept window definition.
- windowSpec.frameSpecification match {
- case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow =>
- // This Hive window function does not support user-speficied window frame.
- throw new AnalysisException(
- s"Window function $name does not take a frame specification.")
- case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow &&
- windowFunctionInfo.isPivotResult =>
- // These two should not be true at the same time when a window frame is defined.
- // If so, throw an exception.
- throw new AnalysisException(s"Could not handle Hive window function $name because " +
- s"it supports both a user specified window frame and pivot result.")
- case _ => // OK
- }
- // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs
- // a window frame specification to work.
- val newWindowSpec = windowSpec.frameSpecification match {
- case UnspecifiedFrame =>
- val newWindowFrame =
- SpecifiedWindowFrame.defaultWindowFrame(
- windowSpec.orderSpec.nonEmpty,
- windowFunctionInfo.isSupportsWindow)
- WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame)
- case _ => windowSpec
- }
-
- // Finally, we create a WindowExpression with the resolved window function and
- // specified window spec.
- WindowExpression(windowFunction, newWindowSpec)
- }
- }
-}
-
-/**
- * A [[WindowFunction]] implementation wrapping Hive's window function.
- * @param funcWrapper The wrapper for the Hive Window Function.
- * @param pivotResult If it is true, the Hive function will return a list of values representing
- * the values of the added columns. Otherwise, a single value is returned for
- * current row.
- * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's
- * createFunction is UDAF, we need to use GenericUDAFBridge to wrap
- * it as a GenericUDAFResolver2.
- * @param children Input parameters.
- */
-private[hive] case class HiveWindowFunction(
- funcWrapper: HiveFunctionWrapper,
- pivotResult: Boolean,
- isUDAFBridgeRequired: Boolean,
- children: Seq[Expression]) extends WindowFunction
- with HiveInspectors {
-
- // Hive window functions are based on GenericUDAFResolver2.
- type UDFType = GenericUDAFResolver2
-
- @transient
- protected lazy val resolver: GenericUDAFResolver2 =
- if (isUDAFBridgeRequired) {
- new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
- } else {
- funcWrapper.createFunction[GenericUDAFResolver2]()
- }
-
- @transient
- protected lazy val inputInspectors = children.map(toInspector).toArray
-
- // The GenericUDAFEvaluator used to evaluate the window function.
- @transient
- protected lazy val evaluator: GenericUDAFEvaluator = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false)
- resolver.getEvaluator(parameterInfo)
- }
-
- // The object inspector of values returned from the Hive window function.
- @transient
- protected lazy val returnInspector = {
- evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors)
- }
-
- def dataType: DataType =
- if (!pivotResult) {
- inspectorToDataType(returnInspector)
- } else {
- // If pivotResult is true, we should take the element type out as the data type of this
- // function.
- inspectorToDataType(returnInspector) match {
- case ArrayType(dt, _) => dt
- case _ =>
- sys.error(
- s"error resolve the data type of window function ${funcWrapper.functionClassName}")
- }
- }
-
- def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
- @transient
- lazy val inputProjection = new InterpretedProjection(children)
-
- @transient
- private var hiveEvaluatorBuffer: AggregationBuffer = _
- // Output buffer.
- private var outputBuffer: Any = _
-
- override def init(): Unit = {
- evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors)
- }
-
- // Reset the hiveEvaluatorBuffer and outputPosition
- override def reset(): Unit = {
- // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber.
- // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init.
- // However, RowNumberBuffer.init does not really reset this buffer.
- hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer
- evaluator.reset(hiveEvaluatorBuffer)
- }
-
- override def prepareInputParameters(input: InternalRow): AnyRef = {
- wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length))
- }
- // Add input parameters for a single row.
- override def update(input: AnyRef): Unit = {
- evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]])
- }
-
- override def batchUpdate(inputs: Array[AnyRef]): Unit = {
- var i = 0
- while (i < inputs.length) {
- evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]])
- i += 1
- }
- }
-
- override def evaluate(): Unit = {
- outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector)
- }
-
- override def get(index: Int): Any = {
- if (!pivotResult) {
- // if pivotResult is false, we will get a single value for all rows in the frame.
- outputBuffer
- } else {
- // if pivotResult is true, we will get a Seq having the same size with the size
- // of the window frame. At here, we will return the result at the position of
- // index in the output buffer.
- outputBuffer.asInstanceOf[Seq[Any]].get(index)
- }
- }
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-
- override def newInstance: WindowFunction =
- new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children)
-}
-
-private[hive] case class HiveGenericUdaf(
- funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression
- with HiveInspectors {
-
- type UDFType = AbstractGenericUDAFResolver
-
- @transient
- protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()
-
- @transient
- protected lazy val objectInspector = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
- resolver.getEvaluator(parameterInfo)
- .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
- }
-
- @transient
- protected lazy val inspectors = children.map(toInspector)
-
- def dataType: DataType = inspectorToDataType(objectInspector)
-
- def nullable: Boolean = true
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-
- def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this)
-}
-
-/** It is used as a wrapper for the hive functions which uses UDAF interface */
-private[hive] case class HiveUdaf(
- funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression
- with HiveInspectors {
-
- type UDFType = UDAF
-
- @transient
- protected lazy val resolver: AbstractGenericUDAFResolver =
- new GenericUDAFBridge(funcWrapper.createFunction())
-
- @transient
- protected lazy val objectInspector = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
- resolver.getEvaluator(parameterInfo)
- .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
- }
-
- @transient
- protected lazy val inspectors = children.map(toInspector)
-
- def dataType: DataType = inspectorToDataType(objectInspector)
-
- def nullable: Boolean = true
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-
- def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true)
-}
-
-/**
- * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
- * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow
- * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning
- * dependent operations like calls to `close()` before producing output will not operate the same as
- * in Hive. However, in practice this should not affect compatibility for most sane UDTFs
- * (e.g. explode or GenericUDTFParseUrlTuple).
- *
- * Operators that require maintaining state in between input rows should instead be implemented as
- * user defined aggregations, which have clean semantics even in a partitioned execution.
- */
-private[hive] case class HiveGenericUdtf(
- funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression])
- extends Generator with HiveInspectors {
-
- @transient
- protected lazy val function: GenericUDTF = {
- val fun: GenericUDTF = funcWrapper.createFunction()
- fun.setCollector(collector)
- fun
- }
-
- @transient
- protected lazy val inputInspectors = children.map(toInspector)
-
- @transient
- protected lazy val outputInspector = function.initialize(inputInspectors.toArray)
-
- @transient
- protected lazy val udtInput = new Array[AnyRef](children.length)
-
- @transient
- protected lazy val collector = new UDTFCollector
-
- lazy val elementTypes = outputInspector.getAllStructFieldRefs.map {
- field => (inspectorToDataType(field.getFieldObjectInspector), true)
- }
-
- override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
- outputInspector // Make sure initialized.
-
- val inputProjection = new InterpretedProjection(children)
-
- function.process(wrap(inputProjection(input), inputInspectors, udtInput))
- collector.collectRows()
- }
-
- protected class UDTFCollector extends Collector {
- var collected = new ArrayBuffer[InternalRow]
-
- override def collect(input: java.lang.Object) {
- // We need to clone the input here because implementations of
- // GenericUDTF reuse the same object. Luckily they are always an array, so
- // it is easy to clone.
- collected += unwrap(input, outputInspector).asInstanceOf[InternalRow]
- }
-
- def collectRows(): Seq[InternalRow] = {
- val toCollect = collected
- collected = new ArrayBuffer[InternalRow]
- toCollect
- }
- }
-
- override def terminate(): TraversableOnce[InternalRow] = {
- outputInspector // Make sure initialized.
- function.close()
- collector.collectRows()
- }
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-}
-
-private[hive] case class HiveUdafFunction(
- funcWrapper: HiveFunctionWrapper,
- exprs: Seq[Expression],
- base: AggregateExpression,
- isUDAFBridgeRequired: Boolean = false)
- extends AggregateFunction
- with HiveInspectors {
-
- def this() = this(null, null, null)
-
- private val resolver =
- if (isUDAFBridgeRequired) {
- new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
- } else {
- funcWrapper.createFunction[AbstractGenericUDAFResolver]()
- }
-
- private val inspectors = exprs.map(toInspector).toArray
-
- private val function = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
- resolver.getEvaluator(parameterInfo)
- }
-
- private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
-
- private val buffer =
- function.getNewAggregationBuffer
-
- override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector)
-
- @transient
- val inputProjection = new InterpretedProjection(exprs)
-
- @transient
- protected lazy val cached = new Array[AnyRef](exprs.length)
-
- def update(input: InternalRow): Unit = {
- val inputs = inputProjection(input)
- function.iterate(buffer, wrap(inputs, inspectors, cached))
- }
-}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index ea325cc..7978fda 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -391,7 +391,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
* Records the UDFs present when the server starts, so we can delete ones that are created by
* tests.
*/
- protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames
+ protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames
/**
* Resets the test instance by deleting any tables that have been created.
@@ -410,7 +410,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
catalog.client.reset()
catalog.unregisterAllTables()
- FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName =>
+ FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName =>
FunctionRegistry.unregisterTemporaryUDF(udfName)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/test/resources/data/files/testUDF/part-00000
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/data/files/testUDF/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000
new file mode 100755
index 0000000..240a5c1
Binary files /dev/null and b/sql/hive/src/test/resources/data/files/testUDF/part-00000 differ
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/test/resources/data/files/testUdf/part-00000
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUdf/part-00000
deleted file mode 100755
index 240a5c1..0000000
Binary files a/sql/hive/src/test/resources/data/files/testUdf/part-00000 and /dev/null differ
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org