You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by go...@apache.org on 2023/01/31 10:27:59 UTC
[flink] branch master updated: [FLINK-30542][table-planner] Introduce adaptive local hash aggregate to adaptively determine whether local hash aggregate is required at runtime
This is an automated email from the ASF dual-hosted git repository.
godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 122ba8f319b [FLINK-30542][table-planner] Introduce adaptive local hash aggregate to adaptively determine whether local hash aggregate is required at runtime
122ba8f319b is described below
commit 122ba8f319b0d68374abba08d676e6dfa82cc114
Author: zhengyunhong.zyh <33...@qq.com>
AuthorDate: Wed Jan 11 21:57:02 2023 +0800
[FLINK-30542][table-planner] Introduce adaptive local hash aggregate to adaptively determine whether local hash aggregate is required at runtime
This closes #21586
---
.../nodes/exec/batch/BatchExecHashAggregate.java | 25 +-
.../planner/codegen/ProjectionCodeGenerator.scala | 166 ++++++++++-
.../codegen/agg/batch/HashAggCodeGenerator.scala | 205 +++++++++++--
.../batch/BatchPhysicalHashAggregate.scala | 1 +
.../batch/BatchPhysicalLocalHashAggregate.scala | 3 +
.../physical/batch/BatchPhysicalAggRuleBase.scala | 4 +-
.../physical/batch/BatchPhysicalHashAggRule.scala | 10 +-
.../physical/batch/BatchPhysicalJoinRuleBase.scala | 1 +
.../physical/batch/BatchPhysicalSortAggRule.scala | 4 +-
.../physical/batch/EnforceLocalAggRuleBase.scala | 3 +-
.../table/planner/plan/utils/AggregateUtil.scala | 20 ++
.../agg/batch/HashAggCodeGeneratorTest.scala | 8 +-
.../plan/metadata/FlinkRelMdHandlerTestBase.scala | 3 +
.../batch/sql/agg/AggregateITCaseBase.scala | 15 +-
.../runtime/batch/sql/agg/HashAggITCase.scala | 316 ++++++++++++++++++++-
15 files changed, 733 insertions(+), 51 deletions(-)
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashAggregate.java
index 02ed8a737da..53ebbf2132c 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashAggregate.java
@@ -58,6 +58,7 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
private final RowType aggInputRowType;
private final boolean isMerge;
private final boolean isFinal;
+ private final boolean supportAdaptiveLocalHashAgg;
public BatchExecHashAggregate(
ReadableConfig tableConfig,
@@ -67,6 +68,7 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
RowType aggInputRowType,
boolean isMerge,
boolean isFinal,
+ boolean supportAdaptiveLocalHashAgg,
InputProperty inputProperty,
RowType outputType,
String description) {
@@ -83,6 +85,7 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
this.aggInputRowType = aggInputRowType;
this.isMerge = isMerge;
this.isFinal = isFinal;
+ this.supportAdaptiveLocalHashAgg = supportAdaptiveLocalHashAgg;
}
@SuppressWarnings("unchecked")
@@ -126,17 +129,17 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)
.getBytes();
generatedOperator =
- new HashAggCodeGenerator(
- ctx,
- planner.createRelBuilder(),
- aggInfos,
- inputRowType,
- outputRowType,
- grouping,
- auxGrouping,
- isMerge,
- isFinal)
- .genWithKeys();
+ HashAggCodeGenerator.genWithKeys(
+ ctx,
+ planner.createRelBuilder(),
+ aggInfos,
+ inputRowType,
+ outputRowType,
+ grouping,
+ auxGrouping,
+ isMerge,
+ isFinal,
+ supportAdaptiveLocalHashAgg);
}
return ExecNodeUtil.createOneInputTransformation(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
index 496af1394fe..e2c3759ddaa 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
@@ -19,11 +19,18 @@ package org.apache.flink.table.planner.codegen
import org.apache.flink.table.data.RowData
import org.apache.flink.table.data.binary.BinaryRowData
+import org.apache.flink.table.data.writer.BinaryRowWriter
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.planner.codegen.GenerateUtils.generateRecordStatement
+import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens
+import org.apache.flink.table.planner.functions.aggfunctions._
+import org.apache.flink.table.planner.plan.utils.AggregateInfo
import org.apache.flink.table.runtime.generated.{GeneratedProjection, Projection}
-import org.apache.flink.table.types.logical.RowType
+import org.apache.flink.table.types.logical.{BigIntType, LogicalType, RowType}
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes
+
+import scala.collection.mutable.ArrayBuffer
/**
* CodeGenerator for projection, Take out some fields of [[RowData]] to generate a new [[RowData]].
@@ -124,6 +131,163 @@ object ProjectionCodeGenerator {
new GeneratedProjection(className, code, ctx.references.toArray, ctx.tableConfig)
}
+ /**
+ * If adaptive local hash aggregation takes effect, local hash aggregation will be suppressed. In
+ * order to ensure that the data structure transmitted downstream with doing local hash
+ * aggregation is consistent with the data format transmitted downstream without doing local hash
+ * aggregation, we need to do projection for grouping function value.
+ *
+ * <p> For example, for sql statement "select a, avg(b), count(c) from T group by a", if local
+ * hash aggregation suppressed and a row (1, 5, "a") comes to local hash aggregation, we will pass
+ * (1, 5, 1, 1) to downstream.
+ */
+ def genAdaptiveLocalHashAggValueProjectionCode(
+ ctx: CodeGeneratorContext,
+ inputType: RowType,
+ outClass: Class[_ <: RowData] = classOf[BinaryRowData],
+ inputTerm: String = DEFAULT_INPUT1_TERM,
+ aggInfos: Array[AggregateInfo],
+ outRecordTerm: String = DEFAULT_OUT_RECORD_TERM,
+ outRecordWriterTerm: String = DEFAULT_OUT_RECORD_WRITER_TERM): String = {
+ val fieldExprs: ArrayBuffer[GeneratedExpression] = ArrayBuffer()
+ aggInfos.map {
+ aggInfo =>
+ aggInfo.function match {
+ case sumAggFunction: SumAggFunction =>
+ fieldExprs += genValueProjectionForSumAggFunc(
+ ctx,
+ inputType,
+ inputTerm,
+ sumAggFunction.getResultType.getLogicalType,
+ aggInfo.agg.getArgList.get(0))
+ case _: MaxAggFunction | _: MinAggFunction =>
+ fieldExprs += GenerateUtils.generateFieldAccess(
+ ctx,
+ inputType,
+ inputTerm,
+ aggInfo.agg.getArgList.get(0))
+ case avgAggFunction: AvgAggFunction =>
+ fieldExprs += genValueProjectionForSumAggFunc(
+ ctx,
+ inputType,
+ inputTerm,
+ avgAggFunction.getSumType.getLogicalType,
+ aggInfo.agg.getArgList.get(0))
+ fieldExprs += genValueProjectionForCountAggFunc(
+ ctx,
+ inputTerm,
+ aggInfo.agg.getArgList.get(0))
+ case _: CountAggFunction =>
+ fieldExprs += genValueProjectionForCountAggFunc(
+ ctx,
+ inputTerm,
+ aggInfo.agg.getArgList.get(0))
+ case _: Count1AggFunction =>
+ fieldExprs += genValueProjectionForCount1AggFunc(ctx)
+ }
+ }
+
+ val binaryRowWriter = CodeGenUtils.className[BinaryRowWriter]
+ val typeTerm = outClass.getCanonicalName
+ ctx.addReusableMember(s"private $typeTerm $outRecordTerm= new $typeTerm(${fieldExprs.size});")
+ ctx.addReusableMember(
+ s"private $binaryRowWriter $outRecordWriterTerm = new $binaryRowWriter($outRecordTerm);")
+
+ val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap
+ val setFieldsCode = fieldExprs.zipWithIndex
+ .map {
+ case (fieldExpr, index) =>
+ val pos = fieldExprIdxToOutputRowPosMap.getOrElse(
+ index,
+ throw new CodeGenException(s"Illegal field expr index: $index"))
+ rowSetField(
+ ctx,
+ classOf[BinaryRowData],
+ outRecordTerm,
+ pos.toString,
+ fieldExpr,
+ Option(outRecordWriterTerm))
+ }
+ .mkString("\n")
+
+ val writer = outRecordWriterTerm
+ val resetWriter = s"$writer.reset();"
+ val completeWriter: String = s"$writer.complete();"
+ s"""
+ |$resetWriter
+ |$setFieldsCode
+ |$completeWriter
+ """.stripMargin
+ }
+
+ /**
+ * Do projection for grouping function 'sum(col)' if adaptive local hash aggregation takes effect.
+ * For 'count(col)', we will try to convert the projected value type to sum agg function target
+ * type if col is not null and convert it to default value type if col is null.
+ */
+ def genValueProjectionForSumAggFunc(
+ ctx: CodeGeneratorContext,
+ inputType: LogicalType,
+ inputTerm: String,
+ targetType: LogicalType,
+ index: Int): GeneratedExpression = {
+ val fieldType = getFieldTypes(inputType).get(index)
+ val resultTypeTerm = primitiveTypeTermForType(fieldType)
+ val defaultValue = primitiveDefaultValue(fieldType)
+ val readCode = rowFieldReadAccess(index.toString, inputTerm, fieldType)
+ val Seq(fieldTerm, nullTerm) =
+ ctx.addReusableLocalVariables((resultTypeTerm, "field"), ("boolean", "isNull"))
+
+ val inputCode =
+ s"""
+ |$nullTerm = $inputTerm.isNullAt($index);
+ |$fieldTerm = $defaultValue;
+ |if (!$nullTerm) {
+ | $fieldTerm = $readCode;
+ |}
+ """.stripMargin.trim
+
+ val expression = GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType)
+ // Convert the projected value type to sum agg func target type.
+ ScalarOperatorGens.generateCast(ctx, expression, targetType, true)
+ }
+
+ /**
+ * Do projection for grouping function 'count(col)' if adaptive local hash aggregation takes
+ * effect. 'count(col)' will be convert to 1L if col is not null and convert to 0L if col is null.
+ */
+ def genValueProjectionForCountAggFunc(
+ ctx: CodeGeneratorContext,
+ inputTerm: String,
+ index: Int): GeneratedExpression = {
+ val Seq(fieldTerm, nullTerm) =
+ ctx.addReusableLocalVariables(("long", "field"), ("boolean", "isNull"))
+
+ val inputCode =
+ s"""
+ |$fieldTerm = 0L;
+ |if (!$inputTerm.isNullAt($index)) {
+ | $fieldTerm = 1L;
+ |}
+ """.stripMargin.trim
+
+ GeneratedExpression(fieldTerm, nullTerm, inputCode, new BigIntType())
+ }
+
+ /**
+ * Do projection for grouping function 'count(*)' or 'count(1)' if adaptive local hash agg takes
+ * effect. 'count(*) or count(1)' will be convert to 1L and transmitted to downstream.
+ */
+ def genValueProjectionForCount1AggFunc(ctx: CodeGeneratorContext): GeneratedExpression = {
+ val Seq(fieldTerm, nullTerm) =
+ ctx.addReusableLocalVariables(("long", "field"), ("boolean", "isNull"))
+ val inputCode =
+ s"""
+ |$fieldTerm = 1L;
+ |""".stripMargin.trim
+ GeneratedExpression(fieldTerm, nullTerm, inputCode, new BigIntType())
+ }
+
/** For java invoke. */
def generateProjection(
ctx: CodeGeneratorContext,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
index 98193260222..e98ca5dd266 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
@@ -17,19 +17,24 @@
*/
package org.apache.flink.table.planner.codegen.agg.batch
+import org.apache.flink.annotation.Experimental
+import org.apache.flink.configuration.ConfigOption
+import org.apache.flink.configuration.ConfigOptions.key
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.table.data.{GenericRowData, RowData}
import org.apache.flink.table.data.binary.BinaryRowData
import org.apache.flink.table.data.utils.JoinedRowData
-import org.apache.flink.table.functions.{AggregateFunction, DeclarativeAggregateFunction}
-import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ProjectionCodeGenerator}
-import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList}
+import org.apache.flink.table.functions.DeclarativeAggregateFunction
+import org.apache.flink.table.planner.{JBoolean, JDouble, JLong}
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, OperatorCodeGenerator, ProjectionCodeGenerator}
+import org.apache.flink.table.planner.codegen.CodeGenUtils.ROW_DATA
+import org.apache.flink.table.planner.plan.utils.AggregateInfoList
import org.apache.flink.table.planner.typeutils.RowTypeUtils
import org.apache.flink.table.runtime.generated.GeneratedOperator
import org.apache.flink.table.runtime.operators.TableStreamOperator
import org.apache.flink.table.runtime.operators.aggregate.BytesHashMapSpillMemorySegmentPool
import org.apache.flink.table.runtime.util.collections.binary.BytesMap
-import org.apache.flink.table.types.logical.{LogicalType, RowType}
+import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.tools.RelBuilder
@@ -38,32 +43,78 @@ import org.apache.calcite.tools.RelBuilder
* aggregateBuffers should be update(e.g.: setInt) in [[BinaryRowData]]. (Hash Aggregate performs
* much better than Sort Aggregate).
*/
-class HashAggCodeGenerator(
- ctx: CodeGeneratorContext,
- builder: RelBuilder,
- aggInfoList: AggregateInfoList,
- inputType: RowType,
- outputType: RowType,
- grouping: Array[Int],
- auxGrouping: Array[Int],
- isMerge: Boolean,
- isFinal: Boolean) {
+object HashAggCodeGenerator {
- private lazy val aggInfos: Array[AggregateInfo] = aggInfoList.aggInfos
+ // It is a experimental config, will may be removed later.
+ @Experimental
+ val TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_ENABLED: ConfigOption[JBoolean] =
+ key("table.exec.local-hash-agg.adaptive.enabled")
+ .booleanType()
+ .defaultValue(Boolean.box(true))
+ .withDescription(
+ s"""
+ |Whether to enable adaptive local hash aggregation. Adaptive local hash
+ |aggregation is an optimization of local hash aggregation, which can adaptively
+ |determine whether to continue to do local hash aggregation according to the distinct
+ | value rate of sampling data. If distinct value rate bigger than defined threshold
+ |(see parameter: table.exec.local-hash-agg.adaptive.distinct-value-rate-threshold),
+ |we will stop aggregating and just send the input data to the downstream after a simple
+ |projection. Otherwise, we will continue to do aggregation. Adaptive local hash aggregation
+ |only works in batch mode. Default value of this parameter is true.
+ |""".stripMargin)
- private lazy val functionIdentifiers: Map[AggregateFunction[_, _], String] =
- AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+ @Experimental
+ val TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD: ConfigOption[JLong] =
+ key("table.exec.local-hash-agg.adaptive.sampling-threshold")
+ .longType()
+ .defaultValue(Long.box(5000000L))
+ .withDescription(
+ s"""
+ |If adaptive local hash aggregation is enabled, this value defines how
+ |many records will be used as sampled data to calculate distinct value rate
+ |(see parameter: table.exec.local-hash-agg.adaptive.distinct-value-rate-threshold)
+ |for the local aggregate. The higher the sampling threshold, the more accurate
+ |the distinct value rate is. But as the sampling threshold increases, local
+ |aggregation is meaningless when the distinct values rate is low.
+ |The default value is 5000000.
+ |""".stripMargin)
- private lazy val aggBufferNames: Array[Array[String]] =
- AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+ @Experimental
+ val TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_DISTINCT_VALUE_RATE_THRESHOLD: ConfigOption[JDouble] =
+ key("table.exec.local-hash-agg.adaptive.distinct-value-rate-threshold")
+ .doubleType()
+ .defaultValue(0.5d)
+ .withDescription(
+ s"""
+ |The distinct value rate can be defined as the number of local
+ |aggregation result for the sampled data divided by the sampling
+ |threshold (see ${TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD.key()}).
+ |If the computed result is lower than the given configuration value,
+ |the remaining input records proceed to do local aggregation, otherwise
+ |the remaining input records are subjected to simple projection which
+ |calculation cost is less than local aggregation. The default value is 0.5.
+ |""".stripMargin)
- private lazy val aggBufferTypes: Array[Array[LogicalType]] =
- AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
+ def genWithKeys(
+ ctx: CodeGeneratorContext,
+ builder: RelBuilder,
+ aggInfoList: AggregateInfoList,
+ inputType: RowType,
+ outputType: RowType,
+ grouping: Array[Int],
+ auxGrouping: Array[Int],
+ isMerge: Boolean,
+ isFinal: Boolean,
+ supportAdaptiveLocalHashAgg: Boolean)
+ : GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
- private lazy val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
- private lazy val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
+ val aggInfos = aggInfoList.aggInfos
+ val functionIdentifiers = AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+ val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+ val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
+ val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
+ val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
- def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val className = if (isFinal) "HashAggregateWithKeys" else "LocalHashAggregateWithKeys"
@@ -74,6 +125,10 @@ class HashAggCodeGenerator(
// gen code to do group key projection from input
val currentKeyTerm = CodeGenUtils.newName("currentKey")
val currentKeyWriterTerm = CodeGenUtils.newName("currentKeyWriter")
+ // currentValueTerm and currentValueWriterTerm are used for value
+ // projection while supportAdaptiveLocalHashAgg is true.
+ val currentValueTerm = CodeGenUtils.newName("currentValue")
+ val currentValueWriterTerm = CodeGenUtils.newName("currentValueWriter")
val keyProjectionCode = ProjectionCodeGenerator
.generateProjectionExpression(
ctx,
@@ -85,6 +140,21 @@ class HashAggCodeGenerator(
outRecordWriterTerm = currentKeyWriterTerm)
.code
+ val valueProjectionCode =
+ if (!isFinal && supportAdaptiveLocalHashAgg) {
+ ProjectionCodeGenerator.genAdaptiveLocalHashAggValueProjectionCode(
+ ctx,
+ inputType,
+ classOf[BinaryRowData],
+ inputTerm = inputTerm,
+ aggInfos,
+ outRecordTerm = currentValueTerm,
+ outRecordWriterTerm = currentValueWriterTerm
+ )
+ } else {
+ ""
+ }
+
// gen code to create groupKey, aggBuffer Type array
// it will be used in BytesHashMap and BufferedKVExternalSorter if enable fallback
val groupKeyTypesTerm = CodeGenUtils.newName("groupKeyTypes")
@@ -173,6 +243,77 @@ class HashAggCodeGenerator(
HashAggCodeGenHelper.prepareMetrics(ctx, aggregateMapTerm, if (isFinal) sorterTerm else null)
+ // Do adaptive hash aggregation
+ val outputResultForAdaptiveLocalHashAgg = {
+ // gen code to iterating the aggregate map and output to downstream
+ val inputUnboxingCode = s"${ctx.reuseInputUnboxingCode(reuseAggBufferTerm)}"
+ s"""
+ | // set result and output
+ | $reuseGroupKeyTerm = ($ROW_DATA)$currentKeyTerm;
+ | $reuseAggBufferTerm = ($ROW_DATA)$currentValueTerm;
+ | $inputUnboxingCode
+ | ${outputExpr.code}
+ | ${OperatorCodeGenerator.generateCollect(outputExpr.resultTerm)}
+ |
+ """.stripMargin
+ }
+ val localAggSuppressedTerm = CodeGenUtils.newName("localAggSuppressed")
+ ctx.addReusableMember(s"private transient boolean $localAggSuppressedTerm = false;")
+ val (
+ distinctCountIncCode,
+ totalCountIncCode,
+ adaptiveSamplingCode,
+ adaptiveLocalHashAggCode,
+ flushResultSuppressEnableCode) = {
+ // from these conditions we know that it must be a distinct operation
+ if (
+ !isFinal &&
+ ctx.tableConfig.get(TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_ENABLED) &&
+ supportAdaptiveLocalHashAgg
+ ) {
+ val adaptiveDistinctCountTerm = CodeGenUtils.newName("distinctCount")
+ val adaptiveTotalCountTerm = CodeGenUtils.newName("totalCount")
+ ctx.addReusableMember(s"private transient long $adaptiveDistinctCountTerm = 0;")
+ ctx.addReusableMember(s"private transient long $adaptiveTotalCountTerm = 0;")
+
+ val samplingThreshold =
+ ctx.tableConfig.get(TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD)
+ val distinctValueRateThreshold =
+ ctx.tableConfig.get(TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_DISTINCT_VALUE_RATE_THRESHOLD)
+
+ (
+ s"$adaptiveDistinctCountTerm++;",
+ s"$adaptiveTotalCountTerm++;",
+ s"""
+ |if ($adaptiveTotalCountTerm == $samplingThreshold) {
+ | $logTerm.info("Local hash aggregation checkpoint reached, sampling threshold = " +
+ | $samplingThreshold + ", distinct value count = " + $adaptiveDistinctCountTerm + ", total = " +
+ | $adaptiveTotalCountTerm + ", distinct value rate threshold = "
+ | + $distinctValueRateThreshold);
+ | if ($adaptiveDistinctCountTerm / (1.0 * $adaptiveTotalCountTerm) > $distinctValueRateThreshold) {
+ | $logTerm.info("Local hash aggregation is suppressed");
+ | $localAggSuppressedTerm = true;
+ | }
+ |}
+ |""".stripMargin,
+ s"""
+ |if ($localAggSuppressedTerm) {
+ | $valueProjectionCode
+ | $outputResultForAdaptiveLocalHashAgg
+ | return;
+ |}
+ |""".stripMargin,
+ s"""
+ |if ($localAggSuppressedTerm) {
+ | $outputResultFromMap
+ | return;
+ |}
+ |""".stripMargin)
+ } else {
+ ("", "", "", "", "")
+ }
+ }
+
val lazyInitAggBufferCode = if (auxGrouping.nonEmpty) {
s"""
|// lazy init agg buffer (with auxGrouping)
@@ -188,11 +329,15 @@ class HashAggCodeGenerator(
|${ctx.reuseInputUnboxingCode(inputTerm)}
| // project key from input
|$keyProjectionCode
+ |
+ |$adaptiveLocalHashAggCode
+ |
| // look up output buffer using current group key
|$lookupInfo = ($lookupInfoTypeTerm) $aggregateMapTerm.lookup($currentKeyTerm);
|$currentAggBufferTerm = ($binaryRowTypeTerm) $lookupInfo.getValue();
|
|if (!$lookupInfo.isFound()) {
+ | $distinctCountIncCode
| $lazyInitAggBufferCode
| // append empty agg buffer into aggregate map for current group key
| try {
@@ -202,10 +347,16 @@ class HashAggCodeGenerator(
| $dealWithAggHashMapOOM
| }
|}
+ |
+ |$totalCountIncCode
+ |$adaptiveSamplingCode
+ |
| // aggregate buffer fields access
|${ctx.reuseInputUnboxingCode(currentAggBufferTerm)}
| // do aggregate and update agg buffer
|${aggregate.code}
+ | // flush result form map if suppress is enable.
+ |$flushResultSuppressEnableCode
|""".stripMargin.trim
val endInputCode = if (isFinal) {
@@ -227,7 +378,11 @@ class HashAggCodeGenerator(
|}
""".stripMargin
} else {
- s"$outputResultFromMap"
+ s"""
+ |if (!$localAggSuppressedTerm) {
+ | $outputResultFromMap
+ |}
+ |""".stripMargin
}
AggCodeGenHelper.generateOperator(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalHashAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalHashAggregate.scala
index 95cdc4800a2..e22cf17d735 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalHashAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalHashAggregate.scala
@@ -161,6 +161,7 @@ class BatchPhysicalHashAggregate(
FlinkTypeFactory.toLogicalRowType(aggInputRowType),
isMerge,
true, // isFinal is always true
+ false, // supportAdaptiveLocalHashAgg is always false
InputProperty
.builder()
.requiredDistribution(requiredDistribution)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala
index b0c0ea8c63a..6963440ef3a 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala
@@ -50,6 +50,7 @@ class BatchPhysicalLocalHashAggregate(
inputRowType: RelDataType,
grouping: Array[Int],
auxGrouping: Array[Int],
+ val supportAdaptiveLocalHashAgg: Boolean,
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)])
extends BatchPhysicalHashAggregateBase(
cluster,
@@ -71,6 +72,7 @@ class BatchPhysicalLocalHashAggregate(
inputRowType,
grouping,
auxGrouping,
+ supportAdaptiveLocalHashAgg,
aggCallToAggFunction)
}
@@ -134,6 +136,7 @@ class BatchPhysicalLocalHashAggregate(
FlinkTypeFactory.toLogicalRowType(inputRowType),
false, // isMerge is always false
false, // isFinal is always false
+ supportAdaptiveLocalHashAgg,
getInputProperty,
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala
index 7b4de5b5f60..412046d2400 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala
@@ -203,7 +203,8 @@ trait BatchPhysicalAggRuleBase {
auxGrouping: Array[Int],
aggBufferTypes: Array[Array[DataType]],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
- isLocalHashAgg: Boolean): BatchPhysicalGroupAggregateBase = {
+ isLocalHashAgg: Boolean,
+ supportAdaptiveLocalHashAgg: Boolean): BatchPhysicalGroupAggregateBase = {
val inputRowType = input.getRowType
val aggFunctions = aggCallToAggFunction.map(_._2).toArray
@@ -231,6 +232,7 @@ trait BatchPhysicalAggRuleBase {
inputRowType,
grouping,
auxGrouping,
+ supportAdaptiveLocalHashAgg,
aggCallToAggFunction)
} else {
new BatchPhysicalLocalSortAggregate(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala
index f208cdfbed6..f6ad3a06f89 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala
@@ -94,6 +94,12 @@ class BatchPhysicalHashAggRule
val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions)
val aggProvidedTraitSet = agg.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
+ // Judge whether this agg operator support adaptive local hash agg.
+ // If all agg function in agg operator can projection, then it support
+ // adaptive local hash agg. Otherwise false.
+ val supportAdaptiveLocalHashAgg =
+ AggregateUtil.doAllAggSupportAdaptiveLocalHashAgg(aggCallToAggFunction.map(_._1))
+
// create two-phase agg if possible
if (isTwoPhaseAggWorkable(aggFunctions, tableConfig)) {
// create BatchPhysicalLocalHashAggregate
@@ -109,7 +115,9 @@ class BatchPhysicalHashAggRule
auxGroupSet,
aggBufferTypes,
aggCallToAggFunction,
- isLocalHashAgg = true)
+ isLocalHashAgg = true,
+ supportAdaptiveLocalHashAgg
+ )
// create global BatchPhysicalHashAggregate
val (globalGroupSet, globalAuxGroupSet) = getGlobalAggGroupSetPair(groupSet, auxGroupSet)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalJoinRuleBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalJoinRuleBase.scala
index 94cf8884a67..66540df1989 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalJoinRuleBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalJoinRuleBase.scala
@@ -70,6 +70,7 @@ trait BatchPhysicalJoinRuleBase {
node.getRowType, // input row type
distinctKeys.toArray,
Array.empty,
+ supportAdaptiveLocalHashAgg = false,
Seq())
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala
index f36dbbc6523..8758bb1e17e 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala
@@ -108,7 +108,9 @@ class BatchPhysicalSortAggRule
auxGroupSet,
aggBufferTypes,
aggCallToAggFunction,
- isLocalHashAgg = false)
+ isLocalHashAgg = false,
+ supportAdaptiveLocalHashAgg = false
+ )
// create global BatchPhysicalSortAggregate
val (globalGroupSet, globalAuxGroupSet) = getGlobalAggGroupSetPair(groupSet, auxGroupSet)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala
index d280fbed810..aba39ef31a2 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala
@@ -94,7 +94,8 @@ abstract class EnforceLocalAggRuleBase(operand: RelOptRuleOperand, description:
auxGrouping,
aggBufferTypes,
aggCallToAggFunction,
- isLocalHashAgg
+ isLocalHashAgg,
+ supportAdaptiveLocalHashAgg = false
)
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index ec06e058a88..ef25844477f 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -926,6 +926,26 @@ object AggregateUtil extends Enumeration {
aggInfos.isEmpty || supportMerge
}
+ /**
+ * Return true if all aggregates can be projected for adaptive local hash aggregate. False
+ * otherwise.
+ */
+ def doAllAggSupportAdaptiveLocalHashAgg(aggCalls: Seq[AggregateCall]): Boolean = {
+ aggCalls.forall {
+ aggCall =>
+ // TODO support adaptive local hash agg while agg call with filter condition.
+ if (aggCall.filterArg >= 0) {
+ return false
+ }
+ aggCall.getAggregation match {
+ case _: SqlCountAggFunction | _: SqlAvgAggFunction | _: SqlMinMaxAggFunction |
+ _: SqlSumAggFunction =>
+ true
+ case _ => false
+ }
+ }
+ }
+
/** Return true if all aggregates can be split. False otherwise. */
def doAllAggSupportSplit(aggCalls: util.List[AggregateCall]): Boolean = {
aggCalls.forall {
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGeneratorTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGeneratorTest.scala
index 2a7b433ad00..78e50368396 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGeneratorTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGeneratorTest.scala
@@ -22,7 +22,7 @@ import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction.LongAvgAggFunction
import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList}
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
-import org.apache.flink.table.types.logical.{BigIntType, DoubleType, LogicalType, RowType, VarCharType}
+import org.apache.flink.table.types.logical._
import org.apache.calcite.rel.core.AggregateCall
import org.junit.Test
@@ -124,7 +124,7 @@ class HashAggCodeGeneratorTest extends BatchAggTestBase {
(inputType, localOutputType)
}
val auxGrouping = if (isMerge) Array(1) else Array(4)
- val generator = new HashAggCodeGenerator(
+ val genOp = HashAggCodeGenerator.genWithKeys(
ctx,
relBuilder,
aggInfoList,
@@ -133,8 +133,8 @@ class HashAggCodeGeneratorTest extends BatchAggTestBase {
Array(0),
auxGrouping,
isMerge,
- isFinal)
- val genOp = generator.genWithKeys()
+ isFinal,
+ false)
(new CodeGenOperatorFactory[RowData](genOp), iType, oType)
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
index b9074397fa6..5535fc0671e 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
@@ -1163,6 +1163,7 @@ class FlinkRelMdHandlerTestBase {
studentBatchScan.getRowType,
Array(3),
auxGrouping = Array(),
+ true,
aggCallToAggFunction)
val batchExchange1 = new BatchPhysicalExchange(
@@ -1428,6 +1429,7 @@ class FlinkRelMdHandlerTestBase {
calcOnStudentScan.getRowType,
Array(3),
auxGrouping = Array(),
+ true,
aggCallToAggFunction)
val batchExchange1 = new BatchPhysicalExchange(
@@ -1580,6 +1582,7 @@ class FlinkRelMdHandlerTestBase {
studentBatchScan.getRowType,
Array(0),
auxGrouping = Array(1, 4),
+ true,
aggCallToAggFunction)
val hash0 = FlinkRelDistribution.hash(Array(0), requireStrict = true)
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
index 6bcf4364369..556f09eba54 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.runtime.batch.sql.agg
import org.apache.flink.api.common.BatchShuffleMode
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfoBase}
import org.apache.flink.api.scala._
import org.apache.flink.configuration.{ExecutionOptions, JobManagerOptions}
import org.apache.flink.configuration.JobManagerOptions.SchedulerType
@@ -298,9 +298,16 @@ abstract class AggregateITCaseBase(testName: String) extends BatchTestBase {
val tableRows = tableData.map(toRow)
val tupleTypeInfo = implicitly[TypeInformation[T]]
- val fieldInfos = tupleTypeInfo.getGenericParameters.values()
- import scala.collection.JavaConverters._
- val rowTypeInfo = new RowTypeInfo(fieldInfos.asScala.toArray: _*)
+ val rowTypeInfo: RowTypeInfo = if (tupleTypeInfo.isTupleType) {
+ new RowTypeInfo(
+ tupleTypeInfo
+ .asInstanceOf[TupleTypeInfoBase[T]]
+ .getFieldTypes: _*)
+ } else {
+ val fieldInfos = tupleTypeInfo.getGenericParameters.values()
+ import scala.collection.JavaConverters._
+ new RowTypeInfo(fieldInfos.asScala.toArray: _*)
+ }
newTableId += 1
val tableName = "TestTableX" + newTableId
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/HashAggITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/HashAggITCase.scala
index 15191584f48..4ff207754b4 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/HashAggITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/HashAggITCase.scala
@@ -17,12 +17,324 @@
*/
package org.apache.flink.table.planner.runtime.batch.sql.agg
-import org.apache.flink.table.api.config.ExecutionConfigOptions
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.DataTypes
+import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions}
+import org.apache.flink.table.planner.codegen.agg.batch.HashAggCodeGenerator
+import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
+
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import java.math.BigDecimal
/** AggregateITCase using HashAgg Operator. */
-class HashAggITCase extends AggregateITCaseBase("HashAggregate") {
+@RunWith(classOf[Parameterized])
+class HashAggITCase(adaptiveLocalHashAggEnable: Boolean)
+ extends AggregateITCaseBase("HashAggregate") {
override def prepareAggOp(): Unit = {
tEnv.getConfig.set(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "SortAgg")
+ if (adaptiveLocalHashAggEnable) {
+ tEnv.getConfig
+ .set(OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY, "TWO_PHASE")
+ tEnv.getConfig.set(
+ HashAggCodeGenerator.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_ENABLED,
+ Boolean.box(true))
+ tEnv.getConfig.set(
+ HashAggCodeGenerator.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD,
+ Long.box(5L))
+ }
+ }
+
+ @Test
+ def testAdaptiveLocalHashAggWithHighAggregationDegree(): Unit = {
+ checkQuery(
+ Seq(
+ (1, 1, 1, 1, 1L, 1.1d),
+ (1, 1, 1, 2, 1L, 1.2d),
+ (1, 1, 2, 3, 2L, 2.2d),
+ (1, 1, 2, 2, 2L, 1d),
+ (1, 1, 3, 3, 3L, 3d),
+ (1, 1, 2, 2, 2L, 4d),
+ (1, 2, 1, 1, 1L, 1.1d),
+ (1, 2, 1, 2, 1L, 2.3d),
+ (1, 3, 1, 1, 1L, 3.3d),
+ (1, 4, 1, 1, 1L, 1.1d),
+ (2, 1, 2, 2, 2L, 2.2d),
+ (2, 2, 3, 3, 3L, 3.3d)
+ ),
+ """
+ | SELECT f0, f1, sum(f2), avg(f2), max(f3), min(f3), count(f3), count(*), sum(f4), sum(f5), avg(f4), avg(f5)
+ | FROM TableName GROUP BY f0, f1
+ |""".stripMargin,
+ Seq(
+ (1, 1, 11, 1, 3, 1, 6, 6, 11, 12.5, 1, 2.0833333333333335),
+ (1, 2, 2, 1, 2, 1, 2, 2, 2, 3.4, 1, 1.7),
+ (1, 3, 1, 1, 1, 1, 1, 1, 1, 3.3, 1, 3.3),
+ (1, 4, 1, 1, 1, 1, 1, 1, 1, 1.1, 1, 1.1),
+ (2, 1, 2, 2, 2, 2, 1, 1, 2, 2.2, 2, 2.2),
+ (2, 2, 3, 3, 3, 3, 1, 1, 3, 3.3, 3, 3.3)
+ )
+ )
+ }
+
+ @Test
+ def testAdaptiveLocalHashAggWithLowAggregationDegree(): Unit = {
+ checkQuery(
+ Seq(
+ (1, 1, 1, 1, 1L, 1.1d),
+ (1, 1, 1, 2, 1L, 1.2d),
+ (1, 2, 2, 3, 2L, 2.2d),
+ (1, 3, 2, 2, 2L, 1d),
+ (1, 4, 3, 3, 3L, 3d),
+ (1, 5, 2, 2, 2L, 4d),
+ (1, 6, 1, 1, 1L, 3.3d),
+ (2, 1, 1, 2, 1L, 2.3d),
+ (2, 2, 1, 1, 1L, 3.3d),
+ (2, 3, 1, 1, 1L, 3.3d),
+ (2, 3, 2, 2, 2L, 2.2d),
+ (2, 3, 3, 3, 3L, 3.3d)
+ ),
+ """
+ | SELECT f0, f1, sum(f2), avg(f2), max(f3), min(f3), count(f3), count(*), sum(f4), sum(f5), avg(f4), avg(f5)
+ | FROM TableName GROUP BY f0, f1
+ |""".stripMargin,
+ Seq(
+ (1, 1, 2, 1, 2, 1, 2, 2, 2, 2.3, 1, 1.15),
+ (1, 2, 2, 2, 3, 3, 1, 1, 2, 2.2, 2, 2.2),
+ (1, 3, 2, 2, 2, 2, 1, 1, 2, 1.0, 2, 1.0),
+ (1, 4, 3, 3, 3, 3, 1, 1, 3, 3.0, 3, 3.0),
+ (1, 5, 2, 2, 2, 2, 1, 1, 2, 4.0, 2, 4.0),
+ (1, 6, 1, 1, 1, 1, 1, 1, 1, 3.3, 1, 3.3),
+ (2, 1, 1, 1, 2, 2, 1, 1, 1, 2.3, 1, 2.3),
+ (2, 2, 1, 1, 1, 1, 1, 1, 1, 3.3, 1, 3.3),
+ (2, 3, 6, 2, 3, 1, 3, 3, 6, 8.8, 2, 2.9333333333333336)
+ )
+ )
+ }
+
+ @Test
+ def testAdaptiveLocalHashAggWithRowLessThanSamplingThreshold(): Unit = {
+ checkQuery(
+ Seq((1, 1, 1, 1, 1L, 1.1d), (1, 1, 1, 2, 1L, 1.2d), (1, 2, 2, 3, 2L, 2.2d)),
+ """
+ | SELECT f0, f1, sum(f2), avg(f2), max(f3), min(f3), count(f3), count(*), sum(f4), sum(f5), avg(f4), avg(f5)
+ | FROM TableName GROUP BY f0, f1
+ |""".stripMargin,
+ Seq((1, 1, 2, 1, 2, 1, 2, 2, 2, 2.3, 1, 1.15), (1, 2, 2, 2, 3, 3, 1, 1, 2, 2.2, 2, 2.2))
+ )
+ }
+
+ @Test
+ def testAdaptiveLocalHashAggWithNullValue(): Unit = {
+ val testDataWithNullValue = tEnv.fromValues(
+ DataTypes.ROW(
+ DataTypes.FIELD("f0", DataTypes.INT()),
+ DataTypes.FIELD("f1", DataTypes.INT()),
+ DataTypes.FIELD("f2", DataTypes.INT()),
+ DataTypes.FIELD("f3", DataTypes.INT()),
+ DataTypes.FIELD("f4", DataTypes.BIGINT()),
+ DataTypes.FIELD("f5", DataTypes.DOUBLE())
+ ),
+ row(1, 1, 1, 1, null, 1.1d),
+ row(1, 1, null, null, 1L, null),
+ row(1, 1, 2, 3, 2L, 2.2d),
+ row(1, 1, 2, 2, null, 1d),
+ row(1, 1, 2, 2, 2L, 4d),
+ row(1, 1, null, 3, 3L, null),
+ row(1, 2, 1, 1, 1L, 1.1d),
+ row(1, 2, null, 2, null, null),
+ row(1, 3, 1, 1, 1L, null),
+ row(1, 4, 1, 1, 1L, 1.1d),
+ row(2, 1, null, 2, 2L, 2.2d),
+ row(2, 2, 3, null, 3L, 3.3d)
+ )
+
+ checkResult(
+ s"""
+ | SELECT f0, f1, sum(f2), avg(f2), max(f3), min(f3), count(f3), count(*), sum(f4), sum(f5), avg(f4), avg(f5)
+ | FROM $testDataWithNullValue GROUP BY f0, f1
+ |""".stripMargin,
+ Seq(
+ row(1, 1, 7, 1, 3, 1, 5, 6, 8, 8.3, 2, 2.075),
+ row(1, 2, 1, 1, 2, 1, 2, 2, 1, 1.1, 1, 1.1),
+ row(1, 3, 1, 1, 1, 1, 1, 1, 1, null, 1, null),
+ row(1, 4, 1, 1, 1, 1, 1, 1, 1, 1.1, 1, 1.1),
+ row(2, 1, null, null, 2, 2, 1, 1, 2, 2.2, 2, 2.2),
+ row(2, 2, 3, 3, null, null, 0, 1, 3, 3.3, 3, 3.3)
+ )
+ )
+
+ }
+
+ @Test
+ def testAdaptiveHashAggWithSumAndAvgFunctionForNumericalType(): Unit = {
+ val testDataWithAllTypes = tEnv.fromValues(
+ DataTypes.ROW(
+ DataTypes.FIELD("f0", DataTypes.INT()),
+ DataTypes.FIELD("f1", DataTypes.TINYINT()),
+ DataTypes.FIELD("f2", DataTypes.SMALLINT()),
+ DataTypes.FIELD("f3", DataTypes.BIGINT()),
+ DataTypes.FIELD("f4", DataTypes.FLOAT()),
+ DataTypes.FIELD("f5", DataTypes.DOUBLE()),
+ DataTypes.FIELD("f6", DataTypes.DECIMAL(5, 2)),
+ DataTypes.FIELD("f7", DataTypes.DECIMAL(14, 3)),
+ DataTypes.FIELD("f8", DataTypes.DECIMAL(38, 18))
+ ),
+ row(
+ 1,
+ 1,
+ 1,
+ 1000L,
+ -1.1f,
+ 1.1d,
+ new BigDecimal("111.11"),
+ new BigDecimal("11111111111.111"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ ),
+ row(
+ 1,
+ 1,
+ 1,
+ 1000L,
+ -1.1f,
+ 1.1d,
+ new BigDecimal("111.11"),
+ new BigDecimal("11111111111.111"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ ),
+ row(
+ 1,
+ 1,
+ 1,
+ 1000L,
+ -1.1f,
+ 1.1d,
+ new BigDecimal("111.11"),
+ new BigDecimal("11111111111.111"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ ),
+ row(
+ 1,
+ 1,
+ 1,
+ 1000L,
+ -1.1f,
+ 1.1d,
+ new BigDecimal("111.11"),
+ new BigDecimal("11111111111.111"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ ),
+ row(
+ 2,
+ 1,
+ 1,
+ 1000L,
+ -1.1f,
+ 1.1d,
+ new BigDecimal("111.11"),
+ new BigDecimal("11111111111.111"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ ),
+ row(
+ 2,
+ 1,
+ 1,
+ 1000L,
+ -1.1f,
+ 1.1d,
+ new BigDecimal("111.11"),
+ new BigDecimal("11111111111.111"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ ),
+ row(
+ 3,
+ 1,
+ 1,
+ 1000L,
+ -1.1f,
+ 1.1d,
+ new BigDecimal("111.11"),
+ new BigDecimal("11111111111.111"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ )
+ )
+
+ checkResult(
+ s"""
+ | SELECT f0, sum(f1), avg(f1), sum(f2), avg(f2), sum(f3), avg(f3),
+ | sum(f4), avg(f4), sum(f5), avg(f5), sum(f6), avg(f6),
+ | sum(f7), avg(f7), sum(f8), avg(f8)
+ | FROM $testDataWithAllTypes GROUP BY f0
+ |""".stripMargin,
+ Seq(
+ row(
+ 1,
+ 4.toByte,
+ 1.toByte,
+ 4.toShort,
+ 1.toShort,
+ 4000L,
+ 1000L,
+ -4.4f,
+ -1.1f,
+ 4.4d,
+ 1.1d,
+ new BigDecimal("444.44"),
+ new BigDecimal("111.110000"),
+ new BigDecimal("44444444444.444"),
+ new BigDecimal("11111111111.111000"),
+ new BigDecimal("44444444444444444444.444444444444444444"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ ),
+ row(
+ 2,
+ 2.toByte,
+ 1.toByte,
+ 2.toShort,
+ 1.toShort,
+ 2000L,
+ 1000L,
+ -2.2f,
+ -1.1f,
+ 2.2d,
+ 1.1d,
+ new BigDecimal("222.22"),
+ new BigDecimal("111.110000"),
+ new BigDecimal("22222222222.222"),
+ new BigDecimal("11111111111.111000"),
+ new BigDecimal("22222222222222222222.222222222222222222"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ ),
+ row(
+ 3,
+ 1.toByte,
+ 1.toByte,
+ 1.toShort,
+ 1.toShort,
+ 1000L,
+ 1000L,
+ -1.1f,
+ -1.1f,
+ 1.1d,
+ 1.1d,
+ new BigDecimal("111.11"),
+ new BigDecimal("111.110000"),
+ new BigDecimal("11111111111.111"),
+ new BigDecimal("11111111111.111000"),
+ new BigDecimal("11111111111111111111.111111111111111111"),
+ new BigDecimal("11111111111111111111.111111111111111111")
+ )
+ )
+ )
+ }
+}
+
+object HashAggITCase {
+ @Parameterized.Parameters(name = "adaptiveLocalHashAggEnable={0}")
+ def parameters(): java.util.Collection[Boolean] = {
+ java.util.Arrays.asList(true, false)
}
}