You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2020/07/24 10:49:29 UTC

[flink] branch master updated: [FLINK-15803][table] Use AggregateInfo as the single source of type description

This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new a37131d  [FLINK-15803][table] Use AggregateInfo as the single source of type description
a37131d is described below

commit a37131d576c2cc281a1aa5bfe186bdd72748ac06
Author: Timo Walther <tw...@apache.org>
AuthorDate: Wed Jul 22 13:12:19 2020 +0200

    [FLINK-15803][table] Use AggregateInfo as the single source of type description
    
    This refactors a lot of the code generation around aggregate functions. It does
    this for better code maintainability and in particular for having a single source
    of generating all types (arguments, accumulator, result).
    
    This closes #12967.
---
 .../codegen/agg/batch/AggCodeGenHelper.scala       | 590 ++++++++++++---------
 .../agg/batch/AggWithoutKeysCodeGenerator.scala    |  31 +-
 .../codegen/agg/batch/HashAggCodeGenHelper.scala   | 216 ++++----
 .../codegen/agg/batch/HashAggCodeGenerator.scala   |  40 +-
 .../agg/batch/HashWindowCodeGenerator.scala        |  11 +-
 .../codegen/agg/batch/SortAggCodeGenerator.scala   |  34 +-
 .../agg/batch/SortWindowCodeGenerator.scala        |  14 +-
 .../codegen/agg/batch/WindowCodeGenerator.scala    | 116 ++--
 .../types/LogicalTypeDataTypeConverter.java        |   2 +
 9 files changed, 568 insertions(+), 486 deletions(-)

diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
index e7fa871..72be697 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
@@ -22,26 +22,24 @@ import org.apache.flink.runtime.util.SingleElementIterator
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator
 import org.apache.flink.table.data.{GenericRowData, RowData}
 import org.apache.flink.table.expressions.ApiExpressionUtils.localRef
-import org.apache.flink.table.expressions.{Expression, _}
-import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
+import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
 import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.STREAM_RECORD
 import org.apache.flink.table.planner.codegen._
 import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver
 import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef
 import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
 import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
-import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getAggUserDefinedInputTypes}
+import org.apache.flink.table.planner.plan.utils.AggregateInfo
 import org.apache.flink.table.runtime.context.ExecutionContextImpl
 import org.apache.flink.table.runtime.generated.{GeneratedAggsHandleFunction, GeneratedOperator}
-import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.{fromDataTypeToLogicalType, fromLogicalTypeToDataType}
+import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter._
 import org.apache.flink.table.runtime.typeutils.InternalSerializers
-import org.apache.flink.table.types.DataType
 import org.apache.flink.table.types.logical.LogicalTypeRoot._
 import org.apache.flink.table.types.logical.{DistinctType, LogicalType, RowType}
 
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.rex.RexNode
 import org.apache.calcite.tools.RelBuilder
 
 import scala.annotation.tailrec
@@ -52,34 +50,50 @@ import scala.annotation.tailrec
 object AggCodeGenHelper {
 
   def getAggBufferNames(
-      auxGrouping: Array[Int], aggregates: Seq[UserDefinedFunction]): Array[Array[String]] = {
-    auxGrouping.zipWithIndex.map {
-      case (_, index) => Array(s"aux_group$index")
-    } ++ aggregates.zipWithIndex.toArray.map {
-      case (a: DeclarativeAggregateFunction, index) =>
-        val idx = auxGrouping.length + index
-        a.aggBufferAttributes.map(attr => s"agg${idx}_${attr.getName}")
-      case (_: AggregateFunction[_, _], index) =>
-        val idx = auxGrouping.length + index
-        Array(s"agg$idx")
+      auxGrouping: Array[Int],
+      aggInfos: Seq[AggregateInfo])
+    : Array[Array[String]] = {
+    val auxGroupingNames = auxGrouping.indices
+      .map(index => Array(s"aux_group$index"))
+
+    val aggNames = aggInfos.map { aggInfo =>
+
+      val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+      aggInfo.function match {
+
+        // create one buffer for each attribute in declarative functions
+        case function: DeclarativeAggregateFunction =>
+          function.aggBufferAttributes.map(attr => s"agg${aggBufferIdx}_${attr.getName}")
+
+        // create one buffer for imperative functions
+        case _: AggregateFunction[_, _] =>
+          Array(s"agg$aggBufferIdx")
+      }
     }
+
+    (auxGroupingNames ++ aggNames).toArray
   }
 
   def getAggBufferTypes(
-      inputType: RowType, auxGrouping: Array[Int], aggregates: Seq[UserDefinedFunction])
+      inputType: RowType,
+      auxGrouping: Array[Int],
+      aggInfos: Seq[AggregateInfo])
     : Array[Array[LogicalType]] = {
-    auxGrouping.map { index =>
-      Array(inputType.getTypeAt(index))
-    } ++ aggregates.map {
-      case a: DeclarativeAggregateFunction => a.getAggBufferTypes.map(_.getLogicalType)
-      case a: AggregateFunction[_, _] =>
-        Array(fromDataTypeToLogicalType(getAccumulatorTypeOfAggregateFunction(a)))
-    }.toArray[Array[LogicalType]]
+    val auxGroupingTypes = auxGrouping
+      .map { index =>
+        Array(inputType.getTypeAt(index))
+      }
+
+    val aggTypes = aggInfos
+      .map(_.externalAccTypes.map(fromDataTypeToLogicalType))
+
+    auxGroupingTypes ++ aggTypes
   }
 
-  def getUdaggs(
-      aggregates: Seq[UserDefinedFunction]): Map[AggregateFunction[_, _], String] = {
-    aggregates
+  def getFunctionIdentifiers(aggInfos: Seq[AggregateInfo]): Map[AggregateFunction[_, _], String] = {
+    aggInfos
+        .map(_.function)
         .filter(a => a.isInstanceOf[AggregateFunction[_, _]])
         .map(a => a -> CodeGenUtils.udfFieldName(a)).toMap
         .asInstanceOf[Map[AggregateFunction[_, _], String]]
@@ -95,7 +109,8 @@ object AggCodeGenHelper {
   private[flink] def addAggsHandler(
       aggsHandler: GeneratedAggsHandleFunction,
       ctx: CodeGeneratorContext,
-      aggsHandlerCtx: CodeGeneratorContext): String = {
+      aggsHandlerCtx: CodeGeneratorContext)
+    : String = {
     ctx.addReusableInnerClass(aggsHandler.getClassName, aggsHandler.getCode)
     val handler = CodeGenUtils.newName("handler")
     ctx.addReusableMember(s"${aggsHandler.getClassName} $handler = null;")
@@ -116,7 +131,8 @@ object AggCodeGenHelper {
     */
   private[flink] def genGroupKeyChangedCheckCode(
       currentKeyTerm: String,
-      lastKeyTerm: String): String = {
+      lastKeyTerm: String)
+    : String = {
     s"""
        |$currentKeyTerm.getSizeInBytes() != $lastKeyTerm.getSizeInBytes() ||
        |  !(org.apache.flink.table.data.binary.BinaryRowDataUtil.byteArrayEquals(
@@ -133,27 +149,25 @@ object AggCodeGenHelper {
       builder: RelBuilder,
       grouping: Array[Int],
       auxGrouping: Array[Int],
-      aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
-      aggArgs: Array[Array[Int]],
-      aggregates: Seq[UserDefinedFunction],
-      aggResultTypes: Seq[DataType],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       inputTerm: String,
       inputType: RowType,
       aggBufferNames: Array[Array[String]],
       aggBufferTypes: Array[Array[LogicalType]],
       outputType: RowType,
-      forHashAgg: Boolean = false): (String, String, GeneratedExpression) = {
+      forHashAgg: Boolean = false)
+    : (String, String, GeneratedExpression) = {
     // gen code to apply aggregate functions to grouping elements
     val argsMapping = buildAggregateArgsMapping(
-      isMerge, grouping.length, inputType, auxGrouping, aggArgs, aggBufferTypes)
+      isMerge, grouping.length, inputType, auxGrouping, aggInfos, aggBufferTypes)
 
     val aggBufferExprs = genFlatAggBufferExprs(
       isMerge,
       ctx,
       builder,
       auxGrouping,
-      aggregates,
+      aggInfos,
       argsMapping,
       aggBufferNames,
       aggBufferTypes)
@@ -165,8 +179,8 @@ object AggCodeGenHelper {
       inputTerm,
       grouping,
       auxGrouping,
-      aggregates,
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       aggBufferExprs,
       forHashAgg)
 
@@ -177,9 +191,8 @@ object AggCodeGenHelper {
       inputType,
       inputTerm,
       auxGrouping,
-      aggCallToAggFunction,
-      aggregates,
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       argsMapping,
       aggBufferNames,
       aggBufferTypes,
@@ -192,9 +205,8 @@ object AggCodeGenHelper {
       builder,
       grouping,
       auxGrouping,
-      aggregates,
-      aggResultTypes,
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       argsMapping,
       aggBufferNames,
       aggBufferTypes,
@@ -218,8 +230,10 @@ object AggCodeGenHelper {
       aggBufferOffset: Int,
       inputType: RowType,
       auxGrouping: Array[Int],
-      aggArgs: Array[Array[Int]],
-      aggBufferTypes: Array[Array[LogicalType]]): Array[Array[(Int, LogicalType)]] = {
+      aggInfos: Seq[AggregateInfo],
+      aggBufferTypes: Array[Array[LogicalType]])
+    : Array[Array[(Int, LogicalType)]] = {
+    val aggArgs = aggInfos.map(_.argIndexes).toArray
     val auxGroupingMapping = auxGrouping.indices.map {
       i => Array[(Int, LogicalType)]((i, aggBufferTypes(i)(0)))
     }.toArray
@@ -281,29 +295,47 @@ object AggCodeGenHelper {
       ctx: CodeGeneratorContext,
       builder: RelBuilder,
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
+      aggInfos: Seq[AggregateInfo],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBufferNames: Array[Array[String]],
       aggBufferTypes: Array[Array[LogicalType]]): Seq[GeneratedExpression] = {
-    val exprCodegen = new ExprCodeGenerator(ctx, false)
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
     val converter = new ExpressionConverter(builder)
 
-    val accessAuxGroupingExprs = auxGrouping.indices.map {
-      idx => newLocalReference(aggBufferNames(idx)(0), aggBufferTypes(idx)(0))
-    }.map(_.accept(converter)).map(exprCodegen.generateExpression)
-
-    val aggCallExprs = aggregates.zipWithIndex.flatMap {
-      case (agg: DeclarativeAggregateFunction, aggIndex: Int) =>
-        val idx = auxGrouping.length + aggIndex
-        agg.aggBufferAttributes.map(_.accept(
-          ResolveReference(ctx, builder, isMerge, agg, idx, argsMapping, aggBufferTypes)))
-      case (_: AggregateFunction[_, _], aggIndex: Int) =>
-        val idx = auxGrouping.length + aggIndex
-        val variableName = aggBufferNames(idx)(0)
-        Some(newLocalReference(variableName, aggBufferTypes(idx)(0)))
-    }.map(_.accept(converter)).map(exprCodegen.generateExpression)
-
-    accessAuxGroupingExprs ++ aggCallExprs
+    val accessAuxGroupingExprs = auxGrouping.indices
+      .map(idx => newLocalReference(aggBufferNames(idx).head, aggBufferTypes(idx).head))
+
+    val aggCallExprs = aggInfos.flatMap { aggInfo =>
+
+      val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+      aggInfo.function match {
+
+        // create a buffer expression for each attribute in declarative functions
+        case function: DeclarativeAggregateFunction =>
+          val ref = ResolveReference(
+            ctx,
+            builder,
+            isMerge,
+            function,
+            aggBufferIdx,
+            argsMapping,
+            aggBufferTypes)
+          function.aggBufferAttributes().map(_.accept(ref))
+
+        // create one buffer for imperative functions
+        case _: AggregateFunction[_, _] =>
+          val aggBufferName = aggBufferNames(aggBufferIdx).head
+          val aggBufferType = aggBufferTypes(aggBufferIdx).head
+          Some(newLocalReference(aggBufferName, aggBufferType))
+      }
+    }
+
+    val aggBufferExprs = accessAuxGroupingExprs ++ aggCallExprs
+
+    aggBufferExprs
+      .map(_.accept(converter))
+      .map(exprCodeGen.generateExpression)
   }
 
   /**
@@ -316,12 +348,14 @@ object AggCodeGenHelper {
       inputTerm: String,
       grouping: Array[Int],
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       aggBufferExprs: Seq[GeneratedExpression],
-      forHashAgg: Boolean = false): String = {
-    val exprCodegen = new ExprCodeGenerator(ctx, false)
+      forHashAgg: Boolean = false)
+    : String = {
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
         .bindInput(inputType, inputTerm = inputTerm, inputFieldMapping = Some(auxGrouping))
+    val converter = new ExpressionConverter(builder)
 
     val initAuxGroupingExprs = {
       if (forHashAgg) {
@@ -335,25 +369,27 @@ object AggCodeGenHelper {
       GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, idx)
     }
 
-    val initAggCallBufferExprs = aggregates.flatMap {
-      case (agg: DeclarativeAggregateFunction) =>
-        agg.initialValuesExpressions
-      case (agg: AggregateFunction[_, _]) =>
-        Some(agg)
-    }.map {
-      case (expr: Expression) => expr.accept(new ExpressionConverter(builder))
-      case t@_ => t
-    }.map {
-      case (rex: RexNode) => exprCodegen.generateExpression(rex)
-      case (agg: AggregateFunction[_, _]) =>
-        val resultTerm = s"${udaggs(agg)}.createAccumulator()"
-        val nullTerm = "false"
-        val resultType = getAccumulatorTypeOfAggregateFunction(agg)
-        GeneratedExpression(
-          genToInternal(ctx, resultType, resultTerm),
-          nullTerm,
-          "",
-          fromDataTypeToLogicalType(resultType))
+    val initAggCallBufferExprs = aggInfos.flatMap { aggInfo =>
+      aggInfo.function match {
+
+        // generate code for each agg buffer in declarative functions
+        case function: DeclarativeAggregateFunction =>
+          val expressions = function.initialValuesExpressions
+          val rexNodes = expressions.map(_.accept(converter))
+          rexNodes.map(exprCodeGen.generateExpression)
+
+        // call createAccumulator() in imperative functions
+        case function: AggregateFunction[_, _] =>
+          val accTerm = s"${functionIdentifiers(function)}.createAccumulator()"
+          val externalAccType = aggInfo.externalAccTypes.head
+          val internalAccType = externalAccType.getLogicalType
+          val genExpr = GeneratedExpression(
+            genToInternal(ctx, externalAccType)(accTerm),
+            NEVER_NULL,
+            NO_CODE,
+            internalAccType)
+          Seq(genExpr)
+        }
     }
 
     val initAggBufferExprs = initAuxGroupingExprs ++ initAggCallBufferExprs
@@ -394,13 +430,13 @@ object AggCodeGenHelper {
       inputType: RowType,
       inputTerm: String,
       auxGrouping: Array[Int],
-      aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
-      aggregates: Seq[UserDefinedFunction],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBufferNames: Array[Array[String]],
       aggBufferTypes: Array[Array[LogicalType]],
-      aggBufferExprs: Seq[GeneratedExpression]): String = {
+      aggBufferExprs: Seq[GeneratedExpression])
+    : String = {
     if (isMerge) {
       genMergeFlatAggregateBuffer(
         ctx,
@@ -408,8 +444,8 @@ object AggCodeGenHelper {
         inputTerm,
         inputType,
         auxGrouping,
-        aggregates,
-        udaggs,
+        aggInfos,
+        functionIdentifiers,
         argsMapping,
         aggBufferNames,
         aggBufferTypes,
@@ -421,8 +457,8 @@ object AggCodeGenHelper {
         inputTerm,
         inputType,
         auxGrouping,
-        aggCallToAggFunction,
-        udaggs,
+        aggInfos,
+        functionIdentifiers,
         argsMapping,
         aggBufferNames,
         aggBufferTypes,
@@ -437,35 +473,34 @@ object AggCodeGenHelper {
       builder: RelBuilder,
       grouping: Array[Int],
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
-      aggResultTypes: Seq[DataType],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBufferNames: Array[Array[String]],
       aggBufferTypes: Array[Array[LogicalType]],
       aggBufferExprs: Seq[GeneratedExpression],
-      outputType: RowType): GeneratedExpression = {
+      outputType: RowType)
+    : GeneratedExpression = {
     val valueRow = CodeGenUtils.newName("valueRow")
-    val resultCodegen = new ExprCodeGenerator(ctx, false)
+    val resultCodeGen = new ExprCodeGenerator(ctx, false)
     if (isFinal) {
       val getValueExprs = genGetValueFromFlatAggregateBuffer(
         isMerge,
         ctx,
         builder,
         auxGrouping,
-        aggregates,
-        aggResultTypes,
-        udaggs,
+        aggInfos,
+        functionIdentifiers,
         argsMapping,
         aggBufferNames,
         aggBufferTypes,
         outputType)
       val valueRowType = RowType.of(getValueExprs.map(_.resultType): _*)
-      resultCodegen.generateResultExpression(
+      resultCodeGen.generateResultExpression(
         getValueExprs, valueRowType, classOf[GenericRowData], valueRow)
     } else {
       val valueRowType = RowType.of(aggBufferExprs.map(_.resultType): _*)
-      resultCodegen.generateResultExpression(
+      resultCodeGen.generateResultExpression(
         aggBufferExprs, valueRowType, classOf[GenericRowData], valueRow)
     }
   }
@@ -478,44 +513,58 @@ object AggCodeGenHelper {
       ctx: CodeGeneratorContext,
       builder: RelBuilder,
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
-      aggResultTypes: Seq[DataType],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBufferNames: Array[Array[String]],
       aggBufferTypes: Array[Array[LogicalType]],
       outputType: RowType): Seq[GeneratedExpression] = {
-    val exprCodegen = new ExprCodeGenerator(ctx, false)
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
+    val converter = new ExpressionConverter(builder)
 
     val auxGroupingExprs = auxGrouping.indices.map { idx =>
-      val resultTerm = aggBufferNames(idx)(0)
-      val nullTerm = s"${resultTerm}IsNull"
-      GeneratedExpression(resultTerm, nullTerm, "", aggBufferTypes(idx)(0))
+      val aggBufferName = aggBufferNames(idx).head
+      val aggBufferType = aggBufferTypes(idx).head
+      val nullTerm = s"${aggBufferName}IsNull"
+      GeneratedExpression(aggBufferName, nullTerm, NO_CODE, aggBufferType)
     }
 
-    val aggExprs = aggregates.zipWithIndex.map {
-      case (agg: DeclarativeAggregateFunction, aggIndex) =>
-        val idx = auxGrouping.length + aggIndex
-        agg.getValueExpression.accept(ResolveReference(
-          ctx, builder, isMerge, agg, idx, argsMapping, aggBufferTypes))
-      case (agg: AggregateFunction[_, _], aggIndex) =>
-        val idx = auxGrouping.length + aggIndex
-        (agg, idx)
-    }.map {
-      case (expr: Expression) => expr.accept(new ExpressionConverter(builder))
-      case t@_ => t
-    }.map {
-      case (rex: RexNode) => exprCodegen.generateExpression(rex)
-      case (agg: AggregateFunction[_, _], aggIndex: Int) =>
-        val resultType = aggResultTypes(aggIndex - auxGrouping.length)
-        val accType = getAccumulatorTypeOfAggregateFunction(agg)
-        val resultTerm = genToInternal(ctx, resultType,
-          s"${udaggs(agg)}.getValue(${genToExternal(ctx, accType, aggBufferNames(aggIndex)(0))})")
-        val nullTerm = s"${aggBufferNames(aggIndex)(0)}IsNull"
-        GeneratedExpression(resultTerm, nullTerm, "", fromDataTypeToLogicalType(resultType))
+    val getValueExprs = aggInfos.map { aggInfo =>
+
+      val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+      aggInfo.function match {
+
+        // evaluate the value expression in declarative functions
+        case function: DeclarativeAggregateFunction =>
+          val ref = ResolveReference(
+            ctx,
+            builder,
+            isMerge,
+            function,
+            aggBufferIdx,
+            argsMapping,
+            aggBufferTypes)
+          val getValueRexNode = function.getValueExpression
+            .accept(ref)
+            .accept(converter)
+          exprCodeGen.generateExpression(getValueRexNode)
+
+        // call getValue() for imperative functions
+        case function: AggregateFunction[_, _] =>
+          val aggBufferName = aggBufferNames(aggBufferIdx).head
+          val externalAccType = aggInfo.externalAccTypes.head
+          val externalResultType = aggInfo.externalResultType
+          val resultType = externalResultType.getLogicalType
+          val getValueCode = s"${functionIdentifiers(function)}.getValue(" +
+            s"${genToExternal(ctx, externalAccType, aggBufferName)})"
+          val resultTerm = genToInternal(ctx, externalResultType)(getValueCode)
+          val nullTerm = s"${aggBufferName}IsNull"
+          GeneratedExpression(resultTerm, nullTerm, NO_CODE, resultType)
+      }
     }
 
-    auxGroupingExprs ++ aggExprs
+    auxGroupingExprs ++ getValueExprs
   }
 
   /**
@@ -527,55 +576,81 @@ object AggCodeGenHelper {
       inputTerm: String,
       inputType: RowType,
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBufferNames: Array[Array[String]],
       aggBufferTypes: Array[Array[LogicalType]],
-      aggBufferExprs: Seq[GeneratedExpression]): String = {
-    val exprCodegen = new ExprCodeGenerator(ctx, false).bindInput(inputType, inputTerm = inputTerm)
-
-    // flat map to get flat agg buffers.
-    aggregates.zipWithIndex.flatMap {
-      case (agg: DeclarativeAggregateFunction, aggIndex) =>
-        val idx = auxGrouping.length + aggIndex
-        agg.mergeExpressions.map(_.accept(ResolveReference(
-          ctx, builder, isMerge = true, agg, idx, argsMapping, aggBufferTypes)))
-      case (agg: AggregateFunction[_, _], aggIndex) =>
-        val idx = auxGrouping.length + aggIndex
-        Some(agg, idx)
-    }.zip(aggBufferExprs.slice(auxGrouping.length, aggBufferExprs.size)).map {
-      // DeclarativeAggregateFunction
-      case ((expr: Expression), aggBufVar) =>
-        val mergeExpr = exprCodegen.generateExpression(
-          expr.accept(new ExpressionConverter(builder)))
-        s"""
-           |${mergeExpr.code}
-           |${aggBufVar.nullTerm} = ${mergeExpr.nullTerm};
-           |if (!${mergeExpr.nullTerm}) {
-           |  ${mergeExpr.copyResultTermToTargetIfChanged(ctx, aggBufVar.resultTerm)}
-           |}
-           """.stripMargin.trim
-      // UserDefinedAggregateFunction
-      case ((agg: AggregateFunction[_, _], aggIndex: Int), aggBufVar) =>
-        val (inputIndex, inputType) = argsMapping(aggIndex)(0)
-        val inputRef = toRexInputRef(builder, inputIndex, inputType)
-        val inputExpr = exprCodegen.generateExpression(
-          inputRef.accept(new ExpressionConverter(builder)))
-        val singleIterableClass = classOf[SingleElementIterator[_]].getCanonicalName
-
-        val externalAccT = getAccumulatorTypeOfAggregateFunction(agg)
-        val javaField = typeTerm(externalAccT.getConversionClass)
-        val tmpAcc = newName("tmpAcc")
-        s"""
-           |final $singleIterableClass accIt$aggIndex = new  $singleIterableClass();
-           |accIt$aggIndex.set(${genToExternal(ctx, externalAccT, inputExpr.resultTerm)});
-           |$javaField $tmpAcc = ${genToExternal(ctx, externalAccT, aggBufferNames(aggIndex)(0))};
-           |${udaggs(agg)}.merge($tmpAcc, accIt$aggIndex);
-           |${aggBufferNames(aggIndex)(0)} = ${genToInternal(ctx, externalAccT, tmpAcc)};
-           |${aggBufVar.nullTerm} = ${aggBufferNames(aggIndex)(0)}IsNull || ${inputExpr.nullTerm};
-         """.stripMargin
-    } mkString "\n"
+      aggBufferExprs: Seq[GeneratedExpression])
+    : String = {
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
+      .bindInput(inputType, inputTerm = inputTerm)
+    val converter = new ExpressionConverter(builder)
+
+    var currentAggBufferExprIdx = auxGrouping.length
+
+    val mergeCode = aggInfos.map { aggInfo =>
+
+      val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+      aggInfo.function match {
+
+        // merge each agg buffer for declarative functions
+        case function: DeclarativeAggregateFunction =>
+          val ref = ResolveReference(
+            ctx,
+            builder,
+            isMerge = true,
+            function,
+            aggBufferIdx,
+            argsMapping,
+            aggBufferTypes)
+          val mergeExprs = function.mergeExpressions
+            .map(_.accept(ref))
+            .map(_.accept(converter))
+            .map(exprCodeGen.generateExpression)
+          mergeExprs
+            .map { mergeExpr =>
+              val aggBufferExpr = aggBufferExprs(currentAggBufferExprIdx)
+              currentAggBufferExprIdx += 1
+              s"""
+                 |${mergeExpr.code}
+                 |${aggBufferExpr.nullTerm} = ${mergeExpr.nullTerm};
+                 |if (!${mergeExpr.nullTerm}) {
+                 |  ${mergeExpr.copyResultTermToTargetIfChanged(ctx, aggBufferExpr.resultTerm)}
+                 |}
+              """.stripMargin
+            }
+            .mkString("\n")
+
+        // call merge() for imperative functions
+        case function: AggregateFunction[_, _] =>
+          val (inputIndex, inputType) = argsMapping(aggBufferIdx).head
+          val inputRef = toRexInputRef(builder, inputIndex, inputType)
+          val inputExpr = exprCodeGen.generateExpression(
+            inputRef.accept(converter))
+          val aggBufferName = aggBufferNames(aggBufferIdx).head
+          val aggBufferExpr = aggBufferExprs(currentAggBufferExprIdx)
+          currentAggBufferExprIdx += 1
+          val iterableTypeTerm = className[SingleElementIterator[_]]
+          val externalAccType = aggInfo.externalAccTypes.head
+          val externalAccTypeTerm = typeTerm(externalAccType.getConversionClass)
+          val externalAccTerm = newName("acc")
+          val aggIndex = aggInfo.aggIndex
+          s"""
+            |$iterableTypeTerm accIt$aggIndex = new $iterableTypeTerm();
+            |accIt$aggIndex.set(${
+              genToExternal(ctx, externalAccType, inputExpr.resultTerm)});
+            |$externalAccTypeTerm $externalAccTerm = ${
+              genToExternal(ctx, externalAccType, aggBufferName)};
+            |${functionIdentifiers(function)}.merge($externalAccTerm, accIt$aggIndex);
+            |$aggBufferName = ${genToInternal(ctx, externalAccType)(externalAccTerm)};
+            |${aggBufferExpr.nullTerm} = ${aggBufferName}IsNull || ${inputExpr.nullTerm};
+          """.stripMargin
+      }
+    }
+
+    mergeCode.mkString("\n")
   }
 
   /**
@@ -587,81 +662,96 @@ object AggCodeGenHelper {
       inputTerm: String,
       inputType: RowType,
       auxGrouping: Array[Int],
-      aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBufferNames: Array[Array[String]],
       aggBufferTypes: Array[Array[LogicalType]],
-      aggBufferExprs: Seq[GeneratedExpression]): String = {
-    val exprCodegen = new ExprCodeGenerator(ctx, false).bindInput(inputType, inputTerm = inputTerm)
-
-    // flat map to get flat agg buffers.
-    aggCallToAggFunction.zipWithIndex.flatMap {
-      case (aggCallToAggFun, aggIndex) =>
-        val idx = auxGrouping.length + aggIndex
-        val aggCall = aggCallToAggFun._1
-        aggCallToAggFun._2 match {
-          case agg: DeclarativeAggregateFunction =>
-            agg.accumulateExpressions.map(_.accept(ResolveReference(
-              ctx, builder, isMerge = false, agg, idx, argsMapping, aggBufferTypes)))
-                .map(e => (e, aggCall))
-          case agg: AggregateFunction[_, _] =>
-            val idx = auxGrouping.length + aggIndex
-            Some(agg, idx, aggCall)
-        }
-    }.zip(aggBufferExprs.slice(auxGrouping.length, aggBufferExprs.size)).map {
-      // DeclarativeAggregateFunction
-      case ((expr: Expression, aggCall: AggregateCall), aggBufVar) =>
-        val accExpr = exprCodegen.generateExpression(expr.accept(new ExpressionConverter(builder)))
-        (s"""
-            |${accExpr.code}
-            |${aggBufVar.nullTerm} = ${accExpr.nullTerm};
-            |if (!${accExpr.nullTerm}) {
-            |  ${accExpr.copyResultTermToTargetIfChanged(ctx, aggBufVar.resultTerm)}
-            |}
-           """.stripMargin, aggCall.filterArg)
-      // UserDefinedAggregateFunction
-      case ((agg: AggregateFunction[_, _], aggIndex: Int, aggCall: AggregateCall),
-      aggBufVar) =>
-        val inFields = argsMapping(aggIndex)
-        val externalAccType = getAccumulatorTypeOfAggregateFunction(agg)
-
-        val inputExprs = inFields.map {
-          f =>
-            val inputRef = toRexInputRef(builder, f._1, f._2)
-            exprCodegen.generateExpression(inputRef.accept(new ExpressionConverter(builder)))
-        }
-
-        val externalUDITypes = getAggUserDefinedInputTypes(
-          agg, externalAccType, inputExprs.map(_.resultType))
-        val parameters = inputExprs.zipWithIndex.map {
-          case (expr, i) =>
-            genToExternalIfNeeded(ctx, externalUDITypes(i), expr)
-        }
+      aggBufferExprs: Seq[GeneratedExpression])
+    : String = {
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
+      .bindInput(inputType, inputTerm = inputTerm)
+    val converter = new ExpressionConverter(builder)
 
-        val javaTerm = typeTerm(externalAccType.getConversionClass)
-        val tmpAcc = newName("tmpAcc")
-        val innerCode =
+    var currentAggBufferExprIdx = auxGrouping.length
+
+    val filteredAccCode = aggInfos.map { aggInfo =>
+
+      val aggCall = aggInfo.agg
+
+      val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+      val accCode = aggInfo.function match {
+
+        // update each agg buffer for declarative functions
+        case function: DeclarativeAggregateFunction =>
+          val ref = ResolveReference(
+            ctx,
+            builder,
+            isMerge = false,
+            function,
+            aggBufferIdx,
+            argsMapping,
+            aggBufferTypes)
+          val accExprs = function.accumulateExpressions
+            .map(_.accept(ref))
+            .map(_.accept(new ExpressionConverter(builder)))
+            .map(exprCodeGen.generateExpression)
+          accExprs
+            .map { accExpr =>
+              val aggBufferExpr = aggBufferExprs(currentAggBufferExprIdx)
+              currentAggBufferExprIdx += 1
+              s"""
+                |${accExpr.code}
+                |${aggBufferExpr.nullTerm} = ${accExpr.nullTerm};
+                |if (!${accExpr.nullTerm}) {
+                |  // copy result term
+                |  ${accExpr.copyResultTermToTargetIfChanged(ctx, aggBufferExpr.resultTerm)}
+                |}
+              """.stripMargin
+            }
+            .mkString("\n")
+
+        // call accumulate() for imperative functions
+        case function: AggregateFunction[_, _] =>
+          val args = argsMapping(aggBufferIdx)
+          val inputExprs = args.map { case (argIndex, argType) =>
+              val inputRef = toRexInputRef(builder, argIndex, argType)
+              exprCodeGen.generateExpression(inputRef.accept(converter))
+          }
+          val operandTerms = inputExprs.zipWithIndex.map { case (expr, i) =>
+              genToExternalIfNeeded(ctx, aggInfo.externalArgTypes(i), expr)
+          }
+          val aggBufferName = aggBufferNames(aggBufferIdx).head
+          val aggBufferExpr = aggBufferExprs(currentAggBufferExprIdx)
+          currentAggBufferExprIdx += 1
+          val externalAccType = aggInfo.externalAccTypes.head
+          val externalAccTypeTerm = typeTerm(externalAccType.getConversionClass)
+          val externalAccTerm = newName("acc")
+          val externalAccCode = genToExternal(ctx, externalAccType, aggBufferName)
           s"""
-             |  $javaTerm $tmpAcc = ${
-            genToExternal(ctx, externalAccType, aggBufferNames(aggIndex)(0))};
-             |  ${udaggs(agg)}.accumulate($tmpAcc, ${parameters.mkString(", ")});
-             |  ${aggBufferNames(aggIndex)(0)} = ${genToInternal(ctx, externalAccType, tmpAcc)};
-             |  ${aggBufVar.nullTerm} = false;
-           """.stripMargin
-        (innerCode, aggCall.filterArg)
-    }.map({
-      case (innerCode, filterArg) =>
-        if (filterArg >= 0) {
-          s"""
-             |if ($inputTerm.getBoolean($filterArg)) {
-             | $innerCode
-             |}
+            |$externalAccTypeTerm $externalAccTerm = $externalAccCode;
+            |${functionIdentifiers(function)}.accumulate(
+            |  $externalAccTerm,
+            |  ${operandTerms.mkString(", ")});
+            |$aggBufferName = ${genToInternal(ctx, externalAccType)(externalAccTerm)};
+            |${aggBufferExpr.nullTerm} = false;
           """.stripMargin
-        } else {
-          innerCode
-        }
-    }) mkString "\n"
+      }
+
+      // apply filter if present
+      if (aggInfo.agg.filterArg >= 0) {
+        s"""
+          |if ($inputTerm.getBoolean(${aggCall.filterArg})) {
+          |  $accCode
+          |}
+        """.stripMargin
+      } else {
+        accCode
+      }
+    }
+
+    filteredAccCode.mkString("\n")
   }
 
   /**
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggWithoutKeysCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggWithoutKeysCodeGenerator.scala
index 43725b9..5a99b5a 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggWithoutKeysCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggWithoutKeysCodeGenerator.scala
@@ -44,19 +44,21 @@ object AggWithoutKeysCodeGenerator {
       outputType: RowType,
       isMerge: Boolean,
       isFinal: Boolean,
-      prefix: String): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
-    val aggCallToAggFunction = aggInfoList.aggInfos.map(info => (info.agg, info.function))
-    val aggregates = aggCallToAggFunction.map(_._2)
-    val udaggs = AggCodeGenHelper.getUdaggs(aggregates)
-    val aggBufferNames = AggCodeGenHelper.getAggBufferNames(Array(), aggregates)
-    val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, Array(), aggregates)
-    val aggArgs = aggInfoList.aggInfos.map(_.argIndexes)
+      prefix: String)
+    : GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
 
-    val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
+    // prepare for aggregation
+    val auxGrouping = Array[Int]()
+    val aggInfos = aggInfoList.aggInfos
+    aggInfos
+      .map(_.function)
+      .filter(_.isInstanceOf[AggregateFunction[_, _]])
+      .map(ctx.addReusableFunction(_))
+    val functionIdentifiers = AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+    val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+    val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
 
-    // register udagg
-    aggregates.filter(a => a.isInstanceOf[AggregateFunction[_, _]])
-        .map(a => ctx.addReusableFunction(a))
+    val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
 
     val (initAggBufferCode, doAggregateCode, aggOutputExpr) = genSortAggCodes(
       isMerge,
@@ -65,11 +67,8 @@ object AggWithoutKeysCodeGenerator {
       builder,
       Array(),
       Array(),
-      aggCallToAggFunction,
-      aggArgs,
-      aggregates,
-      aggInfoList.aggInfos.map(_.externalResultType),
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       inputTerm,
       inputType,
       aggBufferNames,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala
index 5ba6227..e8efbee 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala
@@ -18,12 +18,13 @@
 
 package org.apache.flink.table.planner.codegen.agg.batch
 
+import org.apache.calcite.tools.RelBuilder
 import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
 import org.apache.flink.metrics.Gauge
 import org.apache.flink.table.data.binary.BinaryRowData
 import org.apache.flink.table.data.{GenericRowData, JoinedRowData, RowData}
-import org.apache.flink.table.expressions.{Expression, _}
-import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.planner.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull}
 import org.apache.flink.table.planner.codegen._
 import org.apache.flink.table.planner.codegen.agg.batch.AggCodeGenHelper.buildAggregateArgsMapping
@@ -32,17 +33,13 @@ import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver
 import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef
 import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
 import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
-import org.apache.flink.table.planner.plan.utils.SortUtil
+import org.apache.flink.table.planner.plan.utils.{AggregateInfo, SortUtil}
 import org.apache.flink.table.runtime.generated.{NormalizedKeyComputer, RecordComparator}
 import org.apache.flink.table.runtime.operators.aggregate.{BytesHashMap, BytesHashMapSpillMemorySegmentPool}
 import org.apache.flink.table.runtime.operators.sort.BufferedKVExternalSorter
 import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer
-import org.apache.flink.table.types.DataType
 import org.apache.flink.table.types.logical.{LogicalType, RowType}
 
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.tools.RelBuilder
-
 import scala.collection.JavaConversions._
 
 object HashAggCodeGenHelper {
@@ -118,24 +115,23 @@ object HashAggCodeGenHelper {
       groupingAndAuxGrouping: (Array[Int], Array[Int]),
       inputTerm: String,
       inputType: RowType,
-      aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
-      aggArgs: Array[Array[Int]],
-      aggregates: Seq[UserDefinedFunction],
+      aggInfos: Seq[AggregateInfo],
       currentAggBufferTerm: String,
       aggBufferRowType: RowType,
       aggBufferTypes: Array[Array[LogicalType]],
       outputTerm: String,
       outputType: RowType,
       groupKeyTerm: String,
-      aggBufferTerm: String): (GeneratedExpression, GeneratedExpression, GeneratedExpression) = {
+      aggBufferTerm: String)
+    : (GeneratedExpression, GeneratedExpression, GeneratedExpression) = {
     val (grouping, auxGrouping) = groupingAndAuxGrouping
     // build mapping for DeclarativeAggregationFunction binding references
     val argsMapping = buildAggregateArgsMapping(
-      isMerge, grouping.length, inputType, auxGrouping, aggArgs, aggBufferTypes)
+      isMerge, grouping.length, inputType, auxGrouping, aggInfos, aggBufferTypes)
     val aggBuffMapping = buildAggregateAggBuffMapping(aggBufferTypes)
     // gen code to create empty agg buffer
     val initedAggBuffer = genReusableEmptyAggBuffer(
-      ctx, builder, inputTerm, inputType, auxGrouping, aggregates, aggBufferRowType)
+      ctx, builder, inputTerm, inputType, auxGrouping, aggInfos, aggBufferRowType)
     if (auxGrouping.isEmpty) {
       // create an empty agg buffer and initialized make it reusable
       ctx.addReusableOpenStatement(initedAggBuffer.code)
@@ -148,8 +144,7 @@ object HashAggCodeGenHelper {
       inputType,
       inputTerm,
       auxGrouping,
-      aggregates,
-      aggCallToAggFunction,
+      aggInfos,
       argsMapping,
       aggBuffMapping,
       currentAggBufferTerm,
@@ -161,7 +156,7 @@ object HashAggCodeGenHelper {
       ctx,
       builder,
       auxGrouping,
-      aggregates,
+      aggInfos,
       argsMapping,
       aggBuffMapping,
       outputTerm,
@@ -196,26 +191,29 @@ object HashAggCodeGenHelper {
       inputTerm: String,
       inputType: RowType,
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
-      aggBufferType: RowType): GeneratedExpression = {
-    val exprCodegen = new ExprCodeGenerator(ctx, false)
+      aggInfos: Seq[AggregateInfo],
+      aggBufferType: RowType)
+    : GeneratedExpression = {
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
         .bindInput(inputType, inputTerm = inputTerm, inputFieldMapping = Some(auxGrouping))
+    val converter = new ExpressionConverter(builder)
 
     val initAuxGroupingExprs = auxGrouping.map { idx =>
       GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, idx)
     }
 
-    val initAggCallBufferExprs = aggregates.flatMap(a =>
-      a.asInstanceOf[DeclarativeAggregateFunction].initialValuesExpressions)
-        .map(_.accept(new ExpressionConverter(builder)))
-        .map(exprCodegen.generateExpression)
+    val initAggCallBufferExprs = aggInfos
+      .map(_.function.asInstanceOf[DeclarativeAggregateFunction])
+      .flatMap(_.initialValuesExpressions)
+      .map(_.accept(converter))
+      .map(exprCodeGen.generateExpression)
 
     val initAggBufferExprs = initAuxGroupingExprs ++ initAggCallBufferExprs
 
     // empty agg buffer and writer will be reused
     val emptyAggBufferTerm = CodeGenUtils.newName("emptyAggBuffer")
     val emptyAggBufferWriterTerm = CodeGenUtils.newName("emptyAggBufferWriterTerm")
-    exprCodegen.generateResultExpression(
+    exprCodeGen.generateResultExpression(
       initAggBufferExprs,
       aggBufferType,
       classOf[BinaryRowData],
@@ -231,8 +229,7 @@ object HashAggCodeGenHelper {
       inputType: RowType,
       inputTerm: String,
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
-      aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
+      aggInfos: Seq[AggregateInfo],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBuffMapping: Array[Array[(Int, LogicalType)]],
       currentAggBufferTerm: String,
@@ -245,7 +242,7 @@ object HashAggCodeGenHelper {
         inputType,
         currentAggBufferTerm,
         auxGrouping,
-        aggregates,
+        aggInfos,
         argsMapping,
         aggBuffMapping,
         aggBufferRowType)
@@ -257,7 +254,7 @@ object HashAggCodeGenHelper {
         inputType,
         currentAggBufferTerm,
         auxGrouping,
-        aggCallToAggFunction,
+        aggInfos,
         argsMapping,
         aggBuffMapping,
         aggBufferRowType)
@@ -270,7 +267,7 @@ object HashAggCodeGenHelper {
       ctx: CodeGeneratorContext,
       builder: RelBuilder,
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
+      aggInfos: Seq[AggregateInfo],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBuffMapping: Array[Array[(Int, LogicalType)]],
       outputTerm: String,
@@ -279,29 +276,46 @@ object HashAggCodeGenHelper {
       inputType: RowType,
       groupKeyTerm: Option[String],
       aggBufferTerm: String,
-      aggBufferType: RowType): GeneratedExpression = {
-    // gen code to get agg result
-    val exprCodegen = new ExprCodeGenerator(ctx, false)
+      aggBufferType: RowType)
+    : GeneratedExpression = {
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
         .bindInput(inputType, inputTerm = inputTerm)
         .bindSecondInput(aggBufferType, inputTerm = aggBufferTerm)
+    val converter = new ExpressionConverter(builder)
+
     val resultExpr = if (isFinal) {
+
       val bindRefOffset = inputType.getFieldCount
-      val getAuxGroupingExprs = auxGrouping.indices.map { idx =>
-        val (_, resultType) = aggBuffMapping(idx)(0)
-        toRexInputRef(builder, bindRefOffset + idx, resultType)
-      }.map(_.accept(new ExpressionConverter(builder))).map(exprCodegen.generateExpression)
 
-      val getAggValueExprs = aggregates.zipWithIndex.map {
-        case (agg: DeclarativeAggregateFunction, aggIndex) =>
-          val idx = auxGrouping.length + aggIndex
-          agg.getValueExpression.accept(ResolveReference(
-            ctx, builder, isMerge, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))
-      }.map(_.accept(new ExpressionConverter(builder))).map(exprCodegen.generateExpression)
+      val getAuxGroupingExprs = auxGrouping.indices
+        .map { idx =>
+          val (_, resultType) = aggBuffMapping(idx)(0)
+          toRexInputRef(builder, bindRefOffset + idx, resultType)
+        }
+
+      val getAggValueExprs = aggInfos.map { aggInfo =>
+        val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+        val function = aggInfo.function.asInstanceOf[DeclarativeAggregateFunction]
+        val ref = ResolveReference(
+          ctx,
+          builder,
+          isMerge,
+          bindRefOffset,
+          function,
+          aggBufferIdx,
+          argsMapping,
+          aggBuffMapping)
+        function.getValueExpression
+          .accept(ref)
+      }
+
+      val getValueExprs = (getAuxGroupingExprs ++ getAggValueExprs)
+        .map(_.accept(converter))
+        .map(exprCodeGen.generateExpression)
 
-      val getValueExprs = getAuxGroupingExprs ++ getAggValueExprs
       val aggValueTerm = CodeGenUtils.newName("aggVal")
       val valueType = RowType.of(getValueExprs.map(_.resultType): _*)
-      exprCodegen.generateResultExpression(
+      exprCodeGen.generateResultExpression(
         getValueExprs,
         valueType,
         classOf[GenericRowData],
@@ -365,22 +379,36 @@ object HashAggCodeGenHelper {
       inputType: RowType,
       currentAggBufferTerm: String,
       auxGrouping: Array[Int],
-      aggregates: Seq[UserDefinedFunction],
+      aggInfos: Seq[AggregateInfo],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBuffMapping: Array[Array[(Int, LogicalType)]],
-      aggBufferType: RowType): GeneratedExpression = {
-    val exprCodegen = new ExprCodeGenerator(ctx, false)
+      aggBufferType: RowType)
+    : GeneratedExpression = {
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
         .bindInput(inputType, inputTerm = inputTerm)
         .bindSecondInput(aggBufferType, inputTerm = currentAggBufferTerm)
+    val converter = new ExpressionConverter(builder)
 
-    val mergeExprs = aggregates.zipWithIndex.flatMap {
-      case (agg: DeclarativeAggregateFunction, aggIndex) =>
-        val idx = auxGrouping.length + aggIndex
-        val bindRefOffset = inputType.getFieldCount
-        agg.mergeExpressions.map(
-          _.accept(ResolveReference(
-            ctx, builder, isMerge = true, bindRefOffset, agg, idx, argsMapping, aggBuffMapping)))
-    }.map(_.accept(new ExpressionConverter(builder))).map(exprCodegen.generateExpression)
+    val mergeExprs = aggInfos
+      .map(_.function)
+      .zipWithIndex
+      .flatMap {
+        case (agg: DeclarativeAggregateFunction, aggIndex) =>
+          val aggBufferIdx = auxGrouping.length + aggIndex
+          val bindRefOffset = inputType.getFieldCount
+          val ref = ResolveReference(
+            ctx,
+            builder,
+            isMerge = true,
+            bindRefOffset,
+            agg,
+            aggBufferIdx,
+            argsMapping,
+            aggBuffMapping)
+          agg.mergeExpressions.map(_.accept(ref))
+      }
+      .map(_.accept(converter))
+      .map(exprCodeGen.generateExpression)
 
     val aggBufferTypeWithoutAuxGrouping = if (auxGrouping.nonEmpty) {
       // auxGrouping does not need merge-code
@@ -398,7 +426,7 @@ object HashAggCodeGenHelper {
     }.toMap
 
     // update agg buff in-place
-    exprCodegen.generateResultExpression(
+    exprCodeGen.generateResultExpression(
       mergeExprs,
       mergeExprIdxToOutputRowPosMap,
       aggBufferTypeWithoutAuxGrouping,
@@ -423,30 +451,37 @@ object HashAggCodeGenHelper {
       inputType: RowType,
       currentAggBufferTerm: String,
       auxGrouping: Array[Int],
-      aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
+      aggInfos: Seq[AggregateInfo],
       argsMapping: Array[Array[(Int, LogicalType)]],
       aggBuffMapping: Array[Array[(Int, LogicalType)]],
-      aggBufferType: RowType): GeneratedExpression = {
-    val exprCodegen = new ExprCodeGenerator(ctx, false)
+      aggBufferType: RowType)
+    : GeneratedExpression = {
+    val exprCodeGen = new ExprCodeGenerator(ctx, false)
         .bindInput(inputType, inputTerm = inputTerm)
         .bindSecondInput(aggBufferType, inputTerm = currentAggBufferTerm)
-
-    val accumulateExprsWithFilterArgs = aggCallToAggFunction.zipWithIndex.flatMap {
-      case (aggCallToAggFun, aggIndex) =>
-        val idx = auxGrouping.length + aggIndex
-        val bindRefOffset = inputType.getFieldCount
-        val aggCall = aggCallToAggFun._1
-        aggCallToAggFun._2 match {
-          case agg: DeclarativeAggregateFunction =>
-            agg.accumulateExpressions.map(_.accept(ResolveReference(
-              ctx, builder, isMerge = false, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))
-            ).map(e => (e, aggCall))
-        }
-    }.map {
-      case (expr: Expression, aggCall: AggregateCall) =>
-        (exprCodegen.generateExpression(expr.accept(new ExpressionConverter(builder))),
-            aggCall.filterArg)
-    }
+    val converter = new ExpressionConverter(builder)
+
+    val bindRefOffset = inputType.getFieldCount
+
+    val accumulateExprsWithFilterArgs = aggInfos
+      .flatMap { aggInfo =>
+        val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+        val function = aggInfo.function.asInstanceOf[DeclarativeAggregateFunction]
+        val ref = ResolveReference(
+          ctx,
+          builder,
+          isMerge = false,
+          bindRefOffset,
+          function,
+          aggBufferIdx,
+          argsMapping,
+          aggBuffMapping)
+        function.accumulateExpressions
+          .map(_.accept(ref))
+          .map { e =>
+            (exprCodeGen.generateExpression(e.accept(converter)), aggInfo.agg.filterArg)
+          }
+      }
 
     // update agg buff in-place
     val code = accumulateExprsWithFilterArgs.zipWithIndex.map({
@@ -537,10 +572,8 @@ object HashAggCodeGenHelper {
       ctx: CodeGeneratorContext,
       builder: RelBuilder,
       groupingAndAuxGrouping: (Array[Int], Array[Int]),
-      aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
-      aggArgs: Array[Array[Int]],
-      aggResultTypes: Seq[DataType],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       logTerm: String,
       aggregateMapTerm: String,
       aggMapKVTypesTerm: (String, String),
@@ -570,11 +603,8 @@ object HashAggCodeGenHelper {
         builder,
         grouping,
         auxGrouping,
-        aggCallToAggFunction,
-        aggArgs,
-        aggCallToAggFunction.map(_._2),
-        aggResultTypes,
-        udaggs,
+        aggInfos,
+        functionIdentifiers,
         aggregateMapTerm,
         (groupKeyRowType, aggBufferRowType),
         aggregateMapTerm,
@@ -697,11 +727,8 @@ object HashAggCodeGenHelper {
       builder: RelBuilder,
       grouping: Array[Int],
       auxGrouping: Array[Int],
-      aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
-      aggArgs: Array[Array[Int]],
-      aggregates: Seq[UserDefinedFunction],
-      aggResultTypes: Seq[DataType],
-      udaggs: Map[AggregateFunction[_, _], String],
+      aggInfos: Seq[AggregateInfo],
+      functionIdentifiers: Map[AggregateFunction[_, _], String],
       mapTerm: String,
       mapKVRowTypes: (RowType, RowType),
       aggregateMapTerm: String,
@@ -728,11 +755,8 @@ object HashAggCodeGenHelper {
       builder,
       grouping,
       auxGrouping,
-      aggCallToAggFunction,
-      aggArgs,
-      aggregates,
-      aggResultTypes,
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       fallbackInputTerm,
       fallbackInputType,
       aggBufferNames,
@@ -796,7 +820,7 @@ object HashAggCodeGenHelper {
       aggMapKeyType: RowType) : String = {
     val keyFieldTypes = aggMapKeyType.getChildren.toArray(Array[LogicalType]())
     val keys = keyFieldTypes.indices.toArray
-    val orders = keys.map((_) => true)
+    val orders = keys.map(_ => true)
     val nullsIsLast = SortUtil.getNullDefaultOrders(orders)
 
     val sortCodeGenerator = new SortCodeGenerator(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
index 5e74793..66552cd 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
@@ -21,15 +21,14 @@ package org.apache.flink.table.planner.codegen.agg.batch
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator
 import org.apache.flink.table.data.binary.BinaryRowData
 import org.apache.flink.table.data.{GenericRowData, JoinedRowData, RowData}
-import org.apache.flink.table.functions.UserDefinedFunction
+import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
 import org.apache.flink.table.planner.codegen.{CodeGenUtils, CodeGeneratorContext, ProjectionCodeGenerator}
 import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
-import org.apache.flink.table.planner.plan.utils.AggregateInfoList
+import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList}
 import org.apache.flink.table.runtime.generated.GeneratedOperator
 import org.apache.flink.table.runtime.operators.TableStreamOperator
 import org.apache.flink.table.runtime.operators.aggregate.{BytesHashMap, BytesHashMapSpillMemorySegmentPool}
-import org.apache.flink.table.types.logical.RowType
-
+import org.apache.flink.table.types.logical.{LogicalType, RowType}
 import org.apache.calcite.tools.RelBuilder
 
 /**
@@ -48,17 +47,20 @@ class HashAggCodeGenerator(
     isMerge: Boolean,
     isFinal: Boolean) {
 
+  private lazy val aggInfos: Array[AggregateInfo] = aggInfoList.aggInfos
+
+  private lazy val functionIdentifiers: Map[AggregateFunction[_, _], String] =
+    AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+
+  private lazy val aggBufferNames: Array[Array[String]] =
+    AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+
+  private lazy val aggBufferTypes: Array[Array[LogicalType]] = AggCodeGenHelper.getAggBufferTypes(
+    inputType,
+    auxGrouping,
+    aggInfos)
+
   private lazy val groupKeyRowType = AggCodeGenHelper.projectRowType(inputType, grouping)
-  private lazy val aggCallToAggFunction =
-    aggInfoList.aggInfos.map(info => (info.agg, info.function))
-  private lazy val aggregates: Seq[UserDefinedFunction] = aggInfoList.aggInfos.map(_.function)
-  private lazy val aggArgs: Array[Array[Int]] = aggInfoList.aggInfos.map(_.argIndexes)
-  // get udagg instance names
-  private lazy val udaggs = AggCodeGenHelper.getUdaggs(aggregates)
-  // currently put auxGrouping to aggBuffer in code-gen
-  private lazy val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggregates)
-  private lazy val aggBufferTypes =
-    AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggregates)
   private lazy val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
 
   def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
@@ -119,9 +121,7 @@ class HashAggCodeGenerator(
       (grouping, auxGrouping),
       inputTerm,
       inputType,
-      aggCallToAggFunction,
-      aggArgs,
-      aggregates,
+      aggInfos,
       currentAggBufferTerm,
       aggBufferRowType,
       aggBufferTypes,
@@ -143,10 +143,8 @@ class HashAggCodeGenerator(
       ctx,
       builder,
       (grouping, auxGrouping),
-      aggCallToAggFunction,
-      aggArgs,
-      aggInfoList.aggInfos.map(_.externalResultType),
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       logTerm,
       aggregateMapTerm,
       (groupKeyTypesTerm, aggBufferTypesTerm),
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashWindowCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashWindowCodeGenerator.scala
index 53cc295..cebfc73 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashWindowCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashWindowCodeGenerator.scala
@@ -369,11 +369,11 @@ class HashWindowCodeGenerator(
     // build mapping for DeclarativeAggregationFunction binding references
     val offset = if (isMerge) grouping.length + 1 else grouping.length
     val argsMapping = AggCodeGenHelper.buildAggregateArgsMapping(
-      isMerge, offset, inputType,  auxGrouping, aggArgs, aggBufferTypes)
+      isMerge, offset, inputType, auxGrouping, aggInfos, aggBufferTypes)
     val aggBuffMapping = HashAggCodeGenHelper.buildAggregateAggBuffMapping(aggBufferTypes)
     // gen code to create empty agg buffer
     val initedAggBuffer = HashAggCodeGenHelper.genReusableEmptyAggBuffer(
-      ctx, builder, inputTerm, inputType, auxGrouping, aggregates, aggBufferRowType)
+      ctx, builder, inputTerm, inputType, auxGrouping, aggInfos, aggBufferRowType)
     if (auxGrouping.isEmpty) {
       // init aggBuffer in open function when there is no auxGrouping
       ctx.addReusableOpenStatement(initedAggBuffer.code)
@@ -386,8 +386,7 @@ class HashWindowCodeGenerator(
       inputType,
       inputTerm,
       auxGrouping,
-      aggregates,
-      aggCallToAggFunction,
+      aggInfos,
       argsMapping,
       aggBuffMapping,
       currentAggBufferTerm,
@@ -650,7 +649,7 @@ class HashWindowCodeGenerator(
         ctx,
         builder,
         auxGrouping,
-        aggregates,
+        aggInfos,
         argsMapping,
         aggBuffMapping,
         outputTerm,
@@ -697,7 +696,7 @@ class HashWindowCodeGenerator(
         ctx,
         builder,
         auxGrouping,
-        aggregates,
+        aggInfos,
         argsMapping,
         aggBuffMapping,
         outputTerm,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
index b413324..b0f6f99 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
@@ -46,16 +46,20 @@ object SortAggCodeGenerator {
       grouping: Array[Int],
       auxGrouping: Array[Int],
       isMerge: Boolean,
-      isFinal: Boolean): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
-    val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
-
-    val aggCallToAggFunction = aggInfoList.aggInfos.map(info => (info.agg, info.function))
-    val aggArgs = aggInfoList.aggInfos.map(_.argIndexes)
-
-    // register udaggs
-    aggCallToAggFunction.map(_._2).filter(a => a.isInstanceOf[AggregateFunction[_, _]])
-        .map(a => ctx.addReusableFunction(a))
+      isFinal: Boolean)
+    : GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
+
+    // prepare for aggregation
+    val aggInfos = aggInfoList.aggInfos
+    aggInfos
+      .map(_.function)
+      .filter(_.isInstanceOf[AggregateFunction[_, _]])
+      .map(ctx.addReusableFunction(_))
+    val functionIdentifiers = AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+    val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+    val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
 
+    val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
     val lastKeyTerm = "lastKey"
     val currentKeyTerm = "currentKey"
     val currentKeyWriterTerm = "currentKeyWriter"
@@ -72,11 +76,6 @@ object SortAggCodeGenerator {
 
     val keyNotEquals = AggCodeGenHelper.genGroupKeyChangedCheckCode(currentKeyTerm, lastKeyTerm)
 
-    val aggregates = aggCallToAggFunction.map(_._2)
-    val udaggs = AggCodeGenHelper.getUdaggs(aggregates)
-    val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggregates)
-    val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggregates)
-
     val (initAggBufferCode, doAggregateCode, aggOutputExpr) = AggCodeGenHelper.genSortAggCodes(
       isMerge,
       isFinal,
@@ -84,11 +83,8 @@ object SortAggCodeGenerator {
       builder,
       grouping,
       auxGrouping,
-      aggCallToAggFunction,
-      aggArgs,
-      aggregates,
-      aggInfoList.aggInfos.map(_.externalResultType),
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       inputTerm,
       inputType,
       aggBufferNames,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortWindowCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortWindowCodeGenerator.scala
index f51a74a..3677b41 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortWindowCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortWindowCodeGenerator.scala
@@ -88,13 +88,15 @@ class SortWindowCodeGenerator(
     isMerge,
     isFinal) {
 
+  // prepare for aggregation
+  aggInfos
+      .map(_.function)
+      .filter(_.isInstanceOf[AggregateFunction[_, _]])
+      .map(ctx.addReusableFunction(_))
+
   def genWithoutKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
     val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
 
-    aggCallToAggFunction
-        .map(_._2).filter(a => a.isInstanceOf[AggregateFunction[_, _]])
-        .map(a => ctx.addReusableFunction(a))
-
     val timeWindowType = classOf[TimeWindow].getName
     val currentWindow = CodeGenUtils.newName("currentWindow")
     ctx.addReusableMember(s"transient $timeWindowType $currentWindow = null;")
@@ -158,10 +160,6 @@ class SortWindowCodeGenerator(
   }
 
   def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
-    aggCallToAggFunction
-        .map(_._2).filter(a => a.isInstanceOf[AggregateFunction[_, _]])
-        .map(a => ctx.addReusableFunction(a))
-
     val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
 
     val currentKey = CodeGenUtils.newName("currentKey")
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
index 9b5cc68..9f9576b 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
@@ -18,12 +18,17 @@
 
 package org.apache.flink.table.planner.codegen.agg.batch
 
+import org.apache.calcite.avatica.util.DateTimeUtils
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.tools.RelBuilder
+import org.apache.commons.math3.util.ArithmeticUtils
 import org.apache.flink.table.api.DataTypes
 import org.apache.flink.table.data.binary.BinaryRowData
 import org.apache.flink.table.data.{GenericRowData, JoinedRowData, RowData}
 import org.apache.flink.table.expressions.ExpressionUtils.extractValue
 import org.apache.flink.table.expressions.{Expression, ValueLiteralExpression}
-import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.planner.JLong
 import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindowProperty
 import org.apache.flink.table.planner.calcite.FlinkTypeFactory
@@ -37,24 +42,15 @@ import org.apache.flink.table.planner.codegen.agg.batch.WindowCodeGenerator.{asL
 import org.apache.flink.table.planner.expressions.CallExpressionResolver
 import org.apache.flink.table.planner.expressions.ExpressionBuilder._
 import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
-import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
-import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
 import org.apache.flink.table.planner.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow}
-import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil}
+import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList, AggregateUtil}
 import org.apache.flink.table.runtime.operators.window.TimeWindow
 import org.apache.flink.table.runtime.operators.window.grouping.{HeapWindowsGrouping, WindowsGrouping}
-import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
 import org.apache.flink.table.runtime.util.RowIterator
 import org.apache.flink.table.types.logical.LogicalTypeRoot.INTERVAL_DAY_TIME
 import org.apache.flink.table.types.logical._
 import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot
 
-import org.apache.calcite.avatica.util.DateTimeUtils
-import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.tools.RelBuilder
-import org.apache.commons.math3.util.ArithmeticUtils
-
 import scala.collection.JavaConversions._
 
 abstract class WindowCodeGenerator(
@@ -71,49 +67,32 @@ abstract class WindowCodeGenerator(
     val isMerge: Boolean,
     val isFinal: Boolean) {
 
-  lazy val builder: RelBuilder = relBuilder.values(inputRowType)
-  lazy val timestampInternalType: LogicalType =
+  protected lazy val builder: RelBuilder = relBuilder.values(inputRowType)
+
+  protected lazy val aggInfos: Array[AggregateInfo] = aggInfoList.aggInfos
+
+  protected lazy val functionIdentifiers: Map[AggregateFunction[_, _], String] =
+    AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+
+  protected lazy val aggBufferNames: Array[Array[String]] =
+    AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+
+  protected lazy val aggBufferTypes: Array[Array[LogicalType]] = AggCodeGenHelper.getAggBufferTypes(
+    inputType,
+    auxGrouping,
+    aggInfos)
+
+  protected lazy val groupKeyRowType: RowType = AggCodeGenHelper.projectRowType(inputType, grouping)
+
+  private lazy val inputType: RowType =
+    FlinkTypeFactory.toLogicalType(inputRowType).asInstanceOf[RowType]
+
+  protected lazy val timestampInternalType: LogicalType =
     if (inputTimeIsDate) new IntType() else new BigIntType()
-  lazy val timestampInternalTypeName: String = if (inputTimeIsDate) "Int" else "Long"
-  lazy val aggCallToAggFunction: Array[(AggregateCall, UserDefinedFunction)] =
-    aggInfoList.aggInfos.map(info => (info.agg, info.function))
-  lazy val aggregateCalls: Seq[AggregateCall] = aggCallToAggFunction.map(_._1)
-  lazy val aggregates: Seq[UserDefinedFunction] = aggCallToAggFunction.map(_._2)
-
-  lazy val aggArgs: Array[Array[Int]] = aggInfoList.aggInfos.map(_.argIndexes)
-
-  // currently put auxGrouping to aggBuffer in code-gen
-  lazy val aggBufferNames: Array[Array[String]] = auxGrouping.zipWithIndex.map {
-    case (_, index) => Array(s"aux_group$index")
-  } ++ aggregates.zipWithIndex.toArray.map {
-    case (a: DeclarativeAggregateFunction, index) =>
-      val idx = auxGrouping.length + index
-      a.aggBufferAttributes.map(attr => s"agg${idx}_${attr.getName}")
-    case (_: AggregateFunction[_, _], index) =>
-      val idx = auxGrouping.length + index
-      Array(s"agg$idx")
-  }
 
-  lazy val aggBufferTypes: Array[Array[LogicalType]] = auxGrouping.map { index =>
-    Array(FlinkTypeFactory.toLogicalType(inputRowType.getFieldList.get(index).getType))
-  } ++ aggregates.map {
-    case a: DeclarativeAggregateFunction => a.getAggBufferTypes.map(_.getLogicalType)
-    case a: AggregateFunction[_, _] =>
-      Array(getAccumulatorTypeOfAggregateFunction(a)).map(fromDataTypeToLogicalType)
-  }.toArray[Array[LogicalType]]
-
-  lazy val groupKeyRowType: RowType = RowType.of(
-    grouping.map { index =>
-      FlinkTypeFactory.toLogicalType(inputRowType.getFieldList.get(index).getType)
-    }, grouping.map(inputRowType.getFieldNames.get(_)))
-
-  // get udagg instance names
-  lazy val udaggs: Map[AggregateFunction[_, _], String] = aggregates
-      .filter(a => a.isInstanceOf[AggregateFunction[_, _]])
-      .map(a => a -> CodeGenUtils.udfFieldName(a)).toMap
-      .asInstanceOf[Map[AggregateFunction[_, _], String]]
-
-  lazy val windowedGroupKeyType: RowType = RowType.of(
+  protected lazy val timestampInternalTypeName: String = if (inputTimeIsDate) "Int" else "Long"
+
+  private lazy val windowedGroupKeyType: RowType = RowType.of(
     (groupKeyRowType.getChildren :+ timestampInternalType).toArray,
     (groupKeyRowType.getFieldNames :+ "assignedTs$").toArray)
 
@@ -224,13 +203,13 @@ abstract class WindowCodeGenerator(
     // gen code to apply aggregate functions to grouping window elements
     val offset = if (enablePreAcc) grouping.length + 1 else grouping.length
     val argsMapping = buildAggregateArgsMapping(
-      enablePreAcc, offset, inputType, auxGrouping, aggArgs, aggBufferTypes)
+      enablePreAcc, offset, inputType, auxGrouping, aggInfos, aggBufferTypes)
     val aggBufferExprs = genFlatAggBufferExprs(
       enablePreAcc,
       ctx,
       builder,
       auxGrouping,
-      aggregates,
+      aggInfos,
       argsMapping,
       aggBufferNames,
       aggBufferTypes)
@@ -241,8 +220,8 @@ abstract class WindowCodeGenerator(
       inputTerm,
       grouping,
       auxGrouping,
-      aggregates,
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       aggBufferExprs)
     val doAggregateCode = genAggregateByFlatAggregateBuffer(
       enablePreAcc,
@@ -251,9 +230,8 @@ abstract class WindowCodeGenerator(
       inputType,
       inputTerm,
       auxGrouping,
-      aggCallToAggFunction,
-      aggregates,
-      udaggs,
+      aggInfos,
+      functionIdentifiers,
       argsMapping,
       aggBufferNames,
       aggBufferTypes,
@@ -271,9 +249,8 @@ abstract class WindowCodeGenerator(
         builder,
         grouping,
         auxGrouping,
-        aggregates,
-        aggInfoList.aggInfos.map(_.externalResultType),
-        udaggs,
+        aggInfos,
+        functionIdentifiers,
         argsMapping,
         aggBufferNames,
         aggBufferTypes,
@@ -442,13 +419,13 @@ abstract class WindowCodeGenerator(
         // case: global/complete window agg: Sliding window with with pane optimization
         val offset = if (isMerge) grouping.length + 1 else grouping.length
         val argsMapping = buildAggregateArgsMapping(
-          isMerge, offset, inputType, auxGrouping, aggArgs, aggBufferTypes)
+          isMerge, offset, inputType, auxGrouping, aggInfos, aggBufferTypes)
         val aggBufferExprs = genFlatAggBufferExprs(
           isMerge,
           ctx,
           builder,
           auxGrouping,
-          aggregates,
+          aggInfos,
           argsMapping,
           aggBufferNames,
           aggBufferTypes)
@@ -459,8 +436,8 @@ abstract class WindowCodeGenerator(
           inputTerm,
           grouping,
           auxGrouping,
-          aggregates,
-          udaggs,
+          aggInfos,
+          functionIdentifiers,
           aggBufferExprs)
         val doAggregateCode = genAggregateByFlatAggregateBuffer(
           isMerge,
@@ -469,9 +446,8 @@ abstract class WindowCodeGenerator(
           inputType,
           inputTerm,
           auxGrouping,
-          aggCallToAggFunction,
-          aggregates,
-          udaggs,
+          aggInfos,
+          functionIdentifiers,
           argsMapping,
           aggBufferNames,
           aggBufferTypes,
@@ -710,7 +686,7 @@ abstract class WindowCodeGenerator(
 
   def getAuxGrouping: Array[Int] = auxGrouping
 
-  def getAggCallList: Seq[AggregateCall] = aggCallToAggFunction.map(_._1)
+  def getAggCallList: Seq[AggregateCall] = aggInfos.map(_.agg)
 
   def getInputTimeValue(inputTerm: String, index: Int): String = {
     if(inputTimeIsDate) {
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/types/LogicalTypeDataTypeConverter.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/types/LogicalTypeDataTypeConverter.java
index fde61b8..47bc778 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/types/LogicalTypeDataTypeConverter.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/types/LogicalTypeDataTypeConverter.java
@@ -34,6 +34,7 @@ import org.apache.flink.table.types.utils.TypeConversions;
 @Deprecated
 public class LogicalTypeDataTypeConverter {
 
+	@Deprecated
 	public static DataType fromLogicalTypeToDataType(LogicalType logicalType) {
 		return TypeConversions.fromLogicalToDataType(logicalType);
 	}
@@ -41,6 +42,7 @@ public class LogicalTypeDataTypeConverter {
 	/**
 	 * It convert {@link LegacyTypeInformationType} to planner types.
 	 */
+	@Deprecated
 	public static LogicalType fromDataTypeToLogicalType(DataType dataType) {
 		return PlannerTypeUtils.removeLegacyTypes(dataType.getLogicalType());
 	}