You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2019/01/18 12:47:02 UTC

[flink] branch master updated: [FLINK-8739] [table] Optimize DISTINCE aggregates to use the same distinct accumulator if possible

This is an automated email from the ASF dual-hosted git repository.

jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 12260a2  [FLINK-8739] [table] Optimize DISTINCE aggregates to use the same distinct accumulator if possible
12260a2 is described below

commit 12260a27ddb0ec8356b4981f033c6c9954fe2ab4
Author: Dian Fu <fu...@alibaba-inc.com>
AuthorDate: Mon Dec 10 20:57:23 2018 +0800

    [FLINK-8739] [table] Optimize DISTINCE aggregates to use the same distinct accumulator if possible
    
    This closes #7286
---
 .../table/codegen/AggregationCodeGenerator.scala   | 318 +++++++++++----------
 .../flink/table/codegen/MatchCodeGenerator.scala   |  32 ++-
 .../aggfunctions/DistinctAccumulator.scala         |  19 +-
 .../functions/utils/UserDefinedFunctionUtils.scala |   7 +-
 .../table/runtime/aggregate/AggregateUtil.scala    | 189 +++++++-----
 .../runtime/utils/JavaUserDefinedAggFunctions.java |  76 +++++
 .../harness/GroupAggregateHarnessTest.scala        | 185 ++++++++++--
 .../table/runtime/harness/HarnessTestBase.scala    | 187 ++----------
 .../runtime/stream/sql/OverWindowITCase.scala      |  47 +--
 .../flink/table/runtime/stream/sql/SqlITCase.scala |  15 +-
 .../runtime/stream/table/AggregateITCase.scala     |  23 +-
 11 files changed, 627 insertions(+), 471 deletions(-)

diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index 57cc815..ea148bc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -19,17 +19,17 @@ package org.apache.flink.table.codegen
 
 import java.lang.reflect.Modifier
 import java.lang.{Iterable => JIterable}
+import java.util.{List => JList}
 
 import org.apache.calcite.rex.RexLiteral
 import org.apache.flink.api.common.state.{ListStateDescriptor, MapStateDescriptor, State, StateDescriptor}
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument, getRawClass}
 import org.apache.flink.table.api.TableConfig
 import org.apache.flink.table.api.dataview._
 import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess}
 import org.apache.flink.table.codegen.Indenter.toISC
-import org.apache.flink.table.dataview.{MapViewTypeInfo, StateListView, StateMapView}
+import org.apache.flink.table.dataview.{StateListView, StateMapView}
 import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.functions.aggfunctions.DistinctAccumulator
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
@@ -38,6 +38,7 @@ import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, SingleEl
 import org.apache.flink.table.utils.EncodingUtils
 import org.apache.flink.types.Row
 
+import scala.collection.JavaConversions._
 import scala.collection.mutable
 
 /**
@@ -77,7 +78,8 @@ class AggregationCodeGenerator(
     * @param aggregates  All aggregate functions
     * @param aggFields   Indexes of the input fields for all aggregate functions
     * @param aggMapping  The mapping of aggregates to output fields
-    * @param isDistinctAggs The flag array indicating whether it is distinct aggregate.
+    * @param distinctAccMapping The mapping of the distinct accumulator index to the
+    *                           corresponding aggregates.
     * @param isStateBackedDataViews a flag to indicate if distinct filter uses state backend.
     * @param partialResults A flag defining whether final or partial results (accumulators) are set
     *                       to the output row.
@@ -98,7 +100,7 @@ class AggregationCodeGenerator(
       aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]],
       aggFields: Array[Array[Int]],
       aggMapping: Array[Int],
-      isDistinctAggs: Array[Boolean],
+      distinctAccMapping: Array[(Integer, JList[Integer])],
       isStateBackedDataViews: Boolean,
       partialResults: Boolean,
       fwdMapping: Array[Int],
@@ -142,21 +144,35 @@ class AggregationCodeGenerator(
       fields.mkString(", ")
     }
 
-    val parametersCodeForDistinctMerge = aggFields.map { inFields =>
-      val fields = inFields.filter(_ > -1).zipWithIndex.map { case (f, i) =>
-        // index to constant
-        if (f >= physicalInputTypes.length) {
-          constantFields(f - physicalInputTypes.length)
-        }
+    // get parameter lists for distinct acc, constant fields are not necessary
+    val parametersCodeForDistinctAcc = aggFields.map { inFields =>
+      val fields = inFields.filter(i => i > -1 && i < physicalInputTypes.length).map { f =>
         // index to input field
-        else {
-          s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) k.getField($i)"
-        }
+        s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) input.getField($f)"
       }
 
       fields.mkString(", ")
     }
 
+    val parametersCodeForDistinctMerge = aggFields.map { inFields =>
+      // transform inFields to pairs of (inField, index in acc) firstly,
+      // e.g. (4, 2, 3, 2) will be transformed to ((4,2), (2,0), (3,1), (2,0))
+      val fields = inFields.filter(_ > -1).groupBy(identity).toSeq.sortBy(_._1).zipWithIndex
+        .flatMap { case (a, i) => a._2.map((_, i)) }
+        .map { case (f, i) =>
+          // index to constant
+          if (f >= physicalInputTypes.length) {
+            constantFields(f - physicalInputTypes.length)
+          }
+          // index to input field
+          else {
+            s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) k.getField($i)"
+          }
+        }
+
+      fields.mkString(", ")
+    }
+
     // get method signatures
     val classes = UserDefinedFunctionUtils.typeInfoToClass(physicalInputTypes)
     val constantClasses = UserDefinedFunctionUtils.typeInfoToClass(constantTypes)
@@ -174,30 +190,11 @@ class AggregationCodeGenerator(
     }
 
     // get distinct filter of acc fields for each aggregate functions
-    val distinctAccType = s"${classOf[DistinctAccumulator[_]].getName}"
-
-    // preparing MapViewSpecs for distinct value maps
-    val distinctAggs: Array[Seq[DataViewSpec[_]]] = isDistinctAggs.zipWithIndex.map {
-      case (isDistinctAgg, idx) if isDistinctAgg =>
-
-        // get types of agg function arguments
-        val argTypes: Array[TypeInformation[_]] = aggFields(idx)
-          .map(physicalInputTypes(_))
-        // create type for MapView
-        val mapViewTypeInfo = new MapViewTypeInfo(
-          new RowTypeInfo(argTypes:_*),
-          BasicTypeInfo.LONG_TYPE_INFO)
-        // create MapViewSpec for distinct value map
-        Seq(
-          MapViewSpec(
-            "distinctAgg" + idx,
-            classOf[DistinctAccumulator[_]].getDeclaredField("distinctValueMap"),
-            mapViewTypeInfo)
-        )
-      case _ => Seq()
-    }
+    val distinctAccType = s"${classOf[DistinctAccumulator].getName}"
+
+    val distinctAccCount = distinctAccMapping.count(_._1 >= 0)
 
-    if (isDistinctAggs.contains(true) && partialResults && isStateBackedDataViews) {
+    if (distinctAccCount > 0 && partialResults && isStateBackedDataViews) {
       // should not happen, but add an error message just in case.
       throw new CodeGenException(
         s"Cannot emit partial results if DISTINCT values are tracked in state-backed maps. " +
@@ -266,31 +263,13 @@ class AggregationCodeGenerator(
       * aggregation functions.
       */
     def addAccumulatorDataViews(): Unit = {
-      if (isStateBackedDataViews) {
-        // create MapStates for distinct value maps
-        val descMapping: Map[String, StateDescriptor[_, _]] = distinctAggs
-          .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
-          .toMap[String, StateDescriptor[_ <: State, _]]
-
-        for (i <- aggs.indices) yield {
-          for (spec <- distinctAggs(i)) {
-            // Check if stat descriptor exists.
-            val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId,
-              throw new CodeGenException(
-                s"Can not find DataView for distinct filter in accumulator by id: ${spec.stateId}"))
-
-            addReusableDataView(spec, desc, i)
-          }
-        }
-      }
-
       if (accConfig.isDefined) {
         // create state handles for DataView backed accumulator fields.
         val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get
           .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
           .toMap[String, StateDescriptor[_ <: State, _]]
 
-        for (i <- aggs.indices) yield {
+        for (i <- 0 until aggs.length + distinctAccCount) yield {
           for (spec <- accConfig.get(i)) yield {
             // Check if stat descriptor exists.
             val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId,
@@ -375,14 +354,6 @@ class AggregationCodeGenerator(
       reusableCleanupStatements.add(cleanup)
     }
 
-    def genDistinctDataViewFieldSetter(str: String, i: Int): String = {
-      if (isStateBackedDataViews && distinctAggs(i).nonEmpty) {
-        genDataViewFieldSetter(distinctAggs(i), str, i)
-      } else {
-        ""
-      }
-    }
-
     def genAccDataViewFieldSetter(str: String, i: Int): String = {
       if (accConfig.isDefined) {
         genDataViewFieldSetter(accConfig.get(i), str, i)
@@ -429,38 +400,55 @@ class AggregationCodeGenerator(
            |    org.apache.flink.types.Row output) throws Exception """.stripMargin
 
       val setAggs: String = {
-        for (i <- aggs.indices) yield
-
+        for ((i, aggIndexes) <- distinctAccMapping) yield {
           if (partialResults) {
-            j"""
-               |    output.setField(
-               |      ${aggMapping(i)},
-               |      (${accTypes(i)}) accs.getField($i));""".stripMargin
-          } else {
-            val setAccOutput =
+            def setAggs(aggIndexes: JList[Integer]) = {
+              for (i <- aggIndexes) yield {
+                j"""
+                   |output.setField(
+                   |  ${aggMapping(i)},
+                   |  (${accTypes(i)}) accs.getField($i));
+                 """.stripMargin
+              }
+            }.mkString("\n")
+
+            if (i >= 0) {
               j"""
-                 |    ${genAccDataViewFieldSetter(s"acc$i", i)}
                  |    output.setField(
                  |      ${aggMapping(i)},
-                 |      baseClass$i.getValue(acc$i));
+                 |      ($distinctAccType) accs.getField($i));
+                 |    ${setAggs(aggIndexes)}
                  """.stripMargin
-            if (isDistinctAggs(i)) {
-                j"""
-                   |    org.apache.flink.table.functions.AggregateFunction baseClass$i =
-                   |      (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
-                   |    $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
-                   |    ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
-                   |    $setAccOutput
-                   """.stripMargin
-              } else {
+            } else {
+              j"""
+                 |    ${setAggs(aggIndexes)}
+                 """.stripMargin
+            }
+          } else {
+            def setAggs(aggIndexes: JList[Integer]) = {
+              for (i <- aggIndexes) yield {
+                val setAccOutput =
+                  j"""
+                     |${genAccDataViewFieldSetter(s"acc$i", i)}
+                     |output.setField(
+                     |  ${aggMapping(i)},
+                     |  baseClass$i.getValue(acc$i));
+                     """.stripMargin
+
                 j"""
-                   |    org.apache.flink.table.functions.AggregateFunction baseClass$i =
-                   |      (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
-                   |    ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
-                   |    $setAccOutput
+                   |org.apache.flink.table.functions.AggregateFunction baseClass$i =
+                   |  (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
+                   |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+                   |$setAccOutput
                    """.stripMargin
               }
-            }
+            }.mkString("\n")
+
+            j"""
+               |    ${setAggs(aggIndexes)}
+               """.stripMargin
+          }
+        }
       }.mkString("\n")
 
       j"""
@@ -478,27 +466,30 @@ class AggregationCodeGenerator(
            |    org.apache.flink.types.Row input) throws Exception """.stripMargin
 
       val accumulate: String = {
-        for (i <- aggs.indices) yield {
-          val accumulateAcc =
+        def accumulateAcc(aggIndexes: JList[Integer]) = {
+          for (i <- aggIndexes) yield {
             j"""
-               |      ${genAccDataViewFieldSetter(s"acc$i", i)}
-               |      ${aggs(i)}.accumulate(acc$i
-               |        ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
+               |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+               |${genAccDataViewFieldSetter(s"acc$i", i)}
+               |${aggs(i)}.accumulate(acc$i
+               |  ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
                """.stripMargin
-          if (isDistinctAggs(i)) {
+          }
+        }.mkString("\n")
+
+        for ((i, aggIndexes) <- distinctAccMapping) yield {
+          if (i >= 0) {
             j"""
                |    $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
-               |    ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
-               |    if (distinctAcc$i.add(
-               |        ${classOf[Row].getCanonicalName}.of(${parametersCode(i)}))) {
-               |      ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
-               |      $accumulateAcc
+               |    ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
+               |    if (distinctAcc$i.add(${classOf[Row].getCanonicalName}.of(
+               |        ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) {
+               |      ${accumulateAcc(aggIndexes)}
                |    }
                """.stripMargin
           } else {
             j"""
-               |    ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
-               |    $accumulateAcc
+               |    ${accumulateAcc(aggIndexes)}
                """.stripMargin
           }
         }
@@ -518,27 +509,30 @@ class AggregationCodeGenerator(
            |    org.apache.flink.types.Row input) throws Exception """.stripMargin
 
       val retract: String = {
-        for (i <- aggs.indices) yield {
-          val retractAcc =
+        def retractAcc(aggIndexes: JList[Integer]) = {
+          for (i <- aggIndexes) yield {
             j"""
-               |    ${genAccDataViewFieldSetter(s"acc$i", i)}
-               |    ${aggs(i)}.retract(
-               |      acc$i ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
+               |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+               |${genAccDataViewFieldSetter(s"acc$i", i)}
+               |${aggs(i)}.retract(acc$i
+               |  ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
                """.stripMargin
-          if (isDistinctAggs(i)) {
+          }
+        }.mkString("\n")
+
+        for ((i, aggIndexes) <- distinctAccMapping) yield {
+          if (i >= 0) {
             j"""
                |    $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
-               |    ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
-               |    if (distinctAcc$i.remove(
-               |        ${classOf[Row].getCanonicalName}.of(${parametersCode(i)}))) {
-               |      ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
-               |      $retractAcc
+               |    ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
+               |    if (distinctAcc$i.remove(${classOf[Row].getCanonicalName}.of(
+               |        ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) {
+               |      ${retractAcc(aggIndexes)}
                |    }
                """.stripMargin
           } else {
             j"""
-               |    ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
-               |    $retractAcc
+               |    ${retractAcc(aggIndexes)}
                """.stripMargin
           }
         }
@@ -565,26 +559,34 @@ class AggregationCodeGenerator(
       val init: String =
         j"""
            |      org.apache.flink.types.Row accs =
-           |          new org.apache.flink.types.Row(${aggs.length});"""
+           |          new org.apache.flink.types.Row(${aggs.length + distinctAccCount});"""
         .stripMargin
       val create: String = {
-        for (i <- aggs.indices) yield {
-          if (isDistinctAggs(i)) {
+        def createAcc(aggIndexes: JList[Integer]) = {
+          for (i <- aggIndexes) yield {
+            j"""
+               |${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
+               |accs.setField(
+               |  $i,
+               |  acc$i);
+               """.stripMargin
+          }
+        }.mkString("\n")
+
+        for ((i, aggIndexes) <- distinctAccMapping) yield {
+          if (i >= 0) {
             j"""
-               |    ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
                |    $distinctAccType distinctAcc$i = ($distinctAccType)
-               |      new ${classOf[DistinctAccumulator[_]].getCanonicalName} (acc$i);
+               |      new ${classOf[DistinctAccumulator].getCanonicalName}();
                |    accs.setField(
                |      $i,
-               |      distinctAcc$i);"""
-              .stripMargin
+               |      distinctAcc$i);
+               |    ${createAcc(aggIndexes)}
+               """.stripMargin
           } else {
             j"""
-               |    ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
-               |    accs.setField(
-               |      $i,
-               |      acc$i);"""
-              .stripMargin
+               |    ${createAcc(aggIndexes)}
+               """.stripMargin
           }
         }
       }.mkString("\n")
@@ -633,8 +635,7 @@ class AggregationCodeGenerator(
     }
 
     def genMergeAccumulatorsPair: String = {
-
-      val mapping = mergeMapping.getOrElse(aggs.indices.toArray)
+      val mapping = mergeMapping.getOrElse((0 until aggs.length + distinctAccCount).toArray)
 
       val sig: String =
         j"""
@@ -643,14 +644,35 @@ class AggregationCodeGenerator(
            |    org.apache.flink.types.Row b)
            """.stripMargin
       val merge: String = {
-        for (i <- aggs.indices) yield {
-          if (isDistinctAggs(i)) {
+        def accumulateAcc(aggIndexes: JList[Integer]) = {
+          for (i <- aggIndexes) yield {
+            j"""
+               |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
+               |${aggs(i)}.accumulate(aAcc$i, ${parametersCodeForDistinctMerge(i)});
+               |a.setField($i, aAcc$i);
+               """.stripMargin
+          }
+        }.mkString("\n")
+
+        def mergeAcc(aggIndexes: JList[Integer]) = {
+          for (i <- aggIndexes) yield {
+            j"""
+               |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
+               |${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
+               |accIt$i.setElement(bAcc$i);
+               |${aggs(i)}.merge(aAcc$i, accIt$i);
+               |a.setField($i, aAcc$i);
+               """.stripMargin
+          }
+        }.mkString("\n")
+
+        for ((i, aggIndexes) <- distinctAccMapping) yield {
+          if (i >= 0) {
             j"""
                |    $distinctAccType aDistinctAcc$i = ($distinctAccType) a.getField($i);
                |    $distinctAccType bDistinctAcc$i = ($distinctAccType) b.getField(${mapping(i)});
                |    java.util.Iterator<java.util.Map.Entry> mergeIt$i =
                |        bDistinctAcc$i.elements().iterator();
-               |    ${accTypes(i)} aAcc$i = (${accTypes(i)}) aDistinctAcc$i.getRealAcc();
                |
                |    while (mergeIt$i.hasNext()) {
                |      java.util.Map.Entry entry = (java.util.Map.Entry) mergeIt$i.next();
@@ -658,18 +680,14 @@ class AggregationCodeGenerator(
                |          (${classOf[Row].getCanonicalName}) entry.getKey();
                |      Long v = (Long) entry.getValue();
                |      if (aDistinctAcc$i.add(k, v)) {
-               |        ${aggs(i)}.accumulate(aAcc$i, ${parametersCodeForDistinctMerge(i)});
+               |        ${accumulateAcc(aggIndexes)}
                |      }
                |    }
                |    a.setField($i, aDistinctAcc$i);
                """.stripMargin
           } else {
             j"""
-               |    ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
-               |    ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
-               |    accIt$i.setElement(bAcc$i);
-               |    ${aggs(i)}.merge(aAcc$i, accIt$i);
-               |    a.setField($i, aAcc$i);
+               |    ${mergeAcc(aggIndexes)}
                """.stripMargin
           }
         }
@@ -716,20 +734,28 @@ class AggregationCodeGenerator(
            |    org.apache.flink.types.Row accs) throws Exception """.stripMargin
 
       val reset: String = {
-        for (i <- aggs.indices) yield {
-          if (isDistinctAggs(i)) {
+        def resetAcc(aggIndexes: JList[Integer]) = {
+          for (i <- aggIndexes) yield {
+            j"""
+               |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+               |${genAccDataViewFieldSetter(s"acc$i", i)}
+               |${aggs(i)}.resetAccumulator(acc$i);
+               """.stripMargin
+          }
+        }.mkString("\n")
+
+        for ((i, aggIndexes) <- distinctAccMapping) yield {
+          if (i >= 0) {
             j"""
                |    $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
-               |    ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
-               |    ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
-               |    ${genAccDataViewFieldSetter(s"acc$i", i)}
+               |    ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
                |    distinctAcc$i.reset();
-               |    ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin
+               |    ${resetAcc(aggIndexes)}
+               """.stripMargin
           } else {
             j"""
-               |    ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
-               |    ${genAccDataViewFieldSetter(s"acc$i", i)}
-               |    ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin
+               |    ${resetAcc(aggIndexes)}
+               """.stripMargin
           }
         }
       }.mkString("\n")
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
index 62cad01..6097178 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
@@ -737,7 +737,7 @@ class MatchCodeGenerator(
         matchAgg.aggregations.map(_.aggFunction).toArray,
         matchAgg.aggregations.map(_.inputIndices).toArray,
         matchAgg.aggregations.indices.toArray,
-        Array.fill(matchAgg.aggregations.size)(false),
+        matchAgg.getDistinctAccMapping,
         isStateBackedDataViews = false,
         partialResults = false,
         Array.emptyIntArray,
@@ -781,18 +781,27 @@ class MatchCodeGenerator(
           callsWithIndices.map(_._2).toArray)
       })
 
+      val distinctAccMap: mutable.Map[util.Set[Integer], Integer] = mutable.Map()
       val aggs = logicalAggregates.zipWithIndex.map {
         case (agg, index) =>
           val result = AggregateUtil.extractAggregateCallMetadata(
             agg.function,
             isDistinct = false, // TODO properly set once supported in Calcite
+            distinctAccMap,
+            new util.ArrayList[Integer](), // TODO properly set once supported in Calcite
+            aggregates.length,
+            input.getArity,
             agg.inputTypes,
             needRetraction = false,
             config,
             isStateBackedDataViews = false,
             index)
 
-          SingleAggCall(result.aggregateFunction, agg.exprIndices.toArray, result.accumulatorSpecs)
+          SingleAggCall(
+            result.aggregateFunction,
+            agg.exprIndices.toArray,
+            result.accumulatorSpecs,
+            result.distinctAccIndex)
       }
 
       MatchAgg(aggs, inputRows.values.map(_._1).toSeq)
@@ -867,14 +876,25 @@ class MatchCodeGenerator(
     private case class SingleAggCall(
       aggFunction: TableAggregateFunction[_, _],
       inputIndices: Array[Int],
-      dataViews: Seq[DataViewSpec[_]]
+      dataViews: Seq[DataViewSpec[_]],
+      distinctAccIndex: Int
     )
 
     private case class MatchAgg(
       aggregations: Seq[SingleAggCall],
-      inputExprs: Seq[RexNode]
-    )
-
+      inputExprs: Seq[RexNode]) {
+
+      def getDistinctAccMapping: Array[(Integer, util.List[Integer])] = {
+        val distinctAccMapping = mutable.Map[Integer, util.List[Integer]]()
+        aggregations.map(_.distinctAccIndex).zipWithIndex.foreach {
+          case (distinctAccIndex, aggIndex) =>
+            distinctAccMapping
+              .getOrElseUpdate(distinctAccIndex, new util.ArrayList[Integer]())
+              .add(aggIndex)
+        }
+        distinctAccMapping.toArray
+      }
+    }
   }
 
 }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
index 3427c9c..1c54acc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
@@ -28,30 +28,19 @@ import org.apache.flink.types.Row
 /**
   * Wraps an accumulator and adds a map to filter distinct values.
   *
-  * @param realAcc the wrapped accumulator.
   * @param distinctValueMap the [[MapView]] that stores the distinct filter hash map.
-  *
-  * @tparam ACC the accumulator type for the realAcc.
   */
-class DistinctAccumulator[ACC](
-    var realAcc: ACC,
-    var distinctValueMap: MapView[Row, JLong]) {
+class DistinctAccumulator(var distinctValueMap: MapView[Row, JLong]) {
 
   def this() {
-    this(null.asInstanceOf[ACC], new MapView[Row, JLong]())
-  }
-
-  def this(realAcc: ACC) {
-    this(realAcc, new MapView[Row, JLong]())
+    this(new MapView[Row, JLong]())
   }
 
-  def getRealAcc: ACC = realAcc
-
-  def canEqual(a: Any): Boolean = a.isInstanceOf[DistinctAccumulator[ACC]]
+  def canEqual(a: Any): Boolean = a.isInstanceOf[DistinctAccumulator]
 
   override def equals(that: Any): Boolean =
     that match {
-      case that: DistinctAccumulator[ACC] => that.canEqual(this) &&
+      case that: DistinctAccumulator => that.canEqual(this) &&
         this.distinctValueMap == that.distinctValueMap
       case _ => false
     }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 4faa6ed..1e58612 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -464,15 +464,15 @@ object UserDefinedFunctionUtils {
     * Remove StateView fields from accumulator type information.
     *
     * @param index index of aggregate function
-    * @param aggFun aggregate function
+    * @param acc accumulator
     * @param accType accumulator type information, only support pojo type
     * @param isStateBackedDataViews is data views use state backend
     * @return mapping of accumulator type information and data view config which contains id,
     *         field name and state descriptor
     */
-  def removeStateViewFieldsFromAccTypeInfo(
+  def removeStateViewFieldsFromAccTypeInfo[ACC](
       index: Int,
-      aggFun: AggregateFunction[_, _],
+      acc: ACC,
       accType: TypeInformation[_],
       isStateBackedDataViews: Boolean)
     : (TypeInformation[_], Option[Seq[DataViewSpec[_]]]) = {
@@ -489,7 +489,6 @@ object UserDefinedFunctionUtils {
       )
     }
 
-    val acc = aggFun.createAccumulator()
     accType match {
       case pojoType: PojoTypeInfo[_] if pojoType.getArity > 0 =>
         val arity = pojoType.getArity
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index f5cf191..7c38254 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -18,6 +18,7 @@
 package org.apache.flink.table.runtime.aggregate
 
 import java.util
+import java.util.{ArrayList => JArrayList, List => JList}
 
 import org.apache.calcite.rel.`type`._
 import org.apache.calcite.rel.core.AggregateCall
@@ -27,7 +28,8 @@ import org.apache.calcite.sql.fun._
 import org.apache.calcite.sql.{SqlAggFunction, SqlKind}
 import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction, AggregateFunction => DataStreamAggFunction, _}
 import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, RowTypeInfo}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.scala.typeutils.Types
 import org.apache.flink.streaming.api.functions.ProcessFunction
 import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
 import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
@@ -36,7 +38,6 @@ import org.apache.flink.table.api.{StreamQueryConfig, TableConfig, TableExceptio
 import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
 import org.apache.flink.table.calcite.FlinkTypeFactory
 import org.apache.flink.table.codegen.AggregationCodeGenerator
-import org.apache.flink.table.dataview.MapViewTypeInfo
 import org.apache.flink.table.expressions.ExpressionUtils.isTimeIntervalLiteral
 import org.apache.flink.table.expressions._
 import org.apache.flink.table.functions.aggfunctions._
@@ -51,11 +52,11 @@ import org.apache.flink.types.Row
 
 import scala.collection.JavaConversions._
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 object AggregateUtil {
 
   type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R]
-  type JavaList[T] = java.util.List[T]
 
   /**
     * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for unbounded OVER
@@ -88,6 +89,7 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
         namedAggregates.map(_.getKey),
         aggregateInputType,
+        inputFieldTypeInfo.length,
         needRetraction = false,
         tableConfig,
         isStateBackedDataViews = true)
@@ -104,7 +106,7 @@ object AggregateUtil {
       aggregateMetadata.getAggregateFunctions,
       aggregateMetadata.getAggregateIndices,
       aggMapping,
-      aggregateMetadata.getAggregatesDistinctFlags,
+      aggregateMetadata.getDistinctAccMapping,
       isStateBackedDataViews = true,
       partialResults = false,
       forwardMapping,
@@ -172,6 +174,7 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
         namedAggregates.map(_.getKey),
         inputRowType,
+        inputFieldTypes.length,
         consumeRetraction,
         tableConfig,
         isStateBackedDataViews = true)
@@ -185,7 +188,7 @@ object AggregateUtil {
       aggregateMetadata.getAggregateFunctions,
       aggregateMetadata.getAggregateIndices,
       aggMapping,
-      aggregateMetadata.getAggregatesDistinctFlags,
+      aggregateMetadata.getDistinctAccMapping,
       isStateBackedDataViews = true,
       partialResults = false,
       groupings,
@@ -240,6 +243,7 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
         namedAggregates.map(_.getKey),
         aggregateInputType,
+        inputFieldTypeInfo.length,
         needRetract,
         tableConfig,
         isStateBackedDataViews = true)
@@ -257,7 +261,7 @@ object AggregateUtil {
       aggregateMetadata.getAggregateFunctions,
       aggregateMetadata.getAggregateIndices,
       aggMapping,
-      aggregateMetadata.getAggregatesDistinctFlags,
+      aggregateMetadata.getDistinctAccMapping,
       isStateBackedDataViews = true,
       partialResults = false,
       forwardMapping,
@@ -346,6 +350,7 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       inputType,
+      inputFieldTypeInfo.length,
       needRetract,
       tableConfig)
 
@@ -385,8 +390,9 @@ object AggregateUtil {
         throw new UnsupportedOperationException(s"$window is currently not supported on batch")
     }
 
-    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
-    val outputArity = aggregateMetadata.getAggregateCallsCount + groupings.length + 1
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+    val outputArity = aggregateMetadata.getAggregateCallsCount + groupings.length +
+      aggregateMetadata.getDistinctAccCount + 1
 
     val genFunction = generator.generateAggregations(
       "DataSetAggregatePrepareMapHelper",
@@ -394,7 +400,7 @@ object AggregateUtil {
       aggregateMetadata.getAggregateFunctions,
       aggregateMetadata.getAggregateIndices,
       aggMapping,
-      aggregateMetadata.getAggregatesDistinctFlags,
+      aggregateMetadata.getDistinctAccMapping,
       isStateBackedDataViews = false,
       partialResults = true,
       groupings,
@@ -455,6 +461,7 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
+      physicalInputTypes.length,
       needRetract,
       tableConfig)
 
@@ -465,19 +472,20 @@ object AggregateUtil {
       physicalInputRowType,
       Some(Array(BasicTypeInfo.LONG_TYPE_INFO)))
 
-    val keysAndAggregatesArity = groupings.length + namedAggregates.length
+    val aggMappings = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+    val keysAndAggregatesArity =
+      groupings.length + namedAggregates.length + aggregateMetadata.getDistinctAccCount
 
     window match {
       case SlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) =>
         // sliding time-window for partial aggregations
-        val aggMappings = aggregateMetadata.getAdjustedMapping(groupings.length)
         val genFunction = generator.generateAggregations(
           "DataSetAggregatePrepareMapHelper",
           physicalInputTypes,
           aggregateMetadata.getAggregateFunctions,
           aggregateMetadata.getAggregateIndices,
           aggMappings,
-          aggregateMetadata.getAggregatesDistinctFlags,
+          aggregateMetadata.getDistinctAccMapping,
           isStateBackedDataViews = false,
           partialResults = true,
           groupings.indices.toArray,
@@ -573,10 +581,11 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
+      physicalInputTypes.length,
       needRetract,
       tableConfig)
 
-    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
 
     val genPreAggFunction = generator.generateAggregations(
       "GroupingWindowAggregateHelper",
@@ -584,12 +593,12 @@ object AggregateUtil {
       aggregateMetadata.getAggregateFunctions,
       aggregateMetadata.getAggregateIndices,
       aggMapping,
-      aggregateMetadata.getAggregatesDistinctFlags,
+      aggregateMetadata.getDistinctAccMapping,
       isStateBackedDataViews = false,
       partialResults = true,
       groupings.indices.toArray,
       Some(aggMapping),
-      outputType.getFieldCount,
+      outputType.getFieldCount + aggregateMetadata.getDistinctAccCount,
       needRetract,
       needMerge = true,
       needReset = true,
@@ -602,7 +611,7 @@ object AggregateUtil {
       aggregateMetadata.getAggregateFunctions,
       aggregateMetadata.getAggregateIndices,
       aggMapping,
-      aggregateMetadata.getAggregatesDistinctFlags,
+      aggregateMetadata.getDistinctAccMapping,
       isStateBackedDataViews = false,
       partialResults = false,
       groupings.indices.toArray,
@@ -730,12 +739,13 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
+      physicalInputTypes.length,
       needRetract,
       tableConfig)
 
-    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
-
-    val keysAndAggregatesArity = groupings.length + namedAggregates.length
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+    val keysAndAggregatesArity = groupings.length + aggregateMetadata.getAggregateCallsCount +
+      aggregateMetadata.getDistinctAccCount
 
     window match {
       case SessionGroupWindow(_, _, gap) =>
@@ -753,12 +763,12 @@ object AggregateUtil {
           aggregateMetadata.getAggregateFunctions,
           aggregateMetadata.getAggregateIndices,
           aggMapping,
-          aggregateMetadata.getAggregatesDistinctFlags,
+          aggregateMetadata.getDistinctAccMapping,
           isStateBackedDataViews = false,
           partialResults = true,
           groupings.indices.toArray,
           Some(aggMapping),
-          groupings.length + aggregateMetadata.getAggregateCallsCount + 2,
+          keysAndAggregatesArity + 2,
           needRetract,
           needMerge = true,
           needReset = true,
@@ -807,11 +817,13 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
+      physicalInputTypes.length,
       needRetract,
       tableConfig)
 
-    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
-    val keysAndAggregatesArity = groupings.length + namedAggregates.length
+    val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+    val keysAndAggregatesArity =
+      groupings.length + namedAggregates.length + aggregateMetadata.getDistinctAccCount
 
     window match {
 
@@ -830,7 +842,7 @@ object AggregateUtil {
           aggregateMetadata.getAggregateFunctions,
           aggregateMetadata.getAggregateIndices,
           aggMapping,
-          aggregateMetadata.getAggregatesDistinctFlags,
+          aggregateMetadata.getDistinctAccMapping,
           isStateBackedDataViews = false,
           partialResults = true,
           groupings.indices.toArray,
@@ -876,6 +888,7 @@ object AggregateUtil {
     val aggregateMetadata = extractAggregateMetadata(
       namedAggregates.map(_.getKey),
       inputType,
+      inputFieldTypeInfo.length,
       needRetract,
       tableConfig)
 
@@ -890,7 +903,9 @@ object AggregateUtil {
 
     if (doAllSupportPartialMerge(aggregateMetadata.getAggregateFunctions)) {
 
-      val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
+      val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+      val keysAndAggregatesArity = groupings.length + aggregateMetadata.getAggregateCallsCount +
+        aggregateMetadata.getDistinctAccCount
 
       // compute preaggregation type
       val preAggFieldTypes = gkeyOutMapping.map(_._2)
@@ -904,12 +919,12 @@ object AggregateUtil {
         aggregateMetadata.getAggregateFunctions,
         aggregateMetadata.getAggregateIndices,
         aggMapping,
-        aggregateMetadata.getAggregatesDistinctFlags,
+        aggregateMetadata.getDistinctAccMapping,
         isStateBackedDataViews = false,
         partialResults = true,
         groupings,
         None,
-        groupings.length + aggregateMetadata.getAggregateCallsCount,
+        keysAndAggregatesArity,
         needRetract,
         needMerge = false,
         needReset = true,
@@ -932,7 +947,7 @@ object AggregateUtil {
         aggregateMetadata.getAggregateFunctions,
         aggregateMetadata.getAggregateIndices,
         aggOutFields,
-        aggregateMetadata.getAggregatesDistinctFlags,
+        aggregateMetadata.getDistinctAccMapping,
         isStateBackedDataViews = false,
         partialResults = false,
         gkeyMapping,
@@ -957,7 +972,7 @@ object AggregateUtil {
         aggregateMetadata.getAggregateFunctions,
         aggregateMetadata.getAggregateIndices,
         aggOutFields,
-        aggregateMetadata.getAggregatesDistinctFlags,
+        aggregateMetadata.getDistinctAccMapping,
         isStateBackedDataViews = false,
         partialResults = false,
         groupings,
@@ -1046,6 +1061,7 @@ object AggregateUtil {
       extractAggregateMetadata(
         namedAggregates.map(_.getKey),
         inputType,
+        inputFieldTypeInfo.length,
         needRetract,
         tableConfig)
 
@@ -1058,7 +1074,7 @@ object AggregateUtil {
       aggregateMetadata.getAggregateFunctions,
       aggregateMetadata.getAggregateIndices,
       aggMapping,
-      aggregateMetadata.getAggregatesDistinctFlags,
+      aggregateMetadata.getDistinctAccMapping,
       isStateBackedDataViews = false,
       partialResults = false,
       groupingKeys,
@@ -1088,6 +1104,7 @@ object AggregateUtil {
     val aggregateList = extractAggregateMetadata(
       aggregateCalls,
       inputType,
+      inputType.getFieldCount,
       needRetraction = false,
       tableConfig).getAggregateFunctions
 
@@ -1173,22 +1190,30 @@ object AggregateUtil {
     * [[GeneratedAggregations]] function.
     */
   private[flink] class AggregateMetadata(
-    private val aggregates: Seq[(AggregateCallMetadata, Array[Int])]) {
+    private val aggregates: Seq[(AggregateCallMetadata, Array[Int])],
+    private val distinctAccTypesWithSpecs: Seq[(TypeInformation[_], Seq[DataViewSpec[_]])]) {
 
     def getAggregateFunctions: Array[TableAggregateFunction[_, _]] = {
       aggregates.map(_._1.aggregateFunction).toArray
     }
 
     def getAggregatesAccumulatorTypes: Array[TypeInformation[_]] = {
-      aggregates.map(_._1.accumulatorType).toArray
+      aggregates.map(_._1.accumulatorType).toArray ++ distinctAccTypesWithSpecs.map(_._1)
     }
 
     def getAggregatesAccumulatorSpecs: Array[Seq[DataViewSpec[_]]] = {
-      aggregates.map(_._1.accumulatorSpecs).toArray
+      aggregates.map(_._1.accumulatorSpecs).toArray ++ distinctAccTypesWithSpecs.map(_._2)
     }
 
-    def getAggregatesDistinctFlags: Array[Boolean] = {
-      aggregates.map(_._1.isDistinct).toArray
+    def getDistinctAccMapping: Array[(Integer, JList[Integer])] = {
+      val distinctAccMapping = mutable.Map[Integer, JList[Integer]]()
+      aggregates.map(_._1.distinctAccIndex).zipWithIndex.foreach {
+        case (distinctAccIndex, aggIndex) =>
+          distinctAccMapping
+            .getOrElseUpdate(distinctAccIndex, new JArrayList[Integer]())
+            .add(aggIndex)
+      }
+      distinctAccMapping.toArray
     }
 
     def getAggregateCallsCount: Int = {
@@ -1199,8 +1224,13 @@ object AggregateUtil {
       aggregates.map(_._2).toArray
     }
 
-    def getAdjustedMapping(offset: Int): Array[Int] = {
-      (0 until getAggregateCallsCount).map(_ + offset).toArray
+    def getAdjustedMapping(offset: Int, partialResults: Boolean = false): Array[Int] = {
+      val accCount = getAggregateCallsCount + (if (partialResults) getDistinctAccCount else 0)
+      (0 until accCount).map(_ + offset).toArray
+    }
+
+    def getDistinctAccCount: Int = {
+      getDistinctAccMapping.count(_._1 >= 0)
     }
   }
 
@@ -1212,7 +1242,7 @@ object AggregateUtil {
     aggregateFunction: TableAggregateFunction[_, _],
     accumulatorType: TypeInformation[_],
     accumulatorSpecs: Seq[DataViewSpec[_]],
-    isDistinct: Boolean
+    distinctAccIndex: Int
   )
 
   /**
@@ -1221,6 +1251,11 @@ object AggregateUtil {
     *
     * @param aggregateFunction calcite's aggregate function
     * @param isDistinct true if should be distinct aggregation
+    * @param distinctAccMap mapping of the distinct aggregate input fields index
+    *                       to the corresponding acc index
+    * @param argList indexes of the input fields of given aggregates
+    * @param aggregateCount number of aggregates
+    * @param inputFieldsCount number of input fields
     * @param aggregateInputTypes input types of given aggregate
     * @param needRetraction if the [[TableAggregateFunction]] should produce retractions
     * @param tableConfig tableConfig, required for decimal precision
@@ -1235,6 +1270,10 @@ object AggregateUtil {
   private[flink] def extractAggregateCallMetadata(
       aggregateFunction: SqlAggFunction,
       isDistinct: Boolean,
+      distinctAccMap: mutable.Map[util.Set[Integer], Integer],
+      argList: util.List[Integer],
+      aggregateCount: Int,
+      inputFieldsCount: Int,
       aggregateInputTypes: Seq[RelDataType],
       needRetraction: Boolean,
       tableConfig: TableConfig,
@@ -1258,48 +1297,36 @@ object AggregateUtil {
 
       removeStateViewFieldsFromAccTypeInfo(
         uniqueIdWithinAggregate,
-        aggregate,
+        aggregate.createAccumulator(),
         accType,
         isStateBackedDataViews)
     }
 
     // create distinct accumulator filter argument
-    val distinctAccumulatorType = if (isDistinct) {
-      createDistinctAccumulatorType(aggregateInputTypes, isStateBackedDataViews, accumulatorType)
+    val distinctAccIndex = if (isDistinct) {
+      getDistinctAccIndex(distinctAccMap, argList, aggregateCount, inputFieldsCount)
     } else {
-      accumulatorType
+      -1
     }
 
-    AggregateCallMetadata(aggregate, distinctAccumulatorType, accSpecs.getOrElse(Seq()), isDistinct)
+    AggregateCallMetadata(aggregate, accumulatorType, accSpecs.getOrElse(Seq()), distinctAccIndex)
   }
 
-  private def createDistinctAccumulatorType(
-      aggregateInputTypes: Seq[RelDataType],
-      isStateBackedDataViews: Boolean,
-      accumulatorType: TypeInformation[_])
-    : PojoTypeInfo[DistinctAccumulator[_]] = {
-    // Using Pojo fields for the real underlying accumulator
-    val pojoFields = new util.ArrayList[PojoField]()
-    pojoFields.add(new PojoField(
-      classOf[DistinctAccumulator[_]].getDeclaredField("realAcc"),
-      accumulatorType)
-    )
-    // If StateBackend is not enabled, the distinct mapping also needs
-    // to be added to the Pojo fields.
-    if (!isStateBackedDataViews) {
-
-      val argTypes: Array[TypeInformation[_]] = aggregateInputTypes
-        .map(FlinkTypeFactory.toTypeInfo).toArray
-
-      val mapViewTypeInfo = new MapViewTypeInfo(
-        new RowTypeInfo(argTypes: _*),
-        BasicTypeInfo.LONG_TYPE_INFO)
-      pojoFields.add(new PojoField(
-        classOf[DistinctAccumulator[_]].getDeclaredField("distinctValueMap"),
-        mapViewTypeInfo)
-      )
+  private def getDistinctAccIndex(
+      distinctAccMap: mutable.Map[util.Set[Integer], Integer],
+      argList: util.List[Integer],
+      aggregateCount: Int,
+      inputFieldsCount: Int): Int = {
+
+    val argListWithoutConstants = argList.toSet.filter(i => i > -1 && i < inputFieldsCount)
+    distinctAccMap.get(argListWithoutConstants) match {
+      case None =>
+        val index: Int = aggregateCount + distinctAccMap.size
+        distinctAccMap.put(argListWithoutConstants, index)
+        index
+
+      case Some(index) => index
     }
-    new PojoTypeInfo(classOf[DistinctAccumulator[_]], pojoFields)
   }
 
   /**
@@ -1308,6 +1335,7 @@ object AggregateUtil {
     *
     * @param aggregateCalls calcite's aggregate function
     * @param aggregateInputType input type of given aggregates
+    * @param inputFieldsCount number of input fields
     * @param needRetraction if the [[TableAggregateFunction]] should produce retractions
     * @param tableConfig tableConfig, required for decimal precision
     * @param isStateBackedDataViews if data should be backed by state backend
@@ -1320,11 +1348,14 @@ object AggregateUtil {
   private def extractAggregateMetadata(
       aggregateCalls: Seq[AggregateCall],
       aggregateInputType: RelDataType,
+      inputFieldsCount: Int,
       needRetraction: Boolean,
       tableConfig: TableConfig,
       isStateBackedDataViews: Boolean = false)
     : AggregateMetadata = {
 
+    val distinctAccMap = mutable.Map[util.Set[Integer], Integer]()
+
     val aggregatesWithIndices = aggregateCalls.zipWithIndex.map {
       case (aggregateCall, index) =>
         val argList: util.List[Integer] = aggregateCall.getArgList
@@ -1340,8 +1371,13 @@ object AggregateUtil {
         }
 
         val inputTypes = argList.map(aggregateInputType.getFieldList.get(_).getType)
-        val aggregateCallMetadata = extractAggregateCallMetadata(aggregateCall.getAggregation,
+        val aggregateCallMetadata = extractAggregateCallMetadata(
+          aggregateCall.getAggregation,
           aggregateCall.isDistinct,
+          distinctAccMap,
+          argList,
+          aggregateCalls.length,
+          inputFieldsCount,
           inputTypes,
           needRetraction,
           tableConfig,
@@ -1351,8 +1387,19 @@ object AggregateUtil {
         (aggregateCallMetadata, aggFieldIndices)
     }
 
+    val distinctAccType = Types.POJO(classOf[DistinctAccumulator])
+
+    val distinctAccTypesWithSpecs = (0 until distinctAccMap.size).map { idx =>
+      val (accType, accSpec) = removeStateViewFieldsFromAccTypeInfo(
+        aggregateCalls.length + idx,
+        new DistinctAccumulator(),
+        distinctAccType,
+        isStateBackedDataViews)
+      (accType, accSpec.getOrElse(Seq()))
+    }
+
     // store the aggregate fields of each aggregate function, by the same order of aggregates.
-    new AggregateMetadata(aggregatesWithIndices)
+    new AggregateMetadata(aggregatesWithIndices, distinctAccTypesWithSpecs)
   }
 
   /**
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
index 0483c40..d22fc18 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
@@ -349,4 +349,80 @@ public class JavaUserDefinedAggFunctions {
 			isCloseCalled = true;
 		}
 	}
+
+	/**
+	 * Count accumulator.
+	 */
+	public static class MultiArgCountAcc {
+		public long count;
+	}
+
+	/**
+	 * Count aggregate function with multiple arguments.
+	 */
+	public static class MultiArgCount extends AggregateFunction<Long, MultiArgCountAcc> {
+
+		@Override
+		public MultiArgCountAcc createAccumulator() {
+			MultiArgCountAcc acc = new MultiArgCountAcc();
+			acc.count = 0L;
+			return acc;
+		}
+
+		public void accumulate(MultiArgCountAcc acc, Object in1, Object in2) {
+			if (in1 != null && in2 != null) {
+				acc.count += 1;
+			}
+		}
+
+		public void retract(MultiArgCountAcc acc, Object in1, Object in2) {
+			if (in1 != null && in2 != null) {
+				acc.count -= 1;
+			}
+		}
+
+		public void merge(MultiArgCountAcc accumulator, java.lang.Iterable<MultiArgCountAcc> iterable) {
+			for (MultiArgCountAcc otherAcc : iterable) {
+				accumulator.count += otherAcc.count;
+			}
+		}
+
+		@Override
+		public Long getValue(MultiArgCountAcc acc) {
+			return acc.count;
+		}
+	}
+
+	/**
+	 * Sum accumulator.
+	 */
+	public static class MultiArgSumAcc {
+		public long count;
+	}
+
+	/**
+	 * Sum aggregate function with multiple arguments.
+	 */
+	public static class MultiArgSum extends AggregateFunction<Long, MultiArgSumAcc> {
+
+		@Override
+		public MultiArgSumAcc createAccumulator() {
+			MultiArgSumAcc acc = new MultiArgSumAcc();
+			acc.count = 0L;
+			return acc;
+		}
+
+		public void accumulate(MultiArgSumAcc acc, long in1, long in2) {
+			acc.count += in1 + in2;
+		}
+
+		public void retract(MultiArgSumAcc acc, long in1, long in2) {
+			acc.count -= in1 + in2;
+		}
+
+		@Override
+		public Long getValue(MultiArgSumAcc acc) {
+			return acc.count;
+		}
+	}
 }
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
index 1dce994..cc81161 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
@@ -22,13 +22,22 @@ import java.util.concurrent.ConcurrentLinkedQueue
 
 import org.apache.flink.api.common.time.Time
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.api.scala._
 import org.apache.flink.streaming.api.operators.LegacyKeyedProcessOperator
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.TableEnvironment
 import org.apache.flink.table.runtime.aggregate._
 import org.apache.flink.table.runtime.harness.HarnessTestBase._
 import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{MultiArgCount, MultiArgSum}
+import org.apache.flink.types.Row
+import org.junit.Assert.assertTrue
 import org.junit.Test
 
+import scala.collection.mutable
+
 class GroupAggregateHarnessTest extends HarnessTestBase {
 
   protected var queryConfig =
@@ -176,68 +185,182 @@ class GroupAggregateHarnessTest extends HarnessTestBase {
 
   @Test
   def testDistinctAggregateWithRetract(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+
+    val data = new mutable.MutableList[(JLong, JInt)]
+    val t = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+    tEnv.registerTable("T", t)
+    val sqlQuery = tEnv.sqlQuery(
+      s"""
+         |SELECT
+         |  a, count(distinct b), sum(distinct b)
+         |FROM (
+         |  SELECT a, b
+         |  FROM T
+         |  GROUP BY a, b
+         |) GROUP BY a
+         |""".stripMargin)
+
+    val testHarness = createHarnessTester[String, CRow, CRow](
+      sqlQuery.toRetractStream[Row](queryConfig), "groupBy")
+
+    testHarness.setStateBackend(getStateBackend)
+    testHarness.open()
 
-    val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow](
-      new GroupAggProcessFunction(
-        genDistinctCountAggFunction,
-        distinctCountAggregationStateType,
-        true,
-        queryConfig))
+    val operator = getOperator(testHarness)
+    val fields = getGeneratedAggregationFields(
+      operator,
+      "function",
+      classOf[GroupAggProcessFunction])
 
-    val testHarness =
-      createHarnessTester(
-        processFunction,
-        new TupleRowKeySelector[String](2),
-        BasicTypeInfo.STRING_TYPE_INFO)
+    // check only one DistinctAccumulator is used
+    assertTrue(fields.count(_.getName.endsWith("distinctValueMap_dataview")) == 1)
+
+    val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+    // register cleanup timer with 3001
+    testHarness.setProcessingTime(1)
+
+    // insert
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 1: JInt)))
+    testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong, 1: JInt)))
+
+    // distinct count retract then accumulate for downstream operators
+    testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(false, 2L: JLong, 1L: JLong, 1: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong, 1: JInt)))
+
+    // update count for accumulate
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 1L: JLong, 1: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong, 3: JInt)))
+
+    // update count for retraction
+    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 3: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 1: JInt)))
+
+    // insert
+    testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(4L: JLong, 1L: JLong, 3: JInt)))
+
+    // retract entirely
+    testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(false, 4L: JLong, 1L: JLong, 3: JInt)))
+
+    // trigger cleanup timer and register cleanup timer with 6002
+    testHarness.setProcessingTime(3002)
+
+    // insert
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 1: JInt)))
+
+    // trigger cleanup timer and register cleanup timer with 9002
+    testHarness.setProcessingTime(6002)
+
+    // retract after cleanup
+    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 1: JInt, 1L: JLong)))
 
+    val result = testHarness.getOutput
+
+    verify(expectedOutput, result)
+
+    testHarness.close()
+  }
+
+  @Test
+  def testDistinctAggregateWithDifferentArgumentOrder(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+
+    val data = new mutable.MutableList[(JLong, JLong, JLong)]
+    val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c)
+    tEnv.registerTable("T", t)
+    tEnv.registerFunction("myCount", new MultiArgCount)
+    tEnv.registerFunction("mySum", new MultiArgSum)
+    val sqlQuery = tEnv.sqlQuery(
+      s"""
+         |SELECT
+         |  a, myCount(distinct b, c), mySum(distinct c, b)
+         |FROM (
+         |  SELECT a, b, c
+         |  FROM T
+         |  GROUP BY a, b, c
+         |) GROUP BY a
+         |""".stripMargin)
+
+    val testHarness = createHarnessTester[String, CRow, CRow](
+      sqlQuery.toRetractStream[Row](queryConfig), "groupBy")
+
+    testHarness.setStateBackend(getStateBackend)
     testHarness.open()
 
+    val operator = getOperator(testHarness)
+    val fields = getGeneratedAggregationFields(
+      operator,
+      "function",
+      classOf[GroupAggProcessFunction])
+
+    // check only one DistinctAccumulator is used
+    assertTrue(fields.count(_.getName.endsWith("distinctValueMap_dataview")) == 1)
+
     val expectedOutput = new ConcurrentLinkedQueue[Object]()
 
     // register cleanup timer with 3001
     testHarness.setProcessingTime(1)
 
     // insert
-    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa")))
-    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
-    testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb")))
-    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1L: JLong, 1L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 2L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1L: JLong, 1L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong, 2L: JLong)))
 
     // distinct count retract then accumulate for downstream operators
-    testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb")))
-    expectedOutput.add(new StreamRecord(CRow(false, 2L: JLong, 1L: JLong)))
-    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1L: JLong, 1L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(false, 2L: JLong, 1L: JLong, 2L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong, 2L: JLong)))
 
     // update count for accumulate
-    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2: JInt, "aaa")))
-    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 1L: JLong)))
-    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2L: JLong, 3L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 1L: JLong, 2L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong, 7L: JLong)))
+
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2L: JLong, 3L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 7L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong, 7L: JLong)))
 
     // update count for retraction
-    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2: JInt, "aaa")))
-    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong)))
-    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 3L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 7L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong, 7L: JLong)))
+
+    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 3L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 7L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 2L: JLong)))
 
     // insert
-    testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt, "ccc")))
-    expectedOutput.add(new StreamRecord(CRow(4L: JLong, 1L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3L: JLong, 3L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(4L: JLong, 1L: JLong, 6L: JLong)))
 
     // retract entirely
-    testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3: JInt, "ccc")))
-    expectedOutput.add(new StreamRecord(CRow(false, 4L: JLong, 1L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3L: JLong, 3L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(false, 4L: JLong, 1L: JLong, 6L: JLong)))
 
     // trigger cleanup timer and register cleanup timer with 6002
     testHarness.setProcessingTime(3002)
 
     // insert
-    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa")))
-    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1L: JLong, 2L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 3L: JLong)))
 
     // trigger cleanup timer and register cleanup timer with 9002
     testHarness.setProcessingTime(6002)
 
     // retract after cleanup
-    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 1: JInt, "aaa")))
+    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 1: JInt, 2L: JLong)))
 
     val result = testHarness.getOutput
 
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index c37fd0c..cf7dc07 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -17,9 +17,9 @@
  */
 package org.apache.flink.table.runtime.harness
 
+import java.lang.reflect.Field
 import java.util.{Comparator, Queue => JQueue}
 
-import org.apache.flink.api.common.state.{MapStateDescriptor, StateDescriptor}
 import org.apache.flink.api.common.time.Time
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRING_TYPE_INFO}
 import org.apache.flink.api.common.typeinfo.TypeInformation
@@ -32,7 +32,7 @@ import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
 import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness, TestHarnessUtil}
 import org.apache.flink.table.api.dataview.DataView
-import org.apache.flink.table.api.{StreamQueryConfig, Types}
+import org.apache.flink.table.api.StreamQueryConfig
 import org.apache.flink.table.codegen.GeneratedAggregationsFunction
 import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
@@ -83,15 +83,8 @@ class HarnessTestBase extends StreamingWithStateTestBase {
   protected val sumAggregationStateType: RowTypeInfo =
     new RowTypeInfo(sumAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)
 
-  protected val distinctCountAggregationStateType: RowTypeInfo =
-    new RowTypeInfo(distinctCountAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)
-
-  protected val distinctCountDescriptor: String = EncodingUtils.encodeObjectToString(
-    new MapStateDescriptor("distinctAgg0", distinctCountAggregationStateType, Types.LONG))
-
   protected val minMaxFuncName = "MinMaxAggregateHelper"
   protected val sumFuncName = "SumAggregationHelper"
-  protected val distinctCountFuncName = "DistinctCountAggregationHelper"
 
   val minMaxCode: String =
     s"""
@@ -326,170 +319,8 @@ class HarnessTestBase extends StreamingWithStateTestBase {
       |}
       |""".stripMargin
 
-    val distinctCountAggCode: String =
-    s"""
-      |public final class $distinctCountFuncName
-      |  extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
-      |
-      |  final org.apache.flink.table.functions.aggfunctions.CountAggFunction count;
-      |
-      |  final org.apache.flink.table.api.dataview.MapView acc0_distinctValueMap_dataview;
-      |
-      |  final java.lang.reflect.Field distinctValueMap =
-      |      org.apache.flink.api.java.typeutils.TypeExtractor.getDeclaredField(
-      |        org.apache.flink.table.functions.aggfunctions.DistinctAccumulator.class,
-      |        "distinctValueMap");
-      |
-      |
-      |  private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache
-      |    .flink.table.functions.aggfunctions.CountAccumulator> accIt0 =
-      |      new org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache.flink
-      |      .table
-      |      .functions.aggfunctions.CountAccumulator>();
-      |
-      |  public $distinctCountFuncName() throws Exception {
-      |
-      |    count = (org.apache.flink.table.functions.aggfunctions.CountAggFunction)
-      |      ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
-      |      "$distinctCountAggFunction",
-      |      ${classOf[UserDefinedFunction].getCanonicalName}.class);
-      |
-      |    distinctValueMap.setAccessible(true);
-      |  }
-      |
-      |  public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) {
-      |    org.apache.flink.api.common.state.StateDescriptor acc0_distinctValueMap_dataview_desc =
-      |      (org.apache.flink.api.common.state.StateDescriptor)
-      |      ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
-      |      "$distinctCountDescriptor",
-      |      ${classOf[StateDescriptor[_, _]].getCanonicalName}.class,
-      |      ctx.getUserCodeClassLoader());
-      |    acc0_distinctValueMap_dataview = new org.apache.flink.table.dataview.StateMapView(
-      |          ctx.getMapState((org.apache.flink.api.common.state.MapStateDescriptor)
-      |          acc0_distinctValueMap_dataview_desc));
-      |  }
-      |
-      |  public final void setAggregationResults(
-      |    org.apache.flink.types.Row accs,
-      |    org.apache.flink.types.Row output) {
-      |
-      |    org.apache.flink.table.functions.AggregateFunction baseClass0 =
-      |      (org.apache.flink.table.functions.AggregateFunction)
-      |      count;
-      |
-      |    org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
-      |      (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
-      |    org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
-      |      (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
-      |      distinctAcc0.getRealAcc();
-      |
-      |    output.setField(1, baseClass0.getValue(acc0));
-      |  }
-      |
-      |  public final void accumulate(
-      |    org.apache.flink.types.Row accs,
-      |    org.apache.flink.types.Row input) throws Exception {
-      |
-      |    org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
-      |      (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
-      |
-      |    distinctValueMap.set(distinctAcc0, acc0_distinctValueMap_dataview);
-      |
-      |    if (distinctAcc0.add(
-      |          org.apache.flink.types.Row.of((java.lang.Integer) input.getField(1)))) {
-      |        org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
-      |          (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
-      |          distinctAcc0.getRealAcc();
-      |
-      |
-      |        count.accumulate(acc0, (java.lang.Integer) input.getField(1));
-      |    }
-      |  }
-      |
-      |  public final void retract(
-      |    org.apache.flink.types.Row accs,
-      |    org.apache.flink.types.Row input) throws Exception {
-      |
-      |    org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
-      |      (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
-      |
-      |    distinctValueMap.set(distinctAcc0, acc0_distinctValueMap_dataview);
-      |
-      |    if (distinctAcc0.remove(
-      |          org.apache.flink.types.Row.of((java.lang.Integer) input.getField(1)))) {
-      |        org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
-      |          (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
-      |            distinctAcc0.getRealAcc();
-      |
-      |        count.retract(acc0 , (java.lang.Integer) input.getField(1));
-      |      }
-      |  }
-      |
-      |  public final org.apache.flink.types.Row createAccumulators()
-      |     {
-      |
-      |      org.apache.flink.types.Row accs = new org.apache.flink.types.Row(1);
-      |
-      |      org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
-      |        (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
-      |        count.createAccumulator();
-      |      org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
-      |        (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator)
-      |        new org.apache.flink.table.functions.aggfunctions.DistinctAccumulator (acc0);
-      |      accs.setField(
-      |        0,
-      |        distinctAcc0);
-      |
-      |        return accs;
-      |  }
-      |
-      |  public final void setForwardedFields(
-      |    org.apache.flink.types.Row input,
-      |    org.apache.flink.types.Row output)
-      |     {
-      |
-      |    output.setField(
-      |      0,
-      |      input.getField(0));
-      |  }
-      |
-      |  public final void setConstantFlags(org.apache.flink.types.Row output)
-      |     {
-      |
-      |  }
-      |
-      |  public final org.apache.flink.types.Row createOutputRow() {
-      |    return new org.apache.flink.types.Row(2);
-      |  }
-      |
-      |
-      |  public final org.apache.flink.types.Row mergeAccumulatorsPair(
-      |    org.apache.flink.types.Row a,
-      |    org.apache.flink.types.Row b)
-      |            {
-      |
-      |      return a;
-      |
-      |  }
-      |
-      |  public final void resetAccumulator(
-      |    org.apache.flink.types.Row accs) {
-      |  }
-      |
-      |  public void cleanup() {
-      |    acc0_distinctValueMap_dataview.clear();
-      |  }
-      |
-      |  public void close() {
-      |  }
-      |}
-      |""".stripMargin
-
   protected val genMinMaxAggFunction = GeneratedAggregationsFunction(minMaxFuncName, minMaxCode)
   protected val genSumAggFunction = GeneratedAggregationsFunction(sumFuncName, sumAggCode)
-  protected val genDistinctCountAggFunction = GeneratedAggregationsFunction(
-    distinctCountFuncName,
-    distinctCountAggCode)
 
   def createHarnessTester[KEY, IN, OUT](
       dataStream: DataStream[_],
@@ -553,6 +384,20 @@ class HarnessTestBase extends StreamingWithStateTestBase {
     stateField.get(generatedAggregation).asInstanceOf[DataView]
   }
 
+  def getGeneratedAggregationFields(
+      operator: AbstractUdfStreamOperator[_, _],
+      funcName: String,
+      funcClass: Class[_]): Array[Field] = {
+    val function = funcClass.getDeclaredField(funcName)
+    function.setAccessible(true)
+    val generatedAggregation =
+      function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
+    val cls = generatedAggregation.getClass
+    val fields = cls.getDeclaredFields
+    fields.foreach(_.setAccessible(true))
+    fields
+  }
+
   def createHarnessTester[IN, OUT, KEY](
     operator: OneInputStreamOperator[IN, OUT],
     keySelector: KeySelector[IN, KEY],
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
index d152804..1d2dd5e 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
@@ -31,6 +31,7 @@ import org.apache.flink.table.api.scala._
 import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction
 import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment}
 import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.MultiArgCount
 import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase}
 import org.apache.flink.types.Row
 import org.junit.Assert._
@@ -911,6 +912,7 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     val t = StreamTestData.get5TupleDataStream(env)
       .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
     tEnv.registerTable("MyTable", t)
+    tEnv.registerFunction("myCount", new MultiArgCount)
 
     val sqlQuery = "SELECT a, " +
       "  COUNT(e) OVER (" +
@@ -918,6 +920,10 @@ class OverWindowITCase extends StreamingWithStateTestBase {
       "  SUM(DISTINCT e) OVER (" +
       "    PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding), " +
       "  MIN(DISTINCT e) OVER (" +
+      "    PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding), " +
+      "  myCount(DISTINCT e, 1) OVER (" +
+      "    PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding), " +
+      "  myCount(DISTINCT 1, e) OVER (" +
       "    PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding) " +
       "FROM MyTable"
 
@@ -926,21 +932,21 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     env.execute()
 
     val expected = List(
-      "1,1,1,1",
-      "2,1,2,2",
-      "2,2,3,1",
-      "3,1,2,2",
-      "3,2,2,2",
-      "3,3,5,2",
-      "4,1,2,2",
-      "4,2,3,1",
-      "4,3,3,1",
-      "4,4,3,1",
-      "5,1,1,1",
-      "5,2,4,1",
-      "5,3,4,1",
-      "5,4,6,1",
-      "5,5,6,1")
+      "1,1,1,1,1,1",
+      "2,1,2,2,1,1",
+      "2,2,3,1,2,2",
+      "3,1,2,2,1,1",
+      "3,2,2,2,1,1",
+      "3,3,5,2,2,2",
+      "4,1,2,2,1,1",
+      "4,2,3,1,2,2",
+      "4,3,3,1,2,2",
+      "4,4,3,1,2,2",
+      "5,1,1,1,1,1",
+      "5,2,4,1,2,2",
+      "5,3,4,1,2,2",
+      "5,4,6,1,3,3",
+      "5,5,6,1,3,3")
     assertEquals(expected, StreamITCase.testResults)
   }
 
@@ -973,6 +979,7 @@ class OverWindowITCase extends StreamingWithStateTestBase {
 
     tEnv.registerTable("MyTable", table)
     tEnv.registerFunction("CntNullNonNull", new CountNullNonNull)
+    tEnv.registerFunction("myCount", new MultiArgCount)
 
     val sqlQuery = "SELECT " +
       "  c, " +
@@ -980,6 +987,10 @@ class OverWindowITCase extends StreamingWithStateTestBase {
       "  COUNT(DISTINCT c) " +
       "    OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding), " +
       "  CntNullNonNull(DISTINCT c) " +
+      "    OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding), " +
+      "  myCount(DISTINCT c, 1) " +
+      "    OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding), " +
+      "  myCount(DISTINCT 1, c) " +
       "    OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding)" +
       "FROM " +
       "  MyTable"
@@ -989,9 +1000,9 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     env.execute()
 
     val expected = List(
-      "null,1,0,0|1", "null,1,0,0|1", "null,2,0,0|1", "null,1,2,2|1",
-      "Hello,1,1,1|1", "Hello,1,1,1|1", "Hello,2,1,1|1",
-      "Hello World,1,2,2|1", "Hello World,2,2,2|1", "Hello World,2,2,2|1")
+      "null,1,0,0|1,0,0", "null,1,0,0|1,0,0", "null,2,0,0|1,0,0", "null,1,2,2|1,2,2",
+      "Hello,1,1,1|1,1,1", "Hello,1,1,1|1,1,1", "Hello,2,1,1|1,1,1",
+      "Hello World,1,2,2|1,2,2", "Hello World,2,2,2|1,2,2", "Hello World,2,2,2|1,2,2")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
index ddc2a68..818ba65 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
@@ -30,6 +30,7 @@ import org.apache.flink.table.api.{TableEnvironment, Types}
 import org.apache.flink.table.descriptors.{Rowtime, Schema}
 import org.apache.flink.table.expressions.utils.Func15
 import org.apache.flink.table.runtime.stream.sql.SqlITCase.TimestampAndWatermarkWithOffset
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.MultiArgCount
 import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction
 import org.apache.flink.table.runtime.utils.{JavaUserDefinedTableFunctions, StreamITCase, StreamTestData, StreamingWithStateTestBase}
 import org.apache.flink.table.utils.{InMemoryTableFactory, MemoryTableSourceSinkUtil}
@@ -70,15 +71,19 @@ class SqlITCase extends StreamingWithStateTestBase {
     StreamITCase.clear
     val stream = env
       .fromCollection(sessionWindowTestData)
-      .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset[(Long, Int, String)](10L))
+      .assignTimestampsAndWatermarks(
+        new TimestampAndWatermarkWithOffset[(Long, Int, String)](10L))
 
     val tEnv = TableEnvironment.getTableEnvironment(env)
     val table = stream.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
     tEnv.registerTable("MyTable", table)
+    tEnv.registerFunction("myCount", new MultiArgCount)
 
     val sqlQuery = "SELECT c, " +
       "  COUNT(DISTINCT b)," +
       "  SUM(DISTINCT b)," +
+      "  myCount(DISTINCT b, 1)," +
+      "  myCount(DISTINCT 1, b)," +
       "  SESSION_END(rowtime, INTERVAL '0.005' SECOND) " +
       "FROM MyTable " +
       "GROUP BY SESSION(rowtime, INTERVAL '0.005' SECOND), c "
@@ -88,10 +93,10 @@ class SqlITCase extends StreamingWithStateTestBase {
     env.execute()
 
     val expected = Seq(
-      "Hello World,1,9,1970-01-01 00:00:00.014", // window starts at [9L] till {14L}
-      "Hello,1,16,1970-01-01 00:00:00.021",       // window starts at [16L] till {21L}, not merged
-      "Hello,3,6,1970-01-01 00:00:00.015"        // window starts at [1L,2L],
-                                               //   merged with [8L,10L], by [4L], till {15L}
+      "Hello World,1,9,1,1,1970-01-01 00:00:00.014", // window starts at [9L] till {14L}
+      "Hello,1,16,1,1,1970-01-01 00:00:00.021",    // window starts at [16L] till {21L}, not merged
+      "Hello,3,6,3,3,1970-01-01 00:00:00.015"      // window starts at [1L,2L],
+                                                   // merged with [8L,10L], by [4L], till {15L}
     )
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
index 2e9dac5..2d1666f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
@@ -88,15 +88,30 @@ class AggregateITCase extends StreamingWithStateTestBase {
     val tEnv = TableEnvironment.getTableEnvironment(env)
     StreamITCase.clear
 
-    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
-      .groupBy('e)
-      .select('e, 'a.count.distinct)
+    val data = new mutable.MutableList[(Int, Int, String)]
+    data.+=((1, 1, "A"))
+    data.+=((2, 2, "B"))
+    data.+=((2, 2, "B"))
+    data.+=((4, 3, "C"))
+    data.+=((5, 3, "C"))
+    data.+=((4, 3, "C"))
+    data.+=((7, 3, "B"))
+    data.+=((1, 4, "A"))
+    data.+=((9, 4, "D"))
+    data.+=((4, 1, "A"))
+    data.+=((3, 2, "B"))
+
+    val testAgg = new WeightedAvg
+    val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c)
+      .groupBy('c)
+      .select('c, 'a.count.distinct, 'a.sum.distinct,
+              testAgg.distinct('a, 'b), testAgg.distinct('b, 'a), testAgg('a, 'b))
 
     val results = t.toRetractStream[Row](queryConfig)
     results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
     env.execute()
 
-    val expected = mutable.MutableList("1,4", "2,4", "3,2")
+    val expected = mutable.MutableList("A,2,5,1,1,1", "B,3,12,4,2,3", "C,2,9,4,3,4", "D,1,9,9,4,9")
     assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
   }