You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2019/07/05 01:42:16 UTC
[flink] branch master updated: [FLINK-13087][table] Add group
window Aggregate operator to Table API
This is an automated email from the ASF dual-hosted git repository.
jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new a0012aa [FLINK-13087][table] Add group window Aggregate operator to Table API
a0012aa is described below
commit a0012aae89ec7a56133642b58b04e5f7b155c0f4
Author: hequn8128 <ch...@gmail.com>
AuthorDate: Thu Jul 4 11:11:01 2019 +0800
[FLINK-13087][table] Add group window Aggregate operator to Table API
This closes #8979
---
docs/dev/table/tableApi.md | 41 +++++-
.../apache/flink/table/api/WindowGroupedTable.java | 39 +++++-
.../apache/flink/table/api/internal/TableImpl.java | 109 +++++++++++++---
.../operations/utils/OperationTreeBuilder.java | 143 ++++++++++++++++++---
.../table/api/stream/table/AggregateTest.scala | 45 ++++++-
.../stringexpr/AggregateStringExpressionTest.scala | 25 ++++
.../table/validation/AggregateValidationTest.scala | 21 ++-
.../GroupWindowTableAggregateValidationTest.scala | 15 +++
.../validation/GroupWindowValidationTest.scala | 35 ++++-
.../runtime/stream/table/GroupWindowITCase.scala | 32 +++++
10 files changed, 461 insertions(+), 44 deletions(-)
diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md
index bd6bd38..744a82c 100644
--- a/docs/dev/table/tableApi.md
+++ b/docs/dev/table/tableApi.md
@@ -2643,6 +2643,26 @@ Table table = input
<tr>
<td>
+ <strong>Group Window Aggregate</strong><br>
+ <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span>
+ </td>
+ <td>
+ <p>Groups and aggregates a table on a <a href="#group-windows">group window</a> and possibly one or more grouping keys. You have to close the "aggregate" with a select statement. And the select statement does not support "*" or aggregate functions.</p>
+{% highlight java %}
+AggregateFunction myAggFunc = new MyMinMax();
+tableEnv.registerFunction("myAggFunc", myAggFunc);
+
+Table table = input
+ .window(Tumble.over("5.minutes").on("rowtime").as("w")) // define window
+ .groupBy("key, w") // group by key and window
+ .aggregate("myAggFunc(a) as (x, y)")
+ .select("key, x, y, w.start, w.end"); // access window properties and aggregate results
+{% endhighlight %}
+ </td>
+ </tr>
+
+ <tr>
+ <td>
<strong>FlatAggregate</strong><br>
<span class="label label-primary">Streaming</span><br>
<span class="label label-info">Result Updating</span>
@@ -2837,7 +2857,7 @@ class MyMinMax extends AggregateFunction[Row, MyMinMaxAcc] {
}
}
-val myAggFunc: AggregateFunction = new MyMinMax
+val myAggFunc = new MyMinMax
val table = input
.groupBy('key)
.aggregate(myAggFunc('a) as ('x, 'y))
@@ -2848,6 +2868,25 @@ val table = input
<tr>
<td>
+ <strong>Group Window Aggregate</strong><br>
+ <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span>
+ </td>
+ <td>
+ <p>Groups and aggregates a table on a <a href="#group-windows">group window</a> and possibly one or more grouping keys. You have to close the "aggregate" with a select statement. And the select statement does not support "*" or aggregate functions.</p>
+{% highlight scala %}
+val myAggFunc = new MyMinMax
+val table = input
+ .window(Tumble over 5.minutes on 'rowtime as 'w) // define window
+ .groupBy('key, 'w) // group by key and window
+ .aggregate(myAggFunc('a) as ('x, 'y))
+ .select('key, 'x, 'y, 'w.start, 'w.end) // access window properties and aggregate results
+
+{% endhighlight %}
+ </td>
+ </tr>
+
+ <tr>
+ <td>
<strong>FlatAggregate</strong><br>
<span class="label label-primary">Streaming</span><br>
<span class="label label-info">Result Updating</span>
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/WindowGroupedTable.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/WindowGroupedTable.java
index 0e1cf84..7e5a3ac 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/WindowGroupedTable.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/WindowGroupedTable.java
@@ -56,6 +56,43 @@ public interface WindowGroupedTable {
Table select(Expression... fields);
/**
+ * Performs an aggregate operation on a window grouped table. You have to close the
+ * {@link #aggregate(String)} with a select statement. The output will be flattened if the
+ * output type is a composite type.
+ *
+ * <p>Example:
+ *
+ * <pre>
+ * {@code
+ * AggregateFunction aggFunc = new MyAggregateFunction();
+ * tableEnv.registerFunction("aggFunc", aggFunc);
+ * windowGroupedTable
+ * .aggregate("aggFunc(a, b) as (x, y, z)")
+ * .select("key, window.start, x, y, z")
+ * }
+ * </pre>
+ */
+ AggregatedTable aggregate(String aggregateFunction);
+
+ /**
+ * Performs an aggregate operation on a window grouped table. You have to close the
+ * {@link #aggregate(Expression)} with a select statement. The output will be flattened if the
+ * output type is a composite type.
+ *
+ * <p>Scala Example:
+ *
+ * <pre>
+ * {@code
+ * val aggFunc = new MyAggregateFunction
+ * windowGroupedTable
+ * .aggregate(aggFunc('a, 'b) as ('x, 'y, 'z))
+ * .select('key, 'window.start, 'x, 'y, 'z)
+ * }
+ * </pre>
+ */
+ AggregatedTable aggregate(Expression aggregateFunction);
+
+ /**
* Performs a flatAggregate operation on a window grouped table. FlatAggregate takes a
* TableAggregateFunction which returns multiple rows. Use a selection after flatAggregate.
*
@@ -63,7 +100,7 @@ public interface WindowGroupedTable {
*
* <pre>
* {@code
- * TableAggregateFunction tableAggFunc = new MyTableAggregateFunction
+ * TableAggregateFunction tableAggFunc = new MyTableAggregateFunction();
* tableEnv.registerFunction("tableAggFunc", tableAggFunc);
* windowGroupedTable
* .flatAggregate("tableAggFunc(a, b) as (x, y, z)")
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableImpl.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableImpl.java
index 8d33e10..ae08d24 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableImpl.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableImpl.java
@@ -36,6 +36,7 @@ import org.apache.flink.table.api.WindowGroupedTable;
import org.apache.flink.table.catalog.FunctionLookup;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionParser;
+import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.expressions.resolver.LookupCallResolver;
import org.apache.flink.table.functions.TemporalTableFunction;
import org.apache.flink.table.functions.TemporalTableFunctionImpl;
@@ -762,6 +763,16 @@ public class TableImpl implements Table {
}
@Override
+ public AggregatedTable aggregate(String aggregateFunction) {
+ return aggregate(ExpressionParser.parseExpression(aggregateFunction));
+ }
+
+ @Override
+ public AggregatedTable aggregate(Expression aggregateFunction) {
+ return new WindowAggregatedTableImpl(table, groupKeys, aggregateFunction, window);
+ }
+
+ @Override
public FlatAggregateTable flatAggregate(String tableAggregateFunction) {
return flatAggregate(ExpressionParser.parseExpression(tableAggregateFunction));
}
@@ -772,6 +783,62 @@ public class TableImpl implements Table {
}
}
+ private static final class WindowAggregatedTableImpl implements AggregatedTable {
+ private final TableImpl table;
+ private final List<Expression> groupKeys;
+ private final Expression aggregateFunction;
+ private final GroupWindow window;
+
+ private WindowAggregatedTableImpl(
+ TableImpl table,
+ List<Expression> groupKeys,
+ Expression aggregateFunction,
+ GroupWindow window) {
+ this.table = table;
+ this.groupKeys = groupKeys;
+ this.aggregateFunction = aggregateFunction;
+ this.window = window;
+ }
+
+ @Override
+ public Table select(String fields) {
+ return select(ExpressionParser.parseExpressionList(fields).toArray(new Expression[0]));
+ }
+
+ @Override
+ public Table select(Expression... fields) {
+ List<Expression> expressionsWithResolvedCalls = Arrays.stream(fields)
+ .map(f -> f.accept(table.lookupResolver))
+ .collect(Collectors.toList());
+ CategorizedExpressions extracted = OperationExpressionsUtils.extractAggregationsAndProperties(
+ expressionsWithResolvedCalls
+ );
+
+ if (!extracted.getAggregations().isEmpty()) {
+ throw new ValidationException("Aggregate functions cannot be used in the select right " +
+ "after the aggregate.");
+ }
+
+ if (extracted.getProjections().stream()
+ .anyMatch(p -> (p instanceof UnresolvedReferenceExpression)
+ && "*".equals(((UnresolvedReferenceExpression) p).getName()))) {
+ throw new ValidationException("Can not use * for window aggregate!");
+ }
+
+ return table.createTable(
+ table.operationTreeBuilder.project(
+ extracted.getProjections(),
+ table.operationTreeBuilder.windowAggregate(
+ groupKeys,
+ window,
+ extracted.getWindowProperties(),
+ aggregateFunction,
+ table.operationTree
+ )
+ ));
+ }
+ }
+
private static final class WindowFlatAggregateTableImpl implements FlatAggregateTable {
private final TableImpl table;
@@ -804,24 +871,30 @@ public class TableImpl implements Table {
expressionsWithResolvedCalls
);
- if (!extracted.getAggregations().isEmpty()) {
- throw new ValidationException("Aggregate functions cannot be used in the select right " +
- "after the flatAggregate.");
- }
-
- return table.createTable(
- table.operationTreeBuilder.project(
- extracted.getProjections(),
- table.operationTreeBuilder.windowTableAggregate(
- groupKeys,
- window,
- extracted.getWindowProperties(),
- tableAggFunction,
- table.operationTree
- ),
- // required for proper resolution of the time attribute in multi-windows
- true
- ));
+ if (!extracted.getAggregations().isEmpty()) {
+ throw new ValidationException("Aggregate functions cannot be used in the select right " +
+ "after the flatAggregate.");
+ }
+
+ if (extracted.getProjections().stream()
+ .anyMatch(p -> (p instanceof UnresolvedReferenceExpression)
+ && "*".equals(((UnresolvedReferenceExpression) p).getName()))) {
+ throw new ValidationException("Can not use * for window aggregate!");
+ }
+
+ return table.createTable(
+ table.operationTreeBuilder.project(
+ extracted.getProjections(),
+ table.operationTreeBuilder.windowTableAggregate(
+ groupKeys,
+ window,
+ extracted.getWindowProperties(),
+ tableAggFunction,
+ table.operationTree
+ ),
+ // required for proper resolution of the time attribute in multi-windows
+ true
+ ));
}
}
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationTreeBuilder.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationTreeBuilder.java
index 03406de..7510cb9 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationTreeBuilder.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationTreeBuilder.java
@@ -58,6 +58,7 @@ import org.apache.flink.table.operations.utils.factories.SortOperationFactory;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
+import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.table.typeutils.FieldInfoUtils;
import org.apache.flink.util.Preconditions;
@@ -253,6 +254,73 @@ public final class OperationTreeBuilder {
child);
}
+ public QueryOperation windowAggregate(
+ List<Expression> groupingExpressions,
+ GroupWindow window,
+ List<Expression> windowProperties,
+ Expression aggregateFunction,
+ QueryOperation child) {
+
+ ExpressionResolver resolver = getResolver(child);
+ Expression resolvedAggregate = aggregateFunction.accept(lookupResolver);
+ AggregateWithAlias aggregateWithAlias = resolvedAggregate.accept(new ExtractAliasAndAggregate(true, resolver));
+
+ List<Expression> groupsAndAggregate = new ArrayList<>(groupingExpressions);
+ groupsAndAggregate.add(aggregateWithAlias.aggregate);
+ List<Expression> namedGroupsAndAggregate = addAliasToTheCallInAggregate(
+ Arrays.asList(child.getTableSchema().getFieldNames()),
+ groupsAndAggregate);
+
+ // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to
+ // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the
+ // table aggregate function in Step6.
+ List<Expression> newGroupingExpressions = namedGroupsAndAggregate.subList(0, groupingExpressions.size());
+
+ // Step2: turn agg to a named agg, because it will be verified later.
+ Expression aggregateRenamed = namedGroupsAndAggregate.get(groupingExpressions.size());
+
+ // Step3: resolve expressions, including grouping, aggregates and window properties.
+ ResolvedGroupWindow resolvedWindow = aggregateOperationFactory.createResolvedWindow(window, resolver);
+ ExpressionResolver resolverWithWindowReferences = ExpressionResolver.resolverFor(
+ tableReferenceLookup,
+ functionCatalog,
+ child)
+ .withLocalReferences(
+ new LocalReferenceExpression(
+ resolvedWindow.getAlias(),
+ resolvedWindow.getTimeAttribute().getOutputDataType()))
+ .build();
+
+ List<ResolvedExpression> convertedGroupings = resolverWithWindowReferences.resolve(newGroupingExpressions);
+ List<ResolvedExpression> convertedAggregates = resolverWithWindowReferences.resolve(Collections.singletonList(
+ aggregateRenamed));
+ List<ResolvedExpression> convertedProperties = resolverWithWindowReferences.resolve(windowProperties);
+
+ // Step4: create window agg operation
+ QueryOperation aggregateOperation = aggregateOperationFactory.createWindowAggregate(
+ convertedGroupings,
+ Collections.singletonList(convertedAggregates.get(0)),
+ convertedProperties,
+ resolvedWindow,
+ child);
+
+ // Step5: flatten the aggregate function
+ String[] aggNames = aggregateOperation.getTableSchema().getFieldNames();
+ List<Expression> flattenedExpressions = Arrays.stream(aggNames)
+ .map(ApiExpressionUtils::unresolvedRef)
+ .collect(Collectors.toCollection(ArrayList::new));
+ flattenedExpressions.set(
+ groupingExpressions.size(),
+ unresolvedCall(
+ BuiltInFunctionDefinitions.FLATTEN,
+ unresolvedRef(aggNames[groupingExpressions.size()])));
+ QueryOperation flattenedProjection = this.project(flattenedExpressions, aggregateOperation);
+
+ // Step6: add a top project to alias the output fields of the aggregate. Also, project the
+ // window attribute.
+ return aliasBackwardFields(flattenedProjection, aggregateWithAlias.aliases, groupingExpressions.size());
+ }
+
public QueryOperation join(
QueryOperation left,
QueryOperation right,
@@ -405,21 +473,30 @@ public final class OperationTreeBuilder {
public QueryOperation aggregate(List<Expression> groupingExpressions, Expression aggregate, QueryOperation child) {
Expression resolvedAggregate = aggregate.accept(lookupResolver);
- AggregateWithAlias aggregateWithAlias = resolvedAggregate.accept(new ExtractAliasAndAggregate());
+ AggregateWithAlias aggregateWithAlias =
+ resolvedAggregate.accept(new ExtractAliasAndAggregate(true, getResolver(child)));
- // turn agg to a named agg, because it will be verified later.
- String[] childNames = child.getTableSchema().getFieldNames();
- Expression aggregateRenamed = addAliasToTheCallInGroupings(
- Arrays.asList(childNames),
- Collections.singletonList(aggregateWithAlias.aggregate)).get(0);
+ List<Expression> groupsAndAggregate = new ArrayList<>(groupingExpressions);
+ groupsAndAggregate.add(aggregateWithAlias.aggregate);
+ List<Expression> namedGroupsAndAggregate = addAliasToTheCallInAggregate(
+ Arrays.asList(child.getTableSchema().getFieldNames()),
+ groupsAndAggregate);
- // get agg table
+ // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to
+ // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the
+ // aggregate function in Step5.
+ List<Expression> newGroupingExpressions = namedGroupsAndAggregate.subList(0, groupingExpressions.size());
+
+ // Step2: turn agg to a named agg, because it will be verified later.
+ Expression aggregateRenamed = namedGroupsAndAggregate.get(groupingExpressions.size());
+
+ // Step3: get agg table
QueryOperation aggregateOperation = this.aggregate(
- groupingExpressions,
+ newGroupingExpressions,
Collections.singletonList(aggregateRenamed),
child);
- // flatten the aggregate function
+ // Step4: flatten the aggregate function
String[] aggNames = aggregateOperation.getTableSchema().getFieldNames();
List<Expression> flattenedExpressions = Arrays.asList(aggNames)
.subList(0, groupingExpressions.size())
@@ -433,7 +510,7 @@ public final class OperationTreeBuilder {
QueryOperation flattenedProjection = this.project(flattenedExpressions, aggregateOperation);
- // add alias
+ // Step5: add alias
return aliasBackwardFields(flattenedProjection, aggregateWithAlias.aliases, groupingExpressions.size());
}
@@ -448,6 +525,16 @@ public final class OperationTreeBuilder {
}
private static class ExtractAliasAndAggregate extends ApiExpressionDefaultVisitor<AggregateWithAlias> {
+
+ // need this flag to validate alias, i.e., the length of alias and function result type should be same.
+ private boolean isRowbasedAggregate = false;
+ private ExpressionResolver resolver = null;
+
+ public ExtractAliasAndAggregate(boolean isRowbasedAggregate, ExpressionResolver resolver) {
+ this.isRowbasedAggregate = isRowbasedAggregate;
+ this.resolver = resolver;
+ }
+
@Override
public AggregateWithAlias visit(UnresolvedCallExpression unresolvedCall) {
if (ApiExpressionUtils.isFunction(unresolvedCall, BuiltInFunctionDefinitions.AS)) {
@@ -489,6 +576,9 @@ public final class OperationTreeBuilder {
fieldNames = Collections.emptyList();
}
} else {
+ ResolvedExpression resolvedExpression =
+ resolver.resolve(Collections.singletonList(unresolvedCall)).get(0);
+ validateAlias(aliases, resolvedExpression, isRowbasedAggregate);
fieldNames = aliases;
}
return Optional.of(new AggregateWithAlias(unresolvedCall, fieldNames));
@@ -501,6 +591,27 @@ public final class OperationTreeBuilder {
protected AggregateWithAlias defaultMethod(Expression expression) {
throw new ValidationException("Aggregate function expected. Got: " + expression);
}
+
+ private void validateAlias(
+ List<String> aliases,
+ ResolvedExpression resolvedExpression,
+ Boolean isRowbasedAggregate) {
+
+ int length = TypeConversions
+ .fromDataTypeToLegacyInfo(resolvedExpression.getOutputDataType()).getArity();
+ int callArity = isRowbasedAggregate ? length : 1;
+ int aliasesSize = aliases.size();
+
+ if ((0 < aliasesSize) && (aliasesSize != callArity)) {
+ throw new ValidationException(String.format(
+ "List of column aliases must have same degree as table; " +
+ "the returned table of function '%s' has " +
+ "%d columns, whereas alias list has %d columns",
+ resolvedExpression,
+ callArity,
+ aliasesSize));
+ }
+ }
}
public QueryOperation tableAggregate(
@@ -511,7 +622,7 @@ public final class OperationTreeBuilder {
// Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to
// groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the
// table aggregate function in Step4.
- List<Expression> newGroupingExpressions = addAliasToTheCallInGroupings(
+ List<Expression> newGroupingExpressions = addAliasToTheCallInAggregate(
Arrays.asList(child.getTableSchema().getFieldNames()),
groupingExpressions);
@@ -540,7 +651,7 @@ public final class OperationTreeBuilder {
// Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to
// groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the
// table aggregate function in Step4.
- List<Expression> newGroupingExpressions = addAliasToTheCallInGroupings(
+ List<Expression> newGroupingExpressions = addAliasToTheCallInAggregate(
Arrays.asList(child.getTableSchema().getFieldNames()),
groupingExpressions);
@@ -605,17 +716,17 @@ public final class OperationTreeBuilder {
/**
* Add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to
- * groupBy(a % 5 as TMP_0).
+ * groupBy(a % 5 as TMP_0) or make aggregate a named aggregate.
*/
- private List<Expression> addAliasToTheCallInGroupings(
+ private List<Expression> addAliasToTheCallInAggregate(
List<String> inputFieldNames,
- List<Expression> groupingExpressions) {
+ List<Expression> expressions) {
int attrNameCntr = 0;
Set<String> usedFieldNames = new HashSet<>(inputFieldNames);
List<Expression> result = new ArrayList<>();
- for (Expression groupingExpression : groupingExpressions) {
+ for (Expression groupingExpression : expressions) {
if (groupingExpression instanceof UnresolvedCallExpression &&
!ApiExpressionUtils.isFunction(groupingExpression, BuiltInFunctionDefinitions.AS)) {
String tempName = getUniqueName("TMP_" + attrNameCntr, usedFieldNames);
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
index 41ba5fc..2501e63 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
@@ -345,14 +345,14 @@ class AggregateTest extends TableTestBase {
}
@Test
- def testSelectStar(): Unit = {
+ def testSelectStarAndGroupByCall(): Unit = {
val util = streamTestUtil()
val table = util.addTable[(Int, Long, String)](
"MyTable", 'a, 'b, 'c)
val testAgg = new CountMinMax
val resultTable = table
- .groupBy('b)
+ .groupBy('b % 5)
.aggregate(testAgg('a))
.select('*)
@@ -364,12 +364,12 @@ class AggregateTest extends TableTestBase {
unaryNode(
"DataStreamCalc",
streamTableNode(table),
- term("select", "a", "b")
+ term("select", "a", "MOD(b, 5) AS TMP_0")
),
- term("groupBy", "b"),
- term("select", "b", "CountMinMax(a) AS TMP_0")
+ term("groupBy", "TMP_0"),
+ term("select", "TMP_0", "CountMinMax(a) AS TMP_1")
),
- term("select", "b", "TMP_0.f0 AS f0", "TMP_0.f1 AS f1", "TMP_0.f2 AS f2")
+ term("select", "TMP_0", "TMP_1.f0 AS f0", "TMP_1.f1 AS f1", "TMP_1.f2 AS f2")
)
util.verifyTable(resultTable, expected)
}
@@ -428,4 +428,37 @@ class AggregateTest extends TableTestBase {
)
util.verifyTable(resultTable, expected)
}
+
+ @Test
+ def testAggregateOnWindowedTable(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Long, String)](
+ "MyTable", 'a, 'b, 'c, 'rowtime.rowtime)
+ val testAgg = new CountMinMax
+
+ val result = table
+ .window(Tumble over 15.minute on 'rowtime as 'w)
+ .groupBy('w, 'b % 3)
+ .aggregate(testAgg('a) as ('x, 'y, 'z))
+ .select('w.start, 'x, 'y)
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(table),
+ term("select", "a", "rowtime", "MOD(b, 3) AS TMP_0")
+ ),
+ term("groupBy", "TMP_0"),
+ term("window", "TumblingGroupWindow('w, 'rowtime, 900000.millis)"),
+ term("select", "TMP_0", "CountMinMax(a) AS TMP_1", "start('w) AS EXPR$0")
+ ),
+ term("select", "EXPR$0", "TMP_1.f0 AS x", "TMP_1.f1 AS y")
+ )
+
+ util.verifyTable(result, expected)
+ }
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
index 5ee4471..7de35f3 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
@@ -235,6 +235,31 @@ class AggregateStringExpressionTest extends TableTestBase {
verifyTableEquals(resScala, resJava)
}
+
+ @Test
+ def testAggregateWithWindow(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[TestPojo]('int, 'long.rowtime as 'rowtime, 'string)
+
+ val testAgg = new CountMinMax
+ util.tableEnv.registerFunction("testAgg", testAgg)
+
+ // Expression / Scala API
+ val resScala = t
+ .window(Tumble over 50.milli on 'rowtime as 'w1)
+ .groupBy('w1, 'string)
+ .aggregate(testAgg('int) as ('x, 'y, 'z))
+ .select('string, 'x, 'y, 'w1.start, 'w1.end)
+
+ // String / Java API
+ val resJava = t
+ .window(Tumble.over("50.milli").on("rowtime").as("w1"))
+ .groupBy("w1, string")
+ .aggregate("testAgg(int) as (x, y, z)")
+ .select("string, x, y, w1.start, w1.end")
+
+ verifyTableEquals(resJava, resScala)
+ }
}
class TestPojo() {
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/AggregateValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/AggregateValidationTest.scala
index a8de009..4b25ea2 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/AggregateValidationTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/AggregateValidationTest.scala
@@ -21,7 +21,7 @@ package org.apache.flink.table.api.stream.table.validation
import org.apache.flink.api.scala._
import org.apache.flink.table.api.{ExpressionParserException, ValidationException}
import org.apache.flink.table.api.scala._
-import org.apache.flink.table.utils.{TableFunc0, TableTestBase}
+import org.apache.flink.table.utils.{CountMinMax, TableFunc0, TableTestBase}
import org.junit.Test
class AggregateValidationTest extends TableTestBase {
@@ -123,4 +123,23 @@ class AggregateValidationTest extends TableTestBase {
// must fail. Only one AggregateFunction can be used in aggregate
.aggregate("sum(c), count(b)")
}
+
+ @Test
+ def testInvalidAlias(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("List of column aliases must have same degree as " +
+ "table; the returned table of function 'minMax(b)' has 3 columns, " +
+ "whereas alias list has 2 columns")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+ val minMax = new CountMinMax
+
+ util.tableEnv.registerFunction("minMax", minMax)
+ table
+ .groupBy('a)
+ // must fail. Invalid alias length
+ .aggregate("minMax(b) as (x, y)")
+ .select("x, y")
+ }
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowTableAggregateValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowTableAggregateValidationTest.scala
index 21369ca..714e1dd 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowTableAggregateValidationTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowTableAggregateValidationTest.scala
@@ -254,4 +254,19 @@ class GroupWindowTableAggregateValidationTest extends TableTestBase {
.flatAggregate(top3('int))
.select('string, 'f0.count)
}
+
+ @Test
+ def testInvalidStarInSelection(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Can not use * for window aggregate!")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('long, 'int, 'string, 'proctime.proctime)
+
+ table
+ .window(Tumble over 2.rows on 'proctime as 'w)
+ .groupBy('string, 'w)
+ .flatAggregate(top3('int))
+ .select('*)
+ }
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowValidationTest.scala
index 7b7ff87..2b03bc6 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowValidationTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/validation/GroupWindowValidationTest.scala
@@ -22,7 +22,7 @@ import org.apache.flink.api.scala._
import org.apache.flink.table.api.{Session, Slide, Tumble, ValidationException}
import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvgWithMerge
import org.apache.flink.table.api.scala._
-import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.table.utils.{CountMinMax, TableTestBase}
import org.junit.Test
class GroupWindowValidationTest extends TableTestBase {
@@ -290,4 +290,37 @@ class GroupWindowValidationTest extends TableTestBase {
.groupBy('w, 'string)
.select('string, 'w.start, 'w.end) // invalid start/end on rows-count window
}
+
+ @Test
+ def testInvalidAggregateInSelection(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Aggregate functions cannot be used in the select " +
+ "right after the aggregate.")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('long, 'int, 'string, 'proctime.proctime)
+ val testAgg = new CountMinMax
+
+ table
+ .window(Tumble over 2.rows on 'proctime as 'w)
+ .groupBy('string, 'w)
+ .aggregate(testAgg('int))
+ .select('string, 'f0.count)
+ }
+
+ @Test
+ def testInvalidStarInSelection(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Can not use * for window aggregate!")
+
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('long, 'int, 'string, 'proctime.proctime)
+ val testAgg = new CountMinMax
+
+ table
+ .window(Tumble over 2.rows on 'proctime as 'w)
+ .groupBy('string, 'w)
+ .aggregate(testAgg('int))
+ .select('*)
+ }
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala
index 9054bc8..130019d 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala
@@ -32,6 +32,7 @@ import org.apache.flink.table.functions.aggfunctions.CountAggFunction
import org.apache.flink.table.runtime.stream.table.GroupWindowITCase._
import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, CountDistinctWithMerge, WeightedAvg, WeightedAvgWithMerge}
import org.apache.flink.table.runtime.utils.StreamITCase
+import org.apache.flink.table.utils.CountMinMax
import org.apache.flink.test.util.AbstractTestBase
import org.apache.flink.types.Row
import org.junit.Assert._
@@ -442,6 +443,37 @@ class GroupWindowITCase extends AbstractTestBase {
"null,1,1970-01-01 00:00:00.03,1970-01-01 00:00:00.033")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
+
+ @Test
+ def testRowbasedAggregateWithEventTimeTumblingWindow(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+ val tEnv = StreamTableEnvironment.create(env)
+ StreamITCase.testResults = mutable.MutableList()
+
+ val stream = env
+ .fromCollection(data)
+ .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset[(Long, Int, String)](0L))
+ val table = stream.toTable(tEnv, 'long, 'int, 'string, 'rowtime.rowtime)
+ val minMax = new CountMinMax
+
+ val windowedTable = table
+ .window(Tumble over 5.milli on 'rowtime as 'w)
+ .groupBy('w, 'string)
+ .aggregate(minMax('int) as ('x, 'y, 'z))
+ .select('string, 'x, 'y, 'z, 'w.start, 'w.end)
+
+ val results = windowedTable.toAppendStream[Row]
+ results.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = Seq(
+ "Hello world,1,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01",
+ "Hello world,1,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02",
+ "Hello,2,2,2,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005",
+ "Hi,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
}
object GroupWindowITCase {