You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/05/16 20:03:46 UTC
flink git commit: [FLINK-6462] [table] Add requiresOver method to
AggregateFunction.
Repository: flink
Updated Branches:
refs/heads/release-1.3 fc2012702 -> 629d3633b
[FLINK-6462] [table] Add requiresOver method to AggregateFunction.
This closes #3851.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/629d3633
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/629d3633
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/629d3633
Branch: refs/heads/release-1.3
Commit: 629d3633bcc458dc4ba5e660f48ed42b1a90b834
Parents: fc20127
Author: sunjincheng121 <su...@gmail.com>
Authored: Mon May 8 12:04:47 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue May 16 21:36:27 2017 +0200
----------------------------------------------------------------------
.../sql/validate/SqlUserDefinedAggFunction.java | 82 ++++++++++++++++++++
.../flink/table/expressions/aggregations.scala | 21 ++---
.../table/functions/AggregateFunction.scala | 5 ++
.../table/functions/utils/AggSqlFunction.scala | 11 ++-
.../utils/UserDefinedFunctionUtils.scala | 2 +-
.../flink/table/plan/logical/operators.scala | 11 ++-
.../api/java/utils/UserDefinedAggFunctions.java | 24 ++++++
.../api/scala/batch/sql/AggregationsTest.scala | 43 ++++++++++
.../scala/batch/sql/WindowAggregateTest.scala | 16 +++-
.../api/scala/batch/table/GroupWindowTest.scala | 16 +++-
.../validation/AggregationsValidationTest.scala | 28 +++++--
.../api/scala/stream/sql/AggregationsTest.scala | 42 ++++++++++
.../scala/stream/sql/WindowAggregateTest.scala | 13 +++-
.../stream/table/GroupAggregationsTest.scala | 12 +++
.../scala/stream/table/GroupWindowTest.scala | 17 +++-
.../table/runtime/harness/HarnessTestBase.scala | 43 ++++------
.../flink/table/utils/TableTestBase.scala | 13 ++++
17 files changed, 339 insertions(+), 60 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/main/java/org/apache/calcite/sql/validate/SqlUserDefinedAggFunction.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/java/org/apache/calcite/sql/validate/SqlUserDefinedAggFunction.java b/flink-libraries/flink-table/src/main/java/org/apache/calcite/sql/validate/SqlUserDefinedAggFunction.java
new file mode 100644
index 0000000..3733d61
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/java/org/apache/calcite/sql/validate/SqlUserDefinedAggFunction.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.sql.validate;
+
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.schema.AggregateFunction;
+import org.apache.calcite.schema.FunctionParameter;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlFunctionCategory;
+import org.apache.calcite.sql.SqlIdentifier;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.type.SqlOperandTypeChecker;
+import org.apache.calcite.sql.type.SqlOperandTypeInference;
+import org.apache.calcite.sql.type.SqlReturnTypeInference;
+import org.apache.calcite.util.Util;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Lists;
+
+import java.util.List;
+
+/**
+ * User-defined aggregate function.
+ *
+ * <p>Created by the validator, after resolving a function call to a function
+ * defined in a Calcite schema.</p>
+ */
+public class SqlUserDefinedAggFunction extends SqlAggFunction {
+ public final AggregateFunction function;
+
+ public SqlUserDefinedAggFunction(SqlIdentifier opName,
+ SqlReturnTypeInference returnTypeInference,
+ SqlOperandTypeInference operandTypeInference,
+ SqlOperandTypeChecker operandTypeChecker, AggregateFunction function) {
+ this(opName, returnTypeInference, operandTypeInference, operandTypeChecker, function,false);
+ }
+
+ public SqlUserDefinedAggFunction(SqlIdentifier opName,
+ SqlReturnTypeInference returnTypeInference,
+ SqlOperandTypeInference operandTypeInference,
+ SqlOperandTypeChecker operandTypeChecker,
+ AggregateFunction function,
+ Boolean requestsOver) {
+ super(Util.last(opName.names), opName, SqlKind.OTHER_FUNCTION,
+ returnTypeInference, operandTypeInference, operandTypeChecker,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION, false, requestsOver);
+ this.function = function;
+ }
+
+ @SuppressWarnings("deprecation")
+ public List<RelDataType> getParameterTypes(
+ final RelDataTypeFactory typeFactory) {
+ return Lists.transform(function.getParameters(),
+ new Function<FunctionParameter, RelDataType>() {
+ public RelDataType apply(FunctionParameter input) {
+ return input.getType(typeFactory);
+ }
+ });
+ }
+
+ @SuppressWarnings("deprecation")
+ public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
+ return function.getReturnType(typeFactory);
+ }
+}
+
+// End SqlUserDefinedAggFunction.java
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index 7b180ae..6d906b9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -253,29 +253,20 @@ case class AggFunctionCall(
override def toString(): String = s"${aggregateFunction.getClass.getSimpleName}($args)"
override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
- val sqlFunction = AggSqlFunction(aggregateFunction.getClass.getSimpleName,
- aggregateFunction,
- resultType,
- typeFactory)
- relBuilder.aggregateCall(sqlFunction, false, null, name, args.map(_.toRexNode): _*)
+ relBuilder.aggregateCall(this.getSqlAggFunction(), false, null, name, args.map(_.toRexNode): _*)
}
override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
- AggSqlFunction(aggregateFunction.getClass.getSimpleName,
+ val sqlAgg = AggSqlFunction(aggregateFunction.getClass.getSimpleName,
aggregateFunction,
resultType,
- typeFactory)
+ typeFactory,
+ aggregateFunction.requiresOver)
+ sqlAgg
}
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
- val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
- relBuilder.call(
- AggSqlFunction(aggregateFunction.getClass.getSimpleName,
- aggregateFunction,
- resultType,
- typeFactory),
- args.map(_.toRexNode): _*)
+ relBuilder.call(this.getSqlAggFunction(), args.map(_.toRexNode): _*)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index 9c79439..f90860b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -135,4 +135,9 @@ abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
* @return the aggregation result
*/
def getValue(accumulator: ACC): T
+
+ /**
+ * whether this aggregate only used in OVER clause
+ */
+ def requiresOver: Boolean = false
}
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
index c3f6c4c..816bc52 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
@@ -43,7 +43,8 @@ class AggSqlFunction(
name: String,
aggregateFunction: AggregateFunction[_, _],
returnType: TypeInformation[_],
- typeFactory: FlinkTypeFactory)
+ typeFactory: FlinkTypeFactory,
+ requiresOver: Boolean)
extends SqlUserDefinedAggFunction(
new SqlIdentifier(name, SqlParserPos.ZERO),
createReturnTypeInference(returnType, typeFactory),
@@ -51,7 +52,8 @@ class AggSqlFunction(
createOperandTypeChecker(aggregateFunction),
// Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
// will be generated when translating the calcite relnode to flink runtime execution plan
- null
+ null,
+ requiresOver
) {
def getFunction: AggregateFunction[_, _] = aggregateFunction
@@ -63,9 +65,10 @@ object AggSqlFunction {
name: String,
aggregateFunction: AggregateFunction[_, _],
returnType: TypeInformation[_],
- typeFactory: FlinkTypeFactory): AggSqlFunction = {
+ typeFactory: FlinkTypeFactory,
+ requiresOver: Boolean): AggSqlFunction = {
- new AggSqlFunction(name, aggregateFunction, returnType, typeFactory)
+ new AggSqlFunction(name, aggregateFunction, returnType, typeFactory, requiresOver)
}
private[flink] def createOperandTypeInference(
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 11174de..1016574 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -291,7 +291,7 @@ object UserDefinedFunctionUtils {
//check if a qualified accumulate method exists before create Sql function
checkAndExtractMethods(aggFunction, "accumulate")
val resultType: TypeInformation[_] = getResultTypeOfAggregateFunction(aggFunction, typeInfo)
- AggSqlFunction(name, aggFunction, resultType, typeFactory)
+ AggSqlFunction(name, aggFunction, resultType, typeFactory, aggFunction.requiresOver)
}
// ----------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
index bfb6cbf..6777ef5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
@@ -215,7 +215,7 @@ case class Aggregate(
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
-
+ implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder
val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate]
val groupingExprs = resolvedAggregate.groupingExpressions
val aggregateExprs = resolvedAggregate.aggregateExpressions
@@ -223,6 +223,10 @@ case class Aggregate(
groupingExprs.foreach(validateGroupingExpression)
def validateAggregateExpression(expr: Expression): Unit = expr match {
+ // check aggregate function
+ case aggExpr: Aggregation
+ if aggExpr.getSqlAggFunction.requiresOver =>
+ failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].")
// check no nested aggregation exists.
case aggExpr: Aggregation =>
aggExpr.children.foreach { child =>
@@ -602,6 +606,7 @@ case class WindowAggregate(
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder
val resolvedWindowAggregate = super.validate(tableEnv).asInstanceOf[WindowAggregate]
val groupingExprs = resolvedWindowAggregate.groupingExpressions
val aggregateExprs = resolvedWindowAggregate.aggregateExpressions
@@ -609,6 +614,10 @@ case class WindowAggregate(
groupingExprs.foreach(validateGroupingExpression)
def validateAggregateExpression(expr: Expression): Unit = expr match {
+ // check aggregate function
+ case aggExpr: Aggregation
+ if aggExpr.getSqlAggFunction.requiresOver =>
+ failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].")
// check no nested aggregation exists.
case aggExpr: Aggregation =>
aggExpr.children.foreach { child =>
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java
index cfddc57..a51a4af 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java
@@ -23,6 +23,30 @@ import org.apache.flink.table.functions.AggregateFunction;
import java.util.Iterator;
public class UserDefinedAggFunctions {
+ // Accumulator for test requiresOver
+ public static class Accumulator0 extends Tuple2<Long, Integer>{}
+
+ // Test for requiresOver
+ public static class OverAgg0 extends AggregateFunction<Long, Accumulator0> {
+ @Override
+ public Accumulator0 createAccumulator() {
+ return new Accumulator0();
+ }
+
+ @Override
+ public Long getValue(Accumulator0 accumulator) {
+ return 1L;
+ }
+
+ //Overloaded accumulate method
+ public void accumulate(Accumulator0 accumulator, long iValue, int iWeight) {
+ }
+
+ @Override
+ public boolean requiresOver() {
+ return true;
+ }
+ }
// Accumulator for WeightedAvg
public static class WeightedAvgAccum extends Tuple2<Long, Integer> {
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsTest.scala
new file mode 100644
index 0000000..bf150c3
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsTest.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.scala.batch.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.OverAgg0
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.utils.TableTestBase
+import org.junit.Test
+
+class AggregationsTest extends TableTestBase {
+
+ /**
+ * OVER clause is necessary for [[OverAgg0]] window function.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testOverAggregation(): Unit = {
+
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("T", 'a, 'b, 'c)
+
+ util.addFunction("overAgg", new OverAgg0)
+
+ val sqlQuery = "SELECT overAgg(b, a) FROM T"
+ util.tEnv.sql(sqlQuery)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/WindowAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/WindowAggregateTest.scala
index 71d0002..328c03c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/WindowAggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/WindowAggregateTest.scala
@@ -22,7 +22,7 @@ import java.sql.Timestamp
import org.apache.flink.api.scala._
import org.apache.flink.table.api.{TableException, ValidationException}
-import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvgWithMerge
+import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.{OverAgg0, WeightedAvgWithMerge}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.utils.TableTestBase
@@ -31,6 +31,20 @@ import org.junit.Test
class WindowAggregateTest extends TableTestBase {
+ /**
+ * OVER clause is necessary for [[OverAgg0]] window function.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testOverAggregation(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts)
+ util.addFunction("overAgg", new OverAgg0)
+
+ val sqlQuery = "SELECT overAgg(b, a) FROM T GROUP BY TUMBLE(ts, INTERVAL '2' HOUR)"
+
+ util.tEnv.sql(sqlQuery)
+ }
+
@Test
def testNonPartitionedTumbleWindow(): Unit = {
val util = batchTestUtil()
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/GroupWindowTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/GroupWindowTest.scala
index aa6edd3..12e8897 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/GroupWindowTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/GroupWindowTest.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.api.scala.batch.table
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.ValidationException
-import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvgWithMerge
+import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.{OverAgg0, WeightedAvgWithMerge}
import org.apache.flink.table.expressions.WindowReference
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.utils.TableTestBase
@@ -33,6 +33,20 @@ class GroupWindowTest extends TableTestBase {
// Common test
//===============================================================================================
+ /**
+ * OVER clause is necessary for [[OverAgg0]] window function.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testOverAggregation(): Unit = {
+ val util = batchTestUtil()
+ val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
+ val overAgg = new OverAgg0
+ table
+ .window(Tumble over 5.milli on 'long as 'w)
+ .groupBy('string,'w)
+ .select(overAgg('long, 'int))
+ }
+
@Test(expected = classOf[ValidationException])
def testGroupByWithoutWindowAlias(): Unit = {
val util = batchTestUtil()
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/validation/AggregationsValidationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/validation/AggregationsValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/validation/AggregationsValidationTest.scala
index 278711c..8e90fa8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/validation/AggregationsValidationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/validation/AggregationsValidationTest.scala
@@ -20,13 +20,25 @@ package org.apache.flink.table.api.scala.batch.table.validation
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
-import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvgWithMergeAndReset
+import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.{OverAgg0, WeightedAvgWithMergeAndReset}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{TableEnvironment, ValidationException}
import org.junit._
class AggregationsValidationTest {
+ /**
+ * OVER clause is necessary for [[OverAgg0]] window function.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testOverAggregation(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env)
+ val input = CollectionDataSets.get3TupleDataSet(env).toTable(tableEnv, 'a, 'b, 'c)
+ val overAgg = new OverAgg0
+ input.select('c.count, overAgg('b, 'a))
+ }
+
@Test(expected = classOf[ValidationException])
def testNonWorkingAggregationDataTypes(): Unit = {
@@ -150,7 +162,7 @@ class AggregationsValidationTest {
@Test(expected = classOf[ValidationException])
@throws[Exception]
def testNoNestedAggregationsJava() {
- val env= ExecutionEnvironment.getExecutionEnvironment
+ val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env)
val table = env.fromElements((1f, "Hello")).toTable(tableEnv)
// Must fail. Aggregation on aggregation not allowed.
@@ -160,7 +172,7 @@ class AggregationsValidationTest {
@Test(expected = classOf[ValidationException])
@throws[Exception]
def testGroupingOnNonExistentFieldJava() {
- val env= ExecutionEnvironment.getExecutionEnvironment
+ val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env)
val input = CollectionDataSets.get3TupleDataSet(env).toTable(tableEnv, 'a, 'b, 'c)
input
@@ -172,7 +184,7 @@ class AggregationsValidationTest {
@Test(expected = classOf[ValidationException])
@throws[Exception]
def testGroupingInvalidSelectionJava() {
- val env= ExecutionEnvironment.getExecutionEnvironment
+ val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env)
val input = CollectionDataSets.get3TupleDataSet(env).toTable(tableEnv, 'a, 'b, 'c)
input
@@ -184,7 +196,7 @@ class AggregationsValidationTest {
@Test(expected = classOf[ValidationException])
@throws[Exception]
def testUnknownUdAggJava() {
- val env= ExecutionEnvironment.getExecutionEnvironment
+ val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env)
val input = CollectionDataSets.get3TupleDataSet(env).toTable(tableEnv, 'a, 'b, 'c)
input
@@ -195,7 +207,7 @@ class AggregationsValidationTest {
@Test(expected = classOf[ValidationException])
@throws[Exception]
def testGroupingUnknownUdAggJava() {
- val env= ExecutionEnvironment.getExecutionEnvironment
+ val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env)
val input = CollectionDataSets.get3TupleDataSet(env).toTable(tableEnv, 'a, 'b, 'c)
input
@@ -207,7 +219,7 @@ class AggregationsValidationTest {
@Test(expected = classOf[ValidationException])
@throws[Exception]
def testInvalidUdAggArgsJava() {
- val env= ExecutionEnvironment.getExecutionEnvironment
+ val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env)
val myWeightedAvg = new WeightedAvgWithMergeAndReset
@@ -222,7 +234,7 @@ class AggregationsValidationTest {
@Test(expected = classOf[ValidationException])
@throws[Exception]
def testGroupingInvalidUdAggArgsJava() {
- val env= ExecutionEnvironment.getExecutionEnvironment
+ val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env)
val myWeightedAvg = new WeightedAvgWithMergeAndReset
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/AggregationsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/AggregationsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/AggregationsTest.scala
new file mode 100644
index 0000000..585c390
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/AggregationsTest.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.scala.stream.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.OverAgg0
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.utils.{StreamTableTestUtil, TableTestBase}
+import org.junit.Test
+
+class AggregationsTest extends TableTestBase {
+ private val streamUtil: StreamTableTestUtil = streamTestUtil()
+ streamUtil.addTable[(Int, String, Long)]("MyTable", 'a, 'b, 'c)
+
+ /**
+ * OVER clause is necessary for [[OverAgg0]] window function.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testOverAggregation(): Unit = {
+ streamUtil.addFunction("overAgg", new OverAgg0)
+
+ val sqlQuery = "SELECT overAgg(c, a) FROM MyTable"
+
+ streamUtil.tEnv.sql(sqlQuery)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
index 125d071..3729ef0 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
@@ -19,7 +19,7 @@ package org.apache.flink.table.api.scala.stream.sql
import org.apache.flink.api.scala._
import org.apache.flink.table.api.{TableException, ValidationException}
-import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvgWithMerge
+import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.{OverAgg0, WeightedAvgWithMerge}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.utils.TableTestUtil._
@@ -31,6 +31,17 @@ class WindowAggregateTest extends TableTestBase {
streamUtil.addTable[(Int, String, Long)](
"MyTable", 'a, 'b, 'c, 'proctime.proctime, 'rowtime.rowtime)
+ /**
+ * OVER clause is necessary for [[OverAgg0]] window function.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testOverAggregation(): Unit = {
+ streamUtil.addFunction("overAgg", new OverAgg0)
+
+ val sqlQuery = "SELECT overAgg(c, a) FROM MyTable"
+ streamUtil.tEnv.sql(sqlQuery)
+ }
+
@Test
def testGroupbyWithoutWindow() = {
val sql = "SELECT COUNT(a) FROM MyTable GROUP BY b"
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala
index 520592c..a16688e 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala
@@ -24,10 +24,22 @@ import org.apache.flink.table.utils.TableTestBase
import org.junit.Test
import org.apache.flink.table.api.scala._
import org.apache.flink.api.scala._
+import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.OverAgg0
import org.apache.flink.table.utils.TableTestUtil._
class GroupAggregationsTest extends TableTestBase {
+ /**
+ * OVER clause is necessary for [[OverAgg0]] window function.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testOverAggregation(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+ val overAgg = new OverAgg0
+ table.select(overAgg('a, 'b))
+ }
+
@Test(expected = classOf[ValidationException])
def testGroupingOnNonExistentField(): Unit = {
val util = streamTestUtil()
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala
index b389183..55689d0 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala
@@ -19,7 +19,7 @@
package org.apache.flink.table.api.scala.stream.table
import org.apache.flink.api.scala._
-import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.{WeightedAvg, WeightedAvgWithMerge}
+import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.{OverAgg0, WeightedAvg, WeightedAvgWithMerge}
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.WindowReference
@@ -30,6 +30,21 @@ import org.junit.{Ignore, Test}
class GroupWindowTest extends TableTestBase {
+ /**
+ * OVER clause is necessary for [[OverAgg0]] window function.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testOverAggregation(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('long, 'int, 'string, 'proctime.proctime)
+
+ val overAgg = new OverAgg0
+ table
+ .window(Tumble over 2.rows on 'proctime as 'w)
+ .groupBy('w, 'string)
+ .select(overAgg('long, 'int))
+ }
+
@Test(expected = classOf[ValidationException])
def testInvalidWindowProperty(): Unit = {
val util = streamTestUtil()
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index 77798f9..0b3373c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -29,12 +29,22 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil}
import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.functions.aggfunctions.{LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction, IntSumWithRetractAggFunction}
import org.apache.flink.table.runtime.aggregate.AggregateUtil
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
class HarnessTestBase {
+ val longMinWithRetractAggFunction =
+ UserDefinedFunctionUtils.serialize(new LongMinWithRetractAggFunction)
+
+ val longMaxWithRetractAggFunction =
+ UserDefinedFunctionUtils.serialize(new LongMaxWithRetractAggFunction)
+
+ val intSumWithRetractAggFunction =
+ UserDefinedFunctionUtils.serialize(new IntSumWithRetractAggFunction)
+
protected val MinMaxRowType = new RowTypeInfo(Array[TypeInformation[_]](
INT_TYPE_INFO,
LONG_TYPE_INFO,
@@ -66,7 +76,7 @@ class HarnessTestBase {
AggregateUtil.createAccumulatorRowType(sumAggregates)
val minMaxCode: String =
- """
+ s"""
|public class MinMaxAggregateHelper
| extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
|
@@ -80,23 +90,11 @@ class HarnessTestBase {
|
| fmin = (org.apache.flink.table.functions.aggfunctions.LongMinWithRetractAggFunction)
| org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
- | .deserialize("rO0ABXNyAEtvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLmZ1bmN0aW9ucy5hZ2dmdW5jdGlvbn" +
- | "MuTG9uZ01pbldpdGhSZXRyYWN0QWdnRnVuY3Rpb26oIdX_DaMPxQIAAHhyAEdvcmcuYXBhY2hlLmZsaW5rL" +
- | "nRhYmxlLmZ1bmN0aW9ucy5hZ2dmdW5jdGlvbnMuTWluV2l0aFJldHJhY3RBZ2dGdW5jdGlvbq_ZGuzxtA_S" +
- | "AgABTAADb3JkdAAVTHNjYWxhL21hdGgvT3JkZXJpbmc7eHIAMm9yZy5hcGFjaGUuZmxpbmsudGFibGUuZnV" +
- | "uY3Rpb25zLkFnZ3JlZ2F0ZUZ1bmN0aW9uTcYVPtJjNfwCAAB4cgA0b3JnLmFwYWNoZS5mbGluay50YWJsZS" +
- | "5mdW5jdGlvbnMuVXNlckRlZmluZWRGdW5jdGlvbi0B91QxuAyTAgAAeHBzcgAZc2NhbGEubWF0aC5PcmRlc" +
- | "mluZyRMb25nJOda0iCPo2ukAgAAeHA");
+ | .deserialize("${longMinWithRetractAggFunction}");
|
| fmax = (org.apache.flink.table.functions.aggfunctions.LongMaxWithRetractAggFunction)
| org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
- | .deserialize("rO0ABXNyAEtvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLmZ1bmN0aW9ucy5hZ2dmdW5jdGlvbn" +
- | "MuTG9uZ01heFdpdGhSZXRyYWN0QWdnRnVuY3Rpb25RmsI8azNGXwIAAHhyAEdvcmcuYXBhY2hlLmZsaW5rL" +
- | "nRhYmxlLmZ1bmN0aW9ucy5hZ2dmdW5jdGlvbnMuTWF4V2l0aFJldHJhY3RBZ2dGdW5jdGlvbvnwowlX0_Qf" +
- | "AgABTAADb3JkdAAVTHNjYWxhL21hdGgvT3JkZXJpbmc7eHIAMm9yZy5hcGFjaGUuZmxpbmsudGFibGUuZnV" +
- | "uY3Rpb25zLkFnZ3JlZ2F0ZUZ1bmN0aW9uTcYVPtJjNfwCAAB4cgA0b3JnLmFwYWNoZS5mbGluay50YWJsZS" +
- | "5mdW5jdGlvbnMuVXNlckRlZmluZWRGdW5jdGlvbi0B91QxuAyTAgAAeHBzcgAZc2NhbGEubWF0aC5PcmRlc" +
- | "mluZyRMb25nJOda0iCPo2ukAgAAeHA");
+ | .deserialize("${longMaxWithRetractAggFunction}");
| }
|
| public void setAggregationResults(
@@ -192,7 +190,7 @@ class HarnessTestBase {
""".stripMargin
val sumAggCode: String =
- """
+ s"""
|public final class SumAggregationHelper
| extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
|
@@ -209,17 +207,8 @@ class HarnessTestBase {
|
|sum = (org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction)
|org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
- |.deserialize
- |("rO0ABXNyAEpvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLmZ1bmN0aW9ucy5hZ2dmdW5jdGlvbnMuSW50U3VtV2l0a" +
- |"FJldHJhY3RBZ2dGdW5jdGlvblkfWkeNZDeDAgAAeHIAR29yZy5hcGFjaGUuZmxpbmsudGFibGUuZnVuY3Rpb25" +
- |"zLmFnZ2Z1bmN0aW9ucy5TdW1XaXRoUmV0cmFjdEFnZ0Z1bmN0aW9ut2oWrOsLrs0CAAFMAAdudW1lcmljdAAUT" +
- |"HNjYWxhL21hdGgvTnVtZXJpYzt4cgAyb3JnLmFwYWNoZS5mbGluay50YWJsZS5mdW5jdGlvbnMuQWdncmVnYXR" +
- |"lRnVuY3Rpb25NxhU-0mM1_AIAAHhyADRvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLmZ1bmN0aW9ucy5Vc2VyRGVma" +
- |"W5lZEZ1bmN0aW9uLQH3VDG4DJMCAAB4cHNyACFzY2FsYS5tYXRoLk51bWVyaWMkSW50SXNJbnRlZ3JhbCTw6XA" +
- |"59sPAzAIAAHhw");
- |
- |
- | }
+ |.deserialize("${intSumWithRetractAggFunction}");
+ |}
|
| public final void setAggregationResults(
| org.apache.flink.types.Row accs,
http://git-wip-us.apache.org/repos/asf/flink/blob/629d3633/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
index 65014cd..a1b28d3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
@@ -31,6 +31,7 @@ import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream}
import org.apache.flink.streaming.api.environment.{StreamExecutionEnvironment => JStreamExecutionEnvironment}
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.table.api.scala.batch.utils.LogicalPlanFormatUtils
+import org.apache.flink.table.functions.AggregateFunction
import org.junit.Assert.assertEquals
import org.mockito.Mockito.{mock, when}
@@ -150,6 +151,12 @@ case class BatchTableTestUtil() extends TableTestUtil {
tEnv.registerFunction(name, function)
}
+ def addFunction[T:TypeInformation, ACC:TypeInformation](
+ name: String,
+ function: AggregateFunction[T, ACC]): Unit = {
+ tEnv.registerFunction(name, function)
+ }
+
def verifySql(query: String, expected: String): Unit = {
verifyTable(tEnv.sql(query), expected)
}
@@ -210,6 +217,12 @@ case class StreamTableTestUtil() extends TableTestUtil {
tEnv.registerFunction(name, function)
}
+ def addFunction[T:TypeInformation, ACC:TypeInformation](
+ name: String,
+ function: AggregateFunction[T, ACC]): Unit = {
+ tEnv.registerFunction(name, function)
+ }
+
def verifySql(query: String, expected: String): Unit = {
verifyTable(tEnv.sql(query), expected)
}