You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2019/05/05 07:56:14 UTC
[flink] branch master updated: [FLINK-10977][table] Add
non-windowec streaming FlatAggregate to Table API.
This is an automated email from the ASF dual-hosted git repository.
jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new e5cadf6 [FLINK-10977][table] Add non-windowec streaming FlatAggregate to Table API.
e5cadf6 is described below
commit e5cadf69d176181af7d097928bf2a6cade1b6c76
Author: hequn8128 <ch...@gmail.com>
AuthorDate: Mon Apr 29 23:40:12 2019 +0800
[FLINK-10977][table] Add non-windowec streaming FlatAggregate to Table API.
This closes #8230
---
docs/dev/table/tableApi.md | 103 ++
.../table/api/java/StreamTableEnvironment.java | 12 +
.../{GroupedTable.java => FlatAggregateTable.java} | 29 +-
.../org/apache/flink/table/api/GroupedTable.java | 35 +
.../java/org/apache/flink/table/api/Table.java | 33 +
.../table/api/scala/StreamTableEnvironment.scala | 15 +-
.../expressions/AggregateFunctionDefinition.java | 8 +-
.../flink/table/functions/AggregateFunction.java | 52 +-
.../table/functions/TableAggregateFunction.java | 91 ++
.../functions/UserDefinedAggregateFunction.java | 57 +
.../org/apache/flink/table/api/TableImpl.scala | 4 +
.../operations/AggregateOperationFactory.java | 139 ++-
.../flink/table/plan/TableOperationConverter.java | 32 +-
.../ExtendedAggregateExtractProjectRule.java | 41 +-
.../org/apache/flink/table/api/TableEnvImpl.scala | 4 +-
.../flink/table/api/java/StreamTableEnvImpl.scala | 23 +-
.../flink/table/api/scala/StreamTableEnvImpl.scala | 8 +-
.../flink/table/api/scala/expressionDsl.scala | 6 +-
.../org/apache/flink/table/api/tableImpl.scala | 41 +
.../flink/table/calcite/FlinkRelBuilder.scala | 12 +-
.../table/calcite/RelTimeIndicatorConverter.scala | 7 +-
.../table/codegen/AggregationCodeGenerator.scala | 1213 +++++++++++---------
.../flink/table/codegen/MatchCodeGenerator.scala | 13 +-
.../flink/table/expressions/aggregations.scala | 8 +-
.../table/functions/utils/AggSqlFunction.scala | 27 +-
.../functions/utils/UserDefinedFunctionUtils.scala | 21 +-
.../table/operations/OperationTreeBuilder.scala | 71 +-
.../plan/logical/rel/LogicalTableAggregate.scala | 87 ++
.../flink/table/plan/nodes/CommonAggregate.scala | 26 +-
.../table/plan/nodes/CommonTableAggregate.scala | 88 ++
.../plan/nodes/dataset/DataSetAggregate.scala | 11 +-
.../nodes/dataset/DataSetWindowAggregate.scala | 103 +-
.../datastream/DataStreamGroupAggregate.scala | 127 +-
...te.scala => DataStreamGroupAggregateBase.scala} | 53 +-
.../datastream/DataStreamGroupTableAggregate.scala | 101 ++
.../DataStreamGroupWindowAggregate.scala | 17 +-
.../nodes/datastream/DataStreamOverAggregate.scala | 36 +-
.../nodes/logical/FlinkLogicalTableAggregate.scala | 87 ++
.../flink/table/plan/rules/FlinkRuleSets.scala | 6 +-
.../datastream/DataStreamTableAggregateRule.scala | 59 +
.../table/runtime/aggregate/AggregateUtil.scala | 306 +++--
.../runtime/aggregate/GeneratedAggregations.scala | 83 +-
.../aggregate/GroupTableAggProcessFunction.scala | 196 ++++
.../flink/table/validate/FunctionCatalog.scala | 4 +-
.../api/stream/table/TableAggregateTest.scala | 179 +++
.../TableAggregateStringExpressionTest.scala | 120 ++
.../validation/TableAggregateValidationTest.scala | 147 +++
.../harness/TableAggregateHarnessTest.scala | 162 +++
.../stream/table/TableAggregateITCase.scala | 160 +++
.../table/utils/UserDefinedTableAggFunctions.scala | 202 ++++
50 files changed, 3506 insertions(+), 959 deletions(-)
diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md
index 2cb526d..c93faa3 100644
--- a/docs/dev/table/tableApi.md
+++ b/docs/dev/table/tableApi.md
@@ -1892,6 +1892,59 @@ Table table = input
{% endhighlight %}
</td>
</tr>
+
+ <tr>
+ <td>
+ <strong>GroupBy TableAggregation</strong><br>
+ <span class="label label-primary">Streaming</span><br>
+ <span class="label label-info">Result Updating</span>
+ </td>
+ <td>
+ <p>Similar to a <b>GroupBy Aggregation</b>. Groups the rows on the grouping keys with the following running table aggregation operator to aggregate rows group-wise. The difference from an AggregateFunction is that TableAggregateFunction may return 0 or more records for a group. You have to close the "flatAggregate" with a select statement. And the select statement does not support aggregate functions.</p>
+{% highlight java %}
+ public class MyMinMaxAcc {
+ public int min = 0;
+ public int max = 0;
+ }
+
+ public class MyMinMax extends TableAggregateFunction<Row, MyMinMaxAcc> {
+
+ public void accumulate(MyMinMaxAcc acc, int value) {
+ if (value < acc.min) {
+ acc.min = value;
+ }
+ if (value > acc.max) {
+ acc.max = value;
+ }
+ }
+
+ @Override
+ public MyMinMaxAcc createAccumulator() {
+ return new MyMinMaxAcc();
+ }
+
+ public void emitValue(MyMinMaxAcc acc, Collector<Row> out) {
+ out.collect(Row.of(acc.min, acc.min));
+ out.collect(Row.of(acc.max, acc.max));
+ }
+
+ @Override
+ public TypeInformation<Row> getResultType() {
+ return new RowTypeInfo(Types.INT, Types.INT);
+ }
+ }
+
+TableAggregateFunction tableAggFunc = new MyMinMax();
+tableEnv.registerFunction("myTableAggFunc", tableAggFunc);
+Table orders = tableEnv.scan("Orders");
+Table result = orders
+ .groupBy("a")
+ .flatAggregate("myTableAggFunc(a) as (x, y)")
+ .select("a, x, y");
+{% endhighlight %}
+ <p><b>Note:</b> For streaming queries, the required state to compute the query result might grow infinitely depending on the type of aggregation and the number of distinct grouping keys. Please provide a query configuration with a valid retention interval to prevent excessive state size. See <a href="streaming/query_configuration.html">Query Configuration</a> for details.</p>
+ </td>
+ </tr>
</tbody>
</table>
</div>
@@ -1960,6 +2013,56 @@ val table = input
{% endhighlight %}
</td>
</tr>
+
+ <tr>
+ <td>
+ <strong>GroupBy TableAggregation</strong><br>
+ <span class="label label-primary">Streaming</span><br>
+ <span class="label label-info">Result Updating</span>
+ </td>
+ <td>
+ <p>Similar to a <b>GroupBy Aggregation</b>. Groups the rows on the grouping keys with the following running table aggregation operator to aggregate rows group-wise. The difference from an AggregateFunction is that TableAggregateFunction may return 0 or more records for a group. You have to close the "flatAggregate" with a select statement. And the select statement does not support aggregate functions.</p>
+{% highlight scala %}
+case class MyMinMaxAcc(var min: Int, var max: Int)
+
+class MyMinMax extends TableAggregateFunction[Row, MyMinMaxAcc] {
+
+ def accumulate(acc: MyMinMaxAcc, value: Int): Unit = {
+ if (value < acc.min) {
+ acc.min = value
+ }
+ if (value > acc.max) {
+ acc.max = value
+ }
+ }
+
+ def resetAccumulator(acc: MyMinMaxAcc): Unit = {
+ acc.min = 0
+ acc.max = 0
+ }
+
+ override def createAccumulator(): MyMinMaxAcc = MyMinMaxAcc(0, 0)
+
+ def emitValue(acc: MyMinMaxAcc, out: Collector[Row]): Unit = {
+ out.collect(Row.of(Integer.valueOf(acc.min), Integer.valueOf(acc.min)))
+ out.collect(Row.of(Integer.valueOf(acc.max), Integer.valueOf(acc.max)))
+ }
+
+ override def getResultType: TypeInformation[Row] = {
+ new RowTypeInfo(Types.INT, Types.INT)
+ }
+}
+
+val tableAggFunc = new MyMinMax
+val orders: Table = tableEnv.scan("Orders")
+val result = orders
+ .groupBy('a)
+ .flatAggregate(tableAggFunc('a) as ('x, 'y))
+ .select('a, 'x, 'y)
+{% endhighlight %}
+ <p><b>Note:</b> For streaming queries, the required state to compute the query result might grow infinitely depending on the type of aggregation and the number of distinct grouping keys. Please provide a query configuration with a valid retention interval to prevent excessive state size. See <a href="streaming/query_configuration.html">Query Configuration</a> for details.</p>
+ </td>
+ </tr>
</tbody>
</table>
</div>
diff --git a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/api/java/StreamTableEnvironment.java b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/api/java/StreamTableEnvironment.java
index 5940025..8b385fc 100644
--- a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/api/java/StreamTableEnvironment.java
+++ b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/api/java/StreamTableEnvironment.java
@@ -31,6 +31,7 @@ import org.apache.flink.table.api.TableException;
import org.apache.flink.table.descriptors.ConnectorDescriptor;
import org.apache.flink.table.descriptors.StreamTableDescriptor;
import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.TableAggregateFunction;
import org.apache.flink.table.functions.TableFunction;
import java.lang.reflect.Constructor;
@@ -75,6 +76,17 @@ public interface StreamTableEnvironment extends TableEnvironment {
<T, ACC> void registerFunction(String name, AggregateFunction<T, ACC> aggregateFunction);
/**
+ * Registers an {@link TableAggregateFunction} under a unique name in the TableEnvironment's
+ * catalog. Registered functions can only be referenced in Table API.
+ *
+ * @param name The name under which the function is registered.
+ * @param tableAggregateFunction The TableAggregateFunction to register.
+ * @param <T> The type of the output value.
+ * @tparam ACC The type of aggregate accumulator.
+ */
+ <T, ACC> void registerFunction(String name, TableAggregateFunction<T, ACC> tableAggregateFunction);
+
+ /**
* Converts the given {@link DataStream} into a {@link Table}.
*
* The field names of the {@link Table} are automatically derived from the type of the
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/GroupedTable.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/FlatAggregateTable.java
similarity index 50%
copy from flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/GroupedTable.java
copy to flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/FlatAggregateTable.java
index c535462..96b7023 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/GroupedTable.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/FlatAggregateTable.java
@@ -22,34 +22,47 @@ import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.expressions.Expression;
/**
- * A table that has been grouped on a set of grouping keys.
+ * A table that performs flatAggregate on a {@link Table} or a {@link GroupedTable}.
*/
@PublicEvolving
-public interface GroupedTable {
+public interface FlatAggregateTable {
/**
- * Performs a selection operation on a grouped table. Similar to an SQL SELECT statement.
- * The field expressions can contain complex expressions and aggregations.
+ * Performs a selection operation on a FlatAggregateTable. Similar to an SQL SELECT
+ * statement. The field expressions can contain complex expressions.
+ *
+ * <p><b>Note</b>: You have to close the flatAggregate with a select statement. And the select
+ * statement does not support aggregate functions.
*
* <p>Example:
*
* <pre>
* {@code
- * tab.groupBy("key").select("key, value.avg + ' The average' as average")
+ * TableAggregateFunction tableAggFunc = new MyTableAggregateFunction
+ * tableEnv.registerFunction("tableAggFunc", tableAggFunc);
+ * tab.groupBy("key")
+ * .flatAggregate("tableAggFunc(a, b) as (x, y, z)")
+ * .select("key, x, y, z")
* }
* </pre>
*/
Table select(String fields);
/**
- * Performs a selection operation on a grouped table. Similar to an SQL SELECT statement.
- * The field expressions can contain complex expressions and aggregations.
+ * Performs a selection operation on a FlatAggregateTable table. Similar to an SQL SELECT
+ * statement. The field expressions can contain complex expressions.
+ *
+ * <p><b>Note</b>: You have to close the flatAggregate with a select statement. And the select
+ * statement does not support aggregate functions.
*
* <p>Scala Example:
*
* <pre>
* {@code
- * tab.groupBy('key).select('key, 'value.avg + " The average" as 'average)
+ * val tableAggFunc = new MyTableAggregateFunction
+ * tab.groupBy('key)
+ * .flatAggregate(tableAggFunc('a, 'b) as ('x, 'y, 'z))
+ * .select('key, 'x, 'y, 'z)
* }
* </pre>
*/
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/GroupedTable.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/GroupedTable.java
index c535462..671964e 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/GroupedTable.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/GroupedTable.java
@@ -54,4 +54,39 @@ public interface GroupedTable {
* </pre>
*/
Table select(Expression... fields);
+
+ /**
+ * Performs a flatAggregate operation on a grouped table. FlatAggregate takes a
+ * TableAggregateFunction which returns multiple rows. Use a selection after flatAggregate.
+ *
+ * <p>Example:
+ *
+ * <pre>
+ * {@code
+ * val tableAggFunc: TableAggregateFunction = new MyTableAggregateFunction
+ * tableEnv.registerFunction("tableAggFunc", tableAggFunc);
+ * tab.groupBy("key")
+ * .flatAggregate("tableAggFunc(a, b) as (x, y, z)")
+ * .select("key, x, y, z")
+ * }
+ * </pre>
+ */
+ FlatAggregateTable flatAggregate(String tableAggFunction);
+
+ /**
+ * Performs a flatAggregate operation on a grouped table. FlatAggregate takes a
+ * TableAggregateFunction which returns multiple rows. Use a selection after flatAggregate.
+ *
+ * <p>Scala Example:
+ *
+ * <pre>
+ * {@code
+ * val tableAggFunc: TableAggregateFunction = new MyTableAggregateFunction
+ * tab.groupBy('key)
+ * .flatAggregate(tableAggFunc('a, 'b) as ('x, 'y, 'z))
+ * .select('key, 'x, 'y, 'z)
+ * }
+ * </pre>
+ */
+ FlatAggregateTable flatAggregate(Expression tableAggFunction);
}
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Table.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Table.java
index 474178f..4b14f71 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Table.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Table.java
@@ -1058,4 +1058,37 @@ public interface Table {
* </pre>
*/
Table flatMap(Expression tableFunction);
+
+ /**
+ * Perform a global flatAggregate without groupBy. FlatAggregate takes a TableAggregateFunction
+ * which returns multiple rows. Use a selection after the flatAggregate.
+ *
+ * <p>Example:
+ *
+ * <pre>
+ * {@code
+ * TableAggregateFunction tableAggFunc = new MyTableAggregateFunction();
+ * tableEnv.registerFunction("tableAggFunc", tableAggFunc);
+ * tab.flatAggregate("tableAggFunc(a, b) as (x, y, z)")
+ * .select("x, y, z")
+ * }
+ * </pre>
+ */
+ FlatAggregateTable flatAggregate(String tableAggFunction);
+
+ /**
+ * Perform a global flatAggregate without groupBy. FlatAggregate takes a TableAggregateFunction
+ * which returns multiple rows. Use a selection after the flatAggregate.
+ *
+ * <p>Scala Example:
+ *
+ * <pre>
+ * {@code
+ * val tableAggFunc = new MyTableAggregateFunction
+ * tab.flatAggregate(tableAggFunc('a, 'b) as ('x, 'y, 'z))
+ * .select('x, 'y, 'z)
+ * }
+ * </pre>
+ */
+ FlatAggregateTable flatAggregate(Expression tableAggFunction);
}
diff --git a/flink-table/flink-table-api-scala-bridge/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala b/flink-table/flink-table-api-scala-bridge/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala
index c58c2de..8f379e6 100644
--- a/flink-table/flink-table-api-scala-bridge/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala
+++ b/flink-table/flink-table-api-scala-bridge/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala
@@ -22,7 +22,7 @@ import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironm
import org.apache.flink.table.api.{TableEnvironment, _}
import org.apache.flink.table.descriptors.{ConnectorDescriptor, StreamTableDescriptor}
import org.apache.flink.table.expressions.Expression
-import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
+import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction, TableFunction}
/**
* The [[TableEnvironment]] for a Scala [[StreamExecutionEnvironment]] that works with
@@ -62,6 +62,19 @@ trait StreamTableEnvironment extends TableEnvironment {
f: AggregateFunction[T, ACC]): Unit
/**
+ * Registers an [[TableAggregateFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can only be referenced in Table API.
+ *
+ * @param name The name under which the function is registered.
+ * @param f The TableAggregateFunction to register.
+ * @tparam T The type of the output value.
+ * @tparam ACC The type of aggregate accumulator.
+ */
+ def registerFunction[T: TypeInformation, ACC: TypeInformation](
+ name: String,
+ f: TableAggregateFunction[T, ACC]): Unit
+
+ /**
* Converts the given [[DataStream]] into a [[Table]].
*
* The field names of the [[Table]] are automatically derived from the type of the
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateFunctionDefinition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateFunctionDefinition.java
index 18fc982..690f64e 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateFunctionDefinition.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateFunctionDefinition.java
@@ -20,7 +20,7 @@ package org.apache.flink.table.expressions;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.UserDefinedAggregateFunction;
import org.apache.flink.util.Preconditions;
import static org.apache.flink.table.expressions.FunctionDefinition.Type.AGGREGATE_FUNCTION;
@@ -31,13 +31,13 @@ import static org.apache.flink.table.expressions.FunctionDefinition.Type.AGGREGA
@PublicEvolving
public final class AggregateFunctionDefinition extends FunctionDefinition {
- private final AggregateFunction<?, ?> aggregateFunction;
+ private final UserDefinedAggregateFunction<?, ?> aggregateFunction;
private final TypeInformation<?> resultTypeInfo;
private final TypeInformation<?> accumulatorTypeInfo;
public AggregateFunctionDefinition(
String name,
- AggregateFunction<?, ?> aggregateFunction,
+ UserDefinedAggregateFunction<?, ?> aggregateFunction,
TypeInformation<?> resultTypeInfo,
TypeInformation<?> accTypeInfo) {
super(name, AGGREGATE_FUNCTION);
@@ -46,7 +46,7 @@ public final class AggregateFunctionDefinition extends FunctionDefinition {
this.accumulatorTypeInfo = Preconditions.checkNotNull(accTypeInfo);
}
- public AggregateFunction<?, ?> getAggregateFunction() {
+ public UserDefinedAggregateFunction<?, ?> getAggregateFunction() {
return aggregateFunction;
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AggregateFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AggregateFunction.java
index 70066a2..2399fa75 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AggregateFunction.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AggregateFunction.java
@@ -19,25 +19,29 @@
package org.apache.flink.table.functions;
import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
/**
* Base class for user-defined aggregates.
*
* <p>The behavior of an {@link AggregateFunction} can be defined by implementing a series of custom
* methods. An {@link AggregateFunction} needs at least three methods:
- * - <code>createAccumulator</code>,
- * - <code>accumulate</code>, and
- * - <code>getValue</code>.
+ * <ul>
+ * <li>createAccumulator</li>
+ * <li>accumulate</li>
+ * <li>getValue</li>
+ * </ul>
*
* <p>There are a few other methods that can be optional to have:
- * - <code>retract</code>,
- * - <code>merge</code>, and
- * - <code>resetAccumulator</code>.
+ * <ul>
+ * <li>retract</li>
+ * <li>merge</li>
+ * <li>resetAccumulator</li>
+ * </ul>
*
* <p>All these methods must be declared publicly, not static, and named exactly as the names
- * mentioned above. The methods {@link #createAccumulator()} and {@link #getValue} are defined in
- * the {@link AggregateFunction} functions, while other methods are explained below.
+ * mentioned above. The method {@link #createAccumulator()} is defined in the
+ * {@link UserDefinedAggregateFunction} function, and method {@link #getValue} is defined in
+ * the {@link AggregateFunction} while other methods are explained below.
*
* <pre>
* {@code
@@ -100,15 +104,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
* AggregateFunction must be put into the accumulator.
*/
@PublicEvolving
-public abstract class AggregateFunction<T, ACC> extends UserDefinedFunction {
-
- /**
- * Creates and initializes the accumulator for this {@link AggregateFunction}. The accumulator
- * is used to keep the aggregated values which are needed to compute an aggregation result.
- *
- * @return the accumulator with the initial value
- */
- public abstract ACC createAccumulator();
+public abstract class AggregateFunction<T, ACC> extends UserDefinedAggregateFunction<T, ACC> {
/**
* Called every time when an aggregation result should be materialized.
@@ -132,24 +128,4 @@ public abstract class AggregateFunction<T, ACC> extends UserDefinedFunction {
public boolean requiresOver() {
return false;
}
-
- /**
- * Returns the {@link TypeInformation} of the {@link AggregateFunction}'s result.
- *
- * @return The {@link TypeInformation} of the {@link AggregateFunction}'s result or
- * <code>null</code> if the result type should be automatically inferred.
- */
- public TypeInformation<T> getResultType() {
- return null;
- }
-
- /**
- * Returns the {@link TypeInformation} of the {@link AggregateFunction}'s accumulator.
- *
- * @return The {@link TypeInformation} of the {@link AggregateFunction}'s accumulator or
- * <code>null</code> if the accumulator type should be automatically inferred.
- */
- public TypeInformation<ACC> getAccumulatorType() {
- return null;
- }
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableAggregateFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableAggregateFunction.java
new file mode 100644
index 0000000..4224983
--- /dev/null
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableAggregateFunction.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * Base class for user-defined table aggregates.
+ *
+ * <p>The behavior of a {@link TableAggregateFunction} can be defined by implementing a series of
+ * custom methods. A {@link TableAggregateFunction} needs at least three methods:
+ * <ul>
+ * <li>createAccumulator</li>
+ * <li>accumulate</li>
+ * <li>emitValue</li>
+ * </ul>
+ *
+ * <p>There is another method that can be optional to have:
+ * <ul>
+ * <li>retract</li>
+ * </ul>
+ *
+ * <p>All these methods must be declared publicly, not static, and named exactly as the names
+ * mentioned above. The method {@link #createAccumulator()} is defined in
+ * the {@link UserDefinedAggregateFunction} functions, while other methods are explained below.
+ *
+ * <pre>
+ * {@code
+ * Processes the input values and update the provided accumulator instance. The method
+ * accumulate can be overloaded with different custom types and arguments. A TableAggregateFunction
+ * requires at least one accumulate() method.
+ *
+ * param: accumulator the accumulator which contains the current aggregated results
+ * param: [user defined inputs] the input value (usually obtained from a new arrived data).
+ *
+ * public void accumulate(ACC accumulator, [user defined inputs])
+ * }
+ * </pre>
+ *
+ * <pre>
+ * {@code
+ * Retracts the input values from the accumulator instance. The current design assumes the
+ * inputs are the values that have been previously accumulated. The method retract can be
+ * overloaded with different custom types and arguments.
+ *
+ * param: accumulator the accumulator which contains the current aggregated results
+ * param: [user defined inputs] the input value (usually obtained from a new arrived data).
+ *
+ * public void retract(ACC accumulator, [user defined inputs])
+ * }
+ * </pre>
+ *
+ * <pre>
+ * {@code
+ * Called every time when an aggregation result should be materialized. The returned value could
+ * be either an early and incomplete result (periodically emitted as data arrive) or the final
+ * result of the aggregation.
+ *
+ * param: accumulator the accumulator which contains the current aggregated results
+ * param: out the collector used to output data.
+ *
+ * public void emitValue(ACC accumulator, Collector<T> out)
+ * }
+ * </pre>
+ *
+ * @param <T> the type of the table aggregation result
+ * @param <ACC> the type of the table aggregation accumulator. The accumulator is used to keep the
+ * aggregated values which are needed to compute an aggregation result.
+ * TableAggregateFunction represents its state using accumulator, thereby the state of
+ * the TableAggregateFunction must be put into the accumulator.
+ */
+@PublicEvolving
+public abstract class TableAggregateFunction<T, ACC> extends UserDefinedAggregateFunction<T, ACC> {
+
+}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedAggregateFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedAggregateFunction.java
new file mode 100644
index 0000000..b746952
--- /dev/null
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedAggregateFunction.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+
+/**
+ * Base class for user-defined aggregates and table aggregates.
+ */
+@PublicEvolving
+public abstract class UserDefinedAggregateFunction<T, ACC> extends UserDefinedFunction {
+
+ /**
+ * Creates and initializes the accumulator for this {@link AggregateFunction}. The accumulator
+ * is used to keep the aggregated values which are needed to compute an aggregation result.
+ *
+ * @return the accumulator with the initial value
+ */
+ public abstract ACC createAccumulator();
+
+ /**
+ * Returns the {@link TypeInformation} of the {@link AggregateFunction}'s result.
+ *
+ * @return The {@link TypeInformation} of the {@link AggregateFunction}'s result or
+ * <code>null</code> if the result type should be automatically inferred.
+ */
+ public TypeInformation<T> getResultType() {
+ return null;
+ }
+
+ /**
+ * Returns the {@link TypeInformation} of the {@link AggregateFunction}'s accumulator.
+ *
+ * @return The {@link TypeInformation} of the {@link AggregateFunction}'s accumulator or
+ * <code>null</code> if the accumulator type should be automatically inferred.
+ */
+ public TypeInformation<ACC> getAccumulatorType() {
+ return null;
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableImpl.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableImpl.scala
index 4d3cba8..7f67b66 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableImpl.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableImpl.scala
@@ -204,4 +204,8 @@ class TableImpl(val tableEnv: TableEnvironment, relNode: RelNode) extends Table
override def flatMap(tableFunction: Expression): Table = ???
override def getTableOperation: TableOperation = ???
+
+ override def flatAggregate(tableAggFunction: String): FlatAggregateTable = ???
+
+ override def flatAggregate(tableAggFunction: Expression): FlatAggregateTable = ???
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/AggregateOperationFactory.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/AggregateOperationFactory.java
index f21f448..54ded6b 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/AggregateOperationFactory.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/AggregateOperationFactory.java
@@ -22,6 +22,8 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.table.api.GroupWindow;
import org.apache.flink.table.api.SessionWithGapOnTimeWithAlias;
import org.apache.flink.table.api.SlideWithSizeAndSlideOnTimeWithAlias;
@@ -29,6 +31,7 @@ import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.TumbleWithSizeOnTimeWithAlias;
import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.expressions.AggFunctionCall;
import org.apache.flink.table.expressions.AggregateFunctionDefinition;
import org.apache.flink.table.expressions.ApiExpressionDefaultVisitor;
import org.apache.flink.table.expressions.BuiltInFunctionDefinitions;
@@ -36,23 +39,30 @@ import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionBridge;
import org.apache.flink.table.expressions.ExpressionResolver;
+import org.apache.flink.table.expressions.ExpressionUtils;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.FunctionDefinition;
import org.apache.flink.table.expressions.PlannerExpression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils;
import org.apache.flink.table.operations.WindowAggregateTableOperation.ResolvedGroupWindow;
import org.apache.flink.table.typeutils.RowIntervalTypeInfo;
import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo;
import org.apache.flink.table.typeutils.TimeIntervalTypeInfo;
+import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static java.lang.String.format;
import static java.util.Collections.singletonList;
+import static java.util.stream.Collectors.toList;
import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.LONG_TYPE_INFO;
+import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.AS;
import static org.apache.flink.table.expressions.ExpressionUtils.isFunctionOfType;
import static org.apache.flink.table.expressions.FunctionDefinition.Type.AGGREGATE_FUNCTION;
import static org.apache.flink.table.operations.OperationExpressionsUtils.extractName;
@@ -97,17 +107,31 @@ public class AggregateOperationFactory {
List<PlannerExpression> convertedGroupings = bridge(groupings);
List<PlannerExpression> convertedAggregates = bridge(aggregates);
+ Boolean isTableAggregate = aggregates.size() == 1 && isTableAggFunctionCall(aggregates.get(0));
+
TypeInformation[] fieldTypes = Stream.concat(
- convertedGroupings.stream(),
- convertedAggregates.stream()
- ).map(PlannerExpression::resultType)
- .toArray(TypeInformation[]::new);
+ convertedGroupings.stream().map(PlannerExpression::resultType),
+ convertedAggregates.stream().flatMap(expr -> {
+ if (isTableAggregate) {
+ return Stream.of(UserDefinedFunctionUtils.getFieldInfo(expr.resultType())._3());
+ } else {
+ return Stream.of(expr.resultType());
+ }
+ })
+ ).toArray(TypeInformation[]::new);
String[] fieldNames = Stream.concat(
- groupings.stream(),
- aggregates.stream()
- ).map(expr -> extractName(expr).orElseGet(expr::toString))
- .toArray(String[]::new);
+ groupings.stream().map(expr -> extractName(expr).orElseGet(expr::toString)),
+ aggregates.stream().flatMap(expr -> {
+ if (isTableAggregate) {
+ return Stream.of(UserDefinedFunctionUtils.getFieldInfo(
+ ((AggregateFunctionDefinition) ((CallExpression) expr).getFunctionDefinition())
+ .getResultTypeInfo())._1());
+ } else {
+ return Stream.of(extractName(expr).orElseGet(expr::toString));
+ }
+ })
+ ).toArray(String[]::new);
TableSchema tableSchema = new TableSchema(fieldNames, fieldTypes);
@@ -399,7 +423,10 @@ public class AggregateOperationFactory {
}
private boolean requiresOver(FunctionDefinition functionDefinition) {
- return ((AggregateFunctionDefinition) functionDefinition).getAggregateFunction().requiresOver();
+ return ((AggregateFunctionDefinition) functionDefinition).getAggregateFunction()
+ instanceof AggregateFunction &&
+ ((AggregateFunction) ((AggregateFunctionDefinition) functionDefinition)
+ .getAggregateFunction()).requiresOver();
}
@Override
@@ -408,7 +435,7 @@ public class AggregateOperationFactory {
return null;
}
- private void failExpression(Expression expression) {
+ protected void failExpression(Expression expression) {
throw new ValidationException(format("Expression '%s' is invalid because it is neither" +
" present in GROUP BY nor an aggregate function", expression));
}
@@ -465,4 +492,96 @@ public class AggregateOperationFactory {
return null;
}
}
+
+ /**
+ * Extract a table aggregate Expression and it's aliases.
+ */
+ public Tuple2<Expression, List<String>> extractTableAggFunctionAndAliases(Expression callExpr) {
+ TableAggFunctionCallVisitor visitor = new TableAggFunctionCallVisitor();
+ return Tuple2.of(callExpr.accept(visitor), visitor.getAlias());
+ }
+
+ private class TableAggFunctionCallVisitor extends ApiExpressionDefaultVisitor<Expression> {
+
+ private List<String> alias = new LinkedList<>();
+
+ public List<String> getAlias() {
+ return alias;
+ }
+
+ @Override
+ public Expression visitCall(CallExpression call) {
+ FunctionDefinition definition = call.getFunctionDefinition();
+ if (definition.equals(AS)) {
+ return unwrapFromAlias(call);
+ } else if (definition instanceof AggregateFunctionDefinition) {
+ if (!isTableAggFunctionCall(call)) {
+ throw fail();
+ }
+ return call;
+ } else {
+ return defaultMethod(call);
+ }
+ }
+
+ private Expression unwrapFromAlias(CallExpression call) {
+ List<Expression> children = call.getChildren();
+ List<String> aliases = children.subList(1, children.size())
+ .stream()
+ .map(alias -> ExpressionUtils.extractValue(alias, Types.STRING)
+ .orElseThrow(() -> new ValidationException("Unexpected alias: " + alias)))
+ .collect(toList());
+
+ if (!isTableAggFunctionCall(children.get(0))) {
+ throw fail();
+ }
+
+ validateAlias(aliases, (AggregateFunctionDefinition) ((CallExpression) children.get(0)).getFunctionDefinition());
+ alias = aliases;
+ return children.get(0);
+ }
+
+ private void validateAlias(
+ List<String> aliases,
+ AggregateFunctionDefinition aggFunctionDefinition) {
+
+ TypeInformation resultType = aggFunctionDefinition.getResultTypeInfo();
+
+ int callArity = resultType.getTotalFields();
+ int aliasesSize = aliases.size();
+
+ if (aliasesSize > 0 && aliasesSize != callArity) {
+ throw new ValidationException(String.format(
+ "List of column aliases must have same degree as table; " +
+ "the returned table of function '%s' has " +
+ "%d columns, whereas alias list has %d columns",
+ aggFunctionDefinition.getName(),
+ callArity,
+ aliasesSize));
+ }
+ }
+
+ @Override
+ protected AggFunctionCall defaultMethod(Expression expression) {
+ throw fail();
+ }
+
+ private ValidationException fail() {
+ return new ValidationException(
+ "A flatAggregate only accepts an expression which defines a table aggregate " +
+ "function that might be followed by some alias.");
+ }
+ }
+
+ /**
+ * Return true if the input {@link Expression} is a {@link CallExpression} of table aggregate function.
+ */
+ public static boolean isTableAggFunctionCall(Expression expression) {
+ return Stream.of(expression)
+ .filter(p -> p instanceof CallExpression)
+ .map(p -> (CallExpression) p)
+ .filter(p -> p.getFunctionDefinition() instanceof AggregateFunctionDefinition)
+ .map(p -> (AggregateFunctionDefinition) p.getFunctionDefinition())
+ .anyMatch(p -> p.getAggregateFunction() instanceof TableAggregateFunction);
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java
index e8d0537..d652108 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.calcite.FlinkRelBuilder;
import org.apache.flink.table.calcite.FlinkTypeFactory;
+import org.apache.flink.table.expressions.AggFunctionCall;
import org.apache.flink.table.expressions.Aggregation;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
@@ -35,6 +36,7 @@ import org.apache.flink.table.expressions.RexPlannerExpression;
import org.apache.flink.table.expressions.WindowReference;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.table.functions.utils.TableSqlFunction;
+import org.apache.flink.table.operations.AggregateOperationFactory;
import org.apache.flink.table.operations.AggregateTableOperation;
import org.apache.flink.table.operations.CalculatedTableOperation;
import org.apache.flink.table.operations.CatalogTableOperation;
@@ -107,6 +109,7 @@ public class TableOperationConverter extends TableOperationDefaultVisitor<RelNod
private final SingleRelVisitor singleRelVisitor = new SingleRelVisitor();
private final ExpressionBridge<PlannerExpression> expressionBridge;
private final AggregateVisitor aggregateVisitor = new AggregateVisitor();
+ private final TableAggregateVisitor tableAggregateVisitor = new TableAggregateVisitor();
private final JoinExpressionVisitor joinExpressionVisitor = new JoinExpressionVisitor();
public TableOperationConverter(
@@ -133,15 +136,27 @@ public class TableOperationConverter extends TableOperationDefaultVisitor<RelNod
@Override
public RelNode visitAggregate(AggregateTableOperation aggregate) {
+ boolean isTableAggregate = aggregate.getAggregateExpressions().size() == 1 &&
+ AggregateOperationFactory.isTableAggFunctionCall(aggregate.getAggregateExpressions().get(0));
+
List<AggCall> aggregations = aggregate.getAggregateExpressions()
.stream()
- .map(expr -> expr.accept(aggregateVisitor))
- .collect(toList());
+ .map(expr -> {
+ if (isTableAggregate) {
+ return expr.accept(tableAggregateVisitor);
+ } else {
+ return expr.accept(aggregateVisitor);
+ }
+ }).collect(toList());
List<RexNode> groupings = convertToRexNodes(aggregate.getGroupingExpressions());
GroupKey groupKey = relBuilder.groupKey(groupings);
- return relBuilder.aggregate(groupKey, aggregations).build();
+ if (isTableAggregate) {
+ return ((FlinkRelBuilder) relBuilder).tableAggregate(groupKey, aggregations).build();
+ } else {
+ return relBuilder.aggregate(groupKey, aggregations).build();
+ }
}
@Override
@@ -362,4 +377,15 @@ public class TableOperationConverter extends TableOperationDefaultVisitor<RelNod
throw new TableException("Unexpected expression: " + expression);
}
}
+
+ private class TableAggregateVisitor extends AggregateVisitor {
+ @Override
+ public AggCall visitCall(CallExpression call) {
+ if (isFunctionOfType(call, AGGREGATE_FUNCTION)) {
+ AggFunctionCall aggFunctionCall = (AggFunctionCall) expressionBridge.bridge(call);
+ return aggFunctionCall.toAggCall(aggFunctionCall.toString(), false, relBuilder);
+ }
+ throw new TableException("Expected table aggregate. Got: " + call);
+ }
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.java
index 3393d2d..88690a4 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.java
@@ -20,11 +20,13 @@ package org.apache.flink.table.plan.rules.logical;
import org.apache.flink.table.expressions.ResolvedFieldReference;
import org.apache.flink.table.plan.logical.LogicalWindow;
+import org.apache.flink.table.plan.logical.rel.LogicalTableAggregate;
import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
@@ -45,20 +47,20 @@ import java.util.List;
import java.util.stream.Collectors;
/**
- * Rule to extract a {@link org.apache.calcite.rel.core.Project}
- * from a {@link LogicalAggregate} or a {@link LogicalWindowAggregate}
- * and push it down towards the input.
+ * Rule to extract a {@link org.apache.calcite.rel.core.Project} from a {@link LogicalAggregate},
+ * a {@link LogicalWindowAggregate} or a {@link LogicalTableAggregate} and push it down towards
+ * the input.
*
* <p>Note: Most of the logic in this rule is same with {@link AggregateExtractProjectRule}. The
- * difference is this rule has also taken the {@link LogicalWindowAggregate} into consideration.
- * Furthermore, this rule also creates trivial {@link Project}s unless the input node is already
- * a {@link Project}.
+ * difference is this rule has also taken the {@link LogicalWindowAggregate} and
+ * {@link LogicalTableAggregate} into consideration. Furthermore, this rule also creates trivial
+ * {@link Project}s unless the input node is already a {@link Project}.
*/
public class ExtendedAggregateExtractProjectRule extends AggregateExtractProjectRule {
public static final ExtendedAggregateExtractProjectRule INSTANCE =
new ExtendedAggregateExtractProjectRule(
- operand(Aggregate.class,
+ operand(SingleRel.class,
operand(RelNode.class, any())), RelFactories.LOGICAL_BUILDER);
public ExtendedAggregateExtractProjectRule(
@@ -70,19 +72,34 @@ public class ExtendedAggregateExtractProjectRule extends AggregateExtractProject
@Override
public boolean matches(RelOptRuleCall call) {
- final Aggregate aggregate = call.rel(0);
- return aggregate instanceof LogicalWindowAggregate || aggregate instanceof LogicalAggregate;
+ final SingleRel relNode = call.rel(0);
+ return relNode instanceof LogicalWindowAggregate ||
+ relNode instanceof LogicalAggregate ||
+ relNode instanceof LogicalTableAggregate;
}
@Override
public void onMatch(RelOptRuleCall call) {
- final Aggregate aggregate = call.rel(0);
+ final RelNode relNode = call.rel(0);
final RelNode input = call.rel(1);
final RelBuilder relBuilder = call.builder().push(input);
+ if (relNode instanceof LogicalAggregate || relNode instanceof LogicalWindowAggregate) {
+ call.transformTo(performExtract((Aggregate) relNode, input, relBuilder));
+ } else if (relNode instanceof LogicalTableAggregate) {
+ LogicalAggregate logicalAggregate =
+ LogicalTableAggregate.getCorrespondingAggregate((LogicalTableAggregate) relNode);
+ RelNode newAggregate = performExtract(logicalAggregate, input, relBuilder);
+ call.transformTo(LogicalTableAggregate.create((Aggregate) newAggregate));
+ }
+ }
+
+ /**
+ * Extract a project from the input aggregate and return a new aggregate.
+ */
+ private RelNode performExtract(Aggregate aggregate, RelNode input, RelBuilder relBuilder) {
Mapping mapping = extractProjectsAndMapping(aggregate, input, relBuilder);
- RelNode newAggregate = getNewAggregate(aggregate, relBuilder, mapping);
- call.transformTo(newAggregate);
+ return getNewAggregate(aggregate, relBuilder, mapping);
}
/**
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala
index 58646e1..df5936f 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala
@@ -46,7 +46,7 @@ import org.apache.flink.table.catalog.{ExternalCatalog, ExternalCatalogSchema}
import org.apache.flink.table.codegen.{ExpressionReducer, FunctionCodeGenerator, GeneratedFunction}
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
-import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction}
+import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction, UserDefinedAggregateFunction}
import org.apache.flink.table.operations.{CatalogTableOperation, OperationTreeBuilder, PlannerTableOperation}
import org.apache.flink.table.plan.TableOperationConverter
import org.apache.flink.table.plan.cost.DataSetCostFactory
@@ -456,7 +456,7 @@ abstract class TableEnvImpl(val config: TableConfig) extends TableEnvironment {
* user-defined functions under this name.
*/
private[flink] def registerAggregateFunctionInternal[T: TypeInformation, ACC: TypeInformation](
- name: String, function: AggregateFunction[T, ACC]): Unit = {
+ name: String, function: UserDefinedAggregateFunction[T, ACC]): Unit = {
// check if class not Scala object
checkNotSingleton(function.getClass)
// check if class could be instantiated
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala
index f4c8b85..16ecebf 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala
@@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.{TupleTypeInfo, TypeExtractor}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.table.api._
-import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
+import org.apache.flink.table.functions.{AggregateFunction, TableFunction, TableAggregateFunction, UserDefinedAggregateFunction}
import org.apache.flink.table.expressions.ExpressionParser
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
@@ -158,12 +158,29 @@ class StreamTableEnvImpl(
name: String,
f: AggregateFunction[T, ACC])
: Unit = {
+ registerUserDefinedAggregateFunction(name, f)
+ }
+
+ override def registerFunction[T, ACC](
+ name: String,
+ f: TableAggregateFunction[T, ACC])
+ : Unit = {
+ registerUserDefinedAggregateFunction(name, f)
+ }
+
+ /**
+ * Common function for registering an [[AggregateFunction]] or a [[TableAggregateFunction]].
+ */
+ private def registerUserDefinedAggregateFunction[T, ACC](
+ name: String,
+ f: UserDefinedAggregateFunction[T, ACC])
+ : Unit = {
implicit val typeInfo: TypeInformation[T] = TypeExtractor
- .createTypeInfo(f, classOf[AggregateFunction[T, ACC]], f.getClass, 0)
+ .createTypeInfo(f, classOf[UserDefinedAggregateFunction[T, ACC]], f.getClass, 0)
.asInstanceOf[TypeInformation[T]]
implicit val accTypeInfo: TypeInformation[ACC] = TypeExtractor
- .createTypeInfo(f, classOf[AggregateFunction[T, ACC]], f.getClass, 1)
+ .createTypeInfo(f, classOf[UserDefinedAggregateFunction[T, ACC]], f.getClass, 1)
.asInstanceOf[TypeInformation[ACC]]
registerAggregateFunctionInternal[T, ACC](name, f)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvImpl.scala
index 4d2f9e2..897ab2e 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvImpl.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvImpl.scala
@@ -21,7 +21,7 @@ import org.apache.flink.api.scala._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.{StreamQueryConfig, Table, TableConfig, TableEnvImpl}
import org.apache.flink.table.expressions.Expression
-import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
+import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction, TableFunction}
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.api.scala.asScalaStream
@@ -100,4 +100,10 @@ class StreamTableEnvImpl(
: Unit = {
registerAggregateFunctionInternal[T, ACC](name, f)
}
+
+ override def registerFunction[T: TypeInformation, ACC: TypeInformation](
+ name: String,
+ f: TableAggregateFunction[T, ACC]): Unit = {
+ registerAggregateFunctionInternal[T, ACC](name, f)
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index 333a8db..a38d956 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -27,7 +27,7 @@ import org.apache.flink.table.expressions.ApiExpressionUtils._
import org.apache.flink.table.expressions.BuiltInFunctionDefinitions.{WITH_COLUMNS, RANGE_TO, E => FDE, UUID => FDUUID, _}
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getResultTypeOfAggregateFunction}
-import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction}
+import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedAggregateFunction}
import _root_.scala.language.implicitConversions
@@ -1096,8 +1096,8 @@ trait ImplicitExpressionConversions {
}
}
- implicit class AggregateFunctionCall[T: TypeInformation, ACC: TypeInformation]
- (val a: AggregateFunction[T, ACC]) {
+ implicit class UserDefinedAggregateFunctionCall[T: TypeInformation, ACC: TypeInformation]
+ (val a: UserDefinedAggregateFunction[T, ACC]) {
private def createFunctionDefinition(): AggregateFunctionDefinition = {
val resultTypeInfo: TypeInformation[_] = getResultTypeOfAggregateFunction(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/tableImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/tableImpl.scala
index 2a965c3..d9b4cb6 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/tableImpl.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/tableImpl.scala
@@ -445,6 +445,14 @@ class TableImpl(
wrap(operationTreeBuilder.flatMap(tableFunction, operationTree))
}
+ override def flatAggregate(tableAggFunction: String): FlatAggregateTable = {
+ groupBy().flatAggregate(tableAggFunction)
+ }
+
+ override def flatAggregate(tableAggFunction: Expression): FlatAggregateTable = {
+ groupBy().flatAggregate(tableAggFunction)
+ }
+
/**
* Registers an unique table name under the table environment
* and return the registered table name.
@@ -494,6 +502,39 @@ class GroupedTableImpl(
)
))
}
+
+ override def flatAggregate(tableAggFunction: String): FlatAggregateTable = {
+ flatAggregate(ExpressionParser.parseExpression(tableAggFunction))
+ }
+
+ override def flatAggregate(tableAggFunction: Expression): FlatAggregateTable = {
+ new FlatAggregateTableImpl(table, groupKeys, tableAggFunction)
+ }
+}
+
+class FlatAggregateTableImpl(
+ private[flink] val table: Table,
+ private[flink] val groupKey: Seq[Expression],
+ private[flink] val tableAggFunction: Expression) extends FlatAggregateTable {
+
+ private val tableImpl = table.asInstanceOf[TableImpl]
+
+ override def select(fields: String): Table = {
+ select(ExpressionParser.parseExpressionList(fields).asScala: _*)
+ }
+
+ override def select(fields: Expression*): Table = {
+ val resolvedTableAggFunction = tableAggFunction.accept(tableImpl.callResolver)
+
+ val flatAggTable = new TableImpl(tableImpl.tableEnv,
+ tableImpl.operationTreeBuilder.tableAggregate(
+ groupKey.asJava,
+ resolvedTableAggFunction,
+ tableImpl.operationTree
+ ))
+
+ flatAggTable.select(fields: _*)
+ }
}
/**
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
index 3ea046f..5a8678a 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
@@ -34,7 +34,7 @@ import org.apache.flink.table.expressions.{Alias, ExpressionBridge, PlannerExpre
import org.apache.flink.table.operations.TableOperation
import org.apache.flink.table.plan.TableOperationConverter
import org.apache.flink.table.plan.logical.LogicalWindow
-import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate
+import org.apache.flink.table.plan.logical.rel.{LogicalTableAggregate, LogicalWindowAggregate}
import scala.collection.JavaConverters._
@@ -84,6 +84,16 @@ class FlinkRelBuilder(
this
}
+ def tableAggregate(
+ groupKey: GroupKey,
+ aggCalls: Iterable[AggCall]): RelBuilder = {
+
+ // build logical aggregate
+ val aggregate = super.aggregate(groupKey, aggCalls).build().asInstanceOf[LogicalAggregate]
+ // build logical table aggregate from it
+ push(LogicalTableAggregate.create(aggregate))
+ }
+
def tableOperation(tableOperation: TableOperation): RelBuilder= {
val relNode = tableOperation.accept(toRelNodeConverter)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
index 902940c..5b03fec 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
@@ -30,7 +30,7 @@ import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo
import org.apache.flink.table.api.{TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory.{isRowtimeIndicatorType, _}
import org.apache.flink.table.functions.sql.ProctimeSqlFunction
-import org.apache.flink.table.plan.logical.rel.{LogicalTemporalTableJoin, LogicalWindowAggregate}
+import org.apache.flink.table.plan.logical.rel.{LogicalTableAggregate, LogicalTemporalTableJoin, LogicalWindowAggregate}
import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType
import org.apache.flink.table.validate.BasicOperatorTable
@@ -160,6 +160,11 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
aggregate.getNamedProperties,
convAggregate)
+ case tableAgg: LogicalTableAggregate =>
+ val correspondAggregate = LogicalTableAggregate.getCorrespondingAggregate(tableAgg)
+ val convAggregate = convertAggregate(correspondAggregate)
+ LogicalTableAggregate.create(convAggregate)
+
case temporalTableJoin: LogicalTemporalTableJoin =>
visit(temporalTableJoin)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index ea148bc..05f2362 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -24,42 +24,109 @@ import java.util.{List => JList}
import org.apache.calcite.rex.RexLiteral
import org.apache.flink.api.common.state.{ListStateDescriptor, MapStateDescriptor, State, StateDescriptor}
import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument, getRawClass}
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.api.dataview._
import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.dataview.{StateListView, StateMapView}
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.UserDefinedAggregateFunction
import org.apache.flink.table.functions.aggfunctions.DistinctAccumulator
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString}
-import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, SingleElementIterable}
+import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, GeneratedTableAggregations, SingleElementIterable}
import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
import scala.collection.JavaConversions._
import scala.collection.mutable
/**
- * A code generator for generating [[GeneratedAggregations]].
+ * A code generator for generating [[GeneratedAggregations]] or
+ * [[GeneratedTableAggregations]].
*
- * @param config configuration that determines runtime behavior
- * @param nullableInput input(s) can be null.
- * @param input type information about the input of the Function
- * @param constants constant expressions that act like a second input in the parameter indices.
+ * @param config configuration that determines runtime behavior
+ * @param nullableInput input(s) can be null.
+ * @param inputTypeInfo type information about the input of the Function
+ * @param constants constant expressions that act like a second input in the
+ * parameter indices.
+ * @param classNamePrefix Class name of the function.
+ * Does not need to be unique but has to be a valid Java class
+ * identifier.
+ * @param physicalInputTypes Physical input row types
+ * @param aggregates All aggregate functions
+ * @param aggFields Indexes of the input fields for all aggregate functions
+ * @param aggMapping The mapping of aggregates to output fields
+ * @param distinctAccMapping The mapping of the distinct accumulator index to the
+ * corresponding aggregates.
+ * @param isStateBackedDataViews a flag to indicate if distinct filter uses state backend.
+ * @param partialResults A flag defining whether final or partial results (accumulators)
+ * are set
+ * to the output row.
+ * @param fwdMapping The mapping of input fields to output fields
+ * @param mergeMapping An optional mapping to specify the accumulators to merge. If not
+ * set, we
+ * assume that both rows have the accumulators at the same position.
+ * @param outputArity The number of fields in the output row.
+ * @param needRetract a flag to indicate if the aggregate needs the retract method
+ * @param needMerge a flag to indicate if the aggregate needs the merge method
+ * @param needReset a flag to indicate if the aggregate needs the resetAccumulator
+ * method
+ * @param accConfig Data view specification for accumulators
*/
class AggregationCodeGenerator(
- config: TableConfig,
- nullableInput: Boolean,
- input: TypeInformation[_ <: Any],
- constants: Option[Seq[RexLiteral]])
- extends CodeGenerator(config, nullableInput, input) {
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
+ classNamePrefix: String,
+ physicalInputTypes: Seq[TypeInformation[_]],
+ aggregates: Array[UserDefinedAggregateFunction[_ <: Any, _ <: Any]],
+ aggFields: Array[Array[Int]],
+ aggMapping: Array[Int],
+ distinctAccMapping: Array[(Integer, JList[Integer])],
+ isStateBackedDataViews: Boolean,
+ partialResults: Boolean,
+ fwdMapping: Array[Int],
+ mergeMapping: Option[Array[Int]],
+ outputArity: Int,
+ needRetract: Boolean,
+ needMerge: Boolean,
+ needReset: Boolean,
+ accConfig: Option[Array[Seq[DataViewSpec[_]]]])
+ extends CodeGenerator(config, nullableInput, inputTypeInfo) {
// set of statements for cleanup dataview that will be added only once
// we use a LinkedHashSet to keep the insertion order
private val reusableCleanupStatements = mutable.LinkedHashSet[String]()
+ // get unique function name
+ val funcName = newName(classNamePrefix)
+
+ // register UDAGGs
+ val aggs = aggregates.map(a => addReusableFunction(a, contextTerm))
+
+ // get java types of accumulators
+ val accTypeClasses = aggregates.map { a =>
+ a.getClass.getMethod("createAccumulator").getReturnType
+ }
+ val accTypes = accTypeClasses.map(_.getCanonicalName)
+
+ var parametersCode: Array[String] = _
+ var parametersCodeForDistinctAcc: Array[String] = _
+ var parametersCodeForDistinctMerge: Array[String] = _
+
+ // get distinct filter of acc fields for each aggregate functions
+ val distinctAccType = s"${classOf[DistinctAccumulator].getName}"
+ var distinctAccCount: Int = _
+
+ // create constants
+ val constantExprs = constants.map(_.map(generateExpression)).getOrElse(Seq())
+ val constantTypes = constantExprs.map(_.resultType)
+ val constantFields = constantExprs.map(addReusableBoxedConstant)
+
/**
* @return code block of statements that need to be placed in the cleanup() method of
* [[GeneratedAggregations]]
@@ -68,68 +135,10 @@ class AggregationCodeGenerator(
reusableCleanupStatements.mkString("", "\n", "\n")
}
- /**
- * Generates a [[org.apache.flink.table.runtime.aggregate.GeneratedAggregations]] that can be
- * passed to a Java compiler.
- *
- * @param name Class name of the function.
- * Does not need to be unique but has to be a valid Java class identifier.
- * @param physicalInputTypes Physical input row types
- * @param aggregates All aggregate functions
- * @param aggFields Indexes of the input fields for all aggregate functions
- * @param aggMapping The mapping of aggregates to output fields
- * @param distinctAccMapping The mapping of the distinct accumulator index to the
- * corresponding aggregates.
- * @param isStateBackedDataViews a flag to indicate if distinct filter uses state backend.
- * @param partialResults A flag defining whether final or partial results (accumulators) are set
- * to the output row.
- * @param fwdMapping The mapping of input fields to output fields
- * @param mergeMapping An optional mapping to specify the accumulators to merge. If not set, we
- * assume that both rows have the accumulators at the same position.
- * @param outputArity The number of fields in the output row.
- * @param needRetract a flag to indicate if the aggregate needs the retract method
- * @param needMerge a flag to indicate if the aggregate needs the merge method
- * @param needReset a flag to indicate if the aggregate needs the resetAccumulator method
- * @param accConfig Data view specification for accumulators
- *
- * @return A GeneratedAggregationsFunction
- */
- def generateAggregations(
- name: String,
- physicalInputTypes: Seq[TypeInformation[_]],
- aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]],
- aggFields: Array[Array[Int]],
- aggMapping: Array[Int],
- distinctAccMapping: Array[(Integer, JList[Integer])],
- isStateBackedDataViews: Boolean,
- partialResults: Boolean,
- fwdMapping: Array[Int],
- mergeMapping: Option[Array[Int]],
- outputArity: Int,
- needRetract: Boolean,
- needMerge: Boolean,
- needReset: Boolean,
- accConfig: Option[Array[Seq[DataViewSpec[_]]]])
- : GeneratedAggregationsFunction = {
-
- // get unique function name
- val funcName = newName(name)
- // register UDAGGs
- val aggs = aggregates.map(a => addReusableFunction(a, contextTerm))
-
- // get java types of accumulators
- val accTypeClasses = aggregates.map { a =>
- a.getClass.getMethod("createAccumulator").getReturnType
- }
- val accTypes = accTypeClasses.map(_.getCanonicalName)
-
- // create constants
- val constantExprs = constants.map(_.map(generateExpression)).getOrElse(Seq())
- val constantTypes = constantExprs.map(_.resultType)
- val constantFields = constantExprs.map(addReusableBoxedConstant)
+ def init(): Unit = {
// get parameter lists for aggregation functions
- val parametersCode = aggFields.map { inFields =>
+ parametersCode = aggFields.map { inFields =>
val fields = inFields.filter(_ > -1).map { f =>
// index to constant
if (f >= physicalInputTypes.length) {
@@ -145,7 +154,7 @@ class AggregationCodeGenerator(
}
// get parameter lists for distinct acc, constant fields are not necessary
- val parametersCodeForDistinctAcc = aggFields.map { inFields =>
+ parametersCodeForDistinctAcc = aggFields.map { inFields =>
val fields = inFields.filter(i => i > -1 && i < physicalInputTypes.length).map { f =>
// index to input field
s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) input.getField($f)"
@@ -154,7 +163,7 @@ class AggregationCodeGenerator(
fields.mkString(", ")
}
- val parametersCodeForDistinctMerge = aggFields.map { inFields =>
+ parametersCodeForDistinctMerge = aggFields.map { inFields =>
// transform inFields to pairs of (inField, index in acc) firstly,
// e.g. (4, 2, 3, 2) will be transformed to ((4,2), (2,0), (3,1), (2,0))
val fields = inFields.filter(_ > -1).groupBy(identity).toSeq.sortBy(_._1).zipWithIndex
@@ -189,11 +198,7 @@ class AggregationCodeGenerator(
}
}
- // get distinct filter of acc fields for each aggregate functions
- val distinctAccType = s"${classOf[DistinctAccumulator].getName}"
-
- val distinctAccCount = distinctAccMapping.count(_._1 >= 0)
-
+ distinctAccCount = distinctAccMapping.count(_._1 >= 0)
if (distinctAccCount > 0 && partialResults && isStateBackedDataViews) {
// should not happen, but add an error message just in case.
throw new CodeGenException(
@@ -209,31 +214,31 @@ class AggregationCodeGenerator(
aggregates.zipWithIndex.map {
case (a, i) =>
getUserDefinedMethod(a, "accumulate", Array(accTypeClasses(i)) ++ methodSignaturesList(i))
- .getOrElse(
- throw new CodeGenException(
- s"No matching accumulate method found for AggregateFunction " +
- s"'${a.getClass.getCanonicalName}'" +
- s"with parameters '${signatureToString(methodSignaturesList(i))}'.")
- )
-
- if (needRetract) {
- getUserDefinedMethod(a, "retract", Array(accTypeClasses(i)) ++ methodSignaturesList(i))
.getOrElse(
throw new CodeGenException(
- s"No matching retract method found for AggregateFunction " +
+ s"No matching accumulate method found for AggregateFunction " +
s"'${a.getClass.getCanonicalName}'" +
s"with parameters '${signatureToString(methodSignaturesList(i))}'.")
)
+
+ if (needRetract) {
+ getUserDefinedMethod(a, "retract", Array(accTypeClasses(i)) ++ methodSignaturesList(i))
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching retract method found for AggregateFunction " +
+ s"'${a.getClass.getCanonicalName}'" +
+ s"with parameters '${signatureToString(methodSignaturesList(i))}'.")
+ )
}
if (needMerge) {
val method =
getUserDefinedMethod(a, "merge", Array(accTypeClasses(i), classOf[JIterable[Any]]))
- .getOrElse(
- throw new CodeGenException(
- s"No matching merge method found for AggregateFunction " +
- s"${a.getClass.getCanonicalName}'.")
- )
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching merge method found for AggregateFunction " +
+ s"${a.getClass.getCanonicalName}'.")
+ )
// use the TypeExtractionUtils here to support nested GenericArrayTypes and
// other complex types
@@ -250,526 +255,535 @@ class AggregationCodeGenerator(
if (needReset) {
getUserDefinedMethod(a, "resetAccumulator", Array(accTypeClasses(i)))
- .getOrElse(
- throw new CodeGenException(
- s"No matching resetAccumulator method found for " +
- s"aggregate ${a.getClass.getCanonicalName}'.")
- )
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching resetAccumulator method found for " +
+ s"aggregate ${a.getClass.getCanonicalName}'.")
+ )
}
}
+ }
- /**
- * Add all data views for all field accumulators and distinct filters defined by
- * aggregation functions.
- */
- def addAccumulatorDataViews(): Unit = {
- if (accConfig.isDefined) {
- // create state handles for DataView backed accumulator fields.
- val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get
- .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
- .toMap[String, StateDescriptor[_ <: State, _]]
-
- for (i <- 0 until aggs.length + distinctAccCount) yield {
- for (spec <- accConfig.get(i)) yield {
- // Check if stat descriptor exists.
- val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId,
- throw new CodeGenException(
- s"Can not find DataView in accumulator by id: ${spec.stateId}"))
+ /**
+ * Add all data views for all field accumulators and distinct filters defined by
+ * aggregation functions.
+ */
+ def addAccumulatorDataViews(): Unit = {
+ if (accConfig.isDefined) {
+ // create state handles for DataView backed accumulator fields.
+ val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get
+ .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
+ .toMap[String, StateDescriptor[_ <: State, _]]
+
+ for (i <- 0 until aggs.length + distinctAccCount) yield {
+ for (spec <- accConfig.get(i)) yield {
+ // Check if stat descriptor exists.
+ val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId,
+ throw new CodeGenException(
+ s"Can not find DataView in accumulator by id: ${spec.stateId}"))
- addReusableDataView(spec, desc, i)
- }
+ addReusableDataView(spec, desc, i)
}
}
}
+ }
- /**
- * Create DataView Term, for example, acc1_map_dataview.
- *
- * @param aggIndex index of aggregate function
- * @param fieldName field name of DataView
- * @return term to access [[MapView]] or [[ListView]]
- */
- def createDataViewTerm(aggIndex: Int, fieldName: String): String = {
- s"acc${aggIndex}_${fieldName}_dataview"
+ /**
+ * Create DataView Term, for example, acc1_map_dataview.
+ *
+ * @param aggIndex index of aggregate function
+ * @param fieldName field name of DataView
+ * @return term to access [[MapView]] or [[ListView]]
+ */
+ def createDataViewTerm(aggIndex: Int, fieldName: String): String = {
+ s"acc${aggIndex}_${fieldName}_dataview"
+ }
+
+ /**
+ * Adds a reusable [[org.apache.flink.table.api.dataview.DataView]] to the open, cleanup,
+ * close and member area of the generated function.
+ * @param spec the [[DataViewSpec]] of the desired data view term.
+ * @param desc the [[StateDescriptor]] of the desired data view term.
+ * @param aggIndex the aggregation function index associate with the data view.
+ */
+ def addReusableDataView(
+ spec: DataViewSpec[_],
+ desc: StateDescriptor[_, _],
+ aggIndex: Int): Unit = {
+ val dataViewField = spec.field
+ val dataViewTypeTerm = dataViewField.getType.getCanonicalName
+
+ // define the DataView variables
+ val serializedData = EncodingUtils.encodeObjectToString(desc)
+ val dataViewFieldTerm = createDataViewTerm(aggIndex, dataViewField.getName)
+ val field =
+ s"""
+ | final $dataViewTypeTerm $dataViewFieldTerm;
+ |""".stripMargin
+ reusableMemberStatements.add(field)
+
+ // create DataViews
+ val descFieldTerm = s"${dataViewFieldTerm}_desc"
+ val descClassQualifier = classOf[StateDescriptor[_, _]].getCanonicalName
+ val descDeserializeCode =
+ s"""
+ | $descClassQualifier $descFieldTerm = ($descClassQualifier)
+ | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
+ | "$serializedData",
+ | $descClassQualifier.class,
+ | $contextTerm.getUserCodeClassLoader());
+ |""".stripMargin
+ val createDataView = if (dataViewField.getType == classOf[MapView[_, _]]) {
+ s"""
+ | $descDeserializeCode
+ | $dataViewFieldTerm = new ${classOf[StateMapView[_, _]].getCanonicalName}(
+ | $contextTerm.getMapState(
+ | (${classOf[MapStateDescriptor[_, _]].getCanonicalName}) $descFieldTerm));
+ |""".stripMargin
+ } else if (dataViewField.getType == classOf[ListView[_]]) {
+ s"""
+ | $descDeserializeCode
+ | $dataViewFieldTerm = new ${classOf[StateListView[_]].getCanonicalName}(
+ | $contextTerm.getListState(
+ | (${classOf[ListStateDescriptor[_]].getCanonicalName}) $descFieldTerm));
+ |""".stripMargin
+ } else {
+ throw new CodeGenException(s"Unsupported dataview type: $dataViewTypeTerm")
}
+ reusableOpenStatements.add(createDataView)
+
+ // cleanup DataViews
+ val cleanup =
+ s"""
+ | $dataViewFieldTerm.clear();
+ |""".stripMargin
+ reusableCleanupStatements.add(cleanup)
+ }
- /**
- * Adds a reusable [[org.apache.flink.table.api.dataview.DataView]] to the open, cleanup,
- * close and member area of the generated function.
- * @param spec the [[DataViewSpec]] of the desired data view term.
- * @param desc the [[StateDescriptor]] of the desired data view term.
- * @param aggIndex the aggregation function index associate with the data view.
- */
- def addReusableDataView(
- spec: DataViewSpec[_],
- desc: StateDescriptor[_, _],
- aggIndex: Int): Unit = {
- val dataViewField = spec.field
- val dataViewTypeTerm = dataViewField.getType.getCanonicalName
-
- // define the DataView variables
- val serializedData = EncodingUtils.encodeObjectToString(desc)
- val dataViewFieldTerm = createDataViewTerm(aggIndex, dataViewField.getName)
- val field =
- s"""
- | final $dataViewTypeTerm $dataViewFieldTerm;
- |""".stripMargin
- reusableMemberStatements.add(field)
-
- // create DataViews
- val descFieldTerm = s"${dataViewFieldTerm}_desc"
- val descClassQualifier = classOf[StateDescriptor[_, _]].getCanonicalName
- val descDeserializeCode =
- s"""
- | $descClassQualifier $descFieldTerm = ($descClassQualifier)
- | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
- | "$serializedData",
- | $descClassQualifier.class,
- | $contextTerm.getUserCodeClassLoader());
- |""".stripMargin
- val createDataView = if (dataViewField.getType == classOf[MapView[_, _]]) {
- s"""
- | $descDeserializeCode
- | $dataViewFieldTerm = new ${classOf[StateMapView[_, _]].getCanonicalName}(
- | $contextTerm.getMapState(
- | (${classOf[MapStateDescriptor[_, _]].getCanonicalName}) $descFieldTerm));
- |""".stripMargin
- } else if (dataViewField.getType == classOf[ListView[_]]) {
- s"""
- | $descDeserializeCode
- | $dataViewFieldTerm = new ${classOf[StateListView[_]].getCanonicalName}(
- | $contextTerm.getListState(
- | (${classOf[ListStateDescriptor[_]].getCanonicalName}) $descFieldTerm));
- |""".stripMargin
- } else {
- throw new CodeGenException(s"Unsupported dataview type: $dataViewTypeTerm")
- }
- reusableOpenStatements.add(createDataView)
-
- // cleanup DataViews
- val cleanup =
- s"""
- | $dataViewFieldTerm.clear();
- |""".stripMargin
- reusableCleanupStatements.add(cleanup)
+ def genAccDataViewFieldSetter(str: String, i: Int): String = {
+ if (accConfig.isDefined) {
+ genDataViewFieldSetter(accConfig.get(i), str, i)
+ } else {
+ ""
}
+ }
- def genAccDataViewFieldSetter(str: String, i: Int): String = {
- if (accConfig.isDefined) {
- genDataViewFieldSetter(accConfig.get(i), str, i)
+ /**
+ * Generate statements to set data view field when use state backend.
+ *
+ * @param specs aggregation [[DataViewSpec]]s for this aggregation term.
+ * @param accTerm aggregation term
+ * @param aggIndex index of aggregation
+ * @return data view field set statements
+ */
+ def genDataViewFieldSetter(
+ specs: Seq[DataViewSpec[_]],
+ accTerm: String,
+ aggIndex: Int): String = {
+ val setters = for (spec <- specs) yield {
+ val field = spec.field
+ val dataViewTerm = createDataViewTerm(aggIndex, field.getName)
+ val fieldSetter = if (Modifier.isPublic(field.getModifiers)) {
+ s"$accTerm.${field.getName} = $dataViewTerm;"
} else {
- ""
+ val fieldTerm = addReusablePrivateFieldAccess(field.getDeclaringClass, field.getName)
+ s"${reflectiveFieldWriteAccess(fieldTerm, field, accTerm, dataViewTerm)};"
}
- }
- /**
- * Generate statements to set data view field when use state backend.
- *
- * @param specs aggregation [[DataViewSpec]]s for this aggregation term.
- * @param accTerm aggregation term
- * @param aggIndex index of aggregation
- * @return data view field set statements
- */
- def genDataViewFieldSetter(
- specs: Seq[DataViewSpec[_]],
- accTerm: String,
- aggIndex: Int): String = {
- val setters = for (spec <- specs) yield {
- val field = spec.field
- val dataViewTerm = createDataViewTerm(aggIndex, field.getName)
- val fieldSetter = if (Modifier.isPublic(field.getModifiers)) {
- s"$accTerm.${field.getName} = $dataViewTerm;"
- } else {
- val fieldTerm = addReusablePrivateFieldAccess(field.getDeclaringClass, field.getName)
- s"${reflectiveFieldWriteAccess(fieldTerm, field, accTerm, dataViewTerm)};"
- }
-
- s"""
- | $fieldSetter
+ s"""
+ | $fieldSetter
""".stripMargin
- }
- setters.mkString("\n")
}
+ setters.mkString("\n")
+ }
- def genSetAggregationResults: String = {
-
- val sig: String =
- j"""
- | public final void setAggregationResults(
- | org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row output) throws Exception """.stripMargin
-
- val setAggs: String = {
- for ((i, aggIndexes) <- distinctAccMapping) yield {
- if (partialResults) {
- def setAggs(aggIndexes: JList[Integer]) = {
- for (i <- aggIndexes) yield {
- j"""
- |output.setField(
- | ${aggMapping(i)},
- | (${accTypes(i)}) accs.getField($i));
- """.stripMargin
- }
- }.mkString("\n")
+ def genSetAggregationResults: String = {
- if (i >= 0) {
- j"""
- | output.setField(
- | ${aggMapping(i)},
- | ($distinctAccType) accs.getField($i));
- | ${setAggs(aggIndexes)}
- """.stripMargin
- } else {
+ val sig: String =
+ j"""
+ | public final void setAggregationResults(
+ | org.apache.flink.types.Row accs,
+ | org.apache.flink.types.Row output) throws Exception """.stripMargin
+
+ val setAggs: String = {
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (partialResults) {
+ def setAggs(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
j"""
- | ${setAggs(aggIndexes)}
+ |output.setField(
+ | ${aggMapping(i)},
+ | (${accTypes(i)}) accs.getField($i));
""".stripMargin
}
+ }.mkString("\n")
+
+ if (i >= 0) {
+ j"""
+ | output.setField(
+ | ${aggMapping(i)},
+ | ($distinctAccType) accs.getField($i));
+ | ${setAggs(aggIndexes)}
+ """.stripMargin
} else {
- def setAggs(aggIndexes: JList[Integer]) = {
- for (i <- aggIndexes) yield {
- val setAccOutput =
- j"""
- |${genAccDataViewFieldSetter(s"acc$i", i)}
- |output.setField(
- | ${aggMapping(i)},
- | baseClass$i.getValue(acc$i));
+ j"""
+ | ${setAggs(aggIndexes)}
+ """.stripMargin
+ }
+ } else {
+ def setAggs(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ val setAccOutput =
+ j"""
+ |${genAccDataViewFieldSetter(s"acc$i", i)}
+ |output.setField(
+ | ${aggMapping(i)},
+ | baseClass$i.getValue(acc$i));
""".stripMargin
- j"""
- |org.apache.flink.table.functions.AggregateFunction baseClass$i =
- | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
- |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- |$setAccOutput
+ j"""
+ |org.apache.flink.table.functions.AggregateFunction baseClass$i =
+ | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
+ |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ |$setAccOutput
""".stripMargin
- }
- }.mkString("\n")
+ }
+ }.mkString("\n")
- j"""
- | ${setAggs(aggIndexes)}
+ j"""
+ | ${setAggs(aggIndexes)}
""".stripMargin
- }
}
- }.mkString("\n")
+ }
+ }.mkString("\n")
- j"""
- |$sig {
- |$setAggs
- | }""".stripMargin
- }
+ j"""
+ |$sig {
+ |$setAggs
+ | }""".stripMargin
+ }
- def genAccumulate: String = {
+ def genAccumulate: String = {
- val sig: String =
- j"""
- | public final void accumulate(
- | org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row input) throws Exception """.stripMargin
+ val sig: String =
+ j"""
+ | public final void accumulate(
+ | org.apache.flink.types.Row accs,
+ | org.apache.flink.types.Row input) throws Exception """.stripMargin
- val accumulate: String = {
- def accumulateAcc(aggIndexes: JList[Integer]) = {
- for (i <- aggIndexes) yield {
- j"""
- |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- |${genAccDataViewFieldSetter(s"acc$i", i)}
- |${aggs(i)}.accumulate(acc$i
- | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
+ val accumulate: String = {
+ def accumulateAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ |${genAccDataViewFieldSetter(s"acc$i", i)}
+ |${aggs(i)}.accumulate(acc$i
+ | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
""".stripMargin
- }
- }.mkString("\n")
+ }
+ }.mkString("\n")
- for ((i, aggIndexes) <- distinctAccMapping) yield {
- if (i >= 0) {
- j"""
- | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
- | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
- | if (distinctAcc$i.add(${classOf[Row].getCanonicalName}.of(
- | ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) {
- | ${accumulateAcc(aggIndexes)}
- | }
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
+ j"""
+ | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
+ | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
+ | if (distinctAcc$i.add(${classOf[Row].getCanonicalName}.of(
+ | ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) {
+ | ${accumulateAcc(aggIndexes)}
+ | }
""".stripMargin
- } else {
- j"""
- | ${accumulateAcc(aggIndexes)}
+ } else {
+ j"""
+ | ${accumulateAcc(aggIndexes)}
""".stripMargin
- }
}
- }.mkString("\n")
+ }
+ }.mkString("\n")
- j"""$sig {
- |$accumulate
- | }""".stripMargin
- }
+ j"""$sig {
+ |$accumulate
+ | }""".stripMargin
+ }
- def genRetract: String = {
+ def genRetract: String = {
- val sig: String =
- j"""
- | public final void retract(
- | org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row input) throws Exception """.stripMargin
+ val sig: String =
+ j"""
+ | public final void retract(
+ | org.apache.flink.types.Row accs,
+ | org.apache.flink.types.Row input) throws Exception """.stripMargin
- val retract: String = {
- def retractAcc(aggIndexes: JList[Integer]) = {
- for (i <- aggIndexes) yield {
- j"""
- |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- |${genAccDataViewFieldSetter(s"acc$i", i)}
- |${aggs(i)}.retract(acc$i
- | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
+ val retract: String = {
+ def retractAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ |${genAccDataViewFieldSetter(s"acc$i", i)}
+ |${aggs(i)}.retract(acc$i
+ | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
""".stripMargin
- }
- }.mkString("\n")
+ }
+ }.mkString("\n")
- for ((i, aggIndexes) <- distinctAccMapping) yield {
- if (i >= 0) {
- j"""
- | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
- | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
- | if (distinctAcc$i.remove(${classOf[Row].getCanonicalName}.of(
- | ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) {
- | ${retractAcc(aggIndexes)}
- | }
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
+ j"""
+ | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
+ | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
+ | if (distinctAcc$i.remove(${classOf[Row].getCanonicalName}.of(
+ | ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) {
+ | ${retractAcc(aggIndexes)}
+ | }
""".stripMargin
- } else {
- j"""
- | ${retractAcc(aggIndexes)}
+ } else {
+ j"""
+ | ${retractAcc(aggIndexes)}
""".stripMargin
- }
}
- }.mkString("\n")
-
- if (needRetract) {
- j"""
- |$sig {
- |$retract
- | }""".stripMargin
- } else {
- j"""
- |$sig {
- | }""".stripMargin
}
+ }.mkString("\n")
+
+ if (needRetract) {
+ j"""
+ |$sig {
+ |$retract
+ | }""".stripMargin
+ } else {
+ j"""
+ |$sig {
+ | }""".stripMargin
}
+ }
- def genCreateAccumulators: String = {
+ def genCreateAccumulators: String = {
- val sig: String =
- j"""
- | public final org.apache.flink.types.Row createAccumulators() throws Exception
- | """.stripMargin
- val init: String =
- j"""
- | org.apache.flink.types.Row accs =
- | new org.apache.flink.types.Row(${aggs.length + distinctAccCount});"""
+ val sig: String =
+ j"""
+ | public final org.apache.flink.types.Row createAccumulators() throws Exception
+ | """.stripMargin
+ val init: String =
+ j"""
+ | org.apache.flink.types.Row accs =
+ | new org.apache.flink.types.Row(${aggs.length + distinctAccCount});"""
.stripMargin
- val create: String = {
- def createAcc(aggIndexes: JList[Integer]) = {
- for (i <- aggIndexes) yield {
- j"""
- |${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
- |accs.setField(
- | $i,
- | acc$i);
+ val create: String = {
+ def createAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
+ |accs.setField(
+ | $i,
+ | acc$i);
""".stripMargin
- }
- }.mkString("\n")
+ }
+ }.mkString("\n")
- for ((i, aggIndexes) <- distinctAccMapping) yield {
- if (i >= 0) {
- j"""
- | $distinctAccType distinctAcc$i = ($distinctAccType)
- | new ${classOf[DistinctAccumulator].getCanonicalName}();
- | accs.setField(
- | $i,
- | distinctAcc$i);
- | ${createAcc(aggIndexes)}
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
+ j"""
+ | $distinctAccType distinctAcc$i = ($distinctAccType)
+ | new ${classOf[DistinctAccumulator].getCanonicalName}();
+ | accs.setField(
+ | $i,
+ | distinctAcc$i);
+ | ${createAcc(aggIndexes)}
""".stripMargin
- } else {
- j"""
- | ${createAcc(aggIndexes)}
+ } else {
+ j"""
+ | ${createAcc(aggIndexes)}
""".stripMargin
- }
}
- }.mkString("\n")
- val ret: String =
- j"""
- | return accs;"""
+ }
+ }.mkString("\n")
+ val ret: String =
+ j"""
+ | return accs;"""
.stripMargin
- j"""$sig {
- |$init
- |$create
- |$ret
- | }""".stripMargin
- }
+ j"""$sig {
+ |$init
+ |$create
+ |$ret
+ | }""".stripMargin
+ }
- def genSetForwardedFields: String = {
+ def genSetForwardedFields: String = {
- val sig: String =
- j"""
- | public final void setForwardedFields(
- | org.apache.flink.types.Row input,
- | org.apache.flink.types.Row output)
- | """.stripMargin
-
- val forward: String = {
- for (i <- fwdMapping.indices if fwdMapping(i) >= 0) yield
- {
- j"""
- | output.setField(
- | $i,
- | input.getField(${fwdMapping(i)}));"""
+ val sig: String =
+ j"""
+ | public final void setForwardedFields(
+ | org.apache.flink.types.Row input,
+ | org.apache.flink.types.Row output)
+ | """.stripMargin
+
+ val forward: String = {
+ for (i <- fwdMapping.indices if fwdMapping(i) >= 0) yield
+ {
+ j"""
+ | output.setField(
+ | $i,
+ | input.getField(${fwdMapping(i)}));"""
.stripMargin
- }
- }.mkString("\n")
+ }
+ }.mkString("\n")
- j"""$sig {
- |$forward
- | }""".stripMargin
- }
+ j"""$sig {
+ |$forward
+ | }""".stripMargin
+ }
- def genCreateOutputRow: String = {
- j"""
- | public final org.apache.flink.types.Row createOutputRow() {
- | return new org.apache.flink.types.Row($outputArity);
- | }""".stripMargin
- }
+ def genCreateOutputRow: String = {
+ j"""
+ | public final org.apache.flink.types.Row createOutputRow() {
+ | return new org.apache.flink.types.Row($outputArity);
+ | }""".stripMargin
+ }
- def genMergeAccumulatorsPair: String = {
- val mapping = mergeMapping.getOrElse((0 until aggs.length + distinctAccCount).toArray)
+ def genMergeAccumulatorsPair: String = {
+ val mapping = mergeMapping.getOrElse((0 until aggs.length + distinctAccCount).toArray)
- val sig: String =
- j"""
- | public final org.apache.flink.types.Row mergeAccumulatorsPair(
- | org.apache.flink.types.Row a,
- | org.apache.flink.types.Row b)
+ val sig: String =
+ j"""
+ | public final org.apache.flink.types.Row mergeAccumulatorsPair(
+ | org.apache.flink.types.Row a,
+ | org.apache.flink.types.Row b)
""".stripMargin
- val merge: String = {
- def accumulateAcc(aggIndexes: JList[Integer]) = {
- for (i <- aggIndexes) yield {
- j"""
- |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
- |${aggs(i)}.accumulate(aAcc$i, ${parametersCodeForDistinctMerge(i)});
- |a.setField($i, aAcc$i);
+ val merge: String = {
+ def accumulateAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
+ |${aggs(i)}.accumulate(aAcc$i, ${parametersCodeForDistinctMerge(i)});
+ |a.setField($i, aAcc$i);
""".stripMargin
- }
- }.mkString("\n")
+ }
+ }.mkString("\n")
- def mergeAcc(aggIndexes: JList[Integer]) = {
- for (i <- aggIndexes) yield {
- j"""
- |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
- |${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
- |accIt$i.setElement(bAcc$i);
- |${aggs(i)}.merge(aAcc$i, accIt$i);
- |a.setField($i, aAcc$i);
+ def mergeAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
+ |${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
+ |accIt$i.setElement(bAcc$i);
+ |${aggs(i)}.merge(aAcc$i, accIt$i);
+ |a.setField($i, aAcc$i);
""".stripMargin
- }
- }.mkString("\n")
+ }
+ }.mkString("\n")
- for ((i, aggIndexes) <- distinctAccMapping) yield {
- if (i >= 0) {
- j"""
- | $distinctAccType aDistinctAcc$i = ($distinctAccType) a.getField($i);
- | $distinctAccType bDistinctAcc$i = ($distinctAccType) b.getField(${mapping(i)});
- | java.util.Iterator<java.util.Map.Entry> mergeIt$i =
- | bDistinctAcc$i.elements().iterator();
- |
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
+ j"""
+ | $distinctAccType aDistinctAcc$i = ($distinctAccType) a.getField($i);
+ | $distinctAccType bDistinctAcc$i = ($distinctAccType) b.getField(${mapping(i)});
+ | java.util.Iterator<java.util.Map.Entry> mergeIt$i =
+ | bDistinctAcc$i.elements().iterator();
+ |
| while (mergeIt$i.hasNext()) {
- | java.util.Map.Entry entry = (java.util.Map.Entry) mergeIt$i.next();
- | ${classOf[Row].getCanonicalName} k =
- | (${classOf[Row].getCanonicalName}) entry.getKey();
- | Long v = (Long) entry.getValue();
- | if (aDistinctAcc$i.add(k, v)) {
- | ${accumulateAcc(aggIndexes)}
- | }
- | }
- | a.setField($i, aDistinctAcc$i);
+ | java.util.Map.Entry entry = (java.util.Map.Entry) mergeIt$i.next();
+ | ${classOf[Row].getCanonicalName} k =
+ | (${classOf[Row].getCanonicalName}) entry.getKey();
+ | Long v = (Long) entry.getValue();
+ | if (aDistinctAcc$i.add(k, v)) {
+ | ${accumulateAcc(aggIndexes)}
+ | }
+ | }
+ | a.setField($i, aDistinctAcc$i);
""".stripMargin
- } else {
- j"""
- | ${mergeAcc(aggIndexes)}
+ } else {
+ j"""
+ | ${mergeAcc(aggIndexes)}
""".stripMargin
- }
}
- }.mkString("\n")
- val ret: String =
- j"""
- | return a;
+ }
+ }.mkString("\n")
+ val ret: String =
+ j"""
+ | return a;
""".stripMargin
- if (needMerge) {
- if (accConfig.isDefined) {
- throw new CodeGenException("DataView doesn't support merge when the backend uses " +
- s"state when generate aggregation for $funcName.")
- }
- j"""
- |$sig {
- |$merge
- |$ret
- | }""".stripMargin
- } else {
- j"""
- |$sig {
- |$ret
- | }""".stripMargin
+ if (needMerge) {
+ if (accConfig.isDefined) {
+ throw new CodeGenException("DataView doesn't support merge when the backend uses " +
+ s"state when generate aggregation for $funcName.")
}
+ j"""
+ |$sig {
+ |$merge
+ |$ret
+ | }""".stripMargin
+ } else {
+ j"""
+ |$sig {
+ |$ret
+ | }""".stripMargin
}
+ }
- def genMergeList: String = {
- {
- val singleIterableClass = classOf[SingleElementIterable[_]].getCanonicalName
- for (i <- accTypes.indices) yield
- j"""
- | private final $singleIterableClass<${accTypes(i)}> accIt$i =
- | new $singleIterableClass<${accTypes(i)}>();
+ def genMergeList: String = {
+ {
+ val singleIterableClass = classOf[SingleElementIterable[_]].getCanonicalName
+ for (i <- accTypes.indices) yield
+ j"""
+ | private final $singleIterableClass<${accTypes(i)}> accIt$i =
+ | new $singleIterableClass<${accTypes(i)}>();
""".stripMargin
- }.mkString("\n")
- }
+ }.mkString("\n")
+ }
- def genResetAccumulator: String = {
+ def genResetAccumulator: String = {
- val sig: String =
- j"""
- | public final void resetAccumulator(
- | org.apache.flink.types.Row accs) throws Exception """.stripMargin
+ val sig: String =
+ j"""
+ | public final void resetAccumulator(
+ | org.apache.flink.types.Row accs) throws Exception """.stripMargin
- val reset: String = {
- def resetAcc(aggIndexes: JList[Integer]) = {
- for (i <- aggIndexes) yield {
- j"""
- |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- |${genAccDataViewFieldSetter(s"acc$i", i)}
- |${aggs(i)}.resetAccumulator(acc$i);
+ val reset: String = {
+ def resetAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ |${genAccDataViewFieldSetter(s"acc$i", i)}
+ |${aggs(i)}.resetAccumulator(acc$i);
""".stripMargin
- }
- }.mkString("\n")
+ }
+ }.mkString("\n")
- for ((i, aggIndexes) <- distinctAccMapping) yield {
- if (i >= 0) {
- j"""
- | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
- | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
- | distinctAcc$i.reset();
- | ${resetAcc(aggIndexes)}
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
+ j"""
+ | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
+ | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
+ | distinctAcc$i.reset();
+ | ${resetAcc(aggIndexes)}
""".stripMargin
- } else {
- j"""
- | ${resetAcc(aggIndexes)}
+ } else {
+ j"""
+ | ${resetAcc(aggIndexes)}
""".stripMargin
- }
}
- }.mkString("\n")
-
- if (needReset) {
- j"""$sig {
- |$reset
- | }""".stripMargin
- } else {
- j"""$sig {
- | }""".stripMargin
}
+ }.mkString("\n")
+
+ if (needReset) {
+ j"""$sig {
+ |$reset
+ | }""".stripMargin
+ } else {
+ j"""$sig {
+ | }""".stripMargin
}
+ }
+
+ /**
+ * Generates a [[GeneratedAggregations]] that can be passed to a Java compiler.
+ *
+ * @return A GeneratedAggregationsFunction
+ */
+ def generateAggregations: GeneratedAggregationsFunction = {
+ init()
val aggFuncCode = Seq(
genSetAggregationResults,
genAccumulate,
@@ -811,4 +825,159 @@ class AggregationCodeGenerator(
GeneratedAggregationsFunction(funcName, funcCode)
}
+
+ /**
+ * Generates a [[org.apache.flink.table.runtime.aggregate.GeneratedAggregations]] that can be
+ * passed to a Java compiler.
+ *
+ * @return A GeneratedAggregationsFunction
+ */
+ def generateTableAggregations(
+ tableAggOutputRowType: RowTypeInfo,
+ tableAggOutputType: TypeInformation[_]): GeneratedAggregationsFunction = {
+
+ // constants
+ val CONVERT_COLLECTOR_CLASS_TERM = "ConvertCollector"
+
+ val CONVERT_COLLECTOR_VARIABLE_TERM = "convertCollector"
+ val COLLECTOR_VARIABLE_TERM = "cRowWrappingcollector"
+ val CONVERTER_ROW_RESULT_TERM = "rowTerm"
+
+ val COLLECTOR: String = classOf[Collector[_]].getCanonicalName
+ val ROW: String = classOf[Row].getCanonicalName
+
+ def genEmit: String = {
+
+ val sig: String =
+ j"""
+ | public final void emit(
+ | $ROW accs,
+ | $COLLECTOR<$ROW> collector) throws Exception """.stripMargin
+
+ val emit: String = {
+ for (i <- aggs.indices) yield {
+ val emitAcc =
+ j"""
+ | ${genAccDataViewFieldSetter(s"acc$i", i)}
+ | ${aggs(i)}.emitValue(acc$i
+ | ${if (!parametersCode(i).isEmpty) "," else ""}
+ | $CONVERT_COLLECTOR_VARIABLE_TERM);
+ """.stripMargin
+ j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | $CONVERT_COLLECTOR_VARIABLE_TERM.$COLLECTOR_VARIABLE_TERM = collector;
+ | $emitAcc
+ """.stripMargin
+ }
+ }.mkString("\n")
+
+ j"""$sig {
+ |$emit
+ |}""".stripMargin
+ }
+
+ def genRecordToRow: String = {
+ // gen access expr
+
+ val functionGenerator = new FunctionCodeGenerator(
+ config,
+ false,
+ tableAggOutputType,
+ None,
+ None,
+ None)
+
+ functionGenerator.outRecordTerm = s"$CONVERTER_ROW_RESULT_TERM"
+ val resultExprs = functionGenerator.generateConverterResultExpression(
+ tableAggOutputRowType, tableAggOutputRowType.getFieldNames)
+
+ functionGenerator.reuseInputUnboxingCode() + resultExprs.code
+ }
+
+ /**
+ * Call super init and check emit methods.
+ */
+ def innerInit(): Unit = {
+ init()
+ // check and validate the emit methods
+ aggregates.zipWithIndex.map {
+ case (a, i) =>
+ val methodName = "emitValue"
+ getUserDefinedMethod(
+ a, methodName, Array(accTypeClasses(i), classOf[Collector[_]]))
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching $methodName method found for " +
+ s"tableAggregate ${a.getClass.getCanonicalName}'.")
+ )
+ }
+ }
+
+ innerInit()
+ val aggFuncCode = Seq(
+ genAccumulate,
+ genRetract,
+ genCreateAccumulators,
+ genCreateOutputRow,
+ genSetForwardedFields,
+ genMergeAccumulatorsPair,
+ genEmit).mkString("\n")
+
+ val generatedAggregationsClass = classOf[GeneratedTableAggregations].getCanonicalName
+ val aggOutputTypeName = tableAggOutputType.getTypeClass.getCanonicalName
+ val funcCode =
+ j"""
+ |public final class $funcName extends $generatedAggregationsClass {
+ |
+ | private $CONVERT_COLLECTOR_CLASS_TERM $CONVERT_COLLECTOR_VARIABLE_TERM;
+ | ${reuseMemberCode()}
+ | $genMergeList
+ | public $funcName() throws Exception {
+ | ${reuseInitCode()}
+ | $CONVERT_COLLECTOR_VARIABLE_TERM = new $CONVERT_COLLECTOR_CLASS_TERM();
+ | }
+ | ${reuseConstructorCode(funcName)}
+ |
+ | public final void open(
+ | org.apache.flink.api.common.functions.RuntimeContext $contextTerm) throws Exception {
+ | ${reuseOpenCode()}
+ | }
+ |
+ | $aggFuncCode
+ |
+ | public final void cleanup() throws Exception {
+ | ${reuseCleanupCode()}
+ | }
+ |
+ | public final void close() throws Exception {
+ | ${reuseCloseCode()}
+ | }
+ |
+ | private class $CONVERT_COLLECTOR_CLASS_TERM implements $COLLECTOR {
+ |
+ | public $COLLECTOR<$ROW> $COLLECTOR_VARIABLE_TERM;
+ | private final $ROW $CONVERTER_ROW_RESULT_TERM =
+ | new $ROW(${tableAggOutputType.getArity});
+ |
+ | public $ROW convertToRow(Object record) throws Exception {
+ | $aggOutputTypeName in1 = ($aggOutputTypeName) record;
+ | $genRecordToRow
+ | return $CONVERTER_ROW_RESULT_TERM;
+ | }
+ |
+ | @Override
+ | public void collect(Object record) throws Exception {
+ | $COLLECTOR_VARIABLE_TERM.collect(convertToRow(record));
+ | }
+ |
+ | @Override
+ | public void close() {
+ | $COLLECTOR_VARIABLE_TERM.close();
+ | }
+ | }
+ |}
+ """.stripMargin
+
+ GeneratedAggregationsFunction(funcName, funcCode)
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
index 6097178..8e6ac7d 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
@@ -36,7 +36,7 @@ import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName, primitiveDefaultValue, primitiveTypeTermForTypeInfo}
import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.Indenter.toISC
-import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction}
+import org.apache.flink.table.functions.UserDefinedAggregateFunction
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.`match`.{IterativeConditionRunner, PatternProcessFunctionRunner}
import org.apache.flink.table.runtime.aggregate.AggregateUtil
@@ -729,9 +729,11 @@ class MatchCodeGenerator(
def generateAggFunction(): Unit = {
val matchAgg = extractAggregatesAndExpressions
- val aggGenerator = new AggregationCodeGenerator(config, false, input, None)
-
- val aggFunc = aggGenerator.generateAggregations(
+ val aggGenerator = new AggregationCodeGenerator(
+ config,
+ false,
+ input,
+ None,
s"AggFunction_$variableUID",
matchAgg.inputExprs.map(r => FlinkTypeFactory.toTypeInfo(r.getType)),
matchAgg.aggregations.map(_.aggFunction).toArray,
@@ -748,6 +750,7 @@ class MatchCodeGenerator(
needReset = false,
None
)
+ val aggFunc = aggGenerator.generateAggregations
reusableMemberStatements.add(aggFunc.code)
@@ -874,7 +877,7 @@ class MatchCodeGenerator(
)
private case class SingleAggCall(
- aggFunction: TableAggregateFunction[_, _],
+ aggFunction: UserDefinedAggregateFunction[_, _],
inputIndices: Array[Int],
dataViews: Seq[DataViewSpec[_]],
distinctAccIndex: Int
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index 67435bc..fd1dea1 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -23,7 +23,7 @@ import org.apache.calcite.sql.fun._
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.tools.RelBuilder.AggCall
import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.UserDefinedAggregateFunction
import org.apache.flink.table.functions.utils.AggSqlFunction
import org.apache.flink.table.typeutils.TypeCheckUtils
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
@@ -361,7 +361,7 @@ case class VarSamp(child: PlannerExpression) extends Aggregation {
* Expression for calling a user-defined aggregate function.
*/
case class AggFunctionCall(
- aggregateFunction: AggregateFunction[_, _],
+ val aggregateFunction: UserDefinedAggregateFunction[_, _],
resultTypeInfo: TypeInformation[_],
accTypeInfo: TypeInformation[_],
args: Seq[PlannerExpression])
@@ -382,6 +382,7 @@ case class AggFunctionCall(
getMethodSignatures(aggregateFunction, "accumulate")
.map(_.drop(1))
.map(signatureToString)
+ .sorted // make sure order to verify error messages in tests
.mkString(", ")}")
} else {
ValidationSuccess
@@ -409,8 +410,7 @@ case class AggFunctionCall(
aggregateFunction,
resultType,
accTypeInfo,
- typeFactory,
- aggregateFunction.requiresOver)
+ typeFactory)
}
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
index 8dc4d61..8508a01 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
@@ -19,10 +19,8 @@
package org.apache.flink.table.functions.utils
import java.util
-import java.util.Collections
import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.sql
import org.apache.calcite.sql._
import org.apache.calcite.sql.`type`._
import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
@@ -32,16 +30,17 @@ import org.apache.calcite.util.Optionality
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction, UserDefinedAggregateFunction}
import org.apache.flink.table.functions.utils.AggSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
/**
- * Calcite wrapper for user-defined aggregate functions.
+ * Calcite wrapper for user-defined aggregate functions. Current, the aggregate function can be an
+ * [[AggregateFunction]] or a [[TableAggregateFunction]]
*
* @param name function name (used by SQL parser)
* @param displayName name to be displayed in operator name
- * @param aggregateFunction aggregate function to be called
+ * @param aggregateFunction user defined aggregate function to be called
* @param returnType the type information of returned value
* @param accType the type information of the accumulator
* @param typeFactory type factory for converting Flink's between Calcite's types
@@ -49,7 +48,7 @@ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
class AggSqlFunction(
name: String,
displayName: String,
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
val returnType: TypeInformation[_],
val accType: TypeInformation[_],
typeFactory: FlinkTypeFactory,
@@ -68,7 +67,7 @@ class AggSqlFunction(
typeFactory
) {
- def getFunction: AggregateFunction[_, _] = aggregateFunction
+ def getFunction: UserDefinedAggregateFunction[_, _] = aggregateFunction
override def isDeterministic: Boolean = aggregateFunction.isDeterministic
@@ -82,11 +81,15 @@ object AggSqlFunction {
def apply(
name: String,
displayName: String,
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
returnType: TypeInformation[_],
accType: TypeInformation[_],
- typeFactory: FlinkTypeFactory,
- requiresOver: Boolean): AggSqlFunction = {
+ typeFactory: FlinkTypeFactory): AggSqlFunction = {
+
+ val requiresOver = aggregateFunction match {
+ case a: AggregateFunction[_, _] => a.requiresOver()
+ case _ => false
+ }
new AggSqlFunction(
name,
@@ -99,7 +102,7 @@ object AggSqlFunction {
}
private[flink] def createOperandTypeInference(
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
typeFactory: FlinkTypeFactory)
: SqlOperandTypeInference = {
/**
@@ -149,7 +152,7 @@ object AggSqlFunction {
}
}
- private[flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction[_, _])
+ private[flink] def createOperandTypeChecker(aggregateFunction: UserDefinedAggregateFunction[_, _])
: SqlOperandTypeChecker = {
val methods = checkAndExtractMethods(aggregateFunction, "accumulate")
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 8beb77b..f8443b2 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -37,9 +37,7 @@ import org.apache.flink.table.api.dataview._
import org.apache.flink.table.api.{TableEnvImpl, TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.dataview._
-import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction, UserDefinedFunction}
-import org.apache.flink.table.plan.logical._
+import org.apache.flink.table.functions._
import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.util.InstantiationUtil
@@ -99,11 +97,11 @@ object UserDefinedFunctionUtils {
* of [[TypeInformation]]. Elements of the signature can be null (act as a wildcard).
*/
def getAccumulateMethodSignature(
- function: AggregateFunction[_, _],
+ function: UserDefinedAggregateFunction[_, _],
signature: Seq[TypeInformation[_]])
: Option[Array[Class[_]]] = {
val accType = TypeExtractor.createTypeInfo(
- function, classOf[AggregateFunction[_, _]], function.getClass, 1)
+ function, classOf[UserDefinedAggregateFunction[_, _]], function.getClass, 1)
val input = (Array(accType) ++ signature).toSeq
getUserDefinedMethod(
function,
@@ -324,7 +322,7 @@ object UserDefinedFunctionUtils {
def createAggregateSqlFunction(
name: String,
displayName: String,
- aggFunction: AggregateFunction[_, _],
+ aggFunction: UserDefinedAggregateFunction[_, _],
resultType: TypeInformation[_],
accTypeInfo: TypeInformation[_],
typeFactory: FlinkTypeFactory)
@@ -338,8 +336,7 @@ object UserDefinedFunctionUtils {
aggFunction,
resultType,
accTypeInfo,
- typeFactory,
- aggFunction.requiresOver)
+ typeFactory)
}
/**
@@ -573,7 +570,7 @@ object UserDefinedFunctionUtils {
* @return The inferred result type of the AggregateFunction.
*/
def getResultTypeOfAggregateFunction(
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
extractedType: TypeInformation[_] = null)
: TypeInformation[_] = {
@@ -605,7 +602,7 @@ object UserDefinedFunctionUtils {
* @return The inferred accumulator type of the AggregateFunction.
*/
def getAccumulatorTypeOfAggregateFunction(
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
extractedType: TypeInformation[_] = null)
: TypeInformation[_] = {
@@ -638,12 +635,12 @@ object UserDefinedFunctionUtils {
*/
@throws(classOf[InvalidTypesException])
private def extractTypeFromAggregateFunction(
- aggregateFunction: AggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
parameterTypePos: Int): TypeInformation[_] = {
TypeExtractor.createTypeInfo(
aggregateFunction,
- classOf[AggregateFunction[_, _]],
+ classOf[UserDefinedAggregateFunction[_, _]],
aggregateFunction.getClass,
parameterTypePos).asInstanceOf[TypeInformation[_]]
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/OperationTreeBuilder.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/OperationTreeBuilder.scala
index 44acfa6..197f519 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/OperationTreeBuilder.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/OperationTreeBuilder.scala
@@ -38,6 +38,7 @@ import org.apache.flink.util.Preconditions
import _root_.scala.collection.JavaConversions._
import _root_.scala.collection.JavaConverters._
+import _root_.scala.collection.mutable.ListBuffer
/**
* Builder for [[[Operation]] tree.
@@ -159,6 +160,53 @@ class OperationTreeBuilder(private val tableEnv: TableEnvImpl) {
aggregateOperationFactory.createAggregate(resolvedGroupings, resolvedAggregates, child)
}
+ def tableAggregate(
+ groupingExpressions: JList[Expression],
+ tableAggFunction: Expression,
+ child: TableOperation)
+ : TableOperation = {
+
+ // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to
+ // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the
+ // table aggregate function in Step4.
+ var attrNameCntr: Int = 0
+ val usedFieldNames = child.getTableSchema.getFieldNames.toBuffer
+ val newGroupingExpressions = groupingExpressions.map {
+ case c: CallExpression
+ if !c.getFunctionDefinition.getName.equals(BuiltInFunctionDefinitions.AS.getName) => {
+ val tempName = getUniqueName("TMP_" + attrNameCntr, usedFieldNames)
+ usedFieldNames.append(tempName)
+ attrNameCntr += 1
+ new CallExpression(
+ BuiltInFunctionDefinitions.AS,
+ Seq(c, new ValueLiteralExpression(tempName))
+ )
+ }
+ case e => e
+ }
+
+ // Step2: resolve expressions
+ val resolver = resolverFor(tableCatalog, functionCatalog, child).build
+ val resolvedGroupings = resolver.resolve(newGroupingExpressions)
+ val resolvedFunctionAndAlias = aggregateOperationFactory.extractTableAggFunctionAndAliases(
+ resolveSingleExpression(tableAggFunction, resolver))
+
+ // Step3: create table agg operation
+ val tableAggOperation = aggregateOperationFactory
+ .createAggregate(resolvedGroupings, Seq(resolvedFunctionAndAlias.f0), child)
+
+ // Step4: add a top project to alias the output fields of the table aggregate.
+ val aliasName = resolvedFunctionAndAlias.f1
+ if (aliasName.nonEmpty) {
+ val namesBeforeAlias = tableAggOperation.getTableSchema.getFieldNames
+ val namesAfterAlias = namesBeforeAlias.dropRight(aliasName.size()) ++ aliasName
+ this.alias(namesAfterAlias.map(e =>
+ new UnresolvedReferenceExpression(e)).toList, tableAggOperation)
+ } else {
+ tableAggOperation
+ }
+ }
+
def windowAggregate(
groupingExpressions: JList[Expression],
window: GroupWindow,
@@ -342,16 +390,6 @@ class OperationTreeBuilder(private val tableEnv: TableEnvImpl) {
UserDefinedFunctionUtils.getFieldInfo(tfd.getResultType)._1
}
- def getUniqueName(inputName: String, usedFieldNames: Seq[String]): String = {
- var i = 0
- var resultName = inputName
- while (usedFieldNames.contains(resultName)) {
- resultName = resultName + "_" + i
- i += 1
- }
- resultName
- }
-
val usedFieldNames = child.getTableSchema.getFieldNames.toBuffer
val newFieldNames = originFieldNames.map({ e =>
val resultName = getUniqueName(e, usedFieldNames)
@@ -369,6 +407,19 @@ class OperationTreeBuilder(private val tableEnv: TableEnvImpl) {
alias(originFieldNames.map(a => new UnresolvedReferenceExpression(a)), rightNode)
}
+ /**
+ * Return a unique name that does not exist in usedFieldNames according to the input name.
+ */
+ private def getUniqueName(inputName: String, usedFieldNames: Seq[String]): String = {
+ var i = 0
+ var resultName = inputName
+ while (usedFieldNames.contains(resultName)) {
+ resultName = resultName + "_" + i
+ i += 1
+ }
+ resultName
+ }
+
class NoWindowPropertyChecker(val exceptionMessage: String)
extends ApiExpressionDefaultVisitor[Void] {
override def visitCall(call: CallExpression): Void = {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalTableAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalTableAggregate.scala
new file mode 100644
index 0000000..66506b4
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalTableAggregate.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.logical.rel
+
+import java.util
+
+import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
+import org.apache.calcite.rel.logical.LogicalAggregate
+import org.apache.calcite.rel.{RelNode, SingleRel}
+import org.apache.calcite.util.ImmutableBitSet
+import org.apache.flink.table.plan.nodes.CommonTableAggregate
+
+/**
+ * Logical Node for TableAggregate.
+ */
+class LogicalTableAggregate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ child: RelNode,
+ val indicator: Boolean,
+ val groupSet: ImmutableBitSet,
+ val groupSets: util.List[ImmutableBitSet],
+ val aggCalls: util.List[AggregateCall])
+ extends SingleRel(cluster, traitSet, child)
+ with CommonTableAggregate {
+
+ override def deriveRowType(): RelDataType = {
+ deriveTableAggRowType(cluster, child, groupSet, aggCalls)
+ }
+
+ override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
+ new LogicalTableAggregate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ indicator,
+ groupSet,
+ groupSets,
+ aggCalls
+ )
+ }
+}
+
+object LogicalTableAggregate {
+
+ def create(aggregate: Aggregate): LogicalTableAggregate = {
+
+ new LogicalTableAggregate(
+ aggregate.getCluster,
+ aggregate.getCluster.traitSetOf(Convention.NONE),
+ aggregate.getInput,
+ aggregate.indicator,
+ aggregate.getGroupSet,
+ aggregate.getGroupSets,
+ aggregate.getAggCallList)
+ }
+
+ def getCorrespondingAggregate(tableAgg: LogicalTableAggregate): LogicalAggregate = {
+ new LogicalAggregate(
+ tableAgg.getCluster,
+ tableAgg.getTraitSet,
+ tableAgg.getInput,
+ tableAgg.indicator,
+ tableAgg.groupSet,
+ tableAgg.groupSets,
+ tableAgg.aggCalls
+ )
+ }
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala
index 7960c8c..0190cf8 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala
@@ -25,6 +25,7 @@ import FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.runtime.aggregate.AggregateUtil._
import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
trait CommonAggregate {
@@ -35,16 +36,14 @@ trait CommonAggregate {
}
private[flink] def aggregationToString(
- inputType: RelDataType,
- grouping: Array[Int],
- rowType: RelDataType,
- namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- namedProperties: Seq[NamedWindowProperty])
- : String = {
+ inputType: RelDataType,
+ grouping: Array[Int],
+ outFields: Seq[String],
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ namedProperties: Seq[NamedWindowProperty])
+ : String = {
val inFields = inputType.getFieldNames.asScala
- val outFields = rowType.getFieldNames.asScala
-
val groupStrings = grouping.map( inFields(_) )
val aggs = namedAggregates.map(_.getKey)
@@ -67,4 +66,15 @@ trait CommonAggregate {
}
}.mkString(", ")
}
+
+ private[flink] def aggregationToString(
+ inputType: RelDataType,
+ grouping: Array[Int],
+ rowType: RelDataType,
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ namedProperties: Seq[NamedWindowProperty])
+ : String = {
+ aggregationToString(
+ inputType, grouping, rowType.getFieldNames, namedAggregates, namedProperties)
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonTableAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonTableAggregate.scala
new file mode 100644
index 0000000..09eec51
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonTableAggregate.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes
+
+import java.util
+
+import org.apache.calcite.plan.RelOptCluster
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.util.{ImmutableBitSet, Pair, Util}
+import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
+import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
+
+import scala.collection.JavaConversions._
+
+trait CommonTableAggregate extends CommonAggregate {
+
+ protected def deriveTableAggRowType(
+ cluster: RelOptCluster,
+ child: RelNode,
+ groupSet: ImmutableBitSet,
+ aggCalls: util.List[AggregateCall]): RelDataType = {
+
+ val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ val builder = typeFactory.builder
+
+ // group key fields
+ groupSet.asList().foreach(e => {
+ val field = child.getRowType.getFieldList.get(e)
+ builder.add(field)
+ })
+
+ // agg fields
+ aggCalls.get(0).`type`.getFieldList.foreach(builder.add)
+ builder.build()
+ }
+
+ override private[flink] def aggregationToString(
+ inputType: RelDataType,
+ grouping: Array[Int],
+ rowType: RelDataType,
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ namedProperties: Seq[FlinkRelBuilder.NamedWindowProperty]): String = {
+
+ val outFields = rowType.getFieldNames
+ val tableAggOutputArity = namedAggregates.head.left.getType.getFieldCount
+ val groupSize = grouping.size
+ val outFieldsOfTableAgg = outFields.subList(groupSize, groupSize + tableAggOutputArity)
+ val tableAggOutputFields = Seq(s"(${outFieldsOfTableAgg.mkString(", ")})")
+
+ val newOutFields = outFields.subList(0, groupSize) ++
+ tableAggOutputFields ++
+ outFields.drop(groupSize + tableAggOutputArity)
+
+ aggregationToString(inputType, grouping, newOutFields, namedAggregates, namedProperties)
+ }
+
+ private[flink] def getNamedAggCalls(
+ aggCalls: util.List[AggregateCall],
+ rowType: RelDataType,
+ indicator: Boolean,
+ groupSet: ImmutableBitSet)
+ : util.List[Pair[AggregateCall, String]] = {
+
+ def getGroupCount: Int = groupSet.cardinality
+ def getIndicatorCount: Int = if (indicator) getGroupCount else 0
+
+ val offset = getGroupCount + getIndicatorCount
+ Pair.zip(aggCalls, Util.skip(rowType.getFieldNames, offset))
+ }
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
index 0a7d69b..dde1953 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
@@ -94,18 +94,15 @@ class DataSetAggregate(
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
- val generator = new AggregationCodeGenerator(
- tableEnv.getConfig,
- false,
- inputDS.getType,
- None)
-
val (
preAgg: Option[DataSetPreAggFunction],
preAggType: Option[TypeInformation[Row]],
finalAgg: Either[DataSetAggFunction, DataSetFinalAggFunction]
) = AggregateUtil.createDataSetAggregateFunctions(
- generator,
+ tableEnv.getConfig,
+ false,
+ inputDS.getType,
+ None,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
index eef1982..ce18cf2 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
@@ -22,13 +22,14 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rex.RexLiteral
import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.{ResultTypeQueryable, RowTypeInfo}
import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvImpl, TableConfig}
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.codegen.AggregationCodeGenerator
import org.apache.flink.table.expressions.PlannerExpressionUtils._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.nodes.CommonAggregate
@@ -111,12 +112,6 @@ class DataSetWindowAggregate(
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv, queryConfig)
- val generator = new AggregationCodeGenerator(
- tableEnv.getConfig,
- false,
- inputDS.getType,
- None)
-
// whether identifiers are matched case-sensitively
val caseSensitive = tableEnv.getFrameworkConfig.getParserConfig.caseSensitive()
@@ -124,7 +119,10 @@ class DataSetWindowAggregate(
case TumblingGroupWindow(_, timeField, size)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
createEventTimeTumblingWindowDataSet(
- generator,
+ tableEnv.getConfig,
+ false,
+ inputDS.getType,
+ None,
inputDS,
isTimeIntervalLiteral(size),
caseSensitive,
@@ -132,12 +130,22 @@ class DataSetWindowAggregate(
case SessionGroupWindow(_, timeField, gap)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
- createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive, tableEnv.getConfig)
+ createEventTimeSessionWindowDataSet(
+ tableEnv.getConfig,
+ false,
+ inputDS.getType,
+ None,
+ inputDS,
+ caseSensitive,
+ tableEnv.getConfig)
case SlidingGroupWindow(_, timeField, size, slide)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
createEventTimeSlidingWindowDataSet(
- generator,
+ tableEnv.getConfig,
+ false,
+ inputDS.getType,
+ None,
inputDS,
isTimeIntervalLiteral(size),
asLong(size),
@@ -152,7 +160,10 @@ class DataSetWindowAggregate(
}
private def createEventTimeTumblingWindowDataSet(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputType: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
inputDS: DataSet[Row],
isTimeWindow: Boolean,
isParserCaseSensitive: Boolean,
@@ -161,7 +172,10 @@ class DataSetWindowAggregate(
val input = inputNode.asInstanceOf[DataSetRel]
val mapFunction = createDataSetWindowPrepareMapFunction(
- generator,
+ config,
+ nullableInput,
+ inputType,
+ constants,
window,
namedAggregates,
grouping,
@@ -170,7 +184,10 @@ class DataSetWindowAggregate(
isParserCaseSensitive,
tableConfig)
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
- generator,
+ config,
+ nullableInput,
+ inputType,
+ constants,
window,
namedAggregates,
input.getRowType,
@@ -218,7 +235,10 @@ class DataSetWindowAggregate(
}
private[this] def createEventTimeSessionWindowDataSet(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
inputDS: DataSet[Row],
isParserCaseSensitive: Boolean,
tableConfig: TableConfig): DataSet[Row] = {
@@ -230,7 +250,10 @@ class DataSetWindowAggregate(
// create mapFunction for initializing the aggregations
val mapFunction = createDataSetWindowPrepareMapFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
grouping,
@@ -260,7 +283,10 @@ class DataSetWindowAggregate(
if (groupingKeys.length > 0) {
// create groupCombineFunction for combine the aggregations
val combineGroupFunction = createDataSetWindowAggregationCombineFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
input.getRowType,
@@ -270,7 +296,10 @@ class DataSetWindowAggregate(
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
input.getRowType,
@@ -294,7 +323,10 @@ class DataSetWindowAggregate(
} else {
// non-grouping window
val mapPartitionFunction = createDataSetWindowAggregationMapPartitionFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
input.getRowType,
@@ -304,7 +336,10 @@ class DataSetWindowAggregate(
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
input.getRowType,
@@ -331,7 +366,10 @@ class DataSetWindowAggregate(
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
input.getRowType,
@@ -349,7 +387,10 @@ class DataSetWindowAggregate(
} else {
// non-grouping window
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
input.getRowType,
@@ -369,7 +410,10 @@ class DataSetWindowAggregate(
}
private def createEventTimeSlidingWindowDataSet(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
inputDS: DataSet[Row],
isTimeWindow: Boolean,
size: Long,
@@ -383,7 +427,10 @@ class DataSetWindowAggregate(
// create MapFunction for initializing the aggregations
// it aligns the rowtime for pre-tumbling in case of a time-window for partial aggregates
val mapFunction = createDataSetWindowPrepareMapFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
grouping,
@@ -422,7 +469,10 @@ class DataSetWindowAggregate(
// create GroupReduceFunction
// for pre-tumbling and replicating/omitting the content for each pane
val prepareReduceFunction = createDataSetSlideWindowPrepareGroupReduceFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
grouping,
@@ -461,7 +511,10 @@ class DataSetWindowAggregate(
// create GroupReduceFunction for final aggregation and conversion to output row
val aggregateReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
window,
namedAggregates,
input.getRowType,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
index fa4f07f..44e3ffd 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
@@ -17,25 +17,17 @@
*/
package org.apache.flink.table.plan.nodes.datastream
-import java.lang.{Byte => JByte}
-
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
-import org.apache.flink.api.java.functions.NullByteKeySelector
-import org.apache.flink.streaming.api.datastream.DataStream
+import org.apache.calcite.rel.RelNode
import org.apache.flink.streaming.api.functions.KeyedProcessFunction
-import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvImpl}
-import org.apache.flink.table.codegen.AggregationCodeGenerator
-import org.apache.flink.table.plan.nodes.CommonAggregate
+import org.apache.flink.table.api.{StreamQueryConfig, TableConfig}
import org.apache.flink.table.plan.rules.datastream.DataStreamRetractionRules
import org.apache.flink.table.plan.schema.RowSchema
-import org.apache.flink.table.runtime.CRowKeySelector
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
import org.apache.flink.table.runtime.aggregate._
-import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
+import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.table.util.Logging
-import org.apache.flink.types.Row
/**
*
@@ -58,21 +50,18 @@ class DataStreamGroupAggregate(
schema: RowSchema,
inputSchema: RowSchema,
groupings: Array[Int])
- extends SingleRel(cluster, traitSet, inputNode)
- with CommonAggregate
+ extends DataStreamGroupAggregateBase(
+ cluster,
+ traitSet,
+ inputNode,
+ namedAggregates,
+ schema,
+ inputSchema,
+ groupings,
+ "Aggregate")
with DataStreamRel
with Logging {
- override def deriveRowType() = schema.relDataType
-
- override def needsUpdatesAsRetraction = true
-
- override def producesUpdates = true
-
- override def consumesRetractions = true
-
- def getGroupings: Array[Int] = groupings
-
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamGroupAggregate(
cluster,
@@ -84,92 +73,22 @@ class DataStreamGroupAggregate(
groupings)
}
- override def toString: String = {
- s"Aggregate(${
- if (!groupings.isEmpty) {
- s"groupBy: (${groupingToString(inputSchema.relDataType, groupings)}), "
- } else {
- ""
- }
- }select:(${aggregationToString(
- inputSchema.relDataType, groupings, getRowType, namedAggregates, Nil)}))"
- }
-
- override def explainTerms(pw: RelWriter): RelWriter = {
- super.explainTerms(pw)
- .itemIf("groupBy", groupingToString(
- inputSchema.relDataType, groupings), !groupings.isEmpty)
- .item("select", aggregationToString(
- inputSchema.relDataType, groupings, getRowType, namedAggregates, Nil))
- }
-
- override def translateToPlan(
- tableEnv: StreamTableEnvImpl,
- queryConfig: StreamQueryConfig): DataStream[CRow] = {
+ override def createKeyedProcessFunction[K](
+ tableConfig: TableConfig,
+ queryConfig: StreamQueryConfig): KeyedProcessFunction[K, CRow, CRow] = {
- if (groupings.length > 0 && queryConfig.getMinIdleStateRetentionTime < 0) {
- LOG.warn(
- "No state retention interval configured for a query which accumulates state. " +
- "Please provide a query configuration with valid retention interval to prevent excessive " +
- "state size. You may specify a retention time of 0 to not clean up the state.")
- }
-
- val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
-
- val outRowType = CRowTypeInfo(schema.typeInfo)
-
- val generator = new AggregationCodeGenerator(
- tableEnv.getConfig,
+ AggregateUtil.createGroupAggregateFunction[K](
+ tableConfig,
false,
inputSchema.typeInfo,
- None)
-
- val aggString = aggregationToString(
+ None,
+ namedAggregates,
inputSchema.relDataType,
+ inputSchema.fieldTypeInfos,
groupings,
- getRowType,
- namedAggregates,
- Nil)
-
- val keyedAggOpName = s"groupBy: (${groupingToString(inputSchema.relDataType, groupings)}), " +
- s"select: ($aggString)"
- val nonKeyedAggOpName = s"select: ($aggString)"
-
- def createKeyedProcessFunction[K]: KeyedProcessFunction[K, CRow, CRow] = {
- AggregateUtil.createGroupAggregateFunction[K](
- generator,
- namedAggregates,
- inputSchema.relDataType,
- inputSchema.fieldTypeInfos,
- groupings,
- queryConfig,
- tableEnv.getConfig,
- DataStreamRetractionRules.isAccRetract(this),
- DataStreamRetractionRules.isAccRetract(getInput))
- }
-
- val result: DataStream[CRow] =
- // grouped / keyed aggregation
- if (groupings.nonEmpty) {
- inputDS
- .keyBy(new CRowKeySelector(groupings, inputSchema.projectedTypeInfo(groupings)))
- .process(createKeyedProcessFunction[Row])
- .returns(outRowType)
- .name(keyedAggOpName)
- .asInstanceOf[DataStream[CRow]]
- }
- // global / non-keyed aggregation
- else {
- inputDS
- .keyBy(new NullByteKeySelector[CRow])
- .process(createKeyedProcessFunction[JByte])
- .setParallelism(1)
- .setMaxParallelism(1)
- .returns(outRowType)
- .name(nonKeyedAggOpName)
- .asInstanceOf[DataStream[CRow]]
- }
- result
+ queryConfig,
+ DataStreamRetractionRules.isAccRetract(this),
+ DataStreamRetractionRules.isAccRetract(getInput))
}
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregateBase.scala
similarity index 80%
copy from flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
copy to flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregateBase.scala
index fa4f07f..334161a 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregateBase.scala
@@ -25,21 +25,18 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.flink.api.java.functions.NullByteKeySelector
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.streaming.api.functions.KeyedProcessFunction
-import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvImpl}
-import org.apache.flink.table.codegen.AggregationCodeGenerator
+import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvImpl, TableConfig}
import org.apache.flink.table.plan.nodes.CommonAggregate
-import org.apache.flink.table.plan.rules.datastream.DataStreamRetractionRules
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.CRowKeySelector
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
-import org.apache.flink.table.runtime.aggregate._
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.table.util.Logging
import org.apache.flink.types.Row
/**
*
- * Flink RelNode for data stream unbounded group aggregate
+ * Base RelNode for data stream unbounded group aggregate and unbounded group table aggregate.
*
* @param cluster Cluster of the RelNode, represent for an environment of related
* relational expressions during the optimization of a query.
@@ -49,15 +46,17 @@ import org.apache.flink.types.Row
* @param inputSchema The type of the rows consumed by this RelNode
* @param schema The type of the rows emitted by this RelNode
* @param groupings The position (in the input Row) of the grouping keys
+ * @param aggTypeName The type name of aggregate
*/
-class DataStreamGroupAggregate(
+abstract class DataStreamGroupAggregateBase(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputNode: RelNode,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
schema: RowSchema,
inputSchema: RowSchema,
- groupings: Array[Int])
+ groupings: Array[Int],
+ aggTypeName: String)
extends SingleRel(cluster, traitSet, inputNode)
with CommonAggregate
with DataStreamRel
@@ -73,19 +72,8 @@ class DataStreamGroupAggregate(
def getGroupings: Array[Int] = groupings
- override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
- new DataStreamGroupAggregate(
- cluster,
- traitSet,
- inputs.get(0),
- namedAggregates,
- schema,
- inputSchema,
- groupings)
- }
-
override def toString: String = {
- s"Aggregate(${
+ s"$aggTypeName(${
if (!groupings.isEmpty) {
s"groupBy: (${groupingToString(inputSchema.relDataType, groupings)}), "
} else {
@@ -103,6 +91,10 @@ class DataStreamGroupAggregate(
inputSchema.relDataType, groupings, getRowType, namedAggregates, Nil))
}
+ protected def createKeyedProcessFunction[K](
+ tableConfig: TableConfig,
+ queryConfig: StreamQueryConfig): KeyedProcessFunction[K, CRow, CRow]
+
override def translateToPlan(
tableEnv: StreamTableEnvImpl,
queryConfig: StreamQueryConfig): DataStream[CRow] = {
@@ -118,12 +110,6 @@ class DataStreamGroupAggregate(
val outRowType = CRowTypeInfo(schema.typeInfo)
- val generator = new AggregationCodeGenerator(
- tableEnv.getConfig,
- false,
- inputSchema.typeInfo,
- None)
-
val aggString = aggregationToString(
inputSchema.relDataType,
groupings,
@@ -135,25 +121,12 @@ class DataStreamGroupAggregate(
s"select: ($aggString)"
val nonKeyedAggOpName = s"select: ($aggString)"
- def createKeyedProcessFunction[K]: KeyedProcessFunction[K, CRow, CRow] = {
- AggregateUtil.createGroupAggregateFunction[K](
- generator,
- namedAggregates,
- inputSchema.relDataType,
- inputSchema.fieldTypeInfos,
- groupings,
- queryConfig,
- tableEnv.getConfig,
- DataStreamRetractionRules.isAccRetract(this),
- DataStreamRetractionRules.isAccRetract(getInput))
- }
-
val result: DataStream[CRow] =
// grouped / keyed aggregation
if (groupings.nonEmpty) {
inputDS
.keyBy(new CRowKeySelector(groupings, inputSchema.projectedTypeInfo(groupings)))
- .process(createKeyedProcessFunction[Row])
+ .process(createKeyedProcessFunction[Row](tableEnv.getConfig, queryConfig))
.returns(outRowType)
.name(keyedAggOpName)
.asInstanceOf[DataStream[CRow]]
@@ -162,7 +135,7 @@ class DataStreamGroupAggregate(
else {
inputDS
.keyBy(new NullByteKeySelector[CRow])
- .process(createKeyedProcessFunction[JByte])
+ .process(createKeyedProcessFunction[JByte](tableEnv.getConfig, queryConfig))
.setParallelism(1)
.setMaxParallelism(1)
.returns(outRowType)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupTableAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupTableAggregate.scala
new file mode 100644
index 0000000..1a4adc9
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupTableAggregate.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan.nodes.datastream
+
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.rel.RelNode
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction
+import org.apache.flink.table.api.{StreamQueryConfig, TableConfig}
+import org.apache.flink.table.plan.nodes.CommonTableAggregate
+import org.apache.flink.table.plan.rules.datastream.DataStreamRetractionRules
+import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
+import org.apache.flink.table.runtime.aggregate._
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.util.Logging
+
+/**
+ * Flink RelNode for data stream unbounded table aggregate.
+ *
+ * @param cluster Cluster of the RelNode, represent for an environment of related
+ * relational expressions during the optimization of a query.
+ * @param traitSet Trait set of the RelNode
+ * @param inputNode The input RelNode of aggregation
+ * @param schema The type of the rows emitted by this RelNode
+ * @param inputSchema The type of the rows consumed by this RelNode
+ * @param namedAggregates List of calls to aggregate functions and their output field names
+ * @param groupings The position (in the input Row) of the grouping keys
+ */
+class DataStreamGroupTableAggregate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ inputNode: RelNode,
+ schema: RowSchema,
+ inputSchema: RowSchema,
+ val namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ val groupings: Array[Int])
+ extends DataStreamGroupAggregateBase(
+ cluster,
+ traitSet,
+ inputNode,
+ namedAggregates,
+ schema,
+ inputSchema,
+ groupings,
+ "TableAggregate")
+ with CommonTableAggregate
+ with DataStreamRel
+ with Logging {
+
+ override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+ new DataStreamGroupTableAggregate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ schema,
+ inputSchema,
+ namedAggregates,
+ groupings)
+ }
+
+ override def createKeyedProcessFunction[K](
+ tableConfig: TableConfig,
+ queryConfig: StreamQueryConfig): KeyedProcessFunction[K, CRow, CRow] = {
+
+ val tableAggOutputRowType = new RowTypeInfo(
+ schema.fieldTypeInfos.drop(groupings.length).toArray,
+ schema.fieldNames.drop(groupings.length).toArray)
+
+ AggregateUtil.createGroupTableAggregateFunction[K](
+ tableConfig,
+ false,
+ inputSchema.typeInfo,
+ None,
+ namedAggregates,
+ inputSchema.relDataType,
+ inputSchema.fieldTypeInfos,
+ tableAggOutputRowType,
+ groupings,
+ queryConfig,
+ DataStreamRetractionRules.isAccRetract(this),
+ DataStreamRetractionRules.isAccRetract(getInput))
+ }
+}
+
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
index aa69f93..d7a1851 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
@@ -28,7 +28,6 @@ import org.apache.flink.streaming.api.windowing.triggers.PurgingTrigger
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvImpl, TableException}
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
-import org.apache.flink.table.codegen.AggregationCodeGenerator
import org.apache.flink.table.expressions.PlannerExpressionUtils._
import org.apache.flink.table.expressions.ResolvedFieldReference
import org.apache.flink.table.plan.logical._
@@ -173,12 +172,6 @@ class DataStreamGroupWindowAggregate(
s"select: ($aggString)"
val nonKeyedAggOpName = s"window: ($window), select: ($aggString)"
- val generator = new AggregationCodeGenerator(
- tableEnv.getConfig,
- false,
- inputSchema.typeInfo,
- None)
-
val needMerge = window match {
case SessionGroupWindow(_, _, _) => true
case _ => false
@@ -201,7 +194,10 @@ class DataStreamGroupWindowAggregate(
val (aggFunction, accumulatorRowType) =
AggregateUtil.createDataStreamAggregateFunction(
- generator,
+ tableEnv.getConfig,
+ false,
+ inputSchema.typeInfo,
+ None,
namedAggregates,
inputSchema.relDataType,
inputSchema.fieldTypeInfos,
@@ -227,7 +223,10 @@ class DataStreamGroupWindowAggregate(
val (aggFunction, accumulatorRowType) =
AggregateUtil.createDataStreamAggregateFunction(
- generator,
+ tableEnv.getConfig,
+ false,
+ inputSchema.typeInfo,
+ None,
namedAggregates,
inputSchema.relDataType,
inputSchema.fieldTypeInfos,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
index 3df866d..2d2fb87 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
@@ -27,12 +27,12 @@ import org.apache.calcite.rel.core.Window.Group
import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex.RexLiteral
+import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.NullByteKeySelector
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.streaming.api.functions.KeyedProcessFunction
import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvImpl, TableConfig, TableException}
import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.codegen.AggregationCodeGenerator
import org.apache.flink.table.plan.nodes.OverAggregate
import org.apache.flink.table.plan.rules.datastream.DataStreamRetractionRules
import org.apache.flink.table.plan.schema.RowSchema
@@ -142,12 +142,6 @@ class DataStreamOverAggregate(
val constants: Seq[RexLiteral] = logicWindow.constants.asScala
- val generator = new AggregationCodeGenerator(
- tableEnv.getConfig,
- false,
- inputSchema.typeInfo,
- Some(constants))
-
val constantTypes = constants.map(_.getType)
val fieldTypes = input.getRowType.getFieldList.asScala.map(_.getType)
val aggInTypes = fieldTypes ++ constantTypes
@@ -176,7 +170,9 @@ class DataStreamOverAggregate(
createUnboundedAndCurrentRowOverWindow(
queryConfig,
tableEnv.getConfig,
- generator,
+ false,
+ inputSchema.typeInfo,
+ Some(constants),
inputDS,
rowTimeIdx,
aggregateInputType,
@@ -188,7 +184,10 @@ class DataStreamOverAggregate(
// bounded OVER window
createBoundedAndCurrentRowOverWindow(
queryConfig,
- generator,
+ tableEnv.getConfig,
+ false,
+ inputSchema.typeInfo,
+ Some(constants),
inputDS,
rowTimeIdx,
aggregateInputType,
@@ -202,7 +201,9 @@ class DataStreamOverAggregate(
def createUnboundedAndCurrentRowOverWindow(
queryConfig: StreamQueryConfig,
tableConfig: TableConfig,
- generator: AggregationCodeGenerator,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
inputDS: DataStream[CRow],
rowTimeIdx: Option[Int],
aggregateInputType: RelDataType,
@@ -219,7 +220,10 @@ class DataStreamOverAggregate(
def createKeyedProcessFunction[K]: KeyedProcessFunction[K, CRow, CRow] = {
AggregateUtil.createUnboundedOverProcessFunction[K](
- generator,
+ tableConfig,
+ nullableInput,
+ inputTypeInfo,
+ constants,
namedAggregates,
aggregateInputType,
inputSchema.relDataType,
@@ -254,7 +258,10 @@ class DataStreamOverAggregate(
def createBoundedAndCurrentRowOverWindow(
queryConfig: StreamQueryConfig,
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
inputDS: DataStream[CRow],
rowTimeIdx: Option[Int],
aggregateInputType: RelDataType,
@@ -275,7 +282,10 @@ class DataStreamOverAggregate(
def createKeyedProcessFunction[K]: KeyedProcessFunction[K, CRow, CRow] = {
AggregateUtil.createBoundedOverProcessFunction[K](
- generator,
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
namedAggregates,
aggregateInputType,
inputSchema.relDataType,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableAggregate.scala
new file mode 100644
index 0000000..9334716
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableAggregate.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes.logical
+
+import java.util
+import java.util.{List => JList}
+
+import org.apache.calcite.plan._
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.{RelNode, SingleRel}
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.util.ImmutableBitSet
+import org.apache.flink.table.plan.logical.rel.LogicalTableAggregate
+import org.apache.flink.table.plan.nodes.{CommonTableAggregate, FlinkConventions}
+
+class FlinkLogicalTableAggregate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ child: RelNode,
+ val indicator: Boolean,
+ val groupSet: ImmutableBitSet,
+ groupSets: util.List[ImmutableBitSet],
+ val aggCalls: util.List[AggregateCall])
+ extends SingleRel(cluster, traitSet, child)
+ with FlinkLogicalRel
+ with CommonTableAggregate {
+
+ override def copy(traitSet: RelTraitSet, inputs: JList[RelNode]): RelNode = {
+ new FlinkLogicalTableAggregate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ indicator,
+ groupSet,
+ groupSets,
+ aggCalls
+ )
+ }
+
+ override def deriveRowType(): RelDataType = {
+ deriveTableAggRowType(cluster, child, groupSet, aggCalls)
+ }
+}
+
+private class FlinkLogicalTableAggregateConverter
+ extends ConverterRule(
+ classOf[LogicalTableAggregate],
+ Convention.NONE,
+ FlinkConventions.LOGICAL,
+ "FlinkLogicalTableAggregateConverter") {
+
+ override def convert(rel: RelNode): RelNode = {
+ val agg = rel.asInstanceOf[LogicalTableAggregate]
+ val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL)
+ val newInput = RelOptRule.convert(agg.getInput, FlinkConventions.LOGICAL)
+
+ new FlinkLogicalTableAggregate(
+ rel.getCluster,
+ traitSet,
+ newInput,
+ agg.indicator,
+ agg.groupSet,
+ agg.groupSets,
+ agg.aggCalls)
+ }
+}
+
+object FlinkLogicalTableAggregate {
+ val CONVERTER: ConverterRule = new FlinkLogicalTableAggregateConverter()
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index bf83fda..3107700 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -140,7 +140,8 @@ object FlinkRuleSets {
FlinkLogicalTableSourceScan.CONVERTER,
FlinkLogicalTableFunctionScan.CONVERTER,
FlinkLogicalNativeTableScan.CONVERTER,
- FlinkLogicalMatch.CONVERTER
+ FlinkLogicalMatch.CONVERTER,
+ FlinkLogicalTableAggregate.CONVERTER
)
/**
@@ -231,7 +232,8 @@ object FlinkRuleSets {
DataStreamJoinRule.INSTANCE,
DataStreamTemporalTableJoinRule.INSTANCE,
StreamTableSourceScanRule.INSTANCE,
- DataStreamMatchRule.INSTANCE
+ DataStreamMatchRule.INSTANCE,
+ DataStreamTableAggregateRule.INSTANCE
)
/**
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamTableAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamTableAggregateRule.scala
new file mode 100644
index 0000000..58ab2c7
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamTableAggregateRule.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.datastream
+
+import org.apache.calcite.plan.{RelOptRule, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.flink.table.plan.nodes.FlinkConventions
+import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupTableAggregate
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableAggregate
+import org.apache.flink.table.plan.schema.RowSchema
+
+import scala.collection.JavaConversions._
+
+/**
+ * Rule to convert a [[FlinkLogicalTableAggregate]] into a [[DataStreamGroupTableAggregate]].
+ */
+class DataStreamTableAggregateRule
+ extends ConverterRule(
+ classOf[FlinkLogicalTableAggregate],
+ FlinkConventions.LOGICAL,
+ FlinkConventions.DATASTREAM,
+ "DataStreamTableAggregateRule") {
+
+ override def convert(rel: RelNode): RelNode = {
+ val agg: FlinkLogicalTableAggregate = rel.asInstanceOf[FlinkLogicalTableAggregate]
+ val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM)
+ val convInput: RelNode = RelOptRule.convert(agg.getInput, FlinkConventions.DATASTREAM)
+
+ new DataStreamGroupTableAggregate(
+ rel.getCluster,
+ traitSet,
+ convInput,
+ new RowSchema(rel.getRowType),
+ new RowSchema(agg.getInput.getRowType),
+ agg.getNamedAggCalls(agg.aggCalls, agg.deriveRowType(), agg.indicator, agg.groupSet),
+ agg.groupSet.toArray)
+ }
+}
+
+object DataStreamTableAggregateRule {
+ val INSTANCE: DataStreamTableAggregateRule = new DataStreamTableAggregateRule()
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 1d35082..b135bc7 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -22,6 +22,7 @@ import java.util.{ArrayList => JArrayList, List => JList}
import org.apache.calcite.rel.`type`._
import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.rex.RexLiteral
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.fun._
@@ -43,7 +44,7 @@ import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.aggfunctions._
import org.apache.flink.table.functions.utils.AggSqlFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
-import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction}
+import org.apache.flink.table.functions.{AggregateFunction, UserDefinedAggregateFunction}
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.table.typeutils.TypeCheckUtils._
@@ -62,18 +63,25 @@ object AggregateUtil {
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for unbounded OVER
* window to evaluate final aggregate value.
*
- * @param generator code generator instance
- * @param namedAggregates Physical calls to aggregate functions and their output field names
+ * @param config configuration that determines runtime behavior
+ * @param nullableInput input(s) can be null.
+ * @param input type information about the input of the Function
+ * @param constants constant expressions that act like a second input in the
+ * parameter indices.
+ * @param namedAggregates Physical calls to aggregate functions and their output field names
* @param aggregateInputType Physical type of the aggregate functions's input row.
- * @param inputType Physical type of the row.
- * @param inputTypeInfo Physical type information of the row.
+ * @param inputType Physical type of the row.
+ * @param inputTypeInfo Physical type information of the row.
* @param inputFieldTypeInfo Physical type information of the row's fields.
- * @param rowTimeIdx The index of the rowtime field or None in case of processing time.
- * @param isPartitioned It is a tag that indicate whether the input is partitioned
- * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause
+ * @param rowTimeIdx The index of the rowtime field or None in case of processing time.
+ * @param isPartitioned It is a tag that indicate whether the input is partitioned
+ * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause
*/
private[flink] def createUnboundedOverProcessFunction[K](
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ input: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
aggregateInputType: RelDataType,
inputType: RelDataType,
@@ -100,7 +108,11 @@ object AggregateUtil {
val outputArity = inputType.getFieldCount + aggregateMetadata.getAggregateCallsCount
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ input,
+ constants,
"UnboundedProcessingOverAggregateHelper",
inputFieldTypeInfo,
aggregateMetadata.getAggregateFunctions,
@@ -120,6 +132,8 @@ object AggregateUtil {
val aggregationStateType: RowTypeInfo = new RowTypeInfo(aggregateMetadata
.getAggregatesAccumulatorTypes: _*)
+ val genFunction = generator.generateAggregations
+
if (rowTimeIdx.isDefined) {
if (isRowsClause) {
// ROWS unbounded over process function
@@ -150,24 +164,30 @@ object AggregateUtil {
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for group (without
* window) aggregate to evaluate final aggregate value.
*
- * @param generator code generator instance
- * @param namedAggregates List of calls to aggregate functions and their output field names
- * @param inputRowType Input row type
- * @param inputFieldTypes Types of the physical input fields
- * @param groupings the position (in the input Row) of the grouping keys
- * @param queryConfig The configuration of the query to generate.
+ * @param config configuration that determines runtime behavior
+ * @param nullableInput input(s) can be null.
+ * @param input type information about the input of the Function
+ * @param constants constant expressions that act like a second input in the
+ * parameter indices.
+ * @param namedAggregates List of calls to aggregate functions and their output field names
+ * @param inputRowType Input row type
+ * @param inputFieldTypes Types of the physical input fields
+ * @param groupings the position (in the input Row) of the grouping keys
+ * @param queryConfig The configuration of the query to generate.
* @param generateRetraction It is a tag that indicates whether generate retract record.
- * @param consumeRetraction It is a tag that indicates whether consume the retract record.
+ * @param consumeRetraction It is a tag that indicates whether consume the retract record.
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
private[flink] def createGroupAggregateFunction[K](
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ input: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputRowType: RelDataType,
inputFieldTypes: Seq[TypeInformation[_]],
groupings: Array[Int],
queryConfig: StreamQueryConfig,
- tableConfig: TableConfig,
generateRetraction: Boolean,
consumeRetraction: Boolean): KeyedProcessFunction[K, CRow, CRow] = {
@@ -176,13 +196,17 @@ object AggregateUtil {
inputRowType,
inputFieldTypes.length,
consumeRetraction,
- tableConfig,
+ config,
isStateBackedDataViews = true)
val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
val outputArity = groupings.length + aggregateMetadata.getAggregateCallsCount
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config: TableConfig,
+ nullableInput: Boolean,
+ input: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
"NonWindowedAggregationHelper",
inputFieldTypes,
aggregateMetadata.getAggregateFunctions,
@@ -203,7 +227,84 @@ object AggregateUtil {
val aggregationStateType: RowTypeInfo = new RowTypeInfo(aggregateMetadata
.getAggregatesAccumulatorTypes: _*)
new GroupAggProcessFunction[K](
- genFunction,
+ generator.generateAggregations,
+ aggregationStateType,
+ generateRetraction,
+ queryConfig)
+
+ }
+
+ /**
+ * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for group (without
+ * window) aggregate to evaluate final table aggregate value.
+ *
+ * @param config configuration that determines runtime behavior
+ * @param nullableInput input(s) can be null.
+ * @param input type information about the input of the Function
+ * @param constants constant expressions that act like a second input in the
+ * parameter indices.
+ * @param namedAggregates List of calls to aggregate functions and their output field names
+ * @param inputRowType Input row type
+ * @param inputFieldTypes Types of the physical input fields
+ * @param groupings the position (in the input Row) of the grouping keys
+ * @param queryConfig The configuration of the query to generate.
+ * @param generateRetraction It is a tag that indicates whether generate retract record.
+ * @param consumeRetraction It is a tag that indicates whether consume the retract record.
+ * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
+ */
+ private[flink] def createGroupTableAggregateFunction[K](
+ config: TableConfig,
+ nullableInput: Boolean,
+ input: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ inputRowType: RelDataType,
+ inputFieldTypes: Seq[TypeInformation[_]],
+ tableAggOutputRowType: RowTypeInfo,
+ groupings: Array[Int],
+ queryConfig: StreamQueryConfig,
+ generateRetraction: Boolean,
+ consumeRetraction: Boolean): KeyedProcessFunction[K, CRow, CRow] = {
+
+ val aggregateMetadata = extractAggregateMetadata(
+ namedAggregates.map(_.getKey),
+ inputRowType,
+ inputFieldTypes.length,
+ consumeRetraction,
+ config,
+ isStateBackedDataViews = true)
+
+ val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
+ val outputArity = groupings.length + tableAggOutputRowType.getTotalFields
+ val tableAggOutputType = namedAggregates
+ .head.left.getAggregation.asInstanceOf[AggSqlFunction].returnType
+
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ input,
+ constants,
+ "NonWindowedTableAggregationHelper",
+ inputFieldTypes,
+ aggregateMetadata.getAggregateFunctions,
+ aggregateMetadata.getAggregateIndices,
+ aggMapping,
+ aggregateMetadata.getDistinctAccMapping,
+ isStateBackedDataViews = true,
+ partialResults = false,
+ groupings,
+ None,
+ outputArity,
+ consumeRetraction,
+ needMerge = false,
+ needReset = false,
+ accConfig = Some(aggregateMetadata.getAggregatesAccumulatorSpecs)
+ )
+
+ val aggregationStateType: RowTypeInfo = new RowTypeInfo(aggregateMetadata
+ .getAggregatesAccumulatorTypes: _*)
+ new GroupTableAggProcessFunction[K](
+ generator.generateTableAggregations(tableAggOutputRowType, tableAggOutputType),
aggregationStateType,
generateRetraction,
queryConfig)
@@ -214,19 +315,26 @@ object AggregateUtil {
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for ROWS clause
* bounded OVER window to evaluate final aggregate value.
*
- * @param generator code generator instance
- * @param namedAggregates Physical calls to aggregate functions and their output field names
+ * @param config configuration that determines runtime behavior
+ * @param nullableInput input(s) can be null.
+ * @param input type information about the input of the Function
+ * @param constants constant expressions that act like a second input in the
+ * parameter indices.
+ * @param namedAggregates Physical calls to aggregate functions and their output field names
* @param aggregateInputType Physical type of the aggregate functions's input row.
- * @param inputType Physical type of the row.
- * @param inputTypeInfo Physical type information of the row.
+ * @param inputType Physical type of the row.
+ * @param inputTypeInfo Physical type information of the row.
* @param inputFieldTypeInfo Physical type information of the row's fields.
- * @param precedingOffset the preceding offset
- * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause
- * @param rowTimeIdx The index of the rowtime field or None in case of processing time.
+ * @param precedingOffset the preceding offset
+ * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause
+ * @param rowTimeIdx The index of the rowtime field or None in case of processing time.
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
private[flink] def createBoundedOverProcessFunction[K](
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ input: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
aggregateInputType: RelDataType,
inputType: RelDataType,
@@ -255,7 +363,11 @@ object AggregateUtil {
val outputArity = inputType.getFieldCount + aggregateMetadata.getAggregateCallsCount
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ input,
+ constants,
"BoundedOverAggregateHelper",
inputFieldTypeInfo,
aggregateMetadata.getAggregateFunctions,
@@ -272,6 +384,7 @@ object AggregateUtil {
needReset = false,
accConfig = Some(aggregateMetadata.getAggregatesAccumulatorSpecs)
)
+ val genFunction = generator.generateAggregations
val aggregationStateType: RowTypeInfo = new RowTypeInfo(aggregateMetadata
.getAggregatesAccumulatorTypes: _*)
@@ -336,7 +449,10 @@ object AggregateUtil {
* NOTE: this function is only used for time based window on batch tables.
*/
def createDataSetWindowPrepareMapFunction(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
window: LogicalWindow,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
groupings: Array[Int],
@@ -394,7 +510,11 @@ object AggregateUtil {
val outputArity = aggregateMetadata.getAggregateCallsCount + groupings.length +
aggregateMetadata.getDistinctAccCount + 1
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"DataSetAggregatePrepareMapHelper",
inputFieldTypeInfo,
aggregateMetadata.getAggregateFunctions,
@@ -413,7 +533,7 @@ object AggregateUtil {
)
new DataSetWindowAggMapFunction(
- genFunction,
+ generator.generateAggregations,
timeFieldPos,
tumbleTimeWindowSize,
mapReturnType)
@@ -447,7 +567,10 @@ object AggregateUtil {
* NOTE: this function is only used for sliding windows with partial aggregates on batch tables.
*/
def createDataSetSlideWindowPrepareGroupReduceFunction(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
window: LogicalWindow,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
groupings: Array[Int],
@@ -479,7 +602,11 @@ object AggregateUtil {
window match {
case SlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) =>
// sliding time-window for partial aggregations
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"DataSetAggregatePrepareMapHelper",
physicalInputTypes,
aggregateMetadata.getAggregateFunctions,
@@ -497,7 +624,7 @@ object AggregateUtil {
None
)
new DataSetSlideTimeWindowAggReduceGroupFunction(
- genFunction,
+ generator.generateAggregations,
keysAndAggregatesArity,
asLong(size),
asLong(slide),
@@ -565,7 +692,10 @@ object AggregateUtil {
* NOTE: this function is only used for window on batch tables.
*/
def createDataSetWindowAggregationGroupReduceFunction(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
window: LogicalWindow,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
physicalInputRowType: RelDataType,
@@ -587,7 +717,11 @@ object AggregateUtil {
val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
- val genPreAggFunction = generator.generateAggregations(
+ val generatorPre = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"GroupingWindowAggregateHelper",
physicalInputTypes,
aggregateMetadata.getAggregateFunctions,
@@ -604,8 +738,13 @@ object AggregateUtil {
needReset = true,
None
)
+ val genPreAggFunction = generatorPre.generateAggregations
- val genFinalAggFunction = generator.generateAggregations(
+ val generatorFinal = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"GroupingWindowAggregateHelper",
physicalInputTypes,
aggregateMetadata.getAggregateFunctions,
@@ -622,6 +761,7 @@ object AggregateUtil {
needReset = true,
None
)
+ val genFinalAggFunction = generatorFinal.generateAggregations
val keysAndAggregatesArity = groupings.length + namedAggregates.length
@@ -727,7 +867,10 @@ object AggregateUtil {
*
*/
def createDataSetWindowAggregationMapPartitionFunction(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
window: LogicalWindow,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
physicalInputRowType: RelDataType,
@@ -757,7 +900,11 @@ object AggregateUtil {
physicalInputRowType,
Option(Array(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO)))
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"GroupingWindowAggregateHelper",
physicalInputTypes,
aggregateMetadata.getAggregateFunctions,
@@ -776,7 +923,7 @@ object AggregateUtil {
)
new DataSetSessionWindowAggregatePreProcessor(
- genFunction,
+ generator.generateAggregations,
keysAndAggregatesArity,
asLong(gap),
combineReturnType)
@@ -804,7 +951,10 @@ object AggregateUtil {
*
*/
private[flink] def createDataSetWindowAggregationCombineFunction(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
window: LogicalWindow,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
physicalInputRowType: RelDataType,
@@ -836,7 +986,11 @@ object AggregateUtil {
physicalInputRowType,
Option(Array(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO)))
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"GroupingWindowAggregateHelper",
physicalInputTypes,
aggregateMetadata.getAggregateFunctions,
@@ -855,7 +1009,7 @@ object AggregateUtil {
)
new DataSetSessionWindowAggregatePreProcessor(
- genFunction,
+ generator.generateAggregations,
keysAndAggregatesArity,
asLong(gap),
combineReturnType)
@@ -873,7 +1027,10 @@ object AggregateUtil {
* respective output type are generated as well.
*/
private[flink] def createDataSetAggregateFunctions(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
inputFieldTypeInfo: Seq[TypeInformation[_]],
@@ -913,7 +1070,11 @@ object AggregateUtil {
.map(FlinkTypeFactory.toTypeInfo) ++ aggregateMetadata.getAggregatesAccumulatorTypes
val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*)
- val genPreAggFunction = generator.generateAggregations(
+ val generatorPre = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"DataSetAggregatePrepareMapHelper",
inputFieldTypeInfo,
aggregateMetadata.getAggregateFunctions,
@@ -930,6 +1091,7 @@ object AggregateUtil {
needReset = true,
None
)
+ val genPreAggFunction = generatorPre.generateAggregations
// compute mapping of forwarded grouping keys
val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) {
@@ -941,7 +1103,11 @@ object AggregateUtil {
new Array[Int](0)
}
- val genFinalAggFunction = generator.generateAggregations(
+ val generatorFinal = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"DataSetAggregateFinalHelper",
inputFieldTypeInfo,
aggregateMetadata.getAggregateFunctions,
@@ -958,6 +1124,7 @@ object AggregateUtil {
needReset = true,
None
)
+ val genFinalAggFunction = generatorFinal.generateAggregations
(
Some(new DataSetPreAggFunction(genPreAggFunction)),
@@ -966,7 +1133,11 @@ object AggregateUtil {
)
}
else {
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"DataSetAggregateHelper",
inputFieldTypeInfo,
aggregateMetadata.getAggregateFunctions,
@@ -987,7 +1158,7 @@ object AggregateUtil {
(
None,
None,
- Left(new DataSetAggFunction(genFunction))
+ Left(new DataSetAggFunction(generator.generateAggregations))
)
}
@@ -1046,7 +1217,10 @@ object AggregateUtil {
}
private[flink] def createDataStreamAggregateFunction(
- generator: AggregationCodeGenerator,
+ config: TableConfig,
+ nullableInput: Boolean,
+ inputTypeInfo: TypeInformation[_ <: Any],
+ constants: Option[Seq[RexLiteral]],
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
inputFieldTypeInfo: Seq[TypeInformation[_]],
@@ -1068,7 +1242,11 @@ object AggregateUtil {
val aggMapping = aggregateMetadata.getAdjustedMapping(0)
val outputArity = aggregateMetadata.getAggregateCallsCount
- val genFunction = generator.generateAggregations(
+ val generator = new AggregationCodeGenerator(
+ config,
+ nullableInput,
+ inputTypeInfo,
+ constants,
"GroupingWindowAggregateHelper",
inputFieldTypeInfo,
aggregateMetadata.getAggregateFunctions,
@@ -1087,7 +1265,7 @@ object AggregateUtil {
)
val accumulatorRowType = new RowTypeInfo(aggregateMetadata.getAggregatesAccumulatorTypes: _*)
- val aggFunction = new AggregateAggFunction(genFunction)
+ val aggFunction = new AggregateAggFunction(generator.generateAggregations)
(aggFunction, accumulatorRowType)
}
@@ -1115,7 +1293,7 @@ object AggregateUtil {
* Return true if all aggregates can be partially merged. False otherwise.
*/
private[flink] def doAllSupportPartialMerge(
- aggregateList: Array[TableAggregateFunction[_ <: Any, _ <: Any]]): Boolean = {
+ aggregateList: Array[UserDefinedAggregateFunction[_ <: Any, _ <: Any]]): Boolean = {
aggregateList.forall(ifMethodExistInFunction("merge", _))
}
@@ -1193,7 +1371,7 @@ object AggregateUtil {
private val aggregates: Seq[(AggregateCallMetadata, Array[Int])],
private val distinctAccTypesWithSpecs: Seq[(TypeInformation[_], Seq[DataViewSpec[_]])]) {
- def getAggregateFunctions: Array[TableAggregateFunction[_, _]] = {
+ def getAggregateFunctions: Array[UserDefinedAggregateFunction[_, _]] = {
aggregates.map(_._1.aggregateFunction).toArray
}
@@ -1239,7 +1417,7 @@ object AggregateUtil {
* function.
*/
private[flink] case class AggregateCallMetadata(
- aggregateFunction: TableAggregateFunction[_, _],
+ aggregateFunction: UserDefinedAggregateFunction[_, _],
accumulatorType: TypeInformation[_],
accumulatorSpecs: Seq[DataViewSpec[_]],
distinctAccIndex: Int
@@ -1257,7 +1435,7 @@ object AggregateUtil {
* @param aggregateCount number of aggregates
* @param inputFieldsCount number of input fields
* @param aggregateInputTypes input types of given aggregate
- * @param needRetraction if the [[TableAggregateFunction]] should produce retractions
+ * @param needRetraction if the [[AggregateFunction]] should produce retractions
* @param tableConfig tableConfig, required for decimal precision
* @param isStateBackedDataViews if data should be backed by state backend
* @param uniqueIdWithinAggregate index within an AggregateCallMetadata, used to create unique
@@ -1283,7 +1461,7 @@ object AggregateUtil {
// store the aggregate fields of each aggregate function, by the same order of aggregates.
// create aggregate function instances by function type and aggregate field data type.
- val aggregate: TableAggregateFunction[_, _] = createFlinkAggFunction(
+ val aggregate: UserDefinedAggregateFunction[_, _] = createFlinkAggFunction(
aggregateFunction,
needRetraction,
aggregateInputTypes,
@@ -1331,12 +1509,12 @@ object AggregateUtil {
/**
* Prepares metadata [[AggregateMetadata]] required to generate code for
- * [[GeneratedAggregations]] for all [[AggregateCall]].
+ * [[AggregationsFunction]] for all [[AggregateCall]].
*
* @param aggregateCalls calcite's aggregate function
* @param aggregateInputType input type of given aggregates
* @param inputFieldsCount number of input fields
- * @param needRetraction if the [[TableAggregateFunction]] should produce retractions
+ * @param needRetraction if the [[AggregateFunction]] should produce retractions
* @param tableConfig tableConfig, required for decimal precision
* @param isStateBackedDataViews if data should be backed by state backend
* @return the result contains required metadata:
@@ -1403,7 +1581,7 @@ object AggregateUtil {
}
/**
- * Converts calcite's [[SqlAggFunction]] to a Flink's UDF [[TableAggregateFunction]].
+ * Converts calcite's [[SqlAggFunction]] to a Flink's UDF [[AggregateFunction]].
* create aggregate function instances by function type and aggregate field data type.
*/
private def createFlinkAggFunction(
@@ -1411,7 +1589,7 @@ object AggregateUtil {
needRetraction: Boolean,
inputDataType: Seq[RelDataType],
tableConfig: TableConfig)
- : TableAggregateFunction[_ <: Any, _ <: Any] = {
+ : UserDefinedAggregateFunction[_ <: Any, _ <: Any] = {
lazy val outputType = inputDataType.get(0)
lazy val outputTypeName = if (inputDataType.isEmpty) {
@@ -1665,7 +1843,7 @@ object AggregateUtil {
private def createRowTypeForKeysAndAggregates(
groupings: Array[Int],
- aggregates: Array[TableAggregateFunction[_, _]],
+ aggregates: Array[UserDefinedAggregateFunction[_, _]],
aggTypes: Array[TypeInformation[_]],
inputType: RelDataType,
windowKeyTypes: Option[Array[TypeInformation[_]]] = None): RowTypeInfo = {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
index 7549db5..b771c5e 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
@@ -20,12 +20,12 @@ package org.apache.flink.table.runtime.aggregate
import org.apache.flink.api.common.functions.{Function, RuntimeContext}
import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
/**
- * Base class for code-generated aggregations.
+ * Base class for code-generated aggregations and table aggregations.
*/
-abstract class GeneratedAggregations extends Function {
-
+abstract class AggregationsFunction extends Function {
/**
* Setup method for [[org.apache.flink.table.functions.AggregateFunction]].
* It can be used for initialization work. By default, this method does nothing.
@@ -35,25 +35,6 @@ abstract class GeneratedAggregations extends Function {
def open(ctx: RuntimeContext)
/**
- * Sets the results of the aggregations (partial or final) to the output row.
- * Final results are computed with the aggregation function.
- * Partial results are the accumulators themselves.
- *
- * @param accumulators the accumulators (saved in a row) which contains the current
- * aggregated results
- * @param output output results collected in a row
- */
- def setAggregationResults(accumulators: Row, output: Row)
-
- /**
- * Copies forwarded fields, such as grouping keys, from input row to output row.
- *
- * @param input input values bundled in a row
- * @param output output results collected in a row
- */
- def setForwardedFields(input: Row, output: Row)
-
- /**
* Accumulates the input values to the accumulators.
*
* @param accumulators the accumulators (saved in a row) which contains the current
@@ -79,13 +60,6 @@ abstract class GeneratedAggregations extends Function {
def createAccumulators(): Row
/**
- * Creates an output row object with the correct arity.
- *
- * @return an output row object with the correct arity.
- */
- def createOutputRow(): Row
-
- /**
* Merges two rows of accumulators into one row.
*
* @param a First row of accumulators
@@ -95,12 +69,19 @@ abstract class GeneratedAggregations extends Function {
def mergeAccumulatorsPair(a: Row, b: Row): Row
/**
- * Resets all the accumulators.
+ * Copies forwarded fields, such as grouping keys, from input row to output row.
*
- * @param accumulators the accumulators (saved in a row) which contains the current
- * aggregated results
+ * @param input input values bundled in a row
+ * @param output output results collected in a row
*/
- def resetAccumulator(accumulators: Row)
+ def setForwardedFields(input: Row, output: Row)
+
+ /**
+ * Creates an output row object with the correct arity.
+ *
+ * @return an output row object with the correct arity.
+ */
+ def createOutputRow(): Row
/**
* Cleanup for the accumulators.
@@ -114,6 +95,42 @@ abstract class GeneratedAggregations extends Function {
def close()
}
+/**
+ * Base class for code-generated aggregations.
+ */
+abstract class GeneratedAggregations extends AggregationsFunction {
+
+ /**
+ * Sets the results of the aggregations (partial or final) to the output row.
+ * Final results are computed with the aggregation function.
+ * Partial results are the accumulators themselves.
+ *
+ * @param accumulators the accumulators (saved in a row) which contains the current
+ * aggregated results
+ * @param output output results collected in a row
+ */
+ def setAggregationResults(accumulators: Row, output: Row)
+
+ /**
+ * Resets all the accumulators.
+ *
+ * @param accumulators the accumulators (saved in a row) which contains the current
+ * aggregated results
+ */
+ def resetAccumulator(accumulators: Row)
+}
+
+/**
+ * Base class for code-generated table aggregations.
+ */
+abstract class GeneratedTableAggregations extends AggregationsFunction {
+
+ /**
+ * emit results.
+ */
+ def emit(accumulators: Row, collector: Collector[_])
+}
+
class SingleElementIterable[T] extends java.lang.Iterable[T] {
class SingleElementIterator extends java.util.Iterator[T] {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupTableAggProcessFunction.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupTableAggProcessFunction.scala
new file mode 100644
index 0000000..673a51c
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupTableAggProcessFunction.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import java.lang.{Long => JLong}
+
+import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction
+import org.apache.flink.table.api.{StreamQueryConfig, Types}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
+import org.apache.flink.table.runtime.CRowWrappingCollector
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.util.Logging
+import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
+
+/**
+ * Aggregate Function used for the groupby (without window) table aggregate.
+ *
+ * @param genTableAggregations Generated aggregate helper function
+ * @param aggregationStateType The row type info of aggregation
+ */
+class GroupTableAggProcessFunction[K](
+ private val genTableAggregations: GeneratedAggregationsFunction,
+ private val aggregationStateType: RowTypeInfo,
+ private val generateRetraction: Boolean,
+ private val queryConfig: StreamQueryConfig)
+ extends ProcessFunctionWithCleanupState[K, CRow, CRow](queryConfig)
+ with Compiler[GeneratedTableAggregations]
+ with Logging {
+
+ private var function: GeneratedTableAggregations = _
+
+ private var firstRow: Boolean = _
+ // stores the accumulators
+ private var state: ValueState[Row] = _
+ // counts the number of added and retracted input records
+ private var cntState: ValueState[JLong] = _
+
+ private var appendKeyCollector: AppendKeyCRowCollector = _
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling TableAggregateHelper: ${genTableAggregations.name} \n\n " +
+ s"Code:\n${genTableAggregations.code}")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genTableAggregations.name,
+ genTableAggregations.code)
+ LOG.debug("Instantiating TableAggregateHelper.")
+ function = clazz.newInstance()
+ function.open(getRuntimeContext)
+
+ val stateDescriptor: ValueStateDescriptor[Row] =
+ new ValueStateDescriptor[Row]("GroupTableAggregateState", aggregationStateType)
+ state = getRuntimeContext.getState(stateDescriptor)
+ val inputCntDescriptor: ValueStateDescriptor[JLong] =
+ new ValueStateDescriptor[JLong]("GroupTableAggregateInputCounter", Types.LONG)
+ cntState = getRuntimeContext.getState(inputCntDescriptor)
+
+ appendKeyCollector = new AppendKeyCRowCollector
+ appendKeyCollector.setResultRow(function.createOutputRow())
+
+ initCleanupTimeState("GroupTableAggregateCleanupTime")
+ }
+
+ override def processElement(
+ inputC: CRow,
+ ctx: KeyedProcessFunction[K, CRow, CRow]#Context,
+ out: Collector[CRow]): Unit = {
+
+ val currentTime = ctx.timerService().currentProcessingTime()
+ // register state-cleanup timer
+ processCleanupTimer(ctx, currentTime)
+
+ val input = inputC.row
+
+ // get accumulators and input counter
+ var accumulators = state.value()
+ var inputCnt = cntState.value()
+
+ if (null == accumulators) {
+ // Don't create a new accumulator for a retraction message. This
+ // might happen if the retraction message is the first message for the
+ // key or after a state clean up.
+ if (!inputC.change) {
+ return
+ }
+ // first accumulate message
+ firstRow = true
+ accumulators = function.createAccumulators()
+ } else {
+ firstRow = false
+ }
+
+ // Set group keys value to the final output
+ function.setForwardedFields(input, appendKeyCollector.getResultRow)
+
+ appendKeyCollector.out = out
+ if (!firstRow) {
+ if (generateRetraction) {
+ appendKeyCollector.setChange(false)
+ function.emit(accumulators, appendKeyCollector)
+ appendKeyCollector.setChange(true)
+ }
+ }
+
+ if (null == inputCnt) {
+ inputCnt = 0L
+ }
+
+ // update aggregate result and set to the newRow
+ if (inputC.change) {
+ inputCnt += 1
+ // accumulate input
+ function.accumulate(accumulators, input)
+ } else {
+ inputCnt -= 1
+ // retract input
+ function.retract(accumulators, input)
+ }
+
+ if (inputCnt != 0) {
+ // we aggregated at least one record for this key
+
+ // update the state
+ state.update(accumulators)
+ cntState.update(inputCnt)
+
+ // emit the new result
+ function.emit(accumulators, appendKeyCollector)
+
+ } else {
+ // and clear all state
+ state.clear()
+ cntState.clear()
+ }
+ }
+
+ override def onTimer(
+ timestamp: Long,
+ ctx: KeyedProcessFunction[K, CRow, CRow]#OnTimerContext,
+ out: Collector[CRow]): Unit = {
+
+ if (stateCleaningEnabled) {
+ cleanupState(state, cntState)
+ function.cleanup()
+ }
+ }
+
+ override def close(): Unit = {
+ function.close()
+ }
+}
+
+/**
+ * The collector is used to assemble group key and table function output.
+ */
+class AppendKeyCRowCollector() extends CRowWrappingCollector {
+
+ var resultRow: Row = _
+
+ def setResultRow(row: Row): Unit = {
+ resultRow = row
+ }
+
+ def getResultRow: Row = {
+ resultRow
+ }
+
+ override def collect(record: Row): Unit = {
+ var i = 0
+ val offset = resultRow.getArity - record.getArity
+ while (i < record.getArity) {
+ resultRow.setField(i + offset, record.getField(i))
+ i += 1
+ }
+ super.collect(resultRow)
+ }
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index 9f9c876..2750680 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -30,7 +30,7 @@ import org.apache.flink.table.expressions._
import org.apache.flink.table.expressions.catalog.FunctionDefinitionCatalog
import org.apache.flink.table.functions.sql.ScalarSqlFunctions
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{createAggregateSqlFunction, createScalarSqlFunction, createTableSqlFunction}
-import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction}
+import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedAggregateFunction}
import _root_.scala.collection.JavaConversions._
import _root_.scala.collection.mutable
@@ -75,7 +75,7 @@ class FunctionCatalog extends FunctionDefinitionCatalog {
def registerAggregateFunction(
name: String,
- function: AggregateFunction[_, _],
+ function: UserDefinedAggregateFunction[_, _],
resultType: TypeInformation[_],
accType: TypeInformation[_],
typeFactory: FlinkTypeFactory)
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableAggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableAggregateTest.scala
new file mode 100644
index 0000000..938d2c5
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableAggregateTest.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.stream.table
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.utils.Func0
+import org.apache.flink.table.utils.{EmptyTableAggFunc, TableTestBase}
+import org.apache.flink.table.utils.TableTestUtil._
+import org.apache.flink.types.Row
+import org.junit.Test
+
+class TableAggregateTest extends TableTestBase {
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, Long, Long)]('a, 'b, 'c, 'd.rowtime, 'e.proctime)
+ val emptyFunc = new EmptyTableAggFunc
+
+ @Test
+ def testTableAggregateWithGroupBy(): Unit = {
+
+ val resultTable = table
+ .groupBy('b % 5 as 'bb)
+ .flatAggregate(emptyFunc('a, 'b) as ('x, 'y))
+ .select('bb, 'x + 1, 'y)
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupTableAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "b", "MOD(b, 5) AS bb")
+ ),
+ term("groupBy", "bb"),
+ term("select", "bb", "EmptyTableAggFunc(a, b) AS (f0, f1)")
+ ),
+ term("select", "bb", "+(f0, 1) AS _c1", "f1 AS y")
+ )
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testTableAggregateWithoutGroupBy(): Unit = {
+
+ val resultTable = table
+ .flatAggregate(emptyFunc('a, 'b))
+ .select(Func0('f0) as 'a, 'f1 as 'b)
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupTableAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("select", "EmptyTableAggFunc(a, b) AS (f0, f1)")
+ ),
+ term("select", "Func0$(f0) AS a", "f1 AS b")
+ )
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testTableAggregateWithTimeIndicator(): Unit = {
+
+ val resultTable = table
+ .flatAggregate(emptyFunc('d, 'e))
+ .select('f0 as 'a, 'f1 as 'b)
+
+ val expected =
+ unaryNode(
+ "DataStreamGroupTableAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "CAST(d) AS d", "PROCTIME(e) AS e")
+ ),
+ term("select", "EmptyTableAggFunc(d, e) AS (f0, f1)")
+ )
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testTableAggregateWithSelectStar(): Unit = {
+
+ val resultTable = table
+ .flatAggregate(emptyFunc('b))
+ .select("*")
+
+ val expected =
+ unaryNode(
+ "DataStreamGroupTableAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "b")
+ ),
+ term("select", "EmptyTableAggFunc(b) AS (f0, f1)")
+ )
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testTableAggregateWithAlias(): Unit = {
+
+ val resultTable = table
+ .flatAggregate(emptyFunc('b) as ('a, 'b))
+ .select('a, 'b)
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupTableAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "b")
+ ),
+ term("select", "EmptyTableAggFunc(b) AS (f0, f1)")
+ ),
+ term("select", "f0 AS a", "f1 AS b")
+ )
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testJavaRegisterFunction(): Unit = {
+ val util = streamTestUtil()
+ val typeInfo = new RowTypeInfo(Types.INT, Types.LONG, Types.STRING)
+ val table = util.addJavaTable[Row](typeInfo, "sourceTable", "a, b, c")
+
+ val func = new EmptyTableAggFunc
+ util.javaTableEnv.registerFunction("func", func)
+
+ val resultTable = table
+ .groupBy("c")
+ .flatAggregate("func(a)")
+ .select("*")
+
+ val expected =
+ unaryNode(
+ "DataStreamGroupTableAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "c")
+ ),
+ term("groupBy", "c"),
+ term("select", "c", "EmptyTableAggFunc(a) AS (f0, f1)")
+ )
+ util.verifyJavaTable(resultTable, expected)
+ }
+}
+
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/TableAggregateStringExpressionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/TableAggregateStringExpressionTest.scala
new file mode 100644
index 0000000..6d9e794
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/TableAggregateStringExpressionTest.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.stream.table.stringexpr
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.utils.Func0
+import org.apache.flink.table.utils.{TableTestBase, Top3WithMapView}
+import org.junit.Test
+
+class TableAggregateStringExpressionTest extends TableTestBase {
+
+ @Test
+ def testNonGroupedTableAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]('a, 'b, 'c)
+
+ val top3 = new Top3WithMapView
+ util.tableEnv.registerFunction("top3", top3)
+ util.tableEnv.registerFunction("Func0", Func0)
+
+ // Expression / Scala API
+ val resScala = t
+ .flatAggregate(top3('a))
+ .select(Func0('f0) as 'a, 'f1 as 'b)
+
+ // String / Java API
+ val resJava = t
+ .flatAggregate("top3(a)")
+ .select("Func0(f0) as a, f1 as b")
+
+ verifyTableEquals(resJava, resScala)
+ }
+
+ @Test
+ def testGroupedTableAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]('a, 'b, 'c)
+
+ val top3 = new Top3WithMapView
+ util.tableEnv.registerFunction("top3", top3)
+ util.tableEnv.registerFunction("Func0", Func0)
+
+ // Expression / Scala API
+ val resScala = t
+ .groupBy('b % 5)
+ .flatAggregate(top3('a))
+ .select(Func0('f0) as 'a, 'f1 as 'b)
+
+ // String / Java API
+ val resJava = t
+ .groupBy("b % 5")
+ .flatAggregate("top3(a)")
+ .select("Func0(f0) as a, f1 as b")
+
+ verifyTableEquals(resJava, resScala)
+ }
+
+ @Test
+ def testAliasNonGroupedTableAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]('a, 'b, 'c)
+
+ val top3 = new Top3WithMapView
+ util.tableEnv.registerFunction("top3", top3)
+ util.tableEnv.registerFunction("Func0", Func0)
+
+ // Expression / Scala API
+ val resScala = t
+ .flatAggregate(top3('a) as ('d, 'e))
+ .select('*)
+
+ // String / Java API
+ val resJava = t
+ .flatAggregate("top3(a) as (d, e)")
+ .select("*")
+
+ verifyTableEquals(resJava, resScala)
+ }
+
+ @Test
+ def testAliasGroupedTableAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]('a, 'b, 'c)
+
+ val top3 = new Top3WithMapView
+ util.tableEnv.registerFunction("top3", top3)
+ util.tableEnv.registerFunction("Func0", Func0)
+
+ // Expression / Scala API
+ val resScala = t
+ .groupBy('b)
+ .flatAggregate(top3('a) as ('d, 'e))
+ .select('*)
+
+ // String / Java API
+ val resJava = t
+ .groupBy("b")
+ .flatAggregate("top3(a) as (d, e)")
+ .select("*")
+
+ verifyTableEquals(resJava, resScala)
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/TableAggregateValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/TableAggregateValidationTest.scala
new file mode 100644
index 0000000..4ecca0e
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/TableAggregateValidationTest.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.stream.table.validation
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.utils.{EmptyTableAggFunc, TableTestBase}
+import org.junit.Test
+
+class TableAggregateValidationTest extends TableTestBase {
+
+ @Test
+ def testInvalidParameterNumber(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Given parameters do not match any signature. \n" +
+ "Actual: (java.lang.Long, java.lang.Integer, java.lang.String) \n" +
+ "Expected: (int), (long, int), (long, java.sql.Timestamp)")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('c)
+ // must fail. func does not take 3 parameters
+ .flatAggregate(func('a, 'b, 'c))
+ .select('_1, '_2, '_3)
+ }
+
+ @Test
+ def testInvalidParameterType(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Given parameters do not match any signature. \n" +
+ "Actual: (java.lang.Long, java.lang.String) \n" +
+ "Expected: (int), (long, int), (long, java.sql.Timestamp)")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('c)
+ // must fail. func take 2 parameters of type Long and Timestamp or Long Int
+ .flatAggregate(func('a, 'c))
+ .select('_1, '_2, '_3)
+ }
+
+ @Test
+ def testInvalidWithWindowProperties(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Window properties can only be used on windowed tables.")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ .flatAggregate(func('a, 'b) as ('x, 'y))
+ .select('x.start, 'y)
+ }
+
+ @Test
+ def testInvalidParameterWithAgg(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage(
+ "It's not allowed to use an aggregate function as input of another aggregate function")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ // must fail. func take agg function as input
+ .flatAggregate(func('a.sum, 'c))
+ .select('_1, '_2, '_3)
+ }
+
+ @Test
+ def testInvalidAliasWithWrongNumber(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("List of column aliases must have same degree as " +
+ "table; the returned table of function 'org.apache.flink.table.utils.EmptyTableAggFunc' " +
+ "has 2 columns, whereas alias list has 3 columns")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ // must fail. alias with wrong number of fields
+ .flatAggregate(func('a, 'b) as ('a, 'b, 'c))
+ .select('*)
+ }
+
+ @Test
+ def testAliasWithNameConflict(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Ambiguous column name: b")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ // must fail. alias with wrong number of fields
+ .flatAggregate(func('a, 'b) as ('a, 'b))
+ .select('*)
+ }
+
+ @Test
+ def testInvalidDistinct(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("A flatAggregate only accepts an expression which " +
+ "defines a table aggregate function that might be followed by some alias.")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, Timestamp)]('a, 'b, 'c)
+
+ val func = new EmptyTableAggFunc
+ table
+ .groupBy('b)
+ .flatAggregate(func('a, 'b).distinct as ('a, 'b, 'c))
+ .select('a, 'b)
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/TableAggregateHarnessTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/TableAggregateHarnessTest.scala
new file mode 100644
index 0000000..7d92f2b
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/TableAggregateHarnessTest.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.harness
+
+import java.lang.{Integer => JInt}
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.harness.HarnessTestBase._
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.utils.{Top3WithMapView}
+import org.apache.flink.types.Row
+import org.junit.Test
+
+import scala.collection.mutable
+
+class TableAggregateHarnessTest extends HarnessTestBase {
+
+ protected var queryConfig =
+ new TestStreamQueryConfig(Time.seconds(2), Time.seconds(3))
+ val data = new mutable.MutableList[(Int, Int)]
+
+ @Test
+ def testTableAggregate(): Unit = {
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = StreamTableEnvironment.create(env)
+
+ val top3 = new Top3WithMapView
+ tEnv.registerFunction("top3", top3)
+ val source = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+ val resultTable = source
+ .groupBy('a)
+ .flatAggregate(top3('b) as ('b1, 'b2))
+ .select('a, 'b1, 'b2)
+
+ val testHarness = createHarnessTester[Int, CRow, CRow](
+ resultTable.toRetractStream[Row](queryConfig), "groupBy: (a)")
+
+ testHarness.open()
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ // register cleanup timer with 3001
+ testHarness.setProcessingTime(1)
+
+ // input with two columns: key and value
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 1: JInt), 1))
+ // output with three columns: key, value, value. The value is in the top3 of the key
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 1: JInt, 1: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 1: JInt, 1: JInt, 1: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 1: JInt, 1: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 2: JInt, 2: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 3: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 1: JInt, 1: JInt, 1: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 1: JInt, 2: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 1: JInt, 1: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 2: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 3: JInt, 3: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 1: JInt, 1: JInt, 1: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 1: JInt, 2: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 1: JInt, 3: JInt, 3: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 2: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 2: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 3: JInt, 3: JInt), 1))
+
+ // ingest data with key value of 2
+ testHarness.processElement(new StreamRecord(CRow(2: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(2: JInt, 2: JInt, 2: JInt), 1))
+
+ // trigger cleanup timer
+ testHarness.setProcessingTime(3002)
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 2: JInt, 2: JInt), 1))
+
+ val result = testHarness.getOutput
+
+ verify(expectedOutput, result)
+ testHarness.close()
+ }
+
+ @Test
+ def testTableAggregateWithRetractInput(): Unit = {
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = StreamTableEnvironment.create(env)
+
+ val top3 = new Top3WithMapView
+ tEnv.registerFunction("top3", top3)
+ val source = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+ val resultTable = source
+ .groupBy('a)
+ .select('b.sum as 'b)
+ .flatAggregate(top3('b) as ('b1, 'b2))
+ .select('b1, 'b2)
+
+ val testHarness = createHarnessTester[Int, CRow, CRow](
+ resultTable.toRetractStream[Row](queryConfig), "select: (Top3WithMapView")
+
+ testHarness.open()
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ // register cleanup timer with 3001
+ testHarness.setProcessingTime(1)
+
+ // input with two columns: key and value
+ testHarness.processElement(new StreamRecord(CRow(1: JInt), 1))
+ // output with three columns: key, value, value. The value is in the top3 of the key
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 1: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(false, 1: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 1: JInt, 1: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(3: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(3: JInt, 3: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(4: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 3: JInt, 3: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(3: JInt, 3: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(4: JInt, 4: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(false, 3: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 3: JInt, 3: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 4: JInt, 4: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(4: JInt, 4: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(5: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 4: JInt, 4: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(4: JInt, 4: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(5: JInt, 5: JInt), 1))
+
+ val result = testHarness.getOutput
+
+ verify(expectedOutput, result)
+ testHarness.close()
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableAggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableAggregateITCase.scala
new file mode 100644
index 0000000..d21c0a9
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableAggregateITCase.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.stream.table
+
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.{StreamQueryConfig, Types, ValidationException}
+import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase}
+import org.apache.flink.table.utils.{Top3, Top3WithMapView}
+import org.apache.flink.types.Row
+import org.junit.Assert.assertEquals
+import org.junit.Test
+
+/**
+ * Tests of groupby (without window) table aggregations
+ */
+class TableAggregateITCase extends StreamingWithStateTestBase {
+ private val queryConfig = new StreamQueryConfig()
+ queryConfig.withIdleStateRetentionTime(Time.hours(1), Time.hours(2))
+
+ @Test
+ def testGroupByFlatAggregate(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = StreamTableEnvironment.create(env)
+ StreamITCase.clear
+
+ val top3 = new Top3
+ val source = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c)
+ val resultTable = source.groupBy('b)
+ .flatAggregate(top3('a))
+ .select('b, 'f0, 'f1)
+ .as('category, 'v1, 'v2)
+
+ val results = resultTable.toRetractStream[Row](queryConfig)
+ results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "1,1,1",
+ "2,2,2",
+ "2,3,3",
+ "3,4,4",
+ "3,5,5",
+ "3,6,6",
+ "4,10,10",
+ "4,9,9",
+ "4,8,8",
+ "5,15,15",
+ "5,14,14",
+ "5,13,13",
+ "6,21,21",
+ "6,20,20",
+ "6,19,19"
+ ).sorted
+ assertEquals(expected, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
+ def testNonkeyedFlatAggregate(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = StreamTableEnvironment.create(env)
+ StreamITCase.clear
+
+ val top3 = new Top3
+ val source = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c)
+ val resultTable = source
+ .flatAggregate(top3('a))
+ .select('f0, 'f1)
+ .as('v1, 'v2)
+
+ val results = resultTable.toRetractStream[Row](queryConfig)
+ results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "19,19",
+ "20,20",
+ "21,21"
+ ).sorted
+ assertEquals(expected, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
+ def testWithMapViewAndInputWithRetraction(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = StreamTableEnvironment.create(env)
+ StreamITCase.clear
+
+ val top3 = new Top3WithMapView
+ val source = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c)
+ val resultTable = source
+ .groupBy('b)
+ .select('b, 'a.sum as 'a)
+ .flatAggregate(top3('a) as ('v1, 'v2))
+ .select('v1, 'v2)
+
+ val results = resultTable.toRetractStream[Row](queryConfig)
+ results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "111,111",
+ "65,65",
+ "34,34"
+ ).sorted
+ assertEquals(expected, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
+ def testTableAggFunctionWithoutRetractionMethod(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Function class " +
+ "'org.apache.flink.table.utils.Top3' does not implement at least one method " +
+ "named 'retract' which is public, not abstract and (in case of table functions) not static.")
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = StreamTableEnvironment.create(env)
+ StreamITCase.clear
+
+ tEnv.registerTableSink(
+ "retractSink",
+ new TestRetractSink().configure(
+ Array[String]("v1", "v2"),
+ Array[TypeInformation[_]](Types.INT, Types.INT)))
+
+ val top3 = new Top3
+ val source = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c)
+ source
+ .groupBy('b)
+ .select('b, 'a.sum as 'a)
+ .flatAggregate(top3('a) as ('v1, 'v2))
+ .select('v1, 'v2)
+ .insertInto("retractSink")
+
+ env.execute()
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/UserDefinedTableAggFunctions.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/UserDefinedTableAggFunctions.scala
new file mode 100644
index 0000000..2fadd44
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/UserDefinedTableAggFunctions.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.utils
+
+import org.apache.flink.table.functions.TableAggregateFunction
+import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import java.lang.{Integer => JInt}
+import java.sql.Timestamp
+import java.util
+
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.api.dataview.MapView
+import org.apache.flink.util.Collector
+
+class Top3Accum {
+ var data: util.HashMap[JInt, JInt] = _
+ var size: JInt = _
+ var smallest: JInt = _
+}
+
+class Top3 extends TableAggregateFunction[JTuple2[JInt, JInt], Top3Accum] {
+ override def createAccumulator(): Top3Accum = {
+ val acc = new Top3Accum
+ acc.data = new util.HashMap[JInt, JInt]()
+ acc.size = 0
+ acc.smallest = Integer.MAX_VALUE
+ acc
+ }
+
+ def add(acc: Top3Accum, v: Int): Unit = {
+ var cnt = acc.data.get(v)
+ acc.size += 1
+ if (cnt == null) {
+ cnt = 0
+ }
+ acc.data.put(v, cnt + 1)
+ }
+
+ def delete(acc: Top3Accum, v: Int): Unit = {
+ if (acc.data.containsKey(v)) {
+ acc.size -= 1
+ val cnt = acc.data.get(v) - 1
+ if (cnt == 0) {
+ acc.data.remove(v)
+ } else {
+ acc.data.put(v, cnt)
+ }
+ }
+ }
+
+ def updateSmallest(acc: Top3Accum): Unit = {
+ acc.smallest = Integer.MAX_VALUE
+ val keys = acc.data.keySet().iterator()
+ while (keys.hasNext) {
+ val key = keys.next()
+ if (key < acc.smallest) {
+ acc.smallest = key
+ }
+ }
+ }
+
+ def accumulate(acc: Top3Accum, v: Int) {
+ if (acc.size == 0) {
+ acc.size = 1
+ acc.smallest = v
+ acc.data.put(v, 1)
+ } else if (acc.size < 3) {
+ add(acc, v)
+ if (v < acc.smallest) {
+ acc.smallest = v
+ }
+ } else if (v > acc.smallest) {
+ delete(acc, acc.smallest)
+ add(acc, v)
+ updateSmallest(acc)
+ }
+ }
+
+ def emitValue(acc: Top3Accum, out: Collector[JTuple2[JInt, JInt]]): Unit = {
+ val entries = acc.data.entrySet().iterator()
+ while (entries.hasNext) {
+ val pair = entries.next()
+ for (_ <- 0 until pair.getValue) {
+ out.collect(JTuple2.of(pair.getKey, pair.getKey))
+ }
+ }
+ }
+}
+
+class Top3WithMapViewAccum {
+ var data: MapView[JInt, JInt] = _
+ var size: JInt = _
+ var smallest: JInt = _
+}
+
+class Top3WithMapView extends TableAggregateFunction[JTuple2[JInt, JInt], Top3WithMapViewAccum] {
+
+ @Override
+ def createAccumulator(): Top3WithMapViewAccum = {
+ val acc = new Top3WithMapViewAccum
+ acc.data = new MapView(Types.INT, Types.INT)
+ acc.size = 0
+ acc.smallest = Integer.MAX_VALUE
+ acc
+ }
+
+ def add(acc: Top3WithMapViewAccum, v: Int): Unit = {
+ var cnt = acc.data.get(v)
+ acc.size += 1
+ if (cnt == null) {
+ cnt = 0
+ }
+ acc.data.put(v, cnt + 1)
+ }
+
+ def delete(acc: Top3WithMapViewAccum, v: Int): Unit = {
+ if (acc.data.contains(v)) {
+ acc.size -= 1
+ val cnt = acc.data.get(v) - 1
+ if (cnt == 0) {
+ acc.data.remove(v)
+ } else {
+ acc.data.put(v, cnt)
+ }
+ }
+ }
+
+ def updateSmallest(acc: Top3WithMapViewAccum): Unit = {
+ acc.smallest = Integer.MAX_VALUE
+ val keys = acc.data.iterator
+ while (keys.hasNext) {
+ val pair = keys.next()
+ if (pair.getKey < acc.smallest) {
+ acc.smallest = pair.getKey
+ }
+ }
+ }
+
+ def accumulate(acc: Top3WithMapViewAccum, v: Int) {
+ if (acc.size == 0) {
+ acc.size = 1
+ acc.smallest = v
+ acc.data.put(v, 1)
+ } else if (acc.size < 3) {
+ add(acc, v)
+ if (v < acc.smallest) {
+ acc.smallest = v
+ }
+ } else if (v > acc.smallest) {
+ delete(acc, acc.smallest)
+ add(acc, v)
+ updateSmallest(acc)
+ }
+ }
+
+ def retract(acc: Top3WithMapViewAccum, v: Int) {
+ delete(acc, v)
+ updateSmallest(acc)
+ }
+
+ def emitValue(acc: Top3WithMapViewAccum, out: Collector[JTuple2[JInt, JInt]]): Unit = {
+ val keys = acc.data.iterator
+ while (keys.hasNext) {
+ val pair = keys.next()
+ for (_ <- 0 until pair.getValue) {
+ out.collect(JTuple2.of(pair.getKey, pair.getKey))
+ }
+ }
+ }
+}
+
+/**
+ * Test function for plan test.
+ */
+class EmptyTableAggFunc extends TableAggregateFunction[JTuple2[JInt, JInt], Top3Accum] {
+
+ override def createAccumulator(): Top3Accum = new Top3Accum
+
+ def accumulate(acc: Top3Accum, category: Long, value: Timestamp): Unit = {}
+
+ def accumulate(acc: Top3Accum, category: Long, value: Int): Unit = {}
+
+ def accumulate(acc: Top3Accum, value: Int): Unit = {}
+
+ def emitValue(acc: Top3Accum, out: Collector[JTuple2[JInt, JInt]]): Unit = {}
+}