You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by go...@apache.org on 2023/01/31 10:27:59 UTC

[flink] branch master updated: [FLINK-30542][table-planner] Introduce adaptive local hash aggregate to adaptively determine whether local hash aggregate is required at runtime

This is an automated email from the ASF dual-hosted git repository.

godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 122ba8f319b [FLINK-30542][table-planner] Introduce adaptive local hash aggregate to adaptively determine whether local hash aggregate is required at runtime
122ba8f319b is described below

commit 122ba8f319b0d68374abba08d676e6dfa82cc114
Author: zhengyunhong.zyh <33...@qq.com>
AuthorDate: Wed Jan 11 21:57:02 2023 +0800

    [FLINK-30542][table-planner] Introduce adaptive local hash aggregate to adaptively determine whether local hash aggregate is required at runtime
    
    This closes #21586
---
 .../nodes/exec/batch/BatchExecHashAggregate.java   |  25 +-
 .../planner/codegen/ProjectionCodeGenerator.scala  | 166 ++++++++++-
 .../codegen/agg/batch/HashAggCodeGenerator.scala   | 205 +++++++++++--
 .../batch/BatchPhysicalHashAggregate.scala         |   1 +
 .../batch/BatchPhysicalLocalHashAggregate.scala    |   3 +
 .../physical/batch/BatchPhysicalAggRuleBase.scala  |   4 +-
 .../physical/batch/BatchPhysicalHashAggRule.scala  |  10 +-
 .../physical/batch/BatchPhysicalJoinRuleBase.scala |   1 +
 .../physical/batch/BatchPhysicalSortAggRule.scala  |   4 +-
 .../physical/batch/EnforceLocalAggRuleBase.scala   |   3 +-
 .../table/planner/plan/utils/AggregateUtil.scala   |  20 ++
 .../agg/batch/HashAggCodeGeneratorTest.scala       |   8 +-
 .../plan/metadata/FlinkRelMdHandlerTestBase.scala  |   3 +
 .../batch/sql/agg/AggregateITCaseBase.scala        |  15 +-
 .../runtime/batch/sql/agg/HashAggITCase.scala      | 316 ++++++++++++++++++++-
 15 files changed, 733 insertions(+), 51 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashAggregate.java
index 02ed8a737da..53ebbf2132c 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashAggregate.java
@@ -58,6 +58,7 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
     private final RowType aggInputRowType;
     private final boolean isMerge;
     private final boolean isFinal;
+    private final boolean supportAdaptiveLocalHashAgg;
 
     public BatchExecHashAggregate(
             ReadableConfig tableConfig,
@@ -67,6 +68,7 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
             RowType aggInputRowType,
             boolean isMerge,
             boolean isFinal,
+            boolean supportAdaptiveLocalHashAgg,
             InputProperty inputProperty,
             RowType outputType,
             String description) {
@@ -83,6 +85,7 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
         this.aggInputRowType = aggInputRowType;
         this.isMerge = isMerge;
         this.isFinal = isFinal;
+        this.supportAdaptiveLocalHashAgg = supportAdaptiveLocalHashAgg;
     }
 
     @SuppressWarnings("unchecked")
@@ -126,17 +129,17 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
                     config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)
                             .getBytes();
             generatedOperator =
-                    new HashAggCodeGenerator(
-                                    ctx,
-                                    planner.createRelBuilder(),
-                                    aggInfos,
-                                    inputRowType,
-                                    outputRowType,
-                                    grouping,
-                                    auxGrouping,
-                                    isMerge,
-                                    isFinal)
-                            .genWithKeys();
+                    HashAggCodeGenerator.genWithKeys(
+                            ctx,
+                            planner.createRelBuilder(),
+                            aggInfos,
+                            inputRowType,
+                            outputRowType,
+                            grouping,
+                            auxGrouping,
+                            isMerge,
+                            isFinal,
+                            supportAdaptiveLocalHashAgg);
         }
 
         return ExecNodeUtil.createOneInputTransformation(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
index 496af1394fe..e2c3759ddaa 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
@@ -19,11 +19,18 @@ package org.apache.flink.table.planner.codegen
 
 import org.apache.flink.table.data.RowData
 import org.apache.flink.table.data.binary.BinaryRowData
+import org.apache.flink.table.data.writer.BinaryRowWriter
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
 import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
 import org.apache.flink.table.planner.codegen.GenerateUtils.generateRecordStatement
+import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens
+import org.apache.flink.table.planner.functions.aggfunctions._
+import org.apache.flink.table.planner.plan.utils.AggregateInfo
 import org.apache.flink.table.runtime.generated.{GeneratedProjection, Projection}
-import org.apache.flink.table.types.logical.RowType
+import org.apache.flink.table.types.logical.{BigIntType, LogicalType, RowType}
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes
+
+import scala.collection.mutable.ArrayBuffer
 
 /**
  * CodeGenerator for projection, Take out some fields of [[RowData]] to generate a new [[RowData]].
@@ -124,6 +131,163 @@ object ProjectionCodeGenerator {
     new GeneratedProjection(className, code, ctx.references.toArray, ctx.tableConfig)
   }
 
+  /**
+   * If adaptive local hash aggregation takes effect, local hash aggregation will be suppressed. In
+   * order to ensure that the data structure transmitted downstream with doing local hash
+   * aggregation is consistent with the data format transmitted downstream without doing local hash
+   * aggregation, we need to do projection for grouping function value.
+   *
+   * <p> For example, for sql statement "select a, avg(b), count(c) from T group by a", if local
+   * hash aggregation suppressed and a row (1, 5, "a") comes to local hash aggregation, we will pass
+   * (1, 5, 1, 1) to downstream.
+   */
+  def genAdaptiveLocalHashAggValueProjectionCode(
+      ctx: CodeGeneratorContext,
+      inputType: RowType,
+      outClass: Class[_ <: RowData] = classOf[BinaryRowData],
+      inputTerm: String = DEFAULT_INPUT1_TERM,
+      aggInfos: Array[AggregateInfo],
+      outRecordTerm: String = DEFAULT_OUT_RECORD_TERM,
+      outRecordWriterTerm: String = DEFAULT_OUT_RECORD_WRITER_TERM): String = {
+    val fieldExprs: ArrayBuffer[GeneratedExpression] = ArrayBuffer()
+    aggInfos.map {
+      aggInfo =>
+        aggInfo.function match {
+          case sumAggFunction: SumAggFunction =>
+            fieldExprs += genValueProjectionForSumAggFunc(
+              ctx,
+              inputType,
+              inputTerm,
+              sumAggFunction.getResultType.getLogicalType,
+              aggInfo.agg.getArgList.get(0))
+          case _: MaxAggFunction | _: MinAggFunction =>
+            fieldExprs += GenerateUtils.generateFieldAccess(
+              ctx,
+              inputType,
+              inputTerm,
+              aggInfo.agg.getArgList.get(0))
+          case avgAggFunction: AvgAggFunction =>
+            fieldExprs += genValueProjectionForSumAggFunc(
+              ctx,
+              inputType,
+              inputTerm,
+              avgAggFunction.getSumType.getLogicalType,
+              aggInfo.agg.getArgList.get(0))
+            fieldExprs += genValueProjectionForCountAggFunc(
+              ctx,
+              inputTerm,
+              aggInfo.agg.getArgList.get(0))
+          case _: CountAggFunction =>
+            fieldExprs += genValueProjectionForCountAggFunc(
+              ctx,
+              inputTerm,
+              aggInfo.agg.getArgList.get(0))
+          case _: Count1AggFunction =>
+            fieldExprs += genValueProjectionForCount1AggFunc(ctx)
+        }
+    }
+
+    val binaryRowWriter = CodeGenUtils.className[BinaryRowWriter]
+    val typeTerm = outClass.getCanonicalName
+    ctx.addReusableMember(s"private $typeTerm $outRecordTerm= new $typeTerm(${fieldExprs.size});")
+    ctx.addReusableMember(
+      s"private $binaryRowWriter $outRecordWriterTerm = new $binaryRowWriter($outRecordTerm);")
+
+    val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap
+    val setFieldsCode = fieldExprs.zipWithIndex
+      .map {
+        case (fieldExpr, index) =>
+          val pos = fieldExprIdxToOutputRowPosMap.getOrElse(
+            index,
+            throw new CodeGenException(s"Illegal field expr index: $index"))
+          rowSetField(
+            ctx,
+            classOf[BinaryRowData],
+            outRecordTerm,
+            pos.toString,
+            fieldExpr,
+            Option(outRecordWriterTerm))
+      }
+      .mkString("\n")
+
+    val writer = outRecordWriterTerm
+    val resetWriter = s"$writer.reset();"
+    val completeWriter: String = s"$writer.complete();"
+    s"""
+       |$resetWriter
+       |$setFieldsCode
+       |$completeWriter
+        """.stripMargin
+  }
+
+  /**
+   * Do projection for grouping function 'sum(col)' if adaptive local hash aggregation takes effect.
+   * For 'count(col)', we will try to convert the projected value type to sum agg function target
+   * type if col is not null and convert it to default value type if col is null.
+   */
+  def genValueProjectionForSumAggFunc(
+      ctx: CodeGeneratorContext,
+      inputType: LogicalType,
+      inputTerm: String,
+      targetType: LogicalType,
+      index: Int): GeneratedExpression = {
+    val fieldType = getFieldTypes(inputType).get(index)
+    val resultTypeTerm = primitiveTypeTermForType(fieldType)
+    val defaultValue = primitiveDefaultValue(fieldType)
+    val readCode = rowFieldReadAccess(index.toString, inputTerm, fieldType)
+    val Seq(fieldTerm, nullTerm) =
+      ctx.addReusableLocalVariables((resultTypeTerm, "field"), ("boolean", "isNull"))
+
+    val inputCode =
+      s"""
+         |$nullTerm = $inputTerm.isNullAt($index);
+         |$fieldTerm = $defaultValue;
+         |if (!$nullTerm) {
+         |  $fieldTerm = $readCode;
+         |}
+           """.stripMargin.trim
+
+    val expression = GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType)
+    // Convert the projected value type to sum agg func target type.
+    ScalarOperatorGens.generateCast(ctx, expression, targetType, true)
+  }
+
+  /**
+   * Do projection for grouping function 'count(col)' if adaptive local hash aggregation takes
+   * effect. 'count(col)' will be convert to 1L if col is not null and convert to 0L if col is null.
+   */
+  def genValueProjectionForCountAggFunc(
+      ctx: CodeGeneratorContext,
+      inputTerm: String,
+      index: Int): GeneratedExpression = {
+    val Seq(fieldTerm, nullTerm) =
+      ctx.addReusableLocalVariables(("long", "field"), ("boolean", "isNull"))
+
+    val inputCode =
+      s"""
+         |$fieldTerm = 0L;
+         |if (!$inputTerm.isNullAt($index)) {
+         |  $fieldTerm = 1L;
+         |}
+           """.stripMargin.trim
+
+    GeneratedExpression(fieldTerm, nullTerm, inputCode, new BigIntType())
+  }
+
+  /**
+   * Do projection for grouping function 'count(*)' or 'count(1)' if adaptive local hash agg takes
+   * effect. 'count(*) or count(1)' will be convert to 1L and transmitted to downstream.
+   */
+  def genValueProjectionForCount1AggFunc(ctx: CodeGeneratorContext): GeneratedExpression = {
+    val Seq(fieldTerm, nullTerm) =
+      ctx.addReusableLocalVariables(("long", "field"), ("boolean", "isNull"))
+    val inputCode =
+      s"""
+         |$fieldTerm = 1L;
+         |""".stripMargin.trim
+    GeneratedExpression(fieldTerm, nullTerm, inputCode, new BigIntType())
+  }
+
   /** For java invoke. */
   def generateProjection(
       ctx: CodeGeneratorContext,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
index 98193260222..e98ca5dd266 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
@@ -17,19 +17,24 @@
  */
 package org.apache.flink.table.planner.codegen.agg.batch
 
+import org.apache.flink.annotation.Experimental
+import org.apache.flink.configuration.ConfigOption
+import org.apache.flink.configuration.ConfigOptions.key
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator
 import org.apache.flink.table.data.{GenericRowData, RowData}
 import org.apache.flink.table.data.binary.BinaryRowData
 import org.apache.flink.table.data.utils.JoinedRowData
-import org.apache.flink.table.functions.{AggregateFunction, DeclarativeAggregateFunction}
-import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ProjectionCodeGenerator}
-import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList}
+import org.apache.flink.table.functions.DeclarativeAggregateFunction
+import org.apache.flink.table.planner.{JBoolean, JDouble, JLong}
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, OperatorCodeGenerator, ProjectionCodeGenerator}
+import org.apache.flink.table.planner.codegen.CodeGenUtils.ROW_DATA
+import org.apache.flink.table.planner.plan.utils.AggregateInfoList
 import org.apache.flink.table.planner.typeutils.RowTypeUtils
 import org.apache.flink.table.runtime.generated.GeneratedOperator
 import org.apache.flink.table.runtime.operators.TableStreamOperator
 import org.apache.flink.table.runtime.operators.aggregate.BytesHashMapSpillMemorySegmentPool
 import org.apache.flink.table.runtime.util.collections.binary.BytesMap
-import org.apache.flink.table.types.logical.{LogicalType, RowType}
+import org.apache.flink.table.types.logical.RowType
 
 import org.apache.calcite.tools.RelBuilder
 
@@ -38,32 +43,78 @@ import org.apache.calcite.tools.RelBuilder
  * aggregateBuffers should be update(e.g.: setInt) in [[BinaryRowData]]. (Hash Aggregate performs
  * much better than Sort Aggregate).
  */
-class HashAggCodeGenerator(
-    ctx: CodeGeneratorContext,
-    builder: RelBuilder,
-    aggInfoList: AggregateInfoList,
-    inputType: RowType,
-    outputType: RowType,
-    grouping: Array[Int],
-    auxGrouping: Array[Int],
-    isMerge: Boolean,
-    isFinal: Boolean) {
+object HashAggCodeGenerator {
 
-  private lazy val aggInfos: Array[AggregateInfo] = aggInfoList.aggInfos
+  // It is a experimental config, will may be removed later.
+  @Experimental
+  val TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_ENABLED: ConfigOption[JBoolean] =
+    key("table.exec.local-hash-agg.adaptive.enabled")
+      .booleanType()
+      .defaultValue(Boolean.box(true))
+      .withDescription(
+        s"""
+           |Whether to enable adaptive local hash aggregation. Adaptive local hash
+           |aggregation is an optimization of local hash aggregation, which can adaptively 
+           |determine whether to continue to do local hash aggregation according to the distinct
+           | value rate of sampling data. If distinct value rate bigger than defined threshold
+           |(see parameter: table.exec.local-hash-agg.adaptive.distinct-value-rate-threshold), 
+           |we will stop aggregating and just send the input data to the downstream after a simple 
+           |projection. Otherwise, we will continue to do aggregation. Adaptive local hash aggregation
+           |only works in batch mode. Default value of this parameter is true.
+           |""".stripMargin)
 
-  private lazy val functionIdentifiers: Map[AggregateFunction[_, _], String] =
-    AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+  @Experimental
+  val TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD: ConfigOption[JLong] =
+    key("table.exec.local-hash-agg.adaptive.sampling-threshold")
+      .longType()
+      .defaultValue(Long.box(5000000L))
+      .withDescription(
+        s"""
+           |If adaptive local hash aggregation is enabled, this value defines how 
+           |many records will be used as sampled data to calculate distinct value rate 
+           |(see parameter: table.exec.local-hash-agg.adaptive.distinct-value-rate-threshold) 
+           |for the local aggregate. The higher the sampling threshold, the more accurate 
+           |the distinct value rate is. But as the sampling threshold increases, local 
+           |aggregation is meaningless when the distinct values rate is low. 
+           |The default value is 5000000.
+           |""".stripMargin)
 
-  private lazy val aggBufferNames: Array[Array[String]] =
-    AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+  @Experimental
+  val TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_DISTINCT_VALUE_RATE_THRESHOLD: ConfigOption[JDouble] =
+    key("table.exec.local-hash-agg.adaptive.distinct-value-rate-threshold")
+      .doubleType()
+      .defaultValue(0.5d)
+      .withDescription(
+        s"""
+           |The distinct value rate can be defined as the number of local 
+           |aggregation result for the sampled data divided by the sampling 
+           |threshold (see ${TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD.key()}). 
+           |If the computed result is lower than the given configuration value, 
+           |the remaining input records proceed to do local aggregation, otherwise 
+           |the remaining input records are subjected to simple projection which 
+           |calculation cost is less than local aggregation. The default value is 0.5.
+           |""".stripMargin)
 
-  private lazy val aggBufferTypes: Array[Array[LogicalType]] =
-    AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
+  def genWithKeys(
+      ctx: CodeGeneratorContext,
+      builder: RelBuilder,
+      aggInfoList: AggregateInfoList,
+      inputType: RowType,
+      outputType: RowType,
+      grouping: Array[Int],
+      auxGrouping: Array[Int],
+      isMerge: Boolean,
+      isFinal: Boolean,
+      supportAdaptiveLocalHashAgg: Boolean)
+      : GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
 
-  private lazy val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
-  private lazy val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
+    val aggInfos = aggInfoList.aggInfos
+    val functionIdentifiers = AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+    val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+    val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
+    val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
+    val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
 
-  def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
     val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
     val className = if (isFinal) "HashAggregateWithKeys" else "LocalHashAggregateWithKeys"
 
@@ -74,6 +125,10 @@ class HashAggCodeGenerator(
     // gen code to do group key projection from input
     val currentKeyTerm = CodeGenUtils.newName("currentKey")
     val currentKeyWriterTerm = CodeGenUtils.newName("currentKeyWriter")
+    // currentValueTerm and currentValueWriterTerm are used for value
+    // projection while supportAdaptiveLocalHashAgg is true.
+    val currentValueTerm = CodeGenUtils.newName("currentValue")
+    val currentValueWriterTerm = CodeGenUtils.newName("currentValueWriter")
     val keyProjectionCode = ProjectionCodeGenerator
       .generateProjectionExpression(
         ctx,
@@ -85,6 +140,21 @@ class HashAggCodeGenerator(
         outRecordWriterTerm = currentKeyWriterTerm)
       .code
 
+    val valueProjectionCode =
+      if (!isFinal && supportAdaptiveLocalHashAgg) {
+        ProjectionCodeGenerator.genAdaptiveLocalHashAggValueProjectionCode(
+          ctx,
+          inputType,
+          classOf[BinaryRowData],
+          inputTerm = inputTerm,
+          aggInfos,
+          outRecordTerm = currentValueTerm,
+          outRecordWriterTerm = currentValueWriterTerm
+        )
+      } else {
+        ""
+      }
+
     // gen code to create groupKey, aggBuffer Type array
     // it will be used in BytesHashMap and BufferedKVExternalSorter if enable fallback
     val groupKeyTypesTerm = CodeGenUtils.newName("groupKeyTypes")
@@ -173,6 +243,77 @@ class HashAggCodeGenerator(
 
     HashAggCodeGenHelper.prepareMetrics(ctx, aggregateMapTerm, if (isFinal) sorterTerm else null)
 
+    // Do adaptive hash aggregation
+    val outputResultForAdaptiveLocalHashAgg = {
+      // gen code to iterating the aggregate map and output to downstream
+      val inputUnboxingCode = s"${ctx.reuseInputUnboxingCode(reuseAggBufferTerm)}"
+      s"""
+         |   // set result and output
+         |   $reuseGroupKeyTerm =  ($ROW_DATA)$currentKeyTerm;
+         |   $reuseAggBufferTerm = ($ROW_DATA)$currentValueTerm;
+         |   $inputUnboxingCode
+         |   ${outputExpr.code}
+         |   ${OperatorCodeGenerator.generateCollect(outputExpr.resultTerm)}
+         |
+       """.stripMargin
+    }
+    val localAggSuppressedTerm = CodeGenUtils.newName("localAggSuppressed")
+    ctx.addReusableMember(s"private transient boolean $localAggSuppressedTerm = false;")
+    val (
+      distinctCountIncCode,
+      totalCountIncCode,
+      adaptiveSamplingCode,
+      adaptiveLocalHashAggCode,
+      flushResultSuppressEnableCode) = {
+      // from these conditions we know that it must be a distinct operation
+      if (
+        !isFinal &&
+        ctx.tableConfig.get(TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_ENABLED) &&
+        supportAdaptiveLocalHashAgg
+      ) {
+        val adaptiveDistinctCountTerm = CodeGenUtils.newName("distinctCount")
+        val adaptiveTotalCountTerm = CodeGenUtils.newName("totalCount")
+        ctx.addReusableMember(s"private transient long $adaptiveDistinctCountTerm = 0;")
+        ctx.addReusableMember(s"private transient long $adaptiveTotalCountTerm = 0;")
+
+        val samplingThreshold =
+          ctx.tableConfig.get(TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD)
+        val distinctValueRateThreshold =
+          ctx.tableConfig.get(TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_DISTINCT_VALUE_RATE_THRESHOLD)
+
+        (
+          s"$adaptiveDistinctCountTerm++;",
+          s"$adaptiveTotalCountTerm++;",
+          s"""
+             |if ($adaptiveTotalCountTerm == $samplingThreshold) {
+             |  $logTerm.info("Local hash aggregation checkpoint reached, sampling threshold = " +
+             |    $samplingThreshold + ", distinct value count = " + $adaptiveDistinctCountTerm + ", total = " +
+             |    $adaptiveTotalCountTerm + ", distinct value rate threshold = " 
+             |    + $distinctValueRateThreshold);
+             |  if ($adaptiveDistinctCountTerm / (1.0 * $adaptiveTotalCountTerm) > $distinctValueRateThreshold) {
+             |    $logTerm.info("Local hash aggregation is suppressed");
+             |    $localAggSuppressedTerm = true;
+             |  }
+             |}
+             |""".stripMargin,
+          s"""
+             |if ($localAggSuppressedTerm) {
+             |  $valueProjectionCode
+             |  $outputResultForAdaptiveLocalHashAgg
+             |  return;
+             |}
+             |""".stripMargin,
+          s"""
+             |if ($localAggSuppressedTerm) {
+             |  $outputResultFromMap
+             |  return;
+             |}
+             |""".stripMargin)
+      } else {
+        ("", "", "", "", "")
+      }
+    }
+
     val lazyInitAggBufferCode = if (auxGrouping.nonEmpty) {
       s"""
          |// lazy init agg buffer (with auxGrouping)
@@ -188,11 +329,15 @@ class HashAggCodeGenerator(
          |${ctx.reuseInputUnboxingCode(inputTerm)}
          | // project key from input
          |$keyProjectionCode
+         |
+         |$adaptiveLocalHashAggCode
+         |
          | // look up output buffer using current group key
          |$lookupInfo = ($lookupInfoTypeTerm) $aggregateMapTerm.lookup($currentKeyTerm);
          |$currentAggBufferTerm = ($binaryRowTypeTerm) $lookupInfo.getValue();
          |
          |if (!$lookupInfo.isFound()) {
+         |  $distinctCountIncCode
          |  $lazyInitAggBufferCode
          |  // append empty agg buffer into aggregate map for current group key
          |  try {
@@ -202,10 +347,16 @@ class HashAggCodeGenerator(
          |    $dealWithAggHashMapOOM
          |  }
          |}
+         |
+         |$totalCountIncCode
+         |$adaptiveSamplingCode
+         |
          | // aggregate buffer fields access
          |${ctx.reuseInputUnboxingCode(currentAggBufferTerm)}
          | // do aggregate and update agg buffer
          |${aggregate.code}
+         | // flush result form map if suppress is enable. 
+         |$flushResultSuppressEnableCode
          |""".stripMargin.trim
 
     val endInputCode = if (isFinal) {
@@ -227,7 +378,11 @@ class HashAggCodeGenerator(
          |}
        """.stripMargin
     } else {
-      s"$outputResultFromMap"
+      s"""
+         |if (!$localAggSuppressedTerm) {
+         | $outputResultFromMap
+         |}
+         |""".stripMargin
     }
 
     AggCodeGenHelper.generateOperator(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalHashAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalHashAggregate.scala
index 95cdc4800a2..e22cf17d735 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalHashAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalHashAggregate.scala
@@ -161,6 +161,7 @@ class BatchPhysicalHashAggregate(
       FlinkTypeFactory.toLogicalRowType(aggInputRowType),
       isMerge,
       true, // isFinal is always true
+      false, // supportAdaptiveLocalHashAgg is always false
       InputProperty
         .builder()
         .requiredDistribution(requiredDistribution)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala
index b0c0ea8c63a..6963440ef3a 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala
@@ -50,6 +50,7 @@ class BatchPhysicalLocalHashAggregate(
     inputRowType: RelDataType,
     grouping: Array[Int],
     auxGrouping: Array[Int],
+    val supportAdaptiveLocalHashAgg: Boolean,
     aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)])
   extends BatchPhysicalHashAggregateBase(
     cluster,
@@ -71,6 +72,7 @@ class BatchPhysicalLocalHashAggregate(
       inputRowType,
       grouping,
       auxGrouping,
+      supportAdaptiveLocalHashAgg,
       aggCallToAggFunction)
   }
 
@@ -134,6 +136,7 @@ class BatchPhysicalLocalHashAggregate(
       FlinkTypeFactory.toLogicalRowType(inputRowType),
       false, // isMerge is always false
       false, // isFinal is always false
+      supportAdaptiveLocalHashAgg,
       getInputProperty,
       FlinkTypeFactory.toLogicalRowType(getRowType),
       getRelDetailedDescription)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala
index 7b4de5b5f60..412046d2400 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala
@@ -203,7 +203,8 @@ trait BatchPhysicalAggRuleBase {
       auxGrouping: Array[Int],
       aggBufferTypes: Array[Array[DataType]],
       aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
-      isLocalHashAgg: Boolean): BatchPhysicalGroupAggregateBase = {
+      isLocalHashAgg: Boolean,
+      supportAdaptiveLocalHashAgg: Boolean): BatchPhysicalGroupAggregateBase = {
     val inputRowType = input.getRowType
     val aggFunctions = aggCallToAggFunction.map(_._2).toArray
 
@@ -231,6 +232,7 @@ trait BatchPhysicalAggRuleBase {
         inputRowType,
         grouping,
         auxGrouping,
+        supportAdaptiveLocalHashAgg,
         aggCallToAggFunction)
     } else {
       new BatchPhysicalLocalSortAggregate(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala
index f208cdfbed6..f6ad3a06f89 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala
@@ -94,6 +94,12 @@ class BatchPhysicalHashAggRule
     val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions)
     val aggProvidedTraitSet = agg.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
 
+    // Judge whether this agg operator support adaptive local hash agg.
+    // If all agg function in agg operator can projection, then it support
+    // adaptive local hash agg. Otherwise false.
+    val supportAdaptiveLocalHashAgg =
+      AggregateUtil.doAllAggSupportAdaptiveLocalHashAgg(aggCallToAggFunction.map(_._1))
+
     // create two-phase agg if possible
     if (isTwoPhaseAggWorkable(aggFunctions, tableConfig)) {
       // create BatchPhysicalLocalHashAggregate
@@ -109,7 +115,9 @@ class BatchPhysicalHashAggRule
         auxGroupSet,
         aggBufferTypes,
         aggCallToAggFunction,
-        isLocalHashAgg = true)
+        isLocalHashAgg = true,
+        supportAdaptiveLocalHashAgg
+      )
 
       // create global BatchPhysicalHashAggregate
       val (globalGroupSet, globalAuxGroupSet) = getGlobalAggGroupSetPair(groupSet, auxGroupSet)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalJoinRuleBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalJoinRuleBase.scala
index 94cf8884a67..66540df1989 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalJoinRuleBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalJoinRuleBase.scala
@@ -70,6 +70,7 @@ trait BatchPhysicalJoinRuleBase {
       node.getRowType, // input row type
       distinctKeys.toArray,
       Array.empty,
+      supportAdaptiveLocalHashAgg = false,
       Seq())
   }
 
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala
index f36dbbc6523..8758bb1e17e 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala
@@ -108,7 +108,9 @@ class BatchPhysicalSortAggRule
         auxGroupSet,
         aggBufferTypes,
         aggCallToAggFunction,
-        isLocalHashAgg = false)
+        isLocalHashAgg = false,
+        supportAdaptiveLocalHashAgg = false
+      )
 
       // create global BatchPhysicalSortAggregate
       val (globalGroupSet, globalAuxGroupSet) = getGlobalAggGroupSetPair(groupSet, auxGroupSet)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala
index d280fbed810..aba39ef31a2 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala
@@ -94,7 +94,8 @@ abstract class EnforceLocalAggRuleBase(operand: RelOptRuleOperand, description:
       auxGrouping,
       aggBufferTypes,
       aggCallToAggFunction,
-      isLocalHashAgg
+      isLocalHashAgg,
+      supportAdaptiveLocalHashAgg = false
     )
   }
 
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index ec06e058a88..ef25844477f 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -926,6 +926,26 @@ object AggregateUtil extends Enumeration {
     aggInfos.isEmpty || supportMerge
   }
 
+  /**
+   * Return true if all aggregates can be projected for adaptive local hash aggregate. False
+   * otherwise.
+   */
+  def doAllAggSupportAdaptiveLocalHashAgg(aggCalls: Seq[AggregateCall]): Boolean = {
+    aggCalls.forall {
+      aggCall =>
+        // TODO support adaptive local hash agg while agg call with filter condition.
+        if (aggCall.filterArg >= 0) {
+          return false
+        }
+        aggCall.getAggregation match {
+          case _: SqlCountAggFunction | _: SqlAvgAggFunction | _: SqlMinMaxAggFunction |
+              _: SqlSumAggFunction =>
+            true
+          case _ => false
+        }
+    }
+  }
+
   /** Return true if all aggregates can be split. False otherwise. */
   def doAllAggSupportSplit(aggCalls: util.List[AggregateCall]): Boolean = {
     aggCalls.forall {
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGeneratorTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGeneratorTest.scala
index 2a7b433ad00..78e50368396 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGeneratorTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGeneratorTest.scala
@@ -22,7 +22,7 @@ import org.apache.flink.table.data.RowData
 import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction.LongAvgAggFunction
 import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList}
 import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
-import org.apache.flink.table.types.logical.{BigIntType, DoubleType, LogicalType, RowType, VarCharType}
+import org.apache.flink.table.types.logical._
 
 import org.apache.calcite.rel.core.AggregateCall
 import org.junit.Test
@@ -124,7 +124,7 @@ class HashAggCodeGeneratorTest extends BatchAggTestBase {
       (inputType, localOutputType)
     }
     val auxGrouping = if (isMerge) Array(1) else Array(4)
-    val generator = new HashAggCodeGenerator(
+    val genOp = HashAggCodeGenerator.genWithKeys(
       ctx,
       relBuilder,
       aggInfoList,
@@ -133,8 +133,8 @@ class HashAggCodeGeneratorTest extends BatchAggTestBase {
       Array(0),
       auxGrouping,
       isMerge,
-      isFinal)
-    val genOp = generator.genWithKeys()
+      isFinal,
+      false)
     (new CodeGenOperatorFactory[RowData](genOp), iType, oType)
   }
 
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
index b9074397fa6..5535fc0671e 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
@@ -1163,6 +1163,7 @@ class FlinkRelMdHandlerTestBase {
       studentBatchScan.getRowType,
       Array(3),
       auxGrouping = Array(),
+      true,
       aggCallToAggFunction)
 
     val batchExchange1 = new BatchPhysicalExchange(
@@ -1428,6 +1429,7 @@ class FlinkRelMdHandlerTestBase {
       calcOnStudentScan.getRowType,
       Array(3),
       auxGrouping = Array(),
+      true,
       aggCallToAggFunction)
 
     val batchExchange1 = new BatchPhysicalExchange(
@@ -1580,6 +1582,7 @@ class FlinkRelMdHandlerTestBase {
       studentBatchScan.getRowType,
       Array(0),
       auxGrouping = Array(1, 4),
+      true,
       aggCallToAggFunction)
 
     val hash0 = FlinkRelDistribution.hash(Array(0), requireStrict = true)
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
index 6bcf4364369..556f09eba54 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.runtime.batch.sql.agg
 import org.apache.flink.api.common.BatchShuffleMode
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfoBase}
 import org.apache.flink.api.scala._
 import org.apache.flink.configuration.{ExecutionOptions, JobManagerOptions}
 import org.apache.flink.configuration.JobManagerOptions.SchedulerType
@@ -298,9 +298,16 @@ abstract class AggregateITCaseBase(testName: String) extends BatchTestBase {
     val tableRows = tableData.map(toRow)
 
     val tupleTypeInfo = implicitly[TypeInformation[T]]
-    val fieldInfos = tupleTypeInfo.getGenericParameters.values()
-    import scala.collection.JavaConverters._
-    val rowTypeInfo = new RowTypeInfo(fieldInfos.asScala.toArray: _*)
+    val rowTypeInfo: RowTypeInfo = if (tupleTypeInfo.isTupleType) {
+      new RowTypeInfo(
+        tupleTypeInfo
+          .asInstanceOf[TupleTypeInfoBase[T]]
+          .getFieldTypes: _*)
+    } else {
+      val fieldInfos = tupleTypeInfo.getGenericParameters.values()
+      import scala.collection.JavaConverters._
+      new RowTypeInfo(fieldInfos.asScala.toArray: _*)
+    }
 
     newTableId += 1
     val tableName = "TestTableX" + newTableId
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/HashAggITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/HashAggITCase.scala
index 15191584f48..4ff207754b4 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/HashAggITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/HashAggITCase.scala
@@ -17,12 +17,324 @@
  */
 package org.apache.flink.table.planner.runtime.batch.sql.agg
 
-import org.apache.flink.table.api.config.ExecutionConfigOptions
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.DataTypes
+import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions}
+import org.apache.flink.table.planner.codegen.agg.batch.HashAggCodeGenerator
+import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
+
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import java.math.BigDecimal
 
 /** AggregateITCase using HashAgg Operator. */
-class HashAggITCase extends AggregateITCaseBase("HashAggregate") {
+@RunWith(classOf[Parameterized])
+class HashAggITCase(adaptiveLocalHashAggEnable: Boolean)
+  extends AggregateITCaseBase("HashAggregate") {
 
   override def prepareAggOp(): Unit = {
     tEnv.getConfig.set(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "SortAgg")
+    if (adaptiveLocalHashAggEnable) {
+      tEnv.getConfig
+        .set(OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY, "TWO_PHASE")
+      tEnv.getConfig.set(
+        HashAggCodeGenerator.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_ENABLED,
+        Boolean.box(true))
+      tEnv.getConfig.set(
+        HashAggCodeGenerator.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD,
+        Long.box(5L))
+    }
+  }
+
+  @Test
+  def testAdaptiveLocalHashAggWithHighAggregationDegree(): Unit = {
+    checkQuery(
+      Seq(
+        (1, 1, 1, 1, 1L, 1.1d),
+        (1, 1, 1, 2, 1L, 1.2d),
+        (1, 1, 2, 3, 2L, 2.2d),
+        (1, 1, 2, 2, 2L, 1d),
+        (1, 1, 3, 3, 3L, 3d),
+        (1, 1, 2, 2, 2L, 4d),
+        (1, 2, 1, 1, 1L, 1.1d),
+        (1, 2, 1, 2, 1L, 2.3d),
+        (1, 3, 1, 1, 1L, 3.3d),
+        (1, 4, 1, 1, 1L, 1.1d),
+        (2, 1, 2, 2, 2L, 2.2d),
+        (2, 2, 3, 3, 3L, 3.3d)
+      ),
+      """
+        | SELECT f0, f1, sum(f2), avg(f2), max(f3), min(f3), count(f3), count(*), sum(f4), sum(f5), avg(f4), avg(f5)
+        | FROM TableName GROUP BY f0, f1
+        |""".stripMargin,
+      Seq(
+        (1, 1, 11, 1, 3, 1, 6, 6, 11, 12.5, 1, 2.0833333333333335),
+        (1, 2, 2, 1, 2, 1, 2, 2, 2, 3.4, 1, 1.7),
+        (1, 3, 1, 1, 1, 1, 1, 1, 1, 3.3, 1, 3.3),
+        (1, 4, 1, 1, 1, 1, 1, 1, 1, 1.1, 1, 1.1),
+        (2, 1, 2, 2, 2, 2, 1, 1, 2, 2.2, 2, 2.2),
+        (2, 2, 3, 3, 3, 3, 1, 1, 3, 3.3, 3, 3.3)
+      )
+    )
+  }
+
+  @Test
+  def testAdaptiveLocalHashAggWithLowAggregationDegree(): Unit = {
+    checkQuery(
+      Seq(
+        (1, 1, 1, 1, 1L, 1.1d),
+        (1, 1, 1, 2, 1L, 1.2d),
+        (1, 2, 2, 3, 2L, 2.2d),
+        (1, 3, 2, 2, 2L, 1d),
+        (1, 4, 3, 3, 3L, 3d),
+        (1, 5, 2, 2, 2L, 4d),
+        (1, 6, 1, 1, 1L, 3.3d),
+        (2, 1, 1, 2, 1L, 2.3d),
+        (2, 2, 1, 1, 1L, 3.3d),
+        (2, 3, 1, 1, 1L, 3.3d),
+        (2, 3, 2, 2, 2L, 2.2d),
+        (2, 3, 3, 3, 3L, 3.3d)
+      ),
+      """
+        | SELECT f0, f1, sum(f2), avg(f2), max(f3), min(f3), count(f3), count(*), sum(f4), sum(f5), avg(f4), avg(f5)
+        | FROM TableName GROUP BY f0, f1
+        |""".stripMargin,
+      Seq(
+        (1, 1, 2, 1, 2, 1, 2, 2, 2, 2.3, 1, 1.15),
+        (1, 2, 2, 2, 3, 3, 1, 1, 2, 2.2, 2, 2.2),
+        (1, 3, 2, 2, 2, 2, 1, 1, 2, 1.0, 2, 1.0),
+        (1, 4, 3, 3, 3, 3, 1, 1, 3, 3.0, 3, 3.0),
+        (1, 5, 2, 2, 2, 2, 1, 1, 2, 4.0, 2, 4.0),
+        (1, 6, 1, 1, 1, 1, 1, 1, 1, 3.3, 1, 3.3),
+        (2, 1, 1, 1, 2, 2, 1, 1, 1, 2.3, 1, 2.3),
+        (2, 2, 1, 1, 1, 1, 1, 1, 1, 3.3, 1, 3.3),
+        (2, 3, 6, 2, 3, 1, 3, 3, 6, 8.8, 2, 2.9333333333333336)
+      )
+    )
+  }
+
+  @Test
+  def testAdaptiveLocalHashAggWithRowLessThanSamplingThreshold(): Unit = {
+    checkQuery(
+      Seq((1, 1, 1, 1, 1L, 1.1d), (1, 1, 1, 2, 1L, 1.2d), (1, 2, 2, 3, 2L, 2.2d)),
+      """
+        | SELECT f0, f1, sum(f2), avg(f2), max(f3), min(f3), count(f3), count(*), sum(f4), sum(f5), avg(f4), avg(f5)
+        | FROM TableName GROUP BY f0, f1
+        |""".stripMargin,
+      Seq((1, 1, 2, 1, 2, 1, 2, 2, 2, 2.3, 1, 1.15), (1, 2, 2, 2, 3, 3, 1, 1, 2, 2.2, 2, 2.2))
+    )
+  }
+
+  @Test
+  def testAdaptiveLocalHashAggWithNullValue(): Unit = {
+    val testDataWithNullValue = tEnv.fromValues(
+      DataTypes.ROW(
+        DataTypes.FIELD("f0", DataTypes.INT()),
+        DataTypes.FIELD("f1", DataTypes.INT()),
+        DataTypes.FIELD("f2", DataTypes.INT()),
+        DataTypes.FIELD("f3", DataTypes.INT()),
+        DataTypes.FIELD("f4", DataTypes.BIGINT()),
+        DataTypes.FIELD("f5", DataTypes.DOUBLE())
+      ),
+      row(1, 1, 1, 1, null, 1.1d),
+      row(1, 1, null, null, 1L, null),
+      row(1, 1, 2, 3, 2L, 2.2d),
+      row(1, 1, 2, 2, null, 1d),
+      row(1, 1, 2, 2, 2L, 4d),
+      row(1, 1, null, 3, 3L, null),
+      row(1, 2, 1, 1, 1L, 1.1d),
+      row(1, 2, null, 2, null, null),
+      row(1, 3, 1, 1, 1L, null),
+      row(1, 4, 1, 1, 1L, 1.1d),
+      row(2, 1, null, 2, 2L, 2.2d),
+      row(2, 2, 3, null, 3L, 3.3d)
+    )
+
+    checkResult(
+      s"""
+         | SELECT f0, f1, sum(f2), avg(f2), max(f3), min(f3), count(f3), count(*), sum(f4), sum(f5), avg(f4), avg(f5)
+         | FROM $testDataWithNullValue GROUP BY f0, f1
+         |""".stripMargin,
+      Seq(
+        row(1, 1, 7, 1, 3, 1, 5, 6, 8, 8.3, 2, 2.075),
+        row(1, 2, 1, 1, 2, 1, 2, 2, 1, 1.1, 1, 1.1),
+        row(1, 3, 1, 1, 1, 1, 1, 1, 1, null, 1, null),
+        row(1, 4, 1, 1, 1, 1, 1, 1, 1, 1.1, 1, 1.1),
+        row(2, 1, null, null, 2, 2, 1, 1, 2, 2.2, 2, 2.2),
+        row(2, 2, 3, 3, null, null, 0, 1, 3, 3.3, 3, 3.3)
+      )
+    )
+
+  }
+
+  @Test
+  def testAdaptiveHashAggWithSumAndAvgFunctionForNumericalType(): Unit = {
+    val testDataWithAllTypes = tEnv.fromValues(
+      DataTypes.ROW(
+        DataTypes.FIELD("f0", DataTypes.INT()),
+        DataTypes.FIELD("f1", DataTypes.TINYINT()),
+        DataTypes.FIELD("f2", DataTypes.SMALLINT()),
+        DataTypes.FIELD("f3", DataTypes.BIGINT()),
+        DataTypes.FIELD("f4", DataTypes.FLOAT()),
+        DataTypes.FIELD("f5", DataTypes.DOUBLE()),
+        DataTypes.FIELD("f6", DataTypes.DECIMAL(5, 2)),
+        DataTypes.FIELD("f7", DataTypes.DECIMAL(14, 3)),
+        DataTypes.FIELD("f8", DataTypes.DECIMAL(38, 18))
+      ),
+      row(
+        1,
+        1,
+        1,
+        1000L,
+        -1.1f,
+        1.1d,
+        new BigDecimal("111.11"),
+        new BigDecimal("11111111111.111"),
+        new BigDecimal("11111111111111111111.111111111111111111")
+      ),
+      row(
+        1,
+        1,
+        1,
+        1000L,
+        -1.1f,
+        1.1d,
+        new BigDecimal("111.11"),
+        new BigDecimal("11111111111.111"),
+        new BigDecimal("11111111111111111111.111111111111111111")
+      ),
+      row(
+        1,
+        1,
+        1,
+        1000L,
+        -1.1f,
+        1.1d,
+        new BigDecimal("111.11"),
+        new BigDecimal("11111111111.111"),
+        new BigDecimal("11111111111111111111.111111111111111111")
+      ),
+      row(
+        1,
+        1,
+        1,
+        1000L,
+        -1.1f,
+        1.1d,
+        new BigDecimal("111.11"),
+        new BigDecimal("11111111111.111"),
+        new BigDecimal("11111111111111111111.111111111111111111")
+      ),
+      row(
+        2,
+        1,
+        1,
+        1000L,
+        -1.1f,
+        1.1d,
+        new BigDecimal("111.11"),
+        new BigDecimal("11111111111.111"),
+        new BigDecimal("11111111111111111111.111111111111111111")
+      ),
+      row(
+        2,
+        1,
+        1,
+        1000L,
+        -1.1f,
+        1.1d,
+        new BigDecimal("111.11"),
+        new BigDecimal("11111111111.111"),
+        new BigDecimal("11111111111111111111.111111111111111111")
+      ),
+      row(
+        3,
+        1,
+        1,
+        1000L,
+        -1.1f,
+        1.1d,
+        new BigDecimal("111.11"),
+        new BigDecimal("11111111111.111"),
+        new BigDecimal("11111111111111111111.111111111111111111")
+      )
+    )
+
+    checkResult(
+      s"""
+         | SELECT f0, sum(f1), avg(f1), sum(f2), avg(f2), sum(f3), avg(f3),
+         | sum(f4), avg(f4), sum(f5), avg(f5), sum(f6), avg(f6),
+         | sum(f7), avg(f7), sum(f8), avg(f8)
+         | FROM $testDataWithAllTypes GROUP BY f0
+         |""".stripMargin,
+      Seq(
+        row(
+          1,
+          4.toByte,
+          1.toByte,
+          4.toShort,
+          1.toShort,
+          4000L,
+          1000L,
+          -4.4f,
+          -1.1f,
+          4.4d,
+          1.1d,
+          new BigDecimal("444.44"),
+          new BigDecimal("111.110000"),
+          new BigDecimal("44444444444.444"),
+          new BigDecimal("11111111111.111000"),
+          new BigDecimal("44444444444444444444.444444444444444444"),
+          new BigDecimal("11111111111111111111.111111111111111111")
+        ),
+        row(
+          2,
+          2.toByte,
+          1.toByte,
+          2.toShort,
+          1.toShort,
+          2000L,
+          1000L,
+          -2.2f,
+          -1.1f,
+          2.2d,
+          1.1d,
+          new BigDecimal("222.22"),
+          new BigDecimal("111.110000"),
+          new BigDecimal("22222222222.222"),
+          new BigDecimal("11111111111.111000"),
+          new BigDecimal("22222222222222222222.222222222222222222"),
+          new BigDecimal("11111111111111111111.111111111111111111")
+        ),
+        row(
+          3,
+          1.toByte,
+          1.toByte,
+          1.toShort,
+          1.toShort,
+          1000L,
+          1000L,
+          -1.1f,
+          -1.1f,
+          1.1d,
+          1.1d,
+          new BigDecimal("111.11"),
+          new BigDecimal("111.110000"),
+          new BigDecimal("11111111111.111"),
+          new BigDecimal("11111111111.111000"),
+          new BigDecimal("11111111111111111111.111111111111111111"),
+          new BigDecimal("11111111111111111111.111111111111111111")
+        )
+      )
+    )
+  }
+}
+
+object HashAggITCase {
+  @Parameterized.Parameters(name = "adaptiveLocalHashAggEnable={0}")
+  def parameters(): java.util.Collection[Boolean] = {
+    java.util.Arrays.asList(true, false)
   }
 }