You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2018/09/13 21:25:55 UTC

[GitHub] asfgit closed pull request #6521: [FLINK-5315][table] Adding support for distinct operation for table API on DataStream

asfgit closed pull request #6521: [FLINK-5315][table] Adding support for distinct operation for table API on DataStream
URL: https://github.com/apache/flink/pull/6521
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md
index f8bcd3da1af..a9b92fad995 100644
--- a/docs/dev/table/tableApi.md
+++ b/docs/dev/table/tableApi.md
@@ -370,6 +370,44 @@ Table result = orders
        <p><b>Note:</b> All aggregates must be defined over the same window, i.e., same partitioning, sorting, and range. Currently, only windows with PRECEDING (UNBOUNDED and bounded) to CURRENT ROW range are supported. Ranges with FOLLOWING are not supported yet. ORDER BY must be specified on a single <a href="streaming.html#time-attributes">time attribute</a>.</p>
       </td>
     </tr>
+    <tr>
+      <td>
+        <strong>Distinct Aggregation</strong><br>
+        <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br>
+        <span class="label label-info">Result Updating</span>
+      </td>
+      <td>
+        <p>Similar to a SQL DISTINCT aggregation clause such as COUNT(DISTINCT a). Distinct aggregation declares that an aggregation function (built-in or user-defined) is only applied on distinct input values. Distinct can be applied to <b>GroupBy Aggregation</b>, <b>GroupBy Window Aggregation</b> and <b>Over Window Aggregation</b>.</p>
+{% highlight java %}
+Table orders = tableEnv.scan("Orders");
+// Distinct aggregation on group by
+Table groupByDistinctResult = orders
+    .groupBy("a")
+    .select("a, b.sum.distinct as d");
+// Distinct aggregation on time window group by
+Table groupByWindowDistinctResult = orders
+    .window(Tumble.over("5.minutes").on("rowtime").as("w")).groupBy("a, w")
+    .select("a, b.sum.distinct as d");
+// Distinct aggregation on over window
+Table result = orders
+    .window(Over
+        .partitionBy("a")
+        .orderBy("rowtime")
+        .preceding("UNBOUNDED_RANGE")
+        .as("w"))
+    .select("a, b.avg.distinct over w, b.max over w, b.min over w");
+{% endhighlight %}
+        <p>User-defined aggregation function can also be used with DISTINCT modifiers. To calculate the aggregate results only for distinct values, simply add the distinct modifier towards the aggregation function. </p>
+{% highlight java %}
+Table orders = tEnv.scan("Orders");
+
+// Use distinct aggregation for user-defined aggregate functions
+tEnv.registerFunction("myUdagg", new MyUdagg());
+orders.groupBy("users").select("users, myUdagg.distinct(points) as myDistinctResult");
+{% endhighlight %}
+        <p><b>Note:</b> For streaming queries the required state to compute the query result might grow infinitely depending on the number of distinct fields. Please provide a query configuration with valid retention interval to prevent excessive state size. See <a href="streaming.html">Streaming Concepts</a> for details.</p>
+      </td>
+    </tr>
     <tr>
       <td>
         <strong>Distinct</strong><br>
@@ -453,6 +491,44 @@ val result: Table = orders
        <p><b>Note:</b> All aggregates must be defined over the same window, i.e., same partitioning, sorting, and range. Currently, only windows with PRECEDING (UNBOUNDED and bounded) to CURRENT ROW range are supported. Ranges with FOLLOWING are not supported yet. ORDER BY must be specified on a single <a href="streaming.html#time-attributes">time attribute</a>.</p>
       </td>
     </tr>
+    <tr>
+      <td>
+        <strong>Distinct Aggregation</strong><br>
+        <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br>
+        <span class="label label-info">Result Updating</span>
+      </td>
+      <td>
+        <p>Similar to a SQL DISTINCT AGGREGATION clause such as COUNT(DISTINCT a). Distinct aggregation declares that an aggregation function (built-in or user-defined) is only applied on distinct input values. Distinct can be applied to <b>GroupBy Aggregation</b>, <b>GroupBy Window Aggregation</b> and <b>Over Window Aggregation</b>.</p>
+{% highlight scala %}
+val orders: Table = tableEnv.scan("Orders");
+// Distinct aggregation on group by
+val groupByDistinctResult = orders
+    .groupBy('a)
+    .select('a, 'b.sum.distinct as 'd)
+// Distinct aggregation on time window group by
+val groupByWindowDistinctResult = orders
+    .window(Tumble over 5.minutes on 'rowtime as 'w).groupBy('a, 'w)
+    .select('a, 'b.sum.distinct as 'd)
+// Distinct aggregation on over window
+val result = orders
+    .window(Over
+        partitionBy 'a
+        orderBy 'rowtime
+        preceding UNBOUNDED_RANGE
+        as 'w)
+    .select('a, 'b.avg.distinct over 'w, 'b.max over 'w, 'b.min over 'w)
+{% endhighlight %}
+        <p>User-defined aggregation function can also be used with DISTINCT modifiers. To calculate the aggregate results only for distinct values, simply add the distinct modifier towards the aggregation function. </p>
+{% highlight scala %}
+val orders: Table = tEnv.scan("Orders");
+
+// Use distinct aggregation for user-defined aggregate functions
+val myUdagg = new MyUdagg();
+orders.groupBy('users).select('users, myUdagg.distinct('points) as 'myDistinctResult);
+{% endhighlight %}
+        <p><b>Note:</b> For streaming queries the required state to compute the query result might grow infinitely depending on the number of distinct fields. Please provide a query configuration with valid retention interval to prevent excessive state size. See <a href="streaming.html">Streaming Concepts</a> for details.</p>
+      </td>
+    </tr>
     <tr>
       <td>
         <strong>Distinct</strong><br>
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index dfe69cb0411..d8a68f30d73 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -22,12 +22,12 @@ import java.sql.{Date, Time, Timestamp}
 
 import org.apache.calcite.avatica.util.DateTimeUtils._
 import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
-import org.apache.flink.table.api.{TableException, CurrentRow, CurrentRange, UnboundedRow, UnboundedRange}
+import org.apache.flink.table.api.{CurrentRange, CurrentRow, TableException, UnboundedRange, UnboundedRow}
 import org.apache.flink.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval}
 import org.apache.flink.table.api.Table
 import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit
 import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.{AggregateFunction, DistinctAggregateFunction}
 
 import scala.language.implicitConversions
 
@@ -214,7 +214,7 @@ trait ImplicitExpressionOperations {
   def varSamp = VarSamp(expr)
 
   /**
-    *  Returns multiset aggregate of a given expression.
+    * Returns multiset aggregate of a given expression.
     */
   def collect = Collect(expr)
 
@@ -972,6 +972,10 @@ trait ImplicitExpressionConversions {
   implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array)
   implicit def userDefinedAggFunctionConstructor[T: TypeInformation, ACC: TypeInformation]
       (udagg: AggregateFunction[T, ACC]): UDAGGExpression[T, ACC] = UDAGGExpression(udagg)
+  implicit def toDistinct(agg: Aggregation): DistinctAgg = DistinctAgg(agg)
+  implicit def toDistinct[T: TypeInformation, ACC: TypeInformation]
+      (agg: AggregateFunction[T, ACC]): DistinctAggregateFunction[T, ACC] =
+    DistinctAggregateFunction(agg)
 }
 
 // ------------------------------------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
index 4b2440cf673..d7972110191 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
@@ -81,6 +81,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
   lazy val GET: Keyword = Keyword("get")
   lazy val FLATTEN: Keyword = Keyword("flatten")
   lazy val OVER: Keyword = Keyword("over")
+  lazy val DISTINCT: Keyword = Keyword("distinct")
   lazy val CURRENT_ROW: Keyword = Keyword("current_row")
   lazy val CURRENT_RANGE: Keyword = Keyword("current_range")
   lazy val UNBOUNDED_ROW: Keyword = Keyword("unbounded_row")
@@ -311,6 +312,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
   lazy val suffixFlattening: PackratParser[Expression] =
     composite <~ "." ~ FLATTEN ~ opt("()") ^^ { e => Flattening(e) }
 
+  lazy val suffixDistinct: PackratParser[Expression] =
+    composite <~ "." ~ DISTINCT ~ opt("()") ^^ { e => DistinctAgg(e) }
+
   lazy val suffixAs: PackratParser[Expression] =
     composite ~ "." ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
       case e ~ _ ~ _ ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name))
@@ -330,6 +334,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
     suffixGet |
     // expression with special identifier
     suffixIf |
+    // expression with distinct suffix modifier
+    suffixDistinct |
     // function call must always be at the end
     suffixFunctionCall | suffixFunctionCallOneArg
 
@@ -397,6 +403,11 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
   lazy val prefixToTime: PackratParser[Expression] =
     TO_TIME ~ "(" ~> expression <~ ")" ^^ { e => Cast(e, SqlTimeTypeInfo.TIME) }
 
+  lazy val prefixDistinct: PackratParser[Expression] =
+    functionIdent ~ "." ~ DISTINCT ~ "(" ~ repsep(expression, ",") ~ ")" ^^ {
+      case name ~ _ ~ _ ~ _ ~ args ~ _ => DistinctAgg(Call(name.toUpperCase, args))
+  }
+
   lazy val prefixAs: PackratParser[Expression] =
     AS ~ "(" ~ expression ~ "," ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
     case _ ~ _ ~ e ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name))
@@ -413,6 +424,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
     prefixGet |
     // expression with special identifier
     prefixIf |
+    // expression with prefix distinct
+    prefixDistinct |
     // function call must always be at the end
     prefixFunctionCall | prefixFunctionCallOneArg
 
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index b39bd9821d3..e03c5bef168 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -19,7 +19,6 @@ package org.apache.flink.table.expressions
 
 import org.apache.calcite.rex.RexNode
 import org.apache.calcite.sql.SqlAggFunction
-import org.apache.calcite.sql.SqlKind._
 import org.apache.calcite.sql.fun._
 import org.apache.calcite.tools.RelBuilder
 import org.apache.calcite.tools.RelBuilder.AggCall
@@ -43,7 +42,10 @@ abstract sealed class Aggregation extends Expression {
   /**
     * Convert Aggregate to its counterpart in Calcite, i.e. AggCall
     */
-  private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall
+  private[flink] def toAggCall(
+      name: String,
+      isDistinct: Boolean = false
+  )(implicit relBuilder: RelBuilder): AggCall
 
   /**
     * Returns the SqlAggFunction for this Aggregation.
@@ -52,12 +54,48 @@ abstract sealed class Aggregation extends Expression {
 
 }
 
+case class DistinctAgg(child: Expression) extends Aggregation {
+
+  private[flink] def distinct: Expression = DistinctAgg(child)
+
+  override private[flink] def resultType: TypeInformation[_] = child.resultType
+
+  override private[flink] def validateInput(): ValidationResult = {
+    super.validateInput()
+    child match {
+      case agg: Aggregation =>
+        child.validateInput()
+      case _ =>
+        ValidationFailure(s"Distinct modifier cannot be applied to $child! " +
+          s"It can only be applied to an aggregation expression, for example, " +
+          s"'a.count.distinct which is equivalent with COUNT(DISTINCT a).")
+    }
+  }
+
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = true)(implicit relBuilder: RelBuilder) = {
+    child.asInstanceOf[Aggregation].toAggCall(name, isDistinct = true)
+  }
+
+  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
+    child.asInstanceOf[Aggregation].getSqlAggFunction()
+  }
+
+  override private[flink] def children = Seq(child)
+}
+
 case class Sum(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"sum($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.SUM, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(SqlStdOperatorTable.SUM,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -77,8 +115,14 @@ case class Sum0(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"sum0($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.SUM0, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(SqlStdOperatorTable.SUM0,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -94,8 +138,14 @@ case class Min(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"min($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.MIN, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(SqlStdOperatorTable.MIN,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -112,8 +162,14 @@ case class Max(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"max($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.MAX, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(SqlStdOperatorTable.MAX,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -130,8 +186,14 @@ case class Count(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"count($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.COUNT, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(SqlStdOperatorTable.COUNT,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO
@@ -145,8 +207,14 @@ case class Avg(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"avg($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.AVG, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(SqlStdOperatorTable.AVG,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -171,8 +239,14 @@ case class Collect(child: Expression) extends Aggregation  {
 
   override def toString: String = s"collect($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
@@ -184,9 +258,15 @@ case class StddevPop(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"stddev_pop($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
     relBuilder.aggregateCall(
-      SqlStdOperatorTable.STDDEV_POP, false, false, null, name, child.toRexNode)
+      SqlStdOperatorTable.STDDEV_POP,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -202,9 +282,15 @@ case class StddevSamp(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"stddev_samp($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
     relBuilder.aggregateCall(
-      SqlStdOperatorTable.STDDEV_SAMP, false, false, null, name, child.toRexNode)
+      SqlStdOperatorTable.STDDEV_SAMP,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -220,8 +306,15 @@ case class VarPop(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"var_pop($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.VAR_POP, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.VAR_POP,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -237,9 +330,15 @@ case class VarSamp(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"var_samp($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
     relBuilder.aggregateCall(
-      SqlStdOperatorTable.VAR_SAMP, false, false, null, name, child.toRexNode)
+      SqlStdOperatorTable.VAR_SAMP,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -281,9 +380,15 @@ case class AggFunctionCall(
 
   override def toString: String = s"${aggregateFunction.getClass.getSimpleName}($args)"
 
-  override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+  override def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
     relBuilder.aggregateCall(
-      this.getSqlAggFunction(), false, false, null, name, args.map(_.toRexNode): _*)
+      this.getSqlAggFunction(),
+      isDistinct,
+      false,
+      null,
+      name,
+      args.map(_.toRexNode): _*)
   }
 
   override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala
new file mode 100644
index 00000000000..c75e6faf107
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.functions
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.expressions.{AggFunctionCall, DistinctAgg, Expression}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getResultTypeOfAggregateFunction}
+
+/**
+  * Defines an implicit conversion method (distinct) that converts [[AggregateFunction]]s into
+  * [[DistinctAgg]] Expressions.
+  */
+private[flink] case class DistinctAggregateFunction[T: TypeInformation, ACC: TypeInformation]
+    (aggFunction: AggregateFunction[T, ACC]) {
+
+  private[flink] def distinct(params: Expression*): Expression = {
+    val resultTypeInfo: TypeInformation[_] = getResultTypeOfAggregateFunction(
+      aggFunction,
+      implicitly[TypeInformation[T]])
+
+    val accTypeInfo: TypeInformation[_] = getAccumulatorTypeOfAggregateFunction(
+      aggFunction,
+      implicitly[TypeInformation[ACC]])
+
+    DistinctAgg(
+      AggFunctionCall(aggFunction, resultTypeInfo, accTypeInfo, params))
+  }
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
index a2bd1e45124..7579621a1b2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
@@ -235,6 +235,14 @@ case class Aggregate(
     groupingExprs.foreach(validateGroupingExpression)
 
     def validateAggregateExpression(expr: Expression): Unit = expr match {
+      case distinctExpr: DistinctAgg =>
+        distinctExpr.child match {
+          case _: DistinctAgg => failValidation(
+            "Chained distinct operators are not supported!")
+          case aggExpr: Aggregation => validateAggregateExpression(aggExpr)
+          case _ => failValidation(
+            "Distinct operator can only be applied to aggregation expressions!")
+        }
       // check aggregate function
       case aggExpr: Aggregation
         if aggExpr.getSqlAggFunction.requiresOver =>
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
index 4bbb1012f2f..4e7270ffcd7 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
@@ -27,6 +27,19 @@ import org.junit._
 
 class AggregateStringExpressionTest extends TableTestBase {
 
+  @Test
+  def testDistinctAggregationTypes(): Unit = {
+    val util = batchTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3")
+
+    val t1 = t.select('_1.sum.distinct, '_1.count.distinct, '_1.avg.distinct)
+    val t2 = t.select("_1.sum.distinct, _1.count.distinct, _1.avg.distinct")
+    val t3 = t.select("sum.distinct(_1), count.distinct(_1), avg.distinct(_1)")
+
+    verifyTableEquals(t1, t2)
+    verifyTableEquals(t1, t3)
+  }
+
   @Test
   def testAggregationTypes(): Unit = {
     val util = batchTestUtil()
@@ -118,6 +131,19 @@ class AggregateStringExpressionTest extends TableTestBase {
     verifyTableEquals(distinct, distinct2)
   }
 
+  @Test
+  def testDistinctGroupedAggregate(): Unit = {
+    val util = batchTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+    val t1 = t.groupBy('b).select('b, 'a.sum.distinct, 'a.sum)
+    val t2 = t.groupBy("b").select("b, a.sum.distinct, a.sum")
+    val t3 = t.groupBy("b").select("b, sum.distinct(a), sum(a)")
+
+    verifyTableEquals(t1, t2)
+    verifyTableEquals(t1, t3)
+  }
+
   @Test
   def testGroupedAggregate(): Unit = {
     val util = batchTestUtil()
@@ -238,6 +264,22 @@ class AggregateStringExpressionTest extends TableTestBase {
     verifyTableEquals(resScala, resJava)
   }
 
+  @Test
+  def testDistinctAggregateWithUDAGG(): Unit = {
+    val util = batchTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+    val myCnt = new CountAggFunction
+    util.tableEnv.registerFunction("myCnt", myCnt)
+    val myWeightedAvg = new WeightedAvgWithMergeAndReset
+    util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+    val t1 = t.select(myCnt.distinct('a) as 'aCnt, myWeightedAvg.distinct('b, 'a) as 'wAvg)
+    val t2 = t.select("myCnt.distinct(a) as aCnt, myWeightedAvg.distinct(b, a) as wAvg")
+
+    verifyTableEquals(t1, t2)
+  }
+
   @Test
   def testAggregateWithUDAGG(): Unit = {
     val util = batchTestUtil()
@@ -254,6 +296,30 @@ class AggregateStringExpressionTest extends TableTestBase {
     verifyTableEquals(t1, t2)
   }
 
+  @Test
+  def testDistinctGroupedAggregateWithUDAGG(): Unit = {
+    val util = batchTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+
+    val myCnt = new CountAggFunction
+    util.tableEnv.registerFunction("myCnt", myCnt)
+    val myWeightedAvg = new WeightedAvgWithMergeAndReset
+    util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+    val t1 = t.groupBy('b)
+      .select('b,
+        myCnt.distinct('a) + 9 as 'aCnt,
+        myWeightedAvg.distinct('b, 'a) * 2 as 'wAvg,
+        myWeightedAvg.distinct('a, 'a) as 'distAgg,
+        myWeightedAvg('a, 'a) as 'agg)
+    val t2 = t.groupBy("b")
+      .select("b, myCnt.distinct(a) + 9 as aCnt, myWeightedAvg.distinct(b, a) * 2 as wAvg, " +
+        "myWeightedAvg.distinct(a, a) as distAgg, myWeightedAvg(a, a) as agg")
+
+    verifyTableEquals(t1, t2)
+  }
+
   @Test
   def testGroupedAggregateWithUDAGG(): Unit = {
     val util = batchTestUtil()
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
index 533235ad454..671f8dd8d1d 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
@@ -21,12 +21,64 @@ package org.apache.flink.table.api.stream.table
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.api.scala._
 import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
 import org.apache.flink.table.utils.TableTestUtil._
 import org.apache.flink.table.utils.TableTestBase
 import org.junit.Test
 
 class AggregateTest extends TableTestBase {
 
+  @Test
+  def testGroupDistinctAggregate(): Unit = {
+    val util = streamTestUtil()
+    val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+
+    val resultTable = table
+      .groupBy('b)
+      .select('a.sum.distinct, 'c.count.distinct)
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamGroupAggregate",
+          streamTableNode(0),
+          term("groupBy", "b"),
+          term("select", "b", "SUM(DISTINCT a) AS TMP_0", "COUNT(DISTINCT c) AS TMP_1")
+        ),
+        term("select", "TMP_0", "TMP_1")
+      )
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testGroupDistinctAggregateWithUDAGG(): Unit = {
+    val util = streamTestUtil()
+    val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+    val weightedAvg = new WeightedAvg
+
+    val resultTable = table
+      .groupBy('c)
+      .select(weightedAvg.distinct('a, 'b), weightedAvg('a, 'b))
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamGroupAggregate",
+          streamTableNode(0),
+          term("groupBy", "c"),
+          term(
+            "select",
+            "c",
+            "WeightedAvg(DISTINCT a, b) AS TMP_0",
+            "WeightedAvg(a, b) AS TMP_1")
+        ),
+        term("select", "TMP_0", "TMP_1")
+      )
+    util.verifyTable(resultTable, expected)
+  }
+
   @Test
   def testGroupAggregate() = {
     val util = streamTestUtil()
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
index 2bef95e5b40..ec57436b420 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
@@ -20,12 +20,80 @@ package org.apache.flink.table.api.stream.table.stringexpr
 
 import org.apache.flink.api.scala._
 import org.apache.flink.table.api.scala._
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
+import org.apache.flink.table.functions.aggfunctions.CountAggFunction
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{WeightedAvg, WeightedAvgWithMergeAndReset}
 import org.apache.flink.table.utils.TableTestBase
 import org.junit.Test
 
 class AggregateStringExpressionTest extends TableTestBase {
 
+
+  @Test
+  def testDistinctNonGroupedAggregate(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3")
+
+    val t1 = t.select('_1.sum.distinct, '_1.count.distinct, '_1.avg.distinct)
+    val t2 = t.select("_1.sum.distinct, _1.count.distinct, _1.avg.distinct")
+    val t3 = t.select("sum.distinct(_1), count.distinct(_1), avg.distinct(_1)")
+
+    verifyTableEquals(t1, t2)
+    verifyTableEquals(t1, t3)
+  }
+
+  @Test
+  def testDistinctGroupedAggregate(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+    val t1 = t.groupBy('b).select('b, 'a.sum.distinct, 'a.sum)
+    val t2 = t.groupBy("b").select("b, a.sum.distinct, a.sum")
+    val t3 = t.groupBy("b").select("b, sum.distinct(a), sum(a)")
+
+    verifyTableEquals(t1, t2)
+    verifyTableEquals(t1, t3)
+  }
+
+  @Test
+  def testDistinctNonGroupAggregateWithUDAGG(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+    val myCnt = new CountAggFunction
+    util.tableEnv.registerFunction("myCnt", myCnt)
+    val myWeightedAvg = new WeightedAvgWithMergeAndReset
+    util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+    val t1 = t.select(myCnt.distinct('a) as 'aCnt, myWeightedAvg.distinct('b, 'a) as 'wAvg)
+    val t2 = t.select("myCnt.distinct(a) as aCnt, myWeightedAvg.distinct(b, a) as wAvg")
+
+    verifyTableEquals(t1, t2)
+  }
+
+  @Test
+  def testDistinctGroupedAggregateWithUDAGG(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+
+    val myCnt = new CountAggFunction
+    util.tableEnv.registerFunction("myCnt", myCnt)
+    val myWeightedAvg = new WeightedAvgWithMergeAndReset
+    util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+    val t1 = t.groupBy('b)
+      .select('b,
+        myCnt.distinct('a) + 9 as 'aCnt,
+        myWeightedAvg.distinct('b, 'a) * 2 as 'wAvg,
+        myWeightedAvg.distinct('a, 'a) as 'distAgg,
+        myWeightedAvg('a, 'a) as 'agg)
+    val t2 = t.groupBy("b")
+      .select("b, myCnt.distinct(a) + 9 as aCnt, myWeightedAvg.distinct(b, a) * 2 as wAvg, " +
+        "myWeightedAvg.distinct(a, a) as distAgg, myWeightedAvg(a, a) as agg")
+
+    verifyTableEquals(t1, t2)
+  }
+
   @Test
   def testGroupedAggregate(): Unit = {
     val util = streamTestUtil()
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
index db6820ae90e..219b7653c57 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
@@ -25,7 +25,7 @@ import org.apache.flink.table.api.scala._
 import org.apache.flink.table.runtime.utils.StreamITCase.RetractingSink
 import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment, Types}
 import org.apache.flink.table.expressions.Null
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg}
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg, WeightedAvg}
 import org.apache.flink.table.runtime.utils.{JavaUserDefinedAggFunctions, StreamITCase, StreamTestData, StreamingWithStateTestBase}
 import org.apache.flink.types.Row
 import org.junit.Assert.assertEquals
@@ -40,6 +40,84 @@ class AggregateITCase extends StreamingWithStateTestBase {
   private val queryConfig = new StreamQueryConfig()
   queryConfig.withIdleStateRetentionTime(Time.hours(1), Time.hours(2))
 
+  @Test
+  def testDistinctUDAGG(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val testAgg = new DataViewTestAgg
+    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+      .groupBy('e)
+      .select('e, testAgg.distinct('d, 'e))
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+    env.execute()
+
+    val expected = mutable.MutableList("1,10", "2,21", "3,12")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
+  @Test
+  def testDistinctUDAGGMixedWithNonDistinctUsage(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val testAgg = new WeightedAvg
+    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+      .groupBy('e)
+      .select('e, testAgg.distinct('a, 'a), testAgg('a, 'a))
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+    env.execute()
+
+    val expected = mutable.MutableList("1,3,3", "2,3,4", "3,4,4")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
+  @Test
+  def testDistinctAggregate(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+      .groupBy('e)
+      .select('e, 'a.count.distinct)
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+    env.execute()
+
+    val expected = mutable.MutableList("1,4", "2,4", "3,2")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
+  @Test
+  def testDistinctAggregateMixedWithNonDistinct(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+      .groupBy('e)
+      .select('e, 'a.count.distinct, 'b.count)
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+    env.execute()
+
+    val expected = mutable.MutableList("1,4,5", "2,4,7", "3,2,3")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
   @Test
   def testDistinct(): Unit = {
     val env = StreamExecutionEnvironment.getExecutionEnvironment


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services