You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2018/09/13 21:24:12 UTC
[flink] branch master updated: [FLINK-5315] [table] Add support for
DISTINCT aggregation to Table API.
This is an automated email from the ASF dual-hosted git repository.
fhueske pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 04c7cdf [FLINK-5315] [table] Add support for DISTINCT aggregation to Table API.
04c7cdf is described below
commit 04c7cdf7d3d7b0862f7a4a3e7d821b1fb62ad3f0
Author: Rong Rong <wa...@hotmail.com>
AuthorDate: Tue Aug 7 14:29:46 2018 -0700
[FLINK-5315] [table] Add support for DISTINCT aggregation to Table API.
This closes #6521.
---
docs/dev/table/tableApi.md | 76 ++++++++++
.../flink/table/api/scala/expressionDsl.scala | 10 +-
.../flink/table/expressions/ExpressionParser.scala | 13 ++
.../flink/table/expressions/aggregations.scala | 164 +++++++++++++++++----
.../functions/DistinctAggregateFunction.scala | 43 ++++++
.../flink/table/plan/logical/operators.scala | 8 +
.../stringexpr/AggregateStringExpressionTest.scala | 66 +++++++++
.../table/api/stream/table/AggregateTest.scala | 52 +++++++
.../stringexpr/AggregateStringExpressionTest.scala | 70 ++++++++-
.../runtime/stream/table/AggregateITCase.scala | 80 +++++++++-
10 files changed, 551 insertions(+), 31 deletions(-)
diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md
index 1a21a57..f50b0f5 100644
--- a/docs/dev/table/tableApi.md
+++ b/docs/dev/table/tableApi.md
@@ -372,6 +372,44 @@ Table result = orders
</tr>
<tr>
<td>
+ <strong>Distinct Aggregation</strong><br>
+ <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br>
+ <span class="label label-info">Result Updating</span>
+ </td>
+ <td>
+ <p>Similar to a SQL DISTINCT aggregation clause such as COUNT(DISTINCT a). Distinct aggregation declares that an aggregation function (built-in or user-defined) is only applied on distinct input values. Distinct can be applied to <b>GroupBy Aggregation</b>, <b>GroupBy Window Aggregation</b> and <b>Over Window Aggregation</b>.</p>
+{% highlight java %}
+Table orders = tableEnv.scan("Orders");
+// Distinct aggregation on group by
+Table groupByDistinctResult = orders
+ .groupBy("a")
+ .select("a, b.sum.distinct as d");
+// Distinct aggregation on time window group by
+Table groupByWindowDistinctResult = orders
+ .window(Tumble.over("5.minutes").on("rowtime").as("w")).groupBy("a, w")
+ .select("a, b.sum.distinct as d");
+// Distinct aggregation on over window
+Table result = orders
+ .window(Over
+ .partitionBy("a")
+ .orderBy("rowtime")
+ .preceding("UNBOUNDED_RANGE")
+ .as("w"))
+ .select("a, b.avg.distinct over w, b.max over w, b.min over w");
+{% endhighlight %}
+ <p>User-defined aggregation function can also be used with DISTINCT modifiers. To calculate the aggregate results only for distinct values, simply add the distinct modifier towards the aggregation function. </p>
+{% highlight java %}
+Table orders = tEnv.scan("Orders");
+
+// Use distinct aggregation for user-defined aggregate functions
+tEnv.registerFunction("myUdagg", new MyUdagg());
+orders.groupBy("users").select("users, myUdagg.distinct(points) as myDistinctResult");
+{% endhighlight %}
+ <p><b>Note:</b> For streaming queries the required state to compute the query result might grow infinitely depending on the number of distinct fields. Please provide a query configuration with valid retention interval to prevent excessive state size. See <a href="streaming.html">Streaming Concepts</a> for details.</p>
+ </td>
+ </tr>
+ <tr>
+ <td>
<strong>Distinct</strong><br>
<span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br>
<span class="label label-info">Result Updating</span>
@@ -455,6 +493,44 @@ val result: Table = orders
</tr>
<tr>
<td>
+ <strong>Distinct Aggregation</strong><br>
+ <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br>
+ <span class="label label-info">Result Updating</span>
+ </td>
+ <td>
+ <p>Similar to a SQL DISTINCT AGGREGATION clause such as COUNT(DISTINCT a). Distinct aggregation declares that an aggregation function (built-in or user-defined) is only applied on distinct input values. Distinct can be applied to <b>GroupBy Aggregation</b>, <b>GroupBy Window Aggregation</b> and <b>Over Window Aggregation</b>.</p>
+{% highlight scala %}
+val orders: Table = tableEnv.scan("Orders");
+// Distinct aggregation on group by
+val groupByDistinctResult = orders
+ .groupBy('a)
+ .select('a, 'b.sum.distinct as 'd)
+// Distinct aggregation on time window group by
+val groupByWindowDistinctResult = orders
+ .window(Tumble over 5.minutes on 'rowtime as 'w).groupBy('a, 'w)
+ .select('a, 'b.sum.distinct as 'd)
+// Distinct aggregation on over window
+val result = orders
+ .window(Over
+ partitionBy 'a
+ orderBy 'rowtime
+ preceding UNBOUNDED_RANGE
+ as 'w)
+ .select('a, 'b.avg.distinct over 'w, 'b.max over 'w, 'b.min over 'w)
+{% endhighlight %}
+ <p>User-defined aggregation function can also be used with DISTINCT modifiers. To calculate the aggregate results only for distinct values, simply add the distinct modifier towards the aggregation function. </p>
+{% highlight scala %}
+val orders: Table = tEnv.scan("Orders");
+
+// Use distinct aggregation for user-defined aggregate functions
+val myUdagg = new MyUdagg();
+orders.groupBy('users).select('users, myUdagg.distinct('points) as 'myDistinctResult);
+{% endhighlight %}
+ <p><b>Note:</b> For streaming queries the required state to compute the query result might grow infinitely depending on the number of distinct fields. Please provide a query configuration with valid retention interval to prevent excessive state size. See <a href="streaming.html">Streaming Concepts</a> for details.</p>
+ </td>
+ </tr>
+ <tr>
+ <td>
<strong>Distinct</strong><br>
<span class="label label-primary">Batch</span>
</td>
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index cf8cd91..126cc5f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -22,12 +22,12 @@ import java.sql.{Date, Time, Timestamp}
import org.apache.calcite.avatica.util.DateTimeUtils._
import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
-import org.apache.flink.table.api.{TableException, CurrentRow, CurrentRange, UnboundedRow, UnboundedRange}
+import org.apache.flink.table.api.{CurrentRange, CurrentRow, TableException, UnboundedRange, UnboundedRow}
import org.apache.flink.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval}
import org.apache.flink.table.api.Table
import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit
import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.{AggregateFunction, DistinctAggregateFunction}
import scala.language.implicitConversions
@@ -214,7 +214,7 @@ trait ImplicitExpressionOperations {
def varSamp = VarSamp(expr)
/**
- * Returns multiset aggregate of a given expression.
+ * Returns multiset aggregate of a given expression.
*/
def collect = Collect(expr)
@@ -998,6 +998,10 @@ trait ImplicitExpressionConversions {
implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array)
implicit def userDefinedAggFunctionConstructor[T: TypeInformation, ACC: TypeInformation]
(udagg: AggregateFunction[T, ACC]): UDAGGExpression[T, ACC] = UDAGGExpression(udagg)
+ implicit def toDistinct(agg: Aggregation): DistinctAgg = DistinctAgg(agg)
+ implicit def toDistinct[T: TypeInformation, ACC: TypeInformation]
+ (agg: AggregateFunction[T, ACC]): DistinctAggregateFunction[T, ACC] =
+ DistinctAggregateFunction(agg)
}
// ------------------------------------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
index 633afb1..4909d2c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
@@ -83,6 +83,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val GET: Keyword = Keyword("get")
lazy val FLATTEN: Keyword = Keyword("flatten")
lazy val OVER: Keyword = Keyword("over")
+ lazy val DISTINCT: Keyword = Keyword("distinct")
lazy val CURRENT_ROW: Keyword = Keyword("current_row")
lazy val CURRENT_RANGE: Keyword = Keyword("current_range")
lazy val UNBOUNDED_ROW: Keyword = Keyword("unbounded_row")
@@ -324,6 +325,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val suffixFlattening: PackratParser[Expression] =
composite <~ "." ~ FLATTEN ~ opt("()") ^^ { e => Flattening(e) }
+ lazy val suffixDistinct: PackratParser[Expression] =
+ composite <~ "." ~ DISTINCT ~ opt("()") ^^ { e => DistinctAgg(e) }
+
lazy val suffixAs: PackratParser[Expression] =
composite ~ "." ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
case e ~ _ ~ _ ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name))
@@ -345,6 +349,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
suffixGet |
// expression with special identifier
suffixIf |
+ // expression with distinct suffix modifier
+ suffixDistinct |
// function call must always be at the end
suffixFunctionCall | suffixFunctionCallOneArg
@@ -412,6 +418,11 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val prefixToTime: PackratParser[Expression] =
TO_TIME ~ "(" ~> expression <~ ")" ^^ { e => Cast(e, SqlTimeTypeInfo.TIME) }
+ lazy val prefixDistinct: PackratParser[Expression] =
+ functionIdent ~ "." ~ DISTINCT ~ "(" ~ repsep(expression, ",") ~ ")" ^^ {
+ case name ~ _ ~ _ ~ _ ~ args ~ _ => DistinctAgg(Call(name.toUpperCase, args))
+ }
+
lazy val prefixAs: PackratParser[Expression] =
AS ~ "(" ~ expression ~ "," ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
case _ ~ _ ~ e ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name))
@@ -428,6 +439,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
prefixGet |
// expression with special identifier
prefixIf |
+ // expression with prefix distinct
+ prefixDistinct |
// function call must always be at the end
prefixFunctionCall | prefixFunctionCallOneArg
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index b39bd98..c77bd7a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -19,7 +19,6 @@ package org.apache.flink.table.expressions
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.SqlAggFunction
-import org.apache.calcite.sql.SqlKind._
import org.apache.calcite.sql.fun._
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.tools.RelBuilder.AggCall
@@ -43,7 +42,10 @@ abstract sealed class Aggregation extends Expression {
/**
* Convert Aggregate to its counterpart in Calcite, i.e. AggCall
*/
- private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall
+ private[flink] def toAggCall(
+ name: String,
+ isDistinct: Boolean = false
+ )(implicit relBuilder: RelBuilder): AggCall
/**
* Returns the SqlAggFunction for this Aggregation.
@@ -52,12 +54,49 @@ abstract sealed class Aggregation extends Expression {
}
+case class DistinctAgg(child: Expression) extends Aggregation {
+
+ private[flink] def distinct: Expression = DistinctAgg(child)
+
+ override private[flink] def resultType: TypeInformation[_] = child.resultType
+
+ override private[flink] def validateInput(): ValidationResult = {
+ super.validateInput()
+ child match {
+ case agg: Aggregation =>
+ child.validateInput()
+ case _ =>
+ ValidationFailure(s"Distinct modifier cannot be applied to $child! " +
+ s"It can only be applied to an aggregation expression, for example, " +
+ s"'a.count.distinct which is equivalent with COUNT(DISTINCT a).")
+ }
+ }
+
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = true)(implicit relBuilder: RelBuilder) = {
+ child.asInstanceOf[Aggregation].toAggCall(name, isDistinct = true)
+ }
+
+ override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
+ child.asInstanceOf[Aggregation].getSqlAggFunction()
+ }
+
+ override private[flink] def children = Seq(child)
+}
+
case class Sum(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"sum($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- relBuilder.aggregateCall(SqlStdOperatorTable.SUM, false, false, null, name, child.toRexNode)
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+ relBuilder.aggregateCall(
+ SqlStdOperatorTable.SUM,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -77,8 +116,15 @@ case class Sum0(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"sum0($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- relBuilder.aggregateCall(SqlStdOperatorTable.SUM0, false, false, null, name, child.toRexNode)
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+ relBuilder.aggregateCall(
+ SqlStdOperatorTable.SUM0,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -94,8 +140,15 @@ case class Min(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"min($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- relBuilder.aggregateCall(SqlStdOperatorTable.MIN, false, false, null, name, child.toRexNode)
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+ relBuilder.aggregateCall(
+ SqlStdOperatorTable.MIN,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -112,8 +165,15 @@ case class Max(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"max($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- relBuilder.aggregateCall(SqlStdOperatorTable.MAX, false, false, null, name, child.toRexNode)
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+ relBuilder.aggregateCall(
+ SqlStdOperatorTable.MAX,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -130,8 +190,15 @@ case class Count(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"count($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- relBuilder.aggregateCall(SqlStdOperatorTable.COUNT, false, false, null, name, child.toRexNode)
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+ relBuilder.aggregateCall(
+ SqlStdOperatorTable.COUNT,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO
@@ -145,8 +212,15 @@ case class Avg(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"avg($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- relBuilder.aggregateCall(SqlStdOperatorTable.AVG, false, false, null, name, child.toRexNode)
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+ relBuilder.aggregateCall(
+ SqlStdOperatorTable.AVG,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -171,8 +245,15 @@ case class Collect(child: Expression) extends Aggregation {
override def toString: String = s"collect($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, false, false, null, name, child.toRexNode)
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+ relBuilder.aggregateCall(
+ SqlStdOperatorTable.COLLECT,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
@@ -184,9 +265,15 @@ case class StddevPop(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"stddev_pop($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(
- SqlStdOperatorTable.STDDEV_POP, false, false, null, name, child.toRexNode)
+ SqlStdOperatorTable.STDDEV_POP,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -202,9 +289,15 @@ case class StddevSamp(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"stddev_samp($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(
- SqlStdOperatorTable.STDDEV_SAMP, false, false, null, name, child.toRexNode)
+ SqlStdOperatorTable.STDDEV_SAMP,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -220,8 +313,15 @@ case class VarPop(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"var_pop($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
- relBuilder.aggregateCall(SqlStdOperatorTable.VAR_POP, false, false, null, name, child.toRexNode)
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+ relBuilder.aggregateCall(
+ SqlStdOperatorTable.VAR_POP,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -237,9 +337,15 @@ case class VarSamp(child: Expression) extends Aggregation {
override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"var_samp($child)"
- override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+ override private[flink] def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(
- SqlStdOperatorTable.VAR_SAMP, false, false, null, name, child.toRexNode)
+ SqlStdOperatorTable.VAR_SAMP,
+ isDistinct,
+ false,
+ null,
+ name,
+ child.toRexNode)
}
override private[flink] def resultType = child.resultType
@@ -281,9 +387,15 @@ case class AggFunctionCall(
override def toString: String = s"${aggregateFunction.getClass.getSimpleName}($args)"
- override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+ override def toAggCall(
+ name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(
- this.getSqlAggFunction(), false, false, null, name, args.map(_.toRexNode): _*)
+ this.getSqlAggFunction(),
+ isDistinct,
+ false,
+ null,
+ name,
+ args.map(_.toRexNode): _*)
}
override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala
new file mode 100644
index 0000000..c75e6fa
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.functions
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.expressions.{AggFunctionCall, DistinctAgg, Expression}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getResultTypeOfAggregateFunction}
+
+/**
+ * Defines an implicit conversion method (distinct) that converts [[AggregateFunction]]s into
+ * [[DistinctAgg]] Expressions.
+ */
+private[flink] case class DistinctAggregateFunction[T: TypeInformation, ACC: TypeInformation]
+ (aggFunction: AggregateFunction[T, ACC]) {
+
+ private[flink] def distinct(params: Expression*): Expression = {
+ val resultTypeInfo: TypeInformation[_] = getResultTypeOfAggregateFunction(
+ aggFunction,
+ implicitly[TypeInformation[T]])
+
+ val accTypeInfo: TypeInformation[_] = getAccumulatorTypeOfAggregateFunction(
+ aggFunction,
+ implicitly[TypeInformation[ACC]])
+
+ DistinctAgg(
+ AggFunctionCall(aggFunction, resultTypeInfo, accTypeInfo, params))
+ }
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
index a2bd1e4..7579621 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
@@ -235,6 +235,14 @@ case class Aggregate(
groupingExprs.foreach(validateGroupingExpression)
def validateAggregateExpression(expr: Expression): Unit = expr match {
+ case distinctExpr: DistinctAgg =>
+ distinctExpr.child match {
+ case _: DistinctAgg => failValidation(
+ "Chained distinct operators are not supported!")
+ case aggExpr: Aggregation => validateAggregateExpression(aggExpr)
+ case _ => failValidation(
+ "Distinct operator can only be applied to aggregation expressions!")
+ }
// check aggregate function
case aggExpr: Aggregation
if aggExpr.getSqlAggFunction.requiresOver =>
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
index 4bbb101..4e7270f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
@@ -28,6 +28,19 @@ import org.junit._
class AggregateStringExpressionTest extends TableTestBase {
@Test
+ def testDistinctAggregationTypes(): Unit = {
+ val util = batchTestUtil()
+ val t = util.addTable[(Int, Long, String)]("Table3")
+
+ val t1 = t.select('_1.sum.distinct, '_1.count.distinct, '_1.avg.distinct)
+ val t2 = t.select("_1.sum.distinct, _1.count.distinct, _1.avg.distinct")
+ val t3 = t.select("sum.distinct(_1), count.distinct(_1), avg.distinct(_1)")
+
+ verifyTableEquals(t1, t2)
+ verifyTableEquals(t1, t3)
+ }
+
+ @Test
def testAggregationTypes(): Unit = {
val util = batchTestUtil()
val t = util.addTable[(Int, Long, String)]("Table3")
@@ -119,6 +132,19 @@ class AggregateStringExpressionTest extends TableTestBase {
}
@Test
+ def testDistinctGroupedAggregate(): Unit = {
+ val util = batchTestUtil()
+ val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+ val t1 = t.groupBy('b).select('b, 'a.sum.distinct, 'a.sum)
+ val t2 = t.groupBy("b").select("b, a.sum.distinct, a.sum")
+ val t3 = t.groupBy("b").select("b, sum.distinct(a), sum(a)")
+
+ verifyTableEquals(t1, t2)
+ verifyTableEquals(t1, t3)
+ }
+
+ @Test
def testGroupedAggregate(): Unit = {
val util = batchTestUtil()
val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
@@ -239,6 +265,22 @@ class AggregateStringExpressionTest extends TableTestBase {
}
@Test
+ def testDistinctAggregateWithUDAGG(): Unit = {
+ val util = batchTestUtil()
+ val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+ val myCnt = new CountAggFunction
+ util.tableEnv.registerFunction("myCnt", myCnt)
+ val myWeightedAvg = new WeightedAvgWithMergeAndReset
+ util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+ val t1 = t.select(myCnt.distinct('a) as 'aCnt, myWeightedAvg.distinct('b, 'a) as 'wAvg)
+ val t2 = t.select("myCnt.distinct(a) as aCnt, myWeightedAvg.distinct(b, a) as wAvg")
+
+ verifyTableEquals(t1, t2)
+ }
+
+ @Test
def testAggregateWithUDAGG(): Unit = {
val util = batchTestUtil()
val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
@@ -255,6 +297,30 @@ class AggregateStringExpressionTest extends TableTestBase {
}
@Test
+ def testDistinctGroupedAggregateWithUDAGG(): Unit = {
+ val util = batchTestUtil()
+ val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+
+ val myCnt = new CountAggFunction
+ util.tableEnv.registerFunction("myCnt", myCnt)
+ val myWeightedAvg = new WeightedAvgWithMergeAndReset
+ util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+ val t1 = t.groupBy('b)
+ .select('b,
+ myCnt.distinct('a) + 9 as 'aCnt,
+ myWeightedAvg.distinct('b, 'a) * 2 as 'wAvg,
+ myWeightedAvg.distinct('a, 'a) as 'distAgg,
+ myWeightedAvg('a, 'a) as 'agg)
+ val t2 = t.groupBy("b")
+ .select("b, myCnt.distinct(a) + 9 as aCnt, myWeightedAvg.distinct(b, a) * 2 as wAvg, " +
+ "myWeightedAvg.distinct(a, a) as distAgg, myWeightedAvg(a, a) as agg")
+
+ verifyTableEquals(t1, t2)
+ }
+
+ @Test
def testGroupedAggregateWithUDAGG(): Unit = {
val util = batchTestUtil()
val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
index 533235a..671f8dd 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
@@ -21,6 +21,7 @@ package org.apache.flink.table.api.stream.table
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
import org.apache.flink.table.utils.TableTestUtil._
import org.apache.flink.table.utils.TableTestBase
import org.junit.Test
@@ -28,6 +29,57 @@ import org.junit.Test
class AggregateTest extends TableTestBase {
@Test
+ def testGroupDistinctAggregate(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+
+ val resultTable = table
+ .groupBy('b)
+ .select('a.sum.distinct, 'c.count.distinct)
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupAggregate",
+ streamTableNode(0),
+ term("groupBy", "b"),
+ term("select", "b", "SUM(DISTINCT a) AS TMP_0", "COUNT(DISTINCT c) AS TMP_1")
+ ),
+ term("select", "TMP_0", "TMP_1")
+ )
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testGroupDistinctAggregateWithUDAGG(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+ val weightedAvg = new WeightedAvg
+
+ val resultTable = table
+ .groupBy('c)
+ .select(weightedAvg.distinct('a, 'b), weightedAvg('a, 'b))
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupAggregate",
+ streamTableNode(0),
+ term("groupBy", "c"),
+ term(
+ "select",
+ "c",
+ "WeightedAvg(DISTINCT a, b) AS TMP_0",
+ "WeightedAvg(a, b) AS TMP_1")
+ ),
+ term("select", "TMP_0", "TMP_1")
+ )
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
def testGroupAggregate() = {
val util = streamTestUtil()
val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
index 2bef95e..ec57436 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
@@ -20,12 +20,80 @@ package org.apache.flink.table.api.stream.table.stringexpr
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
+import org.apache.flink.table.functions.aggfunctions.CountAggFunction
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{WeightedAvg, WeightedAvgWithMergeAndReset}
import org.apache.flink.table.utils.TableTestBase
import org.junit.Test
class AggregateStringExpressionTest extends TableTestBase {
+
+ @Test
+ def testDistinctNonGroupedAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]("Table3")
+
+ val t1 = t.select('_1.sum.distinct, '_1.count.distinct, '_1.avg.distinct)
+ val t2 = t.select("_1.sum.distinct, _1.count.distinct, _1.avg.distinct")
+ val t3 = t.select("sum.distinct(_1), count.distinct(_1), avg.distinct(_1)")
+
+ verifyTableEquals(t1, t2)
+ verifyTableEquals(t1, t3)
+ }
+
+ @Test
+ def testDistinctGroupedAggregate(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+ val t1 = t.groupBy('b).select('b, 'a.sum.distinct, 'a.sum)
+ val t2 = t.groupBy("b").select("b, a.sum.distinct, a.sum")
+ val t3 = t.groupBy("b").select("b, sum.distinct(a), sum(a)")
+
+ verifyTableEquals(t1, t2)
+ verifyTableEquals(t1, t3)
+ }
+
+ @Test
+ def testDistinctNonGroupAggregateWithUDAGG(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+ val myCnt = new CountAggFunction
+ util.tableEnv.registerFunction("myCnt", myCnt)
+ val myWeightedAvg = new WeightedAvgWithMergeAndReset
+ util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+ val t1 = t.select(myCnt.distinct('a) as 'aCnt, myWeightedAvg.distinct('b, 'a) as 'wAvg)
+ val t2 = t.select("myCnt.distinct(a) as aCnt, myWeightedAvg.distinct(b, a) as wAvg")
+
+ verifyTableEquals(t1, t2)
+ }
+
+ @Test
+ def testDistinctGroupedAggregateWithUDAGG(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+
+ val myCnt = new CountAggFunction
+ util.tableEnv.registerFunction("myCnt", myCnt)
+ val myWeightedAvg = new WeightedAvgWithMergeAndReset
+ util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+ val t1 = t.groupBy('b)
+ .select('b,
+ myCnt.distinct('a) + 9 as 'aCnt,
+ myWeightedAvg.distinct('b, 'a) * 2 as 'wAvg,
+ myWeightedAvg.distinct('a, 'a) as 'distAgg,
+ myWeightedAvg('a, 'a) as 'agg)
+ val t2 = t.groupBy("b")
+ .select("b, myCnt.distinct(a) + 9 as aCnt, myWeightedAvg.distinct(b, a) * 2 as wAvg, " +
+ "myWeightedAvg.distinct(a, a) as distAgg, myWeightedAvg(a, a) as agg")
+
+ verifyTableEquals(t1, t2)
+ }
+
@Test
def testGroupedAggregate(): Unit = {
val util = streamTestUtil()
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
index db6820a..219b765 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
@@ -25,7 +25,7 @@ import org.apache.flink.table.api.scala._
import org.apache.flink.table.runtime.utils.StreamITCase.RetractingSink
import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment, Types}
import org.apache.flink.table.expressions.Null
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg}
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg, WeightedAvg}
import org.apache.flink.table.runtime.utils.{JavaUserDefinedAggFunctions, StreamITCase, StreamTestData, StreamingWithStateTestBase}
import org.apache.flink.types.Row
import org.junit.Assert.assertEquals
@@ -41,6 +41,84 @@ class AggregateITCase extends StreamingWithStateTestBase {
queryConfig.withIdleStateRetentionTime(Time.hours(1), Time.hours(2))
@Test
+ def testDistinctUDAGG(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val testAgg = new DataViewTestAgg
+ val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+ .groupBy('e)
+ .select('e, testAgg.distinct('d, 'e))
+
+ val results = t.toRetractStream[Row](queryConfig)
+ results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = mutable.MutableList("1,10", "2,21", "3,12")
+ assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
+ def testDistinctUDAGGMixedWithNonDistinctUsage(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val testAgg = new WeightedAvg
+ val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+ .groupBy('e)
+ .select('e, testAgg.distinct('a, 'a), testAgg('a, 'a))
+
+ val results = t.toRetractStream[Row](queryConfig)
+ results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = mutable.MutableList("1,3,3", "2,3,4", "3,4,4")
+ assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
+ def testDistinctAggregate(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+ .groupBy('e)
+ .select('e, 'a.count.distinct)
+
+ val results = t.toRetractStream[Row](queryConfig)
+ results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = mutable.MutableList("1,4", "2,4", "3,2")
+ assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
+ def testDistinctAggregateMixedWithNonDistinct(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+ .groupBy('e)
+ .select('e, 'a.count.distinct, 'b.count)
+
+ val results = t.toRetractStream[Row](queryConfig)
+ results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = mutable.MutableList("1,4,5", "2,4,7", "3,2,3")
+ assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
def testDistinct(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStateBackend(getStateBackend)