You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2018/04/26 21:22:02 UTC
flink git commit: [FLINK-8689] [table] Add support DISTINCT
aggregates in OVER windows.
Repository: flink
Updated Branches:
refs/heads/master 3ac282322 -> 6aef014a3
[FLINK-8689] [table] Add support DISTINCT aggregates in OVER windows.
- distinct values are stored in a MapView, either on the heap or in a StateBackend depending on the context.
This closes #5555.
This closes #3783 (PR was superseded by #5555)
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6aef014a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6aef014a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6aef014a
Branch: refs/heads/master
Commit: 6aef014a3e27dd1a93b463dd47ea0b6ec339597a
Parents: 3ac2823
Author: Rong Rong <ro...@uber.com>
Authored: Wed Feb 21 13:40:29 2018 -0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Apr 26 23:15:56 2018 +0200
----------------------------------------------------------------------
.../codegen/AggregationCodeGenerator.scala | 420 +++++++++++++------
.../aggfunctions/DistinctAccumulator.scala | 121 ++++++
.../flink/table/plan/nodes/OverAggregate.scala | 9 +-
.../table/runtime/aggregate/AggregateUtil.scala | 91 +++-
.../table/api/stream/sql/OverWindowTest.scala | 67 +++
.../runtime/stream/sql/OverWindowITCase.scala | 228 ++++++++++
6 files changed, 801 insertions(+), 135 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/6aef014a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index a9ec112..d6a7b1a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -22,17 +22,21 @@ import java.lang.{Iterable => JIterable}
import org.apache.calcite.rex.RexLiteral
import org.apache.commons.codec.binary.Base64
-import org.apache.flink.api.common.state.{State, StateDescriptor}
-import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.state.{ListStateDescriptor, MapStateDescriptor, State, StateDescriptor}
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument, getRawClass}
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.api.dataview._
import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess}
import org.apache.flink.table.codegen.Indenter.toISC
+import org.apache.flink.table.dataview.{MapViewTypeInfo, StateListView, StateMapView}
import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.aggfunctions.DistinctAccumulator
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString}
import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, SingleElementIterable}
+import org.apache.flink.types.Row
import org.apache.flink.util.InstantiationUtil
import scala.collection.mutable
@@ -74,6 +78,8 @@ class AggregationCodeGenerator(
* @param aggregates All aggregate functions
* @param aggFields Indexes of the input fields for all aggregate functions
* @param aggMapping The mapping of aggregates to output fields
+ * @param isDistinctAggs The flag array indicating whether it is distinct aggregate.
+ * @param isStateBackedDataViews a flag to indicate if distinct filter uses state backend.
* @param partialResults A flag defining whether final or partial results (accumulators) are set
* to the output row.
* @param fwdMapping The mapping of input fields to output fields
@@ -93,6 +99,8 @@ class AggregationCodeGenerator(
aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]],
aggFields: Array[Array[Int]],
aggMapping: Array[Int],
+ isDistinctAggs: Array[Boolean],
+ isStateBackedDataViews: Boolean,
partialResults: Boolean,
fwdMapping: Array[Int],
mergeMapping: Option[Array[Int]],
@@ -151,8 +159,40 @@ class AggregationCodeGenerator(
}
}
- // initialize and create data views
- addReusableDataViews()
+ // get distinct filter of acc fields for each aggregate functions
+ val distinctAccType = s"${classOf[DistinctAccumulator[_]].getName}"
+
+ // preparing MapViewSpecs for distinct value maps
+ val distinctAggs: Array[Seq[DataViewSpec[_]]] = isDistinctAggs.zipWithIndex.map {
+ case (isDistinctAgg, idx) if isDistinctAgg =>
+
+ // get types of agg function arguments
+ val argTypes: Array[TypeInformation[_]] = aggFields(idx)
+ .map(physicalInputTypes(_))
+ // create type for MapView
+ val mapViewTypeInfo = new MapViewTypeInfo(
+ new RowTypeInfo(argTypes:_*),
+ BasicTypeInfo.LONG_TYPE_INFO)
+ // create MapViewSpec for distinct value map
+ Seq(
+ MapViewSpec(
+ "distinctAgg" + idx,
+ classOf[DistinctAccumulator[_]].getDeclaredField("distinctValueMap"),
+ mapViewTypeInfo)
+ )
+ case _ => Seq()
+ }
+
+ if (isDistinctAggs.contains(true) && partialResults && isStateBackedDataViews) {
+ // should not happen, but add an error message just in case.
+ throw new CodeGenException(
+ s"Cannot emit partial results if DISTINCT values are tracked in state-backed maps. " +
+ s"Please report this bug."
+ )
+ }
+
+ // initialize and create data views for accumulators & distinct filters
+ addAccumulatorDataViews()
// check and validate the needed methods
aggregates.zipWithIndex.map {
@@ -208,6 +248,48 @@ class AggregationCodeGenerator(
}
/**
+ * Add all data views for all field accumulators and distinct filters defined by
+ * aggregation functions.
+ */
+ def addAccumulatorDataViews(): Unit = {
+ if (isStateBackedDataViews) {
+ // create MapStates for distinct value maps
+ val descMapping: Map[String, StateDescriptor[_, _]] = distinctAggs
+ .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
+ .toMap[String, StateDescriptor[_ <: State, _]]
+
+ for (i <- aggs.indices) yield {
+ for (spec <- distinctAggs(i)) {
+ // Check if stat descriptor exists.
+ val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId,
+ throw new CodeGenException(
+ s"Can not find DataView for distinct filter in accumulator by id: ${spec.stateId}"))
+
+ addReusableDataView(spec, desc, i)
+ }
+ }
+ }
+
+ if (accConfig.isDefined) {
+ // create state handles for DataView backed accumulator fields.
+ val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get
+ .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
+ .toMap[String, StateDescriptor[_ <: State, _]]
+
+ for (i <- aggs.indices) yield {
+ for (spec <- accConfig.get(i)) yield {
+ // Check if stat descriptor exists.
+ val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId,
+ throw new CodeGenException(
+ s"Can not find DataView in accumulator by id: ${spec.stateId}"))
+
+ addReusableDataView(spec, desc, i)
+ }
+ }
+ }
+ }
+
+ /**
* Create DataView Term, for example, acc1_map_dataview.
*
* @param aggIndex index of aggregate function
@@ -221,98 +303,106 @@ class AggregationCodeGenerator(
/**
* Adds a reusable [[org.apache.flink.table.api.dataview.DataView]] to the open, cleanup,
* close and member area of the generated function.
- *
+ * @param spec the [[DataViewSpec]] of the desired data view term.
+ * @param desc the [[StateDescriptor]] of the desired data view term.
+ * @param aggIndex the aggregation function index associate with the data view.
*/
- def addReusableDataViews(): Unit = {
- if (accConfig.isDefined) {
- val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get
- .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
- .toMap[String, StateDescriptor[_ <: State, _]]
+ def addReusableDataView(
+ spec: DataViewSpec[_],
+ desc: StateDescriptor[_, _],
+ aggIndex: Int): Unit = {
+ val dataViewField = spec.field
+ val dataViewTypeTerm = dataViewField.getType.getCanonicalName
+
+ // define the DataView variables
+ val serializedData = serializeStateDescriptor(desc)
+ val dataViewFieldTerm = createDataViewTerm(aggIndex, dataViewField.getName)
+ val field =
+ s"""
+ | final $dataViewTypeTerm $dataViewFieldTerm;
+ |""".stripMargin
+ reusableMemberStatements.add(field)
+
+ // create DataViews
+ val descFieldTerm = s"${dataViewFieldTerm}_desc"
+ val descClassQualifier = classOf[StateDescriptor[_, _]].getCanonicalName
+ val descDeserializeCode =
+ s"""
+ | $descClassQualifier $descFieldTerm = ($descClassQualifier)
+ | org.apache.flink.util.InstantiationUtil.deserializeObject(
+ | org.apache.commons.codec.binary.Base64.decodeBase64("$serializedData"),
+ | $contextTerm.getUserCodeClassLoader());
+ |""".stripMargin
+ val createDataView = if (dataViewField.getType == classOf[MapView[_, _]]) {
+ s"""
+ | $descDeserializeCode
+ | $dataViewFieldTerm = new ${classOf[StateMapView[_, _]].getCanonicalName}(
+ | $contextTerm.getMapState(
+ | (${classOf[MapStateDescriptor[_, _]].getCanonicalName}) $descFieldTerm));
+ |""".stripMargin
+ } else if (dataViewField.getType == classOf[ListView[_]]) {
+ s"""
+ | $descDeserializeCode
+ | $dataViewFieldTerm = new ${classOf[StateListView[_]].getCanonicalName}(
+ | $contextTerm.getListState(
+ | (${classOf[ListStateDescriptor[_]].getCanonicalName}) $descFieldTerm));
+ |""".stripMargin
+ } else {
+ throw new CodeGenException(s"Unsupported dataview type: $dataViewTypeTerm")
+ }
+ reusableOpenStatements.add(createDataView)
+
+ // cleanup DataViews
+ val cleanup =
+ s"""
+ | $dataViewFieldTerm.clear();
+ |""".stripMargin
+ reusableCleanupStatements.add(cleanup)
+ }
- for (i <- aggs.indices) yield {
- for (spec <- accConfig.get(i)) yield {
- val dataViewField = spec.field
- val dataViewTypeTerm = dataViewField.getType.getCanonicalName
- val desc = descMapping.getOrElse(spec.stateId,
- throw new CodeGenException(
- s"Can not find DataView in accumulator by id: ${spec.stateId}"))
+ def genDistinctDataViewFieldSetter(str: String, i: Int): String = {
+ if (isStateBackedDataViews && distinctAggs(i).nonEmpty) {
+ genDataViewFieldSetter(distinctAggs(i), str, i)
+ } else {
+ ""
+ }
+ }
- // define the DataView variables
- val serializedData = serializeStateDescriptor(desc)
- val dataViewFieldTerm = createDataViewTerm(i, dataViewField.getName)
- val field =
- s"""
- | final $dataViewTypeTerm $dataViewFieldTerm;
- |""".stripMargin
- reusableMemberStatements.add(field)
-
- // create DataViews
- val descFieldTerm = s"${dataViewFieldTerm}_desc"
- val descClassQualifier = classOf[StateDescriptor[_, _]].getCanonicalName
- val descDeserializeCode =
- s"""
- | $descClassQualifier $descFieldTerm = ($descClassQualifier)
- | org.apache.flink.util.InstantiationUtil.deserializeObject(
- | org.apache.commons.codec.binary.Base64.decodeBase64("$serializedData"),
- | $contextTerm.getUserCodeClassLoader());
- |""".stripMargin
- val createDataView = if (dataViewField.getType == classOf[MapView[_, _]]) {
- s"""
- | $descDeserializeCode
- | $dataViewFieldTerm = new org.apache.flink.table.dataview.StateMapView(
- | $contextTerm.getMapState((
- | org.apache.flink.api.common.state.MapStateDescriptor)$descFieldTerm));
- |""".stripMargin
- } else if (dataViewField.getType == classOf[ListView[_]]) {
- s"""
- | $descDeserializeCode
- | $dataViewFieldTerm = new org.apache.flink.table.dataview.StateListView(
- | $contextTerm.getListState((
- | org.apache.flink.api.common.state.ListStateDescriptor)$descFieldTerm));
- |""".stripMargin
- } else {
- throw new CodeGenException(s"Unsupported dataview type: $dataViewTypeTerm")
- }
- reusableOpenStatements.add(createDataView)
-
- // cleanup DataViews
- val cleanup =
- s"""
- | $dataViewFieldTerm.clear();
- |""".stripMargin
- reusableCleanupStatements.add(cleanup)
- }
- }
+ def genAccDataViewFieldSetter(str: String, i: Int): String = {
+ if (accConfig.isDefined) {
+ genDataViewFieldSetter(accConfig.get(i), str, i)
+ } else {
+ ""
}
}
/**
* Generate statements to set data view field when use state backend.
*
+ * @param specs aggregation [[DataViewSpec]]s for this aggregation term.
* @param accTerm aggregation term
* @param aggIndex index of aggregation
* @return data view field set statements
*/
- def genDataViewFieldSetter(accTerm: String, aggIndex: Int): String = {
- if (accConfig.isDefined) {
- val setters = for (spec <- accConfig.get(aggIndex)) yield {
- val field = spec.field
- val dataViewTerm = createDataViewTerm(aggIndex, field.getName)
- val fieldSetter = if (Modifier.isPublic(field.getModifiers)) {
- s"$accTerm.${field.getName} = $dataViewTerm;"
- } else {
- val fieldTerm = addReusablePrivateFieldAccess(field.getDeclaringClass, field.getName)
- s"${reflectiveFieldWriteAccess(fieldTerm, field, accTerm, dataViewTerm)};"
- }
+ def genDataViewFieldSetter(
+ specs: Seq[DataViewSpec[_]],
+ accTerm: String,
+ aggIndex: Int): String = {
+ val setters = for (spec <- specs) yield {
+ val field = spec.field
+ val dataViewTerm = createDataViewTerm(aggIndex, field.getName)
+ val fieldSetter = if (Modifier.isPublic(field.getModifiers)) {
+ s"$accTerm.${field.getName} = $dataViewTerm;"
+ } else {
+ val fieldTerm = addReusablePrivateFieldAccess(field.getDeclaringClass, field.getName)
+ s"${reflectiveFieldWriteAccess(fieldTerm, field, accTerm, dataViewTerm)};"
+ }
- s"""
- | $fieldSetter
+ s"""
+ | $fieldSetter
""".stripMargin
- }
- setters.mkString("\n")
- } else {
- ""
}
+ setters.mkString("\n")
}
def genSetAggregationResults: String = {
@@ -332,15 +422,30 @@ class AggregationCodeGenerator(
| ${aggMapping(i)},
| (${accTypes(i)}) accs.getField($i));""".stripMargin
} else {
- j"""
- | org.apache.flink.table.functions.AggregateFunction baseClass$i =
- | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
- | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | ${genDataViewFieldSetter(s"acc$i", i)}
- | output.setField(
- | ${aggMapping(i)},
- | baseClass$i.getValue(acc$i));""".stripMargin
- }
+ val setAccOutput =
+ j"""
+ | ${genAccDataViewFieldSetter(s"acc$i", i)}
+ | output.setField(
+ | ${aggMapping(i)},
+ | baseClass$i.getValue(acc$i));
+ """.stripMargin
+ if (isDistinctAggs(i)) {
+ j"""
+ | org.apache.flink.table.functions.AggregateFunction baseClass$i =
+ | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
+ | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
+ | $setAccOutput
+ """.stripMargin
+ } else {
+ j"""
+ | org.apache.flink.table.functions.AggregateFunction baseClass$i =
+ | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | $setAccOutput
+ """.stripMargin
+ }
+ }
}.mkString("\n")
j"""
@@ -359,12 +464,28 @@ class AggregationCodeGenerator(
val accumulate: String = {
for (i <- aggs.indices) yield {
- j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | ${genDataViewFieldSetter(s"acc$i", i)}
- | ${aggs(i)}.accumulate(
- | acc$i ${if (!parametersCode(i).isEmpty) "," else "" } ${parametersCode(i)});
- """.stripMargin
+ val accumulateAcc =
+ j"""
+ | ${genAccDataViewFieldSetter(s"acc$i", i)}
+ | ${aggs(i)}.accumulate(acc$i
+ | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
+ """.stripMargin
+ if (isDistinctAggs(i)) {
+ j"""
+ | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
+ | ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
+ | if (distinctAcc$i.add(
+ | ${classOf[Row].getCanonicalName}.of(${parametersCode(i)}))) {
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
+ | $accumulateAcc
+ | }
+ """.stripMargin
+ } else {
+ j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | $accumulateAcc
+ """.stripMargin
+ }
}
}.mkString("\n")
@@ -383,12 +504,28 @@ class AggregationCodeGenerator(
val retract: String = {
for (i <- aggs.indices) yield {
- j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | ${genDataViewFieldSetter(s"acc$i", i)}
- | ${aggs(i)}.retract(
- | acc$i ${if (!parametersCode(i).isEmpty) "," else "" } ${parametersCode(i)});
- """.stripMargin
+ val retractAcc =
+ j"""
+ | ${genAccDataViewFieldSetter(s"acc$i", i)}
+ | ${aggs(i)}.retract(
+ | acc$i ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
+ """.stripMargin
+ if (isDistinctAggs(i)) {
+ j"""
+ | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
+ | ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
+ | if (distinctAcc$i.remove(
+ | ${classOf[Row].getCanonicalName}.of(${parametersCode(i)}))) {
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
+ | $retractAcc
+ | }
+ """.stripMargin
+ } else {
+ j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | $retractAcc
+ """.stripMargin
+ }
}
}.mkString("\n")
@@ -417,13 +554,23 @@ class AggregationCodeGenerator(
.stripMargin
val create: String = {
for (i <- aggs.indices) yield {
- j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
- | ${genDataViewFieldSetter(s"acc$i", i)}
- | accs.setField(
- | $i,
- | acc$i);"""
- .stripMargin
+ if (isDistinctAggs(i)) {
+ j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
+ | $distinctAccType distinctAcc$i = ($distinctAccType)
+ | new ${classOf[DistinctAccumulator[_]].getCanonicalName} (acc$i);
+ | accs.setField(
+ | $i,
+ | distinctAcc$i);"""
+ .stripMargin
+ } else {
+ j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
+ | accs.setField(
+ | $i,
+ | acc$i);"""
+ .stripMargin
+ }
}
}.mkString("\n")
val ret: String =
@@ -481,14 +628,35 @@ class AggregationCodeGenerator(
| org.apache.flink.types.Row b)
""".stripMargin
val merge: String = {
- for (i <- aggs.indices) yield
- j"""
- | ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
- | ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
- | accIt$i.setElement(bAcc$i);
- | ${aggs(i)}.merge(aAcc$i, accIt$i);
- | a.setField($i, aAcc$i);
- """.stripMargin
+ for (i <- aggs.indices) yield {
+ if (isDistinctAggs(i)) {
+ j"""
+ | $distinctAccType aDistinctAcc$i = ($distinctAccType) a.getField($i);
+ | $distinctAccType bDistinctAcc$i = ($distinctAccType) b.getField(${mapping(i)});
+ | java.util.Iterator<java.util.Map.Entry> mergeIt$i =
+ | bDistinctAcc$i.elements().iterator();
+ | ${accTypes(i)} aAcc$i = (${accTypes(i)}) aDistinctAcc$i.getRealAcc();
+ |
+ | while (mergeIt$i.hasNext()) {
+ | java.util.Map.Entry entry = (java.util.Map.Entry) mergeIt$i.next();
+ | Object k = entry.getKey();
+ | Long v = (Long) entry.getValue();
+ | if (aDistinctAcc$i.add(k, v)) {
+ | ${aggs(i)}.accumulate(aAcc$i, k);
+ | }
+ | }
+ | a.setField($i, aDistinctAcc$i);
+ """.stripMargin
+ } else {
+ j"""
+ | ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
+ | ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
+ | accIt$i.setElement(bAcc$i);
+ | ${aggs(i)}.merge(aAcc$i, accIt$i);
+ | a.setField($i, aAcc$i);
+ """.stripMargin
+ }
+ }
}.mkString("\n")
val ret: String =
j"""
@@ -533,10 +701,20 @@ class AggregationCodeGenerator(
val reset: String = {
for (i <- aggs.indices) yield {
- j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | ${genDataViewFieldSetter(s"acc$i", i)}
- | ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin
+ if (isDistinctAggs(i)) {
+ j"""
+ | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
+ | ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
+ | ${genAccDataViewFieldSetter(s"acc$i", i)}
+ | distinctAcc$i.reset();
+ | ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin
+ } else {
+ j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | ${genAccDataViewFieldSetter(s"acc$i", i)}
+ | ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin
+ }
}
}.mkString("\n")
@@ -561,7 +739,7 @@ class AggregationCodeGenerator(
genResetAccumulator).mkString("\n")
val generatedAggregationsClass = classOf[GeneratedAggregations].getCanonicalName
- var funcCode =
+ val funcCode =
j"""
|public final class $funcName extends $generatedAggregationsClass {
|
http://git-wip-us.apache.org/repos/asf/flink/blob/6aef014a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
new file mode 100644
index 0000000..3427c9c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.aggfunctions
+
+import java.lang.{Long => JLong}
+import java.lang.{Iterable => JIterable}
+import java.util.{Map => JMap}
+
+import org.apache.flink.table.api.dataview.MapView
+import org.apache.flink.types.Row
+
+/**
+ * Wraps an accumulator and adds a map to filter distinct values.
+ *
+ * @param realAcc the wrapped accumulator.
+ * @param distinctValueMap the [[MapView]] that stores the distinct filter hash map.
+ *
+ * @tparam ACC the accumulator type for the realAcc.
+ */
+class DistinctAccumulator[ACC](
+ var realAcc: ACC,
+ var distinctValueMap: MapView[Row, JLong]) {
+
+ def this() {
+ this(null.asInstanceOf[ACC], new MapView[Row, JLong]())
+ }
+
+ def this(realAcc: ACC) {
+ this(realAcc, new MapView[Row, JLong]())
+ }
+
+ def getRealAcc: ACC = realAcc
+
+ def canEqual(a: Any): Boolean = a.isInstanceOf[DistinctAccumulator[ACC]]
+
+ override def equals(that: Any): Boolean =
+ that match {
+ case that: DistinctAccumulator[ACC] => that.canEqual(this) &&
+ this.distinctValueMap == that.distinctValueMap
+ case _ => false
+ }
+
+ /**
+ * Checks if the parameters are unique and adds the parameters to the distinct map.
+ * Returns true if the parameters are unique (haven't been in the map yet), false otherwise.
+ *
+ * @param params the parameters to check.
+ * @return true if the parameters are unique (haven't been in the map yet), false otherwise.
+ */
+ def add(params: Row): Boolean = {
+ val currentCnt = distinctValueMap.get(params)
+ if (currentCnt != null) {
+ distinctValueMap.put(params, currentCnt + 1L)
+ false
+ } else {
+ distinctValueMap.put(params, 1L)
+ true
+ }
+ }
+
+ /**
+ * Checks if the parameters are unique and adds the parameters to the distinct map.
+ * Returns true if the parameters are unique (haven't been in the map yet), false otherwise.
+ *
+ * @param params the parameters to check.
+ * @return true if the parameters are unique (haven't been in the map yet), false otherwise.
+ */
+ def add(params: Row, count: JLong): Boolean = {
+ val currentCnt = distinctValueMap.get(params)
+ if (currentCnt != null) {
+ distinctValueMap.put(params, currentCnt + count)
+ false
+ } else {
+ distinctValueMap.put(params, count)
+ true
+ }
+ }
+
+ /**
+ * Removes one instance of the parameters from the distinct map and checks if this was the last
+ * instance.
+ * Returns true if no instances of the parameters remain in the map, false otherwise.
+ *
+ * @param params the parameters to check.
+ * @return true if no instances of the parameters remain in the map, false otherwise.
+ */
+ def remove(params: Row): Boolean = {
+ val currentCnt = distinctValueMap.get(params)
+ if (currentCnt == 1) {
+ distinctValueMap.remove(params)
+ true
+ } else {
+ distinctValueMap.put(params, currentCnt - 1L)
+ false
+ }
+ }
+
+ def reset(): Unit = {
+ distinctValueMap.clear()
+ }
+
+ def elements(): JIterable[JMap.Entry[Row, JLong]] = {
+ distinctValueMap.map.entrySet()
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/6aef014a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
index f9bf803..8d37c29 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
@@ -70,7 +70,12 @@ trait OverAggregate {
val aggStrings = namedAggregates.map(_.getKey).map(
a => s"${a.getAggregation}(${
- if (a.getArgList.size() > 0) {
+ val prefix = if (a.isDistinct) {
+ "DISTINCT "
+ } else {
+ ""
+ }
+ prefix + (if (a.getArgList.size() > 0) {
a.getArgList.asScala.map { arg =>
// index to constant
if (arg >= inputType.getFieldCount) {
@@ -83,7 +88,7 @@ trait OverAggregate {
}.mkString(", ")
} else {
"*"
- }
+ })
})")
(inFields ++ aggStrings).zip(outFields).map {
http://git-wip-us.apache.org/repos/asf/flink/blob/6aef014a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index ce0a9c9..7ce44a6 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -27,7 +27,7 @@ import org.apache.calcite.sql.fun._
import org.apache.calcite.sql.{SqlAggFunction, SqlKind}
import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction, AggregateFunction => DataStreamAggFunction, _}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, RowTypeInfo}
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
@@ -36,6 +36,7 @@ import org.apache.flink.table.api.{StreamQueryConfig, TableConfig, TableExceptio
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.AggregationCodeGenerator
+import org.apache.flink.table.dataview.MapViewTypeInfo
import org.apache.flink.table.expressions.ExpressionUtils.isTimeIntervalLiteral
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.aggfunctions._
@@ -84,7 +85,7 @@ object AggregateUtil {
isRowsClause: Boolean)
: ProcessFunction[CRow, CRow] = {
- val (aggFields, aggregates, accTypes, accSpecs) =
+ val (aggFields, aggregates, isDistinctAggs, accTypes, accSpecs) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
aggregateInputType,
@@ -104,6 +105,8 @@ object AggregateUtil {
aggregates,
aggFields,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = true,
partialResults = false,
forwardMapping,
None,
@@ -165,7 +168,7 @@ object AggregateUtil {
generateRetraction: Boolean,
consumeRetraction: Boolean): ProcessFunction[CRow, CRow] = {
- val (aggFields, aggregates, accTypes, accSpecs) =
+ val (aggFields, aggregates, isDistinctAggs, accTypes, accSpecs) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputRowType,
@@ -185,6 +188,8 @@ object AggregateUtil {
aggregates,
aggFields,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = true,
partialResults = false,
groupings,
None,
@@ -233,7 +238,7 @@ object AggregateUtil {
: ProcessFunction[CRow, CRow] = {
val needRetract = true
- val (aggFields, aggregates, accTypes, accSpecs) =
+ val (aggFields, aggregates, isDistinctAggs, accTypes, accSpecs) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
aggregateInputType,
@@ -254,6 +259,8 @@ object AggregateUtil {
aggregates,
aggFields,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = true,
partialResults = false,
forwardMapping,
None,
@@ -336,7 +343,7 @@ object AggregateUtil {
: MapFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
needRetract,
@@ -387,6 +394,8 @@ object AggregateUtil {
aggregates,
aggFieldIndexes,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = true,
groupings,
None,
@@ -443,7 +452,7 @@ object AggregateUtil {
: RichGroupReduceFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
physicalInputRowType,
needRetract,
@@ -467,6 +476,8 @@ object AggregateUtil {
aggregates,
aggFieldIndexes,
aggregates.indices.map(_ + groupings.length).toArray,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = true,
groupings.indices.toArray,
Some(aggregates.indices.map(_ + groupings.length).toArray),
@@ -558,7 +569,7 @@ object AggregateUtil {
: RichGroupReduceFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, _, _) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, isDistinctAggs, _, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
physicalInputRowType,
needRetract,
@@ -572,6 +583,8 @@ object AggregateUtil {
aggregates,
aggFieldIndexes,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = true,
groupings,
Some(aggregates.indices.map(_ + groupings.length).toArray),
@@ -588,6 +601,8 @@ object AggregateUtil {
aggregates,
aggFieldIndexes,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = false,
groupings.indices.toArray,
Some(aggregates.indices.map(_ + groupings.length).toArray),
@@ -711,7 +726,7 @@ object AggregateUtil {
tableConfig: TableConfig): MapPartitionFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
physicalInputRowType,
needRetract,
@@ -737,6 +752,8 @@ object AggregateUtil {
aggregates,
aggFieldIndexes,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = true,
groupings.indices.toArray,
Some(aggregates.indices.map(_ + groupings.length).toArray),
@@ -786,7 +803,7 @@ object AggregateUtil {
: GroupCombineFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
physicalInputRowType,
needRetract,
@@ -813,6 +830,8 @@ object AggregateUtil {
aggregates,
aggFieldIndexes,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = true,
groupings.indices.toArray,
Some(aggregates.indices.map(_ + groupings.length).toArray),
@@ -854,7 +873,7 @@ object AggregateUtil {
Either[DataSetAggFunction, DataSetFinalAggFunction]) = {
val needRetract = false
- val (aggInFields, aggregates, accTypes, _) = transformToAggregateFunctions(
+ val (aggInFields, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
needRetract,
@@ -883,6 +902,8 @@ object AggregateUtil {
aggregates,
aggInFields,
aggregates.indices.map(_ + groupings.length).toArray,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = true,
groupings,
None,
@@ -909,6 +930,8 @@ object AggregateUtil {
aggregates,
aggInFields,
aggOutFields,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = false,
gkeyMapping,
Some(aggregates.indices.map(_ + groupings.length).toArray),
@@ -932,6 +955,8 @@ object AggregateUtil {
aggregates,
aggInFields,
aggOutFields,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = false,
groupings,
None,
@@ -1015,7 +1040,7 @@ object AggregateUtil {
: (DataStreamAggFunction[CRow, Row, Row], RowTypeInfo, RowTypeInfo) = {
val needRetract = false
- val (aggFields, aggregates, accTypes, _) =
+ val (aggFields, aggregates, isDistinctAggs, accTypes, _) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
@@ -1031,6 +1056,8 @@ object AggregateUtil {
aggregates,
aggFields,
aggMapping,
+ isDistinctAggs,
+ isStateBackedDataViews = false,
partialResults = false,
groupingKeys,
None,
@@ -1147,6 +1174,7 @@ object AggregateUtil {
isStateBackedDataViews: Boolean = false)
: (Array[Array[Int]],
Array[TableAggregateFunction[_, _]],
+ Array[Boolean],
Array[TypeInformation[_]],
Array[Seq[DataViewSpec[_]]]) = {
@@ -1439,7 +1467,46 @@ object AggregateUtil {
}
}
- (aggFieldIndexes, aggregates, accTypes, accSpecs)
+ // create distinct accumulator filter argument
+ val isDistinctAggs = new Array[Boolean](aggregateCalls.size)
+
+ aggregateCalls.zipWithIndex.foreach {
+ case (aggCall, index) =>
+ if (aggCall.isDistinct) {
+ // Generate distinct aggregates and the corresponding DistinctAccumulator
+ // wrappers for storing distinct mapping
+ val argList: util.List[Integer] = aggCall.getArgList
+
+ // Using Pojo fields for the real underlying accumulator
+ val pojoFields = new util.ArrayList[PojoField]()
+ pojoFields.add(new PojoField(
+ classOf[DistinctAccumulator[_]].getDeclaredField("realAcc"),
+ accTypes(index))
+ )
+ // If StateBackend is not enabled, the distinct mapping also needs
+ // to be added to the Pojo fields.
+ if (!isStateBackedDataViews) {
+
+ val argTypes: Array[TypeInformation[_]] = argList
+ .map(aggregateInputType.getFieldList.get(_).getType)
+ .map(FlinkTypeFactory.toTypeInfo).toArray
+
+ val mapViewTypeInfo = new MapViewTypeInfo(
+ new RowTypeInfo(argTypes:_*),
+ BasicTypeInfo.LONG_TYPE_INFO)
+ pojoFields.add(new PojoField(
+ classOf[DistinctAccumulator[_]].getDeclaredField("distinctValueMap"),
+ mapViewTypeInfo)
+ )
+ }
+ accTypes(index) = new PojoTypeInfo(classOf[DistinctAccumulator[_]], pojoFields)
+ isDistinctAggs(index) = true
+ } else {
+ isDistinctAggs(index) = false
+ }
+ }
+
+ (aggFieldIndexes, aggregates, isDistinctAggs, accTypes, accSpecs)
}
private def createRowTypeForKeysAndAggregates(
http://git-wip-us.apache.org/repos/asf/flink/blob/6aef014a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala
index eea395c..c8257b4 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala
@@ -32,6 +32,73 @@ class OverWindowTest extends TableTestBase {
'proctime.proctime, 'rowtime.rowtime)
@Test
+ def testProctimeBoundedDistinctWithNonDistinctPartitionedRowOver() = {
+ val sql = "SELECT " +
+ "b, " +
+ "count(a) OVER (PARTITION BY b ORDER BY proctime ROWS BETWEEN 2 preceding AND " +
+ "CURRENT ROW) as cnt1, " +
+ "sum(a) OVER (PARTITION BY b ORDER BY proctime ROWS BETWEEN 2 preceding AND " +
+ "CURRENT ROW) as sum1, " +
+ "count(DISTINCT a) OVER (PARTITION BY b ORDER BY proctime ROWS BETWEEN 2 preceding AND " +
+ "CURRENT ROW) as cnt2, " +
+ "sum(DISTINCT c) OVER (PARTITION BY b ORDER BY proctime ROWS BETWEEN 2 preceding AND " +
+ "CURRENT ROW) as sum2 " +
+ "from MyTable"
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamOverAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "b", "c", "proctime")
+ ),
+ term("partitionBy", "b"),
+ term("orderBy", "proctime"),
+ term("rows", "BETWEEN 2 PRECEDING AND CURRENT ROW"),
+ term("select", "a", "b", "c", "proctime", "COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1, " +
+ "COUNT(DISTINCT a) AS w0$o2, COUNT(DISTINCT c) AS w0$o3, $SUM0(DISTINCT c) AS w0$o4")
+ ),
+ term("select", "b", "w0$o0 AS cnt1, CASE(>(w0$o0, 0), CAST(w0$o1), null) AS sum1, " +
+ "w0$o2 AS cnt2, CASE(>(w0$o3, 0), CAST(w0$o4), null) AS sum2")
+ )
+ streamUtil.verifySql(sql, expected)
+ }
+
+ @Test
+ def testProctimeBoundedDistinctPartitionedRowOver() = {
+ val sql = "SELECT " +
+ "c, " +
+ "count(DISTINCT a) OVER (PARTITION BY c ORDER BY proctime ROWS BETWEEN 2 preceding AND " +
+ "CURRENT ROW) as cnt1, " +
+ "sum(DISTINCT a) OVER (PARTITION BY c ORDER BY proctime ROWS BETWEEN 2 preceding AND " +
+ "CURRENT ROW) as sum1 " +
+ "from MyTable"
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamOverAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "c", "proctime")
+ ),
+ term("partitionBy", "c"),
+ term("orderBy", "proctime"),
+ term("rows", "BETWEEN 2 PRECEDING AND CURRENT ROW"),
+ term("select", "a", "c", "proctime",
+ "COUNT(DISTINCT a) AS w0$o0, $SUM0(DISTINCT a) AS w0$o1")
+ ),
+ term("select", "c", "w0$o0 AS cnt1, CASE(>(w0$o0, 0), CAST(w0$o1), null) AS sum1")
+ )
+ streamUtil.verifySql(sql, expected)
+ }
+
+ @Test
def testProcTimeBoundedPartitionedRowsOver() = {
val sql = "SELECT " +
"c, " +
http://git-wip-us.apache.org/repos/asf/flink/blob/6aef014a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
index 411cbb1..d152804 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
@@ -20,6 +20,7 @@ package org.apache.flink.table.runtime.stream.sql
import org.apache.flink.api.common.time.Time
import org.apache.flink.api.java.tuple.Tuple1
+import org.apache.flink.api.java.tuple.Tuple2
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.functions.source.SourceFunction
@@ -852,6 +853,195 @@ class OverWindowITCase extends StreamingWithStateTestBase {
)
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
+
+ @Test
+ def testProcTimeDistinctBoundedPartitionedRowsOver(): Unit = {
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setParallelism(1)
+ StreamITCase.clear
+
+ val t = StreamTestData.get5TupleDataStream(env)
+ .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
+ tEnv.registerTable("MyTable", t)
+
+ val sqlQuery = "SELECT a, " +
+ " SUM(DISTINCT e) OVER (" +
+ " PARTITION BY a ORDER BY proctime ROWS BETWEEN 3 PRECEDING AND CURRENT ROW), " +
+ " MIN(DISTINCT e) OVER (" +
+ " PARTITION BY a ORDER BY proctime ROWS BETWEEN 3 PRECEDING AND CURRENT ROW), " +
+ " COLLECT(DISTINCT e) OVER (" +
+ " PARTITION BY a ORDER BY proctime ROWS BETWEEN 3 PRECEDING AND CURRENT ROW) " +
+ "FROM MyTable"
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = List(
+ "1,1,1,{1=1}",
+ "2,2,2,{2=1}",
+ "2,3,1,{1=1, 2=1}",
+ "3,2,2,{2=1}",
+ "3,2,2,{2=1}",
+ "3,5,2,{2=1, 3=1}",
+ "4,2,2,{2=1}",
+ "4,3,1,{1=1, 2=1}",
+ "4,3,1,{1=1, 2=1}",
+ "4,3,1,{1=1, 2=1}",
+ "5,1,1,{1=1}",
+ "5,4,1,{1=1, 3=1}",
+ "5,4,1,{1=1, 3=1}",
+ "5,6,1,{1=1, 2=1, 3=1}",
+ "5,5,2,{2=1, 3=1}")
+ assertEquals(expected, StreamITCase.testResults)
+ }
+
+ @Test
+ def testProcTimeDistinctUnboundedPartitionedRowsOver(): Unit = {
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setParallelism(1)
+ StreamITCase.clear
+
+ val t = StreamTestData.get5TupleDataStream(env)
+ .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
+ tEnv.registerTable("MyTable", t)
+
+ val sqlQuery = "SELECT a, " +
+ " COUNT(e) OVER (" +
+ " PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding), " +
+ " SUM(DISTINCT e) OVER (" +
+ " PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding), " +
+ " MIN(DISTINCT e) OVER (" +
+ " PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding) " +
+ "FROM MyTable"
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = List(
+ "1,1,1,1",
+ "2,1,2,2",
+ "2,2,3,1",
+ "3,1,2,2",
+ "3,2,2,2",
+ "3,3,5,2",
+ "4,1,2,2",
+ "4,2,3,1",
+ "4,3,3,1",
+ "4,4,3,1",
+ "5,1,1,1",
+ "5,2,4,1",
+ "5,3,4,1",
+ "5,4,6,1",
+ "5,5,6,1")
+ assertEquals(expected, StreamITCase.testResults)
+ }
+
+ @Test
+ def testProcTimeDistinctUnboundedPartitionedRangeOverWithNullValues(): Unit = {
+ val data = List(
+ (1L, 1, null),
+ (2L, 1, null),
+ (3L, 2, null),
+ (4L, 1, "Hello"),
+ (5L, 1, "Hello"),
+ (6L, 2, "Hello"),
+ (7L, 1, "Hello World"),
+ (8L, 2, "Hello World"),
+ (9L, 2, "Hello World"),
+ (10L, 1, null))
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ // for sum aggregation ensure that every time the order of each element is consistent
+ env.setParallelism(1)
+
+ val table = env.fromCollection(data)
+ .assignAscendingTimestamps(_._1)
+ .toTable(tEnv, 'a, 'b, 'c, 'rtime.rowtime)
+
+ tEnv.registerTable("MyTable", table)
+ tEnv.registerFunction("CntNullNonNull", new CountNullNonNull)
+
+ val sqlQuery = "SELECT " +
+ " c, " +
+ " b, " +
+ " COUNT(DISTINCT c) " +
+ " OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding), " +
+ " CntNullNonNull(DISTINCT c) " +
+ " OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding)" +
+ "FROM " +
+ " MyTable"
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = List(
+ "null,1,0,0|1", "null,1,0,0|1", "null,2,0,0|1", "null,1,2,2|1",
+ "Hello,1,1,1|1", "Hello,1,1,1|1", "Hello,2,1,1|1",
+ "Hello World,1,2,2|1", "Hello World,2,2,2|1", "Hello World,2,2,2|1")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testProcTimeDistinctPairWithNulls(): Unit = {
+
+ val data = List(
+ ("A", null),
+ ("A", null),
+ ("B", null),
+ (null, "Hello"),
+ ("A", "Hello"),
+ ("A", "Hello"),
+ (null, "Hello World"),
+ (null, "Hello World"),
+ ("A", "Hello World"),
+ ("B", "Hello World"))
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setParallelism(1)
+ StreamITCase.clear
+
+ val table = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'proctime.proctime)
+ tEnv.registerTable("MyTable", table)
+ tEnv.registerFunction("PairCount", new CountPairs)
+
+ val sqlQuery = "SELECT a, b, " +
+ " PairCount(a, b) OVER (ORDER BY proctime RANGE UNBOUNDED preceding), " +
+ " PairCount(DISTINCT a, b) OVER (ORDER BY proctime RANGE UNBOUNDED preceding) " +
+ "FROM MyTable"
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = List(
+ "A,null,1,1",
+ "A,null,2,1",
+ "B,null,3,2",
+ "null,Hello,4,3",
+ "A,Hello,5,4",
+ "A,Hello,6,4",
+ "null,Hello World,7,5",
+ "null,Hello World,8,5",
+ "A,Hello World,9,6",
+ "B,Hello World,10,7")
+ assertEquals(expected, StreamITCase.testResults)
+ }
}
object OverWindowITCase {
@@ -884,3 +1074,41 @@ class LargerThanCount extends AggregateFunction[Long, Tuple1[Long]] {
override def getValue(acc: Tuple1[Long]): Long = acc.f0
}
+
+class CountNullNonNull extends AggregateFunction[String, Tuple2[Long, Long]] {
+
+ override def createAccumulator(): Tuple2[Long, Long] = Tuple2.of(0L, 0L)
+
+ override def getValue(acc: Tuple2[Long, Long]): String = s"${acc.f0}|${acc.f1}"
+
+ def accumulate(acc: Tuple2[Long, Long], v: String): Unit = {
+ if (v == null) {
+ acc.f1 += 1
+ } else {
+ acc.f0 += 1
+ }
+ }
+
+ def retract(acc: Tuple2[Long, Long], v: String): Unit = {
+ if (v == null) {
+ acc.f1 -= 1
+ } else {
+ acc.f0 -= 1
+ }
+ }
+}
+
+class CountPairs extends AggregateFunction[Long, Tuple1[Long]] {
+
+ def accumulate(acc: Tuple1[Long], a: String, b: String): Unit = {
+ acc.f0 += 1
+ }
+
+ def retract(acc: Tuple1[Long], a: String, b: String): Unit = {
+ acc.f0 -= 1
+ }
+
+ override def createAccumulator(): Tuple1[Long] = Tuple1.of(0L)
+
+ override def getValue(acc: Tuple1[Long]): Long = acc.f0
+}