You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/04/21 21:06:49 UTC
[3/5] flink git commit: [FLINK-6242] [table] Add code generation for
DataSet Aggregates
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala
index 8eac79d..381d443 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala
@@ -20,7 +20,8 @@ package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
import org.apache.flink.api.common.functions.CombineFunction
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.types.Row
/**
@@ -30,87 +31,62 @@ import org.apache.flink.types.Row
*
* It is used for sliding on batch for both time and count-windows.
*
- * @param aggregates aggregate functions.
- * @param groupKeysMapping index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param finalRowArity output row field count
+ * @param genPreAggregations Code-generated [[GeneratedAggregations]] for partial aggregation.
+ * @param genFinalAggregations Code-generated [[GeneratedAggregations]] for final aggregation.
+ * @param keysAndAggregatesArity The total arity of keys and aggregates
* @param finalRowWindowStartPos relative window-start position to last field of output row
* @param finalRowWindowEndPos relative window-end position to last field of output row
* @param windowSize size of the window, used to determine window-end for output row
*/
class DataSetSlideWindowAggReduceCombineFunction(
- aggregates: Array[AggregateFunction[_ <: Any]],
- groupKeysMapping: Array[(Int, Int)],
- aggregateMapping: Array[(Int, Int)],
- finalRowArity: Int,
+ genPreAggregations: GeneratedAggregationsFunction,
+ genFinalAggregations: GeneratedAggregationsFunction,
+ keysAndAggregatesArity: Int,
finalRowWindowStartPos: Option[Int],
finalRowWindowEndPos: Option[Int],
windowSize: Long)
extends DataSetSlideWindowAggReduceGroupFunction(
- aggregates,
- groupKeysMapping,
- aggregateMapping,
- finalRowArity,
+ genFinalAggregations,
+ keysAndAggregatesArity,
finalRowWindowStartPos,
finalRowWindowEndPos,
windowSize)
with CombineFunction[Row, Row] {
- private val intermediateRowArity: Int = groupKeysMapping.length + aggregateMapping.length + 1
- private val intermediateRow: Row = new Row(intermediateRowArity)
+ private val intermediateRow: Row = new Row(keysAndAggregatesArity + 1)
- override def combine(records: Iterable[Row]): Row = {
-
- // reset first accumulator
- var i = 0
- while (i < aggregates.length) {
- aggregates(i).resetAccumulator(accumulatorList(i).get(0))
- i += 1
- }
-
- val iterator = records.iterator()
- while (iterator.hasNext) {
- val record = iterator.next()
+ protected var preAggfunction: GeneratedAggregations = _
- // accumulate
- i = 0
- while (i < aggregates.length) {
- // insert received accumulator into acc list
- val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator]
- accumulatorList(i).set(1, newAcc)
- // merge acc list
- val retAcc = aggregates(i).merge(accumulatorList(i))
- // insert result into acc list
- accumulatorList(i).set(0, retAcc)
- i += 1
- }
+ override def open(config: Configuration): Unit = {
+ super.open(config)
- // check if this record is the last record
- if (!iterator.hasNext) {
- // set group keys
- i = 0
- while (i < groupKeysMapping.length) {
- intermediateRow.setField(i, record.getField(i))
- i += 1
- }
+ LOG.debug(s"Compiling AggregateHelper: $genPreAggregations.name \n\n " +
+ s"Code:\n$genPreAggregations.code")
+ val clazz = compile(
+ getClass.getClassLoader,
+ genPreAggregations.name,
+ genPreAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ preAggfunction = clazz.newInstance()
+ }
- // set the partial accumulated result
- i = 0
- while (i < aggregates.length) {
- intermediateRow.setField(groupKeysMapping.length + i, accumulatorList(i).get(0))
- i += 1
- }
+ override def combine(records: Iterable[Row]): Row = {
- intermediateRow.setField(windowStartPos, record.getField(windowStartPos))
+ // reset accumulator
+ preAggfunction.resetAccumulator(accumulators)
- return intermediateRow
- }
+ val iterator = records.iterator()
+ var record: Row = null
+ while (iterator.hasNext) {
+ record = iterator.next()
+ preAggfunction.mergeAccumulatorsPair(accumulators, record)
}
+ // set group keys and partial accumulated result
+ preAggfunction.setAggregationResults(accumulators, intermediateRow)
+ preAggfunction.setForwardedFields(record, intermediateRow)
+
+ intermediateRow.setField(windowStartPos, record.getField(windowStartPos))
- // this code path should never be reached as we return before the loop finishes
- // we need this to prevent a compiler error
- throw new IllegalArgumentException("Group is empty. This should never happen.")
+ intermediateRow
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala
index d6bc006..a221c53 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala
@@ -18,13 +18,13 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
/**
* It wraps the aggregate logic inside of
@@ -32,109 +32,72 @@ import org.apache.flink.util.{Collector, Preconditions}
*
* It is used for sliding on batch for both time and count-windows.
*
- * @param aggregates aggregate functions.
- * @param groupKeysMapping index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param finalRowArity output row field count
+ * @param genAggregations Code-generated [[GeneratedAggregations]]
+ * @param keysAndAggregatesArity The total arity of keys and aggregates
* @param finalRowWindowStartPos relative window-start position to last field of output row
* @param finalRowWindowEndPos relative window-end position to last field of output row
* @param windowSize size of the window, used to determine window-end for output row
*/
class DataSetSlideWindowAggReduceGroupFunction(
- aggregates: Array[AggregateFunction[_ <: Any]],
- groupKeysMapping: Array[(Int, Int)],
- aggregateMapping: Array[(Int, Int)],
- finalRowArity: Int,
+ genAggregations: GeneratedAggregationsFunction,
+ keysAndAggregatesArity: Int,
finalRowWindowStartPos: Option[Int],
finalRowWindowEndPos: Option[Int],
windowSize: Long)
- extends RichGroupReduceFunction[Row, Row] {
-
- Preconditions.checkNotNull(aggregates)
- Preconditions.checkNotNull(groupKeysMapping)
+ extends RichGroupReduceFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
private var collector: TimeWindowPropertyCollector = _
+ protected val windowStartPos: Int = keysAndAggregatesArity
+
private var output: Row = _
- private val accumulatorStartPos: Int = groupKeysMapping.length
- protected val windowStartPos: Int = accumulatorStartPos + aggregates.length
+ protected var accumulators: Row = _
- val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) {
- new JArrayList[Accumulator](2)
- }
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ protected var function: GeneratedAggregations = _
override def open(config: Configuration) {
- output = new Row(finalRowArity)
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getClass.getClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+ accumulators = function.createAccumulators()
collector = new TimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos)
-
- // init lists with two empty accumulators
- var i = 0
- while (i < aggregates.length) {
- val accumulator = aggregates(i).createAccumulator()
- accumulatorList(i).add(accumulator)
- accumulatorList(i).add(accumulator)
- i += 1
- }
}
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
- // reset first accumulator
- var i = 0
- while (i < aggregates.length) {
- aggregates(i).resetAccumulator(accumulatorList(i).get(0))
- i += 1
- }
+ // reset accumulator
+ function.resetAccumulator(accumulators)
val iterator = records.iterator()
+ var record: Row = null
while (iterator.hasNext) {
- val record = iterator.next()
-
- // accumulate
- i = 0
- while (i < aggregates.length) {
- // insert received accumulator into acc list
- val newAcc = record.getField(accumulatorStartPos + i).asInstanceOf[Accumulator]
- accumulatorList(i).set(1, newAcc)
- // merge acc list
- val retAcc = aggregates(i).merge(accumulatorList(i))
- // insert result into acc list
- accumulatorList(i).set(0, retAcc)
- i += 1
- }
+ record = iterator.next()
+ function.mergeAccumulatorsPair(accumulators, record)
+ }
- // check if this record is the last record
- if (!iterator.hasNext) {
- // set group keys value to final output
- i = 0
- while (i < groupKeysMapping.length) {
- val mapping = groupKeysMapping(i)
- output.setField(mapping._1, record.getField(mapping._2))
- i += 1
- }
+ // set group keys value to final output
+ function.setForwardedFields(record, output)
- // get final aggregate value and set to output.
- i = 0
- while (i < aggregateMapping.length) {
- val mapping = aggregateMapping(i)
- val agg = aggregates(i)
- val result = agg.getValue(accumulatorList(mapping._2).get(0))
- output.setField(mapping._1, result)
- i += 1
- }
+ // get final aggregate value and set to output
+ function.setAggregationResults(accumulators, output)
- // adds TimeWindow properties to output then emit output
- if (finalRowWindowStartPos.isDefined || finalRowWindowEndPos.isDefined) {
- collector.wrappedCollector = out
- collector.windowStart = record.getField(windowStartPos).asInstanceOf[Long]
- collector.windowEnd = collector.windowStart + windowSize
+ // adds TimeWindow properties to output then emit output
+ if (finalRowWindowStartPos.isDefined || finalRowWindowEndPos.isDefined) {
+ collector.wrappedCollector = out
+ collector.windowStart = record.getField(windowStartPos).asInstanceOf[Long]
+ collector.windowEnd = collector.windowStart + windowSize
- collector.collect(output)
- } else {
- out.collect(output)
- }
- }
+ collector.collect(output)
+ } else {
+ out.collect(output)
}
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
index 0a525f8..0e73f7b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
@@ -18,106 +18,69 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
/**
* It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]].
* It is only used for tumbling count-window on batch.
*
+ * @param genAggregations Code-generated [[GeneratedAggregations]]
* @param windowSize Tumble count window size
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param finalRowArity The output row field count
*/
class DataSetTumbleCountWindowAggReduceGroupFunction(
- private val windowSize: Long,
- private val aggregates: Array[AggregateFunction[_ <: Any]],
- private val groupKeysMapping: Array[(Int, Int)],
- private val aggregateMapping: Array[(Int, Int)],
- private val finalRowArity: Int)
- extends RichGroupReduceFunction[Row, Row] {
-
- Preconditions.checkNotNull(aggregates)
- Preconditions.checkNotNull(groupKeysMapping)
+ private val genAggregations: GeneratedAggregationsFunction,
+ private val windowSize: Long)
+ extends RichGroupReduceFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
private var output: Row = _
- private val accumStartPos: Int = groupKeysMapping.length
+ private var accumulators: Row = _
- val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) {
- new JArrayList[Accumulator](2)
- }
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ private var function: GeneratedAggregations = _
override def open(config: Configuration) {
- output = new Row(finalRowArity)
-
- // init lists with two empty accumulators
- for (i <- aggregates.indices) {
- val accumulator = aggregates(i).createAccumulator()
- accumulatorList(i).add(accumulator)
- accumulatorList(i).add(accumulator)
- }
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getClass.getClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+ accumulators = function.createAccumulators()
}
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
var count: Long = 0
val iterator = records.iterator()
- var i = 0
while (iterator.hasNext) {
if (count == 0) {
- // reset first accumulator
- i = 0
- while (i < aggregates.length) {
- aggregates(i).resetAccumulator(accumulatorList(i).get(0))
- i += 1
- }
+ function.resetAccumulator(accumulators)
}
val record = iterator.next()
count += 1
- i = 0
- while (i < aggregates.length) {
- // insert received accumulator into acc list
- val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator]
- accumulatorList(i).set(1, newAcc)
- // merge acc list
- val retAcc = aggregates(i).merge(accumulatorList(i))
- // insert result into acc list
- accumulatorList(i).set(0, retAcc)
- i += 1
- }
+ accumulators = function.mergeAccumulatorsPair(accumulators, record)
if (windowSize == count) {
// set group keys value to final output.
- i = 0
- while (i < groupKeysMapping.length) {
- val (after, previous) = groupKeysMapping(i)
- output.setField(after, record.getField(previous))
- i += 1
- }
-
- // merge the accumulators and then get value for the final output
- i = 0
- while (i < aggregateMapping.length) {
- val (after, previous) = aggregateMapping(i)
- val agg = aggregates(previous)
- output.setField(after, agg.getValue(accumulatorList(previous).get(0)))
- i += 1
- }
+ function.setForwardedFields(record, output)
+ function.setAggregationResults(accumulators, output)
// emit the output
out.collect(output)
count = 0
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
index 904c76c..4a459b2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
@@ -20,7 +20,8 @@ package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
import org.apache.flink.api.common.functions.CombineFunction
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.types.Row
/**
@@ -29,34 +30,45 @@ import org.apache.flink.types.Row
* [[org.apache.flink.api.java.operators.GroupCombineOperator]].
* It is used for tumbling time-window on batch.
*
- * @param windowSize Tumbling time window size
- * @param windowStartPos The relative window-start field position to the last field of output row
- * @param windowEndPos The relative window-end field position to the last field of output row
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param finalRowArity The output row field count
+ * @param genPreAggregations Code-generated [[GeneratedAggregations]] for partial aggs.
+ * @param genFinalAggregations Code-generated [[GeneratedAggregations]] for final aggs.
+ * @param windowSize Tumbling time window size
+ * @param windowStartPos The relative window-start field position to the last field of
+ * output row
+ * @param windowEndPos The relative window-end field position to the last field of
+ * output row
+ * @param keysAndAggregatesArity The total arity of keys and aggregates
*/
class DataSetTumbleTimeWindowAggReduceCombineFunction(
+ genPreAggregations: GeneratedAggregationsFunction,
+ genFinalAggregations: GeneratedAggregationsFunction,
windowSize: Long,
windowStartPos: Option[Int],
windowEndPos: Option[Int],
- aggregates: Array[AggregateFunction[_ <: Any]],
- groupKeysMapping: Array[(Int, Int)],
- aggregateMapping: Array[(Int, Int)],
- finalRowArity: Int)
+ keysAndAggregatesArity: Int)
extends DataSetTumbleTimeWindowAggReduceGroupFunction(
+ genFinalAggregations,
windowSize,
windowStartPos,
windowEndPos,
- aggregates,
- groupKeysMapping,
- aggregateMapping,
- finalRowArity)
+ keysAndAggregatesArity)
with CombineFunction[Row, Row] {
+ protected var preAggfunction: GeneratedAggregations = _
+
+ override def open(config: Configuration): Unit = {
+ super.open(config)
+
+ LOG.debug(s"Compiling AggregateHelper: $genPreAggregations.name \n\n " +
+ s"Code:\n$genPreAggregations.code")
+ val clazz = compile(
+ getClass.getClassLoader,
+ genPreAggregations.name,
+ genPreAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ preAggfunction = clazz.newInstance()
+ }
+
/**
* For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
*
@@ -69,47 +81,21 @@ class DataSetTumbleTimeWindowAggReduceCombineFunction(
var last: Row = null
val iterator = records.iterator()
- // reset first accumulator in merge list
- var i = 0
- while (i < aggregates.length) {
- aggregates(i).resetAccumulator(accumulatorList(i).get(0))
- i += 1
- }
+ // reset accumulator
+ preAggfunction.resetAccumulator(accumulators)
while (iterator.hasNext) {
val record = iterator.next()
-
- i = 0
- while (i < aggregates.length) {
- // insert received accumulator into acc list
- val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator]
- accumulatorList(i).set(1, newAcc)
- // merge acc list
- val retAcc = aggregates(i).merge(accumulatorList(i))
- // insert result into acc list
- accumulatorList(i).set(0, retAcc)
- i += 1
- }
-
+ preAggfunction.mergeAccumulatorsPair(accumulators, record)
last = record
}
- // set the partial merged result to the aggregateBuffer
- i = 0
- while (i < aggregates.length) {
- aggregateBuffer.setField(groupKeysMapping.length + i, accumulatorList(i).get(0))
- i += 1
- }
-
- // set group keys to aggregateBuffer.
- i = 0
- while (i < groupKeysMapping.length) {
- aggregateBuffer.setField(i, last.getField(i))
- i += 1
- }
+ // set group keys and partial merged result to aggregateBuffer
+ preAggfunction.setAggregationResults(accumulators, aggregateBuffer)
+ preAggfunction.setForwardedFields(last, aggregateBuffer)
// set the rowtime attribute
- val rowtimePos = groupKeysMapping.length + aggregates.length
+ val rowtimePos = keysAndAggregatesArity
aggregateBuffer.setField(rowtimePos, last.getField(rowtimePos))
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
index 99e2a0a..f4a1fc5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
@@ -18,65 +18,56 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
/**
* It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]]. It is used for tumbling time-window
* on batch.
*
+ * @param genAggregations Code-generated [[GeneratedAggregations]]
* @param windowSize Tumbling time window size
* @param windowStartPos The relative window-start field position to the last field of output row
* @param windowEndPos The relative window-end field position to the last field of output row
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param finalRowArity The output row field count
+ * @param keysAndAggregatesArity The total arity of keys and aggregates
*/
class DataSetTumbleTimeWindowAggReduceGroupFunction(
+ genAggregations: GeneratedAggregationsFunction,
windowSize: Long,
windowStartPos: Option[Int],
windowEndPos: Option[Int],
- aggregates: Array[AggregateFunction[_ <: Any]],
- groupKeysMapping: Array[(Int, Int)],
- aggregateMapping: Array[(Int, Int)],
- finalRowArity: Int)
- extends RichGroupReduceFunction[Row, Row] {
-
- Preconditions.checkNotNull(aggregates)
- Preconditions.checkNotNull(groupKeysMapping)
+ keysAndAggregatesArity: Int)
+ extends RichGroupReduceFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
private var collector: TimeWindowPropertyCollector = _
- protected var aggregateBuffer: Row = _
- private var output: Row = _
- private val accumStartPos: Int = groupKeysMapping.length
- private val rowtimePos: Int = accumStartPos + aggregates.length
- private val intermediateRowArity: Int = rowtimePos + 1
+ protected var aggregateBuffer: Row = new Row(keysAndAggregatesArity + 1)
+ private var output: Row = _
+ protected var accumulators: Row = _
- val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) {
- new JArrayList[Accumulator](2)
- }
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ protected var function: GeneratedAggregations = _
override def open(config: Configuration) {
- aggregateBuffer = new Row(intermediateRowArity)
- output = new Row(finalRowArity)
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getClass.getClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+ accumulators = function.createAccumulators()
collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos)
-
- // init lists with two empty accumulators
- for (i <- aggregates.indices) {
- val accumulator = aggregates(i).createAccumulator()
- accumulatorList(i).add(accumulator)
- accumulatorList(i).add(accumulator)
- }
}
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
@@ -84,51 +75,23 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction(
var last: Row = null
val iterator = records.iterator()
- // reset first accumulator in merge list
- var i = 0
- while (i < aggregates.length) {
- aggregates(i).resetAccumulator(accumulatorList(i).get(0))
- i += 1
- }
+ // reset accumulator
+ function.resetAccumulator(accumulators)
while (iterator.hasNext) {
val record = iterator.next()
-
- i = 0
- while (i < aggregates.length) {
- // insert received accumulator into acc list
- val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator]
- accumulatorList(i).set(1, newAcc)
- // merge acc list
- val retAcc = aggregates(i).merge(accumulatorList(i))
- // insert result into acc list
- accumulatorList(i).set(0, retAcc)
- i += 1
- }
-
+ function.mergeAccumulatorsPair(accumulators, record)
last = record
}
// set group keys value to final output.
- i = 0
- while (i < groupKeysMapping.length) {
- val (after, previous) = groupKeysMapping(i)
- output.setField(after, last.getField(previous))
- i += 1
- }
+ function.setForwardedFields(last, output)
// get final aggregate value and set to output.
- i = 0
- while (i < aggregateMapping.length) {
- val (after, previous) = aggregateMapping(i)
- val agg = aggregates(previous)
- val result = agg.getValue(accumulatorList(previous).get(0))
- output.setField(after, result)
- i += 1
- }
+ function.setAggregationResults(accumulators, output)
// get window start timestamp
- val startTs: Long = last.getField(rowtimePos).asInstanceOf[Long]
+ val startTs: Long = last.getField(keysAndAggregatesArity).asInstanceOf[Long]
// set collector and window
collector.wrappedCollector = out
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala
index 5cc7ada..d49ed0e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala
@@ -25,58 +25,60 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.apache.flink.types.Row
-import org.apache.flink.util.Preconditions
-
+import org.slf4j.LoggerFactory
/**
* This map function only works for windows on batch tables.
* It appends an (aligned) rowtime field to the end of the output row.
+ *
+ * @param genAggregations Code-generated [[GeneratedAggregations]]
+ * @param timeFieldPos Time field position in input row
+ * @param tumbleTimeWindowSize The size of tumble time window
*/
class DataSetWindowAggMapFunction(
- private val aggregates: Array[AggregateFunction[_]],
- private val aggFields: Array[Array[Int]],
- private val groupingKeys: Array[Int],
- private val timeFieldPos: Int, // time field position in input row
+ private val genAggregations: GeneratedAggregationsFunction,
+ private val timeFieldPos: Int,
private val tumbleTimeWindowSize: Option[Long],
@transient private val returnType: TypeInformation[Row])
- extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] {
-
- Preconditions.checkNotNull(aggregates)
- Preconditions.checkNotNull(aggFields)
- Preconditions.checkArgument(aggregates.length == aggFields.length)
+ extends RichMapFunction[Row, Row]
+ with ResultTypeQueryable[Row]
+ with Compiler[GeneratedAggregations] {
+ private var accs: Row = _
private var output: Row = _
- // add one more arity to store rowtime
- private val partialRowLength = groupingKeys.length + aggregates.length + 1
- // rowtime index in the buffer output row
- private val rowtimeIndex: Int = partialRowLength - 1
+
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ private var function: GeneratedAggregations = _
override def open(config: Configuration) {
- output = new Row(partialRowLength)
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ accs = function.createAccumulators()
+ output = function.createOutputRow()
}
override def map(input: Row): Row = {
- var i = 0
- while (i < aggregates.length) {
- val agg = aggregates(i)
- val fieldValue = input.getField(aggFields(i)(0))
- val accumulator = agg.createAccumulator()
- agg.accumulate(accumulator, fieldValue)
- output.setField(groupingKeys.length + i, accumulator)
- i += 1
- }
+ function.resetAccumulator(accs)
- i = 0
- while (i < groupingKeys.length) {
- output.setField(i, input.getField(groupingKeys(i)))
- i += 1
- }
+ function.accumulate(accs, input)
+
+ function.setAggregationResults(accs, output)
+
+ function.setForwardedFields(input, output)
val timeField = input.getField(timeFieldPos)
val rowtime = getTimestamp(timeField)
+ val rowtimeIndex = output.getArity - 1
if (tumbleTimeWindowSize.isDefined) {
// in case of tumble time window, align rowtime to window start to represent the window
output.setField(
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
index 17a1128..bee39fa 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
@@ -27,7 +27,9 @@ import org.apache.flink.types.Row
abstract class GeneratedAggregations extends Function {
/**
- * Calculate the results from accumulators, and set the results to the output
+ * Sets the results of the aggregations (partial or final) to the output row.
+ * Final results are computed with the aggregation function.
+ * Partial results are the accumulators themselves.
*
* @param accumulators the accumulators (saved in a row) which contains the current
* aggregated results
@@ -36,15 +38,22 @@ abstract class GeneratedAggregations extends Function {
def setAggregationResults(accumulators: Row, output: Row)
/**
- * Copies forwarded fields from input row to output row.
+ * Copies forwarded fields, such as grouping keys, from input row to output row.
*
- * @param input input values bundled in a row
- * @param output output results collected in a row
+ * @param input input values bundled in a row
+ * @param output output results collected in a row
*/
def setForwardedFields(input: Row, output: Row)
/**
- * Accumulate the input values to the accumulators
+ * Sets constant flags (boolean fields) to an output row.
+ *
+ * @param output The output row to which the constant flags are set.
+ */
+ def setConstantFlags(output: Row)
+
+ /**
+ * Accumulates the input values to the accumulators.
*
* @param accumulators the accumulators (saved in a row) which contains the current
* aggregated results
@@ -53,7 +62,7 @@ abstract class GeneratedAggregations extends Function {
def accumulate(accumulators: Row, input: Row)
/**
- * Retract the input values from the accumulators
+ * Retracts the input values from the accumulators.
*
* @param accumulators the accumulators (saved in a row) which contains the current
* aggregated results
@@ -62,7 +71,7 @@ abstract class GeneratedAggregations extends Function {
def retract(accumulators: Row, input: Row)
/**
- * Init the accumulators, and save them to a accumulators Row.
+ * Initializes the accumulators and save them to a accumulators row.
*
* @return a row of accumulators which contains the aggregated results
*/
@@ -76,7 +85,7 @@ abstract class GeneratedAggregations extends Function {
def createOutputRow(): Row
/**
- * Merges two rows of accumulators into one row
+ * Merges two rows of accumulators into one row.
*
* @param a First row of accumulators
* @param b The other row of accumulators
@@ -84,4 +93,11 @@ abstract class GeneratedAggregations extends Function {
*/
def mergeAccumulatorsPair(a: Row, b: Row): Row
+ /**
+ * Resets all the accumulators.
+ *
+ * @param accumulators the accumulators (saved in a row) which contains the current
+ * aggregated results
+ */
+ def resetAccumulator(accumulators: Row)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala
index 5ac09b9..4838747 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala
@@ -339,5 +339,6 @@ class AggregationsITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
- case class WC(word: String, frequency: Long)
}
+
+case class WC(word: String, frequency: Long)
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala
index adc84bf..16c493e 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala
@@ -159,12 +159,18 @@ class BoundedProcessingOverRangeProcessFunctionTest {
| return new org.apache.flink.types.Row(7);
| }
|
- | //The test won't use this method
+ |/******* This test does not use the following methods *******/
| public org.apache.flink.types.Row mergeAccumulatorsPair(
| org.apache.flink.types.Row a,
| org.apache.flink.types.Row b) {
| return null;
| }
+ |
+ | public void resetAccumulator(org.apache.flink.types.Row accs) {
+ | }
+ |
+ | public void setConstantFlags(org.apache.flink.types.Row output) {
+ | }
|}
""".stripMargin