You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2019/08/07 20:41:46 UTC
[flink] branch master updated: [FLINK-13471][table] Add stream
FlatAggregate support for blink planner.
This is an automated email from the ASF dual-hosted git repository.
jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 411da97 [FLINK-13471][table] Add stream FlatAggregate support for blink planner.
411da97 is described below
commit 411da973c4671eb3f7998bc6ce52259338fe04f1
Author: sunjincheng121 <su...@gmail.com>
AuthorDate: Sat Aug 3 16:14:05 2019 +0200
[FLINK-13471][table] Add stream FlatAggregate support for blink planner.
This closes #9322
---
.../planner/expressions/SqlAggFunctionVisitor.java | 47 ++--
.../planner/plan/QueryOperationConverter.java | 109 +++++---
.../table/planner/calcite/FlinkRelBuilder.scala | 20 +-
.../calcite/RelTimeIndicatorConverter.scala | 9 +
.../flink/table/planner/codegen/CodeGenUtils.scala | 4 +-
.../table/planner/codegen/agg/AggCodeGen.scala | 3 +-
.../codegen/agg/AggsHandlerCodeGenerator.scala | 190 ++++++++++++-
.../codegen/agg/DeclarativeAggCodeGen.scala | 3 +-
.../planner/codegen/agg/DistinctAggCodeGen.scala | 7 +-
.../planner/codegen/agg/ImperativeAggCodeGen.scala | 23 +-
.../table/planner/dataview/DataViewUtils.scala | 6 +-
.../table/planner/expressions/aggregations.scala | 12 +-
.../planner/functions/utils/AggSqlFunction.scala | 17 +-
.../functions/utils/UserDefinedFunctionUtils.scala | 52 ++--
.../plan/metadata/FlinkRelMdColumnInterval.scala | 31 ++-
.../FlinkRelMdFilteredColumnInterval.scala | 19 +-
.../metadata/FlinkRelMdModifiedMonotonicity.scala | 36 ++-
.../plan/nodes/calcite/LogicalTableAggregate.scala | 65 +++++
.../plan/nodes/calcite/TableAggregate.scala | 108 ++++++++
.../nodes/logical/FlinkLogicalTableAggregate.scala | 82 ++++++
.../stream/StreamExecGroupTableAggregate.scala | 189 +++++++++++++
.../planner/plan/rules/FlinkStreamRuleSets.scala | 2 +
.../stream/StreamExecGroupTableAggregateRule.scala | 64 +++++
.../table/planner/plan/utils/AggregateUtil.scala | 11 +-
.../table/planner/plan/utils/RelExplainUtil.scala | 5 +
.../plan/stream/table/TableAggregateTest.xml | 137 ++++++++++
.../metadata/FlinkRelMdColumnIntervalTest.scala | 10 +
.../FlinkRelMdFilteredColumnIntervalTest.scala | 14 +-
.../plan/metadata/FlinkRelMdHandlerTestBase.scala | 73 ++++-
.../FlinkRelMdModifiedMonotonicityTest.scala | 32 ++-
.../plan/stream/table/TableAggregateTest.scala | 112 ++++++++
.../TableAggregateStringExpressionTest.scala | 120 +++++++++
.../validation/TableAggregateValidationTest.scala | 149 +++++++++++
.../planner/runtime/harness/HarnessTestBase.scala | 26 ++
.../runtime/harness/OverWindowHarnessTest.scala | 26 +-
.../harness/TableAggregateHarnessTest.scala | 164 ++++++++++++
.../stream/table/TableAggregateITCase.scala | 241 +++++++++++++++++
.../flink/table/planner/utils/TableTestBase.scala | 40 ++-
.../utils/UserDefinedTableAggFunctions.scala | 298 +++++++++++++++++++++
.../flink/table/expressions/aggregations.scala | 4 +-
.../table/functions/utils/AggSqlFunction.scala | 4 +-
.../table/plan/logical/rel/TableAggregate.scala | 8 +-
.../runtime/generated/AggsHandleFunction.java | 70 +----
...leFunction.java => AggsHandleFunctionBase.java} | 20 +-
.../GeneratedTableAggsHandleFunction.java | 31 +++
.../runtime/generated/TableAggsHandleFunction.java | 42 +++
.../operators/aggregate/GroupTableAggFunction.java | 171 ++++++++++++
47 files changed, 2674 insertions(+), 232 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java
index 4b3b9a7..bb65381 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java
@@ -26,11 +26,12 @@ import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.AggregateFunctionDefinition;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
-import org.apache.flink.table.functions.UserDefinedAggregateFunction;
+import org.apache.flink.table.functions.FunctionRequirement;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.table.functions.TableAggregateFunctionDefinition;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable;
import org.apache.flink.table.planner.functions.utils.AggSqlFunction;
-import org.apache.flink.util.Preconditions;
import org.apache.calcite.sql.SqlAggFunction;
@@ -39,6 +40,7 @@ import java.util.Map;
import static org.apache.flink.table.expressions.utils.ApiExpressionUtils.isFunctionOfKind;
import static org.apache.flink.table.functions.FunctionKind.AGGREGATE;
+import static org.apache.flink.table.functions.FunctionKind.TABLE_AGGREGATE;
import static org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType;
/**
@@ -70,7 +72,10 @@ public class SqlAggFunctionVisitor extends ExpressionDefaultVisitor<SqlAggFuncti
@Override
public SqlAggFunction visit(CallExpression call) {
- Preconditions.checkArgument(isFunctionOfKind(call, AGGREGATE));
+ if (!isFunctionOfKind(call, AGGREGATE) && !isFunctionOfKind(call, TABLE_AGGREGATE)) {
+ defaultMethod(call);
+ }
+
FunctionDefinition def = call.getFunctionDefinition();
if (AGG_DEF_SQL_OPERATOR_MAPPING.containsKey(def)) {
return AGG_DEF_SQL_OPERATOR_MAPPING.get(def);
@@ -79,21 +84,31 @@ public class SqlAggFunctionVisitor extends ExpressionDefaultVisitor<SqlAggFuncti
Expression innerAgg = call.getChildren().get(0);
return innerAgg.accept(this);
}
- AggregateFunctionDefinition aggDef = (AggregateFunctionDefinition) def;
- UserDefinedAggregateFunction userDefinedAggregateFunc = aggDef.getAggregateFunction();
- if (userDefinedAggregateFunc instanceof AggregateFunction) {
- AggregateFunction aggFunc = (AggregateFunction) userDefinedAggregateFunc;
+
+ if (isFunctionOfKind(call, AGGREGATE)) {
+ AggregateFunctionDefinition aggDef = (AggregateFunctionDefinition) def;
+ AggregateFunction aggFunc = aggDef.getAggregateFunction();
return new AggSqlFunction(
- aggFunc.functionIdentifier(),
- aggFunc.toString(),
- aggFunc,
- fromLegacyInfoToDataType(aggDef.getResultTypeInfo()),
- fromLegacyInfoToDataType(aggDef.getAccumulatorTypeInfo()),
- typeFactory,
- aggFunc.requiresOver(),
- scala.Option.empty());
+ aggFunc.functionIdentifier(),
+ aggFunc.toString(),
+ aggFunc,
+ fromLegacyInfoToDataType(aggDef.getResultTypeInfo()),
+ fromLegacyInfoToDataType(aggDef.getAccumulatorTypeInfo()),
+ typeFactory,
+ aggFunc.getRequirements().contains(FunctionRequirement.OVER_WINDOW_ONLY),
+ scala.Option.empty());
} else {
- throw new UnsupportedOperationException("TableAggregateFunction is not supported yet!");
+ TableAggregateFunctionDefinition aggDef = (TableAggregateFunctionDefinition) def;
+ TableAggregateFunction aggFunc = aggDef.getTableAggregateFunction();
+ return new AggSqlFunction(
+ aggFunc.functionIdentifier(),
+ aggFunc.toString(),
+ aggFunc,
+ fromLegacyInfoToDataType(aggDef.getResultTypeInfo()),
+ fromLegacyInfoToDataType(aggDef.getAccumulatorTypeInfo()),
+ typeFactory,
+ false,
+ scala.Option.empty());
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
index e0f4763..7089ab0 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
@@ -120,6 +120,7 @@ public class QueryOperationConverter extends QueryOperationDefaultVisitor<RelNod
private final LookupCallResolver callResolver;
private final RexNodeConverter rexNodeConverter;
private final AggregateVisitor aggregateVisitor = new AggregateVisitor();
+ private final TableAggregateVisitor tableAggregateVisitor = new TableAggregateVisitor();
private final JoinExpressionVisitor joinExpressionVisitor = new JoinExpressionVisitor();
public QueryOperationConverter(FlinkRelBuilder relBuilder, FunctionLookup functionCatalog) {
@@ -171,7 +172,7 @@ public class QueryOperationConverter extends QueryOperationDefaultVisitor<RelNod
.map(expr -> convertToWindowProperty(expr.accept(callResolver), windowReference))
.collect(toList());
GroupKey groupKey = relBuilder.groupKey(groupings);
- return relBuilder.aggregate(logicalWindow, groupKey, windowProperties, aggregations).build();
+ return relBuilder.windowAggregate(logicalWindow, groupKey, windowProperties, aggregations).build();
}
private FlinkRelBuilder.PlannerNamedWindowProperty convertToWindowProperty(Expression expression,
@@ -202,11 +203,11 @@ public class QueryOperationConverter extends QueryOperationDefaultVisitor<RelNod
}
/**
- * Get the {@link AggCall} correspond to the aggregate expression.
+ * Get the {@link AggCall} correspond to the aggregate or table aggregate expression.
*/
private AggCall getAggCall(Expression aggregateExpression) {
if (isFunctionOfKind(aggregateExpression, TABLE_AGGREGATE)) {
- throw new UnsupportedOperationException("TableAggFunction is not supported yet!");
+ return aggregateExpression.accept(tableAggregateVisitor);
} else {
return aggregateExpression.accept(aggregateVisitor);
}
@@ -502,48 +503,92 @@ public class QueryOperationConverter extends QueryOperationDefaultVisitor<RelNod
protected AggCall defaultMethod(Expression expression) {
throw new TableException("Unexpected expression: " + expression);
}
- }
- private class AggCallVisitor extends ExpressionDefaultVisitor<RelBuilder.AggCall> {
-
- private final RelBuilder relBuilder;
- private final SqlAggFunctionVisitor sqlAggFunctionVisitor;
- private final RexNodeConverter rexNodeConverter;
- private final String name;
- private final boolean isDistinct;
-
- public AggCallVisitor(RelBuilder relBuilder, RexNodeConverter rexNodeConverter, String name,
- boolean isDistinct) {
- this.relBuilder = relBuilder;
- this.sqlAggFunctionVisitor = new SqlAggFunctionVisitor((FlinkTypeFactory) relBuilder.getTypeFactory());
- this.rexNodeConverter = rexNodeConverter;
- this.name = name;
- this.isDistinct = isDistinct;
- }
+ private class AggCallVisitor extends ExpressionDefaultVisitor<RelBuilder.AggCall> {
+
+ private final RelBuilder relBuilder;
+ private final SqlAggFunctionVisitor sqlAggFunctionVisitor;
+ private final RexNodeConverter rexNodeConverter;
+ private final String name;
+ private final boolean isDistinct;
+
+ public AggCallVisitor(RelBuilder relBuilder, RexNodeConverter rexNodeConverter, String name,
+ boolean isDistinct) {
+ this.relBuilder = relBuilder;
+ this.sqlAggFunctionVisitor = new SqlAggFunctionVisitor((FlinkTypeFactory) relBuilder.getTypeFactory());
+ this.rexNodeConverter = rexNodeConverter;
+ this.name = name;
+ this.isDistinct = isDistinct;
+ }
- @Override
- public RelBuilder.AggCall visit(CallExpression call) {
- FunctionDefinition def = call.getFunctionDefinition();
- if (BuiltInFunctionDefinitions.DISTINCT == def) {
- Expression innerAgg = call.getChildren().get(0);
- return innerAgg.accept(new AggCallVisitor(relBuilder, rexNodeConverter, name, true));
- } else {
- SqlAggFunction sqlAggFunction = call.accept(sqlAggFunctionVisitor);
- return relBuilder.aggregateCall(
+ @Override
+ public RelBuilder.AggCall visit(CallExpression call) {
+ FunctionDefinition def = call.getFunctionDefinition();
+ if (BuiltInFunctionDefinitions.DISTINCT == def) {
+ Expression innerAgg = call.getChildren().get(0);
+ return innerAgg.accept(new AggCallVisitor(relBuilder, rexNodeConverter, name, true));
+ } else {
+ SqlAggFunction sqlAggFunction = call.accept(sqlAggFunctionVisitor);
+ return relBuilder.aggregateCall(
sqlAggFunction,
isDistinct,
false,
null,
name,
call.getChildren().stream().map(expr -> expr.accept(rexNodeConverter))
- .collect(Collectors.toList()));
+ .collect(Collectors.toList()));
+ }
}
+ @Override
+ protected RelBuilder.AggCall defaultMethod(Expression expression) {
+ throw new TableException("Unexpected expression: " + expression);
+ }
}
+ }
+ private class TableAggregateVisitor extends ExpressionDefaultVisitor<RelBuilder.AggCall> {
@Override
- protected RelBuilder.AggCall defaultMethod(Expression expression) {
- throw new TableException("Unexpected expression: " + expression);
+ public AggCall visit(CallExpression call) {
+ if (isFunctionOfKind(call, TABLE_AGGREGATE)) {
+ return call.accept(new TableAggCallVisitor(relBuilder, rexNodeConverter));
+ }
+ return defaultMethod(call);
+ }
+
+ @Override
+ protected AggCall defaultMethod(Expression expression) {
+ throw new TableException("Expected table aggregate. Got: " + expression);
+ }
+
+ private class TableAggCallVisitor extends ExpressionDefaultVisitor<RelBuilder.AggCall> {
+
+ private final RelBuilder relBuilder;
+ private final SqlAggFunctionVisitor sqlAggFunctionVisitor;
+ private final RexNodeConverter rexNodeConverter;
+
+ public TableAggCallVisitor(RelBuilder relBuilder, RexNodeConverter rexNodeConverter) {
+ this.relBuilder = relBuilder;
+ this.sqlAggFunctionVisitor = new SqlAggFunctionVisitor((FlinkTypeFactory) relBuilder.getTypeFactory());
+ this.rexNodeConverter = rexNodeConverter;
+ }
+
+ @Override
+ public RelBuilder.AggCall visit(CallExpression call) {
+ SqlAggFunction sqlAggFunction = call.accept(sqlAggFunctionVisitor);
+ return relBuilder.aggregateCall(
+ sqlAggFunction,
+ false,
+ false,
+ null,
+ sqlAggFunction.toString(),
+ call.getChildren().stream().map(expr -> expr.accept(rexNodeConverter)).collect(toList()));
+ }
+
+ @Override
+ protected RelBuilder.AggCall defaultMethod(Expression expression) {
+ throw new TableException("Expected table aggregate. Got: " + expression);
+ }
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
index 31f8079..28e6110 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
@@ -26,7 +26,8 @@ import org.apache.flink.table.planner.calcite.FlinkRelFactories.{ExpandFactory,
import org.apache.flink.table.planner.expressions.{PlannerWindowProperty, WindowProperty}
import org.apache.flink.table.planner.plan.QueryOperationConverter
import org.apache.flink.table.planner.plan.logical.LogicalWindow
-import org.apache.flink.table.planner.plan.nodes.calcite.LogicalWindowAggregate
+import org.apache.flink.table.planner.plan.nodes.calcite.{LogicalTableAggregate, LogicalWindowAggregate}
+import org.apache.flink.table.planner.plan.utils.AggregateUtil
import org.apache.flink.table.runtime.operators.rank.{RankRange, RankType}
import org.apache.flink.table.sinks.TableSink
@@ -111,7 +112,22 @@ class FlinkRelBuilder(
push(rank)
}
- def aggregate(
+ /**
+ * Build non-window aggregate for either aggregate or table aggregate.
+ */
+ override def aggregate(groupKey: GroupKey, aggCalls: Iterable[AggCall]): RelBuilder = {
+ // build a relNode, the build() may also return a project
+ val relNode = super.aggregate(groupKey, aggCalls).build()
+
+ relNode match {
+ case logicalAggregate: LogicalAggregate
+ if AggregateUtil.isTableAggregate(logicalAggregate.getAggCallList) =>
+ push(LogicalTableAggregate.create(logicalAggregate))
+ case _ => push(relNode)
+ }
+ }
+
+ def windowAggregate(
window: LogicalWindow,
groupKey: GroupKey,
namedProperties: List[PlannerNamedWindowProperty],
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/RelTimeIndicatorConverter.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/RelTimeIndicatorConverter.scala
index b227d67..641643b 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/RelTimeIndicatorConverter.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/RelTimeIndicatorConverter.scala
@@ -133,6 +133,15 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
aggregate.getNamedProperties,
convAggregate)
+ case tableAggregate: LogicalTableAggregate =>
+ val correspondingAggregate = LogicalAggregate.create(
+ tableAggregate.getInput,
+ tableAggregate.getGroupSet,
+ tableAggregate.getGroupSets,
+ tableAggregate.getAggCallList)
+ val convAggregate = convertAggregate(correspondingAggregate)
+ LogicalTableAggregate.create(convAggregate)
+
case watermarkAssigner: LogicalWatermarkAssigner =>
watermarkAssigner
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
index 6f097f6..4ecfdbb 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
@@ -25,7 +25,7 @@ import org.apache.flink.table.dataformat.util.BinaryRowUtil.BYTE_ARRAY_BASE_OFFS
import org.apache.flink.table.dataformat.{BinaryStringUtil, Decimal, _}
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.runtime.dataview.StateDataViewStore
-import org.apache.flink.table.runtime.generated.{AggsHandleFunction, HashFunction, NamespaceAggsHandleFunction}
+import org.apache.flink.table.runtime.generated.{AggsHandleFunction, HashFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction}
import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter
import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getInternalClassForType
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
@@ -91,6 +91,8 @@ object CodeGenUtils {
val AGGS_HANDLER_FUNCTION: String = className[AggsHandleFunction]
+ val TABLE_AGGS_HANDLER_FUNCTION: String = className[TableAggsHandleFunction]
+
val NAMESPACE_AGGS_HANDLER_FUNCTION: String = className[NamespaceAggsHandleFunction[_]]
val STATE_DATA_VIEW_STORE: String = className[StateDataViewStore]
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggCodeGen.scala
index f890f93..1d6d80d 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggCodeGen.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggCodeGen.scala
@@ -45,5 +45,6 @@ trait AggCodeGen {
needAccumulate: Boolean = false,
needRetract: Boolean = false,
needMerge: Boolean = false,
- needReset: Boolean = false): Unit
+ needReset: Boolean = false,
+ needEmitValue: Boolean = false): Unit
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
index 1e73781..0b5f94f 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
@@ -19,8 +19,8 @@ package org.apache.flink.table.planner.codegen.agg
import org.apache.flink.table.api.TableException
import org.apache.flink.table.dataformat.GenericRow
+import org.apache.flink.table.dataformat.util.BaseRowUtil
import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.planner.codegen.CodeGenUtils.{BASE_ROW, _}
import org.apache.flink.table.planner.codegen.Indenter.toISC
import org.apache.flink.table.planner.codegen._
@@ -30,11 +30,14 @@ import org.apache.flink.table.planner.expressions.{PlannerProctimeAttribute, Pla
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.planner.plan.utils.AggregateInfoList
import org.apache.flink.table.runtime.dataview.{StateListView, StateMapView}
-import org.apache.flink.table.runtime.generated.{AggsHandleFunction, GeneratedAggsHandleFunction, GeneratedNamespaceAggsHandleFunction, NamespaceAggsHandleFunction}
+import org.apache.flink.table.runtime.generated.{AggsHandleFunction, GeneratedAggsHandleFunction, GeneratedNamespaceAggsHandleFunction, GeneratedTableAggsHandleFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction}
+import org.apache.flink.table.runtime.types.PlannerTypeUtils
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.{BooleanType, IntType, LogicalType, RowType}
import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType
+import org.apache.flink.table.functions.UserDefinedAggregateFunction
+import org.apache.flink.util.Collector
import org.apache.calcite.rex.RexLiteral
import org.apache.calcite.tools.RelBuilder
@@ -221,7 +224,7 @@ class AggsHandlerCodeGenerator(
inputFieldTypes,
constantExprs,
relBuilder)
- case _: AggregateFunction[_, _] =>
+ case _: UserDefinedAggregateFunction[_, _] =>
new ImperativeAggCodeGen(
ctx,
aggInfo,
@@ -398,6 +401,169 @@ class AggsHandlerCodeGenerator(
}
/**
+ * Generate [[GeneratedTableAggsHandleFunction]] with the given function name and aggregate
+ * infos.
+ */
+ def generateTableAggsHandler(
+ name: String,
+ aggInfoList: AggregateInfoList): GeneratedTableAggsHandleFunction = {
+
+ initialAggregateInformation(aggInfoList)
+
+ // generates all methods body first to add necessary reuse code to context
+ val createAccumulatorsCode = genCreateAccumulators()
+ val getAccumulatorsCode = genGetAccumulators()
+ val setAccumulatorsCode = genSetAccumulators()
+ val resetAccumulatorsCode = genResetAccumulators()
+ val accumulateCode = genAccumulate()
+ val retractCode = genRetract()
+ val mergeCode = genMerge()
+ val emitValueCode = genEmitValue()
+
+ // gen converter
+ val aggExternalType = aggInfoList.getActualAggregateInfos(0).externalResultType
+ val recordInputName = newName("recordInput")
+ val converterCode = CodeGenUtils.genToInternal(ctx, aggExternalType, recordInputName)
+
+ def genRecordToBaseRow: String = {
+ val resultType = fromDataTypeToLogicalType(aggExternalType)
+ val resultBaseRowType = PlannerTypeUtils.toRowType(resultType)
+
+ val newCtx = CodeGeneratorContext(ctx.tableConfig)
+ val exprGenerator = new ExprCodeGenerator(newCtx, false).bindInput(resultType)
+ val resultExpr = exprGenerator.generateConverterResultExpression(
+ resultBaseRowType, classOf[GenericRow], "convertResult")
+
+ val resultTypeClass = boxedTypeTermForType(resultType)
+ s"""
+ |${newCtx.reuseMemberCode()}
+ |$resultTypeClass ${exprGenerator.input1Term} = ($resultTypeClass) $converterCode;
+ |${newCtx.reuseLocalVariableCode()}
+ |${newCtx.reuseInputUnboxingCode()}
+ |${resultExpr.code}
+ |return ${resultExpr.resultTerm};
+ """.stripMargin
+ }
+
+ val functionName = newName(name)
+ val functionCode =
+ j"""
+ public final class $functionName implements ${className[TableAggsHandleFunction]} {
+
+ ${ctx.reuseMemberCode()}
+ private $CONVERT_COLLECTOR_TYPE_TERM $MEMBER_COLLECTOR_TERM;
+
+ public $functionName(java.lang.Object[] references) throws Exception {
+ ${ctx.reuseInitCode()}
+ $MEMBER_COLLECTOR_TERM = new $CONVERT_COLLECTOR_TYPE_TERM(references);
+ }
+
+ @Override
+ public void open($STATE_DATA_VIEW_STORE store) throws Exception {
+ ${ctx.reuseOpenCode()}
+ }
+
+ @Override
+ public void accumulate($BASE_ROW $ACCUMULATE_INPUT_TERM) throws Exception {
+ $accumulateCode
+ }
+
+ @Override
+ public void retract($BASE_ROW $RETRACT_INPUT_TERM) throws Exception {
+ $retractCode
+ }
+
+ @Override
+ public void merge($BASE_ROW $MERGED_ACC_TERM) throws Exception {
+ $mergeCode
+ }
+
+ @Override
+ public void setAccumulators($BASE_ROW $ACC_TERM) throws Exception {
+ $setAccumulatorsCode
+ }
+
+ @Override
+ public void resetAccumulators() throws Exception {
+ $resetAccumulatorsCode
+ }
+
+ @Override
+ public $BASE_ROW getAccumulators() throws Exception {
+ $getAccumulatorsCode
+ }
+
+ @Override
+ public $BASE_ROW createAccumulators() throws Exception {
+ $createAccumulatorsCode
+ }
+
+ @Override
+ public void emitValue(
+ $COLLECTOR<$BASE_ROW> $COLLECTOR_TERM, $BASE_ROW key, boolean isRetract)
+ throws Exception {
+
+ $MEMBER_COLLECTOR_TERM.reset(key, isRetract, $COLLECTOR_TERM);
+ $emitValueCode
+ }
+
+ @Override
+ public void cleanup() throws Exception {
+ ${ctx.reuseCleanupCode()}
+ }
+
+ @Override
+ public void close() throws Exception {
+ ${ctx.reuseCloseCode()}
+ }
+
+ private class $CONVERT_COLLECTOR_TYPE_TERM implements $COLLECTOR {
+ private $COLLECTOR<$BASE_ROW> $COLLECTOR_TERM;
+ private $BASE_ROW key;
+ private $JOINED_ROW result;
+ private boolean isRetract = false;
+ ${ctx.reuseMemberCode()}
+
+ public $CONVERT_COLLECTOR_TYPE_TERM(java.lang.Object[] references) throws Exception {
+ ${ctx.reuseInitCode()}
+ result = new $JOINED_ROW();
+ }
+
+ public void reset(
+ $BASE_ROW key, boolean isRetract, $COLLECTOR<$BASE_ROW> $COLLECTOR_TERM) {
+ this.key = key;
+ this.isRetract = isRetract;
+ this.$COLLECTOR_TERM = $COLLECTOR_TERM;
+ }
+
+ public $BASE_ROW convertToBaseRow(Object $recordInputName) throws Exception {
+ $genRecordToBaseRow
+ }
+
+ @Override
+ public void collect(Object $recordInputName) throws Exception {
+ $BASE_ROW tempBaseRow = convertToBaseRow($recordInputName);
+ result.replace(key, tempBaseRow);
+ if (isRetract) {
+ result.setHeader(${className[BaseRowUtil]}.RETRACT_MSG);
+ } else {
+ result.setHeader(${className[BaseRowUtil]}.ACCUMULATE_MSG);
+ }
+ $COLLECTOR_TERM.collect(result);
+ }
+
+ @Override
+ public void close() {
+ $COLLECTOR_TERM.close();
+ }
+ }
+ }
+ """.stripMargin
+
+ new GeneratedTableAggsHandleFunction(functionName, functionCode, ctx.references.toArray)
+ }
+
+ /**
* Generate [[GeneratedAggsHandleFunction]] with the given function name and aggregate infos
* and window properties.
*/
@@ -700,14 +866,21 @@ class AggsHandlerCodeGenerator(
""".stripMargin
}
+ private def genEmitValue(): String = {
+ // validation check
+ checkNeededMethods(needEmitValue = true)
+ aggBufferCodeGens(0).asInstanceOf[ImperativeAggCodeGen].emitValue
+ }
+
private def checkNeededMethods(
needAccumulate: Boolean = false,
needRetract: Boolean = false,
needMerge: Boolean = false,
- needReset: Boolean = false): Unit = {
+ needReset: Boolean = false,
+ needEmitValue: Boolean = false): Unit = {
// check and validate the needed methods
- aggBufferCodeGens
- .foreach(_.checkNeededMethods(needAccumulate, needRetract, needMerge, needReset))
+ aggBufferCodeGens.foreach(
+ _.checkNeededMethods(needAccumulate, needRetract, needMerge, needReset, needEmitValue))
}
private def genThrowException(msg: String): String = {
@@ -729,6 +902,11 @@ object AggsHandlerCodeGenerator {
val NAMESPACE_TERM = "namespace"
val STORE_TERM = "store"
+ val COLLECTOR: String = className[Collector[_]]
+ val COLLECTOR_TERM = "out"
+ val MEMBER_COLLECTOR_TERM = "convertCollector"
+ val CONVERT_COLLECTOR_TYPE_TERM = "ConvertCollector"
+
val INPUT_NOT_NULL = false
/**
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala
index a567923..7ae4c03 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala
@@ -304,7 +304,8 @@ class DeclarativeAggCodeGen(
needAccumulate: Boolean = false,
needRetract: Boolean = false,
needMerge: Boolean = false,
- needReset: Boolean = false): Unit = {
+ needReset: Boolean = false,
+ needEmitValue: Boolean = false): Unit = {
// skip the check for DeclarativeAggregateFunction for now
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala
index 98df983..78cb2f4 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala
@@ -354,14 +354,15 @@ class DistinctAggCodeGen(
needAccumulate: Boolean,
needRetract: Boolean,
needMerge: Boolean,
- needReset: Boolean): Unit = {
+ needReset: Boolean,
+ needEmitValue: Boolean): Unit = {
if (needMerge) {
// see merge method for more information
innerAggCodeGens
.foreach(_.checkNeededMethods(needAccumulate = true, needRetract = consumeRetraction))
} else {
- innerAggCodeGens
- .foreach(_.checkNeededMethods(needAccumulate, needRetract, needMerge, needReset))
+ innerAggCodeGens.foreach(
+ _.checkNeededMethods(needAccumulate, needRetract, needMerge, needReset, needEmitValue))
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
index fced26c..87f91c9 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
@@ -19,7 +19,7 @@ package org.apache.flink.table.planner.codegen.agg
import org.apache.flink.table.dataformat.{BaseRow, GenericRow, UpdatableRow}
import org.apache.flink.table.expressions.Expression
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.UserDefinedAggregateFunction
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.GenerateUtils.generateFieldAccess
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator._
@@ -34,6 +34,7 @@ import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDat
import org.apache.flink.table.runtime.types.{ClassDataTypeConverter, PlannerTypeUtils}
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.{LogicalType, RowType}
+import org.apache.flink.util.Collector
import org.apache.calcite.tools.RelBuilder
@@ -79,7 +80,7 @@ class ImperativeAggCodeGen(
private val SINGLE_ITERABLE = className[SingleElementIterator[_]]
private val UPDATABLE_ROW = className[UpdatableRow]
- val function: AggregateFunction[_, _] = aggInfo.function.asInstanceOf[AggregateFunction[_, _]]
+ val function = aggInfo.function.asInstanceOf[UserDefinedAggregateFunction[_, _]]
val functionTerm: String = ctx.addReusableFunction(
function,
contextTerm = s"$STORE_TERM.getRuntimeContext()")
@@ -441,7 +442,8 @@ class ImperativeAggCodeGen(
needAccumulate: Boolean = false,
needRetract: Boolean = false,
needMerge: Boolean = false,
- needReset: Boolean = false): Unit = {
+ needReset: Boolean = false,
+ needEmitValue: Boolean = false): Unit = {
val methodSignatures = internalTypesToClasses(argTypes)
@@ -503,5 +505,20 @@ class ImperativeAggCodeGen(
s"aggregate ${function.getClass.getCanonicalName}'.")
)
}
+
+ if (needEmitValue) {
+ val collectorDataType = ClassDataTypeConverter.fromClassToDataType(classOf[Collector[_]])
+ getUserDefinedMethod(function, "emitValue", Array(externalAccType, collectorDataType))
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching emitValue method found for " +
+ s"table aggregate ${function.getClass.getCanonicalName}'.")
+ )
+ }
+ }
+
+ def emitValue: String = {
+ val accTerm = if (isAccTypeInternal) accInternalTerm else accExternalTerm
+ s"$functionTerm.emitValue($accTerm, $MEMBER_COLLECTOR_TERM);"
}
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/dataview/DataViewUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/dataview/DataViewUtils.scala
index 6076ffc..d903884 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/dataview/DataViewUtils.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/dataview/DataViewUtils.scala
@@ -24,7 +24,7 @@ import org.apache.flink.table.api.TableException
import org.apache.flink.table.api.dataview._
import org.apache.flink.table.dataformat.{BinaryGeneric, GenericRow}
import org.apache.flink.table.dataview.{ListViewTypeInfo, MapViewTypeInfo}
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.UserDefinedAggregateFunction
import org.apache.flink.table.runtime.types.TypeInfoLogicalTypeConverter.fromTypeInfoToLogicalType
import org.apache.flink.table.runtime.typeutils.BaseRowTypeInfo
import org.apache.flink.table.types.DataType
@@ -41,7 +41,7 @@ object DataViewUtils {
* Use NullSerializer for StateView fields from accumulator type information.
*
* @param index index of aggregate function
- * @param aggFun aggregate function
+ * @param aggFun aggregate or table aggregate function
* @param externalAccType accumulator type information, only support pojo type
* @param isStateBackedDataViews is data views use state backend
* @return mapping of accumulator type information and data view config which contains id,
@@ -49,7 +49,7 @@ object DataViewUtils {
*/
def useNullSerializerForStateViewFieldsFromAccType(
index: Int,
- aggFun: AggregateFunction[_, _],
+ aggFun: UserDefinedAggregateFunction[_, _],
externalAccType: DataType,
isStateBackedDataViews: Boolean): (DataType, Array[DataViewSpec]) = {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/aggregations.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/aggregations.scala
index 6406372..08a2ff5 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/aggregations.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/aggregations.scala
@@ -186,21 +186,15 @@ case class VarSamp(child: PlannerExpression) extends Aggregation {
}
/**
- * Expression for calling a user-defined aggregate function.
+ * Expression for calling a user-defined (table)aggregate function.
*/
case class AggFunctionCall(
- val aggregateFunction: UserDefinedAggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
resultTypeInfo: TypeInformation[_],
accTypeInfo: TypeInformation[_],
args: Seq[PlannerExpression])
extends Aggregation {
- if (aggregateFunction.isInstanceOf[TableAggregateFunction[_, _]]) {
- throw new UnsupportedOperationException("TableAggregateFunction is unsupported now.")
- }
-
- private val aggFunction = aggregateFunction.asInstanceOf[AggregateFunction[_, _]]
-
override private[flink] def children: Seq[PlannerExpression] = args
override def resultType: TypeInformation[_] = resultTypeInfo
@@ -209,7 +203,7 @@ case class AggFunctionCall(
val signature = children.map(_.resultType)
// look for a signature that matches the input types
val foundSignature = getAccumulateMethodSignature(
- aggFunction,
+ aggregateFunction,
signature.map(fromTypeInfoToLogicalType))
if (foundSignature.isEmpty) {
ValidationFailure(s"Given parameters do not match any signature. \n" +
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/AggSqlFunction.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/AggSqlFunction.scala
index 5c5f744..c6ca9a3 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/AggSqlFunction.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/AggSqlFunction.scala
@@ -19,7 +19,7 @@
package org.apache.flink.table.planner.functions.utils
import org.apache.flink.table.api.ValidationException
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction, UserDefinedAggregateFunction}
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.functions.utils.AggSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._
@@ -38,7 +38,8 @@ import org.apache.calcite.util.Optionality
import java.util
/**
- * Calcite wrapper for user-defined aggregate functions.
+ * Calcite wrapper for user-defined aggregate functions. Currently, the aggregate function can be
+ * an [[AggregateFunction]] or a [[TableAggregateFunction]]
*
* @param name function name (used by SQL parser)
* @param displayName name to be displayed in operator name
@@ -50,7 +51,7 @@ import java.util
class AggSqlFunction(
name: String,
displayName: String,
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
val externalResultType: DataType,
val externalAccType: DataType,
typeFactory: FlinkTypeFactory,
@@ -72,8 +73,8 @@ class AggSqlFunction(
) {
def makeFunction(
- constants: Array[AnyRef], argTypes: Array[LogicalType]): AggregateFunction[_, _] =
- aggregateFunction
+ constants: Array[AnyRef],
+ argTypes: Array[LogicalType]): UserDefinedAggregateFunction[_, _] = aggregateFunction
override def isDeterministic: Boolean = aggregateFunction.isDeterministic
@@ -87,7 +88,7 @@ object AggSqlFunction {
def apply(
name: String,
displayName: String,
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
externalResultType: DataType,
externalAccType: DataType,
typeFactory: FlinkTypeFactory,
@@ -105,7 +106,7 @@ object AggSqlFunction {
private[flink] def createOperandTypeInference(
name: String,
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
typeFactory: FlinkTypeFactory): SqlOperandTypeInference = {
/**
* Operand type inference based on [[AggregateFunction]] given information.
@@ -155,7 +156,7 @@ object AggSqlFunction {
private[flink] def createOperandTypeChecker(
name: String,
- aggregateFunction: AggregateFunction[_, _]): SqlOperandTypeChecker = {
+ aggregateFunction: UserDefinedAggregateFunction[_, _]): SqlOperandTypeChecker = {
val methods = checkAndExtractMethods(aggregateFunction, "accumulate")
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala
index 3552a7f..664dd8e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala
@@ -148,7 +148,7 @@ object UserDefinedFunctionUtils {
}
def getAggUserDefinedInputTypes(
- func: AggregateFunction[_, _],
+ func: UserDefinedAggregateFunction[_, _],
externalAccType: DataType,
expectedTypes: Array[LogicalType]): Array[DataType] = {
val accMethod = getAggFunctionUDIMethod(
@@ -188,7 +188,7 @@ object UserDefinedFunctionUtils {
* Elements of the signature can be null (act as a wildcard).
*/
def getAccumulateMethodSignature(
- function: AggregateFunction[_, _],
+ function: UserDefinedAggregateFunction[_, _],
expectedTypes: Seq[LogicalType])
: Option[Array[Class[_]]] = {
getAggFunctionUDIMethod(
@@ -239,7 +239,7 @@ object UserDefinedFunctionUtils {
}
def getAggFunctionUDIMethod(
- function: AggregateFunction[_, _],
+ function: UserDefinedAggregateFunction[_, _],
methodName: String,
accType: DataType,
expectedTypes: Seq[LogicalType])
@@ -530,30 +530,31 @@ object UserDefinedFunctionUtils {
// ----------------------------------------------------------------------------------------------
/**
- * Tries to infer the DataType of an AggregateFunction's return type.
+ * Tries to infer the DataType of a [[UserDefinedAggregateFunction]]'s return type.
*
- * @param aggregateFunction The AggregateFunction for which the return type is inferred.
- * @param extractedType The implicitly inferred type of the result type.
- *
- * @return The inferred result type of the AggregateFunction.
+ * @param userDefinedAggregateFunction The [[UserDefinedAggregateFunction]] for which the return
+ * type is inferred.
+ * @param extractedType The implicitly inferred type of the result type.
+ * @return The inferred result type of the [[UserDefinedAggregateFunction]].
*/
def getResultTypeOfAggregateFunction(
- aggregateFunction: AggregateFunction[_, _],
+ userDefinedAggregateFunction: UserDefinedAggregateFunction[_, _],
extractedType: DataType = null): DataType = {
- val resultType = aggregateFunction.getResultType
+ val resultType = userDefinedAggregateFunction.getResultType
if (resultType != null) {
fromLegacyInfoToDataType(resultType)
} else if (extractedType != null) {
extractedType
} else {
try {
- extractTypeFromAggregateFunction(aggregateFunction, 0)
+ extractTypeFromAggregateFunction(userDefinedAggregateFunction, 0)
} catch {
case ite: InvalidTypesException =>
throw new TableException(
"Cannot infer generic type of ${aggregateFunction.getClass}. " +
- "You can override AggregateFunction.getResultType() to specify the type.",
+ "You can override UserDefinedAggregateFunction.getResultType() to " +
+ "specify the type.",
ite
)
}
@@ -561,30 +562,31 @@ object UserDefinedFunctionUtils {
}
/**
- * Tries to infer the Type of an AggregateFunction's accumulator type.
- *
- * @param aggregateFunction The AggregateFunction for which the accumulator type is inferred.
- * @param extractedType The implicitly inferred type of the accumulator type.
+ * Tries to infer the Type of a [[UserDefinedAggregateFunction]]'s accumulator type.
*
- * @return The inferred accumulator type of the AggregateFunction.
+ * @param userDefinedAggregateFunction The [[UserDefinedAggregateFunction]] for which the
+ * accumulator type is inferred.
+ * @param extractedType The implicitly inferred type of the accumulator type.
+ * @return The inferred accumulator type of the [[UserDefinedAggregateFunction]].
*/
def getAccumulatorTypeOfAggregateFunction(
- aggregateFunction: AggregateFunction[_, _],
+ userDefinedAggregateFunction: UserDefinedAggregateFunction[_, _],
extractedType: DataType = null): DataType = {
- val accType = aggregateFunction.getAccumulatorType
+ val accType = userDefinedAggregateFunction.getAccumulatorType
if (accType != null) {
fromLegacyInfoToDataType(accType)
} else if (extractedType != null) {
extractedType
} else {
try {
- extractTypeFromAggregateFunction(aggregateFunction, 1)
+ extractTypeFromAggregateFunction(userDefinedAggregateFunction, 1)
} catch {
case ite: InvalidTypesException =>
throw new TableException(
"Cannot infer generic type of ${aggregateFunction.getClass}. " +
- "You can override AggregateFunction.getAccumulatorType() to specify the type.",
+ "You can override UserDefinedAggregateFunction.getAccumulatorType() to specify " +
+ "the type.",
ite
)
}
@@ -593,21 +595,21 @@ object UserDefinedFunctionUtils {
}
/**
- * Internal method to extract a type from an AggregateFunction's type parameters.
+ * Internal method to extract a type from a [[UserDefinedAggregateFunction]]'s type parameters.
*
- * @param aggregateFunction The AggregateFunction for which the type is extracted.
+ * @param aggregateFunction The [[UserDefinedAggregateFunction]] for which the type is extracted.
* @param parameterTypePos The position of the type parameter for which the type is extracted.
*
* @return The extracted type.
*/
@throws(classOf[InvalidTypesException])
private def extractTypeFromAggregateFunction(
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
parameterTypePos: Int): DataType = {
fromLegacyInfoToDataType(TypeExtractor.createTypeInfo(
aggregateFunction,
- classOf[AggregateFunction[_, _]],
+ classOf[UserDefinedAggregateFunction[_, _]],
aggregateFunction.getClass,
parameterTypePos))
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
index 848d8f9..c512689 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.ColumnInterval
-import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
+import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, TableAggregate, WindowAggregate}
import org.apache.flink.table.planner.plan.nodes.physical.batch._
import org.apache.flink.table.planner.plan.nodes.physical.stream._
import org.apache.flink.table.planner.plan.schema.FlinkRelOptTable
@@ -365,6 +365,20 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
+ * Gets interval of the given column on TableAggregates.
+ *
+ * @param aggregate TableAggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on TableAggregate
+ */
+ def getColumnInterval(
+ aggregate: TableAggregate,
+ mq: RelMetadataQuery, index: Int): ValueInterval =
+
+ estimateColumnIntervalOfAggregate(aggregate, mq, index)
+
+ /**
* Gets interval of the given column on batch group aggregate.
*
* @param aggregate batch group aggregate RelNode
@@ -391,6 +405,19 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
+ * Gets interval of the given column on stream group table aggregate.
+ *
+ * @param aggregate stream group table aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on stream group TableAggregate
+ */
+ def getColumnInterval(
+ aggregate: StreamExecGroupTableAggregate,
+ mq: RelMetadataQuery,
+ index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
+
+ /**
* Gets interval of the given column on stream local group aggregate.
*
* @param aggregate stream local group aggregate RelNode
@@ -476,6 +503,8 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
// grouping + assignTs + auxGrouping
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
+ case agg: TableAggregate => agg.getGroupSet.toArray
+ case agg: StreamExecGroupTableAggregate => agg.grouping
}
if (index < groupSet.length) {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala
index c5ecc1f..0f962b0 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala
@@ -18,8 +18,9 @@
package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval
+import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase
-import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupAggregate, StreamExecGroupWindowAggregate, StreamExecLocalGroupAggregate}
+import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupAggregate, StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecLocalGroupAggregate}
import org.apache.flink.table.planner.plan.stats.ValueInterval
import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil
import org.apache.flink.util.Preconditions.checkArgument
@@ -167,6 +168,14 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC
}
def getFilteredColumnInterval(
+ aggregate: TableAggregate,
+ mq: RelMetadataQuery,
+ columnIndex: Int,
+ filterArg: Int): ValueInterval = {
+ estimateFilteredColumnIntervalOfAggregate(aggregate, mq, columnIndex, filterArg)
+ }
+
+ def getFilteredColumnInterval(
aggregate: BatchExecGroupAggregateBase,
mq: RelMetadataQuery,
columnIndex: Int,
@@ -183,6 +192,14 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC
}
def getFilteredColumnInterval(
+ aggregate: StreamExecGroupTableAggregate,
+ mq: RelMetadataQuery,
+ columnIndex: Int,
+ filterArg: Int): ValueInterval = {
+ estimateFilteredColumnIntervalOfAggregate(aggregate, mq, columnIndex, filterArg)
+ }
+
+ def getFilteredColumnInterval(
aggregate: StreamExecLocalGroupAggregate,
mq: RelMetadataQuery,
columnIndex: Int,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
index 9c6c77a..cd55234 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
@@ -22,7 +22,7 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction
import org.apache.flink.table.planner.plan.`trait`.RelModifiedMonotonicity
import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.ModifiedMonotonicity
-import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
+import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, TableAggregate, WindowAggregate}
import org.apache.flink.table.planner.plan.nodes.logical._
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecCorrelate, BatchExecGroupAggregateBase}
import org.apache.flink.table.planner.plan.nodes.physical.stream._
@@ -226,6 +226,12 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
}
def getRelModifiedMonotonicity(
+ rel: TableAggregate, mq: RelMetadataQuery): RelModifiedMonotonicity = {
+ getRelModifiedMonotonicityOnTableAggregate(
+ rel.getInput, rel.getGroupSet.toArray, rel.getRowType.getFieldCount, mq)
+ }
+
+ def getRelModifiedMonotonicity(
rel: BatchExecGroupAggregateBase,
mq: RelMetadataQuery): RelModifiedMonotonicity = null
@@ -236,6 +242,13 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
}
def getRelModifiedMonotonicity(
+ rel: StreamExecGroupTableAggregate,
+ mq: RelMetadataQuery): RelModifiedMonotonicity = {
+ getRelModifiedMonotonicityOnTableAggregate(
+ rel.getInput, rel.grouping, rel.getRowType.getFieldCount, mq)
+ }
+
+ def getRelModifiedMonotonicity(
rel: StreamExecGlobalGroupAggregate,
mq: RelMetadataQuery): RelModifiedMonotonicity = {
// global and local agg should have same update monotonicity
@@ -278,6 +291,27 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
rel: StreamExecOverAggregate,
mq: RelMetadataQuery): RelModifiedMonotonicity = constants(rel.getRowType.getFieldCount)
+ def getRelModifiedMonotonicityOnTableAggregate(
+ input: RelNode,
+ grouping: Array[Int],
+ rowSize: Int,
+ mq: RelMetadataQuery): RelModifiedMonotonicity = {
+
+ val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
+ val inputMonotonicity = fmq.getRelModifiedMonotonicity(input)
+
+ // if group by an update field or group by a field mono is null, just return null
+ if (grouping.exists(e =>
+ inputMonotonicity == null || inputMonotonicity.fieldMonotonicities(e) != CONSTANT)) {
+ return null
+ }
+
+ val groupCnt = grouping.length
+ val fieldMonotonicity =
+ Array.fill(groupCnt)(CONSTANT) ++ Array.fill(rowSize - grouping.length)(NOT_MONOTONIC)
+ new RelModifiedMonotonicity(fieldMonotonicity)
+ }
+
def getRelModifiedMonotonicityOnAggregate(
input: RelNode,
mq: RelMetadataQuery,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalTableAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalTableAggregate.scala
new file mode 100644
index 0000000..c8b5348
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalTableAggregate.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.calcite
+
+import java.util
+
+import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
+import org.apache.calcite.util.ImmutableBitSet
+
+/**
+ * Sub-class of [[TableAggregate]] that is a relational expression which performs aggregations but
+ * outputs 0 or more records for a group. This class corresponds to Calcite logical rel.
+ */
+class LogicalTableAggregate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ input: RelNode,
+ groupSet: ImmutableBitSet,
+ groupSets: util.List[ImmutableBitSet],
+ aggCalls: util.List[AggregateCall])
+ extends TableAggregate(cluster, traitSet, input, groupSet, groupSets, aggCalls) {
+
+ override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): TableAggregate = {
+ new LogicalTableAggregate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ groupSet,
+ groupSets,
+ aggCalls
+ )
+ }
+}
+
+object LogicalTableAggregate {
+
+ def create(aggregate: Aggregate): LogicalTableAggregate = {
+
+ new LogicalTableAggregate(
+ aggregate.getCluster,
+ aggregate.getCluster.traitSetOf(Convention.NONE),
+ aggregate.getInput,
+ aggregate.getGroupSet,
+ aggregate.getGroupSets,
+ aggregate.getAggCallList)
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/TableAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/TableAggregate.scala
new file mode 100644
index 0000000..cb1e254
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/TableAggregate.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.calcite
+
+import java.util
+
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.util.{ImmutableBitSet, Pair, Util}
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
+import org.apache.flink.table.types.utils.{LegacyTypeInfoDataTypeConverter, TypeConversions}
+import org.apache.flink.table.typeutils.FieldInfoUtils
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ListBuffer
+
+/**
+ * Relational operator that represents a table aggregate. A TableAggregate is similar to the
+ * [[org.apache.calcite.rel.core.Aggregate]] but may output 0 or more records for a group.
+ */
+abstract class TableAggregate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ input: RelNode,
+ groupSet: ImmutableBitSet,
+ groupSets: util.List[ImmutableBitSet],
+ val aggCalls: util.List[AggregateCall])
+ extends SingleRel(cluster, traitSet, input) {
+
+ private[flink] def getGroupSet: ImmutableBitSet = groupSet
+
+ private[flink] def getGroupSets: util.List[ImmutableBitSet] = groupSets
+
+ private[flink] def getAggCallList: util.List[AggregateCall] = aggCalls
+
+ private[flink] def getNamedAggCalls: util.List[Pair[AggregateCall, String]] = {
+ getNamedAggCalls(aggCalls, deriveRowType(), groupSet)
+ }
+
+ override def deriveRowType(): RelDataType = {
+ deriveTableAggRowType(cluster, input, groupSet, aggCalls)
+ }
+
+ protected def deriveTableAggRowType(
+ cluster: RelOptCluster,
+ child: RelNode,
+ groupSet: ImmutableBitSet,
+ aggCalls: util.List[AggregateCall]): RelDataType = {
+
+ val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ val builder = typeFactory.builder
+ val groupNames = new ListBuffer[String]
+
+ // group key fields
+ groupSet.asList().foreach(e => {
+ val field = child.getRowType.getFieldList.get(e)
+ groupNames.append(field.getName)
+ builder.add(field)
+ })
+
+ // agg fields
+ val aggCall = aggCalls.get(0)
+ if (aggCall.`type`.isStruct) {
+ // only a structured type contains a field list.
+ aggCall.`type`.getFieldList.foreach(builder.add)
+ } else {
+ // A non-structured type does not have a field list, so get field name through
+ // FieldInfoUtils.getFieldNames.
+ val logicalType = FlinkTypeFactory.toLogicalType(aggCall.`type`)
+ val dataType = TypeConversions.fromLogicalToDataType(logicalType)
+ val name = FieldInfoUtils
+ .getFieldNames(LegacyTypeInfoDataTypeConverter.toLegacyTypeInfo(dataType), groupNames).head
+ builder.add(name, aggCall.`type`)
+ }
+ builder.build()
+ }
+
+ private[flink] def getNamedAggCalls(
+ aggCalls: util.List[AggregateCall],
+ rowType: RelDataType,
+ groupSet: ImmutableBitSet): util.List[Pair[AggregateCall, String]] = {
+ Pair.zip(aggCalls, Util.skip(rowType.getFieldNames, groupSet.cardinality))
+ }
+
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ super.explainTerms(pw)
+ .item("group", groupSet)
+ .item("tableAggregate", aggCalls)
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableAggregate.scala
new file mode 100644
index 0000000..161da94
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableAggregate.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.logical
+
+import java.util
+import java.util.{List => JList}
+
+import org.apache.calcite.plan._
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.util.ImmutableBitSet
+import org.apache.flink.table.planner.plan.nodes.FlinkConventions
+import org.apache.flink.table.planner.plan.nodes.calcite.{LogicalTableAggregate, TableAggregate}
+
+/**
+ * Sub-class of [[TableAggregate]] that is a relational expression which performs aggregations but
+ * outputs 0 or more records for a group.
+ */
+class FlinkLogicalTableAggregate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ input: RelNode,
+ groupSet: ImmutableBitSet,
+ groupSets: util.List[ImmutableBitSet],
+ aggCalls: util.List[AggregateCall])
+ extends TableAggregate(cluster, traitSet, input, groupSet, groupSets, aggCalls)
+ with FlinkLogicalRel {
+
+ override def copy(traitSet: RelTraitSet, inputs: JList[RelNode]): RelNode = {
+ new FlinkLogicalTableAggregate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ groupSet,
+ groupSets,
+ aggCalls
+ )
+ }
+}
+
+private class FlinkLogicalTableAggregateConverter
+ extends ConverterRule(
+ classOf[LogicalTableAggregate],
+ Convention.NONE,
+ FlinkConventions.LOGICAL,
+ "FlinkLogicalTableAggregateConverter") {
+
+ override def convert(rel: RelNode): RelNode = {
+ val agg = rel.asInstanceOf[LogicalTableAggregate]
+ val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL)
+ val newInput = RelOptRule.convert(agg.getInput, FlinkConventions.LOGICAL)
+
+ new FlinkLogicalTableAggregate(
+ rel.getCluster,
+ traitSet,
+ newInput,
+ agg.getGroupSet,
+ agg.getGroupSets,
+ agg.aggCalls)
+ }
+}
+
+object FlinkLogicalTableAggregate {
+ val CONVERTER: ConverterRule = new FlinkLogicalTableAggregateConverter()
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregate.scala
new file mode 100644
index 0000000..e06b9b8
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregate.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.plan.nodes.physical.stream
+
+import java.util
+
+import org.apache.flink.api.dag.Transformation
+import org.apache.flink.streaming.api.operators.KeyedProcessOperator
+import org.apache.flink.streaming.api.transformations.OneInputTransformation
+import org.apache.flink.table.dataformat.BaseRow
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
+import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext
+import org.apache.flink.table.planner.delegation.StreamPlanner
+import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery
+import org.apache.flink.table.planner.plan.nodes.exec.{ExecNode, StreamExecNode}
+import org.apache.flink.table.planner.plan.rules.physical.stream.StreamExecRetractionRules
+import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil, KeySelectorUtil, RelExplainUtil}
+import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
+import org.apache.flink.table.runtime.typeutils.BaseRowTypeInfo
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.flink.table.runtime.operators.aggregate.GroupTableAggFunction
+
+import scala.collection.JavaConversions._
+
+/**
+ * Stream physical RelNode for unbounded group table aggregate.
+ */
+class StreamExecGroupTableAggregate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ inputRel: RelNode,
+ outputRowType: RelDataType,
+ val grouping: Array[Int],
+ val aggCalls: Seq[AggregateCall])
+ extends SingleRel(cluster, traitSet, inputRel)
+ with StreamPhysicalRel
+ with StreamExecNode[BaseRow] {
+
+ val aggInfoList: AggregateInfoList = {
+ val needRetraction = StreamExecRetractionRules.isAccRetract(getInput)
+ val fmq = FlinkRelMetadataQuery.reuseOrCreate(cluster.getMetadataQuery)
+ val monotonicity = fmq.getRelModifiedMonotonicity(this)
+ val needRetractionArray = AggregateUtil.getNeedRetractions(
+ grouping.length, needRetraction, monotonicity, aggCalls)
+ AggregateUtil.transformToStreamAggregateInfoList(
+ aggCalls,
+ getInput.getRowType,
+ needRetractionArray,
+ needInputCount = needRetraction,
+ isStateBackendDataViews = true)
+ }
+
+ override def producesUpdates = true
+
+ override def needsUpdatesAsRetraction(input: RelNode) = true
+
+ override def consumesRetractions = true
+
+ override def producesRetractions: Boolean = false
+
+ override def requireWatermark: Boolean = false
+
+ override def deriveRowType(): RelDataType = outputRowType
+
+ override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
+ new StreamExecGroupTableAggregate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ outputRowType,
+ grouping,
+ aggCalls)
+ }
+
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ val inputRowType = getInput.getRowType
+ super.explainTerms(pw)
+ .itemIf("groupBy",
+ RelExplainUtil.fieldToString(grouping, inputRowType), grouping.nonEmpty)
+ .item("select", RelExplainUtil.streamGroupAggregationToString(
+ inputRowType,
+ getRowType,
+ aggInfoList,
+ grouping))
+ }
+
+ //~ ExecNode methods -----------------------------------------------------------
+
+ override def getInputNodes: util.List[ExecNode[StreamPlanner, _]] = {
+ getInputs.map(_.asInstanceOf[ExecNode[StreamPlanner, _]])
+ }
+
+ override def replaceInputNode(
+ ordinalInParent: Int,
+ newInputNode: ExecNode[StreamPlanner, _]): Unit = {
+ replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode])
+ }
+
+ override protected def translateToPlanInternal(
+ planner: StreamPlanner): Transformation[BaseRow] = {
+
+ val tableConfig = planner.getTableConfig
+
+ if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime < 0) {
+ LOG.warn("No state retention interval configured for a query which accumulates state. " +
+ "Please provide a query configuration with valid retention interval to prevent excessive " +
+ "state size. You may specify a retention time of 0 to not clean up the state.")
+ }
+
+ val inputTransformation = getInputNodes.get(0).translateToPlan(planner)
+ .asInstanceOf[Transformation[BaseRow]]
+
+ val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType)
+ val inputRowType = FlinkTypeFactory.toLogicalRowType(getInput.getRowType)
+
+ val generateRetraction = StreamExecRetractionRules.isAccRetract(this)
+ val needRetraction = StreamExecRetractionRules.isAccRetract(getInput)
+
+ val generator = new AggsHandlerCodeGenerator(
+ CodeGeneratorContext(tableConfig),
+ planner.getRelBuilder,
+ inputRowType.getChildren,
+ // TODO: heap state backend do not copy key currently, we have to copy input field
+ // TODO: copy is not need when state backend is rocksdb, improve this in future
+ // TODO: but other operators do not copy this input field.....
+ copyInputField = true)
+
+ if (needRetraction) {
+ generator.needRetract()
+ }
+
+ val aggsHandler = generator
+ .needAccumulate()
+ .generateTableAggsHandler("GroupTableAggHandler", aggInfoList)
+
+ val accTypes = aggInfoList.getAccTypes.map(fromDataTypeToLogicalType)
+ val inputCountIndex = aggInfoList.getIndexOfCountStar
+
+ val aggFunction = new GroupTableAggFunction(
+ tableConfig.getMinIdleStateRetentionTime,
+ tableConfig.getMaxIdleStateRetentionTime,
+ aggsHandler,
+ accTypes,
+ inputCountIndex,
+ generateRetraction)
+ val operator = new KeyedProcessOperator[BaseRow, BaseRow, BaseRow](aggFunction)
+
+ val selector = KeySelectorUtil.getBaseRowSelector(
+ grouping,
+ BaseRowTypeInfo.of(inputRowType))
+
+ // partitioned aggregation
+ val ret = new OneInputTransformation(
+ inputTransformation,
+ "GroupTableAggregate",
+ operator,
+ BaseRowTypeInfo.of(outRowType),
+ inputTransformation.getParallelism)
+
+ if (inputsContainSingleton()) {
+ ret.setParallelism(1)
+ ret.setMaxParallelism(1)
+ }
+
+ // set KeyType and Selector for state
+ ret.setStateKeySelector(selector)
+ ret.setStateKeyType(selector.getProducedType)
+ ret
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index cb371ab..beb00e5 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -290,6 +290,7 @@ object FlinkStreamRuleSets {
private val LOGICAL_CONVERTERS: RuleSet = RuleSets.ofList(
// translate to flink logical rel nodes
FlinkLogicalAggregate.STREAM_CONVERTER,
+ FlinkLogicalTableAggregate.CONVERTER,
FlinkLogicalOverAggregate.CONVERTER,
FlinkLogicalCalc.CONVERTER,
FlinkLogicalCorrelate.CONVERTER,
@@ -365,6 +366,7 @@ object FlinkStreamRuleSets {
StreamExecExpandRule.INSTANCE,
// group agg
StreamExecGroupAggregateRule.INSTANCE,
+ StreamExecGroupTableAggregateRule.INSTANCE,
// over agg
StreamExecOverAggregateRule.INSTANCE,
// window agg
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecGroupTableAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecGroupTableAggregateRule.scala
new file mode 100644
index 0000000..ba441e7
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecGroupTableAggregateRule.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.stream
+
+import org.apache.calcite.plan.RelOptRule
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
+import org.apache.flink.table.planner.plan.nodes.FlinkConventions
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableAggregate
+import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecGroupTableAggregate
+
+import scala.collection.JavaConversions._
+
+class StreamExecGroupTableAggregateRule extends ConverterRule(
+ classOf[FlinkLogicalTableAggregate],
+ FlinkConventions.LOGICAL,
+ FlinkConventions.STREAM_PHYSICAL,
+ "StreamExecGroupTableAggregateRule") {
+
+ def convert(rel: RelNode): RelNode = {
+ val agg: FlinkLogicalTableAggregate = rel.asInstanceOf[FlinkLogicalTableAggregate]
+ val requiredDistribution = if (agg.getGroupSet.cardinality() != 0) {
+ FlinkRelDistribution.hash(agg.getGroupSet.asList)
+ } else {
+ FlinkRelDistribution.SINGLETON
+ }
+ val requiredTraitSet = rel.getCluster.getPlanner.emptyTraitSet()
+ .replace(requiredDistribution)
+ .replace(FlinkConventions.STREAM_PHYSICAL)
+ val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
+ val newInput: RelNode = RelOptRule.convert(agg.getInput, requiredTraitSet)
+
+ new StreamExecGroupTableAggregate(
+ rel.getCluster,
+ providedTraitSet,
+ newInput,
+ agg.getRowType,
+ agg.getGroupSet.toArray,
+ agg.getAggCallList
+ )
+ }
+}
+
+object StreamExecGroupTableAggregateRule {
+ val INSTANCE: StreamExecGroupTableAggregateRule = new StreamExecGroupTableAggregateRule()
+}
+
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index 8877be5..929c242 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -24,7 +24,7 @@ import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.dataview.MapViewTypeInfo
import org.apache.flink.table.expressions.ExpressionUtils.extractValue
import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction, UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.flink.table.planner.JLong
import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindowProperty
import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, FlinkTypeSystem}
@@ -289,7 +289,7 @@ object AggregateUtil extends Enumeration {
val bufferTypeInfos = bufferTypes.map(fromLogicalTypeToDataType)
(bufferTypeInfos, Array.empty[DataViewSpec],
fromLogicalTypeToDataType(a.getResultType.getLogicalType))
- case a: AggregateFunction[_, _] =>
+ case a: UserDefinedAggregateFunction[_, _] =>
val (implicitAccType, implicitResultType) = call.getAggregation match {
case aggSqlFun: AggSqlFunction =>
(aggSqlFun.externalAccType, aggSqlFun.externalResultType)
@@ -745,4 +745,11 @@ object AggregateUtil extends Enumeration {
def toDuration(literalExpr: ValueLiteralExpression): Duration =
extractValue(literalExpr, classOf[Duration]).get()
+
+ private[flink] def isTableAggregate(aggCalls: util.List[AggregateCall]): Boolean = {
+ aggCalls
+ .filter(e => e.getAggregation.isInstanceOf[AggSqlFunction])
+ .map(e => e.getAggregation.asInstanceOf[AggSqlFunction].makeFunction(null, null))
+ .exists(_.isInstanceOf[TableAggregateFunction[_, _]])
+ }
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RelExplainUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RelExplainUtil.scala
index 6cf72aa..93440a3 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RelExplainUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RelExplainUtil.scala
@@ -329,11 +329,16 @@ object RelExplainUtil {
stringifyAggregates(aggInfos, distinctAggs, aggFilters, inFieldNames)
}
+ val isTableAggregate =
+ AggregateUtil.isTableAggregate(aggInfoList.getActualAggregateCalls.toList)
val outputFieldNames = if (isLocal) {
grouping.map(inFieldNames(_)) ++ localAggOutputFieldNames(aggOffset, aggInfos, outFieldNames)
} else if (isIncremental) {
val accFieldNames = inputRowType.getFieldNames.toList.toArray
grouping.map(inFieldNames(_)) ++ localAggOutputFieldNames(aggOffset, aggInfos, accFieldNames)
+ } else if (isTableAggregate) {
+ outFieldNames.slice(0, grouping.length) ++
+ Seq(s"(${outFieldNames.drop(grouping.length).mkString(", ")})")
} else {
outFieldNames
}
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.xml
new file mode 100644
index 0000000..4792ed2
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.xml
@@ -0,0 +1,137 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testJavaRegisterFunction">
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(c=[$0], f0=[$1], f1=[$2])
++- LogicalTableAggregate(group=[{2}], tableAggregate=[[EmptyTableAggFunc($0)]])
+ +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+GroupTableAggregate(groupBy=[c], select=[c, EmptyTableAggFunc(a) AS (f0, f1)])
++- Exchange(distribution=[hash[c]])
+ +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testTableAggregateWithAlias">
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[AS($0, _UTF-16LE'a')], b=[AS($1, _UTF-16LE'b')])
++- LogicalTableAggregate(group=[{}], tableAggregate=[[EmptyTableAggFunc($1)]])
+ +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[f0 AS a, f1 AS b])
++- GroupTableAggregate(select=[EmptyTableAggFunc(b) AS (f0, f1)])
+ +- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testTableAggregateWithGroupBy">
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(bb=[AS($0, _UTF-16LE'bb')], _c1=[+(AS($1, _UTF-16LE'x'), 1)], y=[AS($2, _UTF-16LE'y')])
++- LogicalTableAggregate(group=[{5}], tableAggregate=[[EmptyTableAggFunc($0, $1)]])
+ +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], bb=[MOD($1, 5)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[bb, +(f0, 1) AS _c1, f1 AS y])
++- GroupTableAggregate(groupBy=[bb], select=[bb, EmptyTableAggFunc(a, b) AS (f0, f1)])
+ +- Exchange(distribution=[hash[bb]])
+ +- Calc(select=[a, b, c, d, e, MOD(b, 5) AS bb])
+ +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testTableAggregateWithIntResultType">
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(f0=[$0], f0_0=[$1])
++- LogicalTableAggregate(group=[{0}], tableAggregate=[[EmptyTableAggFuncWithIntResultType($1)]])
+ +- LogicalTableScan(table=[[default_catalog, default_database, Table2, source: [TestTableSource(f0, f1, f2, d, e)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+GroupTableAggregate(groupBy=[f0], select=[f0, EmptyTableAggFuncWithIntResultType(f1) AS (f0_0)])
++- Exchange(distribution=[hash[f0]])
+ +- TableSourceScan(table=[[default_catalog, default_database, Table2, source: [TestTableSource(f0, f1, f2, d, e)]]], fields=[f0, f1, f2, d, e])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testTableAggregateWithoutGroupBy">
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[org$apache$flink$table$planner$expressions$utils$Func0$$33f403d0ad41527ec7747c2e4fdebaf9($0)], b=[$1])
++- LogicalTableAggregate(group=[{}], tableAggregate=[[EmptyTableAggFunc($0, $1)]])
+ +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[Func0$(f0) AS a, f1 AS b])
++- GroupTableAggregate(select=[EmptyTableAggFunc(a, b) AS (f0, f1)])
+ +- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testTableAggregateWithTimeIndicator">
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1])
++- LogicalTableAggregate(group=[{}], tableAggregate=[[EmptyTableAggFunc($3, $4)]])
+ +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+GroupTableAggregate(select=[EmptyTableAggFunc(d, e) AS (f0, f1)])
++- Exchange(distribution=[single])
+ +- Calc(select=[a, b, c, CAST(d) AS d, PROCTIME_MATERIALIZE(e) AS e])
+ +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testTableAggregateWithSelectStar">
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(f0=[$0], f1=[$1])
++- LogicalTableAggregate(group=[{}], tableAggregate=[[EmptyTableAggFunc($1)]])
+ +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+GroupTableAggregate(select=[EmptyTableAggFunc(b) AS (f0, f1)])
++- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+ </Resource>
+ </TestCase>
+</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala
index 80d4ad4..020f441 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala
@@ -455,6 +455,16 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase {
}
@Test
+ def testGetColumnIntervalOnTableAggregate(): Unit = {
+ Array(logicalTableAgg, flinkLogicalTableAgg, streamExecTableAgg).foreach {
+ agg =>
+ assertEquals(RightSemiInfiniteValueInterval(0, true), mq.getColumnInterval(agg, 0))
+ assertNull(mq.getColumnInterval(agg, 1))
+ assertNull(mq.getColumnInterval(agg, 2))
+ }
+ }
+
+ @Test
def testGetColumnIntervalOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithLocalAgg,
batchGlobalWindowAggWithoutLocalAgg, streamWindowAgg).foreach { agg =>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnIntervalTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnIntervalTest.scala
index 278ceb0..36410e7 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnIntervalTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnIntervalTest.scala
@@ -18,7 +18,7 @@
package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable
-import org.apache.flink.table.planner.plan.stats.ValueInterval
+import org.apache.flink.table.planner.plan.stats.{RightSemiInfiniteValueInterval,ValueInterval}
import org.apache.flink.table.types.logical._
import org.apache.calcite.rel.RelNode
@@ -162,6 +162,18 @@ class FlinkRelMdFilteredColumnIntervalTest extends FlinkRelMdHandlerTestBase {
}
@Test
+ def testGetColumnIntervalOnTableAggregate(): Unit = {
+ Array(logicalTableAgg, flinkLogicalTableAgg, streamExecTableAgg).foreach {
+ agg =>
+ assertEquals(
+ RightSemiInfiniteValueInterval(0, true),
+ mq.getFilteredColumnInterval(agg, 0, -1))
+ assertNull(mq.getFilteredColumnInterval(agg, 1, -1))
+ assertNull(mq.getFilteredColumnInterval(agg, 2, -1))
+ }
+ }
+
+ @Test
def testGetColumnIntervalOnUnion(): Unit = {
Array(logicalUnion, logicalUnionAll).foreach { union =>
assertNull(mq.getFilteredColumnInterval(union, 0, -1))
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
index 0923bba..c4270cb 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
@@ -22,6 +22,7 @@ import org.apache.flink.table.api.{TableConfig, TableException}
import org.apache.flink.table.catalog.{CatalogManager, FunctionCatalog, GenericInMemoryCatalog}
import org.apache.flink.table.expressions._
import org.apache.flink.table.expressions.utils.ApiExpressionUtils.intervalOfMillis
+import org.apache.flink.table.functions.UserFunctionsTypeHelper
import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindowProperty
import org.apache.flink.table.planner.calcite.{FlinkRelBuilder, FlinkTypeFactory}
import org.apache.flink.table.planner.delegation.PlannerContext
@@ -29,11 +30,12 @@ import org.apache.flink.table.planner.expressions.{PlannerProctimeAttribute, Pla
import org.apache.flink.table.planner.functions.aggfunctions.SumAggFunction.DoubleSumAggFunction
import org.apache.flink.table.planner.functions.aggfunctions.{DenseRankAggFunction, RankAggFunction, RowNumberAggFunction}
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable
+import org.apache.flink.table.planner.functions.utils.AggSqlFunction
import org.apache.flink.table.planner.plan.PartialFinalType
import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef}
import org.apache.flink.table.planner.plan.logical.{LogicalWindow, TumblingGroupWindow}
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
-import org.apache.flink.table.planner.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalWindowAggregate}
+import org.apache.flink.table.planner.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalTableAggregate, LogicalWindowAggregate}
import org.apache.flink.table.planner.plan.nodes.logical._
import org.apache.flink.table.planner.plan.nodes.physical.batch._
import org.apache.flink.table.planner.plan.nodes.physical.stream._
@@ -41,11 +43,11 @@ import org.apache.flink.table.planner.plan.schema.FlinkRelOptTable
import org.apache.flink.table.planner.plan.stream.sql.join.TestTemporalTable
import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToStreamAggregateInfoList
import org.apache.flink.table.planner.plan.utils._
-import org.apache.flink.table.planner.utils.CountAggFunction
+import org.apache.flink.table.planner.utils.{CountAggFunction, Top3}
import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, RankType, VariableRankRange}
import org.apache.flink.table.types.AtomicDataType
import org.apache.flink.table.types.logical.{BigIntType, DoubleType, IntType, LogicalType, TimestampKind, TimestampType, VarCharType}
-
+import org.apache.flink.table.types.utils.TypeConversions
import com.google.common.collect.{ImmutableList, Lists}
import org.apache.calcite.jdbc.CalciteSchema
import org.apache.calcite.plan._
@@ -58,7 +60,7 @@ import org.apache.calcite.rel.metadata.{JaninoRelMetadataProvider, RelMetadataQu
import org.apache.calcite.rex._
import org.apache.calcite.schema.SchemaPlus
import org.apache.calcite.sql.SqlWindow
-import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.calcite.sql.`type`.{BasicSqlType, SqlTypeName}
import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, BOOLEAN, DATE, DOUBLE, FLOAT, TIME, TIMESTAMP, VARCHAR}
import org.apache.calcite.sql.fun.SqlStdOperatorTable.{AND, CASE, DIVIDE, EQUALS, GREATER_THAN, LESS_THAN, MINUS, MULTIPLY, OR, PLUS}
import org.apache.calcite.sql.fun.{SqlCountAggFunction, SqlStdOperatorTable}
@@ -705,6 +707,69 @@ class FlinkRelMdHandlerTestBase {
(logicalRankWithVariableRange, flinkLogicalRankWithVariableRange, streamRankWithVariableRange)
}
+ protected lazy val tableAggCall = {
+ val top3 = new Top3
+ val resultTypeInfo = UserFunctionsTypeHelper.getReturnTypeOfAggregateFunction(top3)
+ val accTypeInfo = UserFunctionsTypeHelper.getAccumulatorTypeOfAggregateFunction(top3)
+
+ val resultDataType = TypeConversions.fromLegacyInfoToDataType(resultTypeInfo)
+ val accDataType = TypeConversions.fromLegacyInfoToDataType(accTypeInfo)
+
+ val builder = typeFactory.builder()
+ builder.add("f0", new BasicSqlType(typeFactory.getTypeSystem, SqlTypeName.INTEGER))
+ builder.add("f1", new BasicSqlType(typeFactory.getTypeSystem, SqlTypeName.INTEGER))
+ val relDataType = builder.build()
+
+ AggregateCall.create(
+ AggSqlFunction("top3", "top3", new Top3, resultDataType, accDataType, typeFactory, false),
+ false,
+ false,
+ false,
+ Seq(Integer.valueOf(0)).toList,
+ -1,
+ RelCollationImpl.of(),
+ relDataType,
+ ""
+ )
+ }
+
+ protected lazy val (logicalTableAgg, flinkLogicalTableAgg, streamExecTableAgg) = {
+
+ val logicalTableAgg = new LogicalTableAggregate(
+ cluster,
+ logicalTraits,
+ studentLogicalScan,
+ ImmutableBitSet.of(0),
+ null,
+ Seq(tableAggCall))
+
+ val flinkLogicalTableAgg = new FlinkLogicalTableAggregate(
+ cluster,
+ logicalTraits,
+ studentLogicalScan,
+ ImmutableBitSet.of(0),
+ null,
+ Seq(tableAggCall)
+ )
+
+ val builder = typeFactory.builder()
+ builder.add("key", new BasicSqlType(typeFactory.getTypeSystem, SqlTypeName.BIGINT))
+ builder.add("f0", new BasicSqlType(typeFactory.getTypeSystem, SqlTypeName.INTEGER))
+ builder.add("f1", new BasicSqlType(typeFactory.getTypeSystem, SqlTypeName.INTEGER))
+ val relDataType = builder.build()
+
+ val streamExecTableAgg = new StreamExecGroupTableAggregate(
+ cluster,
+ logicalTraits,
+ studentLogicalScan,
+ relDataType,
+ Array(0),
+ Seq(tableAggCall)
+ )
+
+ (logicalTableAgg, flinkLogicalTableAgg, streamExecTableAgg)
+ }
+
// equivalent SQL is
// select age,
// avg(score) as avg_score,
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala
index 3f79b07..76a3a50 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala
@@ -19,7 +19,7 @@
package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.plan.`trait`.RelModifiedMonotonicity
-import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRank
+import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalRank, FlinkLogicalTableAggregate}
import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, RankType}
import org.apache.calcite.rel.RelCollations
@@ -31,6 +31,8 @@ import org.apache.calcite.util.ImmutableBitSet
import org.junit.Assert._
import org.junit.Test
+import scala.collection.JavaConversions._
+
class FlinkRelMdModifiedMonotonicityTest extends FlinkRelMdHandlerTestBase {
@Test
@@ -119,6 +121,34 @@ class FlinkRelMdModifiedMonotonicityTest extends FlinkRelMdHandlerTestBase {
}
@Test
+ def testGetRelMonotonicityOnTableAggregateAfterScan(): Unit = {
+ assertEquals(
+ new RelModifiedMonotonicity(Array(CONSTANT, NOT_MONOTONIC, NOT_MONOTONIC)),
+ mq.getRelModifiedMonotonicity(logicalTableAgg))
+ }
+
+ @Test
+ def testGetRelMonotonicityOnTableAggregateAfterAggregate(): Unit = {
+ val projectWithMaxAgg = relBuilder.scan("MyTable4")
+ .aggregate(
+ relBuilder.groupKey(relBuilder.field("a"), relBuilder.field("b")),
+ relBuilder.max("max_c", relBuilder.field("c")),
+ relBuilder.sum(false, "sum_d", relBuilder.field("d")))
+ .project(relBuilder.field(2), relBuilder.field(1))
+ .build()
+
+ val tableAggregate = new FlinkLogicalTableAggregate(
+ cluster,
+ logicalTraits,
+ projectWithMaxAgg,
+ ImmutableBitSet.of(0),
+ null,
+ Seq(tableAggCall)
+ )
+ assertEquals(null, mq.getRelModifiedMonotonicity(tableAggregate))
+ }
+
+ @Test
def testGetRelMonotonicityOnAggregate(): Unit = {
// select b, sum(a) from (select a + 10 as a, b from MyTable3) t group by b
val aggWithSum = relBuilder.scan("MyTable3")
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.scala
new file mode 100644
index 0000000..351a854
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.stream.table
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.planner.expressions.utils.Func0
+import org.apache.flink.table.planner.utils.{EmptyTableAggFunc, EmptyTableAggFuncWithIntResultType, TableTestBase}
+import org.junit.Test
+
+class TableAggregateTest extends TableTestBase {
+
+ val util = streamTestUtil()
+ val table = util.addTableSource[(Long, Int, Long, Long)]('a, 'b, 'c, 'd.rowtime, 'e.proctime)
+ val emptyFunc = new EmptyTableAggFunc
+
+ @Test
+ def testTableAggregateWithGroupBy(): Unit = {
+ val resultTable = table
+ .groupBy('b % 5 as 'bb)
+ .flatAggregate(emptyFunc('a, 'b) as ('x, 'y))
+ .select('bb, 'x + 1, 'y)
+
+ util.verifyPlan(resultTable)
+ }
+
+ @Test
+ def testTableAggregateWithoutGroupBy(): Unit = {
+ val resultTable = table
+ .flatAggregate(emptyFunc('a, 'b))
+ .select(Func0('f0) as 'a, 'f1 as 'b)
+
+ util.verifyPlan(resultTable)
+ }
+
+ @Test
+ def testTableAggregateWithTimeIndicator(): Unit = {
+
+ val resultTable = table
+ .flatAggregate(emptyFunc('d, 'e))
+ .select('f0 as 'a, 'f1 as 'b)
+
+ util.verifyPlan(resultTable)
+ }
+
+ @Test
+ def testTableAggregateWithSelectStar(): Unit = {
+
+ val resultTable = table
+ .flatAggregate(emptyFunc('b))
+ .select("*")
+
+ util.verifyPlan(resultTable)
+ }
+
+ @Test
+ def testTableAggregateWithAlias(): Unit = {
+
+ val resultTable = table
+ .flatAggregate(emptyFunc('b) as ('a, 'b))
+ .select('a, 'b)
+
+ util.verifyPlan(resultTable)
+ }
+
+ @Test
+ def testTableAggregateWithIntResultType(): Unit = {
+
+ val table = util.addTableSource[(Long, Int, Long, Long)]('f0, 'f1, 'f2, 'd.rowtime, 'e.proctime)
+ val func = new EmptyTableAggFuncWithIntResultType
+
+ val resultTable = table
+ .groupBy('f0)
+ .flatAggregate(func('f1))
+ .select('f0, 'f0_0)
+
+ util.verifyPlan(resultTable)
+ }
+
+ @Test
+ def testJavaRegisterFunction(): Unit = {
+
+ val util = javaStreamTestUtil()
+ val table = util.addTableSource[(Int, Long, Long)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ util.addFunction("func", func)
+
+ val resultTable = table
+ .groupBy("c")
+ .flatAggregate("func(a)")
+ .select("*")
+
+ util.verifyPlan(resultTable)
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/stringexpr/TableAggregateStringExpressionTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/stringexpr/TableAggregateStringExpressionTest.scala
new file mode 100644
index 0000000..5af0a53
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/stringexpr/TableAggregateStringExpressionTest.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.stream.table.stringexpr
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.planner.expressions.utils.Func0
+import org.apache.flink.table.planner.utils.{EmptyTableAggFunc, TableTestBase}
+import org.junit.Test
+
+class TableAggregateStringExpressionTest extends TableTestBase {
+
+ @Test
+ def testNonGroupedTableAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTableSource[(Int, Long, String)]('a, 'b, 'c)
+
+ val top3 = new EmptyTableAggFunc
+ util.addFunction("top3", top3)
+ util.addFunction("Func0", Func0)
+
+ // Expression / Scala API
+ val resScala = t
+ .flatAggregate(top3('a))
+ .select(Func0('f0) as 'a, 'f1 as 'b)
+
+ // String / Java API
+ val resJava = t
+ .flatAggregate("top3(a)")
+ .select("Func0(f0) as a, f1 as b")
+
+ verifyTableEquals(resJava, resScala)
+ }
+
+ @Test
+ def testGroupedTableAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTableSource[(Int, Long, String)]('a, 'b, 'c)
+
+ val top3 = new EmptyTableAggFunc
+ util.addFunction("top3", top3)
+ util.addFunction("Func0", Func0)
+
+ // Expression / Scala API
+ val resScala = t
+ .groupBy('b % 5)
+ .flatAggregate(top3('a))
+ .select(Func0('f0) as 'a, 'f1 as 'b)
+
+ // String / Java API
+ val resJava = t
+ .groupBy("b % 5")
+ .flatAggregate("top3(a)")
+ .select("Func0(f0) as a, f1 as b")
+
+ verifyTableEquals(resJava, resScala)
+ }
+
+ @Test
+ def testAliasNonGroupedTableAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTableSource[(Int, Long, String)]('a, 'b, 'c)
+
+ val top3 = new EmptyTableAggFunc
+ util.addFunction("top3", top3)
+ util.addFunction("Func0", Func0)
+
+ // Expression / Scala API
+ val resScala = t
+ .flatAggregate(top3('a) as ('d, 'e))
+ .select('*)
+
+ // String / Java API
+ val resJava = t
+ .flatAggregate("top3(a) as (d, e)")
+ .select("*")
+
+ verifyTableEquals(resJava, resScala)
+ }
+
+ @Test
+ def testAliasGroupedTableAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTableSource[(Int, Long, String)]('a, 'b, 'c)
+
+ val top3 = new EmptyTableAggFunc
+ util.addFunction("top3", top3)
+ util.addFunction("Func0", Func0)
+
+ // Expression / Scala API
+ val resScala = t
+ .groupBy('b)
+ .flatAggregate(top3('a) as ('d, 'e))
+ .select('*)
+
+ // String / Java API
+ val resJava = t
+ .groupBy("b")
+ .flatAggregate("top3(a) as (d, e)")
+ .select("*")
+
+ verifyTableEquals(resJava, resScala)
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/validation/TableAggregateValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/validation/TableAggregateValidationTest.scala
new file mode 100644
index 0000000..ad22da3
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/validation/TableAggregateValidationTest.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.stream.table.validation
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.planner.utils.{EmptyTableAggFunc, TableTestBase}
+import org.junit.Test
+
+class TableAggregateValidationTest extends TableTestBase {
+
+ @Test
+ def testInvalidParameterNumber(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Given parameters do not match any signature. \n" +
+ "Actual: (java.lang.Long, java.lang.Integer, java.lang.String) \n" +
+ "Expected: (int), (long, int), (long, java.sql.Timestamp)")
+
+ val util = streamTestUtil()
+ val table = util.addTableSource[(Long, Int, String)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('c)
+ // must fail. func does not take 3 parameters
+ .flatAggregate(func('a, 'b, 'c))
+ .select('_1, '_2, '_3)
+ }
+
+ @Test
+ def testInvalidParameterType(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Given parameters do not match any signature. \n" +
+ "Actual: (java.lang.Long, java.lang.String) \n" +
+ "Expected: (int), (long, int), (long, java.sql.Timestamp)")
+
+ val util = streamTestUtil()
+ val table = util.addTableSource[(Long, Int, String)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('c)
+ // must fail. func take 2 parameters of type Long and Timestamp or Long Int
+ .flatAggregate(func('a, 'c))
+ .select('_1, '_2, '_3)
+ }
+
+ @Test
+ def testInvalidWithWindowProperties(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Window properties can only be used on windowed tables.")
+
+ val util = streamTestUtil()
+ val table = util.addTableSource[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ .flatAggregate(func('a, 'b) as ('x, 'y))
+ .select('x.start, 'y)
+ }
+
+ @Test
+ def testInvalidWithAggregation(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Aggregate functions are not supported in the " +
+ "select right after the aggregate or flatAggregate operation.")
+
+ val util = streamTestUtil()
+ val table = util.addTableSource[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ .flatAggregate(func('a, 'b) as ('x, 'y))
+ .select('x.count)
+ }
+
+ @Test
+ def testInvalidParameterWithAgg(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage(
+ "It's not allowed to use an aggregate function as input of another aggregate function")
+
+ val util = streamTestUtil()
+ val table = util.addTableSource[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ // must fail. func take agg function as input
+ .flatAggregate(func('a.sum, 'c))
+ .select('_1, '_2, '_3)
+ }
+
+ @Test
+ def testInvalidAliasWithWrongNumber(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("List of column aliases must have same degree as " +
+ "table; the returned table of function " +
+ "'org.apache.flink.table.planner.utils.EmptyTableAggFunc' has 2 columns, " +
+ "whereas alias list has 3 columns")
+
+ val util = streamTestUtil()
+ val table = util.addTableSource[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ // must fail. alias with wrong number of fields
+ .flatAggregate(func('a, 'b) as ('a, 'b, 'c))
+ .select('*)
+ }
+
+ @Test
+ def testAliasWithNameConflict(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Ambiguous column name: b")
+
+ val util = streamTestUtil()
+ val table = util.addTableSource[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ // must fail. alias with name conflict
+ .flatAggregate(func('a, 'b) as ('a, 'b))
+ .select('*)
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala
index d076171..38b8534 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala
@@ -17,9 +17,11 @@
*/
package org.apache.flink.table.planner.runtime.harness
+import org.apache.flink.api.common.time.Time
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.dag.Transformation
import org.apache.flink.api.java.functions.KeySelector
+import org.apache.flink.table.api.{StreamQueryConfig, TableConfig}
import org.apache.flink.configuration.{CheckpointingOptions, Configuration}
import org.apache.flink.contrib.streaming.state.RocksDBStateBackend
import org.apache.flink.runtime.state.StateBackend
@@ -103,6 +105,30 @@ class HarnessTestBase(mode: StateBackendMode) extends StreamingTestBase {
def dropWatermarks(elements: Array[AnyRef]): util.Collection[AnyRef] = {
elements.filter(e => !e.isInstanceOf[Watermark]).toList
}
+
+ /**
+ * Test class used to test min and max retention time.
+ */
+ class TestStreamQueryConfig(min: Time, max: Time) extends StreamQueryConfig {
+ override def getMinIdleStateRetentionTime: Long = min.toMilliseconds
+ override def getMaxIdleStateRetentionTime: Long = max.toMilliseconds
+ }
+
+ class TestTableConfig extends TableConfig {
+
+ private var minIdleStateRetentionTime = 0L
+
+ private var maxIdleStateRetentionTime = 0L
+
+ override def getMinIdleStateRetentionTime: Long = minIdleStateRetentionTime
+
+ override def getMaxIdleStateRetentionTime: Long = maxIdleStateRetentionTime
+
+ override def setIdleStateRetentionTime(minTime: Time, maxTime: Time): Unit = {
+ minIdleStateRetentionTime = minTime.toMilliseconds
+ maxIdleStateRetentionTime = maxTime.toMilliseconds
+ }
+ }
}
object HarnessTestBase {
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/OverWindowHarnessTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/OverWindowHarnessTest.scala
index 4c393d2..cfd9157 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/OverWindowHarnessTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/OverWindowHarnessTest.scala
@@ -24,7 +24,7 @@ import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.scala.internal.StreamTableEnvironmentImpl
-import org.apache.flink.table.api.{EnvironmentSettings, StreamQueryConfig, TableConfig, Types}
+import org.apache.flink.table.api.{EnvironmentSettings, Types}
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor
import org.apache.flink.table.runtime.util.StreamRecordUtils.{baserow, binaryrow}
@@ -959,28 +959,4 @@ class OverWindowHarnessTest(mode: StateBackendMode) extends HarnessTestBase(mode
assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, result)
testHarness.close()
}
-
- /**
- * Test class used to test min and max retention time.
- */
- class TestStreamQueryConfig(min: Time, max: Time) extends StreamQueryConfig {
- override def getMinIdleStateRetentionTime: Long = min.toMilliseconds
- override def getMaxIdleStateRetentionTime: Long = max.toMilliseconds
- }
-
- class TestTableConfig extends TableConfig {
-
- private var minIdleStateRetentionTime = 0L
-
- private var maxIdleStateRetentionTime = 0L
-
- override def getMinIdleStateRetentionTime: Long = minIdleStateRetentionTime
-
- override def getMaxIdleStateRetentionTime: Long = maxIdleStateRetentionTime
-
- override def setIdleStateRetentionTime(minTime: Time, maxTime: Time): Unit = {
- minIdleStateRetentionTime = minTime.toMilliseconds
- maxIdleStateRetentionTime = maxTime.toMilliseconds
- }
- }
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/TableAggregateHarnessTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/TableAggregateHarnessTest.scala
new file mode 100644
index 0000000..826eea5
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/harness/TableAggregateHarnessTest.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.harness
+
+import java.lang.{Integer => JInt}
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.internal.StreamTableEnvironmentImpl
+import org.apache.flink.table.api.{EnvironmentSettings, Types}
+import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
+import org.apache.flink.table.planner.utils.{Top3WithMapView, Top3WithRetractInput}
+import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor
+import org.apache.flink.table.runtime.util.StreamRecordUtils.{record, retractRecord}
+import org.apache.flink.types.Row
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.junit.{Before, Test}
+
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class TableAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase(mode) {
+
+ @Before
+ override def before(): Unit = {
+ super.before()
+ val setting = EnvironmentSettings.newInstance().useBlinkPlanner().inStreamingMode().build()
+ val config = new TestTableConfig
+ this.tEnv = StreamTableEnvironmentImpl.create(env, setting, config)
+ }
+
+ val data = new mutable.MutableList[(Int, Int)]
+ val queryConfig = new TestStreamQueryConfig(Time.seconds(2), Time.seconds(2))
+
+ @Test
+ def testTableAggregate(): Unit = {
+ val top3 = new Top3WithMapView
+ tEnv.registerFunction("top3", top3)
+ val source = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+ val resultTable = source
+ .groupBy('a)
+ .flatAggregate(top3('b) as ('b1, 'b2))
+ .select('a, 'b1, 'b2)
+
+ val testHarness = createHarnessTester(
+ resultTable.toRetractStream[Row](queryConfig), "GroupTableAggregate")
+ val assertor = new BaseRowHarnessAssertor(Array(Types.INT, Types.INT, Types.INT))
+
+ testHarness.open()
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ // register cleanup timer with 3001
+ testHarness.setProcessingTime(1)
+
+ // input with two columns: key and value
+ testHarness.processElement(record(1: JInt, 1: JInt))
+ // output with three columns: key, value, value. The value is in the top3 of the key
+ expectedOutput.add(record(1: JInt, 1: JInt, 1: JInt))
+
+ testHarness.processElement(record(1: JInt, 2: JInt))
+ expectedOutput.add(retractRecord(1: JInt, 1: JInt, 1: JInt))
+ expectedOutput.add(record(1: JInt, 1: JInt, 1: JInt))
+ expectedOutput.add(record(1: JInt, 2: JInt, 2: JInt))
+
+ testHarness.processElement(record(1: JInt, 3: JInt))
+ expectedOutput.add(retractRecord(1: JInt, 1: JInt, 1: JInt))
+ expectedOutput.add(retractRecord(1: JInt, 2: JInt, 2: JInt))
+ expectedOutput.add(record(1: JInt, 1: JInt, 1: JInt))
+ expectedOutput.add(record(1: JInt, 2: JInt, 2: JInt))
+ expectedOutput.add(record(1: JInt, 3: JInt, 3: JInt))
+
+ testHarness.processElement(record(1: JInt, 2: JInt))
+ expectedOutput.add(retractRecord(1: JInt, 1: JInt, 1: JInt))
+ expectedOutput.add(retractRecord(1: JInt, 2: JInt, 2: JInt))
+ expectedOutput.add(retractRecord(1: JInt, 3: JInt, 3: JInt))
+ expectedOutput.add(record(1: JInt, 2: JInt, 2: JInt))
+ expectedOutput.add(record(1: JInt, 2: JInt, 2: JInt))
+ expectedOutput.add(record(1: JInt, 3: JInt, 3: JInt))
+
+ // ingest data with key value of 2
+ testHarness.processElement(record(2: JInt, 2: JInt))
+ expectedOutput.add(record(2: JInt, 2: JInt, 2: JInt))
+
+ // trigger cleanup timer
+ testHarness.setProcessingTime(3002)
+ testHarness.processElement(record(1: JInt, 2: JInt))
+ expectedOutput.add(record(1: JInt, 2: JInt, 2: JInt))
+
+ val result = testHarness.getOutput
+ assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, result)
+ testHarness.close()
+ }
+
+ @Test
+ def testTableAggregateWithRetractInput(): Unit = {
+ val top3 = new Top3WithRetractInput
+ tEnv.registerFunction("top3", top3)
+ val source = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+ val resultTable = source
+ .groupBy('a)
+ .select('b.sum as 'b)
+ .flatAggregate(top3('b) as ('b1, 'b2))
+ .select('b1, 'b2)
+
+ val testHarness = createHarnessTester(
+ resultTable.toRetractStream[Row](queryConfig), "GroupTableAggregate")
+ val assertor = new BaseRowHarnessAssertor(Array(Types.INT, Types.INT))
+
+ testHarness.open()
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ // register cleanup timer with 3001
+ testHarness.setProcessingTime(1)
+
+ // input with two columns: key and value
+ testHarness.processElement(record(1: JInt))
+ // output with three columns: key, value, value. The value is in the top3 of the key
+ expectedOutput.add(record(1: JInt, 1: JInt))
+
+ testHarness.processElement(retractRecord(1: JInt))
+ expectedOutput.add(retractRecord(1: JInt, 1: JInt))
+
+ testHarness.processElement(record(3: JInt))
+ expectedOutput.add(record(3: JInt, 3: JInt))
+
+ testHarness.processElement(record(4: JInt))
+ expectedOutput.add(retractRecord(3: JInt, 3: JInt))
+ expectedOutput.add(record(3: JInt, 3: JInt))
+ expectedOutput.add(record(4: JInt, 4: JInt))
+
+ testHarness.processElement(retractRecord(3: JInt))
+ expectedOutput.add(retractRecord(3: JInt, 3: JInt))
+ expectedOutput.add(retractRecord(4: JInt, 4: JInt))
+ expectedOutput.add(record(4: JInt, 4: JInt))
+
+ testHarness.processElement(record(5: JInt))
+ expectedOutput.add(retractRecord(4: JInt, 4: JInt))
+ expectedOutput.add(record(4: JInt, 4: JInt))
+ expectedOutput.add(record(5: JInt, 5: JInt))
+
+ val result = testHarness.getOutput
+ assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, result)
+ testHarness.close()
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
new file mode 100644
index 0000000..7b18b96
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.stream.table
+
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
+import org.apache.flink.table.planner.runtime.utils.TestData.tupleData3
+import org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, TestingRetractSink}
+import org.apache.flink.table.planner.utils.{EmptyTableAggFuncWithoutEmit, TableAggSum, Top3, Top3WithMapView, Top3WithRetractInput}
+import org.apache.flink.types.Row
+import org.junit.Assert.assertEquals
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.junit.{Before, Test}
+
+/**
+ * Tests of groupby (without window) table aggregations
+ */
+@RunWith(classOf[Parameterized])
+class TableAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode) {
+
+ @Before
+ override def before(): Unit = {
+ super.before()
+ tEnv.getConfig.setIdleStateRetentionTime(Time.hours(1), Time.hours(2))
+ }
+
+ @Test
+ def testGroupByFlatAggregate(): Unit = {
+ val top3 = new Top3
+
+ val resultTable = failingDataSource(tupleData3).toTable(tEnv, 'a, 'b, 'c)
+ .groupBy('b)
+ .flatAggregate(top3('a))
+ .select('b, 'f0, 'f1)
+ .as('category, 'v1, 'v2)
+
+ val sink = new TestingRetractSink()
+ resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "1,1,1",
+ "2,2,2",
+ "2,3,3",
+ "3,4,4",
+ "3,5,5",
+ "3,6,6",
+ "4,10,10",
+ "4,9,9",
+ "4,8,8",
+ "5,15,15",
+ "5,14,14",
+ "5,13,13",
+ "6,21,21",
+ "6,20,20",
+ "6,19,19"
+ ).sorted
+ assertEquals(expected, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testNonkeyedFlatAggregate(): Unit = {
+
+ val top3 = new Top3
+ val source = failingDataSource(tupleData3).toTable(tEnv, 'a, 'b, 'c)
+ val resultTable = source
+ .flatAggregate(top3('a))
+ .select('f0, 'f1)
+ .as('v1, 'v2)
+
+ val sink = new TestingRetractSink()
+ resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "19,19",
+ "20,20",
+ "21,21"
+ ).sorted
+ assertEquals(expected, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testAggregateAfterTableAggregate(): Unit = {
+ val top3 = new Top3
+
+ val resultTable = failingDataSource(tupleData3).toTable(tEnv, 'a, 'b, 'c)
+ .groupBy('b)
+ .flatAggregate(top3('a))
+ .select('b, 'f0, 'f1)
+ .as('category, 'v1, 'v2)
+ .groupBy('category)
+ .select('category, 'v1.max)
+
+ val sink = new TestingRetractSink()
+ resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "1,1",
+ "2,3",
+ "3,6",
+ "4,10",
+ "5,15",
+ "6,21"
+ ).sorted
+ assertEquals(expected, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testGroupByFlatAggregateWithMapView(): Unit = {
+ val top3 = new Top3WithMapView
+
+ val resultTable = failingDataSource(tupleData3).toTable(tEnv, 'a, 'b, 'c)
+ .groupBy('b)
+ .flatAggregate(top3('a))
+ .select('b, 'f0, 'f1)
+ .as('category, 'v1, 'v2)
+
+ val sink = new TestingRetractSink()
+ resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "1,1,1",
+ "2,2,2",
+ "2,3,3",
+ "3,4,4",
+ "3,5,5",
+ "3,6,6",
+ "4,10,10",
+ "4,9,9",
+ "4,8,8",
+ "5,15,15",
+ "5,14,14",
+ "5,13,13",
+ "6,21,21",
+ "6,20,20",
+ "6,19,19"
+ ).sorted
+ assertEquals(expected, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testInputWithRetraction(): Unit = {
+
+ val top3 = new Top3WithRetractInput
+ val source = failingDataSource(tupleData3).toTable(tEnv, 'a, 'b, 'c)
+ val resultTable = source
+ .groupBy('b)
+ .select('b, 'a.sum as 'a)
+ .flatAggregate(top3('a) as ('v1, 'v2))
+ .select('v1, 'v2)
+
+ val sink = new TestingRetractSink()
+ resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "111,111",
+ "65,65",
+ "34,34"
+ ).sorted
+ assertEquals(expected, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testInternalAccumulatorType(): Unit = {
+ val tableAggSum = new TableAggSum
+ val source = failingDataSource(tupleData3).toTable(tEnv, 'a, 'b, 'c)
+ val resultTable = source
+ .groupBy('b)
+ .flatAggregate(tableAggSum('a) as 'sum)
+ .select('b, 'sum)
+
+ val sink = new TestingRetractSink()
+ resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+
+ val expected = List("6,111", "6,111", "5,65", "5,65", "4,34", "4,34", "3,15", "3,15",
+ "2,5", "2,5", "1,1", "1,1").sorted
+ assertEquals(expected, sink.getRetractResults.sorted)
+ }
+
+ @Test
+ def testTableAggFunctionWithoutRetractionMethod(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Function class 'org.apache.flink.table.planner.utils.Top3'" +
+ " does not implement at least one method named 'retract' which is public, " +
+ "not abstract and (in case of table functions) not static.")
+
+ val top3 = new Top3
+ val source = env.fromCollection(tupleData3).toTable(tEnv, 'a, 'b, 'c)
+ source
+ .groupBy('b)
+ .select('b, 'a.sum as 'a)
+ .flatAggregate(top3('a) as ('v1, 'v2))
+ .select('v1, 'v2)
+ .toRetractStream[Row]
+
+ env.execute()
+ }
+
+ @Test
+ def testTableAggFunctionWithoutEmitValueMethod(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Function class " +
+ "'org.apache.flink.table.planner.utils.EmptyTableAggFuncWithoutEmit' does not " +
+ "implement at least one method named 'emitValue' which is public, " +
+ "not abstract and (in case of table functions) not static.")
+
+ val func = new EmptyTableAggFuncWithoutEmit
+ val source = env.fromCollection(tupleData3).toTable(tEnv, 'a, 'b, 'c)
+ source
+ .flatAggregate(func('a) as ('v1, 'v2))
+ .select('v1, 'v2)
+ .toRetractStream[Row]
+
+ env.execute()
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
index b117a11..23d78b6 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
@@ -38,7 +38,7 @@ import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.delegation.{Executor, ExecutorFactory, PlannerFactory}
import org.apache.flink.table.expressions.Expression
import org.apache.flink.table.factories.ComponentFactoryService
-import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction, UserFunctionsTypeHelper}
+import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableAggregateFunction, TableFunction, UserDefinedAggregateFunction, UserFunctionsTypeHelper}
import org.apache.flink.table.operations.{CatalogSinkModifyOperation, ModifyOperation, QueryOperation}
import org.apache.flink.table.planner.calcite.CalciteConfig
import org.apache.flink.table.planner.delegation.PlannerBase
@@ -549,6 +549,15 @@ abstract class TableTestUtil(
name: String,
function: AggregateFunction[T, ACC]): Unit = testingTableEnv.registerFunction(name, function)
+ /**
+ * Registers a [[TableAggregateFunction]] under given name into the TableEnvironment's catalog.
+ */
+ def addFunction[T: TypeInformation, ACC: TypeInformation](
+ name: String,
+ function: TableAggregateFunction[T, ACC]): Unit = {
+ testingTableEnv.registerFunction(name, function)
+ }
+
def verifyPlan(): Unit = {
doVerifyPlan(
SqlExplainLevel.EXPPLAN_ATTRIBUTES,
@@ -616,6 +625,13 @@ abstract class ScalaTableTestUtil(
def addFunction[T: TypeInformation, ACC: TypeInformation](
name: String,
function: AggregateFunction[T, ACC]): Unit = tableEnv.registerFunction(name, function)
+
+ /**
+ * Registers a [[TableAggregateFunction]] under given name into the TableEnvironment's catalog.
+ */
+ def addFunction[T: TypeInformation, ACC: TypeInformation](
+ name: String,
+ function: TableAggregateFunction[T, ACC]): Unit = tableEnv.registerFunction(name, function)
}
abstract class JavaTableTestUtil(
@@ -645,6 +661,13 @@ abstract class JavaTableTestUtil(
def addFunction[T: TypeInformation, ACC: TypeInformation](
name: String,
function: AggregateFunction[T, ACC]): Unit = tableEnv.registerFunction(name, function)
+
+ /**
+ * Registers a [[TableAggregateFunction]] under given name into the TableEnvironment's catalog.
+ */
+ def addFunction[T: TypeInformation, ACC: TypeInformation](
+ name: String,
+ function: TableAggregateFunction[T, ACC]): Unit = tableEnv.registerFunction(name, function)
}
/**
@@ -902,6 +925,21 @@ class TestingTableEnvironment private(
def registerFunction[T: TypeInformation, ACC: TypeInformation](
name: String,
f: AggregateFunction[T, ACC]): Unit = {
+ registerUserDefinedAggregateFunction(name, f)
+ }
+
+ // just for testing, remove this method while
+ // `<T, ACC> void registerFunction(String name, TableAggregateFunction<T, ACC> tableAggFunc);`
+ // is added into TableEnvironment
+ def registerFunction[T: TypeInformation, ACC: TypeInformation](
+ name: String,
+ f: TableAggregateFunction[T, ACC]): Unit = {
+ registerUserDefinedAggregateFunction(name, f)
+ }
+
+ private def registerUserDefinedAggregateFunction[T: TypeInformation, ACC: TypeInformation](
+ name: String,
+ f: UserDefinedAggregateFunction[T, ACC]): Unit = {
val typeInfo = UserFunctionsTypeHelper
.getReturnTypeOfAggregateFunction(f, implicitly[TypeInformation[T]])
val accTypeInfo = UserFunctionsTypeHelper
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/UserDefinedTableAggFunctions.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/UserDefinedTableAggFunctions.scala
new file mode 100644
index 0000000..ab9b37a
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/UserDefinedTableAggFunctions.scala
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.utils
+
+import java.lang.{Integer => JInt, Iterable => JIterable}
+import java.sql.Timestamp
+import java.util
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.api.dataview.MapView
+import org.apache.flink.table.dataformat.GenericRow
+import org.apache.flink.table.functions.TableAggregateFunction
+import org.apache.flink.table.runtime.typeutils.BaseRowTypeInfo
+import org.apache.flink.table.types.logical.IntType
+import org.apache.flink.util.Collector
+
+import scala.collection.mutable.ListBuffer
+
+/**** Note: Functions in this class suffer performance problem. Only use it in tests. ****/
+
+
+/****** Function for testing basic functionality of TableAggregateFunction ******/
+
+class Top3Accum {
+ var data: util.HashMap[JInt, JInt] = _
+ var size: JInt = _
+ var smallest: JInt = _
+}
+
+class Top3 extends TableAggregateFunction[JTuple2[JInt, JInt], Top3Accum] {
+ override def createAccumulator(): Top3Accum = {
+ val acc = new Top3Accum
+ acc.data = new util.HashMap[JInt, JInt]()
+ acc.size = 0
+ acc.smallest = Integer.MAX_VALUE
+ acc
+ }
+
+ def add(acc: Top3Accum, v: Int): Unit = {
+ var cnt = acc.data.get(v)
+ acc.size += 1
+ if (cnt == null) {
+ cnt = 0
+ }
+ acc.data.put(v, cnt + 1)
+ }
+
+ def delete(acc: Top3Accum, v: Int): Unit = {
+ if (acc.data.containsKey(v)) {
+ acc.size -= 1
+ val cnt = acc.data.get(v) - 1
+ if (cnt == 0) {
+ acc.data.remove(v)
+ } else {
+ acc.data.put(v, cnt)
+ }
+ }
+ }
+
+ def updateSmallest(acc: Top3Accum): Unit = {
+ acc.smallest = Integer.MAX_VALUE
+ val keys = acc.data.keySet().iterator()
+ while (keys.hasNext) {
+ val key = keys.next()
+ if (key < acc.smallest) {
+ acc.smallest = key
+ }
+ }
+ }
+
+ def accumulate(acc: Top3Accum, v: Int) {
+ if (acc.size == 0) {
+ acc.size = 1
+ acc.smallest = v
+ acc.data.put(v, 1)
+ } else if (acc.size < 3) {
+ add(acc, v)
+ if (v < acc.smallest) {
+ acc.smallest = v
+ }
+ } else if (v > acc.smallest) {
+ delete(acc, acc.smallest)
+ add(acc, v)
+ updateSmallest(acc)
+ }
+ }
+
+ def merge(acc: Top3Accum, its: JIterable[Top3Accum]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val map = iter.next().data
+ val mapIter = map.entrySet().iterator()
+ while (mapIter.hasNext) {
+ val entry = mapIter.next()
+ for (_ <- 0 until entry.getValue) {
+ accumulate(acc, entry.getKey)
+ }
+ }
+ }
+ }
+
+ def emitValue(acc: Top3Accum, out: Collector[JTuple2[JInt, JInt]]): Unit = {
+ val entries = acc.data.entrySet().iterator()
+ while (entries.hasNext) {
+ val pair = entries.next()
+ for (_ <- 0 until pair.getValue) {
+ out.collect(JTuple2.of(pair.getKey, pair.getKey))
+ }
+ }
+ }
+}
+
+/****** Function for testing MapView ******/
+
+class Top3WithMapViewAccum {
+ var data: MapView[JInt, JInt] = _
+ var size: JInt = _
+ var smallest: JInt = _
+}
+
+class Top3WithMapView extends TableAggregateFunction[JTuple2[JInt, JInt], Top3WithMapViewAccum] {
+
+ @Override
+ def createAccumulator(): Top3WithMapViewAccum = {
+ val acc = new Top3WithMapViewAccum
+ acc.data = new MapView(Types.INT, Types.INT)
+ acc.size = 0
+ acc.smallest = Integer.MAX_VALUE
+ acc
+ }
+
+ def add(acc: Top3WithMapViewAccum, v: Int): Unit = {
+ var cnt = acc.data.get(v)
+ acc.size += 1
+ if (cnt == null) {
+ cnt = 0
+ }
+ acc.data.put(v, cnt + 1)
+ }
+
+ def delete(acc: Top3WithMapViewAccum, v: Int): Unit = {
+ if (acc.data.contains(v)) {
+ acc.size -= 1
+ val cnt = acc.data.get(v) - 1
+ if (cnt == 0) {
+ acc.data.remove(v)
+ } else {
+ acc.data.put(v, cnt)
+ }
+ }
+ }
+
+ def updateSmallest(acc: Top3WithMapViewAccum): Unit = {
+ acc.smallest = Integer.MAX_VALUE
+ val keys = acc.data.iterator
+ while (keys.hasNext) {
+ val pair = keys.next()
+ if (pair.getKey < acc.smallest) {
+ acc.smallest = pair.getKey
+ }
+ }
+ }
+
+ def accumulate(acc: Top3WithMapViewAccum, v: Int) {
+ if (acc.size == 0) {
+ acc.size = 1
+ acc.smallest = v
+ acc.data.put(v, 1)
+ } else if (acc.size < 3) {
+ add(acc, v)
+ if (v < acc.smallest) {
+ acc.smallest = v
+ }
+ } else if (v > acc.smallest) {
+ delete(acc, acc.smallest)
+ add(acc, v)
+ updateSmallest(acc)
+ }
+ }
+
+ def emitValue(acc: Top3WithMapViewAccum, out: Collector[JTuple2[JInt, JInt]]): Unit = {
+ val keys = acc.data.iterator
+ while (keys.hasNext) {
+ val pair = keys.next()
+ for (_ <- 0 until pair.getValue) {
+ out.collect(JTuple2.of(pair.getKey, pair.getKey))
+ }
+ }
+ }
+}
+
+/****** Function for testing retract input ******/
+
+class Top3WithRetractInputAcc {
+ var data: ListBuffer[Int] = _
+}
+
+class Top3WithRetractInput
+ extends TableAggregateFunction[JTuple2[JInt, JInt], Top3WithRetractInputAcc] {
+
+ @Override
+ def createAccumulator(): Top3WithRetractInputAcc = {
+ val acc = new Top3WithRetractInputAcc
+ acc.data = new ListBuffer[Int]
+ acc
+ }
+
+ def accumulate(acc: Top3WithRetractInputAcc, v: Int) {
+ acc.data.append(v)
+ }
+
+ def retract(acc: Top3WithRetractInputAcc, v: Int) {
+ acc.data.remove(acc.data.indexOf(v))
+ }
+
+ def emitValue(acc: Top3WithRetractInputAcc, out: Collector[JTuple2[JInt, JInt]]): Unit = {
+ acc.data = acc.data.sorted.reverse
+ val ite = acc.data.iterator
+ var i = 0
+ while (i < 3 && i < acc.data.size) {
+ val v = ite.next()
+ i += 1
+ out.collect(JTuple2.of(v, v))
+ }
+ }
+}
+
+/****** Function for testing internal accumulator type ******/
+
+class TableAggSum extends TableAggregateFunction[JInt, GenericRow] {
+
+ override def createAccumulator(): GenericRow = {
+ val acc = new GenericRow(1)
+ acc.setInt(0, 0)
+ acc
+ }
+
+ def accumulate(acc: GenericRow, v: Int): Unit = {
+ acc.setInt(0, acc.getInt(0) + v)
+ }
+
+ def emitValue(acc: GenericRow, out: Collector[JInt]): Unit = {
+ // output two records
+ val result = acc.getInt(0)
+ out.collect(result)
+ out.collect(result)
+ }
+
+ override def getAccumulatorType: TypeInformation[GenericRow] = {
+ new BaseRowTypeInfo(new IntType()).asInstanceOf[TypeInformation[GenericRow]]
+ }
+}
+
+/**
+ * Test function for plan test.
+ */
+class EmptyTableAggFuncWithoutEmit extends TableAggregateFunction[JTuple2[JInt, JInt], Top3Accum] {
+
+ override def createAccumulator(): Top3Accum = new Top3Accum
+
+ def accumulate(acc: Top3Accum, category: Long, value: Timestamp): Unit = {}
+
+ def accumulate(acc: Top3Accum, category: Long, value: Int): Unit = {}
+
+ def accumulate(acc: Top3Accum, value: Int): Unit = {}
+}
+
+class EmptyTableAggFunc extends EmptyTableAggFuncWithoutEmit {
+
+ def emitValue(acc: Top3Accum, out: Collector[JTuple2[JInt, JInt]]): Unit = {}
+}
+
+class EmptyTableAggFuncWithIntResultType extends TableAggregateFunction[JInt, Top3Accum] {
+
+ override def createAccumulator(): Top3Accum = new Top3Accum
+
+ def accumulate(acc: Top3Accum, value: Int): Unit = {}
+
+ def emitValue(acc: Top3Accum, out: Collector[JInt]): Unit = {}
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index fd1dea1..f6d0725 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -358,10 +358,10 @@ case class VarSamp(child: PlannerExpression) extends Aggregation {
}
/**
- * Expression for calling a user-defined aggregate function.
+ * Expression for calling a user-defined (table)aggregate function.
*/
case class AggFunctionCall(
- val aggregateFunction: UserDefinedAggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
resultTypeInfo: TypeInformation[_],
accTypeInfo: TypeInformation[_],
args: Seq[PlannerExpression])
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
index 3aa021d..4859c2a 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
@@ -35,8 +35,8 @@ import org.apache.flink.table.functions.utils.AggSqlFunction.{createOperandTypeC
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
/**
- * Calcite wrapper for user-defined aggregate functions. Current, the aggregate function can be an
- * [[AggregateFunction]] or a [[TableAggregateFunction]]
+ * Calcite wrapper for user-defined aggregate functions. Currently, the aggregate function can be
+ * an [[AggregateFunction]] or a [[TableAggregateFunction]]
*
* @param name function name (used by SQL parser)
* @param displayName name to be displayed in operator name
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/logical/rel/TableAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/logical/rel/TableAggregate.scala
index cf52db1..b4bca27 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/logical/rel/TableAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/logical/rel/TableAggregate.scala
@@ -24,7 +24,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.logical.LogicalAggregate
-import org.apache.calcite.rel.{RelNode, SingleRel}
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.util.{ImmutableBitSet, Pair}
import org.apache.flink.table.plan.nodes.CommonTableAggregate
@@ -70,4 +70,10 @@ abstract class TableAggregate(
aggCalls
)
}
+
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ super.explainTerms(pw)
+ .item("group", groupSet)
+ .item("tableAggregate", aggCalls)
+ }
}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunction.java
index 953667e..2a6e954 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunction.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunction.java
@@ -18,10 +18,8 @@
package org.apache.flink.table.runtime.generated;
-import org.apache.flink.api.common.functions.AggregateFunction;
-import org.apache.flink.api.common.functions.Function;
import org.apache.flink.table.dataformat.BaseRow;
-import org.apache.flink.table.runtime.dataview.StateDataViewStore;
+import org.apache.flink.table.functions.AggregateFunction;
/**
* The base class for handling aggregate functions.
@@ -30,75 +28,11 @@ import org.apache.flink.table.runtime.dataview.StateDataViewStore;
*
* <p>It is the entry point for aggregate operators to operate all {@link AggregateFunction}s.
*/
-public interface AggsHandleFunction extends Function {
-
- /**
- * Initialization method for the function. It is called before the actual working methods.
- */
- void open(StateDataViewStore store) throws Exception;
-
- /**
- * Accumulates the input values to the accumulators.
- * @param input input values bundled in a row
- */
- void accumulate(BaseRow input) throws Exception;
-
- /**
- * Retracts the input values from the accumulators.
- * @param input input values bundled in a row
- */
- void retract(BaseRow input) throws Exception;
-
- /**
- * Merges the other accumulators into current accumulators.
- *
- * @param accumulators The other row of accumulators
- */
- void merge(BaseRow accumulators) throws Exception;
-
- /**
- * Set the current accumulators (saved in a row) which contains the current aggregated results.
- * In streaming: accumulators are store in the state, we need to restore aggregate buffers from state.
- * In batch: accumulators are store in the hashMap, we need to restore aggregate buffers from hashMap.
- *
- * @param accumulators current accumulators
- */
- void setAccumulators(BaseRow accumulators) throws Exception;
-
- /**
- * Resets all the accumulators.
- */
- void resetAccumulators() throws Exception;
-
- /**
- * Gets the current accumulators (saved in a row) which contains the current
- * aggregated results.
- * @return the current accumulators
- */
- BaseRow getAccumulators() throws Exception;
-
- /**
- * Initializes the accumulators and save them to a accumulators row.
- *
- * @return a row of accumulators which contains the aggregated results
- */
- BaseRow createAccumulators() throws Exception;
+public interface AggsHandleFunction extends AggsHandleFunctionBase {
/**
* Gets the result of the aggregation from the current accumulators.
* @return the final result (saved in a row) of the current accumulators.
*/
BaseRow getValue() throws Exception;
-
- /**
- * Cleanup for the retired accumulators state.
- */
- void cleanup() throws Exception;
-
- /**
- * Tear-down method for this function. It can be used for clean up work.
- * By default, this method does nothing.
- */
- void close() throws Exception;
-
}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunctionBase.java
similarity index 87%
copy from flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunction.java
copy to flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunctionBase.java
index 953667e..88b2fd8 100644
--- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunction.java
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunctionBase.java
@@ -18,19 +18,22 @@
package org.apache.flink.table.runtime.generated;
-import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.TableAggregateFunction;
import org.apache.flink.table.runtime.dataview.StateDataViewStore;
/**
- * The base class for handling aggregate functions.
+ * The base class for handling aggregate or table aggregate functions.
*
- * <p>It is code generated to handle all {@link AggregateFunction}s together in an aggregation.
+ * <p>It is code generated to handle all {@link AggregateFunction}s and
+ * {@link TableAggregateFunction}s together in an aggregation.
*
- * <p>It is the entry point for aggregate operators to operate all {@link AggregateFunction}s.
+ * <p>It is the entry point for aggregate operators to operate all {@link AggregateFunction}s and
+ * {@link TableAggregateFunction}s.
*/
-public interface AggsHandleFunction extends Function {
+public interface AggsHandleFunctionBase extends Function {
/**
* Initialization method for the function. It is called before the actual working methods.
@@ -85,12 +88,6 @@ public interface AggsHandleFunction extends Function {
BaseRow createAccumulators() throws Exception;
/**
- * Gets the result of the aggregation from the current accumulators.
- * @return the final result (saved in a row) of the current accumulators.
- */
- BaseRow getValue() throws Exception;
-
- /**
* Cleanup for the retired accumulators state.
*/
void cleanup() throws Exception;
@@ -100,5 +97,4 @@ public interface AggsHandleFunction extends Function {
* By default, this method does nothing.
*/
void close() throws Exception;
-
}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/GeneratedTableAggsHandleFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/GeneratedTableAggsHandleFunction.java
new file mode 100644
index 0000000..f45dd7c
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/GeneratedTableAggsHandleFunction.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.generated;
+
+/**
+ * Describes a generated {@link TableAggsHandleFunction}.
+ */
+public class GeneratedTableAggsHandleFunction extends GeneratedClass<TableAggsHandleFunction> {
+
+ private static final long serialVersionUID = 1L;
+
+ public GeneratedTableAggsHandleFunction(String className, String code, Object[] references) {
+ super(className, code, references);
+ }
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/TableAggsHandleFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/TableAggsHandleFunction.java
new file mode 100644
index 0000000..e56a26a
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/generated/TableAggsHandleFunction.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.generated;
+
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.util.Collector;
+
+/**
+ * The base class for handling table aggregate functions.
+ *
+ * <p>It is code generated to handle all {@link TableAggregateFunction}s together in an aggregation.
+ *
+ * <p>It is the entry point for aggregate operators to operate all {@link TableAggregateFunction}s.
+ */
+public interface TableAggsHandleFunction extends AggsHandleFunctionBase {
+
+ /**
+ * Emit the result of the table aggregation through the collector.
+ *
+ * @param out the collector used to emit records.
+ * @param currentKey the current group key.
+ * @param isRetract the retraction flag which indicates whether emit retract values.
+ */
+ void emitValue(Collector<BaseRow> out, BaseRow currentKey, boolean isRetract) throws Exception;
+}
diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
new file mode 100644
index 0000000..69c9859
--- /dev/null
+++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.aggregate;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore;
+import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState;
+import org.apache.flink.table.runtime.generated.GeneratedTableAggsHandleFunction;
+import org.apache.flink.table.runtime.generated.TableAggsHandleFunction;
+import org.apache.flink.table.runtime.typeutils.BaseRowTypeInfo;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.util.Collector;
+
+import static org.apache.flink.table.dataformat.util.BaseRowUtil.isAccumulateMsg;
+
+/**
+ * Aggregate Function used for the groupby (without window) table aggregate.
+ */
+public class GroupTableAggFunction extends KeyedProcessFunctionWithCleanupState<BaseRow, BaseRow, BaseRow> {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * The code generated function used to handle table aggregates.
+ */
+ private final GeneratedTableAggsHandleFunction genAggsHandler;
+
+ /**
+ * The accumulator types.
+ */
+ private final LogicalType[] accTypes;
+
+ /**
+ * Used to count the number of added and retracted input records.
+ */
+ private final RecordCounter recordCounter;
+
+ /**
+ * Whether this operator will generate retraction.
+ */
+ private final boolean generateRetraction;
+
+ // function used to handle all table aggregates
+ private transient TableAggsHandleFunction function = null;
+
+ // stores the accumulators
+ private transient ValueState<BaseRow> accState = null;
+
+ /**
+ * Creates a {@link GroupTableAggFunction}.
+ *
+ * @param minRetentionTime minimal state idle retention time.
+ * @param maxRetentionTime maximal state idle retention time.
+ * @param genAggsHandler The code generated function used to handle table aggregates.
+ * @param accTypes The accumulator types.
+ * @param indexOfCountStar The index of COUNT(*) in the aggregates.
+ * -1 when the input doesn't contain COUNT(*), i.e. doesn't contain retraction messages.
+ * We make sure there is a COUNT(*) if input stream contains retraction.
+ * @param generateRetraction Whether this operator will generate retraction.
+ */
+ public GroupTableAggFunction(
+ long minRetentionTime,
+ long maxRetentionTime,
+ GeneratedTableAggsHandleFunction genAggsHandler,
+ LogicalType[] accTypes,
+ int indexOfCountStar,
+ boolean generateRetraction) {
+ super(minRetentionTime, maxRetentionTime);
+ this.genAggsHandler = genAggsHandler;
+ this.accTypes = accTypes;
+ this.recordCounter = RecordCounter.of(indexOfCountStar);
+ this.generateRetraction = generateRetraction;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ // instantiate function
+ function = genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader());
+ function.open(new PerKeyStateDataViewStore(getRuntimeContext()));
+
+ BaseRowTypeInfo accTypeInfo = new BaseRowTypeInfo(accTypes);
+ ValueStateDescriptor<BaseRow> accDesc = new ValueStateDescriptor<>("accState", accTypeInfo);
+ accState = getRuntimeContext().getState(accDesc);
+
+ initCleanupTimeState("GroupTableAggregateCleanupTime");
+ }
+
+ @Override
+ public void processElement(BaseRow input, Context ctx, Collector<BaseRow> out) throws Exception {
+ long currentTime = ctx.timerService().currentProcessingTime();
+ // register state-cleanup timer
+ registerProcessingCleanupTimer(ctx, currentTime);
+
+ BaseRow currentKey = ctx.getCurrentKey();
+
+ boolean firstRow;
+ BaseRow accumulators = accState.value();
+ if (null == accumulators) {
+ firstRow = true;
+ accumulators = function.createAccumulators();
+ } else {
+ firstRow = false;
+ }
+
+ // set accumulators to handler first
+ function.setAccumulators(accumulators);
+
+ if (!firstRow && generateRetraction) {
+ function.emitValue(out, currentKey, true);
+ }
+
+ // update aggregate result and set to the newRow
+ if (isAccumulateMsg(input)) {
+ // accumulate input
+ function.accumulate(input);
+ } else {
+ // retract input
+ function.retract(input);
+ }
+
+ // get accumulator
+ accumulators = function.getAccumulators();
+ if (!recordCounter.recordCountIsZero(accumulators)) {
+ function.emitValue(out, currentKey, false);
+
+ // update the state
+ accState.update(accumulators);
+
+ } else {
+ // and clear all state
+ accState.clear();
+ // cleanup dataview under current key
+ function.cleanup();
+ }
+ }
+
+ @Override
+ public void onTimer(long timestamp, OnTimerContext ctx, Collector<BaseRow> out) throws Exception {
+ if (stateCleaningEnabled) {
+ cleanupState(accState);
+ function.cleanup();
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ if (function != null) {
+ function.close();
+ }
+ }
+}