You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2020/07/24 10:49:29 UTC
[flink] branch master updated: [FLINK-15803][table] Use
AggregateInfo as the single source of type description
This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new a37131d [FLINK-15803][table] Use AggregateInfo as the single source of type description
a37131d is described below
commit a37131d576c2cc281a1aa5bfe186bdd72748ac06
Author: Timo Walther <tw...@apache.org>
AuthorDate: Wed Jul 22 13:12:19 2020 +0200
[FLINK-15803][table] Use AggregateInfo as the single source of type description
This refactors a lot of the code generation around aggregate functions. It does
this for better code maintainability and in particular for having a single source
of generating all types (arguments, accumulator, result).
This closes #12967.
---
.../codegen/agg/batch/AggCodeGenHelper.scala | 590 ++++++++++++---------
.../agg/batch/AggWithoutKeysCodeGenerator.scala | 31 +-
.../codegen/agg/batch/HashAggCodeGenHelper.scala | 216 ++++----
.../codegen/agg/batch/HashAggCodeGenerator.scala | 40 +-
.../agg/batch/HashWindowCodeGenerator.scala | 11 +-
.../codegen/agg/batch/SortAggCodeGenerator.scala | 34 +-
.../agg/batch/SortWindowCodeGenerator.scala | 14 +-
.../codegen/agg/batch/WindowCodeGenerator.scala | 116 ++--
.../types/LogicalTypeDataTypeConverter.java | 2 +
9 files changed, 568 insertions(+), 486 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
index e7fa871..72be697 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
@@ -22,26 +22,24 @@ import org.apache.flink.runtime.util.SingleElementIterator
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.table.data.{GenericRowData, RowData}
import org.apache.flink.table.expressions.ApiExpressionUtils.localRef
-import org.apache.flink.table.expressions.{Expression, _}
-import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.planner.codegen.CodeGenUtils._
+import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.STREAM_RECORD
import org.apache.flink.table.planner.codegen._
import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver
import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef
import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
-import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getAggUserDefinedInputTypes}
+import org.apache.flink.table.planner.plan.utils.AggregateInfo
import org.apache.flink.table.runtime.context.ExecutionContextImpl
import org.apache.flink.table.runtime.generated.{GeneratedAggsHandleFunction, GeneratedOperator}
-import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.{fromDataTypeToLogicalType, fromLogicalTypeToDataType}
+import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter._
import org.apache.flink.table.runtime.typeutils.InternalSerializers
-import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical.{DistinctType, LogicalType, RowType}
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.rex.RexNode
import org.apache.calcite.tools.RelBuilder
import scala.annotation.tailrec
@@ -52,34 +50,50 @@ import scala.annotation.tailrec
object AggCodeGenHelper {
def getAggBufferNames(
- auxGrouping: Array[Int], aggregates: Seq[UserDefinedFunction]): Array[Array[String]] = {
- auxGrouping.zipWithIndex.map {
- case (_, index) => Array(s"aux_group$index")
- } ++ aggregates.zipWithIndex.toArray.map {
- case (a: DeclarativeAggregateFunction, index) =>
- val idx = auxGrouping.length + index
- a.aggBufferAttributes.map(attr => s"agg${idx}_${attr.getName}")
- case (_: AggregateFunction[_, _], index) =>
- val idx = auxGrouping.length + index
- Array(s"agg$idx")
+ auxGrouping: Array[Int],
+ aggInfos: Seq[AggregateInfo])
+ : Array[Array[String]] = {
+ val auxGroupingNames = auxGrouping.indices
+ .map(index => Array(s"aux_group$index"))
+
+ val aggNames = aggInfos.map { aggInfo =>
+
+ val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+ aggInfo.function match {
+
+ // create one buffer for each attribute in declarative functions
+ case function: DeclarativeAggregateFunction =>
+ function.aggBufferAttributes.map(attr => s"agg${aggBufferIdx}_${attr.getName}")
+
+ // create one buffer for imperative functions
+ case _: AggregateFunction[_, _] =>
+ Array(s"agg$aggBufferIdx")
+ }
}
+
+ (auxGroupingNames ++ aggNames).toArray
}
def getAggBufferTypes(
- inputType: RowType, auxGrouping: Array[Int], aggregates: Seq[UserDefinedFunction])
+ inputType: RowType,
+ auxGrouping: Array[Int],
+ aggInfos: Seq[AggregateInfo])
: Array[Array[LogicalType]] = {
- auxGrouping.map { index =>
- Array(inputType.getTypeAt(index))
- } ++ aggregates.map {
- case a: DeclarativeAggregateFunction => a.getAggBufferTypes.map(_.getLogicalType)
- case a: AggregateFunction[_, _] =>
- Array(fromDataTypeToLogicalType(getAccumulatorTypeOfAggregateFunction(a)))
- }.toArray[Array[LogicalType]]
+ val auxGroupingTypes = auxGrouping
+ .map { index =>
+ Array(inputType.getTypeAt(index))
+ }
+
+ val aggTypes = aggInfos
+ .map(_.externalAccTypes.map(fromDataTypeToLogicalType))
+
+ auxGroupingTypes ++ aggTypes
}
- def getUdaggs(
- aggregates: Seq[UserDefinedFunction]): Map[AggregateFunction[_, _], String] = {
- aggregates
+ def getFunctionIdentifiers(aggInfos: Seq[AggregateInfo]): Map[AggregateFunction[_, _], String] = {
+ aggInfos
+ .map(_.function)
.filter(a => a.isInstanceOf[AggregateFunction[_, _]])
.map(a => a -> CodeGenUtils.udfFieldName(a)).toMap
.asInstanceOf[Map[AggregateFunction[_, _], String]]
@@ -95,7 +109,8 @@ object AggCodeGenHelper {
private[flink] def addAggsHandler(
aggsHandler: GeneratedAggsHandleFunction,
ctx: CodeGeneratorContext,
- aggsHandlerCtx: CodeGeneratorContext): String = {
+ aggsHandlerCtx: CodeGeneratorContext)
+ : String = {
ctx.addReusableInnerClass(aggsHandler.getClassName, aggsHandler.getCode)
val handler = CodeGenUtils.newName("handler")
ctx.addReusableMember(s"${aggsHandler.getClassName} $handler = null;")
@@ -116,7 +131,8 @@ object AggCodeGenHelper {
*/
private[flink] def genGroupKeyChangedCheckCode(
currentKeyTerm: String,
- lastKeyTerm: String): String = {
+ lastKeyTerm: String)
+ : String = {
s"""
|$currentKeyTerm.getSizeInBytes() != $lastKeyTerm.getSizeInBytes() ||
| !(org.apache.flink.table.data.binary.BinaryRowDataUtil.byteArrayEquals(
@@ -133,27 +149,25 @@ object AggCodeGenHelper {
builder: RelBuilder,
grouping: Array[Int],
auxGrouping: Array[Int],
- aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
- aggArgs: Array[Array[Int]],
- aggregates: Seq[UserDefinedFunction],
- aggResultTypes: Seq[DataType],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
inputTerm: String,
inputType: RowType,
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[LogicalType]],
outputType: RowType,
- forHashAgg: Boolean = false): (String, String, GeneratedExpression) = {
+ forHashAgg: Boolean = false)
+ : (String, String, GeneratedExpression) = {
// gen code to apply aggregate functions to grouping elements
val argsMapping = buildAggregateArgsMapping(
- isMerge, grouping.length, inputType, auxGrouping, aggArgs, aggBufferTypes)
+ isMerge, grouping.length, inputType, auxGrouping, aggInfos, aggBufferTypes)
val aggBufferExprs = genFlatAggBufferExprs(
isMerge,
ctx,
builder,
auxGrouping,
- aggregates,
+ aggInfos,
argsMapping,
aggBufferNames,
aggBufferTypes)
@@ -165,8 +179,8 @@ object AggCodeGenHelper {
inputTerm,
grouping,
auxGrouping,
- aggregates,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
aggBufferExprs,
forHashAgg)
@@ -177,9 +191,8 @@ object AggCodeGenHelper {
inputType,
inputTerm,
auxGrouping,
- aggCallToAggFunction,
- aggregates,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
argsMapping,
aggBufferNames,
aggBufferTypes,
@@ -192,9 +205,8 @@ object AggCodeGenHelper {
builder,
grouping,
auxGrouping,
- aggregates,
- aggResultTypes,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
argsMapping,
aggBufferNames,
aggBufferTypes,
@@ -218,8 +230,10 @@ object AggCodeGenHelper {
aggBufferOffset: Int,
inputType: RowType,
auxGrouping: Array[Int],
- aggArgs: Array[Array[Int]],
- aggBufferTypes: Array[Array[LogicalType]]): Array[Array[(Int, LogicalType)]] = {
+ aggInfos: Seq[AggregateInfo],
+ aggBufferTypes: Array[Array[LogicalType]])
+ : Array[Array[(Int, LogicalType)]] = {
+ val aggArgs = aggInfos.map(_.argIndexes).toArray
val auxGroupingMapping = auxGrouping.indices.map {
i => Array[(Int, LogicalType)]((i, aggBufferTypes(i)(0)))
}.toArray
@@ -281,29 +295,47 @@ object AggCodeGenHelper {
ctx: CodeGeneratorContext,
builder: RelBuilder,
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
+ aggInfos: Seq[AggregateInfo],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[LogicalType]]): Seq[GeneratedExpression] = {
- val exprCodegen = new ExprCodeGenerator(ctx, false)
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
val converter = new ExpressionConverter(builder)
- val accessAuxGroupingExprs = auxGrouping.indices.map {
- idx => newLocalReference(aggBufferNames(idx)(0), aggBufferTypes(idx)(0))
- }.map(_.accept(converter)).map(exprCodegen.generateExpression)
-
- val aggCallExprs = aggregates.zipWithIndex.flatMap {
- case (agg: DeclarativeAggregateFunction, aggIndex: Int) =>
- val idx = auxGrouping.length + aggIndex
- agg.aggBufferAttributes.map(_.accept(
- ResolveReference(ctx, builder, isMerge, agg, idx, argsMapping, aggBufferTypes)))
- case (_: AggregateFunction[_, _], aggIndex: Int) =>
- val idx = auxGrouping.length + aggIndex
- val variableName = aggBufferNames(idx)(0)
- Some(newLocalReference(variableName, aggBufferTypes(idx)(0)))
- }.map(_.accept(converter)).map(exprCodegen.generateExpression)
-
- accessAuxGroupingExprs ++ aggCallExprs
+ val accessAuxGroupingExprs = auxGrouping.indices
+ .map(idx => newLocalReference(aggBufferNames(idx).head, aggBufferTypes(idx).head))
+
+ val aggCallExprs = aggInfos.flatMap { aggInfo =>
+
+ val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+ aggInfo.function match {
+
+ // create a buffer expression for each attribute in declarative functions
+ case function: DeclarativeAggregateFunction =>
+ val ref = ResolveReference(
+ ctx,
+ builder,
+ isMerge,
+ function,
+ aggBufferIdx,
+ argsMapping,
+ aggBufferTypes)
+ function.aggBufferAttributes().map(_.accept(ref))
+
+ // create one buffer for imperative functions
+ case _: AggregateFunction[_, _] =>
+ val aggBufferName = aggBufferNames(aggBufferIdx).head
+ val aggBufferType = aggBufferTypes(aggBufferIdx).head
+ Some(newLocalReference(aggBufferName, aggBufferType))
+ }
+ }
+
+ val aggBufferExprs = accessAuxGroupingExprs ++ aggCallExprs
+
+ aggBufferExprs
+ .map(_.accept(converter))
+ .map(exprCodeGen.generateExpression)
}
/**
@@ -316,12 +348,14 @@ object AggCodeGenHelper {
inputTerm: String,
grouping: Array[Int],
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
aggBufferExprs: Seq[GeneratedExpression],
- forHashAgg: Boolean = false): String = {
- val exprCodegen = new ExprCodeGenerator(ctx, false)
+ forHashAgg: Boolean = false)
+ : String = {
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
.bindInput(inputType, inputTerm = inputTerm, inputFieldMapping = Some(auxGrouping))
+ val converter = new ExpressionConverter(builder)
val initAuxGroupingExprs = {
if (forHashAgg) {
@@ -335,25 +369,27 @@ object AggCodeGenHelper {
GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, idx)
}
- val initAggCallBufferExprs = aggregates.flatMap {
- case (agg: DeclarativeAggregateFunction) =>
- agg.initialValuesExpressions
- case (agg: AggregateFunction[_, _]) =>
- Some(agg)
- }.map {
- case (expr: Expression) => expr.accept(new ExpressionConverter(builder))
- case t@_ => t
- }.map {
- case (rex: RexNode) => exprCodegen.generateExpression(rex)
- case (agg: AggregateFunction[_, _]) =>
- val resultTerm = s"${udaggs(agg)}.createAccumulator()"
- val nullTerm = "false"
- val resultType = getAccumulatorTypeOfAggregateFunction(agg)
- GeneratedExpression(
- genToInternal(ctx, resultType, resultTerm),
- nullTerm,
- "",
- fromDataTypeToLogicalType(resultType))
+ val initAggCallBufferExprs = aggInfos.flatMap { aggInfo =>
+ aggInfo.function match {
+
+ // generate code for each agg buffer in declarative functions
+ case function: DeclarativeAggregateFunction =>
+ val expressions = function.initialValuesExpressions
+ val rexNodes = expressions.map(_.accept(converter))
+ rexNodes.map(exprCodeGen.generateExpression)
+
+ // call createAccumulator() in imperative functions
+ case function: AggregateFunction[_, _] =>
+ val accTerm = s"${functionIdentifiers(function)}.createAccumulator()"
+ val externalAccType = aggInfo.externalAccTypes.head
+ val internalAccType = externalAccType.getLogicalType
+ val genExpr = GeneratedExpression(
+ genToInternal(ctx, externalAccType)(accTerm),
+ NEVER_NULL,
+ NO_CODE,
+ internalAccType)
+ Seq(genExpr)
+ }
}
val initAggBufferExprs = initAuxGroupingExprs ++ initAggCallBufferExprs
@@ -394,13 +430,13 @@ object AggCodeGenHelper {
inputType: RowType,
inputTerm: String,
auxGrouping: Array[Int],
- aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
- aggregates: Seq[UserDefinedFunction],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[LogicalType]],
- aggBufferExprs: Seq[GeneratedExpression]): String = {
+ aggBufferExprs: Seq[GeneratedExpression])
+ : String = {
if (isMerge) {
genMergeFlatAggregateBuffer(
ctx,
@@ -408,8 +444,8 @@ object AggCodeGenHelper {
inputTerm,
inputType,
auxGrouping,
- aggregates,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
argsMapping,
aggBufferNames,
aggBufferTypes,
@@ -421,8 +457,8 @@ object AggCodeGenHelper {
inputTerm,
inputType,
auxGrouping,
- aggCallToAggFunction,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
argsMapping,
aggBufferNames,
aggBufferTypes,
@@ -437,35 +473,34 @@ object AggCodeGenHelper {
builder: RelBuilder,
grouping: Array[Int],
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
- aggResultTypes: Seq[DataType],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[LogicalType]],
aggBufferExprs: Seq[GeneratedExpression],
- outputType: RowType): GeneratedExpression = {
+ outputType: RowType)
+ : GeneratedExpression = {
val valueRow = CodeGenUtils.newName("valueRow")
- val resultCodegen = new ExprCodeGenerator(ctx, false)
+ val resultCodeGen = new ExprCodeGenerator(ctx, false)
if (isFinal) {
val getValueExprs = genGetValueFromFlatAggregateBuffer(
isMerge,
ctx,
builder,
auxGrouping,
- aggregates,
- aggResultTypes,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
argsMapping,
aggBufferNames,
aggBufferTypes,
outputType)
val valueRowType = RowType.of(getValueExprs.map(_.resultType): _*)
- resultCodegen.generateResultExpression(
+ resultCodeGen.generateResultExpression(
getValueExprs, valueRowType, classOf[GenericRowData], valueRow)
} else {
val valueRowType = RowType.of(aggBufferExprs.map(_.resultType): _*)
- resultCodegen.generateResultExpression(
+ resultCodeGen.generateResultExpression(
aggBufferExprs, valueRowType, classOf[GenericRowData], valueRow)
}
}
@@ -478,44 +513,58 @@ object AggCodeGenHelper {
ctx: CodeGeneratorContext,
builder: RelBuilder,
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
- aggResultTypes: Seq[DataType],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[LogicalType]],
outputType: RowType): Seq[GeneratedExpression] = {
- val exprCodegen = new ExprCodeGenerator(ctx, false)
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
+ val converter = new ExpressionConverter(builder)
val auxGroupingExprs = auxGrouping.indices.map { idx =>
- val resultTerm = aggBufferNames(idx)(0)
- val nullTerm = s"${resultTerm}IsNull"
- GeneratedExpression(resultTerm, nullTerm, "", aggBufferTypes(idx)(0))
+ val aggBufferName = aggBufferNames(idx).head
+ val aggBufferType = aggBufferTypes(idx).head
+ val nullTerm = s"${aggBufferName}IsNull"
+ GeneratedExpression(aggBufferName, nullTerm, NO_CODE, aggBufferType)
}
- val aggExprs = aggregates.zipWithIndex.map {
- case (agg: DeclarativeAggregateFunction, aggIndex) =>
- val idx = auxGrouping.length + aggIndex
- agg.getValueExpression.accept(ResolveReference(
- ctx, builder, isMerge, agg, idx, argsMapping, aggBufferTypes))
- case (agg: AggregateFunction[_, _], aggIndex) =>
- val idx = auxGrouping.length + aggIndex
- (agg, idx)
- }.map {
- case (expr: Expression) => expr.accept(new ExpressionConverter(builder))
- case t@_ => t
- }.map {
- case (rex: RexNode) => exprCodegen.generateExpression(rex)
- case (agg: AggregateFunction[_, _], aggIndex: Int) =>
- val resultType = aggResultTypes(aggIndex - auxGrouping.length)
- val accType = getAccumulatorTypeOfAggregateFunction(agg)
- val resultTerm = genToInternal(ctx, resultType,
- s"${udaggs(agg)}.getValue(${genToExternal(ctx, accType, aggBufferNames(aggIndex)(0))})")
- val nullTerm = s"${aggBufferNames(aggIndex)(0)}IsNull"
- GeneratedExpression(resultTerm, nullTerm, "", fromDataTypeToLogicalType(resultType))
+ val getValueExprs = aggInfos.map { aggInfo =>
+
+ val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+ aggInfo.function match {
+
+ // evaluate the value expression in declarative functions
+ case function: DeclarativeAggregateFunction =>
+ val ref = ResolveReference(
+ ctx,
+ builder,
+ isMerge,
+ function,
+ aggBufferIdx,
+ argsMapping,
+ aggBufferTypes)
+ val getValueRexNode = function.getValueExpression
+ .accept(ref)
+ .accept(converter)
+ exprCodeGen.generateExpression(getValueRexNode)
+
+ // call getValue() for imperative functions
+ case function: AggregateFunction[_, _] =>
+ val aggBufferName = aggBufferNames(aggBufferIdx).head
+ val externalAccType = aggInfo.externalAccTypes.head
+ val externalResultType = aggInfo.externalResultType
+ val resultType = externalResultType.getLogicalType
+ val getValueCode = s"${functionIdentifiers(function)}.getValue(" +
+ s"${genToExternal(ctx, externalAccType, aggBufferName)})"
+ val resultTerm = genToInternal(ctx, externalResultType)(getValueCode)
+ val nullTerm = s"${aggBufferName}IsNull"
+ GeneratedExpression(resultTerm, nullTerm, NO_CODE, resultType)
+ }
}
- auxGroupingExprs ++ aggExprs
+ auxGroupingExprs ++ getValueExprs
}
/**
@@ -527,55 +576,81 @@ object AggCodeGenHelper {
inputTerm: String,
inputType: RowType,
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[LogicalType]],
- aggBufferExprs: Seq[GeneratedExpression]): String = {
- val exprCodegen = new ExprCodeGenerator(ctx, false).bindInput(inputType, inputTerm = inputTerm)
-
- // flat map to get flat agg buffers.
- aggregates.zipWithIndex.flatMap {
- case (agg: DeclarativeAggregateFunction, aggIndex) =>
- val idx = auxGrouping.length + aggIndex
- agg.mergeExpressions.map(_.accept(ResolveReference(
- ctx, builder, isMerge = true, agg, idx, argsMapping, aggBufferTypes)))
- case (agg: AggregateFunction[_, _], aggIndex) =>
- val idx = auxGrouping.length + aggIndex
- Some(agg, idx)
- }.zip(aggBufferExprs.slice(auxGrouping.length, aggBufferExprs.size)).map {
- // DeclarativeAggregateFunction
- case ((expr: Expression), aggBufVar) =>
- val mergeExpr = exprCodegen.generateExpression(
- expr.accept(new ExpressionConverter(builder)))
- s"""
- |${mergeExpr.code}
- |${aggBufVar.nullTerm} = ${mergeExpr.nullTerm};
- |if (!${mergeExpr.nullTerm}) {
- | ${mergeExpr.copyResultTermToTargetIfChanged(ctx, aggBufVar.resultTerm)}
- |}
- """.stripMargin.trim
- // UserDefinedAggregateFunction
- case ((agg: AggregateFunction[_, _], aggIndex: Int), aggBufVar) =>
- val (inputIndex, inputType) = argsMapping(aggIndex)(0)
- val inputRef = toRexInputRef(builder, inputIndex, inputType)
- val inputExpr = exprCodegen.generateExpression(
- inputRef.accept(new ExpressionConverter(builder)))
- val singleIterableClass = classOf[SingleElementIterator[_]].getCanonicalName
-
- val externalAccT = getAccumulatorTypeOfAggregateFunction(agg)
- val javaField = typeTerm(externalAccT.getConversionClass)
- val tmpAcc = newName("tmpAcc")
- s"""
- |final $singleIterableClass accIt$aggIndex = new $singleIterableClass();
- |accIt$aggIndex.set(${genToExternal(ctx, externalAccT, inputExpr.resultTerm)});
- |$javaField $tmpAcc = ${genToExternal(ctx, externalAccT, aggBufferNames(aggIndex)(0))};
- |${udaggs(agg)}.merge($tmpAcc, accIt$aggIndex);
- |${aggBufferNames(aggIndex)(0)} = ${genToInternal(ctx, externalAccT, tmpAcc)};
- |${aggBufVar.nullTerm} = ${aggBufferNames(aggIndex)(0)}IsNull || ${inputExpr.nullTerm};
- """.stripMargin
- } mkString "\n"
+ aggBufferExprs: Seq[GeneratedExpression])
+ : String = {
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
+ .bindInput(inputType, inputTerm = inputTerm)
+ val converter = new ExpressionConverter(builder)
+
+ var currentAggBufferExprIdx = auxGrouping.length
+
+ val mergeCode = aggInfos.map { aggInfo =>
+
+ val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+ aggInfo.function match {
+
+ // merge each agg buffer for declarative functions
+ case function: DeclarativeAggregateFunction =>
+ val ref = ResolveReference(
+ ctx,
+ builder,
+ isMerge = true,
+ function,
+ aggBufferIdx,
+ argsMapping,
+ aggBufferTypes)
+ val mergeExprs = function.mergeExpressions
+ .map(_.accept(ref))
+ .map(_.accept(converter))
+ .map(exprCodeGen.generateExpression)
+ mergeExprs
+ .map { mergeExpr =>
+ val aggBufferExpr = aggBufferExprs(currentAggBufferExprIdx)
+ currentAggBufferExprIdx += 1
+ s"""
+ |${mergeExpr.code}
+ |${aggBufferExpr.nullTerm} = ${mergeExpr.nullTerm};
+ |if (!${mergeExpr.nullTerm}) {
+ | ${mergeExpr.copyResultTermToTargetIfChanged(ctx, aggBufferExpr.resultTerm)}
+ |}
+ """.stripMargin
+ }
+ .mkString("\n")
+
+ // call merge() for imperative functions
+ case function: AggregateFunction[_, _] =>
+ val (inputIndex, inputType) = argsMapping(aggBufferIdx).head
+ val inputRef = toRexInputRef(builder, inputIndex, inputType)
+ val inputExpr = exprCodeGen.generateExpression(
+ inputRef.accept(converter))
+ val aggBufferName = aggBufferNames(aggBufferIdx).head
+ val aggBufferExpr = aggBufferExprs(currentAggBufferExprIdx)
+ currentAggBufferExprIdx += 1
+ val iterableTypeTerm = className[SingleElementIterator[_]]
+ val externalAccType = aggInfo.externalAccTypes.head
+ val externalAccTypeTerm = typeTerm(externalAccType.getConversionClass)
+ val externalAccTerm = newName("acc")
+ val aggIndex = aggInfo.aggIndex
+ s"""
+ |$iterableTypeTerm accIt$aggIndex = new $iterableTypeTerm();
+ |accIt$aggIndex.set(${
+ genToExternal(ctx, externalAccType, inputExpr.resultTerm)});
+ |$externalAccTypeTerm $externalAccTerm = ${
+ genToExternal(ctx, externalAccType, aggBufferName)};
+ |${functionIdentifiers(function)}.merge($externalAccTerm, accIt$aggIndex);
+ |$aggBufferName = ${genToInternal(ctx, externalAccType)(externalAccTerm)};
+ |${aggBufferExpr.nullTerm} = ${aggBufferName}IsNull || ${inputExpr.nullTerm};
+ """.stripMargin
+ }
+ }
+
+ mergeCode.mkString("\n")
}
/**
@@ -587,81 +662,96 @@ object AggCodeGenHelper {
inputTerm: String,
inputType: RowType,
auxGrouping: Array[Int],
- aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[LogicalType]],
- aggBufferExprs: Seq[GeneratedExpression]): String = {
- val exprCodegen = new ExprCodeGenerator(ctx, false).bindInput(inputType, inputTerm = inputTerm)
-
- // flat map to get flat agg buffers.
- aggCallToAggFunction.zipWithIndex.flatMap {
- case (aggCallToAggFun, aggIndex) =>
- val idx = auxGrouping.length + aggIndex
- val aggCall = aggCallToAggFun._1
- aggCallToAggFun._2 match {
- case agg: DeclarativeAggregateFunction =>
- agg.accumulateExpressions.map(_.accept(ResolveReference(
- ctx, builder, isMerge = false, agg, idx, argsMapping, aggBufferTypes)))
- .map(e => (e, aggCall))
- case agg: AggregateFunction[_, _] =>
- val idx = auxGrouping.length + aggIndex
- Some(agg, idx, aggCall)
- }
- }.zip(aggBufferExprs.slice(auxGrouping.length, aggBufferExprs.size)).map {
- // DeclarativeAggregateFunction
- case ((expr: Expression, aggCall: AggregateCall), aggBufVar) =>
- val accExpr = exprCodegen.generateExpression(expr.accept(new ExpressionConverter(builder)))
- (s"""
- |${accExpr.code}
- |${aggBufVar.nullTerm} = ${accExpr.nullTerm};
- |if (!${accExpr.nullTerm}) {
- | ${accExpr.copyResultTermToTargetIfChanged(ctx, aggBufVar.resultTerm)}
- |}
- """.stripMargin, aggCall.filterArg)
- // UserDefinedAggregateFunction
- case ((agg: AggregateFunction[_, _], aggIndex: Int, aggCall: AggregateCall),
- aggBufVar) =>
- val inFields = argsMapping(aggIndex)
- val externalAccType = getAccumulatorTypeOfAggregateFunction(agg)
-
- val inputExprs = inFields.map {
- f =>
- val inputRef = toRexInputRef(builder, f._1, f._2)
- exprCodegen.generateExpression(inputRef.accept(new ExpressionConverter(builder)))
- }
-
- val externalUDITypes = getAggUserDefinedInputTypes(
- agg, externalAccType, inputExprs.map(_.resultType))
- val parameters = inputExprs.zipWithIndex.map {
- case (expr, i) =>
- genToExternalIfNeeded(ctx, externalUDITypes(i), expr)
- }
+ aggBufferExprs: Seq[GeneratedExpression])
+ : String = {
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
+ .bindInput(inputType, inputTerm = inputTerm)
+ val converter = new ExpressionConverter(builder)
- val javaTerm = typeTerm(externalAccType.getConversionClass)
- val tmpAcc = newName("tmpAcc")
- val innerCode =
+ var currentAggBufferExprIdx = auxGrouping.length
+
+ val filteredAccCode = aggInfos.map { aggInfo =>
+
+ val aggCall = aggInfo.agg
+
+ val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+
+ val accCode = aggInfo.function match {
+
+ // update each agg buffer for declarative functions
+ case function: DeclarativeAggregateFunction =>
+ val ref = ResolveReference(
+ ctx,
+ builder,
+ isMerge = false,
+ function,
+ aggBufferIdx,
+ argsMapping,
+ aggBufferTypes)
+ val accExprs = function.accumulateExpressions
+ .map(_.accept(ref))
+ .map(_.accept(new ExpressionConverter(builder)))
+ .map(exprCodeGen.generateExpression)
+ accExprs
+ .map { accExpr =>
+ val aggBufferExpr = aggBufferExprs(currentAggBufferExprIdx)
+ currentAggBufferExprIdx += 1
+ s"""
+ |${accExpr.code}
+ |${aggBufferExpr.nullTerm} = ${accExpr.nullTerm};
+ |if (!${accExpr.nullTerm}) {
+ | // copy result term
+ | ${accExpr.copyResultTermToTargetIfChanged(ctx, aggBufferExpr.resultTerm)}
+ |}
+ """.stripMargin
+ }
+ .mkString("\n")
+
+ // call accumulate() for imperative functions
+ case function: AggregateFunction[_, _] =>
+ val args = argsMapping(aggBufferIdx)
+ val inputExprs = args.map { case (argIndex, argType) =>
+ val inputRef = toRexInputRef(builder, argIndex, argType)
+ exprCodeGen.generateExpression(inputRef.accept(converter))
+ }
+ val operandTerms = inputExprs.zipWithIndex.map { case (expr, i) =>
+ genToExternalIfNeeded(ctx, aggInfo.externalArgTypes(i), expr)
+ }
+ val aggBufferName = aggBufferNames(aggBufferIdx).head
+ val aggBufferExpr = aggBufferExprs(currentAggBufferExprIdx)
+ currentAggBufferExprIdx += 1
+ val externalAccType = aggInfo.externalAccTypes.head
+ val externalAccTypeTerm = typeTerm(externalAccType.getConversionClass)
+ val externalAccTerm = newName("acc")
+ val externalAccCode = genToExternal(ctx, externalAccType, aggBufferName)
s"""
- | $javaTerm $tmpAcc = ${
- genToExternal(ctx, externalAccType, aggBufferNames(aggIndex)(0))};
- | ${udaggs(agg)}.accumulate($tmpAcc, ${parameters.mkString(", ")});
- | ${aggBufferNames(aggIndex)(0)} = ${genToInternal(ctx, externalAccType, tmpAcc)};
- | ${aggBufVar.nullTerm} = false;
- """.stripMargin
- (innerCode, aggCall.filterArg)
- }.map({
- case (innerCode, filterArg) =>
- if (filterArg >= 0) {
- s"""
- |if ($inputTerm.getBoolean($filterArg)) {
- | $innerCode
- |}
+ |$externalAccTypeTerm $externalAccTerm = $externalAccCode;
+ |${functionIdentifiers(function)}.accumulate(
+ | $externalAccTerm,
+ | ${operandTerms.mkString(", ")});
+ |$aggBufferName = ${genToInternal(ctx, externalAccType)(externalAccTerm)};
+ |${aggBufferExpr.nullTerm} = false;
""".stripMargin
- } else {
- innerCode
- }
- }) mkString "\n"
+ }
+
+ // apply filter if present
+ if (aggInfo.agg.filterArg >= 0) {
+ s"""
+ |if ($inputTerm.getBoolean(${aggCall.filterArg})) {
+ | $accCode
+ |}
+ """.stripMargin
+ } else {
+ accCode
+ }
+ }
+
+ filteredAccCode.mkString("\n")
}
/**
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggWithoutKeysCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggWithoutKeysCodeGenerator.scala
index 43725b9..5a99b5a 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggWithoutKeysCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggWithoutKeysCodeGenerator.scala
@@ -44,19 +44,21 @@ object AggWithoutKeysCodeGenerator {
outputType: RowType,
isMerge: Boolean,
isFinal: Boolean,
- prefix: String): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
- val aggCallToAggFunction = aggInfoList.aggInfos.map(info => (info.agg, info.function))
- val aggregates = aggCallToAggFunction.map(_._2)
- val udaggs = AggCodeGenHelper.getUdaggs(aggregates)
- val aggBufferNames = AggCodeGenHelper.getAggBufferNames(Array(), aggregates)
- val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, Array(), aggregates)
- val aggArgs = aggInfoList.aggInfos.map(_.argIndexes)
+ prefix: String)
+ : GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
- val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
+ // prepare for aggregation
+ val auxGrouping = Array[Int]()
+ val aggInfos = aggInfoList.aggInfos
+ aggInfos
+ .map(_.function)
+ .filter(_.isInstanceOf[AggregateFunction[_, _]])
+ .map(ctx.addReusableFunction(_))
+ val functionIdentifiers = AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+ val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+ val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
- // register udagg
- aggregates.filter(a => a.isInstanceOf[AggregateFunction[_, _]])
- .map(a => ctx.addReusableFunction(a))
+ val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val (initAggBufferCode, doAggregateCode, aggOutputExpr) = genSortAggCodes(
isMerge,
@@ -65,11 +67,8 @@ object AggWithoutKeysCodeGenerator {
builder,
Array(),
Array(),
- aggCallToAggFunction,
- aggArgs,
- aggregates,
- aggInfoList.aggInfos.map(_.externalResultType),
- udaggs,
+ aggInfos,
+ functionIdentifiers,
inputTerm,
inputType,
aggBufferNames,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala
index 5ba6227..e8efbee 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala
@@ -18,12 +18,13 @@
package org.apache.flink.table.planner.codegen.agg.batch
+import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.metrics.Gauge
import org.apache.flink.table.data.binary.BinaryRowData
import org.apache.flink.table.data.{GenericRowData, JoinedRowData, RowData}
-import org.apache.flink.table.expressions.{Expression, _}
-import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.planner.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull}
import org.apache.flink.table.planner.codegen._
import org.apache.flink.table.planner.codegen.agg.batch.AggCodeGenHelper.buildAggregateArgsMapping
@@ -32,17 +33,13 @@ import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver
import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef
import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
-import org.apache.flink.table.planner.plan.utils.SortUtil
+import org.apache.flink.table.planner.plan.utils.{AggregateInfo, SortUtil}
import org.apache.flink.table.runtime.generated.{NormalizedKeyComputer, RecordComparator}
import org.apache.flink.table.runtime.operators.aggregate.{BytesHashMap, BytesHashMapSpillMemorySegmentPool}
import org.apache.flink.table.runtime.operators.sort.BufferedKVExternalSorter
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer
-import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.{LogicalType, RowType}
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.tools.RelBuilder
-
import scala.collection.JavaConversions._
object HashAggCodeGenHelper {
@@ -118,24 +115,23 @@ object HashAggCodeGenHelper {
groupingAndAuxGrouping: (Array[Int], Array[Int]),
inputTerm: String,
inputType: RowType,
- aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
- aggArgs: Array[Array[Int]],
- aggregates: Seq[UserDefinedFunction],
+ aggInfos: Seq[AggregateInfo],
currentAggBufferTerm: String,
aggBufferRowType: RowType,
aggBufferTypes: Array[Array[LogicalType]],
outputTerm: String,
outputType: RowType,
groupKeyTerm: String,
- aggBufferTerm: String): (GeneratedExpression, GeneratedExpression, GeneratedExpression) = {
+ aggBufferTerm: String)
+ : (GeneratedExpression, GeneratedExpression, GeneratedExpression) = {
val (grouping, auxGrouping) = groupingAndAuxGrouping
// build mapping for DeclarativeAggregationFunction binding references
val argsMapping = buildAggregateArgsMapping(
- isMerge, grouping.length, inputType, auxGrouping, aggArgs, aggBufferTypes)
+ isMerge, grouping.length, inputType, auxGrouping, aggInfos, aggBufferTypes)
val aggBuffMapping = buildAggregateAggBuffMapping(aggBufferTypes)
// gen code to create empty agg buffer
val initedAggBuffer = genReusableEmptyAggBuffer(
- ctx, builder, inputTerm, inputType, auxGrouping, aggregates, aggBufferRowType)
+ ctx, builder, inputTerm, inputType, auxGrouping, aggInfos, aggBufferRowType)
if (auxGrouping.isEmpty) {
// create an empty agg buffer and initialized make it reusable
ctx.addReusableOpenStatement(initedAggBuffer.code)
@@ -148,8 +144,7 @@ object HashAggCodeGenHelper {
inputType,
inputTerm,
auxGrouping,
- aggregates,
- aggCallToAggFunction,
+ aggInfos,
argsMapping,
aggBuffMapping,
currentAggBufferTerm,
@@ -161,7 +156,7 @@ object HashAggCodeGenHelper {
ctx,
builder,
auxGrouping,
- aggregates,
+ aggInfos,
argsMapping,
aggBuffMapping,
outputTerm,
@@ -196,26 +191,29 @@ object HashAggCodeGenHelper {
inputTerm: String,
inputType: RowType,
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
- aggBufferType: RowType): GeneratedExpression = {
- val exprCodegen = new ExprCodeGenerator(ctx, false)
+ aggInfos: Seq[AggregateInfo],
+ aggBufferType: RowType)
+ : GeneratedExpression = {
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
.bindInput(inputType, inputTerm = inputTerm, inputFieldMapping = Some(auxGrouping))
+ val converter = new ExpressionConverter(builder)
val initAuxGroupingExprs = auxGrouping.map { idx =>
GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, idx)
}
- val initAggCallBufferExprs = aggregates.flatMap(a =>
- a.asInstanceOf[DeclarativeAggregateFunction].initialValuesExpressions)
- .map(_.accept(new ExpressionConverter(builder)))
- .map(exprCodegen.generateExpression)
+ val initAggCallBufferExprs = aggInfos
+ .map(_.function.asInstanceOf[DeclarativeAggregateFunction])
+ .flatMap(_.initialValuesExpressions)
+ .map(_.accept(converter))
+ .map(exprCodeGen.generateExpression)
val initAggBufferExprs = initAuxGroupingExprs ++ initAggCallBufferExprs
// empty agg buffer and writer will be reused
val emptyAggBufferTerm = CodeGenUtils.newName("emptyAggBuffer")
val emptyAggBufferWriterTerm = CodeGenUtils.newName("emptyAggBufferWriterTerm")
- exprCodegen.generateResultExpression(
+ exprCodeGen.generateResultExpression(
initAggBufferExprs,
aggBufferType,
classOf[BinaryRowData],
@@ -231,8 +229,7 @@ object HashAggCodeGenHelper {
inputType: RowType,
inputTerm: String,
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
- aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
+ aggInfos: Seq[AggregateInfo],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBuffMapping: Array[Array[(Int, LogicalType)]],
currentAggBufferTerm: String,
@@ -245,7 +242,7 @@ object HashAggCodeGenHelper {
inputType,
currentAggBufferTerm,
auxGrouping,
- aggregates,
+ aggInfos,
argsMapping,
aggBuffMapping,
aggBufferRowType)
@@ -257,7 +254,7 @@ object HashAggCodeGenHelper {
inputType,
currentAggBufferTerm,
auxGrouping,
- aggCallToAggFunction,
+ aggInfos,
argsMapping,
aggBuffMapping,
aggBufferRowType)
@@ -270,7 +267,7 @@ object HashAggCodeGenHelper {
ctx: CodeGeneratorContext,
builder: RelBuilder,
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
+ aggInfos: Seq[AggregateInfo],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBuffMapping: Array[Array[(Int, LogicalType)]],
outputTerm: String,
@@ -279,29 +276,46 @@ object HashAggCodeGenHelper {
inputType: RowType,
groupKeyTerm: Option[String],
aggBufferTerm: String,
- aggBufferType: RowType): GeneratedExpression = {
- // gen code to get agg result
- val exprCodegen = new ExprCodeGenerator(ctx, false)
+ aggBufferType: RowType)
+ : GeneratedExpression = {
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
.bindInput(inputType, inputTerm = inputTerm)
.bindSecondInput(aggBufferType, inputTerm = aggBufferTerm)
+ val converter = new ExpressionConverter(builder)
+
val resultExpr = if (isFinal) {
+
val bindRefOffset = inputType.getFieldCount
- val getAuxGroupingExprs = auxGrouping.indices.map { idx =>
- val (_, resultType) = aggBuffMapping(idx)(0)
- toRexInputRef(builder, bindRefOffset + idx, resultType)
- }.map(_.accept(new ExpressionConverter(builder))).map(exprCodegen.generateExpression)
- val getAggValueExprs = aggregates.zipWithIndex.map {
- case (agg: DeclarativeAggregateFunction, aggIndex) =>
- val idx = auxGrouping.length + aggIndex
- agg.getValueExpression.accept(ResolveReference(
- ctx, builder, isMerge, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))
- }.map(_.accept(new ExpressionConverter(builder))).map(exprCodegen.generateExpression)
+ val getAuxGroupingExprs = auxGrouping.indices
+ .map { idx =>
+ val (_, resultType) = aggBuffMapping(idx)(0)
+ toRexInputRef(builder, bindRefOffset + idx, resultType)
+ }
+
+ val getAggValueExprs = aggInfos.map { aggInfo =>
+ val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+ val function = aggInfo.function.asInstanceOf[DeclarativeAggregateFunction]
+ val ref = ResolveReference(
+ ctx,
+ builder,
+ isMerge,
+ bindRefOffset,
+ function,
+ aggBufferIdx,
+ argsMapping,
+ aggBuffMapping)
+ function.getValueExpression
+ .accept(ref)
+ }
+
+ val getValueExprs = (getAuxGroupingExprs ++ getAggValueExprs)
+ .map(_.accept(converter))
+ .map(exprCodeGen.generateExpression)
- val getValueExprs = getAuxGroupingExprs ++ getAggValueExprs
val aggValueTerm = CodeGenUtils.newName("aggVal")
val valueType = RowType.of(getValueExprs.map(_.resultType): _*)
- exprCodegen.generateResultExpression(
+ exprCodeGen.generateResultExpression(
getValueExprs,
valueType,
classOf[GenericRowData],
@@ -365,22 +379,36 @@ object HashAggCodeGenHelper {
inputType: RowType,
currentAggBufferTerm: String,
auxGrouping: Array[Int],
- aggregates: Seq[UserDefinedFunction],
+ aggInfos: Seq[AggregateInfo],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBuffMapping: Array[Array[(Int, LogicalType)]],
- aggBufferType: RowType): GeneratedExpression = {
- val exprCodegen = new ExprCodeGenerator(ctx, false)
+ aggBufferType: RowType)
+ : GeneratedExpression = {
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
.bindInput(inputType, inputTerm = inputTerm)
.bindSecondInput(aggBufferType, inputTerm = currentAggBufferTerm)
+ val converter = new ExpressionConverter(builder)
- val mergeExprs = aggregates.zipWithIndex.flatMap {
- case (agg: DeclarativeAggregateFunction, aggIndex) =>
- val idx = auxGrouping.length + aggIndex
- val bindRefOffset = inputType.getFieldCount
- agg.mergeExpressions.map(
- _.accept(ResolveReference(
- ctx, builder, isMerge = true, bindRefOffset, agg, idx, argsMapping, aggBuffMapping)))
- }.map(_.accept(new ExpressionConverter(builder))).map(exprCodegen.generateExpression)
+ val mergeExprs = aggInfos
+ .map(_.function)
+ .zipWithIndex
+ .flatMap {
+ case (agg: DeclarativeAggregateFunction, aggIndex) =>
+ val aggBufferIdx = auxGrouping.length + aggIndex
+ val bindRefOffset = inputType.getFieldCount
+ val ref = ResolveReference(
+ ctx,
+ builder,
+ isMerge = true,
+ bindRefOffset,
+ agg,
+ aggBufferIdx,
+ argsMapping,
+ aggBuffMapping)
+ agg.mergeExpressions.map(_.accept(ref))
+ }
+ .map(_.accept(converter))
+ .map(exprCodeGen.generateExpression)
val aggBufferTypeWithoutAuxGrouping = if (auxGrouping.nonEmpty) {
// auxGrouping does not need merge-code
@@ -398,7 +426,7 @@ object HashAggCodeGenHelper {
}.toMap
// update agg buff in-place
- exprCodegen.generateResultExpression(
+ exprCodeGen.generateResultExpression(
mergeExprs,
mergeExprIdxToOutputRowPosMap,
aggBufferTypeWithoutAuxGrouping,
@@ -423,30 +451,37 @@ object HashAggCodeGenHelper {
inputType: RowType,
currentAggBufferTerm: String,
auxGrouping: Array[Int],
- aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
+ aggInfos: Seq[AggregateInfo],
argsMapping: Array[Array[(Int, LogicalType)]],
aggBuffMapping: Array[Array[(Int, LogicalType)]],
- aggBufferType: RowType): GeneratedExpression = {
- val exprCodegen = new ExprCodeGenerator(ctx, false)
+ aggBufferType: RowType)
+ : GeneratedExpression = {
+ val exprCodeGen = new ExprCodeGenerator(ctx, false)
.bindInput(inputType, inputTerm = inputTerm)
.bindSecondInput(aggBufferType, inputTerm = currentAggBufferTerm)
-
- val accumulateExprsWithFilterArgs = aggCallToAggFunction.zipWithIndex.flatMap {
- case (aggCallToAggFun, aggIndex) =>
- val idx = auxGrouping.length + aggIndex
- val bindRefOffset = inputType.getFieldCount
- val aggCall = aggCallToAggFun._1
- aggCallToAggFun._2 match {
- case agg: DeclarativeAggregateFunction =>
- agg.accumulateExpressions.map(_.accept(ResolveReference(
- ctx, builder, isMerge = false, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))
- ).map(e => (e, aggCall))
- }
- }.map {
- case (expr: Expression, aggCall: AggregateCall) =>
- (exprCodegen.generateExpression(expr.accept(new ExpressionConverter(builder))),
- aggCall.filterArg)
- }
+ val converter = new ExpressionConverter(builder)
+
+ val bindRefOffset = inputType.getFieldCount
+
+ val accumulateExprsWithFilterArgs = aggInfos
+ .flatMap { aggInfo =>
+ val aggBufferIdx = auxGrouping.length + aggInfo.aggIndex
+ val function = aggInfo.function.asInstanceOf[DeclarativeAggregateFunction]
+ val ref = ResolveReference(
+ ctx,
+ builder,
+ isMerge = false,
+ bindRefOffset,
+ function,
+ aggBufferIdx,
+ argsMapping,
+ aggBuffMapping)
+ function.accumulateExpressions
+ .map(_.accept(ref))
+ .map { e =>
+ (exprCodeGen.generateExpression(e.accept(converter)), aggInfo.agg.filterArg)
+ }
+ }
// update agg buff in-place
val code = accumulateExprsWithFilterArgs.zipWithIndex.map({
@@ -537,10 +572,8 @@ object HashAggCodeGenHelper {
ctx: CodeGeneratorContext,
builder: RelBuilder,
groupingAndAuxGrouping: (Array[Int], Array[Int]),
- aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
- aggArgs: Array[Array[Int]],
- aggResultTypes: Seq[DataType],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
logTerm: String,
aggregateMapTerm: String,
aggMapKVTypesTerm: (String, String),
@@ -570,11 +603,8 @@ object HashAggCodeGenHelper {
builder,
grouping,
auxGrouping,
- aggCallToAggFunction,
- aggArgs,
- aggCallToAggFunction.map(_._2),
- aggResultTypes,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
aggregateMapTerm,
(groupKeyRowType, aggBufferRowType),
aggregateMapTerm,
@@ -697,11 +727,8 @@ object HashAggCodeGenHelper {
builder: RelBuilder,
grouping: Array[Int],
auxGrouping: Array[Int],
- aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
- aggArgs: Array[Array[Int]],
- aggregates: Seq[UserDefinedFunction],
- aggResultTypes: Seq[DataType],
- udaggs: Map[AggregateFunction[_, _], String],
+ aggInfos: Seq[AggregateInfo],
+ functionIdentifiers: Map[AggregateFunction[_, _], String],
mapTerm: String,
mapKVRowTypes: (RowType, RowType),
aggregateMapTerm: String,
@@ -728,11 +755,8 @@ object HashAggCodeGenHelper {
builder,
grouping,
auxGrouping,
- aggCallToAggFunction,
- aggArgs,
- aggregates,
- aggResultTypes,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
fallbackInputTerm,
fallbackInputType,
aggBufferNames,
@@ -796,7 +820,7 @@ object HashAggCodeGenHelper {
aggMapKeyType: RowType) : String = {
val keyFieldTypes = aggMapKeyType.getChildren.toArray(Array[LogicalType]())
val keys = keyFieldTypes.indices.toArray
- val orders = keys.map((_) => true)
+ val orders = keys.map(_ => true)
val nullsIsLast = SortUtil.getNullDefaultOrders(orders)
val sortCodeGenerator = new SortCodeGenerator(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
index 5e74793..66552cd 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
@@ -21,15 +21,14 @@ package org.apache.flink.table.planner.codegen.agg.batch
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.table.data.binary.BinaryRowData
import org.apache.flink.table.data.{GenericRowData, JoinedRowData, RowData}
-import org.apache.flink.table.functions.UserDefinedFunction
+import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.planner.codegen.{CodeGenUtils, CodeGeneratorContext, ProjectionCodeGenerator}
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
-import org.apache.flink.table.planner.plan.utils.AggregateInfoList
+import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList}
import org.apache.flink.table.runtime.generated.GeneratedOperator
import org.apache.flink.table.runtime.operators.TableStreamOperator
import org.apache.flink.table.runtime.operators.aggregate.{BytesHashMap, BytesHashMapSpillMemorySegmentPool}
-import org.apache.flink.table.types.logical.RowType
-
+import org.apache.flink.table.types.logical.{LogicalType, RowType}
import org.apache.calcite.tools.RelBuilder
/**
@@ -48,17 +47,20 @@ class HashAggCodeGenerator(
isMerge: Boolean,
isFinal: Boolean) {
+ private lazy val aggInfos: Array[AggregateInfo] = aggInfoList.aggInfos
+
+ private lazy val functionIdentifiers: Map[AggregateFunction[_, _], String] =
+ AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+
+ private lazy val aggBufferNames: Array[Array[String]] =
+ AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+
+ private lazy val aggBufferTypes: Array[Array[LogicalType]] = AggCodeGenHelper.getAggBufferTypes(
+ inputType,
+ auxGrouping,
+ aggInfos)
+
private lazy val groupKeyRowType = AggCodeGenHelper.projectRowType(inputType, grouping)
- private lazy val aggCallToAggFunction =
- aggInfoList.aggInfos.map(info => (info.agg, info.function))
- private lazy val aggregates: Seq[UserDefinedFunction] = aggInfoList.aggInfos.map(_.function)
- private lazy val aggArgs: Array[Array[Int]] = aggInfoList.aggInfos.map(_.argIndexes)
- // get udagg instance names
- private lazy val udaggs = AggCodeGenHelper.getUdaggs(aggregates)
- // currently put auxGrouping to aggBuffer in code-gen
- private lazy val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggregates)
- private lazy val aggBufferTypes =
- AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggregates)
private lazy val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
@@ -119,9 +121,7 @@ class HashAggCodeGenerator(
(grouping, auxGrouping),
inputTerm,
inputType,
- aggCallToAggFunction,
- aggArgs,
- aggregates,
+ aggInfos,
currentAggBufferTerm,
aggBufferRowType,
aggBufferTypes,
@@ -143,10 +143,8 @@ class HashAggCodeGenerator(
ctx,
builder,
(grouping, auxGrouping),
- aggCallToAggFunction,
- aggArgs,
- aggInfoList.aggInfos.map(_.externalResultType),
- udaggs,
+ aggInfos,
+ functionIdentifiers,
logTerm,
aggregateMapTerm,
(groupKeyTypesTerm, aggBufferTypesTerm),
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashWindowCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashWindowCodeGenerator.scala
index 53cc295..cebfc73 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashWindowCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashWindowCodeGenerator.scala
@@ -369,11 +369,11 @@ class HashWindowCodeGenerator(
// build mapping for DeclarativeAggregationFunction binding references
val offset = if (isMerge) grouping.length + 1 else grouping.length
val argsMapping = AggCodeGenHelper.buildAggregateArgsMapping(
- isMerge, offset, inputType, auxGrouping, aggArgs, aggBufferTypes)
+ isMerge, offset, inputType, auxGrouping, aggInfos, aggBufferTypes)
val aggBuffMapping = HashAggCodeGenHelper.buildAggregateAggBuffMapping(aggBufferTypes)
// gen code to create empty agg buffer
val initedAggBuffer = HashAggCodeGenHelper.genReusableEmptyAggBuffer(
- ctx, builder, inputTerm, inputType, auxGrouping, aggregates, aggBufferRowType)
+ ctx, builder, inputTerm, inputType, auxGrouping, aggInfos, aggBufferRowType)
if (auxGrouping.isEmpty) {
// init aggBuffer in open function when there is no auxGrouping
ctx.addReusableOpenStatement(initedAggBuffer.code)
@@ -386,8 +386,7 @@ class HashWindowCodeGenerator(
inputType,
inputTerm,
auxGrouping,
- aggregates,
- aggCallToAggFunction,
+ aggInfos,
argsMapping,
aggBuffMapping,
currentAggBufferTerm,
@@ -650,7 +649,7 @@ class HashWindowCodeGenerator(
ctx,
builder,
auxGrouping,
- aggregates,
+ aggInfos,
argsMapping,
aggBuffMapping,
outputTerm,
@@ -697,7 +696,7 @@ class HashWindowCodeGenerator(
ctx,
builder,
auxGrouping,
- aggregates,
+ aggInfos,
argsMapping,
aggBuffMapping,
outputTerm,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
index b413324..b0f6f99 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
@@ -46,16 +46,20 @@ object SortAggCodeGenerator {
grouping: Array[Int],
auxGrouping: Array[Int],
isMerge: Boolean,
- isFinal: Boolean): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
- val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
-
- val aggCallToAggFunction = aggInfoList.aggInfos.map(info => (info.agg, info.function))
- val aggArgs = aggInfoList.aggInfos.map(_.argIndexes)
-
- // register udaggs
- aggCallToAggFunction.map(_._2).filter(a => a.isInstanceOf[AggregateFunction[_, _]])
- .map(a => ctx.addReusableFunction(a))
+ isFinal: Boolean)
+ : GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
+
+ // prepare for aggregation
+ val aggInfos = aggInfoList.aggInfos
+ aggInfos
+ .map(_.function)
+ .filter(_.isInstanceOf[AggregateFunction[_, _]])
+ .map(ctx.addReusableFunction(_))
+ val functionIdentifiers = AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+ val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+ val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
+ val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val lastKeyTerm = "lastKey"
val currentKeyTerm = "currentKey"
val currentKeyWriterTerm = "currentKeyWriter"
@@ -72,11 +76,6 @@ object SortAggCodeGenerator {
val keyNotEquals = AggCodeGenHelper.genGroupKeyChangedCheckCode(currentKeyTerm, lastKeyTerm)
- val aggregates = aggCallToAggFunction.map(_._2)
- val udaggs = AggCodeGenHelper.getUdaggs(aggregates)
- val aggBufferNames = AggCodeGenHelper.getAggBufferNames(auxGrouping, aggregates)
- val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggregates)
-
val (initAggBufferCode, doAggregateCode, aggOutputExpr) = AggCodeGenHelper.genSortAggCodes(
isMerge,
isFinal,
@@ -84,11 +83,8 @@ object SortAggCodeGenerator {
builder,
grouping,
auxGrouping,
- aggCallToAggFunction,
- aggArgs,
- aggregates,
- aggInfoList.aggInfos.map(_.externalResultType),
- udaggs,
+ aggInfos,
+ functionIdentifiers,
inputTerm,
inputType,
aggBufferNames,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortWindowCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortWindowCodeGenerator.scala
index f51a74a..3677b41 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortWindowCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortWindowCodeGenerator.scala
@@ -88,13 +88,15 @@ class SortWindowCodeGenerator(
isMerge,
isFinal) {
+ // prepare for aggregation
+ aggInfos
+ .map(_.function)
+ .filter(_.isInstanceOf[AggregateFunction[_, _]])
+ .map(ctx.addReusableFunction(_))
+
def genWithoutKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
- aggCallToAggFunction
- .map(_._2).filter(a => a.isInstanceOf[AggregateFunction[_, _]])
- .map(a => ctx.addReusableFunction(a))
-
val timeWindowType = classOf[TimeWindow].getName
val currentWindow = CodeGenUtils.newName("currentWindow")
ctx.addReusableMember(s"transient $timeWindowType $currentWindow = null;")
@@ -158,10 +160,6 @@ class SortWindowCodeGenerator(
}
def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
- aggCallToAggFunction
- .map(_._2).filter(a => a.isInstanceOf[AggregateFunction[_, _]])
- .map(a => ctx.addReusableFunction(a))
-
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val currentKey = CodeGenUtils.newName("currentKey")
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
index 9b5cc68..9f9576b 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
@@ -18,12 +18,17 @@
package org.apache.flink.table.planner.codegen.agg.batch
+import org.apache.calcite.avatica.util.DateTimeUtils
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.tools.RelBuilder
+import org.apache.commons.math3.util.ArithmeticUtils
import org.apache.flink.table.api.DataTypes
import org.apache.flink.table.data.binary.BinaryRowData
import org.apache.flink.table.data.{GenericRowData, JoinedRowData, RowData}
import org.apache.flink.table.expressions.ExpressionUtils.extractValue
import org.apache.flink.table.expressions.{Expression, ValueLiteralExpression}
-import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.planner.JLong
import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindowProperty
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
@@ -37,24 +42,15 @@ import org.apache.flink.table.planner.codegen.agg.batch.WindowCodeGenerator.{asL
import org.apache.flink.table.planner.expressions.CallExpressionResolver
import org.apache.flink.table.planner.expressions.ExpressionBuilder._
import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
-import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
-import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
import org.apache.flink.table.planner.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow}
-import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil}
+import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList, AggregateUtil}
import org.apache.flink.table.runtime.operators.window.TimeWindow
import org.apache.flink.table.runtime.operators.window.grouping.{HeapWindowsGrouping, WindowsGrouping}
-import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.runtime.util.RowIterator
import org.apache.flink.table.types.logical.LogicalTypeRoot.INTERVAL_DAY_TIME
import org.apache.flink.table.types.logical._
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot
-import org.apache.calcite.avatica.util.DateTimeUtils
-import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.tools.RelBuilder
-import org.apache.commons.math3.util.ArithmeticUtils
-
import scala.collection.JavaConversions._
abstract class WindowCodeGenerator(
@@ -71,49 +67,32 @@ abstract class WindowCodeGenerator(
val isMerge: Boolean,
val isFinal: Boolean) {
- lazy val builder: RelBuilder = relBuilder.values(inputRowType)
- lazy val timestampInternalType: LogicalType =
+ protected lazy val builder: RelBuilder = relBuilder.values(inputRowType)
+
+ protected lazy val aggInfos: Array[AggregateInfo] = aggInfoList.aggInfos
+
+ protected lazy val functionIdentifiers: Map[AggregateFunction[_, _], String] =
+ AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
+
+ protected lazy val aggBufferNames: Array[Array[String]] =
+ AggCodeGenHelper.getAggBufferNames(auxGrouping, aggInfos)
+
+ protected lazy val aggBufferTypes: Array[Array[LogicalType]] = AggCodeGenHelper.getAggBufferTypes(
+ inputType,
+ auxGrouping,
+ aggInfos)
+
+ protected lazy val groupKeyRowType: RowType = AggCodeGenHelper.projectRowType(inputType, grouping)
+
+ private lazy val inputType: RowType =
+ FlinkTypeFactory.toLogicalType(inputRowType).asInstanceOf[RowType]
+
+ protected lazy val timestampInternalType: LogicalType =
if (inputTimeIsDate) new IntType() else new BigIntType()
- lazy val timestampInternalTypeName: String = if (inputTimeIsDate) "Int" else "Long"
- lazy val aggCallToAggFunction: Array[(AggregateCall, UserDefinedFunction)] =
- aggInfoList.aggInfos.map(info => (info.agg, info.function))
- lazy val aggregateCalls: Seq[AggregateCall] = aggCallToAggFunction.map(_._1)
- lazy val aggregates: Seq[UserDefinedFunction] = aggCallToAggFunction.map(_._2)
-
- lazy val aggArgs: Array[Array[Int]] = aggInfoList.aggInfos.map(_.argIndexes)
-
- // currently put auxGrouping to aggBuffer in code-gen
- lazy val aggBufferNames: Array[Array[String]] = auxGrouping.zipWithIndex.map {
- case (_, index) => Array(s"aux_group$index")
- } ++ aggregates.zipWithIndex.toArray.map {
- case (a: DeclarativeAggregateFunction, index) =>
- val idx = auxGrouping.length + index
- a.aggBufferAttributes.map(attr => s"agg${idx}_${attr.getName}")
- case (_: AggregateFunction[_, _], index) =>
- val idx = auxGrouping.length + index
- Array(s"agg$idx")
- }
- lazy val aggBufferTypes: Array[Array[LogicalType]] = auxGrouping.map { index =>
- Array(FlinkTypeFactory.toLogicalType(inputRowType.getFieldList.get(index).getType))
- } ++ aggregates.map {
- case a: DeclarativeAggregateFunction => a.getAggBufferTypes.map(_.getLogicalType)
- case a: AggregateFunction[_, _] =>
- Array(getAccumulatorTypeOfAggregateFunction(a)).map(fromDataTypeToLogicalType)
- }.toArray[Array[LogicalType]]
-
- lazy val groupKeyRowType: RowType = RowType.of(
- grouping.map { index =>
- FlinkTypeFactory.toLogicalType(inputRowType.getFieldList.get(index).getType)
- }, grouping.map(inputRowType.getFieldNames.get(_)))
-
- // get udagg instance names
- lazy val udaggs: Map[AggregateFunction[_, _], String] = aggregates
- .filter(a => a.isInstanceOf[AggregateFunction[_, _]])
- .map(a => a -> CodeGenUtils.udfFieldName(a)).toMap
- .asInstanceOf[Map[AggregateFunction[_, _], String]]
-
- lazy val windowedGroupKeyType: RowType = RowType.of(
+ protected lazy val timestampInternalTypeName: String = if (inputTimeIsDate) "Int" else "Long"
+
+ private lazy val windowedGroupKeyType: RowType = RowType.of(
(groupKeyRowType.getChildren :+ timestampInternalType).toArray,
(groupKeyRowType.getFieldNames :+ "assignedTs$").toArray)
@@ -224,13 +203,13 @@ abstract class WindowCodeGenerator(
// gen code to apply aggregate functions to grouping window elements
val offset = if (enablePreAcc) grouping.length + 1 else grouping.length
val argsMapping = buildAggregateArgsMapping(
- enablePreAcc, offset, inputType, auxGrouping, aggArgs, aggBufferTypes)
+ enablePreAcc, offset, inputType, auxGrouping, aggInfos, aggBufferTypes)
val aggBufferExprs = genFlatAggBufferExprs(
enablePreAcc,
ctx,
builder,
auxGrouping,
- aggregates,
+ aggInfos,
argsMapping,
aggBufferNames,
aggBufferTypes)
@@ -241,8 +220,8 @@ abstract class WindowCodeGenerator(
inputTerm,
grouping,
auxGrouping,
- aggregates,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
aggBufferExprs)
val doAggregateCode = genAggregateByFlatAggregateBuffer(
enablePreAcc,
@@ -251,9 +230,8 @@ abstract class WindowCodeGenerator(
inputType,
inputTerm,
auxGrouping,
- aggCallToAggFunction,
- aggregates,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
argsMapping,
aggBufferNames,
aggBufferTypes,
@@ -271,9 +249,8 @@ abstract class WindowCodeGenerator(
builder,
grouping,
auxGrouping,
- aggregates,
- aggInfoList.aggInfos.map(_.externalResultType),
- udaggs,
+ aggInfos,
+ functionIdentifiers,
argsMapping,
aggBufferNames,
aggBufferTypes,
@@ -442,13 +419,13 @@ abstract class WindowCodeGenerator(
// case: global/complete window agg: Sliding window with with pane optimization
val offset = if (isMerge) grouping.length + 1 else grouping.length
val argsMapping = buildAggregateArgsMapping(
- isMerge, offset, inputType, auxGrouping, aggArgs, aggBufferTypes)
+ isMerge, offset, inputType, auxGrouping, aggInfos, aggBufferTypes)
val aggBufferExprs = genFlatAggBufferExprs(
isMerge,
ctx,
builder,
auxGrouping,
- aggregates,
+ aggInfos,
argsMapping,
aggBufferNames,
aggBufferTypes)
@@ -459,8 +436,8 @@ abstract class WindowCodeGenerator(
inputTerm,
grouping,
auxGrouping,
- aggregates,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
aggBufferExprs)
val doAggregateCode = genAggregateByFlatAggregateBuffer(
isMerge,
@@ -469,9 +446,8 @@ abstract class WindowCodeGenerator(
inputType,
inputTerm,
auxGrouping,
- aggCallToAggFunction,
- aggregates,
- udaggs,
+ aggInfos,
+ functionIdentifiers,
argsMapping,
aggBufferNames,
aggBufferTypes,
@@ -710,7 +686,7 @@ abstract class WindowCodeGenerator(
def getAuxGrouping: Array[Int] = auxGrouping
- def getAggCallList: Seq[AggregateCall] = aggCallToAggFunction.map(_._1)
+ def getAggCallList: Seq[AggregateCall] = aggInfos.map(_.agg)
def getInputTimeValue(inputTerm: String, index: Int): String = {
if(inputTimeIsDate) {
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/types/LogicalTypeDataTypeConverter.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/types/LogicalTypeDataTypeConverter.java
index fde61b8..47bc778 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/types/LogicalTypeDataTypeConverter.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/types/LogicalTypeDataTypeConverter.java
@@ -34,6 +34,7 @@ import org.apache.flink.table.types.utils.TypeConversions;
@Deprecated
public class LogicalTypeDataTypeConverter {
+ @Deprecated
public static DataType fromLogicalTypeToDataType(LogicalType logicalType) {
return TypeConversions.fromLogicalToDataType(logicalType);
}
@@ -41,6 +42,7 @@ public class LogicalTypeDataTypeConverter {
/**
* It convert {@link LegacyTypeInformationType} to planner types.
*/
+ @Deprecated
public static LogicalType fromDataTypeToLogicalType(DataType dataType) {
return PlannerTypeUtils.removeLegacyTypes(dataType.getLogicalType());
}