You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2018/12/17 08:25:18 UTC

[flink] 03/06: [FLINK-7599][table, cep] Support for aggregates in MATCH_RECOGNIZE

This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1458caf77943704c345bed1a6ef517f35f858fb8
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Mon Nov 26 14:27:22 2018 +0100

    [FLINK-7599][table, cep] Support for aggregates in MATCH_RECOGNIZE
---
 .../apache/flink/table/codegen/CodeGenerator.scala |  37 +-
 .../flink/table/codegen/MatchCodeGenerator.scala   | 425 +++++++++++++++++++--
 .../rules/datastream/DataStreamMatchRule.scala     |  40 +-
 .../org/apache/flink/table/util/MatchUtil.scala    |  53 +++
 ...st.scala => MatchRecognizeValidationTest.scala} |  78 ++--
 .../runtime/stream/sql/MatchRecognizeITCase.scala  | 168 +++++++-
 6 files changed, 709 insertions(+), 92 deletions(-)

diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index fdb0d50..6c05ff0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -1048,6 +1048,25 @@ abstract class CodeGenerator(
   // generator helping methods
   // ----------------------------------------------------------------------------------------------
 
+  protected def makeReusableInSplits(exprs: Iterable[GeneratedExpression]): Unit = {
+    // add results of expressions to member area such that all split functions can access it
+    exprs.foreach { expr =>
+
+      // declaration
+      val resultTypeTerm = primitiveTypeTermForTypeInfo(expr.resultType)
+      if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
+        reusableMemberStatements.add(s"private boolean ${expr.nullTerm};")
+      }
+      reusableMemberStatements.add(s"private $resultTypeTerm ${expr.resultTerm};")
+
+      // assignment
+      if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
+        reusablePerRecordStatements.add(s"this.${expr.nullTerm} = ${expr.nullTerm};")
+      }
+      reusablePerRecordStatements.add(s"this.${expr.resultTerm} = ${expr.resultTerm};")
+    }
+  }
+
   private def generateCodeSplits(splits: Seq[String]): String = {
     val totalLen = splits.map(_.length + 1).sum // 1 for a line break
 
@@ -1057,21 +1076,7 @@ abstract class CodeGenerator(
       hasCodeSplits = true
 
       // add input unboxing to member area such that all split functions can access it
-      reusableInputUnboxingExprs.foreach { case (_, expr) =>
-
-        // declaration
-        val resultTypeTerm = primitiveTypeTermForTypeInfo(expr.resultType)
-        if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
-          reusableMemberStatements.add(s"private boolean ${expr.nullTerm};")
-        }
-        reusableMemberStatements.add(s"private $resultTypeTerm ${expr.resultTerm};")
-
-        // assignment
-        if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
-          reusablePerRecordStatements.add(s"this.${expr.nullTerm} = ${expr.nullTerm};")
-        }
-        reusablePerRecordStatements.add(s"this.${expr.resultTerm} = ${expr.resultTerm};")
-      }
+      makeReusableInSplits(reusableInputUnboxingExprs.values)
 
       // add split methods to the member area and return the code necessary to call those methods
       val methodCalls = splits.map { split =>
@@ -1196,7 +1201,7 @@ abstract class CodeGenerator(
     GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType)
   }
 
-  private def generateFieldAccess(
+  protected def generateFieldAccess(
       inputType: TypeInformation[_],
       inputTerm: String,
       index: Int)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
index 8f8a7f1..ffa7fc2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
@@ -22,19 +22,26 @@ import java.lang.{Long => JLong}
 import java.util
 
 import org.apache.calcite.rel.RelFieldCollation
+import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rex._
+import org.apache.calcite.sql.SqlAggFunction
 import org.apache.calcite.sql.fun.SqlStdOperatorTable._
 import org.apache.flink.api.common.functions._
 import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
 import org.apache.flink.cep.pattern.conditions.{IterativeCondition, RichIterativeCondition}
 import org.apache.flink.cep.{RichPatternFlatSelectFunction, RichPatternSelectFunction}
 import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.api.dataview.DataViewSpec
 import org.apache.flink.table.api.{TableConfig, TableException}
-import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName}
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName, primitiveDefaultValue, primitiveTypeTermForTypeInfo}
 import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
 import org.apache.flink.table.codegen.Indenter.toISC
+import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction}
 import org.apache.flink.table.plan.schema.RowSchema
 import org.apache.flink.table.runtime.`match`.{IterativeConditionRunner, PatternSelectFunctionRunner}
+import org.apache.flink.table.runtime.aggregate.AggregateUtil
+import org.apache.flink.table.util.MatchUtil.{ALL_PATTERN_VARIABLE, AggregationPatternVariableFinder}
 import org.apache.flink.table.utils.EncodingUtils
 import org.apache.flink.types.Row
 import org.apache.flink.util.Collector
@@ -53,7 +60,7 @@ object MatchCodeGenerator {
     names: Seq[String])
   : IterativeConditionRunner = {
     val generator = new MatchCodeGenerator(config, inputTypeInfo, names, Some(patternName))
-    val condition = generator.generateExpression(patternDefinition)
+    val condition = generator.generateCondition(patternDefinition)
     val body =
       s"""
          |${condition.code}
@@ -101,6 +108,90 @@ object MatchCodeGenerator {
 /**
   * A code generator for generating CEP related functions.
   *
+  * Aggregates are generated as follows:
+  * 1. all aggregate [[RexCall]]s are grouped by corresponding pattern variable
+  * 2. even if the same aggregation is used multiple times in an expression
+  *    (e.g. SUM(A.price) > SUM(A.price) + 1) it will be calculated once. To do so [[AggBuilder]]
+  *    keeps set of already seen different aggregation calls, and reuses the code to access
+  *    appropriate field of aggregation result
+  * 3. after translating every expression (either in [[generateCondition]] or in
+  *    [[generateOneRowPerMatchExpression]]) there will be generated code for
+  *       - [[GeneratedFunction]], which will be an inner class
+  *       - said [[GeneratedFunction]] will be instantiated in the ctor and opened/closed
+  *         in corresponding methods of top level generated classes
+  *       - function that transforms input rows (row by row) into aggregate input rows
+  *       - function that calculates aggregates for variable, that uses the previous method
+  *    The generated code will look similar to this:
+  *
+  *
+  * {{{
+  *
+  * public class MatchRecognizePatternSelectFunction$175 extends RichPatternSelectFunction {
+  *
+  *     // Class used to calculate aggregates for a single pattern variable
+  *     public final class AggFunction_variable$115$151 extends GeneratedAggregations {
+  *       ...
+  *     }
+  *
+  *     private final AggFunction_variable$115$151 aggregator_variable$115;
+  *
+  *     public MatchRecognizePatternSelectFunction$175() {
+  *       aggregator_variable$115 = new AggFunction_variable$115$151();
+  *     }
+  *
+  *     public void open() {
+  *       aggregator_variable$115.open();
+  *       ...
+  *     }
+  *
+  *     // Function to transform incoming row into aggregate specific row. It can e.g calculate
+  *     // inner expression of said aggregate
+  *     private Row transformRowForAgg_variable$115(Row inAgg) {
+  *         ...
+  *     }
+  *
+  *     // Function to calculate all aggregates for a single pattern variable
+  *     private Row calculateAgg_variable$115(List<Row> input) {
+  *       Acc accumulator = aggregator_variable$115.createAccumulator();
+  *       for (Row row : input) {
+  *         aggregator_variable$115.accumulate(accumulator, transformRowForAgg_variable$115(row));
+  *       }
+  *
+  *       return aggregator_variable$115.getResult(accumulator);
+  *     }
+  *
+  *     @Override
+  *     public Object select(Map<String, List<Row>> in1) throws Exception {
+  *
+  *       // Extract list of rows assigned to a single pattern variable
+  *       java.util.List patternEvents$130 = (java.util.List) in1.get("A");
+  *       ...
+  *
+  *       // Calculate aggregates
+  *       Row aggRow_variable$110$111 = calculateAgg_variable$110(patternEvents$114);
+  *
+  *       // Every aggregation (e.g SUM(A.price) and AVG(A.price)) will be extracted to a variable
+  *       double result$135 = aggRow_variable$126$127.getField(0);
+  *       long result$137 = aggRow_variable$126$127.getField(1);
+  *
+  *       // Result of aggregation will be used in expression evaluation
+  *       out.setField(0, result$135)
+  *
+  *       long result$140 = result$137 * 2;
+  *       out.setField(1, result$140);
+  *
+  *       double result$144 = $result135 + result$137;
+  *       out.setField(2, result$144);
+  *     }
+  *
+  *     public void close() {
+  *       aggregator_variable$115.close();
+  *       ...
+  *     }
+  *
+  * }
+  * }}}
+  *
   * @param config configuration that determines runtime behavior
   * @param patternNames sorted sequence of pattern variables
   * @param input type information about the first input of the Function
@@ -124,17 +215,39 @@ class MatchCodeGenerator(
     .HashMap[String, GeneratedPatternList]()
 
   /**
+    * Used to deduplicate aggregations calculation. The deduplication is performed by
+    * [[RexNode.toString]]. Those expressions needs to be accessible from splits, if such exists.
+    */
+  private val reusableAggregationExpr = new mutable.HashMap[String, GeneratedExpression]()
+
+  /**
     * Context information used by Pattern reference variable to index rows mapped to it.
     * Indexes element at offset either from beginning or the end based on the value of first.
     */
   private var offset: Int = 0
   private var first : Boolean = false
 
+  /**
+    * Flags that tells if we generate expressions inside an aggregate. It tells how to access input
+    * row.
+    */
+  private var isWithinAggExprState: Boolean = false
+
+  /**
+    * Name of term in function used to transform input row into aggregate input row.
+    */
+  private val inputAggRowTerm = "inAgg"
+
   /** Term for row for key extraction */
-  private val keyRowTerm = newName("keyRow")
+  private val keyRowTerm = "keyRow"
 
   /** Term for list of all pattern names */
-  private val patternNamesTerm = newName("patternNames")
+  private val patternNamesTerm = "patternNames"
+
+  /**
+    * Used to collect all aggregates per pattern variable.
+    */
+  private val aggregatesPerVariable = new mutable.HashMap[String, AggBuilder]
 
   /**
     * Sets the new reference variable indexing context. This should be used when resolving logical
@@ -252,7 +365,6 @@ class MatchCodeGenerator(
   private def generateKeyRow() : GeneratedExpression = {
     val exp = reusableInputUnboxingExprs
       .get((keyRowTerm, 0)) match {
-      // input access and unboxing has already been generated
       case Some(expr) =>
         expr
 
@@ -310,10 +422,26 @@ class MatchCodeGenerator(
         generateExpression(measures.get(fieldName))
       }
 
-    generateResultExpression(
+    val exp = generateResultExpression(
       resultExprs,
       returnType.typeInfo,
       returnType.fieldNames)
+    aggregatesPerVariable.values.foreach(_.generateAggFunction())
+    if (hasCodeSplits) {
+      makeReusableInSplits(reusableAggregationExpr.values)
+    }
+
+    exp
+  }
+
+  def generateCondition(call: RexNode): GeneratedExpression = {
+    val exp = call.accept(this)
+    aggregatesPerVariable.values.foreach(_.generateAggFunction())
+    if (hasCodeSplits) {
+      makeReusableInSplits(reusableAggregationExpr.values)
+    }
+
+    exp
   }
 
   override def visitCall(call: RexCall): GeneratedExpression = {
@@ -341,6 +469,21 @@ class MatchCodeGenerator(
       case FINAL =>
         call.getOperands.get(0).accept(this)
 
+      case _ : SqlAggFunction =>
+
+        val variable = call.accept(new AggregationPatternVariableFinder)
+          .getOrElse(throw new TableException("No pattern variable specified in aggregate"))
+
+        val matchAgg = aggregatesPerVariable.get(variable) match {
+          case Some(agg) => agg
+          case None =>
+            val agg = new AggBuilder(variable)
+            aggregatesPerVariable(variable) = agg
+            agg
+        }
+
+        matchAgg.generateDeduplicatedAggAccess(call)
+
       case _ => super.visitCall(call)
     }
   }
@@ -357,10 +500,15 @@ class MatchCodeGenerator(
   }
 
   override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
-    if (fieldRef.getAlpha.equals("*") && currentPattern.isDefined && offset == 0 && !first) {
-      generateInputAccess(input, input1Term, fieldRef.getIndex)
+    if (isWithinAggExprState) {
+      generateFieldAccess(input, inputAggRowTerm, fieldRef.getIndex)
     } else {
-      generatePatternFieldRef(fieldRef)
+      if (fieldRef.getAlpha.equals(ALL_PATTERN_VARIABLE) &&
+          currentPattern.isDefined && offset == 0 && !first) {
+        generateInputAccess(input, input1Term, fieldRef.getIndex)
+      } else {
+        generatePatternFieldRef(fieldRef)
+      }
     }
   }
 
@@ -372,14 +520,14 @@ class MatchCodeGenerator(
     val eventTypeTerm = boxedTypeTermForTypeInfo(input)
     val eventNameTerm = newName("event")
 
-    val addCurrent = if (currentPattern == patternName || patternName == "*") {
+    val addCurrent = if (currentPattern == patternName || patternName == ALL_PATTERN_VARIABLE) {
       j"""
          |$listName.add($input1Term);
          """.stripMargin
     } else {
       ""
     }
-    val listCode = if (patternName == "*") {
+    val listCode = if (patternName == ALL_PATTERN_VARIABLE) {
       addReusablePatternNames()
       val patternTerm = newName("pattern")
       j"""
@@ -414,7 +562,7 @@ class MatchCodeGenerator(
   private def generateMeasurePatternVariableExp(patternName: String): GeneratedPatternList = {
     val listName = newName("patternEvents")
 
-    val code = if (patternName == "*") {
+    val code = if (patternName == ALL_PATTERN_VARIABLE) {
       addReusablePatternNames()
 
       val patternTerm = newName("pattern")
@@ -448,21 +596,7 @@ class MatchCodeGenerator(
     val eventTypeTerm = boxedTypeTermForTypeInfo(input)
     val isRowNull = newName("isRowNull")
 
-    val findEventsByPatternName = reusablePatternLists.get(patternFieldAlpha) match {
-      // input access and unboxing has already been generated
-      case Some(expr) =>
-        expr
-
-      case None =>
-        val exp = currentPattern match {
-          case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
-          case None => generateMeasurePatternVariableExp(patternFieldAlpha)
-        }
-        reusablePatternLists(patternFieldAlpha) = exp
-        exp
-    }
-
-    val listName = findEventsByPatternName.resultTerm
+    val listName = findEventsByPatternName(patternFieldAlpha).resultTerm
     val resultIndex = if (first) {
       j"""$offset"""
     } else {
@@ -482,11 +616,27 @@ class MatchCodeGenerator(
     GeneratedExpression(rowNameTerm, isRowNull, funcCode, input)
   }
 
+  private def findEventsByPatternName(
+      patternFieldAlpha: String)
+    : GeneratedPatternList = {
+    reusablePatternLists.get(patternFieldAlpha) match {
+      case Some(expr) =>
+        expr
+
+      case None =>
+        val exp = currentPattern match {
+          case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
+          case None => generateMeasurePatternVariableExp(patternFieldAlpha)
+        }
+        reusablePatternLists(patternFieldAlpha) = exp
+        exp
+    }
+  }
+
   private def generatePatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
     val escapedAlpha = EncodingUtils.escapeJava(fieldRef.getAlpha)
     val patternVariableRef = reusableInputUnboxingExprs
       .get((s"$escapedAlpha#$first", offset)) match {
-      // input access and unboxing has already been generated
       case Some(expr) =>
         expr
 
@@ -498,4 +648,223 @@ class MatchCodeGenerator(
 
     generateFieldAccess(patternVariableRef.copy(code = NO_CODE), fieldRef.getIndex)
   }
+
+  class AggBuilder(variable: String) {
+
+    private val aggregates = new mutable.ListBuffer[RexCall]()
+
+    private val variableUID = newName("variable")
+
+    private val rowTypeTerm = "org.apache.flink.types.Row"
+
+    private val calculateAggFuncName = s"calculateAgg_$variableUID"
+
+    def generateDeduplicatedAggAccess(aggCall: RexCall): GeneratedExpression = {
+      reusableAggregationExpr.get(aggCall.toString) match  {
+        case Some(expr) =>
+          expr
+
+        case None =>
+          val exp: GeneratedExpression = generateAggAccess(aggCall)
+          aggregates += aggCall
+          reusableAggregationExpr(aggCall.toString) = exp
+          reusablePerRecordStatements += exp.code
+          exp.copy(code = NO_CODE)
+      }
+    }
+
+    private def generateAggAccess(aggCall: RexCall): GeneratedExpression = {
+      val singleAggResultTerm = newName("result")
+      val singleAggNullTerm = newName("nullTerm")
+      val singleAggResultType = FlinkTypeFactory.toTypeInfo(aggCall.`type`)
+      val primitiveSingleAggResultTypeTerm = primitiveTypeTermForTypeInfo(singleAggResultType)
+      val boxedSingleAggResultTypeTerm = boxedTypeTermForTypeInfo(singleAggResultType)
+
+      val allAggRowTerm = s"aggRow_$variableUID"
+
+      val rowsForVariableCode = findEventsByPatternName(variable)
+      val codeForAgg =
+        j"""
+           |$rowTypeTerm $allAggRowTerm = $calculateAggFuncName(${rowsForVariableCode.resultTerm});
+           |""".stripMargin
+
+      reusablePerRecordStatements += codeForAgg
+
+      val defaultValue = primitiveDefaultValue(singleAggResultType)
+      val codeForSingleAgg = if (nullCheck) {
+        j"""
+           |boolean $singleAggNullTerm;
+           |$primitiveSingleAggResultTypeTerm $singleAggResultTerm;
+           |if ($allAggRowTerm.getField(${aggregates.size}) != null) {
+           |  $singleAggResultTerm = ($boxedSingleAggResultTypeTerm) $allAggRowTerm
+           |    .getField(${aggregates.size});
+           |  $singleAggNullTerm = false;
+           |} else {
+           |  $singleAggNullTerm = true;
+           |  $singleAggResultTerm = $defaultValue;
+           |}
+           |""".stripMargin
+      } else {
+        j"""
+           |$primitiveSingleAggResultTypeTerm $singleAggResultTerm =
+           |    ($boxedSingleAggResultTypeTerm) $allAggRowTerm.getField(${aggregates.size});
+           |""".stripMargin
+      }
+
+      reusablePerRecordStatements += codeForSingleAgg
+
+      GeneratedExpression(singleAggResultTerm, singleAggNullTerm, NO_CODE, singleAggResultType)
+    }
+
+    def generateAggFunction(): Unit = {
+      val matchAgg = extractAggregatesAndExpressions
+
+      val aggGenerator = new AggregationCodeGenerator(config, false, input, None)
+
+      val aggFunc = aggGenerator.generateAggregations(
+        s"AggFunction_$variableUID",
+        matchAgg.inputExprs.map(r => FlinkTypeFactory.toTypeInfo(r.getType)),
+        matchAgg.aggregations.map(_.aggFunction).toArray,
+        matchAgg.aggregations.map(_.inputIndices).toArray,
+        matchAgg.aggregations.indices.toArray,
+        Array.fill(matchAgg.aggregations.size)(false),
+        isStateBackedDataViews = false,
+        partialResults = false,
+        Array.emptyIntArray,
+        None,
+        matchAgg.aggregations.size,
+        needRetract = false,
+        needMerge = false,
+        needReset = false,
+        None
+      )
+
+      reusableMemberStatements.add(aggFunc.code)
+
+      val transformFuncName = s"transformRowForAgg_$variableUID"
+      val inputTransform: String = generateAggInputExprEvaluation(
+        matchAgg.inputExprs,
+        transformFuncName)
+
+      generateAggCalculation(aggFunc, transformFuncName, inputTransform)
+    }
+
+    private def extractAggregatesAndExpressions: MatchAgg = {
+      val inputRows = new mutable.LinkedHashMap[String, (RexNode, Int)]
+
+      val logicalAggregates = aggregates.map(aggCall => {
+        val callsWithIndices = aggCall.operands.asScala.map(innerCall => {
+          inputRows.get(innerCall.toString) match {
+            case Some(x) =>
+              x
+
+            case None =>
+              val callWithIndex = (innerCall, inputRows.size)
+              inputRows(innerCall.toString) = callWithIndex
+              callWithIndex
+          }
+        })
+
+        val agg = aggCall.getOperator.asInstanceOf[SqlAggFunction]
+        LogicalSingleAggCall(agg,
+          callsWithIndices.map(_._1.getType),
+          callsWithIndices.map(_._2).toArray)
+      })
+
+      val aggs = logicalAggregates.zipWithIndex.map {
+        case (agg, index) =>
+          val result = AggregateUtil.extractAggregateCallMetadata(
+            agg.function,
+            isDistinct = false, // TODO properly set once supported in Calcite
+            agg.inputTypes,
+            needRetraction = false,
+            config,
+            isStateBackedDataViews = false,
+            index)
+
+          SingleAggCall(result.aggregateFunction, agg.exprIndices.toArray, result.accumulatorSpecs)
+      }
+
+      MatchAgg(aggs, inputRows.values.map(_._1).toSeq)
+    }
+
+    private def generateAggCalculation(
+        aggFunc: GeneratedAggregationsFunction,
+        transformFuncName: String,
+        inputTransformFunc: String)
+      : Unit = {
+      val aggregatorTerm = s"aggregator_$variableUID"
+      val code =
+        j"""
+           |private final ${aggFunc.name} $aggregatorTerm;
+           |
+           |$inputTransformFunc
+           |
+           |private $rowTypeTerm $calculateAggFuncName(java.util.List input)
+           |    throws Exception {
+           |  $rowTypeTerm accumulator = $aggregatorTerm.createAccumulators();
+           |  for ($rowTypeTerm row : input) {
+           |    $aggregatorTerm.accumulate(accumulator, $transformFuncName(row));
+           |  }
+           |  $rowTypeTerm result = $aggregatorTerm.createOutputRow();
+           |  $aggregatorTerm.setAggregationResults(accumulator, result);
+           |  return result;
+           |}
+         """.stripMargin
+
+      reusableInitStatements.add(s"$aggregatorTerm = new ${aggFunc.name}();")
+      reusableOpenStatements.add(s"$aggregatorTerm.open(getRuntimeContext());")
+      reusableCloseStatements.add(s"$aggregatorTerm.close();")
+      reusableMemberStatements.add(code)
+    }
+
+    private def generateAggInputExprEvaluation(
+        inputExprs: Seq[RexNode],
+        funcName: String)
+      : String = {
+      isWithinAggExprState = true
+      val resultTerm = newName("result")
+      val exprs = inputExprs.zipWithIndex.map {
+        case (inputExpr, outputIndex) => {
+          val expr = generateExpression(inputExpr)
+          s"""
+             |${expr.code}
+             |if (${expr.nullTerm}) {
+             |  $resultTerm.setField($outputIndex, null);
+             |} else {
+             |  $resultTerm.setField($outputIndex, ${expr.resultTerm});
+             |}
+         """.stripMargin
+        }
+      }.mkString("\n")
+      isWithinAggExprState = false
+
+      j"""
+         |private $rowTypeTerm $funcName($rowTypeTerm $inputAggRowTerm) {
+         |  $rowTypeTerm $resultTerm = new $rowTypeTerm(${inputExprs.size});
+         |  $exprs
+         |  return $resultTerm;
+         |}
+       """.stripMargin
+    }
+
+    private case class LogicalSingleAggCall(
+      function: SqlAggFunction,
+      inputTypes: Seq[RelDataType],
+      exprIndices: Seq[Int]
+    )
+
+    private case class SingleAggCall(
+      aggFunction: TableAggregateFunction[_, _],
+      inputIndices: Array[Int],
+      dataViews: Seq[DataViewSpec[_]]
+    )
+
+    private case class MatchAgg(
+      aggregations: Seq[SingleAggCall],
+      inputExprs: Seq[RexNode]
+    )
+
+  }
+
 }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala
index 5b0aa65..bc0f56e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamMatchRule.scala
@@ -18,15 +18,21 @@
 
 package org.apache.flink.table.plan.rules.datastream
 
-import org.apache.calcite.plan.{RelOptRule, RelTraitSet}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.sql.SqlAggFunction
 import org.apache.flink.table.api.TableException
 import org.apache.flink.table.plan.logical.MatchRecognize
 import org.apache.flink.table.plan.nodes.FlinkConventions
 import org.apache.flink.table.plan.nodes.datastream.DataStreamMatch
 import org.apache.flink.table.plan.nodes.logical.FlinkLogicalMatch
 import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.plan.util.RexDefaultVisitor
+import org.apache.flink.table.util.MatchUtil
+
+import scala.collection.JavaConverters._
 
 class DataStreamMatchRule
   extends ConverterRule(
@@ -35,6 +41,14 @@ class DataStreamMatchRule
     FlinkConventions.DATASTREAM,
     "DataStreamMatchRule") {
 
+  override def matches(call: RelOptRuleCall): Boolean = {
+    val logicalMatch: FlinkLogicalMatch = call.rel(0).asInstanceOf[FlinkLogicalMatch]
+
+    validateAggregations(logicalMatch.getMeasures.values().asScala)
+    validateAggregations(logicalMatch.getPatternDefinitions.values().asScala)
+    true
+  }
+
   override def convert(rel: RelNode): RelNode = {
     val logicalMatch: FlinkLogicalMatch = rel.asInstanceOf[FlinkLogicalMatch]
     val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM)
@@ -71,6 +85,30 @@ class DataStreamMatchRule
       new RowSchema(logicalMatch.getRowType),
       new RowSchema(logicalMatch.getInput.getRowType))
   }
+
+  private def validateAggregations(expr: Iterable[RexNode]): Unit = {
+    val validator = new AggregationsValidator
+    expr.foreach(_.accept(validator))
+  }
+
+  class AggregationsValidator extends RexDefaultVisitor[Object] {
+
+    override def visitCall(call: RexCall): AnyRef = {
+      call.getOperator match {
+        case _: SqlAggFunction =>
+          call.accept(new MatchUtil.AggregationPatternVariableFinder)
+        case _ =>
+          call.getOperands.asScala.foreach(_.accept(this))
+      }
+
+      null
+    }
+
+    override def visitNode(rexNode: RexNode): AnyRef = {
+      null
+    }
+  }
+
 }
 
 object DataStreamMatchRule {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/MatchUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/MatchUtil.scala
new file mode 100644
index 0000000..0cce171
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/MatchUtil.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.util
+
+import org.apache.calcite.rex.{RexCall, RexNode, RexPatternFieldRef}
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.plan.util.RexDefaultVisitor
+import scala.collection.JavaConverters._
+
+object MatchUtil {
+  val ALL_PATTERN_VARIABLE = "*"
+
+  class AggregationPatternVariableFinder extends RexDefaultVisitor[Option[String]] {
+
+    override def visitPatternFieldRef(patternFieldRef: RexPatternFieldRef): Option[String] = Some(
+      patternFieldRef.getAlpha)
+
+    override def visitCall(call: RexCall): Option[String] = {
+      if (call.operands.size() == 0) {
+        Some(ALL_PATTERN_VARIABLE)
+      } else {
+        call.operands.asScala.map(n => n.accept(this)).reduce((op1, op2) => (op1, op2) match {
+          case (None, None) => None
+          case (x, None) => x
+          case (None, x) => x
+          case (Some(var1), Some(var2)) if var1.equals(var2) =>
+            Some(var1)
+          case _ =>
+            throw new ValidationException(s"Aggregation must be applied to a single pattern " +
+              s"variable. Malformed expression: $call")
+        })
+      }
+    }
+
+    override def visitNode(rexNode: RexNode): Option[String] = None
+  }
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala
similarity index 89%
rename from flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala
rename to flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala
index e10a568..2890179 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchOperatorValidationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/match/MatchRecognizeValidationTest.scala
@@ -21,13 +21,13 @@ package org.apache.flink.table.`match`
 import org.apache.flink.api.scala._
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.api.{TableException, ValidationException}
-import org.apache.flink.table.codegen.CodeGenException
 import org.apache.flink.table.runtime.stream.sql.ToMillis
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
 import org.apache.flink.table.utils.TableTestBase
 import org.apache.flink.types.Row
 import org.junit.Test
 
-class MatchOperatorValidationTest extends TableTestBase {
+class MatchRecognizeValidationTest extends TableTestBase {
 
   private val streamUtils = streamTestUtil()
   streamUtils.addTable[(String, Long, Int, Int)]("Ticker",
@@ -128,15 +128,10 @@ class MatchOperatorValidationTest extends TableTestBase {
     streamUtils.tableEnv.sqlQuery(sqlQuery).toRetractStream[Row]
   }
 
-  // ***************************************************************************************
-  // * Those validations are temporary. We should remove those tests once we support those *
-  // * features.                                                                           *
-  // ***************************************************************************************
-
   @Test
-  def testAllRowsPerMatch(): Unit = {
-    thrown.expectMessage("All rows per match mode is not supported yet.")
-    thrown.expect(classOf[TableException])
+  def testAggregatesOnMultiplePatternVariablesNotSupported(): Unit = {
+    thrown.expect(classOf[ValidationException])
+    thrown.expectMessage("SQL validation failed.")
 
     val sqlQuery =
       s"""
@@ -145,11 +140,10 @@ class MatchOperatorValidationTest extends TableTestBase {
          |MATCH_RECOGNIZE (
          |  ORDER BY proctime
          |  MEASURES
-         |    A.symbol AS aSymbol
-         |  ALL ROWS PER MATCH
+         |    SUM(A.price + B.tax) AS taxedPrice
          |  PATTERN (A B)
          |  DEFINE
-         |    A AS symbol = 'a'
+         |    A AS A.symbol = 'a'
          |) AS T
          |""".stripMargin
 
@@ -157,10 +151,11 @@ class MatchOperatorValidationTest extends TableTestBase {
   }
 
   @Test
-  def testGreedyQuantifierAtTheEndIsNotSupported(): Unit = {
-    thrown.expectMessage("Greedy quantifiers are not allowed as the last element of a " +
-      "Pattern yet. Finish your pattern with either a simple variable or reluctant quantifier.")
-    thrown.expect(classOf[TableException])
+  def testAggregatesOnMultiplePatternVariablesNotSupportedInUDAGs(): Unit = {
+    thrown.expect(classOf[ValidationException])
+    thrown.expectMessage("Aggregation must be applied to a single pattern variable")
+
+    streamUtils.tableEnv.registerFunction("weightedAvg", new WeightedAvg)
 
     val sqlQuery =
       s"""
@@ -169,20 +164,24 @@ class MatchOperatorValidationTest extends TableTestBase {
          |MATCH_RECOGNIZE (
          |  ORDER BY proctime
          |  MEASURES
-         |    A.symbol AS aSymbol
-         |  PATTERN (A B+)
+         |    weightedAvg(A.price, B.tax) AS weightedAvg
+         |  PATTERN (A B)
          |  DEFINE
-         |    A AS symbol = 'a'
+         |    A AS A.symbol = 'a'
          |) AS T
          |""".stripMargin
 
     streamUtils.tableEnv.sqlQuery(sqlQuery).toAppendStream[Row]
   }
 
+  // ***************************************************************************************
+  // * Those validations are temporary. We should remove those tests once we support those *
+  // * features.                                                                           *
+  // ***************************************************************************************
+
   @Test
-  def testPatternsProducingEmptyMatchesAreNotSupported(): Unit = {
-    thrown.expectMessage("Patterns that can produce empty matches are not supported. " +
-      "There must be at least one non-optional state.")
+  def testAllRowsPerMatch(): Unit = {
+    thrown.expectMessage("All rows per match mode is not supported yet.")
     thrown.expect(classOf[TableException])
 
     val sqlQuery =
@@ -193,7 +192,8 @@ class MatchOperatorValidationTest extends TableTestBase {
          |  ORDER BY proctime
          |  MEASURES
          |    A.symbol AS aSymbol
-         |  PATTERN (A*)
+         |  ALL ROWS PER MATCH
+         |  PATTERN (A B)
          |  DEFINE
          |    A AS symbol = 'a'
          |) AS T
@@ -203,11 +203,10 @@ class MatchOperatorValidationTest extends TableTestBase {
   }
 
   @Test
-  def testAggregatesAreNotSupportedInMeasures(): Unit = {
-    thrown.expectMessage(
-      "Unsupported call: SUM \nIf you think this function should be supported, you can " +
-        "create an issue and start a discussion for it.")
-    thrown.expect(classOf[CodeGenException])
+  def testGreedyQuantifierAtTheEndIsNotSupported(): Unit = {
+    thrown.expectMessage("Greedy quantifiers are not allowed as the last element of a " +
+      "Pattern yet. Finish your pattern with either a simple variable or reluctant quantifier.")
+    thrown.expect(classOf[TableException])
 
     val sqlQuery =
       s"""
@@ -216,10 +215,10 @@ class MatchOperatorValidationTest extends TableTestBase {
          |MATCH_RECOGNIZE (
          |  ORDER BY proctime
          |  MEASURES
-         |    SUM(A.price + A.tax) AS cost
-         |  PATTERN (A B)
+         |    A.symbol AS aSymbol
+         |  PATTERN (A B+)
          |  DEFINE
-         |    A AS A.symbol = 'a'
+         |    A AS symbol = 'a'
          |) AS T
          |""".stripMargin
 
@@ -227,11 +226,10 @@ class MatchOperatorValidationTest extends TableTestBase {
   }
 
   @Test
-  def testAggregatesAreNotSupportedInDefine(): Unit = {
-    thrown.expectMessage(
-      "Unsupported call: SUM \nIf you think this function should be supported, you can " +
-        "create an issue and start a discussion for it.")
-    thrown.expect(classOf[CodeGenException])
+  def testPatternsProducingEmptyMatchesAreNotSupported(): Unit = {
+    thrown.expectMessage("Patterns that can produce empty matches are not supported. " +
+      "There must be at least one non-optional state.")
+    thrown.expect(classOf[TableException])
 
     val sqlQuery =
       s"""
@@ -240,10 +238,10 @@ class MatchOperatorValidationTest extends TableTestBase {
          |MATCH_RECOGNIZE (
          |  ORDER BY proctime
          |  MEASURES
-         |    B.price as bPrice
-         |  PATTERN (A+ B)
+         |    A.symbol AS aSymbol
+         |  PATTERN (A*)
          |  DEFINE
-         |    A AS SUM(A.price + A.tax) < 10
+         |    A AS symbol = 'a'
          |) AS T
          |""".stripMargin
 
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala
index 8f5a8f3..ffd1c2b 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/MatchRecognizeITCase.scala
@@ -22,12 +22,14 @@ import java.sql.Timestamp
 import java.util.TimeZone
 
 import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.api.scala._
 import org.apache.flink.streaming.api.TimeCharacteristic
 import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
-import org.apache.flink.table.api.{TableConfig, TableEnvironment}
 import org.apache.flink.table.api.scala._
-import org.apache.flink.table.functions.{FunctionContext, ScalarFunction}
+import org.apache.flink.table.api.{TableConfig, TableEnvironment, Types}
+import org.apache.flink.table.functions.{AggregateFunction, FunctionContext, ScalarFunction}
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
 import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction
 import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase, UserDefinedFunctionTestUtils}
 import org.apache.flink.types.Row
@@ -39,7 +41,7 @@ import scala.collection.mutable
 class MatchRecognizeITCase extends StreamingWithStateTestBase {
 
   @Test
-  def testSimpleCEP(): Unit = {
+  def testSimplePattern(): Unit = {
     val env = StreamExecutionEnvironment.getExecutionEnvironment
     env.setParallelism(1)
     val tEnv = TableEnvironment.getTableEnvironment(env)
@@ -86,7 +88,7 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
   }
 
   @Test
-  def testSimpleCEPWithNulls(): Unit = {
+  def testSimplePatternWithNulls(): Unit = {
     val env = StreamExecutionEnvironment.getExecutionEnvironment
     env.setParallelism(1)
     val tEnv = TableEnvironment.getTableEnvironment(env)
@@ -464,6 +466,132 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
+  /**
+    * This query checks:
+    *
+    * 1. count(D.price) produces 0, because no rows matched to D
+    * 2. sum(D.price) produces null, because no rows matched to D
+    * 3. aggregates that take multiple parameters work
+    * 4. aggregates with expressions work
+    */
+  @Test
+  def testAggregates(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(1)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    tEnv.getConfig.setMaxGeneratedCodeLength(1)
+    StreamITCase.clear
+
+    val data = new mutable.MutableList[(Int, String, Long, Double, Int)]
+    data.+=((1, "a", 1, 0.8, 1))
+    data.+=((2, "z", 2, 0.8, 3))
+    data.+=((3, "b", 1, 0.8, 2))
+    data.+=((4, "c", 1, 0.8, 5))
+    data.+=((5, "d", 4, 0.1, 5))
+    data.+=((6, "a", 2, 1.5, 2))
+    data.+=((7, "b", 2, 0.8, 3))
+    data.+=((8, "c", 1, 0.8, 2))
+    data.+=((9, "h", 4, 0.8, 3))
+    data.+=((10, "h", 4, 0.8, 3))
+    data.+=((11, "h", 2, 0.8, 3))
+    data.+=((12, "h", 2, 0.8, 3))
+
+    val t = env.fromCollection(data)
+      .toTable(tEnv, 'id, 'name, 'price, 'rate, 'weight, 'proctime.proctime)
+    tEnv.registerTable("MyTable", t)
+    tEnv.registerFunction("weightedAvg", new WeightedAvg)
+
+    val sqlQuery =
+      s"""
+         |SELECT *
+         |FROM MyTable
+         |MATCH_RECOGNIZE (
+         |  ORDER BY proctime
+         |  MEASURES
+         |    FIRST(id) as startId,
+         |    SUM(A.price) AS sumA,
+         |    COUNT(DISTINCT D.price) AS countD,
+         |    SUM(D.price) as sumD,
+         |    weightedAvg(price, weight) as wAvg,
+         |    AVG(B.price) AS avgB,
+         |    SUM(B.price * B.rate) as sumExprB,
+         |    LAST(id) as endId
+         |  AFTER MATCH SKIP PAST LAST ROW
+         |  PATTERN (A+ B+ C D? E )
+         |  DEFINE
+         |    A AS SUM(A.price) < 6,
+         |    B AS SUM(B.price * B.rate) < SUM(A.price) AND
+         |         SUM(B.price * B.rate) > 0.2 AND
+         |         SUM(B.price) >= 1 AND
+         |         AVG(B.price) >= 1 AND
+         |         weightedAvg(price, weight) > 1
+         |) AS T
+         |""".stripMargin
+
+    val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+    result.addSink(new StreamITCase.StringSink[Row])
+    env.execute()
+
+    val expected = mutable.MutableList("1,5,0,null,2,3,3.4,8", "9,4,0,null,3,4,3.2,12")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testAggregatesWithNullInputs(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(1)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    tEnv.getConfig.setMaxGeneratedCodeLength(1)
+    StreamITCase.clear
+
+    val data = new mutable.MutableList[Row]
+    data.+=(Row.of(Int.box(1), "a", Int.box(10)))
+    data.+=(Row.of(Int.box(2), "z", Int.box(10)))
+    data.+=(Row.of(Int.box(3), "b", null))
+    data.+=(Row.of(Int.box(4), "c", null))
+    data.+=(Row.of(Int.box(5), "d", Int.box(3)))
+    data.+=(Row.of(Int.box(6), "c", Int.box(3)))
+    data.+=(Row.of(Int.box(7), "c", Int.box(3)))
+    data.+=(Row.of(Int.box(8), "c", Int.box(3)))
+    data.+=(Row.of(Int.box(9), "c", Int.box(2)))
+
+    val t = env.fromCollection(data)(Types.ROW(
+      BasicTypeInfo.INT_TYPE_INFO,
+      BasicTypeInfo.STRING_TYPE_INFO,
+      BasicTypeInfo.INT_TYPE_INFO))
+      .toTable(tEnv, 'id, 'name, 'price, 'proctime.proctime)
+    tEnv.registerTable("MyTable", t)
+    tEnv.registerFunction("weightedAvg", new WeightedAvg)
+
+    val sqlQuery =
+      s"""
+         |SELECT *
+         |FROM MyTable
+         |MATCH_RECOGNIZE (
+         |  ORDER BY proctime
+         |  MEASURES
+         |    SUM(A.price) as sumA,
+         |    COUNT(A.id) as countAId,
+         |    COUNT(A.price) as countAPrice,
+         |    COUNT(*) as countAll,
+         |    COUNT(price) as countAllPrice,
+         |    LAST(id) as endId
+         |  AFTER MATCH SKIP PAST LAST ROW
+         |  PATTERN (A+ C)
+         |  DEFINE
+         |    A AS SUM(A.price) < 30,
+         |    C AS C.name = 'c'
+         |) AS T
+         |""".stripMargin
+
+    val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+    result.addSink(new StreamITCase.StringSink[Row])
+    env.execute()
+
+    val expected = mutable.MutableList("29,7,5,8,6,8")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
   @Test
   def testAccessingProctime(): Unit = {
     val env = StreamExecutionEnvironment.getExecutionEnvironment
@@ -567,9 +695,11 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
       .toTable(tEnv, 'id, 'name, 'price, 'proctime.proctime)
     tEnv.registerTable("MyTable", t)
     tEnv.registerFunction("prefix", new PrefixingScalarFunc)
+    tEnv.registerFunction("countFrom", new RichAggFunc)
     val prefix = "PREF"
+    val startFrom = 4
     UserDefinedFunctionTestUtils
-      .setJobParameters(env, Map("prefix" -> prefix))
+      .setJobParameters(env, Map("prefix" -> prefix, "start" -> startFrom.toString))
 
     val sqlQuery =
       s"""
@@ -580,11 +710,12 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
          |  MEASURES
          |    FIRST(id) as firstId,
          |    prefix(A.name) as prefixedNameA,
+         |    countFrom(A.price) as countFromA,
          |    LAST(id) as lastId
          |  AFTER MATCH SKIP PAST LAST ROW
          |  PATTERN (A+ C)
          |  DEFINE
-         |    A AS prefix(A.name) = '$prefix:a'
+         |    A AS prefix(A.name) = '$prefix:a' AND countFrom(A.price) <= ${startFrom + 4}
          |) AS T
          |""".stripMargin
 
@@ -592,7 +723,7 @@ class MatchRecognizeITCase extends StreamingWithStateTestBase {
     result.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
-    val expected = mutable.MutableList("1,PREF:a,6", "7,PREF:a,9")
+    val expected = mutable.MutableList("1,PREF:a,8,5", "7,PREF:a,6,9")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 }
@@ -615,3 +746,26 @@ private class PrefixingScalarFunc extends ScalarFunction {
     s"$prefix:$value"
   }
 }
+
+private case class CountAcc(var count: Long)
+
+private class RichAggFunc extends AggregateFunction[Long, CountAcc] {
+
+  private var start : Long = 0
+
+  override def open(context: FunctionContext): Unit = {
+    start = context.getJobParameter("start", "0").toLong
+  }
+
+  override def close(): Unit = {
+    start = 0
+  }
+
+  def accumulate(countAcc: CountAcc, value: Long): Unit = {
+    countAcc.count += value
+  }
+
+  override def createAccumulator(): CountAcc = CountAcc(start)
+
+  override def getValue(accumulator: CountAcc): Long = accumulator.count
+}