You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2018/12/17 08:25:18 UTC
[flink] 03/06: [FLINK-7599][table,
cep] Support for aggregates in MATCH_RECOGNIZE
This is an automated email from the ASF dual-hosted git repository.
dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1458caf77943704c345bed1a6ef517f35f858fb8
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Mon Nov 26 14:27:22 2018 +0100
[FLINK-7599][table, cep] Support for aggregates in MATCH_RECOGNIZE
---
.../apache/flink/table/codegen/CodeGenerator.scala | 37 +-
.../flink/table/codegen/MatchCodeGenerator.scala | 425 +++++++++++++++++++--
.../rules/datastream/DataStreamMatchRule.scala | 40 +-
.../org/apache/flink/table/util/MatchUtil.scala | 53 +++
...st.scala => MatchRecognizeValidationTest.scala} | 78 ++--
.../runtime/stream/sql/MatchRecognizeITCase.scala | 168 +++++++-
6 files changed, 709 insertions(+), 92 deletions(-)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index fdb0d50..6c05ff0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -1048,6 +1048,25 @@ abstract class CodeGenerator(
// generator helping methods
// ----------------------------------------------------------------------------------------------
+ protected def makeReusableInSplits(exprs: Iterable[GeneratedExpression]): Unit = {
+ // add results of expressions to member area such that all split functions can access it
+ exprs.foreach { expr =>
+
+ // declaration
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(expr.resultType)
+ if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
+ reusableMemberStatements.add(s"private boolean ${expr.nullTerm};")
+ }
+ reusableMemberStatements.add(s"private $resultTypeTerm ${expr.resultTerm};")
+
+ // assignment
+ if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
+ reusablePerRecordStatements.add(s"this.${expr.nullTerm} = ${expr.nullTerm};")
+ }
+ reusablePerRecordStatements.add(s"this.${expr.resultTerm} = ${expr.resultTerm};")
+ }
+ }
+
private def generateCodeSplits(splits: Seq[String]): String = {
val totalLen = splits.map(_.length + 1).sum // 1 for a line break
@@ -1057,21 +1076,7 @@ abstract class CodeGenerator(
hasCodeSplits = true
// add input unboxing to member area such that all split functions can access it
- reusableInputUnboxingExprs.foreach { case (_, expr) =>
-
- // declaration
- val resultTypeTerm = primitiveTypeTermForTypeInfo(expr.resultType)
- if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
- reusableMemberStatements.add(s"private boolean ${expr.nullTerm};")
- }
- reusableMemberStatements.add(s"private $resultTypeTerm ${expr.resultTerm};")
-
- // assignment
- if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
- reusablePerRecordStatements.add(s"this.${expr.nullTerm} = ${expr.nullTerm};")
- }
- reusablePerRecordStatements.add(s"this.${expr.resultTerm} = ${expr.resultTerm};")
- }
+ makeReusableInSplits(reusableInputUnboxingExprs.values)
// add split methods to the member area and return the code necessary to call those methods
val methodCalls = splits.map { split =>
@@ -1196,7 +1201,7 @@ abstract class CodeGenerator(
GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType)
}
- private def generateFieldAccess(
+ protected def generateFieldAccess(
inputType: TypeInformation[_],
inputTerm: String,
index: Int)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
index 8f8a7f1..ffa7fc2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
@@ -22,19 +22,26 @@ import java.lang.{Long => JLong}
import java.util
import org.apache.calcite.rel.RelFieldCollation
+import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex._
+import org.apache.calcite.sql.SqlAggFunction
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.flink.api.common.functions._
import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.cep.pattern.conditions.{IterativeCondition, RichIterativeCondition}
import org.apache.flink.cep.{RichPatternFlatSelectFunction, RichPatternSelectFunction}
import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.api.dataview.DataViewSpec
import org.apache.flink.table.api.{TableConfig, TableException}
-import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName}
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName, primitiveDefaultValue, primitiveTypeTermForTypeInfo}
import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.Indenter.toISC
+import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction}
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.`match`.{IterativeConditionRunner, PatternSelectFunctionRunner}
+import org.apache.flink.table.runtime.aggregate.AggregateUtil
+import org.apache.flink.table.util.MatchUtil.{ALL_PATTERN_VARIABLE, AggregationPatternVariableFinder}
import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.types.Row
import org.apache.flink.util.Collector
@@ -53,7 +60,7 @@ object MatchCodeGenerator {
names: Seq[String])
: IterativeConditionRunner = {
val generator = new MatchCodeGenerator(config, inputTypeInfo, names, Some(patternName))
- val condition = generator.generateExpression(patternDefinition)
+ val condition = generator.generateCondition(patternDefinition)
val body =
s"""
|${condition.code}
@@ -101,6 +108,90 @@ object MatchCodeGenerator {
/**
* A code generator for generating CEP related functions.
*
+ * Aggregates are generated as follows:
+ * 1. all aggregate [[RexCall]]s are grouped by corresponding pattern variable
+ * 2. even if the same aggregation is used multiple times in an expression
+ * (e.g. SUM(A.price) > SUM(A.price) + 1) it will be calculated once. To do so [[AggBuilder]]
+ * keeps set of already seen different aggregation calls, and reuses the code to access
+ * appropriate field of aggregation result
+ * 3. after translating every expression (either in [[generateCondition]] or in
+ * [[generateOneRowPerMatchExpression]]) there will be generated code for
+ * - [[GeneratedFunction]], which will be an inner class
+ * - said [[GeneratedFunction]] will be instantiated in the ctor and opened/closed
+ * in corresponding methods of top level generated classes
+ * - function that transforms input rows (row by row) into aggregate input rows
+ * - function that calculates aggregates for variable, that uses the previous method
+ * The generated code will look similar to this:
+ *
+ *
+ * {{{
+ *
+ * public class MatchRecognizePatternSelectFunction$175 extends RichPatternSelectFunction {
+ *
+ * // Class used to calculate aggregates for a single pattern variable
+ * public final class AggFunction_variable$115$151 extends GeneratedAggregations {
+ * ...
+ * }
+ *
+ * private final AggFunction_variable$115$151 aggregator_variable$115;
+ *
+ * public MatchRecognizePatternSelectFunction$175() {
+ * aggregator_variable$115 = new AggFunction_variable$115$151();
+ * }
+ *
+ * public void open() {
+ * aggregator_variable$115.open();
+ * ...
+ * }
+ *
+ * // Function to transform incoming row into aggregate specific row. It can e.g calculate
+ * // inner expression of said aggregate
+ * private Row transformRowForAgg_variable$115(Row inAgg) {
+ * ...
+ * }
+ *
+ * // Function to calculate all aggregates for a single pattern variable
+ * private Row calculateAgg_variable$115(List<Row> input) {
+ * Acc accumulator = aggregator_variable$115.createAccumulator();
+ * for (Row row : input) {
+ * aggregator_variable$115.accumulate(accumulator, transformRowForAgg_variable$115(row));
+ * }
+ *
+ * return aggregator_variable$115.getResult(accumulator);
+ * }
+ *
+ * @Override
+ * public Object select(Map<String, List<Row>> in1) throws Exception {
+ *
+ * // Extract list of rows assigned to a single pattern variable
+ * java.util.List patternEvents$130 = (java.util.List) in1.get("A");
+ * ...
+ *
+ * // Calculate aggregates
+ * Row aggRow_variable$110$111 = calculateAgg_variable$110(patternEvents$114);
+ *
+ * // Every aggregation (e.g SUM(A.price) and AVG(A.price)) will be extracted to a variable
+ * double result$135 = aggRow_variable$126$127.getField(0);
+ * long result$137 = aggRow_variable$126$127.getField(1);
+ *
+ * // Result of aggregation will be used in expression evaluation
+ * out.setField(0, result$135)
+ *
+ * long result$140 = result$137 * 2;
+ * out.setField(1, result$140);
+ *
+ * double result$144 = $result135 + result$137;
+ * out.setField(2, result$144);
+ * }
+ *
+ * public void close() {
+ * aggregator_variable$115.close();
+ * ...
+ * }
+ *
+ * }
+ * }}}
+ *
* @param config configuration that determines runtime behavior
* @param patternNames sorted sequence of pattern variables
* @param input type information about the first input of the Function
@@ -124,17 +215,39 @@ class MatchCodeGenerator(
.HashMap[String, GeneratedPatternList]()
/**
+ * Used to deduplicate aggregations calculation. The deduplication is performed by
+ * [[RexNode.toString]]. Those expressions needs to be accessible from splits, if such exists.
+ */
+ private val reusableAggregationExpr = new mutable.HashMap[String, GeneratedExpression]()
+
+ /**
* Context information used by Pattern reference variable to index rows mapped to it.
* Indexes element at offset either from beginning or the end based on the value of first.
*/
private var offset: Int = 0
private var first : Boolean = false
+ /**
+ * Flags that tells if we generate expressions inside an aggregate. It tells how to access input
+ * row.
+ */
+ private var isWithinAggExprState: Boolean = false
+
+ /**
+ * Name of term in function used to transform input row into aggregate input row.
+ */
+ private val inputAggRowTerm = "inAgg"
+
/** Term for row for key extraction */
- private val keyRowTerm = newName("keyRow")
+ private val keyRowTerm = "keyRow"
/** Term for list of all pattern names */
- private val patternNamesTerm = newName("patternNames")
+ private val patternNamesTerm = "patternNames"
+
+ /**
+ * Used to collect all aggregates per pattern variable.
+ */
+ private val aggregatesPerVariable = new mutable.HashMap[String, AggBuilder]
/**
* Sets the new reference variable indexing context. This should be used when resolving logical
@@ -252,7 +365,6 @@ class MatchCodeGenerator(
private def generateKeyRow() : GeneratedExpression = {
val exp = reusableInputUnboxingExprs
.get((keyRowTerm, 0)) match {
- // input access and unboxing has already been generated
case Some(expr) =>
expr
@@ -310,10 +422,26 @@ class MatchCodeGenerator(
generateExpression(measures.get(fieldName))
}
- generateResultExpression(
+ val exp = generateResultExpression(
resultExprs,
returnType.typeInfo,
returnType.fieldNames)
+ aggregatesPerVariable.values.foreach(_.generateAggFunction())
+ if (hasCodeSplits) {
+ makeReusableInSplits(reusableAggregationExpr.values)
+ }
+
+ exp
+ }
+
+ def generateCondition(call: RexNode): GeneratedExpression = {
+ val exp = call.accept(this)
+ aggregatesPerVariable.values.foreach(_.generateAggFunction())
+ if (hasCodeSplits) {
+ makeReusableInSplits(reusableAggregationExpr.values)
+ }
+
+ exp
}
override def visitCall(call: RexCall): GeneratedExpression = {
@@ -341,6 +469,21 @@ class MatchCodeGenerator(
case FINAL =>
call.getOperands.get(0).accept(this)
+ case _ : SqlAggFunction =>
+
+ val variable = call.accept(new AggregationPatternVariableFinder)
+ .getOrElse(throw new TableException("No pattern variable specified in aggregate"))
+
+ val matchAgg = aggregatesPerVariable.get(variable) match {
+ case Some(agg) => agg
+ case None =>
+ val agg = new AggBuilder(variable)
+ aggregatesPerVariable(variable) = agg
+ agg
+ }
+
+ matchAgg.generateDeduplicatedAggAccess(call)
+
case _ => super.visitCall(call)
}
}
@@ -357,10 +500,15 @@ class MatchCodeGenerator(
}
override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
- if (fieldRef.getAlpha.equals("*") && currentPattern.isDefined && offset == 0 && !first) {
- generateInputAccess(input, input1Term, fieldRef.getIndex)
+ if (isWithinAggExprState) {
+ generateFieldAccess(input, inputAggRowTerm, fieldRef.getIndex)
} else {
- generatePatternFieldRef(fieldRef)
+ if (fieldRef.getAlpha.equals(ALL_PATTERN_VARIABLE) &&
+ currentPattern.isDefined && offset == 0 && !first) {
+ generateInputAccess(input, input1Term, fieldRef.getIndex)
+ } else {
+ generatePatternFieldRef(fieldRef)
+ }
}
}
@@ -372,14 +520,14 @@ class MatchCodeGenerator(
val eventTypeTerm = boxedTypeTermForTypeInfo(input)
val eventNameTerm = newName("event")
- val addCurrent = if (currentPattern == patternName || patternName == "*") {
+ val addCurrent = if (currentPattern == patternName || patternName == ALL_PATTERN_VARIABLE) {
j"""
|$listName.add($input1Term);
""".stripMargin
} else {
""
}
- val listCode = if (patternName == "*") {
+ val listCode = if (patternName == ALL_PATTERN_VARIABLE) {
addReusablePatternNames()
val patternTerm = newName("pattern")
j"""
@@ -414,7 +562,7 @@ class MatchCodeGenerator(
private def generateMeasurePatternVariableExp(patternName: String): GeneratedPatternList = {
val listName = newName("patternEvents")
- val code = if (patternName == "*") {
+ val code = if (patternName == ALL_PATTERN_VARIABLE) {
addReusablePatternNames()
val patternTerm = newName("pattern")
@@ -448,21 +596,7 @@ class MatchCodeGenerator(
val eventTypeTerm = boxedTypeTermForTypeInfo(input)
val isRowNull = newName("isRowNull")
- val findEventsByPatternName = reusablePatternLists.get(patternFieldAlpha) match {
- // input access and unboxing has already been generated
- case Some(expr) =>
- expr
-
- case None =>
- val exp = currentPattern match {
- case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
- case None => generateMeasurePatternVariableExp(patternFieldAlpha)
- }
- reusablePatternLists(patternFieldAlpha) = exp
- exp
- }
-
- val listName = findEventsByPatternName.resultTerm
+ val listName = findEventsByPatternName(patternFieldAlpha).resultTerm
val resultIndex = if (first) {
j"""$offset"""
} else {
@@ -482,11 +616,27 @@ class MatchCodeGenerator(
GeneratedExpression(rowNameTerm, isRowNull, funcCode, input)
}
+ private def findEventsByPatternName(
+ patternFieldAlpha: String)
+ : GeneratedPatternList = {
+ reusablePatternLists.get(patternFieldAlpha) match {
+ case Some(expr) =>
+ expr
+
+ case None =>
+ val exp = currentPattern match {
+ case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
+ case None => generateMeasurePatternVariableExp(patternFieldAlpha)
+ }
+ reusablePatternLists(patternFieldAlpha) = exp
+ exp
+ }
+ }
+
private def generatePatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
val escapedAlpha = EncodingUtils.escapeJava(fieldRef.getAlpha)
val patternVariableRef = reusableInputUnboxingExprs
.get((s"$escapedAlpha#$first", offset)) match {
- // input access and unboxing has already been generated
case Some(expr) =>
expr
@@ -498,4 +648,223 @@ class MatchCodeGenerator(
generateFieldAccess(patternVariableRef.copy(code = NO_CODE), fieldRef.getIndex)
}
+
+ class AggBuilder(variable: String) {
+
+ private val aggregates = new mutable.ListBuffer[RexCall]()
+
+ private val variableUID = newName("variable")
+
+ private val rowTypeTerm = "org.apache.flink.types.Row"
+
+ private val calculateAggFuncName = s"calculateAgg_$variableUID"
+
+ def generateDeduplicatedAggAccess(aggCall: RexCall): GeneratedExpression = {
+ reusableAggregationExpr.get(aggCall.toString) match {
+ case Some(expr) =>
+ expr
+
+ case None =>
+ val exp: GeneratedExpression = generateAggAccess(aggCall)
+ aggregates += aggCall
+ reusableAggregationExpr(aggCall.toString) = exp
+ reusablePerRecordStatements += exp.code
+ exp.copy(code = NO_CODE)
+ }
+ }
+
+ private def generateAggAccess(aggCall: RexCall): GeneratedExpression = {
+ val singleAggResultTerm = newName("result")
+ val singleAggNullTerm = newName("nullTerm")
+ val singleAggResultType = FlinkTypeFactory.toTypeInfo(aggCall.`type`)
+ val primitiveSingleAggResultTypeTerm = primitiveTypeTermForTypeInfo(singleAggResultType)
+ val boxedSingleAggResultTypeTerm = boxedTypeTermForTypeInfo(singleAggResultType)
+
+ val allAggRowTerm = s"aggRow_$variableUID"
+
+ val rowsForVariableCode = findEventsByPatternName(variable)
+ val codeForAgg =
+ j"""
+ |$rowTypeTerm $allAggRowTerm = $calculateAggFuncName(${rowsForVariableCode.resultTerm});
+ |""".stripMargin
+
+ reusablePerRecordStatements += codeForAgg
+
+ val defaultValue = primitiveDefaultValue(singleAggResultType)
+ val codeForSingleAgg = if (nullCheck) {
+ j"""
+ |boolean $singleAggNullTerm;
+ |$primitiveSingleAggResultTypeTerm $singleAggResultTerm;
+ |if ($allAggRowTerm.getField(${aggregates.size}) != null) {
+ | $singleAggResultTerm = ($boxedSingleAggResultTypeTerm) $allAggRowTerm
+ | .getField(${aggregates.size});
+ | $singleAggNullTerm = false;
+ |} else {
+ | $singleAggNullTerm = true;
+ | $singleAggResultTerm = $defaultValue;
+ |}
+ |""".stripMargin
+ } else {
+ j"""
+ |$primitiveSingleAggResultTypeTerm $singleAggResultTerm =
+ | ($boxedSingleAggResultTypeTerm) $allAggRowTerm.getField(${aggregates.size});
+ |""".stripMargin
+ }
+
+ reusablePerRecordStatements += codeForSingleAgg
+
+ GeneratedExpression(singleAggResultTerm, singleAggNullTerm, NO_CODE, singleAggResultType)
+ }
+
+ def generateAggFunction(): Unit = {
+ val matchAgg = extractAggregatesAndExpressions
+
+ val aggGenerator = new AggregationCodeGenerator(config, false, input, None)
+
+ val aggFunc = aggGenerator.generateAggregations(
+ s"AggFunction_$variableUID",
+ matchAgg.inputExprs.map(r => FlinkTypeFactory.toTypeInfo(r.getType)),
+ matchAgg.aggregations.map(_.aggFunction).toArray,
+ matchAgg.aggregations.map(_.inputIndices).toArray,
+ matchAgg.aggregations.indices.toArray,
+ Array.fill(matchAgg.aggregations.size)(false),
+ isStateBackedDataViews = false,
+ partialResults = false,
+ Array.emptyIntArray,
+ None,
+ matchAgg.aggregations.size,
+ needRetract = false,
+ needMerge = false,
+ needReset = false,
+ None
+ )
+
+ reusableMemberStatements.add(aggFunc.code)
+
+ val transformFuncName = s"transformRowForAgg_$variableUID"
+ val inputTransform: String = generateAggInputExprEvaluation(
+ matchAgg.inputExprs,
+ transformFuncName)
+
+ generateAggCalculation(aggFunc, transformFuncName, inputTransform)
+ }
+
+ private def extractAggregatesAndExpressions: MatchAgg = {
+ val inputRows = new mutable.LinkedHashMap[String, (RexNode, Int)]
+
+ val logicalAggregates = aggregates.map(aggCall => {
+ val callsWithIndices = aggCall.operands.asScala.map(innerCall => {
+ inputRows.get(innerCall.toString) match {
+ case Some(x) =>
+ x
+
+ case None =>
+ val callWithIndex = (innerCall, inputRows.size)
+ inputRows(innerCall.toString) = callWithIndex
+ callWithIndex
+ }
+ })
+
+ val agg = aggCall.getOperator.asInstanceOf[SqlAggFunction]
+ LogicalSingleAggCall(agg,
+ callsWithIndices.map(_._1.getType),
+ callsWithIndices.map(_._2).toArray)
+ })
+
+ val aggs = logicalAggregates.zipWithIndex.map {
+ case (agg, index) =>
+ val result = AggregateUtil.extractAggregateCallMetadata(
+ agg.function,
+ isDistinct = false, // TODO properly set once supported in Calcite
+ agg.inputTypes,
+ needRetraction = false,
+ config,
+ isStateBackedDataViews = false,
+ index)
+
+ SingleAggCall(result.aggregateFunction, agg.exprIndices.toArray, result.accumulatorSpecs)
+ }
+
+ MatchAgg(aggs, inputRows.values.map(_._1).toSeq)
+ }
+
+ private def generateAggCalculation(
+ aggFunc: GeneratedAggregationsFunction,
+ transformFuncName: String,
+ inputTransformFunc: String)
+ : Unit = {
+ val aggregatorTerm = s"aggregator_$variableUID"
+ val code =
+ j"""
+ |private final ${aggFunc.name} $aggregatorTerm;
+ |
+ |$inputTransformFunc
+ |
+ |private $rowTypeTerm $calculateAggFuncName(java.util.List input)
+ | throws Exception {
+ | $rowTypeTerm accumulator = $aggregatorTerm.createAccumulators();
+ | for ($rowTypeTerm row : input) {
+ | $aggregatorTerm.accumulate(accumulator, $transformFuncName(row));
+ | }
+ | $rowTypeTerm result = $aggregatorTerm.createOutputRow();
+ | $aggregatorTerm.setAggregationResults(accumulator, result);
+ | return result;
+ |}
+ """.stripMargin
+
+ reusableInitStatements.add(s"$aggregatorTerm = new ${aggFunc.name}();")
+ reusableOpenStatements.add(s"$aggregatorTerm.open(getRuntimeContext());")
+ reusableCloseStatements.add(s"$aggregatorTerm.close();")
+ reusableMemberStatements.add(code)
+ }
+
+ private def generateAggInputExprEvaluation(
+ inputExprs: Seq[RexNode],
+ funcName: String)
+ : String = {
+ isWithinAggExprState = true
+ val resultTerm = newName("result")
+ val exprs = inputExprs.zipWithIndex.map {
+ case (inputExpr, outputIndex) => {
+ val expr = generateExpression(inputExpr)
+ s"""
+ |${expr.code}
+ |if (${expr.nullTerm}) {
+ | $resultTerm.setField($outputIndex, null);
+ |} else {
+ | $resultTerm.setField($outputIndex, ${expr.resultTerm});
+ |}
+ """.stripMargin
+ }
+ }.mkString("\n")
+ isWithinAggExprState = false
+
+ j"""
+ |private $rowTypeTerm $funcName($rowTypeTerm $inputAggRowTerm) {
+ | $rowTypeTerm $resultTerm = new $rowTypeTerm(${inputExprs.size});
+ | $exprs
+ | return $resultTerm;
+ |}
+ """.stripMargin
+ }
+
+ private case class LogicalSingleAggCall(
+ function: SqlAggFunction,
+ inputTypes: Seq[RelDataType],
+ exprIndices: Seq[Int]
+ )
+
+ private case class SingleAggCall(
+ aggFunction: TableAggregateFunction[_, _],
+ inputIndices: Array[Int],
+ dataViews: Seq[DataViewSpec[_]]
+ )
+
+ private case class MatchAgg(
+ aggregations: Seq[SingleAggCall],
+ inputExprs: Seq[RexNode]
+ )
+
+ }
+
}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala
index 5b0aa65..bc0f56e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala
@@ -18,15 +18,21 @@
package org.apache.flink.table.plan.rules.datastream
-import org.apache.calcite.plan.{RelOptRule, RelTraitSet}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.sql.SqlAggFunction
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.logical.MatchRecognize
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.datastream.DataStreamMatch
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalMatch
import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.plan.util.RexDefaultVisitor
+import org.apache.flink.table.util.MatchUtil
+
+import scala.collection.JavaConverters._
class DataStreamMatchRule
extends ConverterRule(
@@ -35,6 +41,14 @@ class DataStreamMatchRule
FlinkConventions.DATASTREAM,
"DataStreamMatchRule") {
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val logicalMatch: FlinkLogicalMatch = call.rel(0).asInstanceOf[FlinkLogicalMatch]
+
+ validateAggregations(logicalMatch.getMeasures.values().asScala)
+ validateAggregations(logicalMatch.getPatternDefinitions.values().asScala)
+ true
+ }
+
override def convert(rel: RelNode): RelNode = {
val logicalMatch: FlinkLogicalMatch = rel.asInstanceOf[FlinkLogicalMatch]
val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM)
@@ -71,6 +85,30 @@ class DataStreamMatchRule
new RowSchema(logicalMatch.getRowType),
new RowSchema(logicalMatch.getInput.getRowType))
}
+
+ private def validateAggregations(expr: Iterable[RexNode]): Unit = {
+ val validator = new AggregationsValidator
+ expr.foreach(_.accept(validator))
+ }
+
+ class AggregationsValidator extends RexDefaultVisitor[Object] {
+
+ override def visitCall(call: RexCall): AnyRef = {
+ call.getOperator match {
+ case _: SqlAggFunction =>
+ call.accept(new MatchUtil.AggregationPatternVariableFinder)
+ case _ =>
+ call.getOperands.asScala.foreach(_.accept(this))
+ }
+
+ null
+ }
+
+ override def visitNode(rexNode: RexNode): AnyRef = {
+ null
+ }
+ }
+
}
object DataStreamMatchRule {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/MatchUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/MatchUtil.scala
new file mode 100644
index 0000000..0cce171
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/MatchUtil.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.util
+
+import org.apache.calcite.rex.{RexCall, RexNode, RexPatternFieldRef}
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.plan.util.RexDefaultVisitor
+import scala.collection.JavaConverters._
+
+object MatchUtil {
+ val ALL_PATTERN_VARIABLE = "*"
+
+ class AggregationPatternVariableFinder extends RexDefaultVisitor[Option[String]] {
+
+ override def visitPatternFieldRef(patternFieldRef: RexPatternFieldRef): Option[String] = Some(
+ patternFieldRef.getAlpha)
+
+ override def visitCall(call: RexCall): Option[String] = {
+ if (call.operands.size() == 0) {
+ Some(ALL_PATTERN_VARIABLE)
+ } else {
+ call.operands.asScala.map(n => n.accept(this)).reduce((op1, op2) => (op1, op2) match {
+ case (None, None) => None
+ case (x, None) => x
+ case (None, x) => x
+ case (Some(var1), Some(var2)) if var1.equals(var2) =>
+ Some(var1)
+ case _ =>
+ throw new ValidationException(s"Aggregation must be applied to a single pattern " +
+ s"variable. Malformed expression: $call")
+ })
+ }
+ }
+
+ override def visitNode(rexNode: RexNode): Option[String] = None
+ }
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala
similarity index 89%
rename from flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala
rename to flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala
index e10a568..2890179 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala
@@ -21,13 +21,13 @@ package org.apache.flink.table.`match`
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{TableException, ValidationException}
-import org.apache.flink.table.codegen.CodeGenException
import org.apache.flink.table.runtime.stream.sql.ToMillis
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
import org.apache.flink.table.utils.TableTestBase
import org.apache.flink.types.Row
import org.junit.Test
-class MatchOperatorValidationTest extends TableTestBase {
+class MatchRecognizeValidationTest extends TableTestBase {
private val streamUtils = streamTestUtil()
streamUtils.addTable[(String, Long, Int, Int)]("Ticker",
@@ -128,15 +128,10 @@ class MatchOperatorValidationTest extends TableTestBase {
streamUtils.tableEnv.sqlQuery(sqlQuery).toRetractStream[Row]
}
- // ***************************************************************************************
- // * Those validations are temporary. We should remove those tests once we support those *
- // * features. *
- // ***************************************************************************************
-
@Test
- def testAllRowsPerMatch(): Unit = {
- thrown.expectMessage("All rows per match mode is not supported yet.")
- thrown.expect(classOf[TableException])
+ def testAggregatesOnMultiplePatternVariablesNotSupported(): Unit = {
+ thrown.expect(classOf[ValidationException])
+ thrown.expectMessage("SQL validation failed.")
val sqlQuery =
s"""
@@ -145,11 +140,10 @@ class MatchOperatorValidationTest extends TableTestBase {
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
- | A.symbol AS aSymbol
- | ALL ROWS PER MATCH
+ | SUM(A.price + B.tax) AS taxedPrice
| PATTERN (A B)
| DEFINE
- | A AS symbol = 'a'
+ | A AS A.symbol = 'a'
|) AS T
|""".stripMargin
@@ -157,10 +151,11 @@ class MatchOperatorValidationTest extends TableTestBase {
}
@Test
- def testGreedyQuantifierAtTheEndIsNotSupported(): Unit = {
- thrown.expectMessage("Greedy quantifiers are not allowed as the last element of a " +
- "Pattern yet. Finish your pattern with either a simple variable or reluctant quantifier.")
- thrown.expect(classOf[TableException])
+ def testAggregatesOnMultiplePatternVariablesNotSupportedInUDAGs(): Unit = {
+ thrown.expect(classOf[ValidationException])
+ thrown.expectMessage("Aggregation must be applied to a single pattern variable")
+
+ streamUtils.tableEnv.registerFunction("weightedAvg", new WeightedAvg)
val sqlQuery =
s"""
@@ -169,20 +164,24 @@ class MatchOperatorValidationTest extends TableTestBase {
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
- | A.symbol AS aSymbol
- | PATTERN (A B+)
+ | weightedAvg(A.price, B.tax) AS weightedAvg
+ | PATTERN (A B)
| DEFINE
- | A AS symbol = 'a'
+ | A AS A.symbol = 'a'
|) AS T
|""".stripMargin
streamUtils.tableEnv.sqlQuery(sqlQuery).toAppendStream[Row]
}
+ // ***************************************************************************************
+ // * Those validations are temporary. We should remove those tests once we support those *
+ // * features. *
+ // ***************************************************************************************
+
@Test
- def testPatternsProducingEmptyMatchesAreNotSupported(): Unit = {
- thrown.expectMessage("Patterns that can produce empty matches are not supported. " +
- "There must be at least one non-optional state.")
+ def testAllRowsPerMatch(): Unit = {
+ thrown.expectMessage("All rows per match mode is not supported yet.")
thrown.expect(classOf[TableException])
val sqlQuery =
@@ -193,7 +192,8 @@ class MatchOperatorValidationTest extends TableTestBase {
| ORDER BY proctime
| MEASURES
| A.symbol AS aSymbol
- | PATTERN (A*)
+ | ALL ROWS PER MATCH
+ | PATTERN (A B)
| DEFINE
| A AS symbol = 'a'
|) AS T
@@ -203,11 +203,10 @@ class MatchOperatorValidationTest extends TableTestBase {
}
@Test
- def testAggregatesAreNotSupportedInMeasures(): Unit = {
- thrown.expectMessage(
- "Unsupported call: SUM \nIf you think this function should be supported, you can " +
- "create an issue and start a discussion for it.")
- thrown.expect(classOf[CodeGenException])
+ def testGreedyQuantifierAtTheEndIsNotSupported(): Unit = {
+ thrown.expectMessage("Greedy quantifiers are not allowed as the last element of a " +
+ "Pattern yet. Finish your pattern with either a simple variable or reluctant quantifier.")
+ thrown.expect(classOf[TableException])
val sqlQuery =
s"""
@@ -216,10 +215,10 @@ class MatchOperatorValidationTest extends TableTestBase {
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
- | SUM(A.price + A.tax) AS cost
- | PATTERN (A B)
+ | A.symbol AS aSymbol
+ | PATTERN (A B+)
| DEFINE
- | A AS A.symbol = 'a'
+ | A AS symbol = 'a'
|) AS T
|""".stripMargin
@@ -227,11 +226,10 @@ class MatchOperatorValidationTest extends TableTestBase {
}
@Test
- def testAggregatesAreNotSupportedInDefine(): Unit = {
- thrown.expectMessage(
- "Unsupported call: SUM \nIf you think this function should be supported, you can " +
- "create an issue and start a discussion for it.")
- thrown.expect(classOf[CodeGenException])
+ def testPatternsProducingEmptyMatchesAreNotSupported(): Unit = {
+ thrown.expectMessage("Patterns that can produce empty matches are not supported. " +
+ "There must be at least one non-optional state.")
+ thrown.expect(classOf[TableException])
val sqlQuery =
s"""
@@ -240,10 +238,10 @@ class MatchOperatorValidationTest extends TableTestBase {
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
- | B.price as bPrice
- | PATTERN (A+ B)
+ | A.symbol AS aSymbol
+ | PATTERN (A*)
| DEFINE
- | A AS SUM(A.price + A.tax) < 10
+ | A AS symbol = 'a'
|) AS T
|""".stripMargin
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala
index 8f5a8f3..ffd1c2b 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala
@@ -22,12 +22,14 @@ import java.sql.Timestamp
import java.util.TimeZone
import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
-import org.apache.flink.table.api.{TableConfig, TableEnvironment}
import org.apache.flink.table.api.scala._
-import org.apache.flink.table.functions.{FunctionContext, ScalarFunction}
+import org.apache.flink.table.api.{TableConfig, TableEnvironment, Types}
+import org.apache.flink.table.functions.{AggregateFunction, FunctionContext, ScalarFunction}
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase, UserDefinedFunctionTestUtils}
import org.apache.flink.types.Row
@@ -39,7 +41,7 @@ import scala.collection.mutable
class MatchRecognizeITCase extends StreamingWithStateTestBase {
@Test
- def testSimpleCEP(): Unit = {
+ def testSimplePattern(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1)
val tEnv = TableEnvironment.getTableEnvironment(env)
@@ -86,7 +88,7 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
}
@Test
- def testSimpleCEPWithNulls(): Unit = {
+ def testSimplePatternWithNulls(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1)
val tEnv = TableEnvironment.getTableEnvironment(env)
@@ -464,6 +466,132 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
+ /**
+ * This query checks:
+ *
+ * 1. count(D.price) produces 0, because no rows matched to D
+ * 2. sum(D.price) produces null, because no rows matched to D
+ * 3. aggregates that take multiple parameters work
+ * 4. aggregates with expressions work
+ */
+ @Test
+ def testAggregates(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setParallelism(1)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ tEnv.getConfig.setMaxGeneratedCodeLength(1)
+ StreamITCase.clear
+
+ val data = new mutable.MutableList[(Int, String, Long, Double, Int)]
+ data.+=((1, "a", 1, 0.8, 1))
+ data.+=((2, "z", 2, 0.8, 3))
+ data.+=((3, "b", 1, 0.8, 2))
+ data.+=((4, "c", 1, 0.8, 5))
+ data.+=((5, "d", 4, 0.1, 5))
+ data.+=((6, "a", 2, 1.5, 2))
+ data.+=((7, "b", 2, 0.8, 3))
+ data.+=((8, "c", 1, 0.8, 2))
+ data.+=((9, "h", 4, 0.8, 3))
+ data.+=((10, "h", 4, 0.8, 3))
+ data.+=((11, "h", 2, 0.8, 3))
+ data.+=((12, "h", 2, 0.8, 3))
+
+ val t = env.fromCollection(data)
+ .toTable(tEnv, 'id, 'name, 'price, 'rate, 'weight, 'proctime.proctime)
+ tEnv.registerTable("MyTable", t)
+ tEnv.registerFunction("weightedAvg", new WeightedAvg)
+
+ val sqlQuery =
+ s"""
+ |SELECT *
+ |FROM MyTable
+ |MATCH_RECOGNIZE (
+ | ORDER BY proctime
+ | MEASURES
+ | FIRST(id) as startId,
+ | SUM(A.price) AS sumA,
+ | COUNT(DISTINCT D.price) AS countD,
+ | SUM(D.price) as sumD,
+ | weightedAvg(price, weight) as wAvg,
+ | AVG(B.price) AS avgB,
+ | SUM(B.price * B.rate) as sumExprB,
+ | LAST(id) as endId
+ | AFTER MATCH SKIP PAST LAST ROW
+ | PATTERN (A+ B+ C D? E )
+ | DEFINE
+ | A AS SUM(A.price) < 6,
+ | B AS SUM(B.price * B.rate) < SUM(A.price) AND
+ | SUM(B.price * B.rate) > 0.2 AND
+ | SUM(B.price) >= 1 AND
+ | AVG(B.price) >= 1 AND
+ | weightedAvg(price, weight) > 1
+ |) AS T
+ |""".stripMargin
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = mutable.MutableList("1,5,0,null,2,3,3.4,8", "9,4,0,null,3,4,3.2,12")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testAggregatesWithNullInputs(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setParallelism(1)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ tEnv.getConfig.setMaxGeneratedCodeLength(1)
+ StreamITCase.clear
+
+ val data = new mutable.MutableList[Row]
+ data.+=(Row.of(Int.box(1), "a", Int.box(10)))
+ data.+=(Row.of(Int.box(2), "z", Int.box(10)))
+ data.+=(Row.of(Int.box(3), "b", null))
+ data.+=(Row.of(Int.box(4), "c", null))
+ data.+=(Row.of(Int.box(5), "d", Int.box(3)))
+ data.+=(Row.of(Int.box(6), "c", Int.box(3)))
+ data.+=(Row.of(Int.box(7), "c", Int.box(3)))
+ data.+=(Row.of(Int.box(8), "c", Int.box(3)))
+ data.+=(Row.of(Int.box(9), "c", Int.box(2)))
+
+ val t = env.fromCollection(data)(Types.ROW(
+ BasicTypeInfo.INT_TYPE_INFO,
+ BasicTypeInfo.STRING_TYPE_INFO,
+ BasicTypeInfo.INT_TYPE_INFO))
+ .toTable(tEnv, 'id, 'name, 'price, 'proctime.proctime)
+ tEnv.registerTable("MyTable", t)
+ tEnv.registerFunction("weightedAvg", new WeightedAvg)
+
+ val sqlQuery =
+ s"""
+ |SELECT *
+ |FROM MyTable
+ |MATCH_RECOGNIZE (
+ | ORDER BY proctime
+ | MEASURES
+ | SUM(A.price) as sumA,
+ | COUNT(A.id) as countAId,
+ | COUNT(A.price) as countAPrice,
+ | COUNT(*) as countAll,
+ | COUNT(price) as countAllPrice,
+ | LAST(id) as endId
+ | AFTER MATCH SKIP PAST LAST ROW
+ | PATTERN (A+ C)
+ | DEFINE
+ | A AS SUM(A.price) < 30,
+ | C AS C.name = 'c'
+ |) AS T
+ |""".stripMargin
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = mutable.MutableList("29,7,5,8,6,8")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
@Test
def testAccessingProctime(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
@@ -567,9 +695,11 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
.toTable(tEnv, 'id, 'name, 'price, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
tEnv.registerFunction("prefix", new PrefixingScalarFunc)
+ tEnv.registerFunction("countFrom", new RichAggFunc)
val prefix = "PREF"
+ val startFrom = 4
UserDefinedFunctionTestUtils
- .setJobParameters(env, Map("prefix" -> prefix))
+ .setJobParameters(env, Map("prefix" -> prefix, "start" -> startFrom.toString))
val sqlQuery =
s"""
@@ -580,11 +710,12 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
| MEASURES
| FIRST(id) as firstId,
| prefix(A.name) as prefixedNameA,
+ | countFrom(A.price) as countFromA,
| LAST(id) as lastId
| AFTER MATCH SKIP PAST LAST ROW
| PATTERN (A+ C)
| DEFINE
- | A AS prefix(A.name) = '$prefix:a'
+ | A AS prefix(A.name) = '$prefix:a' AND countFrom(A.price) <= ${startFrom + 4}
|) AS T
|""".stripMargin
@@ -592,7 +723,7 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
result.addSink(new StreamITCase.StringSink[Row])
env.execute()
- val expected = mutable.MutableList("1,PREF:a,6", "7,PREF:a,9")
+ val expected = mutable.MutableList("1,PREF:a,8,5", "7,PREF:a,6,9")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
}
@@ -615,3 +746,26 @@ private class PrefixingScalarFunc extends ScalarFunction {
s"$prefix:$value"
}
}
+
+private case class CountAcc(var count: Long)
+
+private class RichAggFunc extends AggregateFunction[Long, CountAcc] {
+
+ private var start : Long = 0
+
+ override def open(context: FunctionContext): Unit = {
+ start = context.getJobParameter("start", "0").toLong
+ }
+
+ override def close(): Unit = {
+ start = 0
+ }
+
+ def accumulate(countAcc: CountAcc, value: Long): Unit = {
+ countAcc.count += value
+ }
+
+ override def createAccumulator(): CountAcc = CountAcc(start)
+
+ override def getValue(accumulator: CountAcc): Long = accumulator.count
+}