You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ku...@apache.org on 2019/05/23 06:42:11 UTC
[flink] branch master updated: [FLINK-12559][table-planner-blink]
Introduce metadata handlers on window aggregate
This is an automated email from the ASF dual-hosted git repository.
kurt pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new d195381 [FLINK-12559][table-planner-blink] Introduce metadata handlers on window aggregate
d195381 is described below
commit d1953818acf3dee35fc61e3b79c4471aa6329279
Author: godfrey he <go...@163.com>
AuthorDate: Thu May 23 14:41:56 2019 +0800
[FLINK-12559][table-planner-blink] Introduce metadata handlers on window aggregate
This closes #8487
---
.../metadata/AggCallSelectivityEstimator.scala | 10 +-
.../plan/metadata/FlinkRelMdColumnInterval.scala | 213 ++++++---
.../plan/metadata/FlinkRelMdColumnUniqueness.scala | 104 +++-
.../plan/metadata/FlinkRelMdDistinctRowCount.scala | 84 +++-
.../FlinkRelMdFilteredColumnInterval.scala | 10 +-
.../metadata/FlinkRelMdModifiedMonotonicity.scala | 2 +-
.../plan/metadata/FlinkRelMdPopulationSize.scala | 57 ++-
.../table/plan/metadata/FlinkRelMdRowCount.scala | 58 ++-
.../plan/metadata/FlinkRelMdSelectivity.scala | 41 +-
.../flink/table/plan/metadata/FlinkRelMdSize.scala | 12 +-
.../plan/metadata/FlinkRelMdUniqueGroups.scala | 64 ++-
.../table/plan/metadata/FlinkRelMdUniqueKeys.scala | 96 +++-
...indow.scala => FlinkLogicalOverAggregate.scala} | 17 +-
.../logical/FlinkLogicalTableFunctionScan.scala | 1 -
.../batch/BatchExecLocalHashWindowAggregate.scala | 2 +-
.../batch/BatchExecLocalSortWindowAggregate.scala | 2 +-
.../physical/batch/BatchExecOverAggregate.scala | 2 +-
.../stream/StreamExecGroupWindowAggregate.scala | 10 +-
.../table/plan/rules/FlinkBatchRuleSets.scala | 4 +-
.../table/plan/rules/FlinkStreamRuleSets.scala | 2 +-
.../plan/rules/logical/FlinkLogicalRankRule.scala | 16 +-
...Rule.scala => BatchExecOverAggregateRule.scala} | 18 +-
.../stream/StreamExecOverAggregateRule.scala | 8 +-
.../flink/table/plan/util/AggregateUtil.scala | 33 +-
.../flink/table/plan/util/FlinkRelMdUtil.scala | 174 ++++++-
.../table/plan/batch/sql/DagOptimizationTest.xml | 7 +-
...ndowAggregateTest.xml => OverAggregateTest.xml} | 0
.../plan/batch/sql/agg/WindowAggregateTest.xml | 89 ++--
.../FlinkLogicalRankRuleForConstantRangeTest.xml | 18 +-
.../FlinkLogicalRankRuleForRangeEndTest.xml | 8 +-
...ndowAggregateTest.xml => OverAggregateTest.xml} | 0
...AggregateTest.scala => OverAggregateTest.scala} | 2 +-
.../metadata/FlinkRelMdColumnIntervalTest.scala | 52 +-
.../metadata/FlinkRelMdColumnUniquenessTest.scala | 147 ++++--
.../metadata/FlinkRelMdDistinctRowCountTest.scala | 97 +++-
.../plan/metadata/FlinkRelMdHandlerTestBase.scala | 531 +++++++++++++++++++--
.../metadata/FlinkRelMdPopulationSizeTest.scala | 48 +-
.../plan/metadata/FlinkRelMdRowCountTest.scala | 40 +-
.../plan/metadata/FlinkRelMdSelectivityTest.scala | 66 ++-
.../table/plan/metadata/FlinkRelMdSizeTest.scala | 24 +-
.../plan/metadata/FlinkRelMdUniqueGroupsTest.scala | 43 +-
.../plan/metadata/FlinkRelMdUniqueKeysTest.scala | 28 +-
.../table/plan/metadata/MetadataTestUtil.scala | 49 +-
...AggregateTest.scala => OverAggregateTest.scala} | 2 +-
44 files changed, 1896 insertions(+), 395 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/AggCallSelectivityEstimator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/AggCallSelectivityEstimator.scala
index fc7c803..d4e87e2 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/AggCallSelectivityEstimator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/AggCallSelectivityEstimator.scala
@@ -19,7 +19,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.JDouble
-import org.apache.flink.table.plan.nodes.physical.batch.BatchExecGroupAggregateBase
+import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase}
import org.apache.flink.table.plan.stats._
import org.apache.flink.table.plan.util.AggregateUtil
@@ -62,6 +62,14 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
(rel.getGroupSet.toArray ++ auxGroupSet, otherAggCalls)
case rel: BatchExecGroupAggregateBase =>
(rel.getGrouping ++ rel.getAuxGrouping, rel.getAggCallList)
+ case rel: BatchExecLocalHashWindowAggregate =>
+ val fullGrouping = rel.getGrouping ++ Array(rel.inputTimeFieldIndex) ++ rel.getAuxGrouping
+ (fullGrouping, rel.getAggCallList)
+ case rel: BatchExecLocalSortWindowAggregate =>
+ val fullGrouping = rel.getGrouping ++ Array(rel.inputTimeFieldIndex) ++ rel.getAuxGrouping
+ (fullGrouping, rel.getAggCallList)
+ case rel: BatchExecWindowAggregateBase =>
+ (rel.getGrouping ++ rel.getAuxGrouping, rel.getAggCallList)
case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!")
}
require(outputIdx >= fullGrouping.length)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnInterval.scala
index cc55906..20d4610 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnInterval.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnInterval.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.metadata.FlinkMetadata.ColumnInterval
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream._
import org.apache.flink.table.plan.schema.FlinkRelOptTable
@@ -51,12 +51,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
override def getDef: MetadataDef[ColumnInterval] = FlinkMetadata.ColumnInterval.DEF
/**
- * Gets interval of the given column in TableScan.
+ * Gets interval of the given column on TableScan.
*
* @param ts TableScan RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in TableScan
+ * @return interval of the given column on TableScan
*/
def getColumnInterval(ts: TableScan, mq: RelMetadataQuery, index: Int): ValueInterval = {
val relOptTable = ts.getTable.asInstanceOf[FlinkRelOptTable]
@@ -79,12 +79,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets interval of the given column in Values.
+ * Gets interval of the given column on Values.
*
* @param values Values RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Values
+ * @return interval of the given column on Values
*/
def getColumnInterval(values: Values, mq: RelMetadataQuery, index: Int): ValueInterval = {
val tuples = values.tuples
@@ -101,14 +101,14 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets interval of the given column in Project.
+ * Gets interval of the given column on Project.
*
* Note: Only support the simple RexNode, e.g RexInputRef.
*
* @param project Project RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Project
+ * @return interval of the given column on Project
*/
def getColumnInterval(project: Project, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
@@ -130,12 +130,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets interval of the given column in Filter.
+ * Gets interval of the given column on Filter.
*
* @param filter Filter RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Filter
+ * @return interval of the given column on Filter
*/
def getColumnInterval(filter: Filter, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
@@ -148,12 +148,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets interval of the given column in batch Calc.
+ * Gets interval of the given column on Calc.
*
* @param calc Filter RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Filter
+ * @return interval of the given column on Calc
*/
def getColumnInterval(calc: Calc, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
@@ -249,12 +249,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets interval of the given column in Exchange.
+ * Gets interval of the given column on Exchange.
*
* @param exchange Exchange RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Exchange
+ * @return interval of the given column on Exchange
*/
def getColumnInterval(exchange: Exchange, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
@@ -262,12 +262,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets intervals of the given column of Sort.
+ * Gets interval of the given column on Sort.
*
- * @param sort Sort to analyze
+ * @param sort Sort RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Sort
+ * @return interval of the given column on Sort
*/
def getColumnInterval(sort: Sort, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
@@ -275,9 +275,9 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets intervals of the given column of Expand.
+ * Gets interval of the given column of Expand.
*
- * @param expand expand to analyze
+ * @param expand expand RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch sort
@@ -309,12 +309,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets intervals of the given column of Rank.
+ * Gets interval of the given column on Rank.
*
* @param rank [[Rank]] instance to analyze
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in batch Rank
+ * @return interval of the given column on Rank
*/
def getColumnInterval(
rank: Rank,
@@ -344,101 +344,106 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets intervals of the given column in Aggregates.
+ * Gets interval of the given column on Aggregates.
*
* @param aggregate Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Aggregate
+ * @return interval of the given column on Aggregate
*/
def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
- * Gets intervals of the given column in batch Aggregate.
+ * Gets interval of the given column on batch group aggregate.
*
- * @param aggregate Aggregate RelNode
+ * @param aggregate batch group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in batch Aggregate
+ * @return interval of the given column on batch group aggregate
*/
def getColumnInterval(
aggregate: BatchExecGroupAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
+ /**
+ * Gets interval of the given column on stream group aggregate.
+ *
+ * @param aggregate stream group aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on stream group Aggregate
+ */
def getColumnInterval(
aggregate: StreamExecGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
+ /**
+ * Gets interval of the given column on stream local group aggregate.
+ *
+ * @param aggregate stream local group aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on stream local group Aggregate
+ */
def getColumnInterval(
aggregate: StreamExecLocalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
+ /**
+ * Gets interval of the given column on stream global group aggregate.
+ *
+ * @param aggregate stream global group aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on stream global group Aggregate
+ */
def getColumnInterval(
aggregate: StreamExecGlobalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
- * Gets intervals of the given column in batch OverWindowAggregate.
+ * Gets interval of the given column on window aggregate.
*
- * @param aggregate Aggregate RelNode
- * @param mq RelMetadataQuery instance
- * @param index the index of the given column
- * @return interval of the given column in batch OverWindowAggregate
+ * @param agg window aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on window Aggregate
*/
def getColumnInterval(
- aggregate: BatchExecOverAggregate,
+ agg: WindowAggregate,
mq: RelMetadataQuery,
- index: Int): ValueInterval = getColumnIntervalOfOverWindow(aggregate, mq, index)
+ index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
- * Gets intervals of the given column in batch OverWindowAggregate.
+ * Gets interval of the given column on batch window aggregate.
*
- * @param aggregate Aggregate RelNode
- * @param mq RelMetadataQuery instance
- * @param index the index of the given column
- * @return interval of the given column in batch OverWindowAggregate
+ * @param agg batch window aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on batch window Aggregate
*/
def getColumnInterval(
- aggregate: StreamExecOverAggregate,
+ agg: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
- index: Int): ValueInterval = getColumnIntervalOfOverWindow(aggregate, mq, index)
+ index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
- * Gets intervals of the given column in calcite window.
+ * Gets interval of the given column on stream window aggregate.
*
- * @param window Window RelNode
- * @param mq RelMetadataQuery instance
- * @param index the index of the given column
- * @return interval of the given column in window
+ * @param agg stream window aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on stream window Aggregate
*/
def getColumnInterval(
- window: Window,
+ agg: StreamExecGroupWindowAggregate,
mq: RelMetadataQuery,
- index: Int): ValueInterval = {
- getColumnIntervalOfOverWindow(window, mq, index)
- }
-
- private def getColumnIntervalOfOverWindow(
- overWindow: SingleRel,
- mq: RelMetadataQuery,
- index: Int): ValueInterval = {
- val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
- val input = overWindow.getInput
- val fieldsCountOfInput = input.getRowType.getFieldCount
- if (index < fieldsCountOfInput) {
- fmq.getColumnInterval(input, index)
- } else {
- // cannot estimate aggregate function calls columnInterval.
- null
- }
- }
-
- // TODO supports window aggregate
+ index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
private def estimateColumnIntervalOfAggregate(
aggregate: SingleRel,
@@ -451,8 +456,16 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
case agg: StreamExecLocalGroupAggregate => agg.grouping
case agg: StreamExecGlobalGroupAggregate => agg.grouping
case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping
+ case agg: StreamExecGroupWindowAggregate => agg.getGrouping
case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg)
+ case agg: BatchExecLocalSortWindowAggregate =>
+ // grouping + assignTs + auxGrouping
+ agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
+ case agg: BatchExecLocalHashWindowAggregate =>
+ // grouping + assignTs + auxGrouping
+ agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
+ case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
}
if (index < groupSet.length) {
@@ -513,6 +526,8 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
case agg: StreamExecIncrementalGroupAggregate
if agg.partialAggInfoList.getActualAggregateCalls.length > aggCallIndex =>
agg.partialAggInfoList.getActualAggregateCalls(aggCallIndex)
+ case agg: StreamExecGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
+ agg.aggCalls(aggCallIndex)
case agg: BatchExecLocalHashAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
case agg: BatchExecHashAggregate if agg.isMerge =>
@@ -542,6 +557,8 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
} else {
null
}
+ case agg: BatchExecWindowAggregateBase if agg.getAggCallList.length > aggCallIndex =>
+ agg.getAggCallList(aggCallIndex)
case _ => null
}
@@ -580,12 +597,68 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets intervals of the given column in Join.
+ * Gets interval of the given column on calcite window.
+ *
+ * @param window Window RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index the index of the given column
+ * @return interval of the given column on window
+ */
+ def getColumnInterval(
+ window: Window,
+ mq: RelMetadataQuery,
+ index: Int): ValueInterval = {
+ getColumnIntervalOfOverAgg(window, mq, index)
+ }
+
+ /**
+ * Gets interval of the given column on batch over aggregate.
+ *
+ * @param agg batch over aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index he index of the given column
+ * @return interval of the given column on batch over aggregate.
+ */
+ def getColumnInterval(
+ agg: BatchExecOverAggregate,
+ mq: RelMetadataQuery,
+ index: Int): ValueInterval = getColumnIntervalOfOverAgg(agg, mq, index)
+
+ /**
+ * Gets interval of the given column on stream over aggregate.
+ *
+ * @param agg stream over aggregate RelNode
+ * @param mq RelMetadataQuery instance
+ * @param index he index of the given column
+ * @return interval of the given column on stream over aggregate.
+ */
+ def getColumnInterval(
+ agg: StreamExecOverAggregate,
+ mq: RelMetadataQuery,
+ index: Int): ValueInterval = getColumnIntervalOfOverAgg(agg, mq, index)
+
+ private def getColumnIntervalOfOverAgg(
+ overAgg: SingleRel,
+ mq: RelMetadataQuery,
+ index: Int): ValueInterval = {
+ val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
+ val input = overAgg.getInput
+ val fieldsCountOfInput = input.getRowType.getFieldCount
+ if (index < fieldsCountOfInput) {
+ fmq.getColumnInterval(input, index)
+ } else {
+ // cannot estimate aggregate function calls columnInterval.
+ null
+ }
+ }
+
+ /**
+ * Gets interval of the given column on Join.
*
* @param join Join RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Join
+ * @return interval of the given column on Join
*/
def getColumnInterval(join: Join, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
@@ -612,12 +685,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets interval of the given column in Union.
+ * Gets interval of the given column on Union.
*
* @param union Union RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
- * @return interval of the given column in Union
+ * @return interval of the given column on Union
*/
def getColumnInterval(union: Union, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
@@ -628,7 +701,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
- * Gets intervals of the given column of RelSubset.
+ * Gets interval of the given column on RelSubset.
*
* @param subset RelSubset to analyze
* @param mq RelMetadataQuery instance
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniqueness.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniqueness.scala
index 1279a80..e6a2686 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniqueness.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniqueness.scala
@@ -20,8 +20,9 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.JBoolean
import org.apache.flink.table.api.TableException
+import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.plan.nodes.FlinkRelNode
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.logical._
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream._
@@ -305,9 +306,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
- // group by keys form a unique key
- val groupKey = ImmutableBitSet.range(rel.getGroupCount)
- columns.contains(groupKey)
+ areColumnsUniqueOnAggregate(rel.getGroupSet.toArray, mq, columns, ignoreNulls)
}
def areColumnsUnique(
@@ -316,9 +315,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
if (rel.isFinal) {
- // group key of agg output always starts from 0
- val outputGroupKey = ImmutableBitSet.range(rel.getGrouping.length)
- columns.contains(outputGroupKey)
+ areColumnsUniqueOnAggregate(rel.getGrouping, mq, columns, ignoreNulls)
} else {
null
}
@@ -329,9 +326,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
- // group key of agg output always starts from 0
- val outputGroupKey = ImmutableBitSet.range(rel.grouping.length)
- columns.contains(outputGroupKey)
+ areColumnsUniqueOnAggregate(rel.grouping, mq, columns, ignoreNulls)
}
def areColumnsUnique(
@@ -339,7 +334,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
- columns.contains(ImmutableBitSet.of(rel.grouping.toArray: _*))
+ areColumnsUniqueOnAggregate(rel.grouping, mq, columns, ignoreNulls)
}
def areColumnsUnique(
@@ -348,32 +343,105 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = null
- // TODO supports window aggregate
+ private def areColumnsUniqueOnAggregate(
+ grouping: Array[Int],
+ mq: RelMetadataQuery,
+ columns: ImmutableBitSet,
+ ignoreNulls: Boolean): JBoolean = {
+ // group key of agg output always starts from 0
+ val outputGroupKey = ImmutableBitSet.of(grouping.indices: _*)
+ columns.contains(outputGroupKey)
+ }
+
+ def areColumnsUnique(
+ rel: WindowAggregate,
+ mq: RelMetadataQuery,
+ columns: ImmutableBitSet,
+ ignoreNulls: Boolean): JBoolean = {
+ areColumnsUniqueOnWindowAggregate(
+ rel.getGroupSet.toArray,
+ rel.getNamedProperties,
+ rel.getRowType.getFieldCount,
+ mq,
+ columns,
+ ignoreNulls)
+ }
+
+ def areColumnsUnique(
+ rel: BatchExecWindowAggregateBase,
+ mq: RelMetadataQuery,
+ columns: ImmutableBitSet,
+ ignoreNulls: Boolean): JBoolean = {
+ if (rel.isFinal) {
+ areColumnsUniqueOnWindowAggregate(
+ rel.getGrouping,
+ rel.getNamedProperties,
+ rel.getRowType.getFieldCount,
+ mq,
+ columns,
+ ignoreNulls)
+ } else {
+ null
+ }
+ }
+
+ def areColumnsUnique(
+ rel: StreamExecGroupWindowAggregate,
+ mq: RelMetadataQuery,
+ columns: ImmutableBitSet,
+ ignoreNulls: Boolean): JBoolean = {
+ areColumnsUniqueOnWindowAggregate(
+ rel.getGrouping,
+ rel.getWindowProperties,
+ rel.getRowType.getFieldCount,
+ mq,
+ columns,
+ ignoreNulls)
+ }
+
+ private def areColumnsUniqueOnWindowAggregate(
+ grouping: Array[Int],
+ namedProperties: Seq[NamedWindowProperty],
+ outputFieldCount: Int,
+ mq: RelMetadataQuery,
+ columns: ImmutableBitSet,
+ ignoreNulls: Boolean): JBoolean = {
+ if (namedProperties.nonEmpty) {
+ val begin = outputFieldCount - namedProperties.size
+ val end = outputFieldCount - 1
+ val keys = ImmutableBitSet.of(grouping.indices: _*)
+ (begin to end).map {
+ i => keys.union(ImmutableBitSet.of(i))
+ }.exists(columns.contains)
+ } else {
+ false
+ }
+ }
def areColumnsUnique(
rel: Window,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
- ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverWindow(rel, mq, columns, ignoreNulls)
+ ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverAgg(rel, mq, columns, ignoreNulls)
def areColumnsUnique(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
- ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverWindow(rel, mq, columns, ignoreNulls)
+ ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverAgg(rel, mq, columns, ignoreNulls)
def areColumnsUnique(
rel: StreamExecOverAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
- ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverWindow(rel, mq, columns, ignoreNulls)
+ ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverAgg(rel, mq, columns, ignoreNulls)
- private def areColumnsUniqueOfOverWindow(
- overWindow: SingleRel,
+ private def areColumnsUniqueOfOverAgg(
+ overAgg: SingleRel,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
- val input = overWindow.getInput
+ val input = overAgg.getInput
val inputFieldLength = input.getRowType.getFieldCount
val columnsBelongsToInput = ImmutableBitSet.of(columns.filter(_ < inputFieldLength).toList)
val isSubColumnsUnique = mq.areColumnsUnique(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCount.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCount.scala
index 5b33318..d062b03 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCount.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCount.scala
@@ -19,7 +19,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.api.{PlannerConfigOptions, TableException}
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.schema.FlinkRelOptTable
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, FlinkRelOptUtil, FlinkRexUtil, RankUtil}
@@ -399,26 +399,86 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecGroupAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
+ case rel: BatchExecWindowAggregateBase =>
+ FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
}
- // TODO supports window aggregate
+ def getDistinctRowCount(
+ rel: WindowAggregate,
+ mq: RelMetadataQuery,
+ groupKey: ImmutableBitSet,
+ predicate: RexNode): JDouble = {
+ val newPredicate = FlinkRelMdUtil.makeNamePropertiesSelectivityRexNode(rel, predicate)
+ if (newPredicate == null || newPredicate.isAlwaysTrue) {
+ if (groupKey.isEmpty) {
+ return 1D
+ }
+ }
+ val fieldCnt = rel.getRowType.getFieldCount
+ val namedPropertiesCnt = rel.getNamedProperties.size
+ val namedWindowStartIndex = fieldCnt - namedPropertiesCnt
+ val groupKeyFromNamedWindow = groupKey.toList.exists(_ >= namedWindowStartIndex)
+ if (groupKeyFromNamedWindow) {
+ // cannot estimate DistinctRowCount result when some group keys are from named windows
+ null
+ } else {
+ getDistinctRowCountOfAggregate(rel, mq, groupKey, newPredicate)
+ }
+ }
+
+ def getDistinctRowCount(
+ rel: BatchExecWindowAggregateBase,
+ mq: RelMetadataQuery,
+ groupKey: ImmutableBitSet,
+ predicate: RexNode): JDouble = {
+ if (predicate == null || predicate.isAlwaysTrue) {
+ if (groupKey.isEmpty) {
+ return 1D
+ }
+ }
+
+ val newPredicate = if (rel.isFinal) {
+ val namedWindowStartIndex = rel.getRowType.getFieldCount - rel.getNamedProperties.size
+ val groupKeyFromNamedWindow = groupKey.toList.exists(_ >= namedWindowStartIndex)
+ if (groupKeyFromNamedWindow) {
+ // cannot estimate DistinctRowCount result when some group keys are from named windows
+ return null
+ }
+ val newPredicate = FlinkRelMdUtil.makeNamePropertiesSelectivityRexNode(rel, predicate)
+ if (rel.isMerge) {
+ // set the bits as they correspond to local window aggregate
+ val localWinAggGroupKey = FlinkRelMdUtil.setChildKeysOfWinAgg(groupKey, rel)
+ val childPredicate = FlinkRelMdUtil.setChildPredicateOfWinAgg(newPredicate, rel)
+ return mq.getDistinctRowCount(rel.getInput, localWinAggGroupKey, childPredicate)
+ } else {
+ newPredicate
+ }
+ } else {
+ // local window aggregate
+ val assignTsFieldIndex = rel.getGrouping.length
+ if (groupKey.toList.contains(assignTsFieldIndex)) {
+ // groupKey contains `assignTs` fields
+ return null
+ }
+ predicate
+ }
+ getDistinctRowCountOfAggregate(rel, mq, groupKey, newPredicate)
+ }
def getDistinctRowCount(
rel: Window,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet,
- predicate: RexNode): JDouble =
- getDistinctRowCountOfOverWindow(rel, mq, groupKey, predicate)
+ predicate: RexNode): JDouble = getDistinctRowCountOfOverAgg(rel, mq, groupKey, predicate)
def getDistinctRowCount(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet,
- predicate: RexNode): JDouble =
- getDistinctRowCountOfOverWindow(rel, mq, groupKey, predicate)
+ predicate: RexNode): JDouble = getDistinctRowCountOfOverAgg(rel, mq, groupKey, predicate)
- private def getDistinctRowCountOfOverWindow(
- overWindow: SingleRel,
+ private def getDistinctRowCountOfOverAgg(
+ overAgg: SingleRel,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet,
predicate: RexNode): JDouble = {
@@ -427,10 +487,10 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
return 1D
}
}
- val input = overWindow.getInput
+ val input = overAgg.getInput
val fieldsCountOfInput = input.getRowType.getFieldCount
val groupKeyContainsAggCall = groupKey.toList.exists(_ >= fieldsCountOfInput)
- // cannot estimate ndv of aggCall result of OverWindowAgg
+ // cannot estimate ndv of aggCall result of OverAgg
if (groupKeyContainsAggCall) {
null
} else {
@@ -441,7 +501,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
predicate,
pushable,
notPushable)
- val rexBuilder = overWindow.getCluster.getRexBuilder
+ val rexBuilder = overAgg.getCluster.getRexBuilder
val childPreds = RexUtil.composeConjunction(rexBuilder, pushable, true)
val distinctRowCount = mq.getDistinctRowCount(input, groupKey, childPreds)
if (distinctRowCount == null) {
@@ -450,7 +510,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
distinctRowCount
} else {
val preds = RexUtil.composeConjunction(rexBuilder, notPushable, true)
- val rowCount = mq.getRowCount(overWindow)
+ val rowCount = mq.getRowCount(overAgg)
FlinkRelMdUtil.adaptNdvBasedOnSelectivity(rowCount, distinctRowCount,
RelMdUtil.guessSelectivity(preds))
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdFilteredColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdFilteredColumnInterval.scala
index 4269104..04dda46 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdFilteredColumnInterval.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdFilteredColumnInterval.scala
@@ -19,7 +19,7 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.metadata.FlinkMetadata.FilteredColumnInterval
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecGroupAggregateBase
-import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupAggregate, StreamExecLocalGroupAggregate}
+import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupAggregate, StreamExecGroupWindowAggregate, StreamExecLocalGroupAggregate}
import org.apache.flink.table.plan.stats.ValueInterval
import org.apache.flink.table.plan.util.ColumnIntervalUtil
import org.apache.flink.util.Preconditions.checkArgument
@@ -198,7 +198,13 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC
estimateFilteredColumnIntervalOfAggregate(aggregate, mq, columnIndex, filterArg)
}
- // TODO support window aggregate
+ def getColumnInterval(
+ aggregate: StreamExecGroupWindowAggregate,
+ mq: RelMetadataQuery,
+ columnIndex: Int,
+ filterArg: Int): ValueInterval = {
+ estimateFilteredColumnIntervalOfAggregate(aggregate, mq, columnIndex, filterArg)
+ }
def estimateFilteredColumnIntervalOfAggregate(
rel: RelNode,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
index d82f4f9..914768e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
@@ -272,7 +272,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
}
def getRelModifiedMonotonicity(
- rel: FlinkLogicalOverWindow,
+ rel: FlinkLogicalOverAggregate,
mq: RelMetadataQuery): RelModifiedMonotonicity = constants(rel.getRowType.getFieldCount)
def getRelModifiedMonotonicity(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSize.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSize.scala
index 389d36a..1d5cb00 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSize.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSize.scala
@@ -19,7 +19,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.api.TableException
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, RankUtil}
import org.apache.flink.table.{JArrayList, JDouble}
@@ -272,26 +272,67 @@ class FlinkRelMdPopulationSize private extends MetadataHandler[BuiltInMetadata.P
NumberUtil.min(popSizeOfColsInGroupKeys * popSizeOfColsInAggCalls, inputRowCnt)
}
- // TODO supports window aggregate
+ def getPopulationSize(
+ rel: WindowAggregate,
+ mq: RelMetadataQuery,
+ groupKey: ImmutableBitSet): JDouble = {
+ val fieldCnt = rel.getRowType.getFieldCount
+ val namedPropertiesCnt = rel.getNamedProperties.size
+ val namedWindowStartIndex = fieldCnt - namedPropertiesCnt
+ val groupKeyFromNamedWindow = groupKey.toList.exists(_ >= namedWindowStartIndex)
+ if (groupKeyFromNamedWindow) {
+ // cannot estimate PopulationSize result when some group keys are from named windows
+ null
+ } else {
+ // regular aggregate
+ getPopulationSize(rel.asInstanceOf[Aggregate], mq, groupKey)
+ }
+ }
+
+ def getPopulationSize(
+ rel: BatchExecWindowAggregateBase,
+ mq: RelMetadataQuery,
+ groupKey: ImmutableBitSet): JDouble = {
+ if (rel.isFinal) {
+ val namedWindowStartIndex = rel.getRowType.getFieldCount - rel.getNamedProperties.size
+ val groupKeyFromNamedWindow = groupKey.toList.exists(_ >= namedWindowStartIndex)
+ if (groupKeyFromNamedWindow) {
+ return null
+ }
+ if (rel.isMerge) {
+ // set the bits as they correspond to local window aggregate
+ val localWinAggGroupKey = FlinkRelMdUtil.setChildKeysOfWinAgg(groupKey, rel)
+ return mq.getPopulationSize(rel.getInput, localWinAggGroupKey)
+ }
+ } else {
+ // local window aggregate
+ val assignTsFieldIndex = rel.getGrouping.length
+ if (groupKey.toList.contains(assignTsFieldIndex)) {
+ // groupKey contains `assignTs` fields
+ return null
+ }
+ }
+ getPopulationSizeOfAggregate(rel, mq, groupKey)
+ }
def getPopulationSize(
window: Window,
mq: RelMetadataQuery,
- groupKey: ImmutableBitSet): JDouble = getPopulationSizeOfOverWindow(window, mq, groupKey)
+ groupKey: ImmutableBitSet): JDouble = getPopulationSizeOfOverAgg(window, mq, groupKey)
def getPopulationSize(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
- groupKey: ImmutableBitSet): JDouble = getPopulationSizeOfOverWindow(rel, mq, groupKey)
+ groupKey: ImmutableBitSet): JDouble = getPopulationSizeOfOverAgg(rel, mq, groupKey)
- private def getPopulationSizeOfOverWindow(
- overWindow: SingleRel,
+ private def getPopulationSizeOfOverAgg(
+ overAgg: SingleRel,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet): JDouble = {
- val input = overWindow.getInput
+ val input = overAgg.getInput
val fieldsCountOfInput = input.getRowType.getFieldCount
val groupKeyContainsAggCall = groupKey.toList.exists(_ >= fieldsCountOfInput)
- // cannot estimate population size of aggCall result of OverWindowAgg
+ // cannot estimate population size of aggCall result of OverAgg
if (groupKeyContainsAggCall) {
null
} else {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCount.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCount.scala
index 77f967f..0fb5f54 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCount.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCount.scala
@@ -20,10 +20,12 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.JDouble
import org.apache.flink.table.calcite.FlinkContext
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
+import org.apache.flink.table.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow}
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.exec.NodeResourceConfig
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.stats.ValueInterval
+import org.apache.flink.table.plan.util.AggregateUtil.{extractTimeIntervalValue, isTimeIntervalType}
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, SortUtil}
import org.apache.calcite.adapter.enumerable.EnumerableLimit
@@ -142,6 +144,8 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun
val (grouping, isFinal, isMerge) = rel match {
case agg: BatchExecGroupAggregateBase =>
(ImmutableBitSet.of(agg.getGrouping: _*), agg.isFinal, agg.isMerge)
+ case windowAgg: BatchExecWindowAggregateBase =>
+ (ImmutableBitSet.of(windowAgg.getGrouping: _*), windowAgg.isFinal, windowAgg.isMerge)
case _ => throw new IllegalArgumentException(s"Unknown aggregate type ${rel.getRelTypeName}!")
}
val ndvOfGroupKeysOnGlobalAgg: JDouble = if (grouping.isEmpty) {
@@ -185,16 +189,56 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun
}
}
- // TODO supports window aggregate
+ def getRowCount(rel: WindowAggregate, mq: RelMetadataQuery): JDouble = {
+ val (ndvOfGroupKeys, inputRowCount) = getRowCountOfAgg(rel, rel.getGroupSet, 1, mq)
+ estimateRowCountOfWindowAgg(ndvOfGroupKeys, inputRowCount, rel.getWindow)
+ }
+
+ def getRowCount(rel: BatchExecWindowAggregateBase, mq: RelMetadataQuery): JDouble = {
+ val ndvOfGroupKeys = getRowCountOfBatchExecAgg(rel, mq)
+ val inputRowCount = mq.getRowCount(rel.getInput)
+ estimateRowCountOfWindowAgg(ndvOfGroupKeys, inputRowCount, rel.getWindow)
+ }
+
+ private def estimateRowCountOfWindowAgg(
+ ndv: JDouble,
+ inputRowCount: JDouble,
+ window: LogicalWindow): JDouble = {
+ if (ndv == null) {
+ null
+ } else {
+ // simply assume expand factor of TumblingWindow/SessionWindow/SlideWindowWithoutOverlap is 2
+ // SlideWindowWithOverlap is 4.
+ // Introduce expand factor here to distinguish output rowCount of normal agg with all kinds of
+ // window aggregates.
+ val expandFactorOfTumblingWindow = 2D
+ val expandFactorOfNoOverLapSlidingWindow = 2D
+ val expandFactorOfOverLapSlidingWindow = 4D
+ val expandFactorOfSessionWindow = 2D
+ window match {
+ case TumblingGroupWindow(_, _, size) if isTimeIntervalType(size.getType) =>
+ Math.min(expandFactorOfTumblingWindow * ndv, inputRowCount)
+ case SlidingGroupWindow(_, _, size, slide) if isTimeIntervalType(size.getType) =>
+ val sizeValue = extractTimeIntervalValue(size)
+ val slideValue = extractTimeIntervalValue(slide)
+ if (sizeValue > slideValue) {
+ // only slideWindow which has overlap may generates more records than input
+ expandFactorOfOverLapSlidingWindow * ndv
+ } else {
+ Math.min(expandFactorOfNoOverLapSlidingWindow * ndv, inputRowCount)
+ }
+ case _ => Math.min(expandFactorOfSessionWindow * ndv, inputRowCount)
+ }
+ }
+ }
- def getRowCount(rel: Window, mq: RelMetadataQuery): JDouble =
- getRowCountOfOverWindow(rel, mq)
+ def getRowCount(rel: Window, mq: RelMetadataQuery): JDouble = getRowCountOfOverAgg(rel, mq)
def getRowCount(rel: BatchExecOverAggregate, mq: RelMetadataQuery): JDouble =
- getRowCountOfOverWindow(rel, mq)
+ getRowCountOfOverAgg(rel, mq)
- private def getRowCountOfOverWindow(overWindow: SingleRel, mq: RelMetadataQuery): JDouble =
- mq.getRowCount(overWindow.getInput)
+ private def getRowCountOfOverAgg(overAgg: SingleRel, mq: RelMetadataQuery): JDouble =
+ mq.getRowCount(overAgg.getInput)
def getRowCount(join: Join, mq: RelMetadataQuery): JDouble = {
join.getJoinType match {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivity.scala
index 2c63e06..1c0b7f5 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivity.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivity.scala
@@ -18,7 +18,7 @@
package org.apache.flink.table.plan.metadata
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.util.FlinkRelMdUtil
import org.apache.flink.table.{JArrayList, JDouble}
@@ -101,6 +101,26 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
mq: RelMetadataQuery,
predicate: RexNode): JDouble = getSelectivityOfAgg(rel, mq, predicate)
+ def getSelectivity(
+ rel: WindowAggregate,
+ mq: RelMetadataQuery,
+ predicate: RexNode): JDouble = {
+ val newPredicate = FlinkRelMdUtil.makeNamePropertiesSelectivityRexNode(rel, predicate)
+ getSelectivityOfAgg(rel, mq, newPredicate)
+ }
+
+ def getSelectivity(
+ rel: BatchExecWindowAggregateBase,
+ mq: RelMetadataQuery,
+ predicate: RexNode): JDouble = {
+ val newPredicate = if (rel.isFinal) {
+ FlinkRelMdUtil.makeNamePropertiesSelectivityRexNode(rel, predicate)
+ } else {
+ predicate
+ }
+ getSelectivityOfAgg(rel, mq, newPredicate)
+ }
+
private def getSelectivityOfAgg(
agg: SingleRel,
mq: RelMetadataQuery,
@@ -111,10 +131,17 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
val hasLocalAgg = agg match {
case _: Aggregate => false
case rel: BatchExecGroupAggregateBase => rel.isFinal && rel.isMerge
+ case rel: BatchExecWindowAggregateBase => rel.isFinal && rel.isMerge
case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!")
}
if (hasLocalAgg) {
- return mq.getSelectivity(agg.getInput, predicate)
+ val childPredicate = agg match {
+ case rel: BatchExecWindowAggregateBase =>
+ // set the predicate as they correspond to local window aggregate
+ FlinkRelMdUtil.setChildPredicateOfWinAgg(predicate, rel)
+ case _ => predicate
+ }
+ return mq.getSelectivity(agg.getInput, childPredicate)
}
val (childPred, restPred) = agg match {
@@ -122,6 +149,8 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecGroupAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
+ case rel: BatchExecWindowAggregateBase =>
+ FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!")
}
val childSelectivity = mq.getSelectivity(agg.getInput(), childPred.orNull)
@@ -139,19 +168,17 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
}
}
- // TODO supports window aggregate
-
def getSelectivity(
overWindow: Window,
mq: RelMetadataQuery,
- predicate: RexNode): JDouble = getSelectivityOfOverWindowAgg(overWindow, mq, predicate)
+ predicate: RexNode): JDouble = getSelectivityOfOverAgg(overWindow, mq, predicate)
def getSelectivity(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
- predicate: RexNode): JDouble = getSelectivityOfOverWindowAgg(rel, mq, predicate)
+ predicate: RexNode): JDouble = getSelectivityOfOverAgg(rel, mq, predicate)
- private def getSelectivityOfOverWindowAgg(
+ private def getSelectivityOfOverAgg(
over: SingleRel,
mq: RelMetadataQuery,
predicate: RexNode): JDouble = {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSize.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSize.scala
index c52c947..5cd6ff3 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSize.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSize.scala
@@ -231,16 +231,16 @@ class FlinkRelMdSize private extends MetadataHandler[BuiltInMetadata.Size] {
}
def averageColumnSizes(overWindow: Window, mq: RelMetadataQuery): JList[JDouble] =
- averageColumnSizesOfOverWindow(overWindow, mq)
+ averageColumnSizesOfOverAgg(overWindow, mq)
def averageColumnSizes(rel: BatchExecOverAggregate, mq: RelMetadataQuery): JList[JDouble] =
- averageColumnSizesOfOverWindow(rel, mq)
+ averageColumnSizesOfOverAgg(rel, mq)
- private def averageColumnSizesOfOverWindow(
- overWindow: SingleRel,
+ private def averageColumnSizesOfOverAgg(
+ overAgg: SingleRel,
mq: RelMetadataQuery): JList[JDouble] = {
- val inputFieldCount = overWindow.getInput.getRowType.getFieldCount
- getColumnSizesFromInputOrType(overWindow, mq, (0 until inputFieldCount).zipWithIndex.toMap)
+ val inputFieldCount = overAgg.getInput.getRowType.getFieldCount
+ getColumnSizesFromInputOrType(overAgg, mq, (0 until inputFieldCount).zipWithIndex.toMap)
}
def averageColumnSizes(rel: Join, mq: RelMetadataQuery): JList[JDouble] = {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueGroups.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueGroups.scala
index b559fe0..73a3adb 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueGroups.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueGroups.scala
@@ -18,15 +18,16 @@
package org.apache.flink.table.plan.metadata
+import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.plan.metadata.FlinkMetadata.UniqueGroups
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
-import org.apache.flink.table.plan.util.{FlinkRelMdUtil, RankUtil}
+import org.apache.flink.table.plan.util.{AggregateUtil, FlinkRelMdUtil, RankUtil}
import org.apache.calcite.plan.volcano.RelSubset
-import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
+import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.util.{Bug, ImmutableBitSet, Util}
@@ -252,7 +253,62 @@ class FlinkRelMdUniqueGroups private extends MetadataHandler[UniqueGroups] {
}
}
- // TODO support window aggregate
+ def getUniqueGroups(
+ agg: WindowAggregate,
+ mq: RelMetadataQuery,
+ columns: ImmutableBitSet): ImmutableBitSet = {
+ val grouping = agg.getGroupSet.map(_.toInt).toArray
+ val namedProperties = agg.getNamedProperties
+ val (auxGroupSet, _) = AggregateUtil.checkAndSplitAggCalls(agg)
+ getUniqueGroupsOfWindowAgg(agg, grouping, auxGroupSet, namedProperties, mq, columns)
+ }
+
+ def getUniqueGroups(
+ agg: BatchExecWindowAggregateBase,
+ mq: RelMetadataQuery,
+ columns: ImmutableBitSet): ImmutableBitSet = {
+ val grouping = agg.getGrouping
+ val namedProperties = agg.getNamedProperties
+ getUniqueGroupsOfWindowAgg(agg, grouping, agg.getAuxGrouping, namedProperties, mq, columns)
+ }
+
+ private def getUniqueGroupsOfWindowAgg(
+ windowAgg: SingleRel,
+ grouping: Array[Int],
+ auxGrouping: Array[Int],
+ namedProperties: Seq[NamedWindowProperty],
+ mq: RelMetadataQuery,
+ columns: ImmutableBitSet): ImmutableBitSet = {
+ val fieldCount = windowAgg.getRowType.getFieldCount
+ val columnList = columns.toList
+ val groupingInToOutMap = new mutable.HashMap[Integer, Integer]()
+ columnList.foreach { column =>
+ require(column < fieldCount)
+ if (column < grouping.length) {
+ groupingInToOutMap.put(grouping(column), column)
+ }
+ }
+ if (groupingInToOutMap.isEmpty) {
+ columns
+ } else {
+ val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
+ val inputColumns = ImmutableBitSet.of(groupingInToOutMap.keys.toList)
+ val inputUniqueGroups = fmq.getUniqueGroups(windowAgg.getInput, inputColumns)
+ val uniqueGroupsFromGrouping = inputUniqueGroups.asList.map { i =>
+ groupingInToOutMap.getOrElse(i, throw new IllegalArgumentException(s"Illegal index: $i"))
+ }
+ val fullGroupingOutputIndices =
+ grouping.indices ++ auxGrouping.indices.map(_ + grouping.length)
+ if (columns.equals(ImmutableBitSet.of(fullGroupingOutputIndices: _*))) {
+ return ImmutableBitSet.of(uniqueGroupsFromGrouping)
+ }
+
+ val groupingOutCols = groupingInToOutMap.values
+ // TODO drop some nonGroupingCols base on FlinkRelMdColumnUniqueness#areColumnsUnique(window)
+ val nonGroupingCols = columnList.filterNot(groupingOutCols.contains)
+ ImmutableBitSet.of(uniqueGroupsFromGrouping).union(ImmutableBitSet.of(nonGroupingCols))
+ }
+ }
def getUniqueGroups(
over: Window,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeys.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeys.scala
index f1c9060..956c45d 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeys.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeys.scala
@@ -18,7 +18,8 @@
package org.apache.flink.table.plan.metadata
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
+import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.logical._
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream._
@@ -32,7 +33,7 @@ import com.google.common.collect.ImmutableSet
import org.apache.calcite.plan.RelOptTable
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.core.{JoinRelType, _}
+import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode}
@@ -262,8 +263,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
rel: Aggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- // group by keys form a unique key
- ImmutableSet.of(ImmutableBitSet.range(rel.getGroupCount))
+ getUniqueKeysOnAggregate(rel.getGroupSet.toArray, mq, ignoreNulls)
}
def getUniqueKeys(
@@ -271,8 +271,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
if (rel.isFinal) {
- // group by keys form a unique key
- ImmutableSet.of(ImmutableBitSet.of(rel.getGrouping.indices: _*))
+ getUniqueKeysOnAggregate(rel.getGrouping, mq, ignoreNulls)
} else {
null
}
@@ -282,8 +281,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
rel: StreamExecGroupAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- // group by keys form a unique key
- toImmutableSet(rel.grouping.indices.toArray)
+ getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls)
}
def getUniqueKeys(
@@ -291,44 +289,98 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = null
-
def getUniqueKeys(
rel: StreamExecGlobalGroupAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- ImmutableSet.of(ImmutableBitSet.of(rel.grouping.indices.toArray: _*))
+ getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls)
+ }
+
+ def getUniqueKeysOnAggregate(
+ grouping: Array[Int],
+ mq: RelMetadataQuery,
+ ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
+ // group by keys form a unique key
+ ImmutableSet.of(ImmutableBitSet.of(grouping.indices: _*))
}
def getUniqueKeys(
- rel: StreamExecWindowJoin,
+ rel: WindowAggregate,
mq: RelMetadataQuery,
- ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- val joinInfo = JoinInfo.of(rel.getLeft, rel.getRight, rel.joinCondition)
- getJoinUniqueKeys(joinInfo, rel.joinType, rel.getLeft, rel.getRight, mq, ignoreNulls)
+ ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
+ getUniqueKeysOnWindowAgg(
+ rel.getRowType.getFieldCount,
+ rel.getNamedProperties,
+ rel.getGroupSet.toArray,
+ mq,
+ ignoreNulls)
+ }
+
+ def getUniqueKeys(
+ rel: BatchExecWindowAggregateBase,
+ mq: RelMetadataQuery,
+ ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
+ if (rel.isFinal) {
+ getUniqueKeysOnWindowAgg(
+ rel.getRowType.getFieldCount,
+ rel.getNamedProperties,
+ rel.getGrouping,
+ mq,
+ ignoreNulls)
+ } else {
+ null
+ }
+ }
+
+ def getUniqueKeys(
+ rel: StreamExecGroupWindowAggregate,
+ mq: RelMetadataQuery,
+ ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
+ getUniqueKeysOnWindowAgg(
+ rel.getRowType.getFieldCount, rel.getWindowProperties, rel.getGrouping, mq, ignoreNulls)
+ }
+
+ private def getUniqueKeysOnWindowAgg(
+ fieldCount: Int,
+ namedProperties: Seq[NamedWindowProperty],
+ grouping: Array[Int],
+ mq: RelMetadataQuery,
+ ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
+ if (namedProperties.nonEmpty) {
+ val begin = fieldCount - namedProperties.size
+ val end = fieldCount - 1
+ //namedProperties's indexes is at the end of output record
+ val keys = ImmutableBitSet.of(grouping.indices: _*)
+ (begin to end).map {
+ i => keys.union(ImmutableBitSet.of(i))
+ }.toSet[ImmutableBitSet]
+ } else {
+ null
+ }
}
def getUniqueKeys(
rel: Window,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- getUniqueKeysOfOverWindow(rel, mq, ignoreNulls)
+ getUniqueKeysOfOverAgg(rel, mq, ignoreNulls)
}
def getUniqueKeys(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- getUniqueKeysOfOverWindow(rel, mq, ignoreNulls)
+ getUniqueKeysOfOverAgg(rel, mq, ignoreNulls)
}
def getUniqueKeys(
rel: StreamExecOverAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
- getUniqueKeysOfOverWindow(rel, mq, ignoreNulls)
+ getUniqueKeysOfOverAgg(rel, mq, ignoreNulls)
}
- private def getUniqueKeysOfOverWindow(
+ private def getUniqueKeysOfOverAgg(
window: SingleRel,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
@@ -350,6 +402,14 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
}
}
+ def getUniqueKeys(
+ rel: StreamExecWindowJoin,
+ mq: RelMetadataQuery,
+ ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
+ val joinInfo = JoinInfo.of(rel.getLeft, rel.getRight, rel.joinCondition)
+ getJoinUniqueKeys(joinInfo, rel.joinType, rel.getLeft, rel.getRight, mq, ignoreNulls)
+ }
+
private def getJoinUniqueKeys(
joinInfo: JoinInfo,
joinRelType: JoinRelType,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalOverWindow.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalOverAggregate.scala
similarity index 91%
rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalOverWindow.scala
rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalOverAggregate.scala
index c021c7a..d43c829 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalOverWindow.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalOverAggregate.scala
@@ -19,16 +19,15 @@
package org.apache.flink.table.plan.nodes.logical
import org.apache.flink.table.api.ValidationException
-import org.apache.flink.table.plan.metadata.FlinkRelMetadataQuery
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.calcite.plan._
-import org.apache.calcite.rel.{RelCollation, RelCollationTraitDef, RelNode}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.core.Window
import org.apache.calcite.rel.logical.LogicalWindow
import org.apache.calcite.rel.metadata.RelMdCollation
+import org.apache.calcite.rel.{RelCollation, RelCollationTraitDef, RelNode}
import org.apache.calcite.rex.RexLiteral
import org.apache.calcite.sql.SqlRankFunction
@@ -42,7 +41,7 @@ import scala.collection.JavaConversions._
* Sub-class of [[Window]] that is a relational expression
* which represents a set of over window aggregates in Flink.
*/
-class FlinkLogicalOverWindow(
+class FlinkLogicalOverAggregate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
input: RelNode,
@@ -53,7 +52,7 @@ class FlinkLogicalOverWindow(
with FlinkLogicalRel {
override def copy(traitSet: RelTraitSet, inputs: JList[RelNode]): RelNode = {
- new FlinkLogicalOverWindow(
+ new FlinkLogicalOverAggregate(
cluster,
traitSet,
inputs.get(0),
@@ -64,12 +63,12 @@ class FlinkLogicalOverWindow(
}
-class FlinkLogicalOverWindowConverter
+class FlinkLogicalOverAggregateConverter
extends ConverterRule(
classOf[LogicalWindow],
Convention.NONE,
FlinkConventions.LOGICAL,
- "FlinkLogicalOverWindowConverter") {
+ "FlinkLogicalOverAggregateConverter") {
override def convert(rel: RelNode): RelNode = {
val window = rel.asInstanceOf[LogicalWindow]
@@ -92,7 +91,7 @@ class FlinkLogicalOverWindowConverter
}
}
- new FlinkLogicalOverWindow(
+ new FlinkLogicalOverAggregate(
rel.getCluster,
traitSet,
newInput,
@@ -102,6 +101,6 @@ class FlinkLogicalOverWindowConverter
}
}
-object FlinkLogicalOverWindow {
- val CONVERTER = new FlinkLogicalOverWindowConverter
+object FlinkLogicalOverAggregate {
+ val CONVERTER = new FlinkLogicalOverAggregateConverter
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala
index add914e..22cb665 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala
@@ -18,7 +18,6 @@
package org.apache.flink.table.plan.nodes.logical
-import org.apache.flink.table.plan.metadata.FlinkRelMetadataQuery
import org.apache.flink.table.plan.nodes.FlinkConventions
import com.google.common.collect.ImmutableList
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalHashWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalHashWindowAggregate.scala
index 3fd4398..34752c8 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalHashWindowAggregate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalHashWindowAggregate.scala
@@ -42,7 +42,7 @@ class BatchExecLocalHashWindowAggregate(
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
window: LogicalWindow,
- inputTimeFieldIndex: Int,
+ val inputTimeFieldIndex: Int,
inputTimeIsDate: Boolean,
namedProperties: Seq[NamedWindowProperty],
enableAssignPane: Boolean = false)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalSortWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalSortWindowAggregate.scala
index 408bffe..3839755 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalSortWindowAggregate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalSortWindowAggregate.scala
@@ -45,7 +45,7 @@ class BatchExecLocalSortWindowAggregate(
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
window: LogicalWindow,
- inputTimeFieldIndex: Int,
+ val inputTimeFieldIndex: Int,
inputTimeIsDate: Boolean,
namedProperties: Seq[NamedWindowProperty],
enableAssignPane: Boolean = false)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecOverAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecOverAggregate.scala
index fd48fc1..14f50a8 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecOverAggregate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecOverAggregate.scala
@@ -58,7 +58,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
/**
- * Batch physical RelNode for sort-based over [[Window]].
+ * Batch physical RelNode for sort-based over [[Window]] aggregate.
*/
class BatchExecOverAggregate(
cluster: RelOptCluster,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecGroupWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecGroupWindowAggregate.scala
index 7fe1e0a..84c336b 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecGroupWindowAggregate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecGroupWindowAggregate.scala
@@ -62,7 +62,7 @@ class StreamExecGroupWindowAggregate(
val aggCalls: Seq[AggregateCall],
val window: LogicalWindow,
namedProperties: Seq[NamedWindowProperty],
- inputTimestampIndex: Int,
+ inputTimeFieldIndex: Int,
val emitStrategy: WindowEmitStrategy)
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel
@@ -86,7 +86,7 @@ class StreamExecGroupWindowAggregate(
case _ => false
}
- def getGroupings: Array[Int] = grouping
+ def getGrouping: Array[Int] = grouping
def getWindowProperties: Seq[NamedWindowProperty] = namedProperties
@@ -103,7 +103,7 @@ class StreamExecGroupWindowAggregate(
aggCalls,
window,
namedProperties,
- inputTimestampIndex,
+ inputTimeFieldIndex,
emitStrategy)
}
@@ -177,14 +177,14 @@ class StreamExecGroupWindowAggregate(
namedProperties)
val timeIdx = if (isRowtimeIndicatorType(window.timeAttribute.getResultType)) {
- if (inputTimestampIndex < 0) {
+ if (inputTimeFieldIndex < 0) {
throw new TableException(
"Group window aggregate must defined on a time attribute, " +
"but the time attribute can't be found.\n" +
"This should never happen. Please file an issue."
)
}
- inputTimestampIndex
+ inputTimeFieldIndex
} else {
-1
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
index bf6f247..2474dba 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
@@ -267,7 +267,7 @@ object FlinkBatchRuleSets {
*/
private val LOGICAL_CONVERTERS: RuleSet = RuleSets.ofList(
FlinkLogicalAggregate.BATCH_CONVERTER,
- FlinkLogicalOverWindow.CONVERTER,
+ FlinkLogicalOverAggregate.CONVERTER,
FlinkLogicalCalc.CONVERTER,
FlinkLogicalCorrelate.CONVERTER,
FlinkLogicalJoin.CONVERTER,
@@ -329,7 +329,7 @@ object FlinkBatchRuleSets {
BatchExecNestedLoopJoinRule.INSTANCE,
BatchExecSingleRowJoinRule.INSTANCE,
BatchExecCorrelateRule.INSTANCE,
- BatchExecOverWindowAggRule.INSTANCE,
+ BatchExecOverAggregateRule.INSTANCE,
BatchExecWindowAggregateRule.INSTANCE,
BatchExecLookupJoinRule.SNAPSHOT_ON_TABLESCAN,
BatchExecLookupJoinRule.SNAPSHOT_ON_CALC_TABLESCAN,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
index cbab9e3..c030929 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
@@ -241,7 +241,7 @@ object FlinkStreamRuleSets {
private val LOGICAL_CONVERTERS: RuleSet = RuleSets.ofList(
// translate to flink logical rel nodes
FlinkLogicalAggregate.STREAM_CONVERTER,
- FlinkLogicalOverWindow.CONVERTER,
+ FlinkLogicalOverAggregate.CONVERTER,
FlinkLogicalCalc.CONVERTER,
FlinkLogicalCorrelate.CONVERTER,
FlinkLogicalJoin.CONVERTER,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala
index 3b21083..595ec18 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRule.scala
@@ -19,7 +19,7 @@ package org.apache.flink.table.plan.rules.logical
import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.FlinkContext
-import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalOverWindow, FlinkLogicalRank}
+import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalOverAggregate, FlinkLogicalRank}
import org.apache.flink.table.plan.util.RankUtil
import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType}
@@ -33,17 +33,17 @@ import org.apache.calcite.sql.{SqlKind, SqlRankFunction}
import scala.collection.JavaConversions._
/**
- * Planner rule that matches a [[FlinkLogicalCalc]] on a [[FlinkLogicalOverWindow]],
+ * Planner rule that matches a [[FlinkLogicalCalc]] on a [[FlinkLogicalOverAggregate]],
* and converts them into a [[FlinkLogicalRank]].
*/
abstract class FlinkLogicalRankRuleBase
extends RelOptRule(
operand(classOf[FlinkLogicalCalc],
- operand(classOf[FlinkLogicalOverWindow], any()))) {
+ operand(classOf[FlinkLogicalOverAggregate], any()))) {
override def onMatch(call: RelOptRuleCall): Unit = {
val calc: FlinkLogicalCalc = call.rel(0)
- val window: FlinkLogicalOverWindow = call.rel(1)
+ val window: FlinkLogicalOverAggregate = call.rel(1)
val group = window.groups.get(0)
val rankFun = group.aggCalls.get(0).getOperator.asInstanceOf[SqlRankFunction]
@@ -152,7 +152,7 @@ class FlinkLogicalRankRuleForRangeEnd extends FlinkLogicalRankRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0)
- val window: FlinkLogicalOverWindow = call.rel(1)
+ val window: FlinkLogicalOverAggregate = call.rel(1)
if (window.groups.size > 1) {
// only accept one window
@@ -175,7 +175,7 @@ class FlinkLogicalRankRuleForRangeEnd extends FlinkLogicalRankRuleBase {
val condition = calc.getProgram.getCondition
if (condition != null) {
val predicate = calc.getProgram.expandLocalRef(condition)
- // the rank function is the last field of FlinkLogicalOverWindow
+ // the rank function is the last field of FlinkLogicalOverAggregate
val rankFieldIndex = window.getRowType.getFieldCount - 1
val config = calc.getCluster.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig
val (rankRange, remainingPreds) = RankUtil.extractRankRange(
@@ -217,7 +217,7 @@ class FlinkLogicalRankRuleForRangeEnd extends FlinkLogicalRankRuleBase {
class FlinkLogicalRankRuleForConstantRange extends FlinkLogicalRankRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0)
- val window: FlinkLogicalOverWindow = call.rel(1)
+ val window: FlinkLogicalOverAggregate = call.rel(1)
if (window.groups.size > 1) {
// only accept one window
@@ -240,7 +240,7 @@ class FlinkLogicalRankRuleForConstantRange extends FlinkLogicalRankRuleBase {
val condition = calc.getProgram.getCondition
if (condition != null) {
val predicate = calc.getProgram.expandLocalRef(condition)
- // the rank function is the last field of FlinkLogicalOverWindow
+ // the rank function is the last field of FlinkLogicalOverAggregate
val rankFieldIndex = window.getRowType.getFieldCount - 1
val config = calc.getCluster.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig
val (rankRange, remainingPreds) = RankUtil.extractRankRange(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecOverWindowAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecOverAggregateRule.scala
similarity index 94%
rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecOverWindowAggRule.scala
rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecOverAggregateRule.scala
index a8c1324..b11773f 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecOverWindowAggRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecOverAggregateRule.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.rules.physical.batch
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
-import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverWindow
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverAggregate
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecOverAggregate
import org.apache.flink.table.plan.util.{AggregateUtil, OverAggregateUtil, SortUtil}
@@ -37,18 +37,18 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
/**
- * Rule that converts [[FlinkLogicalOverWindow]] to one or more [[BatchExecOverAggregate]]s.
+ * Rule that converts [[FlinkLogicalOverAggregate]] to one or more [[BatchExecOverAggregate]]s.
* If there are more than one [[Group]], this rule will combine adjacent [[Group]]s with the
* same partition keys and order keys into one BatchExecOverAggregate.
*/
-class BatchExecOverWindowAggRule
+class BatchExecOverAggregateRule
extends RelOptRule(
- operand(classOf[FlinkLogicalOverWindow],
+ operand(classOf[FlinkLogicalOverAggregate],
operand(classOf[RelNode], any)),
- "BatchExecOverWindowAggRule") {
+ "BatchExecOverAggregateRule") {
override def onMatch(call: RelOptRuleCall): Unit = {
- val logicWindow: FlinkLogicalOverWindow = call.rel(0)
+ val logicWindow: FlinkLogicalOverAggregate = call.rel(0)
var input: RelNode = call.rel(1)
var inputRowType = logicWindow.getInput.getRowType
val typeFactory = logicWindow.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
@@ -135,7 +135,7 @@ class BatchExecOverWindowAggRule
/**
* Returns true if group1 satisfies group2 on keys and orderKeys, else false.
*/
- def satisfies(group1: Group, group2: Group, logicWindow: FlinkLogicalOverWindow): Boolean = {
+ def satisfies(group1: Group, group2: Group, logicWindow: FlinkLogicalOverAggregate): Boolean = {
var isSatisfied = false
val keyComp = group1.keys.compareTo(group2.keys)
if (keyComp == 0) {
@@ -176,6 +176,6 @@ class BatchExecOverWindowAggRule
}
}
-object BatchExecOverWindowAggRule {
- val INSTANCE: RelOptRule = new BatchExecOverWindowAggRule
+object BatchExecOverAggregateRule {
+ val INSTANCE: RelOptRule = new BatchExecOverAggregateRule
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecOverAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecOverAggregateRule.scala
index 3a4873a..2eb09e1 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecOverAggregateRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecOverAggregateRule.scala
@@ -21,7 +21,7 @@ package org.apache.flink.table.plan.rules.physical.stream
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
-import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverWindow
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverAggregate
import org.apache.flink.table.plan.nodes.physical.stream.StreamExecOverAggregate
import org.apache.calcite.plan.RelOptRule
@@ -29,19 +29,19 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
/**
- * Rule that converts [[FlinkLogicalOverWindow]] to [[StreamExecOverAggregate]].
+ * Rule that converts [[FlinkLogicalOverAggregate]] to [[StreamExecOverAggregate]].
* NOTES: StreamExecOverAggregate only supports one [[org.apache.calcite.rel.core.Window.Group]],
* else throw exception now
*/
class StreamExecOverAggregateRule
extends ConverterRule(
- classOf[FlinkLogicalOverWindow],
+ classOf[FlinkLogicalOverAggregate],
FlinkConventions.LOGICAL,
FlinkConventions.STREAM_PHYSICAL,
"StreamExecOverAggregateRule") {
override def convert(rel: RelNode): RelNode = {
- val logicWindow: FlinkLogicalOverWindow = rel.asInstanceOf[FlinkLogicalOverWindow]
+ val logicWindow: FlinkLogicalOverAggregate = rel.asInstanceOf[FlinkLogicalOverAggregate]
if (logicWindow.groups.size > 1) {
throw new TableException(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/AggregateUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/AggregateUtil.scala
index 3d1798b..0c94f2e 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/AggregateUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/AggregateUtil.scala
@@ -17,18 +17,8 @@
*/
package org.apache.flink.table.plan.util
-import java.lang.{Long => JLong}
-import java.time.Duration
-import java.util
-
-import org.apache.calcite.rel.`type`._
-import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
-import org.apache.calcite.rex.RexInputRef
-import org.apache.calcite.sql.fun._
-import org.apache.calcite.sql.validate.SqlMonotonicity
-import org.apache.calcite.sql.{SqlKind, SqlRankFunction}
-import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation, Types}
+import org.apache.flink.table.JLong
import org.apache.flink.table.`type`.InternalTypes._
import org.apache.flink.table.`type`.{DecimalType, InternalType, InternalTypes, TypeConverters}
import org.apache.flink.table.api.{TableConfig, TableConfigOptions, TableException}
@@ -45,7 +35,18 @@ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.plan.`trait`.RelModifiedMonotonicity
import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger
-import org.apache.flink.table.typeutils._
+import org.apache.flink.table.typeutils.{BaseRowTypeInfo, BinaryStringTypeInfo, DecimalTypeInfo, MapViewTypeInfo, TimeIndicatorTypeInfo, TimeIntervalTypeInfo}
+
+import org.apache.calcite.rel.`type`._
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
+import org.apache.calcite.rex.RexInputRef
+import org.apache.calcite.sql.fun._
+import org.apache.calcite.sql.validate.SqlMonotonicity
+import org.apache.calcite.sql.{SqlKind, SqlRankFunction}
+import org.apache.calcite.tools.RelBuilder
+
+import java.time.Duration
+import java.util
import scala.collection.JavaConversions._
import scala.collection.mutable
@@ -737,4 +738,12 @@ object AggregateUtil extends Enumeration {
throw new IllegalArgumentException()
}
}
+
+ def extractTimeIntervalValue(literal: ValueLiteralExpression): JLong = {
+ if (isTimeIntervalType(literal.getType)) {
+ literal.getValue.asInstanceOf[JLong]
+ } else {
+ throw new IllegalArgumentException()
+ }
+ }
}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
index d82915b..87d6957 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelMdUtil.scala
@@ -19,10 +19,11 @@
package org.apache.flink.table.plan.util
import org.apache.flink.table.JDouble
+import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.dataformat.BinaryRow
-import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
-import org.apache.flink.table.plan.nodes.physical.batch.BatchExecGroupAggregateBase
+import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
+import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase}
import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange}
import org.apache.flink.table.runtime.sort.BinaryIndexedSortable
import org.apache.flink.table.typeutils.BinaryRowSerializer
@@ -148,6 +149,80 @@ object FlinkRelMdUtil {
1.0 - math.exp(-0.1 * groupingLength)
/**
+ * Creates a RexNode that stores a selectivity value corresponding to the
+ * selectivity of a NamedProperties predicate.
+ *
+ * @param winAgg window aggregate node
+ * @param predicate a RexNode
+ * @return constructed rexNode including non-NamedProperties predicates and
+ * a predicate that stores NamedProperties predicate's selectivity
+ */
+ def makeNamePropertiesSelectivityRexNode(
+ winAgg: WindowAggregate,
+ predicate: RexNode): RexNode = {
+ val fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(winAgg)
+ makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate)
+ }
+
+ /**
+ * Creates a RexNode that stores a selectivity value corresponding to the
+ * selectivity of a NamedProperties predicate.
+ *
+ * @param globalWinAgg global window aggregate node
+ * @param predicate a RexNode
+ * @return constructed rexNode including non-NamedProperties predicates and
+ * a predicate that stores NamedProperties predicate's selectivity
+ */
+ def makeNamePropertiesSelectivityRexNode(
+ globalWinAgg: BatchExecWindowAggregateBase,
+ predicate: RexNode): RexNode = {
+ require(globalWinAgg.isFinal, "local window agg does not contain NamedProperties!")
+ val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping
+ makeNamePropertiesSelectivityRexNode(
+ globalWinAgg, fullGrouping, globalWinAgg.getNamedProperties, predicate)
+ }
+
+ /**
+ * Creates a RexNode that stores a selectivity value corresponding to the
+ * selectivity of a NamedProperties predicate.
+ *
+ * @param winAgg window aggregate node
+ * @param fullGrouping full groupSets
+ * @param namedProperties NamedWindowProperty list
+ * @param predicate a RexNode
+ * @return constructed rexNode including non-NamedProperties predicates and
+ * a predicate that stores NamedProperties predicate's selectivity
+ */
+ def makeNamePropertiesSelectivityRexNode(
+ winAgg: SingleRel,
+ fullGrouping: Array[Int],
+ namedProperties: Seq[NamedWindowProperty],
+ predicate: RexNode): RexNode = {
+ if (predicate == null || predicate.isAlwaysTrue || namedProperties.isEmpty) {
+ return predicate
+ }
+ val rexBuilder = winAgg.getCluster.getRexBuilder
+ val namePropertiesStartIdx = winAgg.getRowType.getFieldCount - namedProperties.size
+ // split non-nameProperties predicates and nameProperties predicates
+ val pushable = new util.ArrayList[RexNode]
+ val notPushable = new util.ArrayList[RexNode]
+ RelOptUtil.splitFilters(
+ ImmutableBitSet.range(0, namePropertiesStartIdx),
+ predicate,
+ pushable,
+ notPushable)
+ if (notPushable.nonEmpty) {
+ val pred = RexUtil.composeConjunction(rexBuilder, notPushable, true)
+ val selectivity = RelMdUtil.guessSelectivity(pred)
+ val fun = rexBuilder.makeCall(
+ RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC,
+ rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)))
+ pushable.add(fun)
+ }
+ RexUtil.composeConjunction(rexBuilder, pushable, true)
+ }
+
+ /**
* Estimates outputRowCount of local aggregate.
*
* output rowcount of local agg is (1 - pow((1 - 1/x) , n/m)) * m * x, based on two assumption:
@@ -212,10 +287,34 @@ object FlinkRelMdUtil {
setChildKeysOfAgg(groupKey, aggRel)
}
+ /**
+ * Takes a bitmap representing a set of input references and extracts the
+ * ones that reference the group by columns in an aggregate.
+ *
+ * @param groupKey the original bitmap
+ * @param aggRel the aggregate
+ */
+ def setAggChildKeys(
+ groupKey: ImmutableBitSet,
+ aggRel: BatchExecWindowAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = {
+ require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!")
+ setChildKeysOfAgg(groupKey, aggRel)
+ }
+
private def setChildKeysOfAgg(
groupKey: ImmutableBitSet,
agg: SingleRel): (ImmutableBitSet, Array[AggregateCall]) = {
val (aggCalls, fullGroupSet) = agg match {
+ case agg: BatchExecLocalSortWindowAggregate =>
+ // grouping + assignTs + auxGrouping
+ (agg.getAggCallList,
+ agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping)
+ case agg: BatchExecLocalHashWindowAggregate =>
+ // grouping + assignTs + auxGrouping
+ (agg.getAggCallList,
+ agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping)
+ case agg: BatchExecWindowAggregateBase =>
+ (agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping)
case agg: BatchExecGroupAggregateBase =>
(agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping)
case _ => throw new IllegalArgumentException(s"Unknown aggregate: ${agg.getRelTypeName}")
@@ -237,7 +336,34 @@ object FlinkRelMdUtil {
}
/**
- * Split groupKeys on Agregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase
+ * Takes a bitmap representing a set of local window aggregate references.
+ *
+ * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
+ * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
+ *
+ * Skips `assignTs` when mapping `groupKey` to `childKey`.
+ *
+ * @param groupKey the original bitmap
+ * @param globalWinAgg the global window aggregate
+ */
+ def setChildKeysOfWinAgg(
+ groupKey: ImmutableBitSet,
+ globalWinAgg: BatchExecWindowAggregateBase): ImmutableBitSet = {
+ require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local window agg!")
+ val childKeyBuilder = ImmutableBitSet.builder
+ groupKey.toArray.foreach { key =>
+ if (key < globalWinAgg.getGrouping.length) {
+ childKeyBuilder.set(key)
+ } else {
+ // skips `assignTs`
+ childKeyBuilder.set(key + 1)
+ }
+ }
+ childKeyBuilder.build()
+ }
+
+ /**
+ * Split groupKeys on Aggregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase
* into keys on aggregate's groupKey and aggregate's aggregateCalls.
*
* @param agg the aggregate
@@ -271,6 +397,10 @@ object FlinkRelMdUtil {
val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping)
(childKeyExcludeAuxKey, aggCalls)
+ case rel: BatchExecWindowAggregateBase =>
+ val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
+ val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping)
+ (childKeyExcludeAuxKey, aggCalls)
case _ => throw new IllegalArgumentException(s"Unknown aggregate: ${agg.getRelTypeName}.")
}
}
@@ -306,6 +436,44 @@ object FlinkRelMdUtil {
splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate)
}
+ /**
+ * Split a predicate on WindowAggregateBatchExecBase into two parts,
+ * the first one is pushable part, the second one is rest part.
+ *
+ * @param agg Aggregate which to analyze
+ * @param predicate Predicate which to analyze
+ * @return a tuple, first element is pushable part, second element is rest part.
+ * Note, pushable condition will be converted based on the input field position.
+ */
+ def splitPredicateOnAggregate(
+ agg: BatchExecWindowAggregateBase,
+ predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
+ splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate)
+ }
+
+ /**
+ * Shifts every [[RexInputRef]] in an expression higher than length of full grouping
+ * (for skips `assignTs`).
+ *
+ * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
+ * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
+ *
+ * @param predicate a RexNode
+ * @param globalWinAgg the global window aggregate
+ */
+ def setChildPredicateOfWinAgg(
+ predicate: RexNode,
+ globalWinAgg: BatchExecWindowAggregateBase): RexNode = {
+ require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local window agg!")
+ if (predicate == null) {
+ return null
+ }
+ // grouping + assignTs + auxGrouping
+ val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping
+ // skips `assignTs`
+ RexUtil.shift(predicate, fullGrouping.length, 1)
+ }
+
private def splitPredicateOnAgg(
grouping: Array[Int],
agg: SingleRel,
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DagOptimizationTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DagOptimizationTest.xml
index f4c3e12..a58970b 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DagOptimizationTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DagOptimizationTest.xml
@@ -596,10 +596,11 @@ LogicalSink(fields=[a, sum_c, time])
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[a, SUM(c) AS sum_c], reuse_id=[1])
+HashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[a, Final_SUM(sum$0) AS sum_c], reuse_id=[1])
+- Exchange(distribution=[hash[a]])
- +- Calc(select=[ts, a, CAST(c) AS c])
- +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, ts)]]], fields=[a, b, c, ts])
+ +- LocalHashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[a, Partial_SUM(c) AS sum$0])
+ +- Calc(select=[ts, a, CAST(c) AS c])
+ +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, ts)]]], fields=[a, b, c, ts])
Sink(fields=[a, sum_c, time, window_start, window_end])
+- Calc(select=[a, sum_c, w$end AS time, w$start AS window_start, w$end AS window_end])
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/OverWindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.xml
similarity index 100%
rename from flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/OverWindowAggregateTest.xml
rename to flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.xml
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/WindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/WindowAggregateTest.xml
index 7539186..b875e28 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/WindowAggregateTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/WindowAggregateTest.xml
@@ -40,10 +40,11 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3], EXPR$3=[$4], EXPR$4=[TUMBL
<Resource name="planAfter">
<![CDATA[
Calc(select=[/(-($f0, /(*(CAST($f1), CAST($f1)), $f2)), $f2) AS EXPR$0, /(-($f0, /(*(CAST($f1), CAST($f1)), $f2)), CASE(=($f2, 1), null:BIGINT, -($f2, 1))) AS EXPR$1, POWER(/(-($f0, /(*(CAST($f1), CAST($f1)), $f2)), $f2), 0.5:DECIMAL(2, 1)) AS EXPR$2, POWER(/(-($f0, /(*(CAST($f1), CAST($f1)), $f2)), CASE(=($f2, 1), null:BIGINT, -($f2, 1))), 0.5:DECIMAL(2, 1)) AS EXPR$3, w$start AS EXPR$4, w$end AS EXPR$5])
-+- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[SUM($f2) AS $f0, SUM(b) AS $f1, COUNT(b) AS $f2])
++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_SUM(sum$0) AS $f0, Final_SUM(sum$1) AS $f1, Final_COUNT(count$2) AS $f2])
+- Exchange(distribution=[single])
- +- Calc(select=[ts, b, *(CAST(b), CAST(b)) AS $f2])
- +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
+ +- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_SUM($f2) AS sum$0, Partial_SUM(b) AS sum$1, Partial_COUNT(b) AS count$2])
+ +- Calc(select=[ts, b, *(CAST(b), CAST(b)) AS $f2])
+ +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
@@ -134,10 +135,11 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[HOP_START($0)])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0, w$start AS EXPR$1], where=[AND(>($f1, 0), =(EXTRACT(FLAG(QUARTER), w$start), 1:BIGINT))])
-+- HashWindowAggregate(window=[SlidingGroupWindow('w$, ts, 60000.millis, 900000.millis)], properties=[w$start, w$end, w$rowtime], select=[COUNT(*) AS EXPR$0, SUM(a) AS $f1])
++- HashWindowAggregate(window=[SlidingGroupWindow('w$, ts, 60000.millis, 900000.millis)], properties=[w$start, w$end, w$rowtime], select=[Final_COUNT(count1$0) AS EXPR$0, Final_SUM(sum$1) AS $f1])
+- Exchange(distribution=[single])
- +- Calc(select=[ts, a])
- +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+ +- LocalHashWindowAggregate(window=[SlidingGroupWindow('w$, ts, 60000.millis, 900000.millis)], properties=[w$start, w$end, w$rowtime], select=[Partial_COUNT(*) AS count1$0, Partial_SUM(a) AS sum$1])
+ +- Calc(select=[ts, a])
+ +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
@@ -248,10 +250,11 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[HOP_START($0)], EXPR$2=[HOP_END($0)])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0, w$start AS EXPR$1, w$end AS EXPR$2])
-+- HashWindowAggregate(window=[SlidingGroupWindow('w$, b, 3000.millis, 3000.millis)], properties=[w$start, w$end, w$rowtime], select=[SUM(a) AS EXPR$0])
++- HashWindowAggregate(window=[SlidingGroupWindow('w$, b, 3000.millis, 3000.millis)], properties=[w$start, w$end, w$rowtime], select=[Final_SUM(sum$0) AS EXPR$0])
+- Exchange(distribution=[single])
- +- Calc(select=[b, a])
- +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+ +- LocalHashWindowAggregate(window=[SlidingGroupWindow('w$, b, 3000.millis, 3000.millis)], properties=[w$start, w$end, w$rowtime], select=[Partial_SUM(a) AS sum$0])
+ +- Calc(select=[b, a])
+ +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
@@ -326,10 +329,11 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[$2])
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashWindowAggregate(window=[TumblingGroupWindow], select=[AVG(c) AS EXPR$0, SUM(a) AS EXPR$1])
+HashWindowAggregate(window=[TumblingGroupWindow], select=[Final_AVG(sum$0, count$1) AS EXPR$0, Final_SUM(sum$2) AS EXPR$1])
+- Exchange(distribution=[single])
- +- Calc(select=[b, c, a])
- +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+ +- LocalHashWindowAggregate(window=[TumblingGroupWindow], select=[Partial_AVG(c) AS (sum$0, count$1), Partial_SUM(a) AS sum$2])
+ +- Calc(select=[b, c, a])
+ +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
@@ -434,10 +438,11 @@ LogicalProject(sumA=[$1], cntB=[$2])
</Resource>
<Resource name="planAfter">
<![CDATA[
-HashWindowAggregate(window=[TumblingGroupWindow], select=[SUM(a) AS sumA, COUNT(b) AS cntB])
+HashWindowAggregate(window=[TumblingGroupWindow], select=[Final_SUM(sum$0) AS sumA, Final_COUNT(count$1) AS cntB])
+- Exchange(distribution=[single])
- +- Calc(select=[ts, a, b])
- +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+ +- LocalHashWindowAggregate(window=[TumblingGroupWindow], select=[Partial_SUM(a) AS sum$0, Partial_COUNT(b) AS count$1])
+ +- Calc(select=[ts, a, b])
+ +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
@@ -596,10 +601,11 @@ LogicalProject(EXPR$0=[TUMBLE_START($0)], EXPR$1=[TUMBLE_END($0)], EXPR$2=[TUMBL
<Resource name="planAfter">
<![CDATA[
Calc(select=[w$start AS EXPR$0, w$end AS EXPR$1, w$rowtime AS EXPR$2, c, sumA, minB])
-+- HashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c, SUM(a) AS sumA, MIN(b) AS minB])
++- HashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c, Final_SUM(sum$0) AS sumA, Final_MIN(min$1) AS minB])
+- Exchange(distribution=[hash[c]])
- +- Calc(select=[ts, c, a, b])
- +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+ +- LocalHashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c, Partial_SUM(a) AS sum$0, Partial_MIN(b) AS min$1])
+ +- Calc(select=[ts, c, a, b])
+ +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
@@ -750,11 +756,13 @@ LogicalProject(EXPR$0=[$2])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0])
-+- SortWindowAggregate(groupBy=[a], window=[SlidingGroupWindow('w$, ts, 3600000.millis, 3000.millis)], select=[a, MAX(c) AS EXPR$0])
- +- Sort(orderBy=[a ASC, ts ASC])
++- SortWindowAggregate(groupBy=[a], window=[SlidingGroupWindow('w$, ts, 3600000.millis, 3000.millis)], select=[a, Final_MAX(max$0) AS EXPR$0])
+ +- Sort(orderBy=[a ASC, assignedPane$ ASC])
+- Exchange(distribution=[hash[a]])
- +- Calc(select=[a, ts, c])
- +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
+ +- LocalSortWindowAggregate(groupBy=[a], window=[SlidingGroupWindow('w$, ts, 3600000.millis, 3000.millis)], select=[a, Partial_MAX(c) AS max$0])
+ +- Sort(orderBy=[a ASC, ts ASC])
+ +- Calc(select=[a, ts, c])
+ +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
@@ -1004,10 +1012,11 @@ LogicalProject(EXPR$0=[$2])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0])
-+- HashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, COUNT(c) AS EXPR$0])
++- HashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, Final_COUNT(count$0) AS EXPR$0])
+- Exchange(distribution=[hash[a]])
- +- Calc(select=[a, ts, c])
- +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
+ +- LocalHashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, Partial_COUNT(c) AS count$0])
+ +- Calc(select=[a, ts, c])
+ +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
@@ -1071,9 +1080,10 @@ LogicalProject(EXPR$0=[$3], EXPR$1=[$4])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0, EXPR$1])
-+- HashWindowAggregate(groupBy=[a, d], window=[TumblingGroupWindow], select=[a, d, AVG(c) AS EXPR$0, COUNT(a) AS EXPR$1])
++- HashWindowAggregate(groupBy=[a, d], window=[TumblingGroupWindow], select=[a, d, Final_AVG(sum$0, count$1) AS EXPR$0, Final_COUNT(count$2) AS EXPR$1])
+- Exchange(distribution=[hash[a, d]])
- +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+ +- LocalHashWindowAggregate(groupBy=[a, d], window=[TumblingGroupWindow], select=[a, d, Partial_AVG(c) AS (sum$0, count$1), Partial_COUNT(a) AS count$2])
+ +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
@@ -1134,11 +1144,13 @@ LogicalProject(wAvg=[$1])
</Resource>
<Resource name="planAfter">
<![CDATA[
-SortWindowAggregate(window=[TumblingGroupWindow], select=[weightedAvg(b, a) AS wAvg])
-+- Sort(orderBy=[ts ASC])
+SortWindowAggregate(window=[TumblingGroupWindow], select=[Final_weightedAvg(wAvg) AS wAvg])
++- Sort(orderBy=[assignedWindow$ ASC])
+- Exchange(distribution=[single])
- +- Calc(select=[ts, b, a])
- +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+ +- LocalSortWindowAggregate(window=[TumblingGroupWindow], select=[Partial_weightedAvg(b, a) AS wAvg])
+ +- Sort(orderBy=[ts ASC])
+ +- Calc(select=[ts, b, a])
+ +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
@@ -1157,11 +1169,13 @@ LogicalProject(EXPR$0=[$2])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0])
-+- SortWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, MAX(c) AS EXPR$0])
- +- Sort(orderBy=[a ASC, ts ASC])
++- SortWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, Final_MAX(max$0) AS EXPR$0])
+ +- Sort(orderBy=[a ASC, assignedWindow$ ASC])
+- Exchange(distribution=[hash[a]])
- +- Calc(select=[a, ts, c])
- +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
+ +- LocalSortWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, Partial_MAX(c) AS max$0])
+ +- Sort(orderBy=[a ASC, ts ASC])
+ +- Calc(select=[a, ts, c])
+ +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
@@ -1298,8 +1312,9 @@ LogicalProject(EXPR$0=[TUMBLE_END($0)])
Calc(select=[w$end AS EXPR$0])
+- HashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c])
+- Exchange(distribution=[hash[c]])
- +- Calc(select=[ts, c])
- +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+ +- LocalHashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c])
+ +- Calc(select=[ts, c])
+ +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRuleForConstantRangeTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRuleForConstantRangeTest.xml
index 128d5d4..b26f882 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRuleForConstantRangeTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRuleForConstantRangeTest.xml
@@ -61,7 +61,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2], rn=[$3])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0, w1$o0], where=[<(w0$o0, 10)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -87,7 +87,7 @@ LogicalProject(a=[$0], b=[$1], rk1=[$2], rk2=[$3])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0, w1$o0], where=[<(w0$o0, 10)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -161,7 +161,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[>(w0$o0, 2)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first, 2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first, 2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -233,7 +233,7 @@ LogicalProject(a=[$0], b=[$1], rn=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[<=(w0$o0, 2)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -257,7 +257,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[<(w0$o0, a)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -281,7 +281,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[>(w0$o0, a)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -305,7 +305,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[AND(<(w0$o0, a), >(CAST(b), 5:BIGINT))])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -329,7 +329,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[=(w0$o0, b)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {0} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {0} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -371,7 +371,7 @@ LogicalProject(a=[$0], b=[$1], rk=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULL
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0 AS $2])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRuleForRangeEndTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRuleForRangeEndTest.xml
index 70bf902..1963e1a 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRuleForRangeEndTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLogicalRankRuleForRangeEndTest.xml
@@ -37,7 +37,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2], rn=[$3])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0, w1$o0], where=[<(w0$o0, 10)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -63,7 +63,7 @@ LogicalProject(a=[$0], b=[$1], rk1=[$2], rk2=[$3])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0, w1$o0], where=[<(w0$o0, 10)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -235,7 +235,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[>(w0$o0, a)])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
@@ -325,7 +325,7 @@ LogicalProject(a=[$0], b=[$1], rk=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULL
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0 AS $2])
-+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
++- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/agg/OverWindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/agg/OverAggregateTest.xml
similarity index 100%
rename from flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/agg/OverWindowAggregateTest.xml
rename to flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/agg/OverAggregateTest.xml
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverWindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.scala
similarity index 99%
rename from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverWindowAggregateTest.scala
rename to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.scala
index d54a7a1..d7d4559 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverWindowAggregateTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.scala
@@ -28,7 +28,7 @@ import org.junit.Test
import java.sql.Timestamp
-class OverWindowAggregateTest extends TableTestBase {
+class OverAggregateTest extends TableTestBase {
private val util = batchTestUtil()
util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnIntervalTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnIntervalTest.scala
index 18dbbfe..164103a 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnIntervalTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnIntervalTest.scala
@@ -448,8 +448,38 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testGetColumnIntervalOnOverWindowAgg(): Unit = {
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach {
+ def testGetColumnIntervalOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithLocalAgg,
+ batchGlobalWindowAggWithoutLocalAgg, streamWindowAgg).foreach { agg =>
+ assertEquals(ValueInterval(5, 45), mq.getColumnInterval(agg, 0))
+ assertEquals(null, mq.getColumnInterval(agg, 1))
+ assertEquals(RightSemiInfiniteValueInterval(0), mq.getColumnInterval(agg, 2))
+ assertEquals(null, mq.getColumnInterval(agg, 3))
+ }
+ assertEquals(ValueInterval(5, 45), mq.getColumnInterval(batchLocalWindowAgg, 0))
+ assertEquals(null, mq.getColumnInterval(batchLocalWindowAgg, 1))
+ assertEquals(null, mq.getColumnInterval(batchLocalWindowAgg, 2))
+ assertEquals(RightSemiInfiniteValueInterval(0), mq.getColumnInterval(batchLocalWindowAgg, 3))
+ assertEquals(null, mq.getColumnInterval(batchLocalWindowAgg, 4))
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup).foreach { agg =>
+ assertEquals(ValueInterval(5, 55), mq.getColumnInterval(agg, 0))
+ assertEquals(ValueInterval(0, 50), mq.getColumnInterval(agg, 1))
+ assertEquals(ValueInterval(0, null), mq.getColumnInterval(agg, 2))
+ assertEquals(null, mq.getColumnInterval(agg, 3))
+ }
+ assertEquals(ValueInterval(5, 55), mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 0))
+ assertEquals(null, mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 1))
+ assertEquals(ValueInterval(0, 50), mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 2))
+ assertEquals(ValueInterval(0, null), mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 3))
+ assertEquals(null, mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 4))
+ }
+
+ @Test
+ def testGetColumnIntervalOnOverAgg(): Unit = {
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach {
agg =>
assertEquals(ValueInterval(0, null), mq.getColumnInterval(agg, 0))
assertEquals(null, mq.getColumnInterval(agg, 1))
@@ -464,14 +494,14 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase {
assertNull(mq.getColumnInterval(agg, 10))
}
- assertEquals(ValueInterval(0, null), mq.getColumnInterval(streamOverWindowAgg, 0))
- assertEquals(null, mq.getColumnInterval(streamOverWindowAgg, 1))
- assertEquals(ValueInterval(2.7, 4.8), mq.getColumnInterval(streamOverWindowAgg, 2))
- assertEquals(ValueInterval(12, 18), mq.getColumnInterval(streamOverWindowAgg, 3))
- assertNull(mq.getColumnInterval(streamOverWindowAgg, 4))
- assertNull(mq.getColumnInterval(streamOverWindowAgg, 5))
- assertNull(mq.getColumnInterval(streamOverWindowAgg, 6))
- assertNull(mq.getColumnInterval(streamOverWindowAgg, 7))
+ assertEquals(ValueInterval(0, null), mq.getColumnInterval(streamOverAgg, 0))
+ assertEquals(null, mq.getColumnInterval(streamOverAgg, 1))
+ assertEquals(ValueInterval(2.7, 4.8), mq.getColumnInterval(streamOverAgg, 2))
+ assertEquals(ValueInterval(12, 18), mq.getColumnInterval(streamOverAgg, 3))
+ assertNull(mq.getColumnInterval(streamOverAgg, 4))
+ assertNull(mq.getColumnInterval(streamOverAgg, 5))
+ assertNull(mq.getColumnInterval(streamOverAgg, 6))
+ assertNull(mq.getColumnInterval(streamOverAgg, 7))
}
@Test
@@ -490,7 +520,7 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase {
assertEquals(ValueInterval(1L, 800000000L), mq.getColumnInterval(join, 1))
assertNull(mq.getColumnInterval(join, 2))
assertNull(mq.getColumnInterval(join, 3))
- assertEquals(ValueInterval(1L, 100L),mq.getColumnInterval(join, 4))
+ assertEquals(ValueInterval(1L, 100L), mq.getColumnInterval(join, 4))
assertNull(mq.getColumnInterval(join, 5))
assertEquals(ValueInterval(8L, 1000L), mq.getColumnInterval(join, 6))
assertNull(mq.getColumnInterval(join, 7))
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniquenessTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniquenessTest.scala
index 2814743..0bf706b 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniquenessTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnUniquenessTest.scala
@@ -355,8 +355,57 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testAreColumnsUniqueOnOverWindow(): Unit = {
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
+ def testAreColumnsUniqueOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithLocalAgg,
+ batchGlobalWindowAggWithoutLocalAgg, streamWindowAgg).foreach { agg =>
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 3)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 2)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 3)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 4)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 5)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 6)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 3, 4, 5, 6)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2, 3)))
+ }
+ assertNull(mq.areColumnsUnique(batchLocalWindowAgg, ImmutableBitSet.of(0, 1)))
+ assertNull(mq.areColumnsUnique(batchLocalWindowAgg, ImmutableBitSet.of(0, 1, 3)))
+
+ Array(logicalWindowAgg2, flinkLogicalWindowAgg2, batchGlobalWindowAggWithLocalAgg2,
+ batchGlobalWindowAggWithoutLocalAgg2, streamWindowAgg2).foreach { agg =>
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 3)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 4)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 5)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2, 3, 4, 5)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(1, 2)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(1, 3)))
+ }
+ assertNull(mq.areColumnsUnique(batchLocalWindowAgg2, ImmutableBitSet.of(0, 1)))
+ assertNull(mq.areColumnsUnique(batchLocalWindowAgg2, ImmutableBitSet.of(0, 2)))
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup, batchGlobalWindowAggWithoutLocalAggWithAuxGroup
+ ).foreach { agg =>
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 2)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 3)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 4)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 5)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 6)))
+ assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 3, 4, 5, 6)))
+ assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(1, 3)))
+ }
+ assertNull(mq.areColumnsUnique(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1)))
+ assertNull(mq.areColumnsUnique(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 3)))
+ }
+
+ @Test
+ def testAreColumnsUniqueOnOverAgg(): Unit = {
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(1)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(2)))
@@ -375,20 +424,20 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 10)))
assertNull(mq.areColumnsUnique(agg, ImmutableBitSet.of(5, 10)))
}
- assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0)))
- assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(1)))
- assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(2)))
- assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(3)))
- assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(4)))
- assertNull(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(5)))
- assertNull(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(6)))
- assertNull(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(7)))
- assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0, 1)))
- assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0, 2)))
- assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(1, 2)))
- assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0, 5)))
- assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0, 7)))
- assertNull(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(5, 7)))
+ assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0)))
+ assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(1)))
+ assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(2)))
+ assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(3)))
+ assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(4)))
+ assertNull(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(5)))
+ assertNull(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(6)))
+ assertNull(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(7)))
+ assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0, 1)))
+ assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0, 2)))
+ assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(1, 2)))
+ assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0, 5)))
+ assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0, 7)))
+ assertNull(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(5, 7)))
}
@Test
@@ -465,38 +514,38 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
@Test
def testAreColumnsUniqueOnIntersect(): Unit = {
- assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (0)))
- assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (1)))
- assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (2)))
- assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (1, 2)))
- assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (0, 2)))
- assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (1, 2)))
-
- assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (0)))
- assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (1)))
- assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (2)))
- assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (1, 2)))
- assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (0, 2)))
- assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (1, 2)))
+ assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(0)))
+ assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(1)))
+ assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(2)))
+ assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(1, 2)))
+ assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(0, 2)))
+ assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(1, 2)))
+
+ assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(0)))
+ assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(1)))
+ assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(2)))
+ assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(1, 2)))
+ assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(0, 2)))
+ assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(1, 2)))
assertTrue(mq.areColumnsUnique(logicalIntersect,
ImmutableBitSet.range(logicalIntersect.getRowType.getFieldCount)))
}
@Test
def testAreColumnsUniqueOnMinus(): Unit = {
- assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (0)))
- assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (1)))
- assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (2)))
- assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (1, 2)))
- assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (0, 2)))
- assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (1, 2)))
-
- assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (0)))
- assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (1)))
- assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (2)))
- assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (1, 2)))
- assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (0, 2)))
- assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (1, 2)))
+ assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(0)))
+ assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(1)))
+ assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(2)))
+ assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(1, 2)))
+ assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(0, 2)))
+ assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(1, 2)))
+
+ assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(0)))
+ assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(1)))
+ assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(2)))
+ assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(1, 2)))
+ assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(0, 2)))
+ assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(1, 2)))
assertTrue(mq.areColumnsUnique(logicalMinus,
ImmutableBitSet.range(logicalMinus.getRowType.getFieldCount)))
@@ -505,12 +554,12 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
.scan("MyTable2")
.scan("MyTable1")
.minus(false).build()
- assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (0)))
- assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (1)))
- assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (2)))
- assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (1, 2)))
- assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (0, 2)))
- assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (1, 2)))
+ assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(0)))
+ assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(1)))
+ assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(2)))
+ assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(1, 2)))
+ assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(0, 2)))
+ assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(1, 2)))
assertTrue(mq.areColumnsUnique(logicalMinus2,
ImmutableBitSet.range(logicalMinus2.getRowType.getFieldCount)))
}
@@ -518,7 +567,7 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
@Test
def testGetColumnNullCountOnDefault(): Unit = {
(0 until testRel.getRowType.getFieldCount).foreach { idx =>
- assertNull(mq.areColumnsUnique(testRel, ImmutableBitSet.of (idx)))
+ assertNull(mq.areColumnsUnique(testRel, ImmutableBitSet.of(idx)))
}
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCountTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCountTest.scala
index f9882d8..ec7dda4 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCountTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdDistinctRowCountTest.scala
@@ -19,6 +19,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecRank
+import org.apache.flink.table.plan.util.FlinkRelMdUtil
import org.apache.calcite.rel.metadata.RelMdUtil
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
@@ -428,8 +429,100 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testGetDistinctRowCountOnOverWindow(): Unit = {
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
+ def testGetDistinctRowCountOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
+ batchGlobalWindowAggWithLocalAgg).foreach { agg =>
+ assertEquals(30D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0), null))
+ assertEquals(5D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), null))
+ assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), null))
+ assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 2), null))
+ assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(3), null))
+ assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 3), null))
+ assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1, 3), null))
+ assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(2, 3), null))
+
+ relBuilder.clear()
+ // $1 > 10
+ val pred = relBuilder
+ .push(agg)
+ .call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10))
+ assertEquals(
+ FlinkRelMdUtil.adaptNdvBasedOnSelectivity(50.0D, 5.0D, 0.5D),
+ mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), pred), 1e-6)
+ assertEquals(25D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), pred))
+
+ // b > 10 and count(c) > 1 and w$end = 100000
+ val pred1 = relBuilder
+ .push(agg)
+ .and(
+ relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10)),
+ relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(1)),
+ relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(100000))
+ )
+ assertEquals(
+ FlinkRelMdUtil.adaptNdvBasedOnSelectivity(50.0D, 5.0D, 0.075D),
+ mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), pred1), 1e-6)
+ assertEquals(25D * 0.15D * 1.0D,
+ mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), pred1), 1e-2)
+ }
+ assertEquals(30D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(0), null))
+ assertEquals(5D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(1), null))
+ assertEquals(50D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(0, 1), null))
+ assertEquals(null, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(0, 2), null))
+ assertEquals(10D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(3), null))
+ assertEquals(50D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(0, 3), null))
+ assertEquals(50.0, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(1, 3), null))
+ assertEquals(null, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(2, 3), null))
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
+ assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0), null))
+ assertEquals(48D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), null))
+ assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), null))
+ assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 2), null))
+ assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1, 2), null))
+ assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(3), null))
+
+ relBuilder.clear()
+ // $1 > 10
+ val pred = relBuilder
+ .push(agg)
+ .call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10))
+ assertEquals(
+ FlinkRelMdUtil.adaptNdvBasedOnSelectivity(50.0D, 48.0D, 0.8D),
+ mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), pred), 1e-6)
+ assertEquals(40D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), pred))
+
+ // b > 10 and count(c) > 1 and w$end = 100000
+ val pred1 = relBuilder
+ .push(agg)
+ .and(
+ relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10)),
+ relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(1)),
+ relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(100000))
+ )
+ assertEquals(
+ FlinkRelMdUtil.adaptNdvBasedOnSelectivity(50.0D, 48.0D, 0.12D),
+ mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), pred1), 1e-6)
+ assertEquals(40D * 0.15D * 1.0D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), pred1))
+ }
+ assertEquals(50D,
+ mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0), null))
+ assertNull(mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(1), null))
+ assertNull(
+ mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1), null))
+ assertEquals(50D,
+ mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 2), null))
+ assertNull(
+ mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(1, 2), null))
+ assertEquals(10D,
+ mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(3), null))
+ }
+
+ @Test
+ def testGetDistinctRowCountOnOverAgg(): Unit = {
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(1.0, mq.getDistinctRowCount(agg, ImmutableBitSet.of(), null))
assertEquals(50.0, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0), null))
assertEquals(48.0, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), null))
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdHandlerTestBase.scala
index ad312d1..a75a90d 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdHandlerTestBase.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdHandlerTestBase.scala
@@ -18,29 +18,33 @@
package org.apache.flink.table.plan.metadata
-import org.apache.flink.table.`type`.{InternalType, InternalTypes}
+import org.apache.flink.table.`type`.{InternalType, InternalTypes, TypeConverters}
import org.apache.flink.table.api.{TableConfig, TableException}
+import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.{FlinkCalciteCatalogReader, FlinkRelBuilder, FlinkTypeFactory}
+import org.apache.flink.table.expressions.{FieldReferenceExpression, ProctimeAttribute, RowtimeAttribute, ValueLiteralExpression, WindowReference, WindowStart}
import org.apache.flink.table.functions.aggfunctions.SumAggFunction.DoubleSumAggFunction
import org.apache.flink.table.functions.aggfunctions.{DenseRankAggFunction, RankAggFunction, RowNumberAggFunction}
import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable
import org.apache.flink.table.plan.PartialFinalType
import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
+import org.apache.flink.table.plan.logical.{LogicalWindow, TumblingGroupWindow}
import org.apache.flink.table.plan.nodes.FlinkConventions
-import org.apache.flink.table.plan.nodes.calcite.{LogicalExpand, LogicalRank}
+import org.apache.flink.table.plan.nodes.calcite.{LogicalExpand, LogicalRank, LogicalWindowAggregate}
import org.apache.flink.table.plan.nodes.logical._
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream._
import org.apache.flink.table.plan.schema.FlinkRelOptTable
import org.apache.flink.table.plan.util.AggregateUtil.transformToStreamAggregateInfoList
-import org.apache.flink.table.plan.util.{AggFunctionFactory, AggregateUtil, ExpandUtil, FlinkRelOptUtil, SortUtil}
+import org.apache.flink.table.plan.util.{AggFunctionFactory, AggregateUtil, ExpandUtil, FlinkRelOptUtil, SortUtil, WindowEmitStrategy}
import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankType, VariableRankRange}
+import org.apache.flink.table.typeutils.TimeIntervalTypeInfo.INTERVAL_MILLIS
import org.apache.flink.table.util.CountAggFunction
-import com.google.common.collect.ImmutableList
+import com.google.common.collect.{ImmutableList, Lists}
import org.apache.calcite.plan.{Convention, ConventionTraitDef, RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl}
-import org.apache.calcite.rel.core.{AggregateCall, Calc, JoinRelType, Window}
+import org.apache.calcite.rel.core.{AggregateCall, Calc, JoinRelType, Project, Window}
import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject, LogicalSort, LogicalTableScan, LogicalValues}
import org.apache.calcite.rel.metadata.{JaninoRelMetadataProvider, RelMetadataQuery}
import org.apache.calcite.rel.{RelCollationImpl, RelCollationTraitDef, RelCollations, RelFieldCollation, RelNode, SingleRel}
@@ -49,8 +53,8 @@ import org.apache.calcite.schema.SchemaPlus
import org.apache.calcite.sql.SqlWindow
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, BOOLEAN, DATE, DOUBLE, FLOAT, TIME, TIMESTAMP, VARCHAR}
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.sql.fun.SqlStdOperatorTable.{AND, CASE, DIVIDE, EQUALS, GREATER_THAN, LESS_THAN, MINUS, MULTIPLY, PLUS}
+import org.apache.calcite.sql.fun.{SqlCountAggFunction, SqlStdOperatorTable}
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.tools.FrameworkConfig
import org.apache.calcite.util.{DateString, ImmutableBitSet, TimeString, TimestampString}
@@ -90,17 +94,11 @@ class FlinkRelMdHandlerTestBase {
logicalTraits = cluster.traitSetOf(Convention.NONE)
- flinkLogicalTraits = cluster
- .traitSetOf(Convention.NONE)
- .replace(FlinkConventions.LOGICAL)
+ flinkLogicalTraits = cluster.traitSetOf(FlinkConventions.LOGICAL)
- batchPhysicalTraits = cluster
- .traitSetOf(Convention.NONE)
- .replace(FlinkConventions.BATCH_PHYSICAL)
+ batchPhysicalTraits = cluster.traitSetOf(FlinkConventions.BATCH_PHYSICAL)
- streamPhysicalTraits = cluster
- .traitSetOf(Convention.NONE)
- .replace(FlinkConventions.STREAM_PHYSICAL)
+ streamPhysicalTraits = cluster.traitSetOf(FlinkConventions.STREAM_PHYSICAL)
}
protected val intType: RelDataType = typeFactory.createTypeFromInternalType(
@@ -579,16 +577,16 @@ class FlinkRelMdHandlerTestBase {
// equivalent SQL is
// select a, b, c from (
// select a, b, c, proctime
- // ROW_NUMBER() over (partition by b order by proctime) rn from TemporalTable
+ // ROW_NUMBER() over (partition by b order by proctime) rn from TemporalTable3
// ) t where rn <= 1
//
// select a, b, c from (
// select a, b, c, proctime
- // ROW_NUMBER() over (partition by b, c order by proctime desc) rn from TemporalTable
+ // ROW_NUMBER() over (partition by b, c order by proctime desc) rn from TemporalTable3
// ) t where rn <= 1
protected lazy val (streamDeduplicateFirstRow, streamDeduplicateLastRow) = {
val scan: StreamExecDataStreamScan =
- createDataStreamScan(ImmutableList.of("TemporalTable"), streamPhysicalTraits)
+ createDataStreamScan(ImmutableList.of("TemporalTable3"), streamPhysicalTraits)
val hash1 = FlinkRelDistribution.hash(Array(1), requireStrict = true)
val streamExchange1 = new StreamExecExchange(
cluster, scan.getTraitSet.replace(hash1), scan, hash1)
@@ -943,6 +941,449 @@ class FlinkRelMdHandlerTestBase {
batchLocalAggWithAuxGroup, batchGlobalAggWithAuxGroup, batchGlobalAggWithoutLocalWithAuxGroup)
}
+ // For window start/end/proc_time the windowAttribute inferred type is a hard code val,
+ // only for row_time we distinguish by batch row time, for what we hard code DataTypes.TIMESTAMP,
+ // which is ok here for testing.
+ private lazy val windowRef: WindowReference =
+ WindowReference.apply("w$", Some(InternalTypes.TIMESTAMP))
+
+ protected lazy val tumblingGroupWindow: LogicalWindow =
+ TumblingGroupWindow(
+ windowRef,
+ new FieldReferenceExpression(
+ "rowtime",
+ TypeConverters.createExternalTypeInfoFromInternalType(InternalTypes.ROWTIME_INDICATOR),
+ 0,
+ 4),
+ new ValueLiteralExpression(900000, INTERVAL_MILLIS)
+ )
+
+ protected lazy val namedPropertiesOfWindowAgg: Seq[NamedWindowProperty] =
+ Seq(NamedWindowProperty("w$start", WindowStart(windowRef)),
+ NamedWindowProperty("w$end", WindowStart(windowRef)),
+ NamedWindowProperty("w$rowtime", RowtimeAttribute(windowRef)),
+ NamedWindowProperty("w$proctime", ProctimeAttribute(windowRef)))
+
+ // equivalent SQL is
+ // select a, b, count(c) as s,
+ // TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as w$start,
+ // TUMBLE_END(rowtime, INTERVAL '15' MINUTE) as w$end,
+ // TUMBLE_ROWTIME(rowtime, INTERVAL '15' MINUTE) as w$rowtime,
+ // TUMBLE_PROCTIME(rowtime, INTERVAL '15' MINUTE) as w$proctime
+ // from TemporalTable1 group by a, b, TUMBLE(rowtime, INTERVAL '15' MINUTE)
+ protected lazy val (
+ logicalWindowAgg,
+ flinkLogicalWindowAgg,
+ batchLocalWindowAgg,
+ batchGlobalWindowAggWithLocalAgg,
+ batchGlobalWindowAggWithoutLocalAgg,
+ streamWindowAgg) = {
+ relBuilder.scan("TemporalTable1")
+ val ts = relBuilder.peek()
+ val project = relBuilder.project(relBuilder.fields(Seq[Integer](0, 1, 4, 2).toList))
+ .build().asInstanceOf[Project]
+ val program = RexProgram.create(
+ ts.getRowType, project.getProjects, null, project.getRowType, rexBuilder)
+ val aggCallOfWindowAgg = Lists.newArrayList(AggregateCall.create(
+ new SqlCountAggFunction("COUNT"), false, false, List[Integer](3), -1, 2, project, null, "s"))
+ // TUMBLE(rowtime, INTERVAL '15' MINUTE))
+ val logicalWindowAgg = new LogicalWindowAggregate(
+ ts.getCluster,
+ ts.getTraitSet,
+ project,
+ ImmutableBitSet.of(0, 1),
+ aggCallOfWindowAgg,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg)
+
+ val flinkLogicalTs: FlinkLogicalDataStreamTableScan =
+ createDataStreamScan(ImmutableList.of("TemporalTable1"), flinkLogicalTraits)
+ val flinkLogicalWindowAgg = new FlinkLogicalWindowAggregate(
+ ts.getCluster,
+ logicalTraits,
+ new FlinkLogicalCalc(ts.getCluster, flinkLogicalTraits, flinkLogicalTs, program),
+ ImmutableBitSet.of(0, 1),
+ aggCallOfWindowAgg,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg)
+
+ val batchTs: BatchExecBoundedStreamScan =
+ createDataStreamScan(ImmutableList.of("TemporalTable1"), batchPhysicalTraits)
+ val batchCalc = new BatchExecCalc(
+ cluster, batchPhysicalTraits, batchTs, program, program.getOutputRowType)
+ val hash01 = FlinkRelDistribution.hash(Array(0, 1), requireStrict = true)
+ val batchExchange1 = new BatchExecExchange(
+ cluster, batchPhysicalTraits.replace(hash01), batchCalc, hash01)
+ val (_, _, aggregates) =
+ AggregateUtil.transformToBatchAggregateFunctions(
+ flinkLogicalWindowAgg.getAggCallList, batchExchange1.getRowType)
+ val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates)
+
+ val localWindowAggTypes =
+ (Array(0, 1).map(batchCalc.getRowType.getFieldList.get(_).getType) ++ // grouping
+ Array(longType) ++ // assignTs
+ aggCallOfWindowAgg.map(_.getType)).toList // agg calls
+ val localWindowAggNames =
+ (Array(0, 1).map(batchCalc.getRowType.getFieldNames.get(_)) ++ // grouping
+ Array("assignedWindow$") ++ // assignTs
+ Array("count$0")).toList // agg calls
+ val localWindowAggRowType = typeFactory.createStructType(
+ localWindowAggTypes, localWindowAggNames)
+ val batchLocalWindowAgg = new BatchExecLocalHashWindowAggregate(
+ batchCalc.getCluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchCalc,
+ localWindowAggRowType,
+ batchCalc.getRowType,
+ Array(0, 1),
+ Array.empty,
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false)
+ val batchExchange2 = new BatchExecExchange(
+ cluster, batchPhysicalTraits.replace(hash01), batchLocalWindowAgg, hash01)
+ val batchWindowAggWithLocal = new BatchExecHashWindowAggregate(
+ cluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchExchange2,
+ flinkLogicalWindowAgg.getRowType,
+ batchExchange2.getRowType,
+ batchCalc.getRowType,
+ Array(0, 1),
+ Array.empty,
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false,
+ isMerge = true
+ )
+
+ val batchWindowAggWithoutLocal = new BatchExecHashWindowAggregate(
+ batchExchange1.getCluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchExchange1,
+ flinkLogicalWindowAgg.getRowType,
+ batchExchange1.getRowType,
+ batchExchange1.getRowType,
+ Array(0, 1),
+ Array.empty,
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false,
+ isMerge = false
+ )
+
+ val streamTs: StreamExecDataStreamScan =
+ createDataStreamScan(ImmutableList.of("TemporalTable1"), streamPhysicalTraits)
+ val streamCalc = new BatchExecCalc(
+ cluster, streamPhysicalTraits, streamTs, program, program.getOutputRowType)
+ val streamExchange = new StreamExecExchange(
+ cluster, streamPhysicalTraits.replace(hash01), streamCalc, hash01)
+ val emitStrategy = WindowEmitStrategy(tableConfig, tumblingGroupWindow)
+ val streamWindowAgg = new StreamExecGroupWindowAggregate(
+ cluster,
+ streamPhysicalTraits,
+ streamExchange,
+ flinkLogicalWindowAgg.getRowType,
+ streamExchange.getRowType,
+ Array(0, 1),
+ flinkLogicalWindowAgg.getAggCallList,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg,
+ inputTimeFieldIndex = 2,
+ emitStrategy
+ )
+
+ (logicalWindowAgg, flinkLogicalWindowAgg, batchLocalWindowAgg, batchWindowAggWithLocal,
+ batchWindowAggWithoutLocal, streamWindowAgg)
+ }
+
+ // equivalent SQL is
+ // select b, count(a) as s,
+ // TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as w$start,
+ // TUMBLE_END(rowtime, INTERVAL '15' MINUTE) as w$end,
+ // TUMBLE_ROWTIME(rowtime, INTERVAL '15' MINUTE) as w$rowtime,
+ // TUMBLE_PROCTIME(rowtime, INTERVAL '15' MINUTE) as w$proctime
+ // from TemporalTable1 group by b, TUMBLE(rowtime, INTERVAL '15' MINUTE)
+ protected lazy val (
+ logicalWindowAgg2,
+ flinkLogicalWindowAgg2,
+ batchLocalWindowAgg2,
+ batchGlobalWindowAggWithLocalAgg2,
+ batchGlobalWindowAggWithoutLocalAgg2,
+ streamWindowAgg2) = {
+ relBuilder.scan("TemporalTable1")
+ val ts = relBuilder.peek()
+ val project = relBuilder.project(relBuilder.fields(Seq[Integer](0, 1, 4).toList))
+ .build().asInstanceOf[Project]
+ val program = RexProgram.create(
+ ts.getRowType, project.getProjects, null, project.getRowType, rexBuilder)
+ val aggCallOfWindowAgg = Lists.newArrayList(AggregateCall.create(
+ new SqlCountAggFunction("COUNT"), false, false, List[Integer](0), -1, 1, project, null, "s"))
+ // TUMBLE(rowtime, INTERVAL '15' MINUTE))
+ val logicalWindowAgg = new LogicalWindowAggregate(
+ ts.getCluster,
+ ts.getTraitSet,
+ project,
+ ImmutableBitSet.of(1),
+ aggCallOfWindowAgg,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg)
+
+ val flinkLogicalTs: FlinkLogicalDataStreamTableScan =
+ createDataStreamScan(ImmutableList.of("TemporalTable1"), flinkLogicalTraits)
+ val flinkLogicalWindowAgg = new FlinkLogicalWindowAggregate(
+ ts.getCluster,
+ logicalTraits,
+ new FlinkLogicalCalc(ts.getCluster, flinkLogicalTraits, flinkLogicalTs, program),
+ ImmutableBitSet.of(1),
+ aggCallOfWindowAgg,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg)
+
+ val batchTs: BatchExecBoundedStreamScan =
+ createDataStreamScan(ImmutableList.of("TemporalTable1"), batchPhysicalTraits)
+ val batchCalc = new BatchExecCalc(
+ cluster, batchPhysicalTraits, batchTs, program, program.getOutputRowType)
+ val hash1 = FlinkRelDistribution.hash(Array(1), requireStrict = true)
+ val batchExchange1 = new BatchExecExchange(
+ cluster, batchPhysicalTraits.replace(hash1), batchCalc, hash1)
+ val (_, _, aggregates) =
+ AggregateUtil.transformToBatchAggregateFunctions(
+ flinkLogicalWindowAgg.getAggCallList, batchExchange1.getRowType)
+ val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates)
+
+ val localWindowAggTypes =
+ (Array(batchCalc.getRowType.getFieldList.get(1).getType) ++ // grouping
+ Array(longType) ++ // assignTs
+ aggCallOfWindowAgg.map(_.getType)).toList // agg calls
+ val localWindowAggNames =
+ (Array(batchCalc.getRowType.getFieldNames.get(1)) ++ // grouping
+ Array("assignedWindow$") ++ // assignTs
+ Array("count$0")).toList // agg calls
+ val localWindowAggRowType = typeFactory.createStructType(
+ localWindowAggTypes, localWindowAggNames)
+ val batchLocalWindowAgg = new BatchExecLocalHashWindowAggregate(
+ batchCalc.getCluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchCalc,
+ localWindowAggRowType,
+ batchCalc.getRowType,
+ Array(1),
+ Array.empty,
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false)
+ val batchExchange2 = new BatchExecExchange(
+ cluster, batchPhysicalTraits.replace(hash1), batchLocalWindowAgg, hash1)
+ val batchWindowAggWithLocal = new BatchExecHashWindowAggregate(
+ cluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchExchange2,
+ flinkLogicalWindowAgg.getRowType,
+ batchExchange2.getRowType,
+ batchCalc.getRowType,
+ Array(0),
+ Array.empty,
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false,
+ isMerge = true
+ )
+
+ val batchWindowAggWithoutLocal = new BatchExecHashWindowAggregate(
+ batchExchange1.getCluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchExchange1,
+ flinkLogicalWindowAgg.getRowType,
+ batchExchange1.getRowType,
+ batchExchange1.getRowType,
+ Array(1),
+ Array.empty,
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false,
+ isMerge = false
+ )
+
+ val streamTs: StreamExecDataStreamScan =
+ createDataStreamScan(ImmutableList.of("TemporalTable1"), streamPhysicalTraits)
+ val streamCalc = new BatchExecCalc(
+ cluster, streamPhysicalTraits, streamTs, program, program.getOutputRowType)
+ val streamExchange = new StreamExecExchange(
+ cluster, streamPhysicalTraits.replace(hash1), streamCalc, hash1)
+ val emitStrategy = WindowEmitStrategy(tableConfig, tumblingGroupWindow)
+ val streamWindowAgg = new StreamExecGroupWindowAggregate(
+ cluster,
+ streamPhysicalTraits,
+ streamExchange,
+ flinkLogicalWindowAgg.getRowType,
+ streamExchange.getRowType,
+ Array(1),
+ flinkLogicalWindowAgg.getAggCallList,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg,
+ inputTimeFieldIndex = 2,
+ emitStrategy
+ )
+
+ (logicalWindowAgg, flinkLogicalWindowAgg, batchLocalWindowAgg, batchWindowAggWithLocal,
+ batchWindowAggWithoutLocal, streamWindowAgg)
+ }
+
+ // equivalent SQL is
+ // select a, c, count(b) as s,
+ // TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as w$start,
+ // TUMBLE_END(rowtime, INTERVAL '15' MINUTE) as w$end,
+ // TUMBLE_ROWTIME(rowtime, INTERVAL '15' MINUTE) as w$rowtime,
+ // TUMBLE_PROCTIME(rowtime, INTERVAL '15' MINUTE) as w$proctime
+ // from TemporalTable2 group by a, c, TUMBLE(rowtime, INTERVAL '15' MINUTE)
+ protected lazy val (
+ logicalWindowAggWithAuxGroup,
+ flinkLogicalWindowAggWithAuxGroup,
+ batchLocalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup) = {
+ relBuilder.scan("TemporalTable2")
+ val ts = relBuilder.peek()
+ val project = relBuilder.project(relBuilder.fields(Seq[Integer](0, 2, 4, 1).toList))
+ .build().asInstanceOf[Project]
+ val program = RexProgram.create(
+ ts.getRowType, project.getProjects, null, project.getRowType, rexBuilder)
+ val aggCallOfWindowAgg = Lists.newArrayList(
+ AggregateCall.create(FlinkSqlOperatorTable.AUXILIARY_GROUP, false, false,
+ List[Integer](1), -1, 1, project, null, "c"),
+ AggregateCall.create(new SqlCountAggFunction("COUNT"), false, false,
+ List[Integer](3), -1, 2, project, null, "s"))
+ // TUMBLE(rowtime, INTERVAL '15' MINUTE))
+ val logicalWindowAggWithAuxGroup = new LogicalWindowAggregate(
+ ts.getCluster,
+ ts.getTraitSet,
+ project,
+ ImmutableBitSet.of(0),
+ aggCallOfWindowAgg,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg)
+
+ val flinkLogicalTs: FlinkLogicalDataStreamTableScan =
+ createDataStreamScan(ImmutableList.of("TemporalTable2"), flinkLogicalTraits)
+ val flinkLogicalWindowAggWithAuxGroup = new FlinkLogicalWindowAggregate(
+ ts.getCluster,
+ logicalTraits,
+ new FlinkLogicalCalc(ts.getCluster, flinkLogicalTraits, flinkLogicalTs, program),
+ ImmutableBitSet.of(0),
+ aggCallOfWindowAgg,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg)
+
+ val batchTs: BatchExecBoundedStreamScan =
+ createDataStreamScan(ImmutableList.of("TemporalTable2"), batchPhysicalTraits)
+ val batchCalc = new BatchExecCalc(
+ cluster, batchPhysicalTraits, batchTs, program, program.getOutputRowType)
+ val hash0 = FlinkRelDistribution.hash(Array(0), requireStrict = true)
+ val batchExchange1 = new BatchExecExchange(
+ cluster, batchPhysicalTraits.replace(hash0), batchCalc, hash0)
+ val aggCallsWithoutAuxGroup = flinkLogicalWindowAggWithAuxGroup.getAggCallList.drop(1)
+ val (_, _, aggregates) =
+ AggregateUtil.transformToBatchAggregateFunctions(
+ aggCallsWithoutAuxGroup, batchExchange1.getRowType)
+ val aggCallToAggFunction = aggCallsWithoutAuxGroup.zip(aggregates)
+
+ val localWindowAggTypes =
+ (Array(batchCalc.getRowType.getFieldList.get(0).getType) ++ // grouping
+ Array(longType) ++ // assignTs
+ Array(batchCalc.getRowType.getFieldList.get(1).getType) ++ // auxGrouping
+ aggCallsWithoutAuxGroup.map(_.getType)).toList // agg calls
+ val localWindowAggNames =
+ (Array(batchCalc.getRowType.getFieldNames.get(0)) ++ // grouping
+ Array("assignedWindow$") ++ // assignTs
+ Array(batchCalc.getRowType.getFieldNames.get(1)) ++ // auxGrouping
+ Array("count$0")).toList // agg calls
+ val localWindowAggRowType = typeFactory.createStructType(
+ localWindowAggTypes, localWindowAggNames)
+ val batchLocalWindowAggWithAuxGroup = new BatchExecLocalHashWindowAggregate(
+ batchCalc.getCluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchCalc,
+ localWindowAggRowType,
+ batchCalc.getRowType,
+ Array(0),
+ Array(1),
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false)
+ val batchExchange2 = new BatchExecExchange(
+ cluster, batchPhysicalTraits.replace(hash0), batchLocalWindowAggWithAuxGroup, hash0)
+ val batchWindowAggWithLocalWithAuxGroup = new BatchExecHashWindowAggregate(
+ cluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchExchange2,
+ flinkLogicalWindowAggWithAuxGroup.getRowType,
+ batchExchange2.getRowType,
+ batchCalc.getRowType,
+ Array(0),
+ Array(2), // local output grouping keys: grouping + assignTs + auxGrouping
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false,
+ isMerge = true
+ )
+
+ val batchWindowAggWithoutLocalWithAuxGroup = new BatchExecHashWindowAggregate(
+ batchExchange1.getCluster,
+ relBuilder,
+ batchPhysicalTraits,
+ batchExchange1,
+ flinkLogicalWindowAggWithAuxGroup.getRowType,
+ batchExchange1.getRowType,
+ batchExchange1.getRowType,
+ Array(0),
+ Array(1),
+ aggCallToAggFunction,
+ tumblingGroupWindow,
+ inputTimeFieldIndex = 2,
+ inputTimeIsDate = false,
+ namedPropertiesOfWindowAgg,
+ enableAssignPane = false,
+ isMerge = false
+ )
+
+ (logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchLocalWindowAggWithAuxGroup, batchWindowAggWithLocalWithAuxGroup,
+ batchWindowAggWithoutLocalWithAuxGroup)
+ }
+
// equivalent SQL is
// select id, name, score, age, class,
// row_number() over(partition by class order by name) as rn,
@@ -952,7 +1393,7 @@ class FlinkRelMdHandlerTestBase {
// max(score) over (partition by age) as max_score,
// count(id) over (partition by age) as cnt
// from student
- protected lazy val (flinkLogicalOverWindow, batchOverWindowAgg) = {
+ protected lazy val (flinkLogicalOverAgg, batchOverAgg) = {
val types = Map(
"id" -> longType,
"name" -> stringType,
@@ -989,26 +1430,26 @@ class FlinkRelMdHandlerTestBase {
val rowTypeOfWindowAgg = createRowType(
"id", "name", "score", "age", "class", "rn", "rk", "drk",
"count$0_score", "sum$0_score", "max_score", "cnt")
- val flinkLogicalOverWindow = new FlinkLogicalOverWindow(
+ val flinkLogicalOverAgg = new FlinkLogicalOverAggregate(
cluster,
flinkLogicalTraits,
new FlinkLogicalCalc(cluster, flinkLogicalTraits, studentFlinkLogicalScan, rexProgram),
ImmutableList.of(),
rowTypeOfWindowAgg,
- overWindowGroups
+ overAggGroups
)
val rowTypeOfWindowAggOutput = createRowType(
"id", "name", "score", "age", "class", "rn", "rk", "drk", "avg_score", "max_score", "cnt")
val projectProgram = RexProgram.create(
- flinkLogicalOverWindow.getRowType,
- (0 until flinkLogicalOverWindow.getRowType.getFieldCount).flatMap { i =>
+ flinkLogicalOverAgg.getRowType,
+ (0 until flinkLogicalOverAgg.getRowType.getFieldCount).flatMap { i =>
if (i < 8 || i >= 10) {
- Array[RexNode](RexInputRef.of(i, flinkLogicalOverWindow.getRowType))
+ Array[RexNode](RexInputRef.of(i, flinkLogicalOverAgg.getRowType))
} else if (i == 8) {
Array[RexNode](rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
- RexInputRef.of(8, flinkLogicalOverWindow.getRowType),
- RexInputRef.of(9, flinkLogicalOverWindow.getRowType)))
+ RexInputRef.of(8, flinkLogicalOverAgg.getRowType),
+ RexInputRef.of(9, flinkLogicalOverAgg.getRowType)))
} else {
Array.empty[RexNode]
}
@@ -1018,10 +1459,10 @@ class FlinkRelMdHandlerTestBase {
rexBuilder
)
- val flinkLogicalOverWindowOutput = new FlinkLogicalCalc(
+ val flinkLogicalOverAggOutput = new FlinkLogicalCalc(
cluster,
flinkLogicalTraits,
- flinkLogicalOverWindow,
+ flinkLogicalOverAgg,
projectProgram
)
@@ -1048,11 +1489,11 @@ class FlinkRelMdHandlerTestBase {
Array(1),
Array(true),
Array(false),
- Seq((overWindowGroups(0), Seq(
+ Seq((overAggGroups(0), Seq(
(AggregateCall.create(SqlStdOperatorTable.ROW_NUMBER, false, ImmutableList.of(), -1,
longType, "rn"),
new RowNumberAggFunction())))),
- flinkLogicalOverWindow
+ flinkLogicalOverAgg
)
// sort class, score
@@ -1075,7 +1516,7 @@ class FlinkRelMdHandlerTestBase {
Array(2),
Array(true),
Array(false),
- Seq((overWindowGroups(1), Seq(
+ Seq((overAggGroups(1), Seq(
(AggregateCall.create(SqlStdOperatorTable.RANK, false, ImmutableList.of(), -1, longType,
"rk"),
new RankAggFunction(Array(InternalTypes.STRING))),
@@ -1089,7 +1530,7 @@ class FlinkRelMdHandlerTestBase {
ImmutableList.of(Integer.valueOf(2)), -1, doubleType, "sum$0_score"),
new DoubleSumAggFunction())
))),
- flinkLogicalOverWindow
+ flinkLogicalOverAgg
)
val hash3 = FlinkRelDistribution.hash(Array(3), requireStrict = true)
@@ -1110,7 +1551,7 @@ class FlinkRelMdHandlerTestBase {
Array.empty,
Array.empty,
Array.empty,
- Seq((overWindowGroups(2), Seq(
+ Seq((overAggGroups(2), Seq(
(AggregateCall.create(SqlStdOperatorTable.MAX, false,
ImmutableList.of(Integer.valueOf(2)), -1, longType, "max_score"),
new CountAggFunction()),
@@ -1118,7 +1559,7 @@ class FlinkRelMdHandlerTestBase {
ImmutableList.of(Integer.valueOf(0)), -1, doubleType, "cnt"),
new DoubleSumAggFunction())
))),
- flinkLogicalOverWindow
+ flinkLogicalOverAgg
)
val batchWindowAggOutput = new BatchExecCalc(
@@ -1129,7 +1570,7 @@ class FlinkRelMdHandlerTestBase {
projectProgram.getOutputRowType
)
- (flinkLogicalOverWindowOutput, batchWindowAggOutput)
+ (flinkLogicalOverAggOutput, batchWindowAggOutput)
}
// equivalent SQL is
@@ -1138,7 +1579,7 @@ class FlinkRelMdHandlerTestBase {
// dense_rank() over (partition by class order by score) as drk,
// avg(score) over (partition by class order by score) as avg_score
// from student
- protected lazy val streamOverWindowAgg: StreamPhysicalRel = {
+ protected lazy val streamOverAgg: StreamPhysicalRel = {
val types = Map(
"id" -> longType,
"name" -> stringType,
@@ -1171,13 +1612,13 @@ class FlinkRelMdHandlerTestBase {
val rowTypeOfWindowAgg = createRowType(
"id", "name", "score", "age", "class", "rk", "drk", "count$0_score", "sum$0_score")
- val flinkLogicalOverWindow = new FlinkLogicalOverWindow(
+ val flinkLogicalOverAgg = new FlinkLogicalOverAggregate(
cluster,
flinkLogicalTraits,
new FlinkLogicalCalc(cluster, flinkLogicalTraits, studentFlinkLogicalScan, rexProgram),
ImmutableList.of(),
rowTypeOfWindowAgg,
- util.Arrays.asList(overWindowGroups.get(1))
+ util.Arrays.asList(overAggGroups.get(1))
)
val streamScan: StreamExecDataStreamScan =
@@ -1193,20 +1634,20 @@ class FlinkRelMdHandlerTestBase {
exchange,
rowTypeOfWindowAgg,
exchange.getRowType,
- flinkLogicalOverWindow
+ flinkLogicalOverAgg
)
val rowTypeOfWindowAggOutput = createRowType(
"id", "name", "score", "age", "class", "rk", "drk", "avg_score")
val projectProgram = RexProgram.create(
- flinkLogicalOverWindow.getRowType,
- (0 until flinkLogicalOverWindow.getRowType.getFieldCount).flatMap { i =>
+ flinkLogicalOverAgg.getRowType,
+ (0 until flinkLogicalOverAgg.getRowType.getFieldCount).flatMap { i =>
if (i < 7) {
- Array[RexNode](RexInputRef.of(i, flinkLogicalOverWindow.getRowType))
+ Array[RexNode](RexInputRef.of(i, flinkLogicalOverAgg.getRowType))
} else if (i == 7) {
Array[RexNode](rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
- RexInputRef.of(7, flinkLogicalOverWindow.getRowType),
- RexInputRef.of(8, flinkLogicalOverWindow.getRowType)))
+ RexInputRef.of(7, flinkLogicalOverAgg.getRowType),
+ RexInputRef.of(8, flinkLogicalOverAgg.getRowType)))
} else {
Array.empty[RexNode]
}
@@ -1232,7 +1673,7 @@ class FlinkRelMdHandlerTestBase {
// avg(score) over (partition by class order by score) as avg_score,
// max(score) over (partition by age) as max_score,
// count(id) over (partition by age) as cnt
- private lazy val overWindowGroups = {
+ private lazy val overAggGroups = {
ImmutableList.of(
new Window.Group(
ImmutableBitSet.of(5),
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSizeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSizeTest.scala
index 220b490..7147b16 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSizeTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdPopulationSizeTest.scala
@@ -222,8 +222,52 @@ class FlinkRelMdPopulationSizeTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testGetPopulationSizeOnOverWindow(): Unit = {
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
+ def testGetPopulationSizeOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
+ batchGlobalWindowAggWithLocalAgg).foreach { agg =>
+ assertEquals(30D, mq.getPopulationSize(agg, ImmutableBitSet.of(0)))
+ assertEquals(5D, mq.getPopulationSize(agg, ImmutableBitSet.of(1)))
+ assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 1)))
+ assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 2)))
+ assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(3)))
+ assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 3)))
+ assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(1, 3)))
+ assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(2, 3)))
+ }
+ assertEquals(30D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(0)))
+ assertEquals(5D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(1)))
+ assertEquals(null, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(2)))
+ assertEquals(50D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(0, 1)))
+ assertEquals(null, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(0, 2)))
+ assertEquals(10D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(3)))
+ assertEquals(50D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(0, 3)))
+ assertEquals(50D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(1, 3)))
+ assertEquals(null, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(2, 3)))
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
+ assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0)))
+ assertEquals(48D, mq.getPopulationSize(agg, ImmutableBitSet.of(1)))
+ assertEquals(10D, mq.getPopulationSize(agg, ImmutableBitSet.of(2)))
+ assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(3)))
+ assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 1)))
+ assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 1, 2)))
+ assertEquals(null, mq.getPopulationSize( agg, ImmutableBitSet.of(0, 1, 3)))
+ }
+ assertEquals(50D, mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0)))
+ assertNull(mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(1)))
+ assertEquals(48D, mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(2)))
+ assertEquals(10D, mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(3)))
+ assertNull(mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1)))
+ assertEquals(50D,
+ mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 2)))
+ assertNull(mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1, 3)))
+ }
+
+ @Test
+ def testGetPopulationSizeOnOverAgg(): Unit = {
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(1.0, mq.getPopulationSize(agg, ImmutableBitSet.of()))
assertEquals(50.0, mq.getPopulationSize(agg, ImmutableBitSet.of(0)))
assertEquals(48.0, mq.getPopulationSize(agg, ImmutableBitSet.of(1)))
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCountTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCountTest.scala
index 841c0cf..61fec6d 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCountTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdRowCountTest.scala
@@ -18,9 +18,15 @@
package org.apache.flink.table.plan.metadata
+import org.apache.flink.table.plan.nodes.calcite.LogicalWindowAggregate
import org.apache.flink.table.plan.util.FlinkRelMdUtil
+import com.google.common.collect.Lists
+import org.apache.calcite.rel.core.{AggregateCall, Project}
+import org.apache.calcite.rex.RexProgram
+import org.apache.calcite.sql.fun.SqlCountAggFunction
import org.apache.calcite.sql.fun.SqlStdOperatorTable.LESS_THAN
+import org.apache.calcite.util.ImmutableBitSet
import org.junit.Assert._
import org.junit.Test
@@ -135,8 +141,38 @@ class FlinkRelMdRowCountTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testGetRowCountOnOverWindow(): Unit = {
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
+ def testGetRowCountOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg, batchLocalWindowAgg,
+ batchGlobalWindowAggWithoutLocalAgg,
+ batchGlobalWindowAggWithLocalAgg, streamWindowAgg).foreach { agg =>
+ assertEquals(50D, mq.getRowCount(agg))
+ }
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchLocalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
+ assertEquals(50D, mq.getRowCount(agg))
+ }
+
+ relBuilder.clear()
+ val ts = relBuilder.scan("TemporalTable3").peek()
+ val aggCallOfWindowAgg = Lists.newArrayList(AggregateCall.create(
+ new SqlCountAggFunction("COUNT"), false, false, List[Integer](3), -1, 2, ts, null, "s"))
+ val windowAgg = new LogicalWindowAggregate(
+ ts.getCluster,
+ ts.getTraitSet,
+ ts,
+ ImmutableBitSet.of(0, 1),
+ aggCallOfWindowAgg,
+ tumblingGroupWindow,
+ namedPropertiesOfWindowAgg)
+ assertEquals(4000000000D, mq.getRowCount(windowAgg))
+ }
+
+ @Test
+ def testGetRowCountOnOverAgg(): Unit = {
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(50.0, mq.getRowCount(agg))
}
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivityTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivityTest.scala
index 19c1b55..9dfa3b5 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivityTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSelectivityTest.scala
@@ -20,7 +20,10 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable
import org.apache.flink.table.plan.nodes.calcite.LogicalExpand
-import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalDataStreamTableScan, FlinkLogicalExpand, FlinkLogicalOverWindow}
+import org.apache.flink.table.plan.nodes.logical.{
+ FlinkLogicalDataStreamTableScan,
+ FlinkLogicalExpand, FlinkLogicalOverAggregate
+}
import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecCalc, BatchExecRank}
import org.apache.flink.table.plan.util.ExpandUtil
@@ -355,7 +358,58 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testGetSelectivityOnOverWindow(): Unit = {
+ def testGetSelectivityOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
+ batchGlobalWindowAggWithLocalAgg).foreach { agg =>
+ relBuilder.clear()
+ relBuilder.push(agg)
+ // predicate without time fields and aggCall fields
+ // a > 15
+ val predicate1 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15))
+ assertEquals(0.75D, mq.getSelectivity(agg, predicate1))
+
+ // predicate with time fields only
+ // a > 15 and w$end = 1000000
+ val predicate2 = relBuilder.and(
+ relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15)),
+ relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(1000000))
+ )
+ assertEquals(0.75D * 0.15D, mq.getSelectivity(agg, predicate2))
+
+ // predicate with time fields and aggCall fields
+ // a > 15 and count(c) > 100 and w$end = 1000000
+ val predicate3 = relBuilder.and(
+ relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15)),
+ relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(100)),
+ relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(1000000))
+ )
+ assertEquals(0.75D * 0.15D * 0.01D, mq.getSelectivity(agg, predicate3))
+ }
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
+ relBuilder.clear()
+ relBuilder.push(agg)
+ // a > 15
+ val predicate4 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15))
+ assertEquals(0.8D, mq.getSelectivity(agg, predicate4))
+ // b > 15
+ val predicate5 = relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(15))
+ assertEquals(0.7D, mq.getSelectivity(agg, predicate5))
+ // a > 15 and b > 15 and count(c) > 100 and w$end = 1000000
+ val predicate6 = relBuilder.and(
+ relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15)),
+ relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(15)),
+ relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(100)),
+ relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(1000000))
+ )
+ assertEquals(0.8D * 0.7D * 0.15D * 0.01D, mq.getSelectivity(agg, predicate6))
+ }
+ }
+
+ @Test
+ def testGetSelectivityOnOverAgg(): Unit = {
// select a, b, c, d,
// rank() over (partition by c order by d) as rk,
// max(d) over(partition by c order by d) as max_d from MyTable4
@@ -363,7 +417,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase {
ImmutableList.of(), -1, longType, "rk")
val maxAggCall = AggregateCall.create(SqlStdOperatorTable.MAX, false,
ImmutableList.of(Integer.valueOf(3)), -1, doubleType, "max_d")
- val overWindowGroups = ImmutableList.of(new Window.Group(
+ val overAggGroups = ImmutableList.of(new Window.Group(
ImmutableBitSet.of(2),
true,
RexWindowBound.create(SqlWindow.createUnboundedPreceding(new SqlParserPos(0, 0)), null),
@@ -383,8 +437,8 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase {
scan.getRowType.getFieldList.foreach(f => builder.add(f.getName, f.getType))
builder.add(rankAggCall.getName, rankAggCall.getType)
builder.add(maxAggCall.getName, maxAggCall.getType)
- val overWindow = new FlinkLogicalOverWindow(cluster, flinkLogicalTraits, scan,
- ImmutableList.of(), builder.build(), overWindowGroups)
+ val overWindow = new FlinkLogicalOverAggregate(cluster, flinkLogicalTraits, scan,
+ ImmutableList.of(), builder.build(), overAggGroups)
relBuilder.push(overWindow)
// a <= 10
@@ -401,7 +455,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase {
relBuilder.call(LESS_THAN, relBuilder.field(4), relBuilder.literal(2)))
assertEquals(1 / 25.0 * ((10.0 - 1.0) / (50.0 - 1)) * 0.5, mq.getSelectivity(overWindow, pred3))
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
relBuilder.clear()
relBuilder.push(agg)
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSizeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSizeTest.scala
index 25e4c96..c94c907 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSizeTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdSizeTest.scala
@@ -135,13 +135,31 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testAverageColumnSizeOnOverWindow(): Unit = {
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
+ def testAverageColumnSizeOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
+ batchGlobalWindowAggWithLocalAgg).foreach { agg =>
+ assertEquals(Seq(4D, 32D, 8D, 12D, 12D, 12D, 12D), mq.getAverageColumnSizes(agg).toSeq)
+ }
+ assertEquals(Seq(4.0, 32.0, 8.0, 8.0),
+ mq.getAverageColumnSizes(batchLocalWindowAgg).toSeq)
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
+ assertEquals(Seq(8D, 4D, 8D, 12D, 12D, 12D, 12D), mq.getAverageColumnSizes(agg).toSeq)
+ }
+ assertEquals(Seq(8D, 8D, 4D, 8D),
+ mq.getAverageColumnSizes(batchLocalWindowAggWithAuxGroup).toSeq)
+ }
+
+ @Test
+ def testAverageColumnSizeOnOverAgg(): Unit = {
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(Seq(8.0, 7.2, 8.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0),
mq.getAverageColumnSizes(agg).toList)
}
assertEquals(Seq(8.0, 12.0, 8.0, 4.0, 4.0, 8.0, 8.0, 8.0),
- mq.getAverageColumnSizes(streamOverWindowAgg).toList)
+ mq.getAverageColumnSizes(streamOverAgg).toList)
}
@Test
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueGroupsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueGroupsTest.scala
index 9812e2a..8417023 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueGroupsTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueGroupsTest.scala
@@ -311,8 +311,47 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testGetUniqueGroupsOnOverWindow(): Unit = {
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
+ def testGetUniqueGroupsOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg,
+ batchGlobalWindowAggWithoutLocalAgg,
+ batchGlobalWindowAggWithLocalAgg).foreach { agg =>
+ assertEquals(ImmutableBitSet.of(0, 1, 2, 3, 4, 5, 6),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2, 3, 4, 5, 6)))
+ assertEquals(ImmutableBitSet.of(3, 4, 5, 6),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(3, 4, 5, 6)))
+ assertEquals(ImmutableBitSet.of(0, 3, 4, 5, 6),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 3, 4, 5, 6)))
+ assertEquals(ImmutableBitSet.of(0, 1),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1)))
+ assertEquals(ImmutableBitSet.of(0, 1, 2),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2)))
+ }
+ assertEquals(ImmutableBitSet.of(0, 1, 2, 3),
+ mq.getUniqueGroups(batchLocalWindowAgg, ImmutableBitSet.of(0, 1, 2, 3)))
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
+ assertEquals(ImmutableBitSet.of(1),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(1)))
+ assertEquals(ImmutableBitSet.of(0),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1)))
+ assertEquals(ImmutableBitSet.of(0, 1, 2),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2)))
+ assertEquals(ImmutableBitSet.of(0, 1, 2, 3),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2, 3)))
+ assertEquals(ImmutableBitSet.of(0, 1, 2, 3, 4, 5, 6),
+ mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2, 3, 4, 5, 6)))
+ }
+ assertEquals(ImmutableBitSet.of(0),
+ mq.getUniqueGroups(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1)))
+ assertEquals(ImmutableBitSet.of(0, 1, 2),
+ mq.getUniqueGroups(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1, 2)))
+ }
+
+ @Test
+ def testGetUniqueGroupsOnOverAgg(): Unit = {
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1)))
assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2)))
assertEquals(ImmutableBitSet.of(1, 2), mq.getUniqueGroups(agg, ImmutableBitSet.of(1, 2)))
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeysTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeysTest.scala
index dce3390..9e034a6 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeysTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/FlinkRelMdUniqueKeysTest.scala
@@ -21,7 +21,7 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.nodes.calcite.LogicalExpand
import org.apache.flink.table.plan.util.ExpandUtil
-import com.google.common.collect.ImmutableList
+import com.google.common.collect.{ImmutableList, ImmutableSet}
import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN}
import org.apache.calcite.util.ImmutableBitSet
import org.junit.Assert._
@@ -169,12 +169,32 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase {
}
@Test
- def testGetUniqueKeysOnOverWindow(): Unit = {
- Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
+ def testGetUniqueKeysOnWindowAgg(): Unit = {
+ Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
+ batchGlobalWindowAggWithLocalAgg).foreach { agg =>
+ assertEquals(ImmutableSet.of(ImmutableBitSet.of(0, 1, 3), ImmutableBitSet.of(0, 1, 4),
+ ImmutableBitSet.of(0, 1, 5), ImmutableBitSet.of(0, 1, 6)),
+ mq.getUniqueKeys(agg))
+ }
+ assertNull(mq.getUniqueKeys(batchLocalWindowAgg))
+
+ Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+ batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
+ batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
+ assertEquals(ImmutableSet.of(ImmutableBitSet.of(0, 3), ImmutableBitSet.of(0, 4),
+ ImmutableBitSet.of(0, 5), ImmutableBitSet.of(0, 6)),
+ mq.getUniqueKeys(agg))
+ }
+ assertNull(mq.getUniqueKeys(batchLocalWindowAggWithAuxGroup))
+ }
+
+ @Test
+ def testGetUniqueKeysOnOverAgg(): Unit = {
+ Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(agg).toSet)
}
- assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamOverWindowAgg).toSet)
+ assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamOverAgg).toSet)
}
@Test
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/MetadataTestUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/MetadataTestUtil.scala
index 9806da6..43113a2 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/MetadataTestUtil.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/metadata/MetadataTestUtil.scala
@@ -89,7 +89,9 @@ object MetadataTestUtil {
rootSchema.add("MyTable2", createMyTable2())
rootSchema.add("MyTable3", createMyTable3())
rootSchema.add("MyTable4", createMyTable4())
- rootSchema.add("TemporalTable", createTemporalTable())
+ rootSchema.add("TemporalTable1", createTemporalTable1())
+ rootSchema.add("TemporalTable2", createTemporalTable2())
+ rootSchema.add("TemporalTable3", createTemporalTable3())
rootSchema
}
@@ -213,7 +215,48 @@ object MetadataTestUtil {
getDataStreamTable(schema, new FlinkStatistic(tableStats, uniqueKeys))
}
- private def createTemporalTable(): DataStreamTable[BaseRow] = {
+ private def createTemporalTable1(): DataStreamTable[BaseRow] = {
+ val fieldNames = Array("a", "b", "c", "proctime", "rowtime")
+ val fieldTypes = Array[InternalType](
+ InternalTypes.LONG,
+ InternalTypes.STRING,
+ InternalTypes.INT,
+ InternalTypes.PROCTIME_INDICATOR,
+ InternalTypes.ROWTIME_INDICATOR)
+
+ val colStatsMap = Map[String, ColumnStats](
+ "a" -> new ColumnStats(30L, 0L, 4D, 4, 45, 5),
+ "b" -> new ColumnStats(5L, 0L, 32D, 32, null, null),
+ "c" -> new ColumnStats(48L, 0L, 8D, 8, 50, 0)
+ )
+
+ val tableStats = new TableStats(50L, colStatsMap)
+ getDataStreamTable(fieldNames, fieldTypes, new FlinkStatistic(tableStats),
+ producesUpdates = false, isAccRetract = false)
+ }
+
+ private def createTemporalTable2(): DataStreamTable[BaseRow] = {
+ val fieldNames = Array("a", "b", "c", "proctime", "rowtime")
+ val fieldTypes = Array[InternalType](
+ InternalTypes.LONG,
+ InternalTypes.STRING,
+ InternalTypes.INT,
+ InternalTypes.PROCTIME_INDICATOR,
+ InternalTypes.ROWTIME_INDICATOR)
+
+ val colStatsMap = Map[String, ColumnStats](
+ "a" -> new ColumnStats(50L, 0L, 8D, 8, 55, 5),
+ "b" -> new ColumnStats(5L, 0L, 16D, 32, null, null),
+ "c" -> new ColumnStats(48L, 0L, 4D, 4, 50, 0)
+ )
+
+ val tableStats = new TableStats(50L, colStatsMap)
+ val uniqueKeys = Set(Set("a").asJava).asJava
+ getDataStreamTable(fieldNames, fieldTypes, new FlinkStatistic(tableStats, uniqueKeys),
+ producesUpdates = false, isAccRetract = false)
+ }
+
+ private def createTemporalTable3(): DataStreamTable[BaseRow] = {
val fieldNames = Array("a", "b", "c", "proctime", "rowtime")
val fieldTypes = Array[InternalType](
InternalTypes.INT,
@@ -228,7 +271,7 @@ object MetadataTestUtil {
"c" -> new ColumnStats(null, 0L, 18.6, 64, null, null)
)
- val tableStats = new TableStats(20000000L, colStatsMap)
+ val tableStats = new TableStats(4000000000L, colStatsMap)
getDataStreamTable(fieldNames, fieldTypes, new FlinkStatistic(tableStats),
producesUpdates = false, isAccRetract = false)
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/agg/OverWindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/agg/OverAggregateTest.scala
similarity index 99%
rename from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/agg/OverWindowAggregateTest.scala
rename to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/agg/OverAggregateTest.scala
index 0905cbd..a5f2f98 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/agg/OverWindowAggregateTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/agg/OverAggregateTest.scala
@@ -27,7 +27,7 @@ import org.apache.flink.table.util.TableTestBase
import org.junit.Assert.assertEquals
import org.junit.Test
-class OverWindowAggregateTest extends TableTestBase {
+class OverAggregateTest extends TableTestBase {
private val util = streamTestUtil()
util.addDataStream[(Int, String, Long)]("MyTable", 'a, 'b, 'c, 'proctime, 'rowtime)