You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2019/01/18 12:47:02 UTC
[flink] branch master updated: [FLINK-8739] [table] Optimize
DISTINCE aggregates to use the same distinct accumulator if possible
This is an automated email from the ASF dual-hosted git repository.
jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 12260a2 [FLINK-8739] [table] Optimize DISTINCE aggregates to use the same distinct accumulator if possible
12260a2 is described below
commit 12260a27ddb0ec8356b4981f033c6c9954fe2ab4
Author: Dian Fu <fu...@alibaba-inc.com>
AuthorDate: Mon Dec 10 20:57:23 2018 +0800
[FLINK-8739] [table] Optimize DISTINCE aggregates to use the same distinct accumulator if possible
This closes #7286
---
.../table/codegen/AggregationCodeGenerator.scala | 318 +++++++++++----------
.../flink/table/codegen/MatchCodeGenerator.scala | 32 ++-
.../aggfunctions/DistinctAccumulator.scala | 19 +-
.../functions/utils/UserDefinedFunctionUtils.scala | 7 +-
.../table/runtime/aggregate/AggregateUtil.scala | 189 +++++++-----
.../runtime/utils/JavaUserDefinedAggFunctions.java | 76 +++++
.../harness/GroupAggregateHarnessTest.scala | 185 ++++++++++--
.../table/runtime/harness/HarnessTestBase.scala | 187 ++----------
.../runtime/stream/sql/OverWindowITCase.scala | 47 +--
.../flink/table/runtime/stream/sql/SqlITCase.scala | 15 +-
.../runtime/stream/table/AggregateITCase.scala | 23 +-
11 files changed, 627 insertions(+), 471 deletions(-)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index 57cc815..ea148bc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -19,17 +19,17 @@ package org.apache.flink.table.codegen
import java.lang.reflect.Modifier
import java.lang.{Iterable => JIterable}
+import java.util.{List => JList}
import org.apache.calcite.rex.RexLiteral
import org.apache.flink.api.common.state.{ListStateDescriptor, MapStateDescriptor, State, StateDescriptor}
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument, getRawClass}
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.api.dataview._
import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess}
import org.apache.flink.table.codegen.Indenter.toISC
-import org.apache.flink.table.dataview.{MapViewTypeInfo, StateListView, StateMapView}
+import org.apache.flink.table.dataview.{StateListView, StateMapView}
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.aggfunctions.DistinctAccumulator
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
@@ -38,6 +38,7 @@ import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, SingleEl
import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.types.Row
+import scala.collection.JavaConversions._
import scala.collection.mutable
/**
@@ -77,7 +78,8 @@ class AggregationCodeGenerator(
* @param aggregates All aggregate functions
* @param aggFields Indexes of the input fields for all aggregate functions
* @param aggMapping The mapping of aggregates to output fields
- * @param isDistinctAggs The flag array indicating whether it is distinct aggregate.
+ * @param distinctAccMapping The mapping of the distinct accumulator index to the
+ * corresponding aggregates.
* @param isStateBackedDataViews a flag to indicate if distinct filter uses state backend.
* @param partialResults A flag defining whether final or partial results (accumulators) are set
* to the output row.
@@ -98,7 +100,7 @@ class AggregationCodeGenerator(
aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]],
aggFields: Array[Array[Int]],
aggMapping: Array[Int],
- isDistinctAggs: Array[Boolean],
+ distinctAccMapping: Array[(Integer, JList[Integer])],
isStateBackedDataViews: Boolean,
partialResults: Boolean,
fwdMapping: Array[Int],
@@ -142,21 +144,35 @@ class AggregationCodeGenerator(
fields.mkString(", ")
}
- val parametersCodeForDistinctMerge = aggFields.map { inFields =>
- val fields = inFields.filter(_ > -1).zipWithIndex.map { case (f, i) =>
- // index to constant
- if (f >= physicalInputTypes.length) {
- constantFields(f - physicalInputTypes.length)
- }
+ // get parameter lists for distinct acc, constant fields are not necessary
+ val parametersCodeForDistinctAcc = aggFields.map { inFields =>
+ val fields = inFields.filter(i => i > -1 && i < physicalInputTypes.length).map { f =>
// index to input field
- else {
- s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) k.getField($i)"
- }
+ s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) input.getField($f)"
}
fields.mkString(", ")
}
+ val parametersCodeForDistinctMerge = aggFields.map { inFields =>
+ // transform inFields to pairs of (inField, index in acc) firstly,
+ // e.g. (4, 2, 3, 2) will be transformed to ((4,2), (2,0), (3,1), (2,0))
+ val fields = inFields.filter(_ > -1).groupBy(identity).toSeq.sortBy(_._1).zipWithIndex
+ .flatMap { case (a, i) => a._2.map((_, i)) }
+ .map { case (f, i) =>
+ // index to constant
+ if (f >= physicalInputTypes.length) {
+ constantFields(f - physicalInputTypes.length)
+ }
+ // index to input field
+ else {
+ s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) k.getField($i)"
+ }
+ }
+
+ fields.mkString(", ")
+ }
+
// get method signatures
val classes = UserDefinedFunctionUtils.typeInfoToClass(physicalInputTypes)
val constantClasses = UserDefinedFunctionUtils.typeInfoToClass(constantTypes)
@@ -174,30 +190,11 @@ class AggregationCodeGenerator(
}
// get distinct filter of acc fields for each aggregate functions
- val distinctAccType = s"${classOf[DistinctAccumulator[_]].getName}"
-
- // preparing MapViewSpecs for distinct value maps
- val distinctAggs: Array[Seq[DataViewSpec[_]]] = isDistinctAggs.zipWithIndex.map {
- case (isDistinctAgg, idx) if isDistinctAgg =>
-
- // get types of agg function arguments
- val argTypes: Array[TypeInformation[_]] = aggFields(idx)
- .map(physicalInputTypes(_))
- // create type for MapView
- val mapViewTypeInfo = new MapViewTypeInfo(
- new RowTypeInfo(argTypes:_*),
- BasicTypeInfo.LONG_TYPE_INFO)
- // create MapViewSpec for distinct value map
- Seq(
- MapViewSpec(
- "distinctAgg" + idx,
- classOf[DistinctAccumulator[_]].getDeclaredField("distinctValueMap"),
- mapViewTypeInfo)
- )
- case _ => Seq()
- }
+ val distinctAccType = s"${classOf[DistinctAccumulator].getName}"
+
+ val distinctAccCount = distinctAccMapping.count(_._1 >= 0)
- if (isDistinctAggs.contains(true) && partialResults && isStateBackedDataViews) {
+ if (distinctAccCount > 0 && partialResults && isStateBackedDataViews) {
// should not happen, but add an error message just in case.
throw new CodeGenException(
s"Cannot emit partial results if DISTINCT values are tracked in state-backed maps. " +
@@ -266,31 +263,13 @@ class AggregationCodeGenerator(
* aggregation functions.
*/
def addAccumulatorDataViews(): Unit = {
- if (isStateBackedDataViews) {
- // create MapStates for distinct value maps
- val descMapping: Map[String, StateDescriptor[_, _]] = distinctAggs
- .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
- .toMap[String, StateDescriptor[_ <: State, _]]
-
- for (i <- aggs.indices) yield {
- for (spec <- distinctAggs(i)) {
- // Check if stat descriptor exists.
- val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId,
- throw new CodeGenException(
- s"Can not find DataView for distinct filter in accumulator by id: ${spec.stateId}"))
-
- addReusableDataView(spec, desc, i)
- }
- }
- }
-
if (accConfig.isDefined) {
// create state handles for DataView backed accumulator fields.
val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get
.flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
.toMap[String, StateDescriptor[_ <: State, _]]
- for (i <- aggs.indices) yield {
+ for (i <- 0 until aggs.length + distinctAccCount) yield {
for (spec <- accConfig.get(i)) yield {
// Check if stat descriptor exists.
val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId,
@@ -375,14 +354,6 @@ class AggregationCodeGenerator(
reusableCleanupStatements.add(cleanup)
}
- def genDistinctDataViewFieldSetter(str: String, i: Int): String = {
- if (isStateBackedDataViews && distinctAggs(i).nonEmpty) {
- genDataViewFieldSetter(distinctAggs(i), str, i)
- } else {
- ""
- }
- }
-
def genAccDataViewFieldSetter(str: String, i: Int): String = {
if (accConfig.isDefined) {
genDataViewFieldSetter(accConfig.get(i), str, i)
@@ -429,38 +400,55 @@ class AggregationCodeGenerator(
| org.apache.flink.types.Row output) throws Exception """.stripMargin
val setAggs: String = {
- for (i <- aggs.indices) yield
-
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
if (partialResults) {
- j"""
- | output.setField(
- | ${aggMapping(i)},
- | (${accTypes(i)}) accs.getField($i));""".stripMargin
- } else {
- val setAccOutput =
+ def setAggs(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |output.setField(
+ | ${aggMapping(i)},
+ | (${accTypes(i)}) accs.getField($i));
+ """.stripMargin
+ }
+ }.mkString("\n")
+
+ if (i >= 0) {
j"""
- | ${genAccDataViewFieldSetter(s"acc$i", i)}
| output.setField(
| ${aggMapping(i)},
- | baseClass$i.getValue(acc$i));
+ | ($distinctAccType) accs.getField($i));
+ | ${setAggs(aggIndexes)}
""".stripMargin
- if (isDistinctAggs(i)) {
- j"""
- | org.apache.flink.table.functions.AggregateFunction baseClass$i =
- | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
- | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
- | ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
- | $setAccOutput
- """.stripMargin
- } else {
+ } else {
+ j"""
+ | ${setAggs(aggIndexes)}
+ """.stripMargin
+ }
+ } else {
+ def setAggs(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ val setAccOutput =
+ j"""
+ |${genAccDataViewFieldSetter(s"acc$i", i)}
+ |output.setField(
+ | ${aggMapping(i)},
+ | baseClass$i.getValue(acc$i));
+ """.stripMargin
+
j"""
- | org.apache.flink.table.functions.AggregateFunction baseClass$i =
- | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
- | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | $setAccOutput
+ |org.apache.flink.table.functions.AggregateFunction baseClass$i =
+ | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
+ |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ |$setAccOutput
""".stripMargin
}
- }
+ }.mkString("\n")
+
+ j"""
+ | ${setAggs(aggIndexes)}
+ """.stripMargin
+ }
+ }
}.mkString("\n")
j"""
@@ -478,27 +466,30 @@ class AggregationCodeGenerator(
| org.apache.flink.types.Row input) throws Exception """.stripMargin
val accumulate: String = {
- for (i <- aggs.indices) yield {
- val accumulateAcc =
+ def accumulateAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
j"""
- | ${genAccDataViewFieldSetter(s"acc$i", i)}
- | ${aggs(i)}.accumulate(acc$i
- | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
+ |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ |${genAccDataViewFieldSetter(s"acc$i", i)}
+ |${aggs(i)}.accumulate(acc$i
+ | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
""".stripMargin
- if (isDistinctAggs(i)) {
+ }
+ }.mkString("\n")
+
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
j"""
| $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
- | ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
- | if (distinctAcc$i.add(
- | ${classOf[Row].getCanonicalName}.of(${parametersCode(i)}))) {
- | ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
- | $accumulateAcc
+ | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
+ | if (distinctAcc$i.add(${classOf[Row].getCanonicalName}.of(
+ | ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) {
+ | ${accumulateAcc(aggIndexes)}
| }
""".stripMargin
} else {
j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | $accumulateAcc
+ | ${accumulateAcc(aggIndexes)}
""".stripMargin
}
}
@@ -518,27 +509,30 @@ class AggregationCodeGenerator(
| org.apache.flink.types.Row input) throws Exception """.stripMargin
val retract: String = {
- for (i <- aggs.indices) yield {
- val retractAcc =
+ def retractAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
j"""
- | ${genAccDataViewFieldSetter(s"acc$i", i)}
- | ${aggs(i)}.retract(
- | acc$i ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
+ |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ |${genAccDataViewFieldSetter(s"acc$i", i)}
+ |${aggs(i)}.retract(acc$i
+ | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)});
""".stripMargin
- if (isDistinctAggs(i)) {
+ }
+ }.mkString("\n")
+
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
j"""
| $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
- | ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
- | if (distinctAcc$i.remove(
- | ${classOf[Row].getCanonicalName}.of(${parametersCode(i)}))) {
- | ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
- | $retractAcc
+ | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
+ | if (distinctAcc$i.remove(${classOf[Row].getCanonicalName}.of(
+ | ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) {
+ | ${retractAcc(aggIndexes)}
| }
""".stripMargin
} else {
j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | $retractAcc
+ | ${retractAcc(aggIndexes)}
""".stripMargin
}
}
@@ -565,26 +559,34 @@ class AggregationCodeGenerator(
val init: String =
j"""
| org.apache.flink.types.Row accs =
- | new org.apache.flink.types.Row(${aggs.length});"""
+ | new org.apache.flink.types.Row(${aggs.length + distinctAccCount});"""
.stripMargin
val create: String = {
- for (i <- aggs.indices) yield {
- if (isDistinctAggs(i)) {
+ def createAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
+ |accs.setField(
+ | $i,
+ | acc$i);
+ """.stripMargin
+ }
+ }.mkString("\n")
+
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
| $distinctAccType distinctAcc$i = ($distinctAccType)
- | new ${classOf[DistinctAccumulator[_]].getCanonicalName} (acc$i);
+ | new ${classOf[DistinctAccumulator].getCanonicalName}();
| accs.setField(
| $i,
- | distinctAcc$i);"""
- .stripMargin
+ | distinctAcc$i);
+ | ${createAcc(aggIndexes)}
+ """.stripMargin
} else {
j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
- | accs.setField(
- | $i,
- | acc$i);"""
- .stripMargin
+ | ${createAcc(aggIndexes)}
+ """.stripMargin
}
}
}.mkString("\n")
@@ -633,8 +635,7 @@ class AggregationCodeGenerator(
}
def genMergeAccumulatorsPair: String = {
-
- val mapping = mergeMapping.getOrElse(aggs.indices.toArray)
+ val mapping = mergeMapping.getOrElse((0 until aggs.length + distinctAccCount).toArray)
val sig: String =
j"""
@@ -643,14 +644,35 @@ class AggregationCodeGenerator(
| org.apache.flink.types.Row b)
""".stripMargin
val merge: String = {
- for (i <- aggs.indices) yield {
- if (isDistinctAggs(i)) {
+ def accumulateAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
+ |${aggs(i)}.accumulate(aAcc$i, ${parametersCodeForDistinctMerge(i)});
+ |a.setField($i, aAcc$i);
+ """.stripMargin
+ }
+ }.mkString("\n")
+
+ def mergeAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
+ |${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
+ |accIt$i.setElement(bAcc$i);
+ |${aggs(i)}.merge(aAcc$i, accIt$i);
+ |a.setField($i, aAcc$i);
+ """.stripMargin
+ }
+ }.mkString("\n")
+
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
j"""
| $distinctAccType aDistinctAcc$i = ($distinctAccType) a.getField($i);
| $distinctAccType bDistinctAcc$i = ($distinctAccType) b.getField(${mapping(i)});
| java.util.Iterator<java.util.Map.Entry> mergeIt$i =
| bDistinctAcc$i.elements().iterator();
- | ${accTypes(i)} aAcc$i = (${accTypes(i)}) aDistinctAcc$i.getRealAcc();
|
| while (mergeIt$i.hasNext()) {
| java.util.Map.Entry entry = (java.util.Map.Entry) mergeIt$i.next();
@@ -658,18 +680,14 @@ class AggregationCodeGenerator(
| (${classOf[Row].getCanonicalName}) entry.getKey();
| Long v = (Long) entry.getValue();
| if (aDistinctAcc$i.add(k, v)) {
- | ${aggs(i)}.accumulate(aAcc$i, ${parametersCodeForDistinctMerge(i)});
+ | ${accumulateAcc(aggIndexes)}
| }
| }
| a.setField($i, aDistinctAcc$i);
""".stripMargin
} else {
j"""
- | ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
- | ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
- | accIt$i.setElement(bAcc$i);
- | ${aggs(i)}.merge(aAcc$i, accIt$i);
- | a.setField($i, aAcc$i);
+ | ${mergeAcc(aggIndexes)}
""".stripMargin
}
}
@@ -716,20 +734,28 @@ class AggregationCodeGenerator(
| org.apache.flink.types.Row accs) throws Exception """.stripMargin
val reset: String = {
- for (i <- aggs.indices) yield {
- if (isDistinctAggs(i)) {
+ def resetAcc(aggIndexes: JList[Integer]) = {
+ for (i <- aggIndexes) yield {
+ j"""
+ |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ |${genAccDataViewFieldSetter(s"acc$i", i)}
+ |${aggs(i)}.resetAccumulator(acc$i);
+ """.stripMargin
+ }
+ }.mkString("\n")
+
+ for ((i, aggIndexes) <- distinctAccMapping) yield {
+ if (i >= 0) {
j"""
| $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i);
- | ${genDistinctDataViewFieldSetter(s"distinctAcc$i", i)}
- | ${accTypes(i)} acc$i = (${accTypes(i)}) distinctAcc$i.getRealAcc();
- | ${genAccDataViewFieldSetter(s"acc$i", i)}
+ | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)}
| distinctAcc$i.reset();
- | ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin
+ | ${resetAcc(aggIndexes)}
+ """.stripMargin
} else {
j"""
- | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | ${genAccDataViewFieldSetter(s"acc$i", i)}
- | ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin
+ | ${resetAcc(aggIndexes)}
+ """.stripMargin
}
}
}.mkString("\n")
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
index 62cad01..6097178 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/MatchCodeGenerator.scala
@@ -737,7 +737,7 @@ class MatchCodeGenerator(
matchAgg.aggregations.map(_.aggFunction).toArray,
matchAgg.aggregations.map(_.inputIndices).toArray,
matchAgg.aggregations.indices.toArray,
- Array.fill(matchAgg.aggregations.size)(false),
+ matchAgg.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = false,
Array.emptyIntArray,
@@ -781,18 +781,27 @@ class MatchCodeGenerator(
callsWithIndices.map(_._2).toArray)
})
+ val distinctAccMap: mutable.Map[util.Set[Integer], Integer] = mutable.Map()
val aggs = logicalAggregates.zipWithIndex.map {
case (agg, index) =>
val result = AggregateUtil.extractAggregateCallMetadata(
agg.function,
isDistinct = false, // TODO properly set once supported in Calcite
+ distinctAccMap,
+ new util.ArrayList[Integer](), // TODO properly set once supported in Calcite
+ aggregates.length,
+ input.getArity,
agg.inputTypes,
needRetraction = false,
config,
isStateBackedDataViews = false,
index)
- SingleAggCall(result.aggregateFunction, agg.exprIndices.toArray, result.accumulatorSpecs)
+ SingleAggCall(
+ result.aggregateFunction,
+ agg.exprIndices.toArray,
+ result.accumulatorSpecs,
+ result.distinctAccIndex)
}
MatchAgg(aggs, inputRows.values.map(_._1).toSeq)
@@ -867,14 +876,25 @@ class MatchCodeGenerator(
private case class SingleAggCall(
aggFunction: TableAggregateFunction[_, _],
inputIndices: Array[Int],
- dataViews: Seq[DataViewSpec[_]]
+ dataViews: Seq[DataViewSpec[_]],
+ distinctAccIndex: Int
)
private case class MatchAgg(
aggregations: Seq[SingleAggCall],
- inputExprs: Seq[RexNode]
- )
-
+ inputExprs: Seq[RexNode]) {
+
+ def getDistinctAccMapping: Array[(Integer, util.List[Integer])] = {
+ val distinctAccMapping = mutable.Map[Integer, util.List[Integer]]()
+ aggregations.map(_.distinctAccIndex).zipWithIndex.foreach {
+ case (distinctAccIndex, aggIndex) =>
+ distinctAccMapping
+ .getOrElseUpdate(distinctAccIndex, new util.ArrayList[Integer]())
+ .add(aggIndex)
+ }
+ distinctAccMapping.toArray
+ }
+ }
}
}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
index 3427c9c..1c54acc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/DistinctAccumulator.scala
@@ -28,30 +28,19 @@ import org.apache.flink.types.Row
/**
* Wraps an accumulator and adds a map to filter distinct values.
*
- * @param realAcc the wrapped accumulator.
* @param distinctValueMap the [[MapView]] that stores the distinct filter hash map.
- *
- * @tparam ACC the accumulator type for the realAcc.
*/
-class DistinctAccumulator[ACC](
- var realAcc: ACC,
- var distinctValueMap: MapView[Row, JLong]) {
+class DistinctAccumulator(var distinctValueMap: MapView[Row, JLong]) {
def this() {
- this(null.asInstanceOf[ACC], new MapView[Row, JLong]())
- }
-
- def this(realAcc: ACC) {
- this(realAcc, new MapView[Row, JLong]())
+ this(new MapView[Row, JLong]())
}
- def getRealAcc: ACC = realAcc
-
- def canEqual(a: Any): Boolean = a.isInstanceOf[DistinctAccumulator[ACC]]
+ def canEqual(a: Any): Boolean = a.isInstanceOf[DistinctAccumulator]
override def equals(that: Any): Boolean =
that match {
- case that: DistinctAccumulator[ACC] => that.canEqual(this) &&
+ case that: DistinctAccumulator => that.canEqual(this) &&
this.distinctValueMap == that.distinctValueMap
case _ => false
}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 4faa6ed..1e58612 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -464,15 +464,15 @@ object UserDefinedFunctionUtils {
* Remove StateView fields from accumulator type information.
*
* @param index index of aggregate function
- * @param aggFun aggregate function
+ * @param acc accumulator
* @param accType accumulator type information, only support pojo type
* @param isStateBackedDataViews is data views use state backend
* @return mapping of accumulator type information and data view config which contains id,
* field name and state descriptor
*/
- def removeStateViewFieldsFromAccTypeInfo(
+ def removeStateViewFieldsFromAccTypeInfo[ACC](
index: Int,
- aggFun: AggregateFunction[_, _],
+ acc: ACC,
accType: TypeInformation[_],
isStateBackedDataViews: Boolean)
: (TypeInformation[_], Option[Seq[DataViewSpec[_]]]) = {
@@ -489,7 +489,6 @@ object UserDefinedFunctionUtils {
)
}
- val acc = aggFun.createAccumulator()
accType match {
case pojoType: PojoTypeInfo[_] if pojoType.getArity > 0 =>
val arity = pojoType.getArity
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index f5cf191..7c38254 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -18,6 +18,7 @@
package org.apache.flink.table.runtime.aggregate
import java.util
+import java.util.{ArrayList => JArrayList, List => JList}
import org.apache.calcite.rel.`type`._
import org.apache.calcite.rel.core.AggregateCall
@@ -27,7 +28,8 @@ import org.apache.calcite.sql.fun._
import org.apache.calcite.sql.{SqlAggFunction, SqlKind}
import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction, AggregateFunction => DataStreamAggFunction, _}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, RowTypeInfo}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.scala.typeutils.Types
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
@@ -36,7 +38,6 @@ import org.apache.flink.table.api.{StreamQueryConfig, TableConfig, TableExceptio
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.AggregationCodeGenerator
-import org.apache.flink.table.dataview.MapViewTypeInfo
import org.apache.flink.table.expressions.ExpressionUtils.isTimeIntervalLiteral
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.aggfunctions._
@@ -51,11 +52,11 @@ import org.apache.flink.types.Row
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
+import scala.collection.mutable
object AggregateUtil {
type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R]
- type JavaList[T] = java.util.List[T]
/**
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for unbounded OVER
@@ -88,6 +89,7 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
aggregateInputType,
+ inputFieldTypeInfo.length,
needRetraction = false,
tableConfig,
isStateBackedDataViews = true)
@@ -104,7 +106,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = true,
partialResults = false,
forwardMapping,
@@ -172,6 +174,7 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
inputRowType,
+ inputFieldTypes.length,
consumeRetraction,
tableConfig,
isStateBackedDataViews = true)
@@ -185,7 +188,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = true,
partialResults = false,
groupings,
@@ -240,6 +243,7 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
aggregateInputType,
+ inputFieldTypeInfo.length,
needRetract,
tableConfig,
isStateBackedDataViews = true)
@@ -257,7 +261,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = true,
partialResults = false,
forwardMapping,
@@ -346,6 +350,7 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
inputType,
+ inputFieldTypeInfo.length,
needRetract,
tableConfig)
@@ -385,8 +390,9 @@ object AggregateUtil {
throw new UnsupportedOperationException(s"$window is currently not supported on batch")
}
- val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
- val outputArity = aggregateMetadata.getAggregateCallsCount + groupings.length + 1
+ val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+ val outputArity = aggregateMetadata.getAggregateCallsCount + groupings.length +
+ aggregateMetadata.getDistinctAccCount + 1
val genFunction = generator.generateAggregations(
"DataSetAggregatePrepareMapHelper",
@@ -394,7 +400,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = true,
groupings,
@@ -455,6 +461,7 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
physicalInputRowType,
+ physicalInputTypes.length,
needRetract,
tableConfig)
@@ -465,19 +472,20 @@ object AggregateUtil {
physicalInputRowType,
Some(Array(BasicTypeInfo.LONG_TYPE_INFO)))
- val keysAndAggregatesArity = groupings.length + namedAggregates.length
+ val aggMappings = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+ val keysAndAggregatesArity =
+ groupings.length + namedAggregates.length + aggregateMetadata.getDistinctAccCount
window match {
case SlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) =>
// sliding time-window for partial aggregations
- val aggMappings = aggregateMetadata.getAdjustedMapping(groupings.length)
val genFunction = generator.generateAggregations(
"DataSetAggregatePrepareMapHelper",
physicalInputTypes,
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMappings,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = true,
groupings.indices.toArray,
@@ -573,10 +581,11 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
physicalInputRowType,
+ physicalInputTypes.length,
needRetract,
tableConfig)
- val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
+ val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
val genPreAggFunction = generator.generateAggregations(
"GroupingWindowAggregateHelper",
@@ -584,12 +593,12 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = true,
groupings.indices.toArray,
Some(aggMapping),
- outputType.getFieldCount,
+ outputType.getFieldCount + aggregateMetadata.getDistinctAccCount,
needRetract,
needMerge = true,
needReset = true,
@@ -602,7 +611,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = false,
groupings.indices.toArray,
@@ -730,12 +739,13 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
physicalInputRowType,
+ physicalInputTypes.length,
needRetract,
tableConfig)
- val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
-
- val keysAndAggregatesArity = groupings.length + namedAggregates.length
+ val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+ val keysAndAggregatesArity = groupings.length + aggregateMetadata.getAggregateCallsCount +
+ aggregateMetadata.getDistinctAccCount
window match {
case SessionGroupWindow(_, _, gap) =>
@@ -753,12 +763,12 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = true,
groupings.indices.toArray,
Some(aggMapping),
- groupings.length + aggregateMetadata.getAggregateCallsCount + 2,
+ keysAndAggregatesArity + 2,
needRetract,
needMerge = true,
needReset = true,
@@ -807,11 +817,13 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
physicalInputRowType,
+ physicalInputTypes.length,
needRetract,
tableConfig)
- val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
- val keysAndAggregatesArity = groupings.length + namedAggregates.length
+ val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+ val keysAndAggregatesArity =
+ groupings.length + namedAggregates.length + aggregateMetadata.getDistinctAccCount
window match {
@@ -830,7 +842,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = true,
groupings.indices.toArray,
@@ -876,6 +888,7 @@ object AggregateUtil {
val aggregateMetadata = extractAggregateMetadata(
namedAggregates.map(_.getKey),
inputType,
+ inputFieldTypeInfo.length,
needRetract,
tableConfig)
@@ -890,7 +903,9 @@ object AggregateUtil {
if (doAllSupportPartialMerge(aggregateMetadata.getAggregateFunctions)) {
- val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length)
+ val aggMapping = aggregateMetadata.getAdjustedMapping(groupings.length, partialResults = true)
+ val keysAndAggregatesArity = groupings.length + aggregateMetadata.getAggregateCallsCount +
+ aggregateMetadata.getDistinctAccCount
// compute preaggregation type
val preAggFieldTypes = gkeyOutMapping.map(_._2)
@@ -904,12 +919,12 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = true,
groupings,
None,
- groupings.length + aggregateMetadata.getAggregateCallsCount,
+ keysAndAggregatesArity,
needRetract,
needMerge = false,
needReset = true,
@@ -932,7 +947,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggOutFields,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = false,
gkeyMapping,
@@ -957,7 +972,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggOutFields,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = false,
groupings,
@@ -1046,6 +1061,7 @@ object AggregateUtil {
extractAggregateMetadata(
namedAggregates.map(_.getKey),
inputType,
+ inputFieldTypeInfo.length,
needRetract,
tableConfig)
@@ -1058,7 +1074,7 @@ object AggregateUtil {
aggregateMetadata.getAggregateFunctions,
aggregateMetadata.getAggregateIndices,
aggMapping,
- aggregateMetadata.getAggregatesDistinctFlags,
+ aggregateMetadata.getDistinctAccMapping,
isStateBackedDataViews = false,
partialResults = false,
groupingKeys,
@@ -1088,6 +1104,7 @@ object AggregateUtil {
val aggregateList = extractAggregateMetadata(
aggregateCalls,
inputType,
+ inputType.getFieldCount,
needRetraction = false,
tableConfig).getAggregateFunctions
@@ -1173,22 +1190,30 @@ object AggregateUtil {
* [[GeneratedAggregations]] function.
*/
private[flink] class AggregateMetadata(
- private val aggregates: Seq[(AggregateCallMetadata, Array[Int])]) {
+ private val aggregates: Seq[(AggregateCallMetadata, Array[Int])],
+ private val distinctAccTypesWithSpecs: Seq[(TypeInformation[_], Seq[DataViewSpec[_]])]) {
def getAggregateFunctions: Array[TableAggregateFunction[_, _]] = {
aggregates.map(_._1.aggregateFunction).toArray
}
def getAggregatesAccumulatorTypes: Array[TypeInformation[_]] = {
- aggregates.map(_._1.accumulatorType).toArray
+ aggregates.map(_._1.accumulatorType).toArray ++ distinctAccTypesWithSpecs.map(_._1)
}
def getAggregatesAccumulatorSpecs: Array[Seq[DataViewSpec[_]]] = {
- aggregates.map(_._1.accumulatorSpecs).toArray
+ aggregates.map(_._1.accumulatorSpecs).toArray ++ distinctAccTypesWithSpecs.map(_._2)
}
- def getAggregatesDistinctFlags: Array[Boolean] = {
- aggregates.map(_._1.isDistinct).toArray
+ def getDistinctAccMapping: Array[(Integer, JList[Integer])] = {
+ val distinctAccMapping = mutable.Map[Integer, JList[Integer]]()
+ aggregates.map(_._1.distinctAccIndex).zipWithIndex.foreach {
+ case (distinctAccIndex, aggIndex) =>
+ distinctAccMapping
+ .getOrElseUpdate(distinctAccIndex, new JArrayList[Integer]())
+ .add(aggIndex)
+ }
+ distinctAccMapping.toArray
}
def getAggregateCallsCount: Int = {
@@ -1199,8 +1224,13 @@ object AggregateUtil {
aggregates.map(_._2).toArray
}
- def getAdjustedMapping(offset: Int): Array[Int] = {
- (0 until getAggregateCallsCount).map(_ + offset).toArray
+ def getAdjustedMapping(offset: Int, partialResults: Boolean = false): Array[Int] = {
+ val accCount = getAggregateCallsCount + (if (partialResults) getDistinctAccCount else 0)
+ (0 until accCount).map(_ + offset).toArray
+ }
+
+ def getDistinctAccCount: Int = {
+ getDistinctAccMapping.count(_._1 >= 0)
}
}
@@ -1212,7 +1242,7 @@ object AggregateUtil {
aggregateFunction: TableAggregateFunction[_, _],
accumulatorType: TypeInformation[_],
accumulatorSpecs: Seq[DataViewSpec[_]],
- isDistinct: Boolean
+ distinctAccIndex: Int
)
/**
@@ -1221,6 +1251,11 @@ object AggregateUtil {
*
* @param aggregateFunction calcite's aggregate function
* @param isDistinct true if should be distinct aggregation
+ * @param distinctAccMap mapping of the distinct aggregate input fields index
+ * to the corresponding acc index
+ * @param argList indexes of the input fields of given aggregates
+ * @param aggregateCount number of aggregates
+ * @param inputFieldsCount number of input fields
* @param aggregateInputTypes input types of given aggregate
* @param needRetraction if the [[TableAggregateFunction]] should produce retractions
* @param tableConfig tableConfig, required for decimal precision
@@ -1235,6 +1270,10 @@ object AggregateUtil {
private[flink] def extractAggregateCallMetadata(
aggregateFunction: SqlAggFunction,
isDistinct: Boolean,
+ distinctAccMap: mutable.Map[util.Set[Integer], Integer],
+ argList: util.List[Integer],
+ aggregateCount: Int,
+ inputFieldsCount: Int,
aggregateInputTypes: Seq[RelDataType],
needRetraction: Boolean,
tableConfig: TableConfig,
@@ -1258,48 +1297,36 @@ object AggregateUtil {
removeStateViewFieldsFromAccTypeInfo(
uniqueIdWithinAggregate,
- aggregate,
+ aggregate.createAccumulator(),
accType,
isStateBackedDataViews)
}
// create distinct accumulator filter argument
- val distinctAccumulatorType = if (isDistinct) {
- createDistinctAccumulatorType(aggregateInputTypes, isStateBackedDataViews, accumulatorType)
+ val distinctAccIndex = if (isDistinct) {
+ getDistinctAccIndex(distinctAccMap, argList, aggregateCount, inputFieldsCount)
} else {
- accumulatorType
+ -1
}
- AggregateCallMetadata(aggregate, distinctAccumulatorType, accSpecs.getOrElse(Seq()), isDistinct)
+ AggregateCallMetadata(aggregate, accumulatorType, accSpecs.getOrElse(Seq()), distinctAccIndex)
}
- private def createDistinctAccumulatorType(
- aggregateInputTypes: Seq[RelDataType],
- isStateBackedDataViews: Boolean,
- accumulatorType: TypeInformation[_])
- : PojoTypeInfo[DistinctAccumulator[_]] = {
- // Using Pojo fields for the real underlying accumulator
- val pojoFields = new util.ArrayList[PojoField]()
- pojoFields.add(new PojoField(
- classOf[DistinctAccumulator[_]].getDeclaredField("realAcc"),
- accumulatorType)
- )
- // If StateBackend is not enabled, the distinct mapping also needs
- // to be added to the Pojo fields.
- if (!isStateBackedDataViews) {
-
- val argTypes: Array[TypeInformation[_]] = aggregateInputTypes
- .map(FlinkTypeFactory.toTypeInfo).toArray
-
- val mapViewTypeInfo = new MapViewTypeInfo(
- new RowTypeInfo(argTypes: _*),
- BasicTypeInfo.LONG_TYPE_INFO)
- pojoFields.add(new PojoField(
- classOf[DistinctAccumulator[_]].getDeclaredField("distinctValueMap"),
- mapViewTypeInfo)
- )
+ private def getDistinctAccIndex(
+ distinctAccMap: mutable.Map[util.Set[Integer], Integer],
+ argList: util.List[Integer],
+ aggregateCount: Int,
+ inputFieldsCount: Int): Int = {
+
+ val argListWithoutConstants = argList.toSet.filter(i => i > -1 && i < inputFieldsCount)
+ distinctAccMap.get(argListWithoutConstants) match {
+ case None =>
+ val index: Int = aggregateCount + distinctAccMap.size
+ distinctAccMap.put(argListWithoutConstants, index)
+ index
+
+ case Some(index) => index
}
- new PojoTypeInfo(classOf[DistinctAccumulator[_]], pojoFields)
}
/**
@@ -1308,6 +1335,7 @@ object AggregateUtil {
*
* @param aggregateCalls calcite's aggregate function
* @param aggregateInputType input type of given aggregates
+ * @param inputFieldsCount number of input fields
* @param needRetraction if the [[TableAggregateFunction]] should produce retractions
* @param tableConfig tableConfig, required for decimal precision
* @param isStateBackedDataViews if data should be backed by state backend
@@ -1320,11 +1348,14 @@ object AggregateUtil {
private def extractAggregateMetadata(
aggregateCalls: Seq[AggregateCall],
aggregateInputType: RelDataType,
+ inputFieldsCount: Int,
needRetraction: Boolean,
tableConfig: TableConfig,
isStateBackedDataViews: Boolean = false)
: AggregateMetadata = {
+ val distinctAccMap = mutable.Map[util.Set[Integer], Integer]()
+
val aggregatesWithIndices = aggregateCalls.zipWithIndex.map {
case (aggregateCall, index) =>
val argList: util.List[Integer] = aggregateCall.getArgList
@@ -1340,8 +1371,13 @@ object AggregateUtil {
}
val inputTypes = argList.map(aggregateInputType.getFieldList.get(_).getType)
- val aggregateCallMetadata = extractAggregateCallMetadata(aggregateCall.getAggregation,
+ val aggregateCallMetadata = extractAggregateCallMetadata(
+ aggregateCall.getAggregation,
aggregateCall.isDistinct,
+ distinctAccMap,
+ argList,
+ aggregateCalls.length,
+ inputFieldsCount,
inputTypes,
needRetraction,
tableConfig,
@@ -1351,8 +1387,19 @@ object AggregateUtil {
(aggregateCallMetadata, aggFieldIndices)
}
+ val distinctAccType = Types.POJO(classOf[DistinctAccumulator])
+
+ val distinctAccTypesWithSpecs = (0 until distinctAccMap.size).map { idx =>
+ val (accType, accSpec) = removeStateViewFieldsFromAccTypeInfo(
+ aggregateCalls.length + idx,
+ new DistinctAccumulator(),
+ distinctAccType,
+ isStateBackedDataViews)
+ (accType, accSpec.getOrElse(Seq()))
+ }
+
// store the aggregate fields of each aggregate function, by the same order of aggregates.
- new AggregateMetadata(aggregatesWithIndices)
+ new AggregateMetadata(aggregatesWithIndices, distinctAccTypesWithSpecs)
}
/**
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
index 0483c40..d22fc18 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
@@ -349,4 +349,80 @@ public class JavaUserDefinedAggFunctions {
isCloseCalled = true;
}
}
+
+ /**
+ * Count accumulator.
+ */
+ public static class MultiArgCountAcc {
+ public long count;
+ }
+
+ /**
+ * Count aggregate function with multiple arguments.
+ */
+ public static class MultiArgCount extends AggregateFunction<Long, MultiArgCountAcc> {
+
+ @Override
+ public MultiArgCountAcc createAccumulator() {
+ MultiArgCountAcc acc = new MultiArgCountAcc();
+ acc.count = 0L;
+ return acc;
+ }
+
+ public void accumulate(MultiArgCountAcc acc, Object in1, Object in2) {
+ if (in1 != null && in2 != null) {
+ acc.count += 1;
+ }
+ }
+
+ public void retract(MultiArgCountAcc acc, Object in1, Object in2) {
+ if (in1 != null && in2 != null) {
+ acc.count -= 1;
+ }
+ }
+
+ public void merge(MultiArgCountAcc accumulator, java.lang.Iterable<MultiArgCountAcc> iterable) {
+ for (MultiArgCountAcc otherAcc : iterable) {
+ accumulator.count += otherAcc.count;
+ }
+ }
+
+ @Override
+ public Long getValue(MultiArgCountAcc acc) {
+ return acc.count;
+ }
+ }
+
+ /**
+ * Sum accumulator.
+ */
+ public static class MultiArgSumAcc {
+ public long count;
+ }
+
+ /**
+ * Sum aggregate function with multiple arguments.
+ */
+ public static class MultiArgSum extends AggregateFunction<Long, MultiArgSumAcc> {
+
+ @Override
+ public MultiArgSumAcc createAccumulator() {
+ MultiArgSumAcc acc = new MultiArgSumAcc();
+ acc.count = 0L;
+ return acc;
+ }
+
+ public void accumulate(MultiArgSumAcc acc, long in1, long in2) {
+ acc.count += in1 + in2;
+ }
+
+ public void retract(MultiArgSumAcc acc, long in1, long in2) {
+ acc.count -= in1 + in2;
+ }
+
+ @Override
+ public Long getValue(MultiArgSumAcc acc) {
+ return acc.count;
+ }
+ }
}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
index 1dce994..cc81161 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
@@ -22,13 +22,22 @@ import java.util.concurrent.ConcurrentLinkedQueue
import org.apache.flink.api.common.time.Time
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.operators.LegacyKeyedProcessOperator
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.runtime.aggregate._
import org.apache.flink.table.runtime.harness.HarnessTestBase._
import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{MultiArgCount, MultiArgSum}
+import org.apache.flink.types.Row
+import org.junit.Assert.assertTrue
import org.junit.Test
+import scala.collection.mutable
+
class GroupAggregateHarnessTest extends HarnessTestBase {
protected var queryConfig =
@@ -176,68 +185,182 @@ class GroupAggregateHarnessTest extends HarnessTestBase {
@Test
def testDistinctAggregateWithRetract(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+
+ val data = new mutable.MutableList[(JLong, JInt)]
+ val t = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+ tEnv.registerTable("T", t)
+ val sqlQuery = tEnv.sqlQuery(
+ s"""
+ |SELECT
+ | a, count(distinct b), sum(distinct b)
+ |FROM (
+ | SELECT a, b
+ | FROM T
+ | GROUP BY a, b
+ |) GROUP BY a
+ |""".stripMargin)
+
+ val testHarness = createHarnessTester[String, CRow, CRow](
+ sqlQuery.toRetractStream[Row](queryConfig), "groupBy")
+
+ testHarness.setStateBackend(getStateBackend)
+ testHarness.open()
- val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow](
- new GroupAggProcessFunction(
- genDistinctCountAggFunction,
- distinctCountAggregationStateType,
- true,
- queryConfig))
+ val operator = getOperator(testHarness)
+ val fields = getGeneratedAggregationFields(
+ operator,
+ "function",
+ classOf[GroupAggProcessFunction])
- val testHarness =
- createHarnessTester(
- processFunction,
- new TupleRowKeySelector[String](2),
- BasicTypeInfo.STRING_TYPE_INFO)
+ // check only one DistinctAccumulator is used
+ assertTrue(fields.count(_.getName.endsWith("distinctValueMap_dataview")) == 1)
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ // register cleanup timer with 3001
+ testHarness.setProcessingTime(1)
+
+ // insert
+ testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 1: JInt)))
+ testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong, 1: JInt)))
+
+ // distinct count retract then accumulate for downstream operators
+ testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(false, 2L: JLong, 1L: JLong, 1: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong, 1: JInt)))
+
+ // update count for accumulate
+ testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 1L: JLong, 1: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong, 3: JInt)))
+
+ // update count for retraction
+ testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 3: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 1: JInt)))
+
+ // insert
+ testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(4L: JLong, 1L: JLong, 3: JInt)))
+
+ // retract entirely
+ testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(false, 4L: JLong, 1L: JLong, 3: JInt)))
+
+ // trigger cleanup timer and register cleanup timer with 6002
+ testHarness.setProcessingTime(3002)
+
+ // insert
+ testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 1: JInt)))
+
+ // trigger cleanup timer and register cleanup timer with 9002
+ testHarness.setProcessingTime(6002)
+
+ // retract after cleanup
+ testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 1: JInt, 1L: JLong)))
+ val result = testHarness.getOutput
+
+ verify(expectedOutput, result)
+
+ testHarness.close()
+ }
+
+ @Test
+ def testDistinctAggregateWithDifferentArgumentOrder(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+
+ val data = new mutable.MutableList[(JLong, JLong, JLong)]
+ val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c)
+ tEnv.registerTable("T", t)
+ tEnv.registerFunction("myCount", new MultiArgCount)
+ tEnv.registerFunction("mySum", new MultiArgSum)
+ val sqlQuery = tEnv.sqlQuery(
+ s"""
+ |SELECT
+ | a, myCount(distinct b, c), mySum(distinct c, b)
+ |FROM (
+ | SELECT a, b, c
+ | FROM T
+ | GROUP BY a, b, c
+ |) GROUP BY a
+ |""".stripMargin)
+
+ val testHarness = createHarnessTester[String, CRow, CRow](
+ sqlQuery.toRetractStream[Row](queryConfig), "groupBy")
+
+ testHarness.setStateBackend(getStateBackend)
testHarness.open()
+ val operator = getOperator(testHarness)
+ val fields = getGeneratedAggregationFields(
+ operator,
+ "function",
+ classOf[GroupAggProcessFunction])
+
+ // check only one DistinctAccumulator is used
+ assertTrue(fields.count(_.getName.endsWith("distinctValueMap_dataview")) == 1)
+
val expectedOutput = new ConcurrentLinkedQueue[Object]()
// register cleanup timer with 3001
testHarness.setProcessingTime(1)
// insert
- testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa")))
- expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
- testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb")))
- expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong)))
+ testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1L: JLong, 1L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 2L: JLong)))
+ testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1L: JLong, 1L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong, 2L: JLong)))
// distinct count retract then accumulate for downstream operators
- testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb")))
- expectedOutput.add(new StreamRecord(CRow(false, 2L: JLong, 1L: JLong)))
- expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong)))
+ testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1L: JLong, 1L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(false, 2L: JLong, 1L: JLong, 2L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong, 2L: JLong)))
// update count for accumulate
- testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2: JInt, "aaa")))
- expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 1L: JLong)))
- expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong)))
+ testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2L: JLong, 3L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 1L: JLong, 2L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong, 7L: JLong)))
+
+ testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2L: JLong, 3L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 7L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong, 7L: JLong)))
// update count for retraction
- testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2: JInt, "aaa")))
- expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong)))
- expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
+ testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 3L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 7L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong, 7L: JLong)))
+
+ testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 3L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong, 7L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 2L: JLong)))
// insert
- testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt, "ccc")))
- expectedOutput.add(new StreamRecord(CRow(4L: JLong, 1L: JLong)))
+ testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3L: JLong, 3L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(4L: JLong, 1L: JLong, 6L: JLong)))
// retract entirely
- testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3: JInt, "ccc")))
- expectedOutput.add(new StreamRecord(CRow(false, 4L: JLong, 1L: JLong)))
+ testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3L: JLong, 3L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(false, 4L: JLong, 1L: JLong, 6L: JLong)))
// trigger cleanup timer and register cleanup timer with 6002
testHarness.setProcessingTime(3002)
// insert
- testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa")))
- expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
+ testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1L: JLong, 2L: JLong)))
+ expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong, 3L: JLong)))
// trigger cleanup timer and register cleanup timer with 9002
testHarness.setProcessingTime(6002)
// retract after cleanup
- testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 1: JInt, "aaa")))
+ testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 1: JInt, 2L: JLong)))
val result = testHarness.getOutput
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index c37fd0c..cf7dc07 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -17,9 +17,9 @@
*/
package org.apache.flink.table.runtime.harness
+import java.lang.reflect.Field
import java.util.{Comparator, Queue => JQueue}
-import org.apache.flink.api.common.state.{MapStateDescriptor, StateDescriptor}
import org.apache.flink.api.common.time.Time
import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRING_TYPE_INFO}
import org.apache.flink.api.common.typeinfo.TypeInformation
@@ -32,7 +32,7 @@ import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness, TestHarnessUtil}
import org.apache.flink.table.api.dataview.DataView
-import org.apache.flink.table.api.{StreamQueryConfig, Types}
+import org.apache.flink.table.api.StreamQueryConfig
import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
@@ -83,15 +83,8 @@ class HarnessTestBase extends StreamingWithStateTestBase {
protected val sumAggregationStateType: RowTypeInfo =
new RowTypeInfo(sumAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)
- protected val distinctCountAggregationStateType: RowTypeInfo =
- new RowTypeInfo(distinctCountAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)
-
- protected val distinctCountDescriptor: String = EncodingUtils.encodeObjectToString(
- new MapStateDescriptor("distinctAgg0", distinctCountAggregationStateType, Types.LONG))
-
protected val minMaxFuncName = "MinMaxAggregateHelper"
protected val sumFuncName = "SumAggregationHelper"
- protected val distinctCountFuncName = "DistinctCountAggregationHelper"
val minMaxCode: String =
s"""
@@ -326,170 +319,8 @@ class HarnessTestBase extends StreamingWithStateTestBase {
|}
|""".stripMargin
- val distinctCountAggCode: String =
- s"""
- |public final class $distinctCountFuncName
- | extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
- |
- | final org.apache.flink.table.functions.aggfunctions.CountAggFunction count;
- |
- | final org.apache.flink.table.api.dataview.MapView acc0_distinctValueMap_dataview;
- |
- | final java.lang.reflect.Field distinctValueMap =
- | org.apache.flink.api.java.typeutils.TypeExtractor.getDeclaredField(
- | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator.class,
- | "distinctValueMap");
- |
- |
- | private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache
- | .flink.table.functions.aggfunctions.CountAccumulator> accIt0 =
- | new org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache.flink
- | .table
- | .functions.aggfunctions.CountAccumulator>();
- |
- | public $distinctCountFuncName() throws Exception {
- |
- | count = (org.apache.flink.table.functions.aggfunctions.CountAggFunction)
- | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
- | "$distinctCountAggFunction",
- | ${classOf[UserDefinedFunction].getCanonicalName}.class);
- |
- | distinctValueMap.setAccessible(true);
- | }
- |
- | public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) {
- | org.apache.flink.api.common.state.StateDescriptor acc0_distinctValueMap_dataview_desc =
- | (org.apache.flink.api.common.state.StateDescriptor)
- | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
- | "$distinctCountDescriptor",
- | ${classOf[StateDescriptor[_, _]].getCanonicalName}.class,
- | ctx.getUserCodeClassLoader());
- | acc0_distinctValueMap_dataview = new org.apache.flink.table.dataview.StateMapView(
- | ctx.getMapState((org.apache.flink.api.common.state.MapStateDescriptor)
- | acc0_distinctValueMap_dataview_desc));
- | }
- |
- | public final void setAggregationResults(
- | org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row output) {
- |
- | org.apache.flink.table.functions.AggregateFunction baseClass0 =
- | (org.apache.flink.table.functions.AggregateFunction)
- | count;
- |
- | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
- | (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
- | org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
- | (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
- | distinctAcc0.getRealAcc();
- |
- | output.setField(1, baseClass0.getValue(acc0));
- | }
- |
- | public final void accumulate(
- | org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row input) throws Exception {
- |
- | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
- | (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
- |
- | distinctValueMap.set(distinctAcc0, acc0_distinctValueMap_dataview);
- |
- | if (distinctAcc0.add(
- | org.apache.flink.types.Row.of((java.lang.Integer) input.getField(1)))) {
- | org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
- | (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
- | distinctAcc0.getRealAcc();
- |
- |
- | count.accumulate(acc0, (java.lang.Integer) input.getField(1));
- | }
- | }
- |
- | public final void retract(
- | org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row input) throws Exception {
- |
- | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
- | (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
- |
- | distinctValueMap.set(distinctAcc0, acc0_distinctValueMap_dataview);
- |
- | if (distinctAcc0.remove(
- | org.apache.flink.types.Row.of((java.lang.Integer) input.getField(1)))) {
- | org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
- | (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
- | distinctAcc0.getRealAcc();
- |
- | count.retract(acc0 , (java.lang.Integer) input.getField(1));
- | }
- | }
- |
- | public final org.apache.flink.types.Row createAccumulators()
- | {
- |
- | org.apache.flink.types.Row accs = new org.apache.flink.types.Row(1);
- |
- | org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
- | (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
- | count.createAccumulator();
- | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
- | (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator)
- | new org.apache.flink.table.functions.aggfunctions.DistinctAccumulator (acc0);
- | accs.setField(
- | 0,
- | distinctAcc0);
- |
- | return accs;
- | }
- |
- | public final void setForwardedFields(
- | org.apache.flink.types.Row input,
- | org.apache.flink.types.Row output)
- | {
- |
- | output.setField(
- | 0,
- | input.getField(0));
- | }
- |
- | public final void setConstantFlags(org.apache.flink.types.Row output)
- | {
- |
- | }
- |
- | public final org.apache.flink.types.Row createOutputRow() {
- | return new org.apache.flink.types.Row(2);
- | }
- |
- |
- | public final org.apache.flink.types.Row mergeAccumulatorsPair(
- | org.apache.flink.types.Row a,
- | org.apache.flink.types.Row b)
- | {
- |
- | return a;
- |
- | }
- |
- | public final void resetAccumulator(
- | org.apache.flink.types.Row accs) {
- | }
- |
- | public void cleanup() {
- | acc0_distinctValueMap_dataview.clear();
- | }
- |
- | public void close() {
- | }
- |}
- |""".stripMargin
-
protected val genMinMaxAggFunction = GeneratedAggregationsFunction(minMaxFuncName, minMaxCode)
protected val genSumAggFunction = GeneratedAggregationsFunction(sumFuncName, sumAggCode)
- protected val genDistinctCountAggFunction = GeneratedAggregationsFunction(
- distinctCountFuncName,
- distinctCountAggCode)
def createHarnessTester[KEY, IN, OUT](
dataStream: DataStream[_],
@@ -553,6 +384,20 @@ class HarnessTestBase extends StreamingWithStateTestBase {
stateField.get(generatedAggregation).asInstanceOf[DataView]
}
+ def getGeneratedAggregationFields(
+ operator: AbstractUdfStreamOperator[_, _],
+ funcName: String,
+ funcClass: Class[_]): Array[Field] = {
+ val function = funcClass.getDeclaredField(funcName)
+ function.setAccessible(true)
+ val generatedAggregation =
+ function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
+ val cls = generatedAggregation.getClass
+ val fields = cls.getDeclaredFields
+ fields.foreach(_.setAccessible(true))
+ fields
+ }
+
def createHarnessTester[IN, OUT, KEY](
operator: OneInputStreamOperator[IN, OUT],
keySelector: KeySelector[IN, KEY],
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
index d152804..1d2dd5e 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
@@ -31,6 +31,7 @@ import org.apache.flink.table.api.scala._
import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction
import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment}
import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.MultiArgCount
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase}
import org.apache.flink.types.Row
import org.junit.Assert._
@@ -911,6 +912,7 @@ class OverWindowITCase extends StreamingWithStateTestBase {
val t = StreamTestData.get5TupleDataStream(env)
.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
+ tEnv.registerFunction("myCount", new MultiArgCount)
val sqlQuery = "SELECT a, " +
" COUNT(e) OVER (" +
@@ -918,6 +920,10 @@ class OverWindowITCase extends StreamingWithStateTestBase {
" SUM(DISTINCT e) OVER (" +
" PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding), " +
" MIN(DISTINCT e) OVER (" +
+ " PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding), " +
+ " myCount(DISTINCT e, 1) OVER (" +
+ " PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding), " +
+ " myCount(DISTINCT 1, e) OVER (" +
" PARTITION BY a ORDER BY proctime RANGE UNBOUNDED preceding) " +
"FROM MyTable"
@@ -926,21 +932,21 @@ class OverWindowITCase extends StreamingWithStateTestBase {
env.execute()
val expected = List(
- "1,1,1,1",
- "2,1,2,2",
- "2,2,3,1",
- "3,1,2,2",
- "3,2,2,2",
- "3,3,5,2",
- "4,1,2,2",
- "4,2,3,1",
- "4,3,3,1",
- "4,4,3,1",
- "5,1,1,1",
- "5,2,4,1",
- "5,3,4,1",
- "5,4,6,1",
- "5,5,6,1")
+ "1,1,1,1,1,1",
+ "2,1,2,2,1,1",
+ "2,2,3,1,2,2",
+ "3,1,2,2,1,1",
+ "3,2,2,2,1,1",
+ "3,3,5,2,2,2",
+ "4,1,2,2,1,1",
+ "4,2,3,1,2,2",
+ "4,3,3,1,2,2",
+ "4,4,3,1,2,2",
+ "5,1,1,1,1,1",
+ "5,2,4,1,2,2",
+ "5,3,4,1,2,2",
+ "5,4,6,1,3,3",
+ "5,5,6,1,3,3")
assertEquals(expected, StreamITCase.testResults)
}
@@ -973,6 +979,7 @@ class OverWindowITCase extends StreamingWithStateTestBase {
tEnv.registerTable("MyTable", table)
tEnv.registerFunction("CntNullNonNull", new CountNullNonNull)
+ tEnv.registerFunction("myCount", new MultiArgCount)
val sqlQuery = "SELECT " +
" c, " +
@@ -980,6 +987,10 @@ class OverWindowITCase extends StreamingWithStateTestBase {
" COUNT(DISTINCT c) " +
" OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding), " +
" CntNullNonNull(DISTINCT c) " +
+ " OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding), " +
+ " myCount(DISTINCT c, 1) " +
+ " OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding), " +
+ " myCount(DISTINCT 1, c) " +
" OVER (PARTITION BY b ORDER BY rtime RANGE UNBOUNDED preceding)" +
"FROM " +
" MyTable"
@@ -989,9 +1000,9 @@ class OverWindowITCase extends StreamingWithStateTestBase {
env.execute()
val expected = List(
- "null,1,0,0|1", "null,1,0,0|1", "null,2,0,0|1", "null,1,2,2|1",
- "Hello,1,1,1|1", "Hello,1,1,1|1", "Hello,2,1,1|1",
- "Hello World,1,2,2|1", "Hello World,2,2,2|1", "Hello World,2,2,2|1")
+ "null,1,0,0|1,0,0", "null,1,0,0|1,0,0", "null,2,0,0|1,0,0", "null,1,2,2|1,2,2",
+ "Hello,1,1,1|1,1,1", "Hello,1,1,1|1,1,1", "Hello,2,1,1|1,1,1",
+ "Hello World,1,2,2|1,2,2", "Hello World,2,2,2|1,2,2", "Hello World,2,2,2|1,2,2")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
index ddc2a68..818ba65 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
@@ -30,6 +30,7 @@ import org.apache.flink.table.api.{TableEnvironment, Types}
import org.apache.flink.table.descriptors.{Rowtime, Schema}
import org.apache.flink.table.expressions.utils.Func15
import org.apache.flink.table.runtime.stream.sql.SqlITCase.TimestampAndWatermarkWithOffset
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.MultiArgCount
import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction
import org.apache.flink.table.runtime.utils.{JavaUserDefinedTableFunctions, StreamITCase, StreamTestData, StreamingWithStateTestBase}
import org.apache.flink.table.utils.{InMemoryTableFactory, MemoryTableSourceSinkUtil}
@@ -70,15 +71,19 @@ class SqlITCase extends StreamingWithStateTestBase {
StreamITCase.clear
val stream = env
.fromCollection(sessionWindowTestData)
- .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset[(Long, Int, String)](10L))
+ .assignTimestampsAndWatermarks(
+ new TimestampAndWatermarkWithOffset[(Long, Int, String)](10L))
val tEnv = TableEnvironment.getTableEnvironment(env)
val table = stream.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
tEnv.registerTable("MyTable", table)
+ tEnv.registerFunction("myCount", new MultiArgCount)
val sqlQuery = "SELECT c, " +
" COUNT(DISTINCT b)," +
" SUM(DISTINCT b)," +
+ " myCount(DISTINCT b, 1)," +
+ " myCount(DISTINCT 1, b)," +
" SESSION_END(rowtime, INTERVAL '0.005' SECOND) " +
"FROM MyTable " +
"GROUP BY SESSION(rowtime, INTERVAL '0.005' SECOND), c "
@@ -88,10 +93,10 @@ class SqlITCase extends StreamingWithStateTestBase {
env.execute()
val expected = Seq(
- "Hello World,1,9,1970-01-01 00:00:00.014", // window starts at [9L] till {14L}
- "Hello,1,16,1970-01-01 00:00:00.021", // window starts at [16L] till {21L}, not merged
- "Hello,3,6,1970-01-01 00:00:00.015" // window starts at [1L,2L],
- // merged with [8L,10L], by [4L], till {15L}
+ "Hello World,1,9,1,1,1970-01-01 00:00:00.014", // window starts at [9L] till {14L}
+ "Hello,1,16,1,1,1970-01-01 00:00:00.021", // window starts at [16L] till {21L}, not merged
+ "Hello,3,6,3,3,1970-01-01 00:00:00.015" // window starts at [1L,2L],
+ // merged with [8L,10L], by [4L], till {15L}
)
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
index 2e9dac5..2d1666f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
@@ -88,15 +88,30 @@ class AggregateITCase extends StreamingWithStateTestBase {
val tEnv = TableEnvironment.getTableEnvironment(env)
StreamITCase.clear
- val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
- .groupBy('e)
- .select('e, 'a.count.distinct)
+ val data = new mutable.MutableList[(Int, Int, String)]
+ data.+=((1, 1, "A"))
+ data.+=((2, 2, "B"))
+ data.+=((2, 2, "B"))
+ data.+=((4, 3, "C"))
+ data.+=((5, 3, "C"))
+ data.+=((4, 3, "C"))
+ data.+=((7, 3, "B"))
+ data.+=((1, 4, "A"))
+ data.+=((9, 4, "D"))
+ data.+=((4, 1, "A"))
+ data.+=((3, 2, "B"))
+
+ val testAgg = new WeightedAvg
+ val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c)
+ .groupBy('c)
+ .select('c, 'a.count.distinct, 'a.sum.distinct,
+ testAgg.distinct('a, 'b), testAgg.distinct('b, 'a), testAgg('a, 'b))
val results = t.toRetractStream[Row](queryConfig)
results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
env.execute()
- val expected = mutable.MutableList("1,4", "2,4", "3,2")
+ val expected = mutable.MutableList("A,2,5,1,1,1", "B,3,12,4,2,3", "C,2,9,4,3,4", "D,1,9,9,4,9")
assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
}