You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2018/12/17 08:25:16 UTC

[flink] 01/06: [FLINK-7599][table] Refactored AggregateUtil#transformToAggregateFunctions

This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit eab3ce7faa57995a39845a8f66cd0d44f18bd4ed
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Mon Nov 26 14:23:22 2018 +0100

    [FLINK-7599][table] Refactored AggregateUtil#transformToAggregateFunctions
---
 .../table/runtime/aggregate/AggregateUtil.scala    | 939 ++++++++++++---------
 1 file changed, 526 insertions(+), 413 deletions(-)

diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 4a50855..f1386df 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -85,27 +85,26 @@ object AggregateUtil {
       isRowsClause: Boolean)
     : ProcessFunction[CRow, CRow] = {
 
-    val (aggFields, aggregates, isDistinctAggs, accTypes, accSpecs) =
-      transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
         namedAggregates.map(_.getKey),
         aggregateInputType,
         needRetraction = false,
         tableConfig,
         isStateBackedDataViews = true)
 
-    val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*)
 
     val forwardMapping = (0 until inputType.getFieldCount).toArray
-    val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray
-    val outputArity = inputType.getFieldCount + aggregates.length
+    val aggMapping = aggregateMetadata.getAdjustedMapping(inputType.getFieldCount)
+
+    val outputArity = inputType.getFieldCount + aggregateMetadata.getAggregateCallsCount
 
     val genFunction = generator.generateAggregations(
       "UnboundedProcessingOverAggregateHelper",
       inputFieldTypeInfo,
-      aggregates,
-      aggFields,
+      aggregateMetadata.getAggregateFunctions,
+      aggregateMetadata.getAggregateIndices,
       aggMapping,
-      isDistinctAggs,
+      aggregateMetadata.getAggregatesDistinctFlags,
       isStateBackedDataViews = true,
       partialResults = false,
       forwardMapping,
@@ -114,9 +113,11 @@ object AggregateUtil {
       needRetract = false,
       needMerge = false,
       needReset = false,
-      accConfig = Some(accSpecs)
+      accConfig = Some(aggregateMetadata.getAggregatesAccumulatorSpecs)
     )
 
+    val aggregationStateType: RowTypeInfo = new RowTypeInfo(aggregateMetadata
+      .getAggregatesAccumulatorTypes: _*)
     if (rowTimeIdx.isDefined) {
       if (isRowsClause) {
         // ROWS unbounded over process function
@@ -168,27 +169,23 @@ object AggregateUtil {
       generateRetraction: Boolean,
       consumeRetraction: Boolean): ProcessFunction[CRow, CRow] = {
 
-    val (aggFields, aggregates, isDistinctAggs, accTypes, accSpecs) =
-      transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
         namedAggregates.map(_.getKey),
         inputRowType,
         consumeRetraction,
         tableConfig,
         isStateBackedDataViews = true)
 
-    val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
-
-    val outputArity = groupings.length + aggregates.length
-
-    val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*)
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
+    val outputArity = groupings.length + aggregateMetadata.getAggregateCallsCount
 
     val genFunction = generator.generateAggregations(
       "NonWindowedAggregationHelper",
       inputFieldTypes,
-      aggregates,
-      aggFields,
+      aggregateMetadata.getAggregateFunctions,
+      aggregateMetadata.getAggregateIndices,
       aggMapping,
-      isDistinctAggs,
+      aggregateMetadata.getAggregatesDistinctFlags,
       isStateBackedDataViews = true,
       partialResults = false,
       groupings,
@@ -197,9 +194,11 @@ object AggregateUtil {
       consumeRetraction,
       needMerge = false,
       needReset = false,
-      accConfig = Some(accSpecs)
+      accConfig = Some(aggregateMetadata.getAggregatesAccumulatorSpecs)
     )
 
+    val aggregationStateType: RowTypeInfo = new RowTypeInfo(aggregateMetadata
+      .getAggregatesAccumulatorTypes: _*)
     new GroupAggProcessFunction(
       genFunction,
       aggregationStateType,
@@ -238,28 +237,27 @@ object AggregateUtil {
     : ProcessFunction[CRow, CRow] = {
 
     val needRetract = true
-    val (aggFields, aggregates, isDistinctAggs, accTypes, accSpecs) =
-      transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
         namedAggregates.map(_.getKey),
         aggregateInputType,
         needRetract,
         tableConfig,
         isStateBackedDataViews = true)
 
-    val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*)
     val inputRowType = CRowTypeInfo(inputTypeInfo)
 
     val forwardMapping = (0 until inputType.getFieldCount).toArray
-    val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray
-    val outputArity = inputType.getFieldCount + aggregates.length
+    val aggMapping = aggregateMetadata.getAdjustedMapping(inputType.getFieldCount)
+
+    val outputArity = inputType.getFieldCount + aggregateMetadata.getAggregateCallsCount
 
     val genFunction = generator.generateAggregations(
       "BoundedOverAggregateHelper",
       inputFieldTypeInfo,
-      aggregates,
-      aggFields,
+      aggregateMetadata.getAggregateFunctions,
+      aggregateMetadata.getAggregateIndices,
       aggMapping,
-      isDistinctAggs,
+      aggregateMetadata.getAggregatesDistinctFlags,
       isStateBackedDataViews = true,
       partialResults = false,
       forwardMapping,
@@ -268,9 +266,11 @@ object AggregateUtil {
       needRetract,
       needMerge = false,
       needReset = false,
-      accConfig = Some(accSpecs)
+      accConfig = Some(aggregateMetadata.getAggregatesAccumulatorSpecs)
     )
 
+    val aggregationStateType: RowTypeInfo = new RowTypeInfo(aggregateMetadata
+      .getAggregatesAccumulatorTypes: _*)
     if (rowTimeIdx.isDefined) {
       if (isRowsClause) {
         new RowTimeBoundedRowsOver(
@@ -343,7 +343,7 @@ object AggregateUtil {
   : MapFunction[Row, Row] = {
 
     val needRetract = false
-    val (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       inputType,
       needRetract,
@@ -352,8 +352,8 @@ object AggregateUtil {
     val mapReturnType: RowTypeInfo =
       createRowTypeForKeysAndAggregates(
         groupings,
-        aggregates,
-        accTypes,
+        aggregateMetadata.getAggregateFunctions,
+        aggregateMetadata.getAggregatesAccumulatorTypes,
         inputType,
         Some(Array(BasicTypeInfo.LONG_TYPE_INFO)))
 
@@ -385,16 +385,16 @@ object AggregateUtil {
         throw new UnsupportedOperationException(s"$window is currently not supported on batch")
     }
 
-    val aggMapping = aggregates.indices.toArray.map(_ + groupings.length)
-    val outputArity = aggregates.length + groupings.length + 1
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
+    val outputArity = aggregateMetadata.getAggregateCallsCount + groupings.length + 1
 
     val genFunction = generator.generateAggregations(
       "DataSetAggregatePrepareMapHelper",
       inputFieldTypeInfo,
-      aggregates,
-      aggFieldIndexes,
+      aggregateMetadata.getAggregateFunctions,
+      aggregateMetadata.getAggregateIndices,
       aggMapping,
-      isDistinctAggs,
+      aggregateMetadata.getAggregatesDistinctFlags,
       isStateBackedDataViews = false,
       partialResults = true,
       groupings,
@@ -452,7 +452,7 @@ object AggregateUtil {
     : RichGroupReduceFunction[Row, Row] = {
 
     val needRetract = false
-    val (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
       needRetract,
@@ -460,8 +460,8 @@ object AggregateUtil {
 
     val returnType: RowTypeInfo = createRowTypeForKeysAndAggregates(
       groupings,
-      aggregates,
-      accTypes,
+      aggregateMetadata.getAggregateFunctions,
+      aggregateMetadata.getAggregatesAccumulatorTypes,
       physicalInputRowType,
       Some(Array(BasicTypeInfo.LONG_TYPE_INFO)))
 
@@ -470,17 +470,18 @@ object AggregateUtil {
     window match {
       case SlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) =>
         // sliding time-window for partial aggregations
+        val aggMappings = aggregateMetadata.getAdjustedMapping(groupings.length)
         val genFunction = generator.generateAggregations(
           "DataSetAggregatePrepareMapHelper",
           physicalInputTypes,
-          aggregates,
-          aggFieldIndexes,
-          aggregates.indices.map(_ + groupings.length).toArray,
-          isDistinctAggs,
+          aggregateMetadata.getAggregateFunctions,
+          aggregateMetadata.getAggregateIndices,
+          aggMappings,
+          aggregateMetadata.getAggregatesDistinctFlags,
           isStateBackedDataViews = false,
           partialResults = true,
           groupings.indices.toArray,
-          Some(aggregates.indices.map(_ + groupings.length).toArray),
+          Some(aggMappings),
           keysAndAggregatesArity + 1,
           needRetract,
           needMerge = true,
@@ -569,25 +570,25 @@ object AggregateUtil {
     : RichGroupReduceFunction[Row, Row] = {
 
     val needRetract = false
-    val (aggFieldIndexes, aggregates, isDistinctAggs, _, _) = transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
       needRetract,
       tableConfig)
 
-    val aggMapping = aggregates.indices.toArray.map(_ + groupings.length)
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
 
     val genPreAggFunction = generator.generateAggregations(
       "GroupingWindowAggregateHelper",
       physicalInputTypes,
-      aggregates,
-      aggFieldIndexes,
+      aggregateMetadata.getAggregateFunctions,
+      aggregateMetadata.getAggregateIndices,
       aggMapping,
-      isDistinctAggs,
+      aggregateMetadata.getAggregatesDistinctFlags,
       isStateBackedDataViews = false,
       partialResults = true,
       groupings.indices.toArray,
-      Some(aggregates.indices.map(_ + groupings.length).toArray),
+      Some(aggMapping),
       outputType.getFieldCount,
       needRetract,
       needMerge = true,
@@ -598,14 +599,14 @@ object AggregateUtil {
     val genFinalAggFunction = generator.generateAggregations(
       "GroupingWindowAggregateHelper",
       physicalInputTypes,
-      aggregates,
-      aggFieldIndexes,
+      aggregateMetadata.getAggregateFunctions,
+      aggregateMetadata.getAggregateIndices,
       aggMapping,
-      isDistinctAggs,
+      aggregateMetadata.getAggregatesDistinctFlags,
       isStateBackedDataViews = false,
       partialResults = false,
       groupings.indices.toArray,
-      Some(aggregates.indices.map(_ + groupings.length).toArray),
+      Some(aggMapping),
       outputType.getFieldCount,
       needRetract,
       needMerge = true,
@@ -619,7 +620,7 @@ object AggregateUtil {
       case TumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) =>
         // tumbling time window
         val (startPos, endPos, timePos) = computeWindowPropertyPos(properties)
-        if (doAllSupportPartialMerge(aggregates)) {
+        if (doAllSupportPartialMerge(aggregateMetadata.getAggregateFunctions)) {
           // for incremental aggregations
           new DataSetTumbleTimeWindowAggReduceCombineFunction(
             genPreAggFunction,
@@ -659,7 +660,7 @@ object AggregateUtil {
 
       case SlidingGroupWindow(_, _, size, _) if isTimeInterval(size.resultType) =>
         val (startPos, endPos, timePos) = computeWindowPropertyPos(properties)
-        if (doAllSupportPartialMerge(aggregates)) {
+        if (doAllSupportPartialMerge(aggregateMetadata.getAggregateFunctions)) {
           // for partial aggregations
           new DataSetSlideWindowAggReduceCombineFunction(
             genPreAggFunction,
@@ -726,13 +727,13 @@ object AggregateUtil {
     tableConfig: TableConfig): MapPartitionFunction[Row, Row] = {
 
     val needRetract = false
-    val (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
       needRetract,
       tableConfig)
 
-    val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
 
     val keysAndAggregatesArity = groupings.length + namedAggregates.length
 
@@ -741,23 +742,23 @@ object AggregateUtil {
         val combineReturnType: RowTypeInfo =
           createRowTypeForKeysAndAggregates(
             groupings,
-            aggregates,
-            accTypes,
+            aggregateMetadata.getAggregateFunctions,
+            aggregateMetadata.getAggregatesAccumulatorTypes,
             physicalInputRowType,
             Option(Array(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO)))
 
         val genFunction = generator.generateAggregations(
           "GroupingWindowAggregateHelper",
           physicalInputTypes,
-          aggregates,
-          aggFieldIndexes,
+          aggregateMetadata.getAggregateFunctions,
+          aggregateMetadata.getAggregateIndices,
           aggMapping,
-          isDistinctAggs,
+          aggregateMetadata.getAggregatesDistinctFlags,
           isStateBackedDataViews = false,
           partialResults = true,
           groupings.indices.toArray,
-          Some(aggregates.indices.map(_ + groupings.length).toArray),
-          groupings.length + aggregates.length + 2,
+          Some(aggMapping),
+          groupings.length + aggregateMetadata.getAggregateCallsCount + 2,
           needRetract,
           needMerge = true,
           needReset = true,
@@ -803,14 +804,13 @@ object AggregateUtil {
     : GroupCombineFunction[Row, Row] = {
 
     val needRetract = false
-    val (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
       needRetract,
       tableConfig)
 
-    val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
-
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
     val keysAndAggregatesArity = groupings.length + namedAggregates.length
 
     window match {
@@ -819,23 +819,23 @@ object AggregateUtil {
         val combineReturnType: RowTypeInfo =
           createRowTypeForKeysAndAggregates(
             groupings,
-            aggregates,
-            accTypes,
+            aggregateMetadata.getAggregateFunctions,
+            aggregateMetadata.getAggregatesAccumulatorTypes,
             physicalInputRowType,
             Option(Array(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO)))
 
         val genFunction = generator.generateAggregations(
           "GroupingWindowAggregateHelper",
           physicalInputTypes,
-          aggregates,
-          aggFieldIndexes,
+          aggregateMetadata.getAggregateFunctions,
+          aggregateMetadata.getAggregateIndices,
           aggMapping,
-          isDistinctAggs,
+          aggregateMetadata.getAggregatesDistinctFlags,
           isStateBackedDataViews = false,
           partialResults = true,
           groupings.indices.toArray,
-          Some(aggregates.indices.map(_ + groupings.length).toArray),
-          groupings.length + aggregates.length + 2,
+          Some(aggMapping),
+          keysAndAggregatesArity + 2,
           needRetract,
           needMerge = true,
           needReset = true,
@@ -873,7 +873,7 @@ object AggregateUtil {
         Either[DataSetAggFunction, DataSetFinalAggFunction]) = {
 
     val needRetract = false
-    val (aggInFields, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
+    val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       inputType,
       needRetract,
@@ -888,26 +888,28 @@ object AggregateUtil {
 
     val aggOutFields = aggOutMapping.map(_._1)
 
-    if (doAllSupportPartialMerge(aggregates)) {
+    if (doAllSupportPartialMerge(aggregateMetadata.getAggregateFunctions)) {
+
+      val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
 
       // compute preaggregation type
       val preAggFieldTypes = gkeyOutMapping.map(_._2)
         .map(inputType.getFieldList.get(_).getType)
-        .map(FlinkTypeFactory.toTypeInfo) ++ accTypes
+        .map(FlinkTypeFactory.toTypeInfo) ++ aggregateMetadata.getAggregatesAccumulatorTypes
       val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*)
 
       val genPreAggFunction = generator.generateAggregations(
         "DataSetAggregatePrepareMapHelper",
         inputFieldTypeInfo,
-        aggregates,
-        aggInFields,
-        aggregates.indices.map(_ + groupings.length).toArray,
-        isDistinctAggs,
+        aggregateMetadata.getAggregateFunctions,
+        aggregateMetadata.getAggregateIndices,
+        aggMapping,
+        aggregateMetadata.getAggregatesDistinctFlags,
         isStateBackedDataViews = false,
         partialResults = true,
         groupings,
         None,
-        groupings.length + aggregates.length,
+        groupings.length + aggregateMetadata.getAggregateCallsCount,
         needRetract,
         needMerge = false,
         needReset = true,
@@ -927,14 +929,14 @@ object AggregateUtil {
       val genFinalAggFunction = generator.generateAggregations(
         "DataSetAggregateFinalHelper",
         inputFieldTypeInfo,
-        aggregates,
-        aggInFields,
+        aggregateMetadata.getAggregateFunctions,
+        aggregateMetadata.getAggregateIndices,
         aggOutFields,
-        isDistinctAggs,
+        aggregateMetadata.getAggregatesDistinctFlags,
         isStateBackedDataViews = false,
         partialResults = false,
         gkeyMapping,
-        Some(aggregates.indices.map(_ + groupings.length).toArray),
+        Some(aggMapping),
         outputType.getFieldCount,
         needRetract,
         needMerge = true,
@@ -952,10 +954,10 @@ object AggregateUtil {
       val genFunction = generator.generateAggregations(
         "DataSetAggregateHelper",
         inputFieldTypeInfo,
-        aggregates,
-        aggInFields,
+        aggregateMetadata.getAggregateFunctions,
+        aggregateMetadata.getAggregateIndices,
         aggOutFields,
-        isDistinctAggs,
+        aggregateMetadata.getAggregatesDistinctFlags,
         isStateBackedDataViews = false,
         partialResults = false,
         groupings,
@@ -1040,23 +1042,23 @@ object AggregateUtil {
     : (DataStreamAggFunction[CRow, Row, Row], RowTypeInfo) = {
 
     val needRetract = false
-    val (aggFields, aggregates, isDistinctAggs, accTypes, _) =
-      transformToAggregateFunctions(
+    val aggregateMetadata =
+      extractAggregateMetadata(
         namedAggregates.map(_.getKey),
         inputType,
         needRetract,
         tableConfig)
 
-    val aggMapping = aggregates.indices.toArray
-    val outputArity = aggregates.length
+    val aggMapping = aggregateMetadata.getAdjustedMapping(0)
+    val outputArity = aggregateMetadata.getAggregateCallsCount
 
     val genFunction = generator.generateAggregations(
       "GroupingWindowAggregateHelper",
       inputFieldTypeInfo,
-      aggregates,
-      aggFields,
+      aggregateMetadata.getAggregateFunctions,
+      aggregateMetadata.getAggregateIndices,
       aggMapping,
-      isDistinctAggs,
+      aggregateMetadata.getAggregatesDistinctFlags,
       isStateBackedDataViews = false,
       partialResults = false,
       groupingKeys,
@@ -1068,7 +1070,7 @@ object AggregateUtil {
       None
     )
 
-    val accumulatorRowType = new RowTypeInfo(accTypes: _*)
+    val accumulatorRowType = new RowTypeInfo(aggregateMetadata.getAggregatesAccumulatorTypes: _*)
     val aggFunction = new AggregateAggFunction(genFunction)
 
     (aggFunction, accumulatorRowType)
@@ -1083,11 +1085,11 @@ object AggregateUtil {
     groupKeysCount: Int,
     tableConfig: TableConfig): Boolean = {
 
-    val aggregateList = transformToAggregateFunctions(
+    val aggregateList = extractAggregateMetadata(
       aggregateCalls,
       inputType,
       needRetraction = false,
-      tableConfig)._2
+      tableConfig).getAggregateFunctions
 
     doAllSupportPartialMerge(aggregateList)
   }
@@ -1166,347 +1168,458 @@ object AggregateUtil {
     (propPos._1, propPos._2, propPos._3)
   }
 
-  private def transformToAggregateFunctions(
+  /**
+    * Meta info of a multiple [[AggregateCall]] required to generate a single
+    * [[GeneratedAggregations]] function.
+    */
+  private[flink] class AggregateMetadata(
+    private val aggregates: Seq[(AggregateCallMetadata, Array[Int])]) {
+
+    def getAggregateFunctions: Array[TableAggregateFunction[_, _]] = {
+      aggregates.map(_._1.aggregateFunction).toArray
+    }
+
+    def getAggregatesAccumulatorTypes: Array[TypeInformation[_]] = {
+      aggregates.map(_._1.accumulatorType).toArray
+    }
+
+    def getAggregatesAccumulatorSpecs: Array[Seq[DataViewSpec[_]]] = {
+      aggregates.map(_._1.accumulatorSpecs).toArray
+    }
+
+    def getAggregatesDistinctFlags: Array[Boolean] = {
+      aggregates.map(_._1.isDistinct).toArray
+    }
+
+    def getAggregateCallsCount: Int = {
+      aggregates.length
+    }
+
+    def getAggregateIndices: Array[Array[Int]] = {
+      aggregates.map(_._2).toArray
+    }
+
+    def getAdjustedMapping(offset: Int): Array[Int] = {
+      (0 until getAggregateCallsCount).map(_ + offset).toArray
+    }
+  }
+
+  /**
+    * Meta info of a single [[SqlAggFunction]] required to generate [[GeneratedAggregations]]
+    * function.
+    */
+  private[flink] case class AggregateCallMetadata(
+    aggregateFunction: TableAggregateFunction[_, _],
+    accumulatorType: TypeInformation[_],
+    accumulatorSpecs: Seq[DataViewSpec[_]],
+    isDistinct: Boolean
+  )
+
+  /**
+    * Prepares metadata [[AggregateCallMetadata]] required to generate code for
+    * [[GeneratedAggregations]] for a single [[SqlAggFunction]].
+    *
+    * @param aggregateFunction calcite's aggregate function
+    * @param isDistinct true if should be distinct aggregation
+    * @param aggregateInputTypes input types of given aggregate
+    * @param needRetraction if the [[TableAggregateFunction]] should produce retractions
+    * @param tableConfig tableConfig, required for decimal precision
+    * @param isStateBackedDataViews if data should be backed by state backend
+    * @param uniqueIdWithinAggregate index within an AggregateCallMetadata, used to create unique
+    *                                state names for each aggregate function
+    * @return the result contains required metadata:
+    *   - flink's aggregate function
+    *   - required accumulator information (type and specifications)
+    *   - if the aggregate is distinct
+    */
+  private[flink] def extractAggregateCallMetadata(
+      aggregateFunction: SqlAggFunction,
+      isDistinct: Boolean,
+      aggregateInputTypes: Seq[RelDataType],
+      needRetraction: Boolean,
+      tableConfig: TableConfig,
+      isStateBackedDataViews: Boolean,
+      uniqueIdWithinAggregate: Int)
+    : AggregateCallMetadata = {
+    // store the aggregate fields of each aggregate function, by the same order of aggregates.
+    // create aggregate function instances by function type and aggregate field data type.
+
+    val aggregate: TableAggregateFunction[_, _] = createFlinkAggFunction(
+      aggregateFunction,
+      needRetraction,
+      aggregateInputTypes,
+      tableConfig)
+
+    val (accumulatorType, accSpecs) = aggregateFunction match {
+      case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT =>
+        removeStateViewFieldsFromAccTypeInfo(
+          uniqueIdWithinAggregate,
+          aggregate,
+          aggregate.getAccumulatorType,
+          isStateBackedDataViews)
+
+      case udagg: AggSqlFunction =>
+        removeStateViewFieldsFromAccTypeInfo(
+          uniqueIdWithinAggregate,
+          aggregate,
+          udagg.accType,
+          isStateBackedDataViews)
+
+      case _ =>
+        (getAccumulatorTypeOfAggregateFunction(aggregate), None)
+    }
+
+    // create distinct accumulator filter argument
+    val distinctAccumulatorType = if (isDistinct) {
+      createDistinctAccumulatorType(aggregateInputTypes, isStateBackedDataViews, accumulatorType)
+    } else {
+      accumulatorType
+    }
+
+    AggregateCallMetadata(aggregate, distinctAccumulatorType, accSpecs.getOrElse(Seq()), isDistinct)
+  }
+
+  private def createDistinctAccumulatorType(
+      aggregateInputTypes: Seq[RelDataType],
+      isStateBackedDataViews: Boolean,
+      accumulatorType: TypeInformation[_])
+    : PojoTypeInfo[DistinctAccumulator[_]] = {
+    // Using Pojo fields for the real underlying accumulator
+    val pojoFields = new util.ArrayList[PojoField]()
+    pojoFields.add(new PojoField(
+      classOf[DistinctAccumulator[_]].getDeclaredField("realAcc"),
+      accumulatorType)
+    )
+    // If StateBackend is not enabled, the distinct mapping also needs
+    // to be added to the Pojo fields.
+    if (!isStateBackedDataViews) {
+
+      val argTypes: Array[TypeInformation[_]] = aggregateInputTypes
+        .map(FlinkTypeFactory.toTypeInfo).toArray
+
+      val mapViewTypeInfo = new MapViewTypeInfo(
+        new RowTypeInfo(argTypes: _*),
+        BasicTypeInfo.LONG_TYPE_INFO)
+      pojoFields.add(new PojoField(
+        classOf[DistinctAccumulator[_]].getDeclaredField("distinctValueMap"),
+        mapViewTypeInfo)
+      )
+    }
+    new PojoTypeInfo(classOf[DistinctAccumulator[_]], pojoFields)
+  }
+
+  /**
+    * Prepares metadata [[AggregateMetadata]] required to generate code for
+    * [[GeneratedAggregations]] for all [[AggregateCall]].
+    *
+    * @param aggregateCalls calcite's aggregate function
+    * @param aggregateInputType input type of given aggregates
+    * @param needRetraction if the [[TableAggregateFunction]] should produce retractions
+    * @param tableConfig tableConfig, required for decimal precision
+    * @param isStateBackedDataViews if data should be backed by state backend
+    * @return the result contains required metadata:
+    * - flink's aggregate function
+    * - required accumulator information (type and specifications)
+    * - indices important for each aggregate
+    * - if the aggregate is distinct
+    */
+  private def extractAggregateMetadata(
       aggregateCalls: Seq[AggregateCall],
       aggregateInputType: RelDataType,
       needRetraction: Boolean,
       tableConfig: TableConfig,
       isStateBackedDataViews: Boolean = false)
-  : (Array[Array[Int]],
-    Array[TableAggregateFunction[_, _]],
-    Array[Boolean],
-    Array[TypeInformation[_]],
-    Array[Seq[DataViewSpec[_]]]) = {
+    : AggregateMetadata = {
+
+    val aggregatesWithIndices = aggregateCalls.zipWithIndex.map {
+      case (aggregateCall, index) =>
+        val argList: util.List[Integer] = aggregateCall.getArgList
+
+        val aggFieldIndices = if (argList.isEmpty) {
+          if (aggregateCall.getAggregation.isInstanceOf[SqlCountAggFunction]) {
+            Array[Int](-1)
+          } else {
+            throw new TableException("Aggregate fields should not be empty.")
+          }
+        } else {
+          argList.asScala.map(i => i.intValue).toArray
+        }
+
+        val inputTypes = argList.map(aggregateInputType.getFieldList.get(_).getType)
+        val aggregateCallMetadata = extractAggregateCallMetadata(aggregateCall.getAggregation,
+          aggregateCall.isDistinct,
+          inputTypes,
+          needRetraction,
+          tableConfig,
+          isStateBackedDataViews,
+          index)
+
+        (aggregateCallMetadata, aggFieldIndices)
+    }
 
     // store the aggregate fields of each aggregate function, by the same order of aggregates.
-    val aggFieldIndexes = new Array[Array[Int]](aggregateCalls.size)
-    val aggregates = new Array[TableAggregateFunction[_ <: Any, _ <: Any]](aggregateCalls.size)
-    val accTypes = new Array[TypeInformation[_]](aggregateCalls.size)
+    new AggregateMetadata(aggregatesWithIndices)
+  }
 
-    // create aggregate function instances by function type and aggregate field data type.
-    aggregateCalls.zipWithIndex.foreach { case (aggregateCall, index) =>
-      val argList: util.List[Integer] = aggregateCall.getArgList
+  /**
+    * Converts calcite's [[SqlAggFunction]] to a Flink's UDF [[TableAggregateFunction]].
+    * create aggregate function instances by function type and aggregate field data type.
+    */
+  private def createFlinkAggFunction(
+      aggFunc: SqlAggFunction,
+      needRetraction: Boolean,
+      inputDataType: Seq[RelDataType],
+      tableConfig: TableConfig)
+    : TableAggregateFunction[_ <: Any, _ <: Any] = {
+
+    lazy val outputType = inputDataType.get(0)
+    lazy val outputTypeName = if (inputDataType.isEmpty) {
+      throw new TableException("Aggregate fields should not be empty.")
+    } else {
+      outputType.getSqlTypeName
+    }
 
-      if (aggregateCall.getAggregation.isInstanceOf[SqlCountAggFunction]) {
-        aggregates(index) = new CountAggFunction
-        if (argList.isEmpty) {
-          aggFieldIndexes(index) = Array[Int](-1)
+    aggFunc match {
+
+      case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT =>
+        new CollectAggFunction(FlinkTypeFactory.toTypeInfo(outputType))
+
+      case udagg: AggSqlFunction =>
+        udagg.getFunction
+
+      case _: SqlCountAggFunction =>
+        new CountAggFunction
+
+      case _: SqlSumAggFunction =>
+        if (needRetraction) {
+          outputTypeName match {
+            case TINYINT =>
+              new ByteSumWithRetractAggFunction
+            case SMALLINT =>
+              new ShortSumWithRetractAggFunction
+            case INTEGER =>
+              new IntSumWithRetractAggFunction
+            case BIGINT =>
+              new LongSumWithRetractAggFunction
+            case FLOAT =>
+              new FloatSumWithRetractAggFunction
+            case DOUBLE =>
+              new DoubleSumWithRetractAggFunction
+            case DECIMAL =>
+              new DecimalSumWithRetractAggFunction
+            case sqlType: SqlTypeName =>
+              throw new TableException(s"Sum aggregate does no support type: '$sqlType'")
+          }
         } else {
-          aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray
+          outputTypeName match {
+            case TINYINT =>
+              new ByteSumAggFunction
+            case SMALLINT =>
+              new ShortSumAggFunction
+            case INTEGER =>
+              new IntSumAggFunction
+            case BIGINT =>
+              new LongSumAggFunction
+            case FLOAT =>
+              new FloatSumAggFunction
+            case DOUBLE =>
+              new DoubleSumAggFunction
+            case DECIMAL =>
+              new DecimalSumAggFunction
+            case sqlType: SqlTypeName =>
+              throw new TableException(s"Sum aggregate does no support type: '$sqlType'")
+          }
         }
-      } else {
-        if (argList.isEmpty) {
-          throw new TableException("Aggregate fields should not be empty.")
+
+      case _: SqlSumEmptyIsZeroAggFunction =>
+        if (needRetraction) {
+          outputTypeName match {
+            case TINYINT =>
+              new ByteSum0WithRetractAggFunction
+            case SMALLINT =>
+              new ShortSum0WithRetractAggFunction
+            case INTEGER =>
+              new IntSum0WithRetractAggFunction
+            case BIGINT =>
+              new LongSum0WithRetractAggFunction
+            case FLOAT =>
+              new FloatSum0WithRetractAggFunction
+            case DOUBLE =>
+              new DoubleSum0WithRetractAggFunction
+            case DECIMAL =>
+              new DecimalSum0WithRetractAggFunction
+            case sqlType: SqlTypeName =>
+              throw new TableException(s"Sum0 aggregate does no support type: '$sqlType'")
+          }
         } else {
-          aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray
+          outputTypeName match {
+            case TINYINT =>
+              new ByteSum0AggFunction
+            case SMALLINT =>
+              new ShortSum0AggFunction
+            case INTEGER =>
+              new IntSum0AggFunction
+            case BIGINT =>
+              new LongSum0AggFunction
+            case FLOAT =>
+              new FloatSum0AggFunction
+            case DOUBLE =>
+              new DoubleSum0AggFunction
+            case DECIMAL =>
+              new DecimalSum0AggFunction
+            case sqlType: SqlTypeName =>
+              throw new TableException(s"Sum0 aggregate does no support type: '$sqlType'")
+          }
         }
 
-        val relDataType = aggregateInputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
-        val sqlTypeName = relDataType.getSqlTypeName
-        aggregateCall.getAggregation match {
-
-          case _: SqlSumAggFunction =>
-            if (needRetraction) {
-              aggregates(index) = sqlTypeName match {
-                case TINYINT =>
-                  new ByteSumWithRetractAggFunction
-                case SMALLINT =>
-                  new ShortSumWithRetractAggFunction
-                case INTEGER =>
-                  new IntSumWithRetractAggFunction
-                case BIGINT =>
-                  new LongSumWithRetractAggFunction
-                case FLOAT =>
-                  new FloatSumWithRetractAggFunction
-                case DOUBLE =>
-                  new DoubleSumWithRetractAggFunction
-                case DECIMAL =>
-                  new DecimalSumWithRetractAggFunction
-                case sqlType: SqlTypeName =>
-                  throw new TableException(s"Sum aggregate does no support type: '$sqlType'")
-              }
-            } else {
-              aggregates(index) = sqlTypeName match {
-                case TINYINT =>
-                  new ByteSumAggFunction
-                case SMALLINT =>
-                  new ShortSumAggFunction
-                case INTEGER =>
-                  new IntSumAggFunction
-                case BIGINT =>
-                  new LongSumAggFunction
-                case FLOAT =>
-                  new FloatSumAggFunction
-                case DOUBLE =>
-                  new DoubleSumAggFunction
-                case DECIMAL =>
-                  new DecimalSumAggFunction
-                case sqlType: SqlTypeName =>
-                  throw new TableException(s"Sum aggregate does no support type: '$sqlType'")
-              }
-            }
-
-          case _: SqlSumEmptyIsZeroAggFunction =>
-            if (needRetraction) {
-              aggregates(index) = sqlTypeName match {
-                case TINYINT =>
-                  new ByteSum0WithRetractAggFunction
-                case SMALLINT =>
-                  new ShortSum0WithRetractAggFunction
-                case INTEGER =>
-                  new IntSum0WithRetractAggFunction
-                case BIGINT =>
-                  new LongSum0WithRetractAggFunction
-                case FLOAT =>
-                  new FloatSum0WithRetractAggFunction
-                case DOUBLE =>
-                  new DoubleSum0WithRetractAggFunction
-                case DECIMAL =>
-                  new DecimalSum0WithRetractAggFunction
-                case sqlType: SqlTypeName =>
-                  throw new TableException(s"Sum0 aggregate does no support type: '$sqlType'")
-              }
-            } else {
-              aggregates(index) = sqlTypeName match {
-                case TINYINT =>
-                  new ByteSum0AggFunction
-                case SMALLINT =>
-                  new ShortSum0AggFunction
-                case INTEGER =>
-                  new IntSum0AggFunction
-                case BIGINT =>
-                  new LongSum0AggFunction
-                case FLOAT =>
-                  new FloatSum0AggFunction
-                case DOUBLE =>
-                  new DoubleSum0AggFunction
-                case DECIMAL =>
-                  new DecimalSum0AggFunction
-                case sqlType: SqlTypeName =>
-                  throw new TableException(s"Sum0 aggregate does no support type: '$sqlType'")
-              }
-            }
+      case a: SqlAvgAggFunction if a.kind == SqlKind.AVG =>
+        outputTypeName match {
+          case TINYINT =>
+            new ByteAvgAggFunction
+          case SMALLINT =>
+            new ShortAvgAggFunction
+          case INTEGER =>
+            new IntAvgAggFunction
+          case BIGINT =>
+            new LongAvgAggFunction
+          case FLOAT =>
+            new FloatAvgAggFunction
+          case DOUBLE =>
+            new DoubleAvgAggFunction
+          case DECIMAL =>
+            new DecimalAvgAggFunction(tableConfig.getDecimalContext)
+          case sqlType: SqlTypeName =>
+            throw new TableException(s"Avg aggregate does no support type: '$sqlType'")
+        }
 
-          case a: SqlAvgAggFunction if a.kind == SqlKind.AVG =>
-            aggregates(index) = sqlTypeName match {
+      case sqlMinMaxFunction: SqlMinMaxAggFunction =>
+        if (sqlMinMaxFunction.getKind == SqlKind.MIN) {
+          if (needRetraction) {
+            outputTypeName match {
               case TINYINT =>
-                new ByteAvgAggFunction
+                new ByteMinWithRetractAggFunction
               case SMALLINT =>
-                new ShortAvgAggFunction
+                new ShortMinWithRetractAggFunction
               case INTEGER =>
-                new IntAvgAggFunction
+                new IntMinWithRetractAggFunction
               case BIGINT =>
-                new LongAvgAggFunction
+                new LongMinWithRetractAggFunction
               case FLOAT =>
-                new FloatAvgAggFunction
+                new FloatMinWithRetractAggFunction
               case DOUBLE =>
-                new DoubleAvgAggFunction
+                new DoubleMinWithRetractAggFunction
               case DECIMAL =>
-                new DecimalAvgAggFunction(tableConfig.getDecimalContext)
+                new DecimalMinWithRetractAggFunction
+              case BOOLEAN =>
+                new BooleanMinWithRetractAggFunction
+              case VARCHAR | CHAR =>
+                new StringMinWithRetractAggFunction
+              case TIMESTAMP =>
+                new TimestampMinWithRetractAggFunction
+              case DATE =>
+                new DateMinWithRetractAggFunction
+              case TIME =>
+                new TimeMinWithRetractAggFunction
               case sqlType: SqlTypeName =>
-                throw new TableException(s"Avg aggregate does no support type: '$sqlType'")
+                throw new TableException(
+                  s"Min with retract aggregate does no support type: '$sqlType'")
             }
-
-          case sqlMinMaxFunction: SqlMinMaxAggFunction =>
-            aggregates(index) = if (sqlMinMaxFunction.getKind == SqlKind.MIN) {
-              if (needRetraction) {
-                sqlTypeName match {
-                  case TINYINT =>
-                    new ByteMinWithRetractAggFunction
-                  case SMALLINT =>
-                    new ShortMinWithRetractAggFunction
-                  case INTEGER =>
-                    new IntMinWithRetractAggFunction
-                  case BIGINT =>
-                    new LongMinWithRetractAggFunction
-                  case FLOAT =>
-                    new FloatMinWithRetractAggFunction
-                  case DOUBLE =>
-                    new DoubleMinWithRetractAggFunction
-                  case DECIMAL =>
-                    new DecimalMinWithRetractAggFunction
-                  case BOOLEAN =>
-                    new BooleanMinWithRetractAggFunction
-                  case VARCHAR | CHAR =>
-                    new StringMinWithRetractAggFunction
-                  case TIMESTAMP =>
-                    new TimestampMinWithRetractAggFunction
-                  case DATE =>
-                    new DateMinWithRetractAggFunction
-                  case TIME =>
-                    new TimeMinWithRetractAggFunction
-                  case sqlType: SqlTypeName =>
-                    throw new TableException(
-                      s"Min with retract aggregate does no support type: '$sqlType'")
-                }
-              } else {
-                sqlTypeName match {
-                  case TINYINT =>
-                    new ByteMinAggFunction
-                  case SMALLINT =>
-                    new ShortMinAggFunction
-                  case INTEGER =>
-                    new IntMinAggFunction
-                  case BIGINT =>
-                    new LongMinAggFunction
-                  case FLOAT =>
-                    new FloatMinAggFunction
-                  case DOUBLE =>
-                    new DoubleMinAggFunction
-                  case DECIMAL =>
-                    new DecimalMinAggFunction
-                  case BOOLEAN =>
-                    new BooleanMinAggFunction
-                  case VARCHAR | CHAR =>
-                    new StringMinAggFunction
-                  case TIMESTAMP =>
-                    new TimestampMinAggFunction
-                  case DATE =>
-                    new DateMinAggFunction
-                  case TIME =>
-                    new TimeMinAggFunction
-                  case sqlType: SqlTypeName =>
-                    throw new TableException(s"Min aggregate does no support type: '$sqlType'")
-                }
-              }
-            } else {
-              if (needRetraction) {
-                sqlTypeName match {
-                  case TINYINT =>
-                    new ByteMaxWithRetractAggFunction
-                  case SMALLINT =>
-                    new ShortMaxWithRetractAggFunction
-                  case INTEGER =>
-                    new IntMaxWithRetractAggFunction
-                  case BIGINT =>
-                    new LongMaxWithRetractAggFunction
-                  case FLOAT =>
-                    new FloatMaxWithRetractAggFunction
-                  case DOUBLE =>
-                    new DoubleMaxWithRetractAggFunction
-                  case DECIMAL =>
-                    new DecimalMaxWithRetractAggFunction
-                  case BOOLEAN =>
-                    new BooleanMaxWithRetractAggFunction
-                  case VARCHAR | CHAR =>
-                    new StringMaxWithRetractAggFunction
-                  case TIMESTAMP =>
-                    new TimestampMaxWithRetractAggFunction
-                  case DATE =>
-                    new DateMaxWithRetractAggFunction
-                  case TIME =>
-                    new TimeMaxWithRetractAggFunction
-                  case sqlType: SqlTypeName =>
-                    throw new TableException(
-                      s"Max with retract aggregate does no support type: '$sqlType'")
-                }
-              } else {
-                sqlTypeName match {
-                  case TINYINT =>
-                    new ByteMaxAggFunction
-                  case SMALLINT =>
-                    new ShortMaxAggFunction
-                  case INTEGER =>
-                    new IntMaxAggFunction
-                  case BIGINT =>
-                    new LongMaxAggFunction
-                  case FLOAT =>
-                    new FloatMaxAggFunction
-                  case DOUBLE =>
-                    new DoubleMaxAggFunction
-                  case DECIMAL =>
-                    new DecimalMaxAggFunction
-                  case BOOLEAN =>
-                    new BooleanMaxAggFunction
-                  case VARCHAR | CHAR =>
-                    new StringMaxAggFunction
-                  case TIMESTAMP =>
-                    new TimestampMaxAggFunction
-                  case DATE =>
-                    new DateMaxAggFunction
-                  case TIME =>
-                    new TimeMaxAggFunction
-                  case sqlType: SqlTypeName =>
-                    throw new TableException(s"Max aggregate does no support type: '$sqlType'")
-                }
-              }
+          } else {
+            outputTypeName match {
+              case TINYINT =>
+                new ByteMinAggFunction
+              case SMALLINT =>
+                new ShortMinAggFunction
+              case INTEGER =>
+                new IntMinAggFunction
+              case BIGINT =>
+                new LongMinAggFunction
+              case FLOAT =>
+                new FloatMinAggFunction
+              case DOUBLE =>
+                new DoubleMinAggFunction
+              case DECIMAL =>
+                new DecimalMinAggFunction
+              case BOOLEAN =>
+                new BooleanMinAggFunction
+              case VARCHAR | CHAR =>
+                new StringMinAggFunction
+              case TIMESTAMP =>
+                new TimestampMinAggFunction
+              case DATE =>
+                new DateMinAggFunction
+              case TIME =>
+                new TimeMinAggFunction
+              case sqlType: SqlTypeName =>
+                throw new TableException(s"Min aggregate does no support type: '$sqlType'")
             }
-
-          case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT =>
-            aggregates(index) = new CollectAggFunction(FlinkTypeFactory.toTypeInfo(relDataType))
-            accTypes(index) = aggregates(index).getAccumulatorType
-
-          case udagg: AggSqlFunction =>
-            aggregates(index) = udagg.getFunction
-            accTypes(index) = udagg.accType
-
-          case unSupported: SqlAggFunction =>
-            throw new TableException(s"Unsupported Function: '${unSupported.getName}'")
-        }
-      }
-    }
-
-    val accSpecs = new Array[Seq[DataViewSpec[_]]](aggregateCalls.size)
-
-    // create accumulator type information for every aggregate function
-    aggregates.zipWithIndex.foreach { case (agg, index) =>
-      if (accTypes(index) != null) {
-        val (accType, specs) = removeStateViewFieldsFromAccTypeInfo(index,
-          agg,
-          accTypes(index),
-          isStateBackedDataViews)
-        if (specs.isDefined) {
-          accSpecs(index) = specs.get
-          accTypes(index) = accType
-        } else {
-          accSpecs(index) = Seq()
-        }
-      } else {
-        accSpecs(index) = Seq()
-        accTypes(index) = getAccumulatorTypeOfAggregateFunction(agg)
-      }
-    }
-
-    // create distinct accumulator filter argument
-    val isDistinctAggs = new Array[Boolean](aggregateCalls.size)
-
-    aggregateCalls.zipWithIndex.foreach {
-      case (aggCall, index) =>
-        if (aggCall.isDistinct) {
-          // Generate distinct aggregates and the corresponding DistinctAccumulator
-          // wrappers for storing distinct mapping
-          val argList: util.List[Integer] = aggCall.getArgList
-
-          // Using Pojo fields for the real underlying accumulator
-          val pojoFields = new util.ArrayList[PojoField]()
-          pojoFields.add(new PojoField(
-            classOf[DistinctAccumulator[_]].getDeclaredField("realAcc"),
-            accTypes(index))
-          )
-          // If StateBackend is not enabled, the distinct mapping also needs
-          // to be added to the Pojo fields.
-          if (!isStateBackedDataViews) {
-
-            val argTypes: Array[TypeInformation[_]] = argList
-              .map(aggregateInputType.getFieldList.get(_).getType)
-              .map(FlinkTypeFactory.toTypeInfo).toArray
-
-            val mapViewTypeInfo = new MapViewTypeInfo(
-              new RowTypeInfo(argTypes:_*),
-              BasicTypeInfo.LONG_TYPE_INFO)
-            pojoFields.add(new PojoField(
-              classOf[DistinctAccumulator[_]].getDeclaredField("distinctValueMap"),
-              mapViewTypeInfo)
-            )
           }
-          accTypes(index) = new PojoTypeInfo(classOf[DistinctAccumulator[_]], pojoFields)
-          isDistinctAggs(index) = true
         } else {
-          isDistinctAggs(index) = false
+          if (needRetraction) {
+            outputTypeName match {
+              case TINYINT =>
+                new ByteMaxWithRetractAggFunction
+              case SMALLINT =>
+                new ShortMaxWithRetractAggFunction
+              case INTEGER =>
+                new IntMaxWithRetractAggFunction
+              case BIGINT =>
+                new LongMaxWithRetractAggFunction
+              case FLOAT =>
+                new FloatMaxWithRetractAggFunction
+              case DOUBLE =>
+                new DoubleMaxWithRetractAggFunction
+              case DECIMAL =>
+                new DecimalMaxWithRetractAggFunction
+              case BOOLEAN =>
+                new BooleanMaxWithRetractAggFunction
+              case VARCHAR | CHAR =>
+                new StringMaxWithRetractAggFunction
+              case TIMESTAMP =>
+                new TimestampMaxWithRetractAggFunction
+              case DATE =>
+                new DateMaxWithRetractAggFunction
+              case TIME =>
+                new TimeMaxWithRetractAggFunction
+              case sqlType: SqlTypeName =>
+                throw new TableException(
+                  s"Max with retract aggregate does no support type: '$sqlType'")
+            }
+          } else {
+            outputTypeName match {
+              case TINYINT =>
+                new ByteMaxAggFunction
+              case SMALLINT =>
+                new ShortMaxAggFunction
+              case INTEGER =>
+                new IntMaxAggFunction
+              case BIGINT =>
+                new LongMaxAggFunction
+              case FLOAT =>
+                new FloatMaxAggFunction
+              case DOUBLE =>
+                new DoubleMaxAggFunction
+              case DECIMAL =>
+                new DecimalMaxAggFunction
+              case BOOLEAN =>
+                new BooleanMaxAggFunction
+              case VARCHAR | CHAR =>
+                new StringMaxAggFunction
+              case TIMESTAMP =>
+                new TimestampMaxAggFunction
+              case DATE =>
+                new DateMaxAggFunction
+              case TIME =>
+                new TimeMaxAggFunction
+              case sqlType: SqlTypeName =>
+                throw new TableException(s"Max aggregate does no support type: '$sqlType'")
+            }
+          }
         }
-    }
 
-    (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, accSpecs)
+      case unSupported: SqlAggFunction =>
+        throw new TableException(s"Unsupported Function: '${unSupported.getName}'")
+    }
   }
 
   private def createRowTypeForKeysAndAggregates(