You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2018/09/13 21:24:12 UTC

[flink] branch master updated: [FLINK-5315] [table] Add support for DISTINCT aggregation to Table API.

This is an automated email from the ASF dual-hosted git repository.

fhueske pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 04c7cdf  [FLINK-5315] [table] Add support for DISTINCT aggregation to Table API.
04c7cdf is described below

commit 04c7cdf7d3d7b0862f7a4a3e7d821b1fb62ad3f0
Author: Rong Rong <wa...@hotmail.com>
AuthorDate: Tue Aug 7 14:29:46 2018 -0700

    [FLINK-5315] [table] Add support for DISTINCT aggregation to Table API.
    
    This closes #6521.
---
 docs/dev/table/tableApi.md                         |  76 ++++++++++
 .../flink/table/api/scala/expressionDsl.scala      |  10 +-
 .../flink/table/expressions/ExpressionParser.scala |  13 ++
 .../flink/table/expressions/aggregations.scala     | 164 +++++++++++++++++----
 .../functions/DistinctAggregateFunction.scala      |  43 ++++++
 .../flink/table/plan/logical/operators.scala       |   8 +
 .../stringexpr/AggregateStringExpressionTest.scala |  66 +++++++++
 .../table/api/stream/table/AggregateTest.scala     |  52 +++++++
 .../stringexpr/AggregateStringExpressionTest.scala |  70 ++++++++-
 .../runtime/stream/table/AggregateITCase.scala     |  80 +++++++++-
 10 files changed, 551 insertions(+), 31 deletions(-)

diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md
index 1a21a57..f50b0f5 100644
--- a/docs/dev/table/tableApi.md
+++ b/docs/dev/table/tableApi.md
@@ -372,6 +372,44 @@ Table result = orders
     </tr>
     <tr>
       <td>
+        <strong>Distinct Aggregation</strong><br>
+        <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br>
+        <span class="label label-info">Result Updating</span>
+      </td>
+      <td>
+        <p>Similar to a SQL DISTINCT aggregation clause such as COUNT(DISTINCT a). Distinct aggregation declares that an aggregation function (built-in or user-defined) is only applied on distinct input values. Distinct can be applied to <b>GroupBy Aggregation</b>, <b>GroupBy Window Aggregation</b> and <b>Over Window Aggregation</b>.</p>
+{% highlight java %}
+Table orders = tableEnv.scan("Orders");
+// Distinct aggregation on group by
+Table groupByDistinctResult = orders
+    .groupBy("a")
+    .select("a, b.sum.distinct as d");
+// Distinct aggregation on time window group by
+Table groupByWindowDistinctResult = orders
+    .window(Tumble.over("5.minutes").on("rowtime").as("w")).groupBy("a, w")
+    .select("a, b.sum.distinct as d");
+// Distinct aggregation on over window
+Table result = orders
+    .window(Over
+        .partitionBy("a")
+        .orderBy("rowtime")
+        .preceding("UNBOUNDED_RANGE")
+        .as("w"))
+    .select("a, b.avg.distinct over w, b.max over w, b.min over w");
+{% endhighlight %}
+        <p>User-defined aggregation function can also be used with DISTINCT modifiers. To calculate the aggregate results only for distinct values, simply add the distinct modifier towards the aggregation function. </p>
+{% highlight java %}
+Table orders = tEnv.scan("Orders");
+
+// Use distinct aggregation for user-defined aggregate functions
+tEnv.registerFunction("myUdagg", new MyUdagg());
+orders.groupBy("users").select("users, myUdagg.distinct(points) as myDistinctResult");
+{% endhighlight %}
+        <p><b>Note:</b> For streaming queries the required state to compute the query result might grow infinitely depending on the number of distinct fields. Please provide a query configuration with valid retention interval to prevent excessive state size. See <a href="streaming.html">Streaming Concepts</a> for details.</p>
+      </td>
+    </tr>
+    <tr>
+      <td>
         <strong>Distinct</strong><br>
         <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br>
         <span class="label label-info">Result Updating</span>
@@ -455,6 +493,44 @@ val result: Table = orders
     </tr>
     <tr>
       <td>
+        <strong>Distinct Aggregation</strong><br>
+        <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span> <br>
+        <span class="label label-info">Result Updating</span>
+      </td>
+      <td>
+        <p>Similar to a SQL DISTINCT AGGREGATION clause such as COUNT(DISTINCT a). Distinct aggregation declares that an aggregation function (built-in or user-defined) is only applied on distinct input values. Distinct can be applied to <b>GroupBy Aggregation</b>, <b>GroupBy Window Aggregation</b> and <b>Over Window Aggregation</b>.</p>
+{% highlight scala %}
+val orders: Table = tableEnv.scan("Orders");
+// Distinct aggregation on group by
+val groupByDistinctResult = orders
+    .groupBy('a)
+    .select('a, 'b.sum.distinct as 'd)
+// Distinct aggregation on time window group by
+val groupByWindowDistinctResult = orders
+    .window(Tumble over 5.minutes on 'rowtime as 'w).groupBy('a, 'w)
+    .select('a, 'b.sum.distinct as 'd)
+// Distinct aggregation on over window
+val result = orders
+    .window(Over
+        partitionBy 'a
+        orderBy 'rowtime
+        preceding UNBOUNDED_RANGE
+        as 'w)
+    .select('a, 'b.avg.distinct over 'w, 'b.max over 'w, 'b.min over 'w)
+{% endhighlight %}
+        <p>User-defined aggregation function can also be used with DISTINCT modifiers. To calculate the aggregate results only for distinct values, simply add the distinct modifier towards the aggregation function. </p>
+{% highlight scala %}
+val orders: Table = tEnv.scan("Orders");
+
+// Use distinct aggregation for user-defined aggregate functions
+val myUdagg = new MyUdagg();
+orders.groupBy('users).select('users, myUdagg.distinct('points) as 'myDistinctResult);
+{% endhighlight %}
+        <p><b>Note:</b> For streaming queries the required state to compute the query result might grow infinitely depending on the number of distinct fields. Please provide a query configuration with valid retention interval to prevent excessive state size. See <a href="streaming.html">Streaming Concepts</a> for details.</p>
+      </td>
+    </tr>
+    <tr>
+      <td>
         <strong>Distinct</strong><br>
         <span class="label label-primary">Batch</span>
       </td>
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index cf8cd91..126cc5f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -22,12 +22,12 @@ import java.sql.{Date, Time, Timestamp}
 
 import org.apache.calcite.avatica.util.DateTimeUtils._
 import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
-import org.apache.flink.table.api.{TableException, CurrentRow, CurrentRange, UnboundedRow, UnboundedRange}
+import org.apache.flink.table.api.{CurrentRange, CurrentRow, TableException, UnboundedRange, UnboundedRow}
 import org.apache.flink.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval}
 import org.apache.flink.table.api.Table
 import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit
 import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.{AggregateFunction, DistinctAggregateFunction}
 
 import scala.language.implicitConversions
 
@@ -214,7 +214,7 @@ trait ImplicitExpressionOperations {
   def varSamp = VarSamp(expr)
 
   /**
-    *  Returns multiset aggregate of a given expression.
+    * Returns multiset aggregate of a given expression.
     */
   def collect = Collect(expr)
 
@@ -998,6 +998,10 @@ trait ImplicitExpressionConversions {
   implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array)
   implicit def userDefinedAggFunctionConstructor[T: TypeInformation, ACC: TypeInformation]
       (udagg: AggregateFunction[T, ACC]): UDAGGExpression[T, ACC] = UDAGGExpression(udagg)
+  implicit def toDistinct(agg: Aggregation): DistinctAgg = DistinctAgg(agg)
+  implicit def toDistinct[T: TypeInformation, ACC: TypeInformation]
+      (agg: AggregateFunction[T, ACC]): DistinctAggregateFunction[T, ACC] =
+    DistinctAggregateFunction(agg)
 }
 
 // ------------------------------------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
index 633afb1..4909d2c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
@@ -83,6 +83,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
   lazy val GET: Keyword = Keyword("get")
   lazy val FLATTEN: Keyword = Keyword("flatten")
   lazy val OVER: Keyword = Keyword("over")
+  lazy val DISTINCT: Keyword = Keyword("distinct")
   lazy val CURRENT_ROW: Keyword = Keyword("current_row")
   lazy val CURRENT_RANGE: Keyword = Keyword("current_range")
   lazy val UNBOUNDED_ROW: Keyword = Keyword("unbounded_row")
@@ -324,6 +325,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
   lazy val suffixFlattening: PackratParser[Expression] =
     composite <~ "." ~ FLATTEN ~ opt("()") ^^ { e => Flattening(e) }
 
+  lazy val suffixDistinct: PackratParser[Expression] =
+    composite <~ "." ~ DISTINCT ~ opt("()") ^^ { e => DistinctAgg(e) }
+
   lazy val suffixAs: PackratParser[Expression] =
     composite ~ "." ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
       case e ~ _ ~ _ ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name))
@@ -345,6 +349,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
     suffixGet |
     // expression with special identifier
     suffixIf |
+    // expression with distinct suffix modifier
+    suffixDistinct |
     // function call must always be at the end
     suffixFunctionCall | suffixFunctionCallOneArg
 
@@ -412,6 +418,11 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
   lazy val prefixToTime: PackratParser[Expression] =
     TO_TIME ~ "(" ~> expression <~ ")" ^^ { e => Cast(e, SqlTimeTypeInfo.TIME) }
 
+  lazy val prefixDistinct: PackratParser[Expression] =
+    functionIdent ~ "." ~ DISTINCT ~ "(" ~ repsep(expression, ",") ~ ")" ^^ {
+      case name ~ _ ~ _ ~ _ ~ args ~ _ => DistinctAgg(Call(name.toUpperCase, args))
+  }
+
   lazy val prefixAs: PackratParser[Expression] =
     AS ~ "(" ~ expression ~ "," ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
     case _ ~ _ ~ e ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name))
@@ -428,6 +439,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
     prefixGet |
     // expression with special identifier
     prefixIf |
+    // expression with prefix distinct
+    prefixDistinct |
     // function call must always be at the end
     prefixFunctionCall | prefixFunctionCallOneArg
 
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index b39bd98..c77bd7a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -19,7 +19,6 @@ package org.apache.flink.table.expressions
 
 import org.apache.calcite.rex.RexNode
 import org.apache.calcite.sql.SqlAggFunction
-import org.apache.calcite.sql.SqlKind._
 import org.apache.calcite.sql.fun._
 import org.apache.calcite.tools.RelBuilder
 import org.apache.calcite.tools.RelBuilder.AggCall
@@ -43,7 +42,10 @@ abstract sealed class Aggregation extends Expression {
   /**
     * Convert Aggregate to its counterpart in Calcite, i.e. AggCall
     */
-  private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall
+  private[flink] def toAggCall(
+      name: String,
+      isDistinct: Boolean = false
+  )(implicit relBuilder: RelBuilder): AggCall
 
   /**
     * Returns the SqlAggFunction for this Aggregation.
@@ -52,12 +54,49 @@ abstract sealed class Aggregation extends Expression {
 
 }
 
+case class DistinctAgg(child: Expression) extends Aggregation {
+
+  private[flink] def distinct: Expression = DistinctAgg(child)
+
+  override private[flink] def resultType: TypeInformation[_] = child.resultType
+
+  override private[flink] def validateInput(): ValidationResult = {
+    super.validateInput()
+    child match {
+      case agg: Aggregation =>
+        child.validateInput()
+      case _ =>
+        ValidationFailure(s"Distinct modifier cannot be applied to $child! " +
+          s"It can only be applied to an aggregation expression, for example, " +
+          s"'a.count.distinct which is equivalent with COUNT(DISTINCT a).")
+    }
+  }
+
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = true)(implicit relBuilder: RelBuilder) = {
+    child.asInstanceOf[Aggregation].toAggCall(name, isDistinct = true)
+  }
+
+  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
+    child.asInstanceOf[Aggregation].getSqlAggFunction()
+  }
+
+  override private[flink] def children = Seq(child)
+}
+
 case class Sum(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"sum($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.SUM, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.SUM,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -77,8 +116,15 @@ case class Sum0(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"sum0($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.SUM0, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.SUM0,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -94,8 +140,15 @@ case class Min(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"min($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.MIN, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.MIN,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -112,8 +165,15 @@ case class Max(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"max($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.MAX, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.MAX,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -130,8 +190,15 @@ case class Count(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"count($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.COUNT, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.COUNT,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO
@@ -145,8 +212,15 @@ case class Avg(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"avg($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.AVG, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.AVG,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -171,8 +245,15 @@ case class Collect(child: Expression) extends Aggregation  {
 
   override def toString: String = s"collect($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.COLLECT,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
@@ -184,9 +265,15 @@ case class StddevPop(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"stddev_pop($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
     relBuilder.aggregateCall(
-      SqlStdOperatorTable.STDDEV_POP, false, false, null, name, child.toRexNode)
+      SqlStdOperatorTable.STDDEV_POP,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -202,9 +289,15 @@ case class StddevSamp(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"stddev_samp($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
     relBuilder.aggregateCall(
-      SqlStdOperatorTable.STDDEV_SAMP, false, false, null, name, child.toRexNode)
+      SqlStdOperatorTable.STDDEV_SAMP,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -220,8 +313,15 @@ case class VarPop(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"var_pop($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
-    relBuilder.aggregateCall(SqlStdOperatorTable.VAR_POP, false, false, null, name, child.toRexNode)
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
+    relBuilder.aggregateCall(
+      SqlStdOperatorTable.VAR_POP,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -237,9 +337,15 @@ case class VarSamp(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"var_samp($child)"
 
-  override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+  override private[flink] def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
     relBuilder.aggregateCall(
-      SqlStdOperatorTable.VAR_SAMP, false, false, null, name, child.toRexNode)
+      SqlStdOperatorTable.VAR_SAMP,
+      isDistinct,
+      false,
+      null,
+      name,
+      child.toRexNode)
   }
 
   override private[flink] def resultType = child.resultType
@@ -281,9 +387,15 @@ case class AggFunctionCall(
 
   override def toString: String = s"${aggregateFunction.getClass.getSimpleName}($args)"
 
-  override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+  override def toAggCall(
+      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
     relBuilder.aggregateCall(
-      this.getSqlAggFunction(), false, false, null, name, args.map(_.toRexNode): _*)
+      this.getSqlAggFunction(),
+      isDistinct,
+      false,
+      null,
+      name,
+      args.map(_.toRexNode): _*)
   }
 
   override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala
new file mode 100644
index 0000000..c75e6fa
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/DistinctAggregateFunction.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.functions
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.expressions.{AggFunctionCall, DistinctAgg, Expression}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getResultTypeOfAggregateFunction}
+
+/**
+  * Defines an implicit conversion method (distinct) that converts [[AggregateFunction]]s into
+  * [[DistinctAgg]] Expressions.
+  */
+private[flink] case class DistinctAggregateFunction[T: TypeInformation, ACC: TypeInformation]
+    (aggFunction: AggregateFunction[T, ACC]) {
+
+  private[flink] def distinct(params: Expression*): Expression = {
+    val resultTypeInfo: TypeInformation[_] = getResultTypeOfAggregateFunction(
+      aggFunction,
+      implicitly[TypeInformation[T]])
+
+    val accTypeInfo: TypeInformation[_] = getAccumulatorTypeOfAggregateFunction(
+      aggFunction,
+      implicitly[TypeInformation[ACC]])
+
+    DistinctAgg(
+      AggFunctionCall(aggFunction, resultTypeInfo, accTypeInfo, params))
+  }
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
index a2bd1e4..7579621 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
@@ -235,6 +235,14 @@ case class Aggregate(
     groupingExprs.foreach(validateGroupingExpression)
 
     def validateAggregateExpression(expr: Expression): Unit = expr match {
+      case distinctExpr: DistinctAgg =>
+        distinctExpr.child match {
+          case _: DistinctAgg => failValidation(
+            "Chained distinct operators are not supported!")
+          case aggExpr: Aggregation => validateAggregateExpression(aggExpr)
+          case _ => failValidation(
+            "Distinct operator can only be applied to aggregation expressions!")
+        }
       // check aggregate function
       case aggExpr: Aggregation
         if aggExpr.getSqlAggFunction.requiresOver =>
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
index 4bbb101..4e7270f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
@@ -28,6 +28,19 @@ import org.junit._
 class AggregateStringExpressionTest extends TableTestBase {
 
   @Test
+  def testDistinctAggregationTypes(): Unit = {
+    val util = batchTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3")
+
+    val t1 = t.select('_1.sum.distinct, '_1.count.distinct, '_1.avg.distinct)
+    val t2 = t.select("_1.sum.distinct, _1.count.distinct, _1.avg.distinct")
+    val t3 = t.select("sum.distinct(_1), count.distinct(_1), avg.distinct(_1)")
+
+    verifyTableEquals(t1, t2)
+    verifyTableEquals(t1, t3)
+  }
+
+  @Test
   def testAggregationTypes(): Unit = {
     val util = batchTestUtil()
     val t = util.addTable[(Int, Long, String)]("Table3")
@@ -119,6 +132,19 @@ class AggregateStringExpressionTest extends TableTestBase {
   }
 
   @Test
+  def testDistinctGroupedAggregate(): Unit = {
+    val util = batchTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+    val t1 = t.groupBy('b).select('b, 'a.sum.distinct, 'a.sum)
+    val t2 = t.groupBy("b").select("b, a.sum.distinct, a.sum")
+    val t3 = t.groupBy("b").select("b, sum.distinct(a), sum(a)")
+
+    verifyTableEquals(t1, t2)
+    verifyTableEquals(t1, t3)
+  }
+
+  @Test
   def testGroupedAggregate(): Unit = {
     val util = batchTestUtil()
     val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
@@ -239,6 +265,22 @@ class AggregateStringExpressionTest extends TableTestBase {
   }
 
   @Test
+  def testDistinctAggregateWithUDAGG(): Unit = {
+    val util = batchTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+    val myCnt = new CountAggFunction
+    util.tableEnv.registerFunction("myCnt", myCnt)
+    val myWeightedAvg = new WeightedAvgWithMergeAndReset
+    util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+    val t1 = t.select(myCnt.distinct('a) as 'aCnt, myWeightedAvg.distinct('b, 'a) as 'wAvg)
+    val t2 = t.select("myCnt.distinct(a) as aCnt, myWeightedAvg.distinct(b, a) as wAvg")
+
+    verifyTableEquals(t1, t2)
+  }
+
+  @Test
   def testAggregateWithUDAGG(): Unit = {
     val util = batchTestUtil()
     val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
@@ -255,6 +297,30 @@ class AggregateStringExpressionTest extends TableTestBase {
   }
 
   @Test
+  def testDistinctGroupedAggregateWithUDAGG(): Unit = {
+    val util = batchTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+
+    val myCnt = new CountAggFunction
+    util.tableEnv.registerFunction("myCnt", myCnt)
+    val myWeightedAvg = new WeightedAvgWithMergeAndReset
+    util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+    val t1 = t.groupBy('b)
+      .select('b,
+        myCnt.distinct('a) + 9 as 'aCnt,
+        myWeightedAvg.distinct('b, 'a) * 2 as 'wAvg,
+        myWeightedAvg.distinct('a, 'a) as 'distAgg,
+        myWeightedAvg('a, 'a) as 'agg)
+    val t2 = t.groupBy("b")
+      .select("b, myCnt.distinct(a) + 9 as aCnt, myWeightedAvg.distinct(b, a) * 2 as wAvg, " +
+        "myWeightedAvg.distinct(a, a) as distAgg, myWeightedAvg(a, a) as agg")
+
+    verifyTableEquals(t1, t2)
+  }
+
+  @Test
   def testGroupedAggregateWithUDAGG(): Unit = {
     val util = batchTestUtil()
     val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
index 533235a..671f8dd 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala
@@ -21,6 +21,7 @@ package org.apache.flink.table.api.stream.table
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.api.scala._
 import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
 import org.apache.flink.table.utils.TableTestUtil._
 import org.apache.flink.table.utils.TableTestBase
 import org.junit.Test
@@ -28,6 +29,57 @@ import org.junit.Test
 class AggregateTest extends TableTestBase {
 
   @Test
+  def testGroupDistinctAggregate(): Unit = {
+    val util = streamTestUtil()
+    val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+
+    val resultTable = table
+      .groupBy('b)
+      .select('a.sum.distinct, 'c.count.distinct)
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamGroupAggregate",
+          streamTableNode(0),
+          term("groupBy", "b"),
+          term("select", "b", "SUM(DISTINCT a) AS TMP_0", "COUNT(DISTINCT c) AS TMP_1")
+        ),
+        term("select", "TMP_0", "TMP_1")
+      )
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testGroupDistinctAggregateWithUDAGG(): Unit = {
+    val util = streamTestUtil()
+    val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+    val weightedAvg = new WeightedAvg
+
+    val resultTable = table
+      .groupBy('c)
+      .select(weightedAvg.distinct('a, 'b), weightedAvg('a, 'b))
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamGroupAggregate",
+          streamTableNode(0),
+          term("groupBy", "c"),
+          term(
+            "select",
+            "c",
+            "WeightedAvg(DISTINCT a, b) AS TMP_0",
+            "WeightedAvg(a, b) AS TMP_1")
+        ),
+        term("select", "TMP_0", "TMP_1")
+      )
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
   def testGroupAggregate() = {
     val util = streamTestUtil()
     val table = util.addTable[(Long, Int, String)]('a, 'b, 'c)
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
index 2bef95e..ec57436 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/AggregateStringExpressionTest.scala
@@ -20,12 +20,80 @@ package org.apache.flink.table.api.stream.table.stringexpr
 
 import org.apache.flink.api.scala._
 import org.apache.flink.table.api.scala._
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
+import org.apache.flink.table.functions.aggfunctions.CountAggFunction
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{WeightedAvg, WeightedAvgWithMergeAndReset}
 import org.apache.flink.table.utils.TableTestBase
 import org.junit.Test
 
 class AggregateStringExpressionTest extends TableTestBase {
 
+
+  @Test
+  def testDistinctNonGroupedAggregate(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3")
+
+    val t1 = t.select('_1.sum.distinct, '_1.count.distinct, '_1.avg.distinct)
+    val t2 = t.select("_1.sum.distinct, _1.count.distinct, _1.avg.distinct")
+    val t3 = t.select("sum.distinct(_1), count.distinct(_1), avg.distinct(_1)")
+
+    verifyTableEquals(t1, t2)
+    verifyTableEquals(t1, t3)
+  }
+
+  @Test
+  def testDistinctGroupedAggregate(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+    val t1 = t.groupBy('b).select('b, 'a.sum.distinct, 'a.sum)
+    val t2 = t.groupBy("b").select("b, a.sum.distinct, a.sum")
+    val t3 = t.groupBy("b").select("b, sum.distinct(a), sum(a)")
+
+    verifyTableEquals(t1, t2)
+    verifyTableEquals(t1, t3)
+  }
+
+  @Test
+  def testDistinctNonGroupAggregateWithUDAGG(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+    val myCnt = new CountAggFunction
+    util.tableEnv.registerFunction("myCnt", myCnt)
+    val myWeightedAvg = new WeightedAvgWithMergeAndReset
+    util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+    val t1 = t.select(myCnt.distinct('a) as 'aCnt, myWeightedAvg.distinct('b, 'a) as 'wAvg)
+    val t2 = t.select("myCnt.distinct(a) as aCnt, myWeightedAvg.distinct(b, a) as wAvg")
+
+    verifyTableEquals(t1, t2)
+  }
+
+  @Test
+  def testDistinctGroupedAggregateWithUDAGG(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("Table3", 'a, 'b, 'c)
+
+
+    val myCnt = new CountAggFunction
+    util.tableEnv.registerFunction("myCnt", myCnt)
+    val myWeightedAvg = new WeightedAvgWithMergeAndReset
+    util.tableEnv.registerFunction("myWeightedAvg", myWeightedAvg)
+
+    val t1 = t.groupBy('b)
+      .select('b,
+        myCnt.distinct('a) + 9 as 'aCnt,
+        myWeightedAvg.distinct('b, 'a) * 2 as 'wAvg,
+        myWeightedAvg.distinct('a, 'a) as 'distAgg,
+        myWeightedAvg('a, 'a) as 'agg)
+    val t2 = t.groupBy("b")
+      .select("b, myCnt.distinct(a) + 9 as aCnt, myWeightedAvg.distinct(b, a) * 2 as wAvg, " +
+        "myWeightedAvg.distinct(a, a) as distAgg, myWeightedAvg(a, a) as agg")
+
+    verifyTableEquals(t1, t2)
+  }
+
   @Test
   def testGroupedAggregate(): Unit = {
     val util = streamTestUtil()
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
index db6820a..219b765 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
@@ -25,7 +25,7 @@ import org.apache.flink.table.api.scala._
 import org.apache.flink.table.runtime.utils.StreamITCase.RetractingSink
 import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment, Types}
 import org.apache.flink.table.expressions.Null
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg}
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg, WeightedAvg}
 import org.apache.flink.table.runtime.utils.{JavaUserDefinedAggFunctions, StreamITCase, StreamTestData, StreamingWithStateTestBase}
 import org.apache.flink.types.Row
 import org.junit.Assert.assertEquals
@@ -41,6 +41,84 @@ class AggregateITCase extends StreamingWithStateTestBase {
   queryConfig.withIdleStateRetentionTime(Time.hours(1), Time.hours(2))
 
   @Test
+  def testDistinctUDAGG(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val testAgg = new DataViewTestAgg
+    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+      .groupBy('e)
+      .select('e, testAgg.distinct('d, 'e))
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+    env.execute()
+
+    val expected = mutable.MutableList("1,10", "2,21", "3,12")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
+  @Test
+  def testDistinctUDAGGMixedWithNonDistinctUsage(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val testAgg = new WeightedAvg
+    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+      .groupBy('e)
+      .select('e, testAgg.distinct('a, 'a), testAgg('a, 'a))
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+    env.execute()
+
+    val expected = mutable.MutableList("1,3,3", "2,3,4", "3,4,4")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
+  @Test
+  def testDistinctAggregate(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+      .groupBy('e)
+      .select('e, 'a.count.distinct)
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+    env.execute()
+
+    val expected = mutable.MutableList("1,4", "2,4", "3,2")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
+  @Test
+  def testDistinctAggregateMixedWithNonDistinct(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
+      .groupBy('e)
+      .select('e, 'a.count.distinct, 'b.count)
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+    env.execute()
+
+    val expected = mutable.MutableList("1,4,5", "2,4,7", "3,2,3")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
+  @Test
   def testDistinct(): Unit = {
     val env = StreamExecutionEnvironment.getExecutionEnvironment
     env.setStateBackend(getStateBackend)