You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/08/25 23:36:22 UTC

spark git commit: [SPARK-17187][SQL] Supports using arbitrary Java object as internal aggregation buffer object

Repository: spark
Updated Branches:
  refs/heads/master 9b5a1d1d5 -> d96d15156


[SPARK-17187][SQL] Supports using arbitrary Java object as internal aggregation buffer object

## What changes were proposed in this pull request?

This PR introduces an abstract class `TypedImperativeAggregate` so that an aggregation function of TypedImperativeAggregate can use  **arbitrary** user-defined Java object as intermediate aggregation buffer object.

**This has advantages like:**
1. It now can support larger category of aggregation functions. For example, it will be much easier to implement aggregation function `percentile_approx`, which has a complex aggregation buffer definition.
2. It can be used to avoid doing serialization/de-serialization for every call of `update` or `merge` when converting domain specific aggregation object to internal Spark-Sql storage format.
3. It is easier to integrate with other existing monoid libraries like algebird, and supports more aggregation functions with high performance.

Please see `org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMaxAggregate` to find an example of how to defined a `TypedImperativeAggregate` aggregation function.
Please see Java doc of `TypedImperativeAggregate` and Jira ticket SPARK-17187 for more information.

## How was this patch tested?

Unit tests.

Author: Sean Zhong <se...@databricks.com>
Author: Yin Huai <yh...@databricks.com>

Closes #14753 from clockfly/object_aggregation_buffer_try_2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d96d1515
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d96d1515
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d96d1515

Branch: refs/heads/master
Commit: d96d1515638da20b594f7bfe3cfdb50088f25a04
Parents: 9b5a1d1
Author: Sean Zhong <se...@databricks.com>
Authored: Thu Aug 25 16:36:16 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Aug 25 16:36:16 2016 -0700

----------------------------------------------------------------------
 .../expressions/aggregate/interfaces.scala      | 141 +++++++++
 .../aggregate/AggregationIterator.scala         |  15 +
 .../sql/TypedImperativeAggregateSuite.scala     | 300 +++++++++++++++++++
 3 files changed, 456 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d96d1515/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 7a39e56..ecbaa2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -389,3 +389,144 @@ abstract class DeclarativeAggregate
     def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
   }
 }
+
+/**
+ * Aggregation function which allows **arbitrary** user-defined java object to be used as internal
+ * aggregation buffer object.
+ *
+ * {{{
+ *                aggregation buffer for normal aggregation function `avg`
+ *                    |
+ *                    v
+ *                  +--------------+---------------+-----------------------------------+
+ *                  |  sum1 (Long) | count1 (Long) | generic user-defined java objects |
+ *                  +--------------+---------------+-----------------------------------+
+ *                                                     ^
+ *                                                     |
+ *                    Aggregation buffer object for `TypedImperativeAggregate` aggregation function
+ * }}}
+ *
+ * Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side):
+ *
+ * Stage 1: Partial aggregate at Mapper side:
+ *
+ *  1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
+ *     buffer object.
+ *  2. Upon each input row, the framework calls
+ *     `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T.
+ *  3. After processing all rows of current group (group by key), the framework will serialize
+ *     aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte]
+ *     to disk if needed.
+ *  4. The framework moves on to next group, until all groups have been processed.
+ *
+ * Shuffling exchange data to Reducer tasks...
+ *
+ * Stage 2: Final mode aggregate at Reducer side:
+ *
+ *  1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
+ *     buffer object (type T) for merging.
+ *  2. For each aggregation output of Stage 1, The framework de-serializes the storage
+ *     format (Array[Byte]) and produces one input aggregation object (type T).
+ *  3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit`
+ *     to merge the input aggregation object into aggregation buffer object.
+ *  4. After processing all input aggregation objects of current group (group by key), the framework
+ *     calls method `eval(buffer: T)` to generate the final output for this group.
+ *  5. The framework moves on to next group, until all groups have been processed.
+ *
+ * NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
+ * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
+ * buffer's storage format, which is not supported by hash based aggregation. Hash based
+ * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
+ * fixed length and can be mutated in place in UnsafeRow)
+ */
+abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
+
+  /**
+   * Creates an empty aggregation buffer object. This is called before processing each key group
+   * (group by key).
+   *
+   * @return an aggregation buffer object
+   */
+  def createAggregationBuffer(): T
+
+  /**
+   * In-place updates the aggregation buffer object with an input row. buffer = buffer + input.
+   * This is typically called when doing Partial or Complete mode aggregation.
+   *
+   * @param buffer The aggregation buffer object.
+   * @param input an input row
+   */
+  def update(buffer: T, input: InternalRow): Unit
+
+  /**
+   * Merges an input aggregation object into aggregation buffer object. buffer = buffer + input.
+   * This is typically called when doing PartialMerge or Final mode aggregation.
+   *
+   * @param buffer the aggregation buffer object used to store the aggregation result.
+   * @param input an input aggregation object. Input aggregation object can be produced by
+   *              de-serializing the partial aggregate's output from Mapper side.
+   */
+  def merge(buffer: T, input: T): Unit
+
+  /**
+   * Generates the final aggregation result value for current key group with the aggregation buffer
+   * object.
+   *
+   * @param buffer aggregation buffer object.
+   * @return The aggregation result of current key group
+   */
+  def eval(buffer: T): Any
+
+  /** Serializes the aggregation buffer object T to Array[Byte] */
+  def serialize(buffer: T): Array[Byte]
+
+  /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
+  def deserialize(storageFormat: Array[Byte]): T
+
+  final override def initialize(buffer: MutableRow): Unit = {
+    val bufferObject = createAggregationBuffer()
+    buffer.update(mutableAggBufferOffset, bufferObject)
+  }
+
+  final override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+    update(bufferObject, input)
+  }
+
+  final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
+    val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+    // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
+    val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
+    merge(bufferObject, inputObject)
+  }
+
+  final override def eval(buffer: InternalRow): Any = {
+    val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+    eval(bufferObject)
+  }
+
+  private[this] val anyObjectType = ObjectType(classOf[AnyRef])
+  private def getField[U](input: InternalRow, fieldIndex: Int): U = {
+    input.get(fieldIndex, anyObjectType).asInstanceOf[U]
+  }
+
+  final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
+    // Underlying storage type for the aggregation buffer object
+    Seq(AttributeReference("buf", BinaryType)())
+  }
+
+  final override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+    aggBufferAttributes.map(_.newInstance())
+
+  final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+  /**
+   * In-place replaces the aggregation buffer object stored at buffer's index
+   * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format
+   * (BinaryType).
+   */
+  final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
+    val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+    buffer(mutableAggBufferOffset) = serialize(bufferObject)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d96d1515/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 34de76d..dfed084 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -234,7 +234,22 @@ abstract class AggregationIterator(
       val resultProjection = UnsafeProjection.create(
         groupingAttributes ++ bufferAttributes,
         groupingAttributes ++ bufferAttributes)
+
+      // TypedImperativeAggregate stores generic object in aggregation buffer, and requires
+      // calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info.
+      val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = {
+        aggregateFunctions.collect {
+          case (ag: TypedImperativeAggregate[_]) => ag
+        }
+      }
+
       (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+        // Serializes the generic object stored in aggregation buffer
+        var i = 0
+        while (i < typedImperativeAggregates.length) {
+          typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer)
+          i += 1
+        }
         resultProjection(joinedRow(currentGroupingKey, currentBuffer))
       }
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/d96d1515/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
new file mode 100644
index 0000000..b5eb16b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow}
+import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
+import org.apache.spark.sql.execution.aggregate.SortAggregateExec
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType}
+
+class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
+
+  import testImplicits._
+
+  private val random = new java.util.Random()
+
+  private val data = (0 until 1000).map { _ =>
+    (random.nextInt(10), random.nextInt(100))
+  }
+
+  test("aggregate with object aggregate buffer") {
+    val agg = new TypedMax(BoundReference(0, IntegerType, nullable = false))
+
+    val group1 = (0 until data.length / 2)
+    val group1Buffer = agg.createAggregationBuffer()
+    group1.foreach { index =>
+      val input = InternalRow(data(index)._1, data(index)._2)
+      agg.update(group1Buffer, input)
+    }
+
+    val group2 = (data.length / 2 until data.length)
+    val group2Buffer = agg.createAggregationBuffer()
+    group2.foreach { index =>
+      val input = InternalRow(data(index)._1, data(index)._2)
+      agg.update(group2Buffer, input)
+    }
+
+    val mergeBuffer = agg.createAggregationBuffer()
+    agg.merge(mergeBuffer, group1Buffer)
+    agg.merge(mergeBuffer, group2Buffer)
+
+    assert(mergeBuffer.value == data.map(_._1).max)
+    assert(agg.eval(mergeBuffer) == data.map(_._1).max)
+
+    // Tests low level eval(row: InternalRow) API.
+    val row = new GenericMutableRow(Array(mergeBuffer): Array[Any])
+
+    // Evaluates directly on row consist of aggregation buffer object.
+    assert(agg.eval(row) == data.map(_._1).max)
+  }
+
+  test("supports SpecificMutableRow as mutable row") {
+    val aggregationBufferSchema = Seq(IntegerType, LongType, BinaryType, IntegerType)
+    val aggBufferOffset = 2
+    val buffer = new SpecificMutableRow(aggregationBufferSchema)
+    val agg = new TypedMax(BoundReference(ordinal = 1, dataType = IntegerType, nullable = false))
+      .withNewMutableAggBufferOffset(aggBufferOffset)
+
+    agg.initialize(buffer)
+    data.foreach { kv =>
+      val input = InternalRow(kv._1, kv._2)
+      agg.update(buffer, input)
+    }
+    assert(agg.eval(buffer) == data.map(_._2).max)
+  }
+
+  test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") {
+    val df = data.toDF("a", "b")
+    val max = new TypedMax($"a".expr)
+
+    // Always uses SortAggregateExec
+    val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan
+    assert(sparkPlan.isInstanceOf[SortAggregateExec])
+  }
+
+  test("dataframe aggregate with object aggregate buffer, no group by") {
+    val df = data.toDF("key", "value").coalesce(2)
+    val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"), count($"value"))
+    val maxKey = data.map(_._1).max
+    val countKey = data.size
+    val maxValue = data.map(_._2).max
+    val countValue = data.size
+    val expected = Seq(Row(maxKey, countKey, maxValue, countValue))
+    checkAnswer(query, expected)
+  }
+
+  test("dataframe aggregate with object aggregate buffer, non-nullable aggregator") {
+    val df = data.toDF("key", "value").coalesce(2)
+
+    // Test non-nullable typedMax
+    val query = df.select(typedMax(lit(null)), count($"key"), typedMax(lit(null)),
+      count($"value"))
+
+    // typedMax is not nullable
+    val maxNull = Int.MinValue
+    val countKey = data.size
+    val countValue = data.size
+    val expected = Seq(Row(maxNull, countKey, maxNull, countValue))
+    checkAnswer(query, expected)
+  }
+
+  test("dataframe aggregate with object aggregate buffer, nullable aggregator") {
+    val df = data.toDF("key", "value").coalesce(2)
+
+    // Test nullable nullableTypedMax
+    val query = df.select(nullableTypedMax(lit(null)), count($"key"), nullableTypedMax(lit(null)),
+      count($"value"))
+
+    // nullableTypedMax is nullable
+    val maxNull = null
+    val countKey = data.size
+    val countValue = data.size
+    val expected = Seq(Row(maxNull, countKey, maxNull, countValue))
+    checkAnswer(query, expected)
+  }
+
+  test("dataframe aggregation with object aggregate buffer, input row contains null") {
+
+    val nullableData = (0 until 1000).map {id =>
+      val nullableKey: Integer = if (random.nextBoolean()) null else random.nextInt(100)
+      val nullableValue: Integer = if (random.nextBoolean()) null else random.nextInt(100)
+      (nullableKey, nullableValue)
+    }
+
+    val df = nullableData.toDF("key", "value").coalesce(2)
+    val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"),
+      count($"value"))
+    val maxKey = nullableData.map(_._1).filter(_ != null).max
+    val countKey = nullableData.map(_._1).filter(_ != null).size
+    val maxValue = nullableData.map(_._2).filter(_ != null).max
+    val countValue = nullableData.map(_._2).filter(_ != null).size
+    val expected = Seq(Row(maxKey, countKey, maxValue, countValue))
+    checkAnswer(query, expected)
+  }
+
+  test("dataframe aggregate with object aggregate buffer, with group by") {
+    val df = data.toDF("value", "key").coalesce(2)
+    val query = df.groupBy($"key").agg(typedMax($"value"), count($"value"), typedMax($"value"))
+    val expected = data.groupBy(_._2).toSeq.map { group =>
+      val (key, values) = group
+      val valueMax = values.map(_._1).max
+      val countValue = values.size
+      Row(key, valueMax, countValue, valueMax)
+    }
+    checkAnswer(query, expected)
+  }
+
+  test("dataframe aggregate with object aggregate buffer, empty inputs, no group by") {
+    val empty = Seq.empty[(Int, Int)].toDF("a", "b")
+    checkAnswer(
+      empty.select(typedMax($"a"), count($"a"), typedMax($"b"), count($"b")),
+      Seq(Row(Int.MinValue, 0, Int.MinValue, 0)))
+  }
+
+  test("dataframe aggregate with object aggregate buffer, empty inputs, with group by") {
+    val empty = Seq.empty[(Int, Int)].toDF("a", "b")
+    checkAnswer(
+      empty.groupBy($"b").agg(typedMax($"a"), count($"a"), typedMax($"a")),
+      Seq.empty[Row])
+  }
+
+  test("TypedImperativeAggregate should not break Window function") {
+    val df = data.toDF("key", "value")
+    // OVER (PARTITION BY a ORDER BY b ROW BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
+    val w = Window.orderBy("value").partitionBy("key").rowsBetween(Long.MinValue, 0)
+
+    val query = df.select(sum($"key").over(w), typedMax($"key").over(w), sum($"value").over(w),
+      typedMax($"value").over(w))
+
+    val expected = data.groupBy(_._1).toSeq.flatMap { group =>
+      val (key, values) = group
+      val sortedValues = values.map(_._2).sorted
+
+      var outputRows = Seq.empty[Row]
+      var i = 0
+      while (i < sortedValues.size) {
+        val unboundedPrecedingAndCurrent = sortedValues.slice(0, i + 1)
+        val sumKey = key * unboundedPrecedingAndCurrent.size
+        val maxKey = key
+        val sumValue = unboundedPrecedingAndCurrent.sum
+        val maxValue = unboundedPrecedingAndCurrent.max
+
+        outputRows :+= Row(sumKey, maxKey, sumValue, maxValue)
+        i += 1
+      }
+
+      outputRows
+    }
+    checkAnswer(query, expected)
+  }
+
+  private def typedMax(column: Column): Column = {
+    val max = TypedMax(column.expr, nullable = false)
+    Column(max.toAggregateExpression())
+  }
+
+  private def nullableTypedMax(column: Column): Column = {
+    val max = TypedMax(column.expr, nullable = true)
+    Column(max.toAggregateExpression())
+  }
+}
+
+object TypedImperativeAggregateSuite {
+
+  /**
+   * Calculate the max value with object aggregation buffer. This stores class MaxValue
+   * in aggregation buffer.
+   */
+  private case class TypedMax(
+      child: Expression,
+      nullable: Boolean = false,
+      mutableAggBufferOffset: Int = 0,
+      inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] {
+
+
+    override def createAggregationBuffer(): MaxValue = {
+      // Returns Int.MinValue if all inputs are null
+      new MaxValue(Int.MinValue)
+    }
+
+    override def update(buffer: MaxValue, input: InternalRow): Unit = {
+      child.eval(input) match {
+        case inputValue: Int =>
+          if (inputValue > buffer.value) {
+            buffer.value = inputValue
+            buffer.isValueSet = true
+          }
+        case null => // skip
+      }
+    }
+
+    override def merge(bufferMax: MaxValue, inputMax: MaxValue): Unit = {
+      if (inputMax.value > bufferMax.value) {
+        bufferMax.value = inputMax.value
+        bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet
+      }
+    }
+
+    override def eval(bufferMax: MaxValue): Any = {
+      if (nullable && bufferMax.isValueSet == false) {
+        null
+      } else {
+        bufferMax.value
+      }
+    }
+
+    override def deterministic: Boolean = true
+
+    override def children: Seq[Expression] = Seq(child)
+
+    override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
+
+    override def dataType: DataType = IntegerType
+
+    override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
+      copy(mutableAggBufferOffset = newOffset)
+
+    override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
+      copy(inputAggBufferOffset = newOffset)
+
+    override def serialize(buffer: MaxValue): Array[Byte] = {
+      val out = new ByteArrayOutputStream()
+      val stream = new DataOutputStream(out)
+      stream.writeBoolean(buffer.isValueSet)
+      stream.writeInt(buffer.value)
+      out.toByteArray
+    }
+
+    override def deserialize(storageFormat: Array[Byte]): MaxValue = {
+      val in = new ByteArrayInputStream(storageFormat)
+      val stream = new DataInputStream(in)
+      val isValueSet = stream.readBoolean()
+      val value = stream.readInt()
+      new MaxValue(value, isValueSet)
+    }
+  }
+
+  private class MaxValue(var value: Int, var isValueSet: Boolean = false)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org