You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/27 22:27:01 UTC
spark git commit: [SPARK-9349] [SQL] UDAF cleanup
Repository: spark
Updated Branches:
refs/heads/master fa84e4a7b -> 55946e76f
[SPARK-9349] [SQL] UDAF cleanup
https://issues.apache.org/jira/browse/SPARK-9349
With this PR, we only expose `UserDefinedAggregateFunction` (an abstract class) and `MutableAggregationBuffer` (an interface). Other internal wrappers and helper classes are moved to `org.apache.spark.sql.execution.aggregate` and marked as `private[sql]`.
Author: Yin Huai <yh...@databricks.com>
Closes #7687 from yhuai/UDAF-cleanup and squashes the following commits:
db36542 [Yin Huai] Add comments to UDAF examples.
ae17f66 [Yin Huai] Address comments.
9c9fa5f [Yin Huai] UDAF cleanup.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/55946e76
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/55946e76
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/55946e76
Branch: refs/heads/master
Commit: 55946e76fd136958081f073c0c5e3ff8563d505b
Parents: fa84e4a
Author: Yin Huai <yh...@databricks.com>
Authored: Mon Jul 27 13:26:57 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon Jul 27 13:26:57 2015 -0700
----------------------------------------------------------------------
.../org/apache/spark/sql/UDAFRegistration.scala | 3 +-
.../spark/sql/execution/aggregate/udaf.scala | 231 +++++++++++++++
.../spark/sql/expressions/aggregate/udaf.scala | 287 -------------------
.../org/apache/spark/sql/expressions/udaf.scala | 101 +++++++
.../spark/sql/hive/aggregate/MyDoubleAvg.java | 34 ++-
.../spark/sql/hive/aggregate/MyDoubleSum.java | 28 +-
6 files changed, 385 insertions(+), 299 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
index 5b872f5..0d4e30f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression}
-import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction}
+import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
new file mode 100644
index 0000000..073c45a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType}
+
+/**
+ * A Mutable [[Row]] representing an mutable aggregation buffer.
+ */
+private[sql] class MutableAggregationBufferImpl (
+ schema: StructType,
+ toCatalystConverters: Array[Any => Any],
+ toScalaConverters: Array[Any => Any],
+ bufferOffset: Int,
+ var underlyingBuffer: MutableRow)
+ extends MutableAggregationBuffer {
+
+ private[this] val offsets: Array[Int] = {
+ val newOffsets = new Array[Int](length)
+ var i = 0
+ while (i < newOffsets.length) {
+ newOffsets(i) = bufferOffset + i
+ i += 1
+ }
+ newOffsets
+ }
+
+ override def length: Int = toCatalystConverters.length
+
+ override def get(i: Int): Any = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not access ${i}th value in this buffer because it only has $length values.")
+ }
+ toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType))
+ }
+
+ def update(i: Int, value: Any): Unit = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not update ${i}th value in this buffer because it only has $length values.")
+ }
+ underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+ }
+
+ override def copy(): MutableAggregationBufferImpl = {
+ new MutableAggregationBufferImpl(
+ schema,
+ toCatalystConverters,
+ toScalaConverters,
+ bufferOffset,
+ underlyingBuffer)
+ }
+}
+
+/**
+ * A [[Row]] representing an immutable aggregation buffer.
+ */
+private[sql] class InputAggregationBuffer private[sql] (
+ schema: StructType,
+ toCatalystConverters: Array[Any => Any],
+ toScalaConverters: Array[Any => Any],
+ bufferOffset: Int,
+ var underlyingInputBuffer: InternalRow)
+ extends Row {
+
+ private[this] val offsets: Array[Int] = {
+ val newOffsets = new Array[Int](length)
+ var i = 0
+ while (i < newOffsets.length) {
+ newOffsets(i) = bufferOffset + i
+ i += 1
+ }
+ newOffsets
+ }
+
+ override def length: Int = toCatalystConverters.length
+
+ override def get(i: Int): Any = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not access ${i}th value in this buffer because it only has $length values.")
+ }
+ // TODO: Use buffer schema to avoid using generic getter.
+ toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType))
+ }
+
+ override def copy(): InputAggregationBuffer = {
+ new InputAggregationBuffer(
+ schema,
+ toCatalystConverters,
+ toScalaConverters,
+ bufferOffset,
+ underlyingInputBuffer)
+ }
+}
+
+/**
+ * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the
+ * internal aggregation code path.
+ * @param children
+ * @param udaf
+ */
+private[sql] case class ScalaUDAF(
+ children: Seq[Expression],
+ udaf: UserDefinedAggregateFunction)
+ extends AggregateFunction2 with Logging {
+
+ require(
+ children.length == udaf.inputSchema.length,
+ s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
+ s"but ${children.length} are provided.")
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = udaf.returnDataType
+
+ override def deterministic: Boolean = udaf.deterministic
+
+ override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
+
+ override val bufferSchema: StructType = udaf.bufferSchema
+
+ override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes
+
+ override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+
+ val childrenSchema: StructType = {
+ val inputFields = children.zipWithIndex.map {
+ case (child, index) =>
+ StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
+ }
+ StructType(inputFields)
+ }
+
+ lazy val inputProjection = {
+ val inputAttributes = childrenSchema.toAttributes
+ log.debug(
+ s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
+ try {
+ GenerateMutableProjection.generate(children, inputAttributes)()
+ } catch {
+ case e: Exception =>
+ log.error("Failed to generate mutable projection, fallback to interpreted", e)
+ new InterpretedMutableProjection(children, inputAttributes)
+ }
+ }
+
+ val inputToScalaConverters: Any => Any =
+ CatalystTypeConverters.createToScalaConverter(childrenSchema)
+
+ val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToCatalystConverter(field.dataType)
+ }
+
+ val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToScalaConverter(field.dataType)
+ }
+
+ lazy val inputAggregateBuffer: InputAggregationBuffer =
+ new InputAggregationBuffer(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ bufferOffset,
+ null)
+
+ lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
+ new MutableAggregationBufferImpl(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ bufferOffset,
+ null)
+
+
+ override def initialize(buffer: MutableRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer
+
+ udaf.initialize(mutableAggregateBuffer)
+ }
+
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer
+
+ udaf.update(
+ mutableAggregateBuffer,
+ inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
+ }
+
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer1
+ inputAggregateBuffer.underlyingInputBuffer = buffer2
+
+ udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
+ }
+
+ override def eval(buffer: InternalRow = null): Any = {
+ inputAggregateBuffer.underlyingInputBuffer = buffer
+
+ udaf.evaluate(inputAggregateBuffer)
+ }
+
+ override def toString: String = {
+ s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
+ }
+
+ override def nodeName: String = udaf.getClass.getSimpleName
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
deleted file mode 100644
index 4ada9ec..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
+++ /dev/null
@@ -1,287 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.expressions.aggregate
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
-import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.Row
-
-/**
- * The abstract class for implementing user-defined aggregate function.
- */
-abstract class UserDefinedAggregateFunction extends Serializable {
-
- /**
- * A [[StructType]] represents data types of input arguments of this aggregate function.
- * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
- * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
- *
- * ```
- * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
- * ```
- *
- * The name of a field of this [[StructType]] is only used to identify the corresponding
- * input argument. Users can choose names to identify the input arguments.
- */
- def inputSchema: StructType
-
- /**
- * A [[StructType]] represents data types of values in the aggregation buffer.
- * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
- * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
- * the returned [[StructType]] will look like
- *
- * ```
- * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
- * ```
- *
- * The name of a field of this [[StructType]] is only used to identify the corresponding
- * buffer value. Users can choose names to identify the input arguments.
- */
- def bufferSchema: StructType
-
- /**
- * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
- */
- def returnDataType: DataType
-
- /** Indicates if this function is deterministic. */
- def deterministic: Boolean
-
- /**
- * Initializes the given aggregation buffer. Initial values set by this method should satisfy
- * the condition that when merging two buffers with initial values, the new buffer should
- * still store initial values.
- */
- def initialize(buffer: MutableAggregationBuffer): Unit
-
- /** Updates the given aggregation buffer `buffer` with new input data from `input`. */
- def update(buffer: MutableAggregationBuffer, input: Row): Unit
-
- /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */
- def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
-
- /**
- * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
- * aggregation buffer.
- */
- def evaluate(buffer: Row): Any
-}
-
-private[sql] abstract class AggregationBuffer(
- toCatalystConverters: Array[Any => Any],
- toScalaConverters: Array[Any => Any],
- bufferOffset: Int)
- extends Row {
-
- override def length: Int = toCatalystConverters.length
-
- protected val offsets: Array[Int] = {
- val newOffsets = new Array[Int](length)
- var i = 0
- while (i < newOffsets.length) {
- newOffsets(i) = bufferOffset + i
- i += 1
- }
- newOffsets
- }
-}
-
-/**
- * A Mutable [[Row]] representing an mutable aggregation buffer.
- */
-class MutableAggregationBuffer private[sql] (
- schema: StructType,
- toCatalystConverters: Array[Any => Any],
- toScalaConverters: Array[Any => Any],
- bufferOffset: Int,
- var underlyingBuffer: MutableRow)
- extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
-
- override def get(i: Int): Any = {
- if (i >= length || i < 0) {
- throw new IllegalArgumentException(
- s"Could not access ${i}th value in this buffer because it only has $length values.")
- }
- toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType))
- }
-
- def update(i: Int, value: Any): Unit = {
- if (i >= length || i < 0) {
- throw new IllegalArgumentException(
- s"Could not update ${i}th value in this buffer because it only has $length values.")
- }
- underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
- }
-
- override def copy(): MutableAggregationBuffer = {
- new MutableAggregationBuffer(
- schema,
- toCatalystConverters,
- toScalaConverters,
- bufferOffset,
- underlyingBuffer)
- }
-}
-
-/**
- * A [[Row]] representing an immutable aggregation buffer.
- */
-class InputAggregationBuffer private[sql] (
- schema: StructType,
- toCatalystConverters: Array[Any => Any],
- toScalaConverters: Array[Any => Any],
- bufferOffset: Int,
- var underlyingInputBuffer: InternalRow)
- extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
-
- override def get(i: Int): Any = {
- if (i >= length || i < 0) {
- throw new IllegalArgumentException(
- s"Could not access ${i}th value in this buffer because it only has $length values.")
- }
- // TODO: Use buffer schema to avoid using generic getter.
- toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType))
- }
-
- override def copy(): InputAggregationBuffer = {
- new InputAggregationBuffer(
- schema,
- toCatalystConverters,
- toScalaConverters,
- bufferOffset,
- underlyingInputBuffer)
- }
-}
-
-/**
- * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the
- * internal aggregation code path.
- * @param children
- * @param udaf
- */
-case class ScalaUDAF(
- children: Seq[Expression],
- udaf: UserDefinedAggregateFunction)
- extends AggregateFunction2 with Logging {
-
- require(
- children.length == udaf.inputSchema.length,
- s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
- s"but ${children.length} are provided.")
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = udaf.returnDataType
-
- override def deterministic: Boolean = udaf.deterministic
-
- override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
-
- override val bufferSchema: StructType = udaf.bufferSchema
-
- override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes
-
- override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
-
- val childrenSchema: StructType = {
- val inputFields = children.zipWithIndex.map {
- case (child, index) =>
- StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
- }
- StructType(inputFields)
- }
-
- lazy val inputProjection = {
- val inputAttributes = childrenSchema.toAttributes
- log.debug(
- s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
- try {
- GenerateMutableProjection.generate(children, inputAttributes)()
- } catch {
- case e: Exception =>
- log.error("Failed to generate mutable projection, fallback to interpreted", e)
- new InterpretedMutableProjection(children, inputAttributes)
- }
- }
-
- val inputToScalaConverters: Any => Any =
- CatalystTypeConverters.createToScalaConverter(childrenSchema)
-
- val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
- CatalystTypeConverters.createToCatalystConverter(field.dataType)
- }
-
- val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
- CatalystTypeConverters.createToScalaConverter(field.dataType)
- }
-
- lazy val inputAggregateBuffer: InputAggregationBuffer =
- new InputAggregationBuffer(
- bufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- bufferOffset,
- null)
-
- lazy val mutableAggregateBuffer: MutableAggregationBuffer =
- new MutableAggregationBuffer(
- bufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- bufferOffset,
- null)
-
-
- override def initialize(buffer: MutableRow): Unit = {
- mutableAggregateBuffer.underlyingBuffer = buffer
-
- udaf.initialize(mutableAggregateBuffer)
- }
-
- override def update(buffer: MutableRow, input: InternalRow): Unit = {
- mutableAggregateBuffer.underlyingBuffer = buffer
-
- udaf.update(
- mutableAggregateBuffer,
- inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
- }
-
- override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- mutableAggregateBuffer.underlyingBuffer = buffer1
- inputAggregateBuffer.underlyingInputBuffer = buffer2
-
- udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
- }
-
- override def eval(buffer: InternalRow = null): Any = {
- inputAggregateBuffer.underlyingInputBuffer = buffer
-
- udaf.evaluate(inputAggregateBuffer)
- }
-
- override def toString: String = {
- s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
- }
-
- override def nodeName: String = udaf.getClass.getSimpleName
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
new file mode 100644
index 0000000..278dd43
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * The abstract class for implementing user-defined aggregate functions.
+ */
+@Experimental
+abstract class UserDefinedAggregateFunction extends Serializable {
+
+ /**
+ * A [[StructType]] represents data types of input arguments of this aggregate function.
+ * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
+ * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
+ *
+ * ```
+ * new StructType()
+ * .add("doubleInput", DoubleType)
+ * .add("longInput", LongType)
+ * ```
+ *
+ * The name of a field of this [[StructType]] is only used to identify the corresponding
+ * input argument. Users can choose names to identify the input arguments.
+ */
+ def inputSchema: StructType
+
+ /**
+ * A [[StructType]] represents data types of values in the aggregation buffer.
+ * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
+ * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
+ * the returned [[StructType]] will look like
+ *
+ * ```
+ * new StructType()
+ * .add("doubleInput", DoubleType)
+ * .add("longInput", LongType)
+ * ```
+ *
+ * The name of a field of this [[StructType]] is only used to identify the corresponding
+ * buffer value. Users can choose names to identify the input arguments.
+ */
+ def bufferSchema: StructType
+
+ /**
+ * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
+ */
+ def returnDataType: DataType
+
+ /** Indicates if this function is deterministic. */
+ def deterministic: Boolean
+
+ /**
+ * Initializes the given aggregation buffer. Initial values set by this method should satisfy
+ * the condition that when merging two buffers with initial values, the new buffer
+ * still store initial values.
+ */
+ def initialize(buffer: MutableAggregationBuffer): Unit
+
+ /** Updates the given aggregation buffer `buffer` with new input data from `input`. */
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit
+
+ /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
+
+ /**
+ * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
+ * aggregation buffer.
+ */
+ def evaluate(buffer: Row): Any
+}
+
+/**
+ * :: Experimental ::
+ * A [[Row]] representing an mutable aggregation buffer.
+ */
+@Experimental
+trait MutableAggregationBuffer extends Row {
+
+ /** Update the ith value of this buffer. */
+ def update(i: Int, value: Any): Unit
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
index 5c9d0e9..a2247e3 100644
--- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
@@ -21,13 +21,18 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a
+ * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum
+ * of the average value of input values and 100.0.
+ */
public class MyDoubleAvg extends UserDefinedAggregateFunction {
private StructType _inputDataType;
@@ -37,10 +42,13 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
private DataType _returnDataType;
public MyDoubleAvg() {
- List<StructField> inputfields = new ArrayList<StructField>();
- inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
- _inputDataType = DataTypes.createStructType(inputfields);
+ List<StructField> inputFields = new ArrayList<StructField>();
+ inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputFields);
+ // The buffer has two values, bufferSum for storing the current sum and
+ // bufferCount for storing the number of non-null input values that have been contribuetd
+ // to the current sum.
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
@@ -66,16 +74,23 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
}
@Override public void initialize(MutableAggregationBuffer buffer) {
+ // The initial value of the sum is null.
buffer.update(0, null);
+ // The initial value of the count is 0.
buffer.update(1, 0L);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
+ // This input Row only has a single column storing the input value in Double.
+ // We only update the buffer when the input value is not null.
if (!input.isNullAt(0)) {
+ // If the buffer value (the intermediate result of the sum) is still null,
+ // we set the input value to the buffer and set the bufferCount to 1.
if (buffer.isNullAt(0)) {
buffer.update(0, input.getDouble(0));
buffer.update(1, 1L);
} else {
+ // Otherwise, update the bufferSum and increment bufferCount.
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
buffer.update(1, buffer.getLong(1) + 1L);
@@ -84,11 +99,16 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ // buffer1 and buffer2 have the same structure.
+ // We only update the buffer1 when the input buffer2's sum value is not null.
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
+ // If the buffer value (intermediate result of the sum) is still null,
+ // we set the it as the input buffer's value.
buffer1.update(0, buffer2.getDouble(0));
buffer1.update(1, buffer2.getLong(1));
} else {
+ // Otherwise, we update the bufferSum and bufferCount.
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
@@ -98,10 +118,12 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
+ // If the bufferSum is still null, we return null because this function has not got
+ // any input row.
return null;
} else {
+ // Otherwise, we calculate the special average value.
return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
}
}
}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
index 1d4587a..da29e24 100644
--- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
@@ -20,14 +20,18 @@ package test.org.apache.spark.sql.hive.aggregate;
import java.util.ArrayList;
import java.util.List;
-import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate the sum of a
+ * {@link org.apache.spark.sql.types.DoubleType} column.
+ */
public class MyDoubleSum extends UserDefinedAggregateFunction {
private StructType _inputDataType;
@@ -37,9 +41,9 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
private DataType _returnDataType;
public MyDoubleSum() {
- List<StructField> inputfields = new ArrayList<StructField>();
- inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
- _inputDataType = DataTypes.createStructType(inputfields);
+ List<StructField> inputFields = new ArrayList<StructField>();
+ inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
@@ -65,14 +69,20 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
}
@Override public void initialize(MutableAggregationBuffer buffer) {
+ // The initial value of the sum is null.
buffer.update(0, null);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
+ // This input Row only has a single column storing the input value in Double.
+ // We only update the buffer when the input value is not null.
if (!input.isNullAt(0)) {
if (buffer.isNullAt(0)) {
+ // If the buffer value (the intermediate result of the sum) is still null,
+ // we set the input value to the buffer.
buffer.update(0, input.getDouble(0));
} else {
+ // Otherwise, we add the input value to the buffer value.
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
}
@@ -80,10 +90,16 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ // buffer1 and buffer2 have the same structure.
+ // We only update the buffer1 when the input buffer2's value is not null.
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
+ // If the buffer value (intermediate result of the sum) is still null,
+ // we set the it as the input buffer's value.
buffer1.update(0, buffer2.getDouble(0));
} else {
+ // Otherwise, we add the input buffer's value (buffer1) to the mutable
+ // buffer's value (buffer2).
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
}
@@ -92,8 +108,10 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
+ // If the buffer value is still null, we return null.
return null;
} else {
+ // Otherwise, the intermediate sum is the final result.
return buffer.getDouble(0);
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org