You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/03/24 19:23:38 UTC
[4/5] flink git commit: [FLINK-5990] [table] Add event-time OVER ROWS
BETWEEN x PRECEDING aggregation to SQL.
[FLINK-5990] [table] Add event-time OVER ROWS BETWEEN x PRECEDING aggregation to SQL.
This closes #3585.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7a9d39fe
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7a9d39fe
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7a9d39fe
Branch: refs/heads/master
Commit: 7a9d39fe9f659d43bf4719a2981f6c4771ffbe48
Parents: 6949c8c
Author: \u91d1\u7af9 <ji...@alibaba-inc.com>
Authored: Sun Mar 19 23:31:00 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Mar 24 20:19:17 2017 +0100
----------------------------------------------------------------------
.../flink/table/plan/nodes/OverAggregate.scala | 31 ++-
.../datastream/DataStreamOverAggregate.scala | 149 +++++++++---
.../table/runtime/aggregate/AggregateUtil.scala | 48 +++-
.../RowsClauseBoundedOverProcessFunction.scala | 239 +++++++++++++++++++
.../table/api/scala/stream/sql/SqlITCase.scala | 139 ++++++++++-
.../scala/stream/sql/WindowAggregateTest.scala | 55 +++++
6 files changed, 623 insertions(+), 38 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
index 793ab23..91c8cef 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
@@ -18,12 +18,15 @@
package org.apache.flink.table.plan.nodes
-import org.apache.calcite.rel.RelFieldCollation
+import org.apache.calcite.rel.{RelFieldCollation, RelNode}
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl}
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.core.Window.Group
+import org.apache.calcite.rel.core.Window
+import org.apache.calcite.rex.{RexInputRef}
import org.apache.flink.table.runtime.aggregate.AggregateUtil._
import org.apache.flink.table.functions.{ProcTimeType, RowTimeType}
+
import scala.collection.JavaConverters._
trait OverAggregate {
@@ -46,8 +49,16 @@ trait OverAggregate {
orderingString
}
- private[flink] def windowRange(overWindow: Group): String = {
- s"BETWEEN ${overWindow.lowerBound} AND ${overWindow.upperBound}"
+ private[flink] def windowRange(
+ logicWindow: Window,
+ overWindow: Group,
+ input: RelNode): String = {
+ if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded) {
+ s"BETWEEN ${getLowerBoundary(logicWindow, overWindow, input)} PRECEDING " +
+ s"AND ${overWindow.upperBound}"
+ } else {
+ s"BETWEEN ${overWindow.lowerBound} AND ${overWindow.upperBound}"
+ }
}
private[flink] def aggregationToString(
@@ -92,4 +103,18 @@ trait OverAggregate {
}.mkString(", ")
}
+ private[flink] def getLowerBoundary(
+ logicWindow: Window,
+ overWindow: Group,
+ input: RelNode): Long = {
+
+ val ref: RexInputRef = overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef]
+ val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex;
+ val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2
+ lowerBound match {
+ case x: java.math.BigDecimal => x.asInstanceOf[java.math.BigDecimal].longValue()
+ case _ => lowerBound.asInstanceOf[Long]
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
index 34b3b0f..547c875 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
@@ -32,6 +32,7 @@ import org.apache.calcite.rel.core.Window
import org.apache.calcite.rel.core.Window.Group
import java.util.{List => JList}
+import org.apache.flink.api.java.functions.NullByteKeySelector
import org.apache.flink.table.functions.{ProcTimeType, RowTimeType}
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
@@ -70,9 +71,9 @@ class DataStreamOverAggregate(
super.explainTerms(pw)
.itemIf("partitionBy", partitionToString(inputType, partitionKeys), partitionKeys.nonEmpty)
- .item("orderBy",orderingToString(inputType, overWindow.orderKeys.getFieldCollations))
- .itemIf("rows", windowRange(overWindow), overWindow.isRows)
- .itemIf("range", windowRange(overWindow), !overWindow.isRows)
+ .item("orderBy",orderingToString(inputType, overWindow.orderKeys.getFieldCollations))
+ .itemIf("rows", windowRange(logicWindow, overWindow, getInput), overWindow.isRows)
+ .itemIf("range", windowRange(logicWindow, overWindow, getInput), !overWindow.isRows)
.item(
"select", aggregationToString(
inputType,
@@ -99,20 +100,58 @@ class DataStreamOverAggregate(
.getFieldList
.get(overWindow.orderKeys.getFieldCollations.get(0).getFieldIndex)
.getValue
-
timeType match {
case _: ProcTimeType =>
- // both ROWS and RANGE clause with UNBOUNDED PRECEDING and CURRENT ROW condition.
- if (overWindow.lowerBound.isUnbounded &&
- overWindow.upperBound.isCurrentRow) {
+ // proc-time OVER window
+ if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
+ // non-bounded OVER window
createUnboundedAndCurrentRowProcessingTimeOverWindow(inputDS)
+ } else if (
+ overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded &&
+ overWindow.upperBound.isCurrentRow) {
+ // bounded OVER window
+ if (overWindow.isRows) {
+ // ROWS clause bounded OVER window
+ throw new TableException(
+ "ROWS clause bounded proc-time OVER window no supported yet.")
+ } else {
+ // RANGE clause bounded OVER window
+ throw new TableException(
+ "RANGE clause bounded proc-time OVER window no supported yet.")
+ }
} else {
throw new TableException(
- "OVER window only support ProcessingTime UNBOUNDED PRECEDING and CURRENT ROW " +
- "condition.")
+ "OVER window only support ProcessingTime UNBOUNDED PRECEDING and CURRENT ROW " +
+ "condition.")
}
case _: RowTimeType =>
- throw new TableException("OVER Window of the EventTime type is not currently supported.")
+ // row-time OVER window
+ if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
+ // non-bounded OVER window
+ if (overWindow.isRows) {
+ // ROWS clause unbounded OVER window
+ throw new TableException(
+ "ROWS clause unbounded row-time OVER window no supported yet.")
+ } else {
+ // RANGE clause unbounded OVER window
+ throw new TableException(
+ "RANGE clause unbounded row-time OVER window no supported yet.")
+ }
+ } else if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded &&
+ overWindow.upperBound.isCurrentRow) {
+ // bounded OVER window
+ if (overWindow.isRows) {
+ // ROWS clause bounded OVER window
+ createRowsClauseBoundedAndCurrentRowOverWindow(inputDS, true)
+ } else {
+ // RANGE clause bounded OVER window
+ throw new TableException(
+ "RANGE clause bounded row-time OVER window no supported yet.")
+ }
+ } else {
+ throw new TableException(
+ "row-time OVER window only support CURRENT ROW condition.")
+ }
case _ =>
throw new TableException(s"Unsupported time type {$timeType}")
}
@@ -120,7 +159,7 @@ class DataStreamOverAggregate(
}
def createUnboundedAndCurrentRowProcessingTimeOverWindow(
- inputDS: DataStream[Row]): DataStream[Row] = {
+ inputDS: DataStream[Row]): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
@@ -130,32 +169,78 @@ class DataStreamOverAggregate(
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val result: DataStream[Row] =
- // partitioned aggregation
- if (partitionKeys.nonEmpty) {
- val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction(
- namedAggregates,
- inputType)
+ // partitioned aggregation
+ if (partitionKeys.nonEmpty) {
+ val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction(
+ namedAggregates,
+ inputType)
- inputDS
+ inputDS
.keyBy(partitionKeys: _*)
.process(processFunction)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
- }
- // non-partitioned aggregation
- else {
- val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction(
- namedAggregates,
- inputType,
- false)
-
- inputDS
- .process(processFunction).setParallelism(1).setMaxParallelism(1)
- .returns(rowTypeInfo)
- .name(aggOpName)
- .asInstanceOf[DataStream[Row]]
- }
+ }
+ // non-partitioned aggregation
+ else {
+ val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction(
+ namedAggregates,
+ inputType,
+ false)
+
+ inputDS
+ .process(processFunction).setParallelism(1).setMaxParallelism(1)
+ .returns(rowTypeInfo)
+ .name(aggOpName)
+ .asInstanceOf[DataStream[Row]]
+ }
+ result
+ }
+
+ def createRowsClauseBoundedAndCurrentRowOverWindow(
+ inputDS: DataStream[Row],
+ isRowTimeType: Boolean = false): DataStream[Row] = {
+
+ val overWindow: Group = logicWindow.groups.get(0)
+ val partitionKeys: Array[Int] = overWindow.keys.toArray
+ val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
+ val inputFields = (0 until inputType.getFieldCount).toArray
+
+ val precedingOffset =
+ getLowerBoundary(logicWindow, overWindow, getInput()) + 1
+
+ // get the output types
+ val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
+
+ val processFunction = AggregateUtil.createRowsClauseBoundedOverProcessFunction(
+ namedAggregates,
+ inputType,
+ inputFields,
+ precedingOffset,
+ isRowTimeType
+ )
+ val result: DataStream[Row] =
+ // partitioned aggregation
+ if (partitionKeys.nonEmpty) {
+ inputDS
+ .keyBy(partitionKeys: _*)
+ .process(processFunction)
+ .returns(rowTypeInfo)
+ .name(aggOpName)
+ .asInstanceOf[DataStream[Row]]
+ }
+ // non-partitioned aggregation
+ else {
+ inputDS
+ .keyBy(new NullByteKeySelector[Row])
+ .process(processFunction)
+ .setParallelism(1)
+ .setMaxParallelism(1)
+ .returns(rowTypeInfo)
+ .name(aggOpName)
+ .asInstanceOf[DataStream[Row]]
+ }
result
}
@@ -180,7 +265,7 @@ class DataStreamOverAggregate(
}
}ORDER BY: ${orderingToString(inputType, overWindow.orderKeys.getFieldCollations)}, " +
s"${if (overWindow.isRows) "ROWS" else "RANGE"}" +
- s"${windowRange(overWindow)}, " +
+ s"${windowRange(logicWindow, overWindow, getInput)}, " +
s"select: (${
aggregationToString(
inputType,
http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 9feec17..0084ee5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -61,7 +61,7 @@ object AggregateUtil {
* @param isPartitioned Flag to indicate whether the input is partitioned or not
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
- private[flink] def CreateUnboundedProcessingOverProcessFunction(
+ private[flink] def createUnboundedProcessingOverProcessFunction(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
isPartitioned: Boolean = true): ProcessFunction[Row, Row] = {
@@ -91,6 +91,52 @@ object AggregateUtil {
}
/**
+ * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for ROWS clause
+ * bounded OVER window to evaluate final aggregate value.
+ *
+ * @param namedAggregates List of calls to aggregate functions and their output field names
+ * @param inputType Input row type
+ * @param inputFields All input fields
+ * @param precedingOffset the preceding offset
+ * @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType
+ * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
+ */
+ private[flink] def createRowsClauseBoundedOverProcessFunction(
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ inputType: RelDataType,
+ inputFields: Array[Int],
+ precedingOffset: Long,
+ isRowTimeType: Boolean): ProcessFunction[Row, Row] = {
+
+ val (aggFields, aggregates) =
+ transformToAggregateFunctions(
+ namedAggregates.map(_.getKey),
+ inputType,
+ needRetraction = true)
+
+ val aggregationStateType: RowTypeInfo =
+ createDataSetAggregateBufferDataType(Array(), aggregates, inputType)
+
+ val inputRowType: RowTypeInfo =
+ createDataSetAggregateBufferDataType(inputFields, Array(), inputType)
+
+ val processFunction = if (isRowTimeType) {
+ new RowsClauseBoundedOverProcessFunction(
+ aggregates,
+ aggFields,
+ inputType.getFieldCount,
+ aggregationStateType,
+ inputRowType,
+ precedingOffset
+ )
+ } else {
+ throw TableException(
+ "Bounded partitioned proc-time OVER aggregation is not supported yet.")
+ }
+ processFunction
+ }
+
+ /**
* Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates.
* The output of the function contains the grouping keys and the timestamp and the intermediate
* aggregate values of all aggregate function. The timestamp field is aligned to time window
http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
new file mode 100644
index 0000000..1678d57
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import java.util
+import java.util.{List => JList}
+
+import org.apache.flink.api.common.state._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.types.Row
+import org.apache.flink.util.{Collector, Preconditions}
+
+/**
+ * Process Function for ROWS clause event-time bounded OVER window
+ *
+ * @param aggregates the list of all [[AggregateFunction]] used for this aggregation
+ * @param aggFields the position (in the input Row) of the input value for each aggregate
+ * @param forwardedFieldCount the count of forwarded fields.
+ * @param aggregationStateType the row type info of aggregation
+ * @param inputRowType the row type info of input row
+ * @param precedingOffset the preceding offset
+ */
+class RowsClauseBoundedOverProcessFunction(
+ private val aggregates: Array[AggregateFunction[_]],
+ private val aggFields: Array[Int],
+ private val forwardedFieldCount: Int,
+ private val aggregationStateType: RowTypeInfo,
+ private val inputRowType: RowTypeInfo,
+ private val precedingOffset: Long)
+ extends ProcessFunction[Row, Row] {
+
+ Preconditions.checkNotNull(aggregates)
+ Preconditions.checkNotNull(aggFields)
+ Preconditions.checkArgument(aggregates.length == aggFields.length)
+ Preconditions.checkNotNull(forwardedFieldCount)
+ Preconditions.checkNotNull(aggregationStateType)
+ Preconditions.checkNotNull(precedingOffset)
+
+ private var output: Row = _
+
+ // the state which keeps the last triggering timestamp
+ private var lastTriggeringTsState: ValueState[Long] = _
+
+ // the state which keeps the count of data
+ private var dataCountState: ValueState[Long] = _
+
+ // the state which used to materialize the accumulator for incremental calculation
+ private var accumulatorState: ValueState[Row] = _
+
+ // the state which keeps all the data that are not expired.
+ // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp,
+ // the second element of tuple is a list that contains the entire data of all the rows belonging
+ // to this time stamp.
+ private var dataState: MapState[Long, JList[Row]] = _
+
+ override def open(config: Configuration) {
+
+ output = new Row(forwardedFieldCount + aggregates.length)
+
+ val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] =
+ new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long])
+ lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor)
+
+ val dataCountStateDescriptor =
+ new ValueStateDescriptor[Long]("dataCountState", classOf[Long])
+ dataCountState = getRuntimeContext.getState(dataCountStateDescriptor)
+
+ val accumulatorStateDescriptor =
+ new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType)
+ accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor)
+
+ val keyTypeInformation: TypeInformation[Long] =
+ BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
+ val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType)
+
+ val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+ new MapStateDescriptor[Long, JList[Row]](
+ "dataState",
+ keyTypeInformation,
+ valueTypeInformation)
+
+ dataState = getRuntimeContext.getMapState(mapStateDescriptor)
+
+ }
+
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ // triggering timestamp for trigger calculation
+ val triggeringTs = ctx.timestamp
+
+ val lastTriggeringTs = lastTriggeringTsState.value
+ // check if the data is expired, if not, save the data and register event time timer
+
+ if (triggeringTs > lastTriggeringTs) {
+ val data = dataState.get(triggeringTs)
+ if (null != data) {
+ data.add(input)
+ dataState.put(triggeringTs, data)
+ } else {
+ val data = new util.ArrayList[Row]
+ data.add(input)
+ dataState.put(triggeringTs, data)
+ // register event time timer
+ ctx.timerService.registerEventTimeTimer(triggeringTs)
+ }
+ }
+ }
+
+ override def onTimer(
+ timestamp: Long,
+ ctx: ProcessFunction[Row, Row]#OnTimerContext,
+ out: Collector[Row]): Unit = {
+
+ // gets all window data from state for the calculation
+ val inputs: JList[Row] = dataState.get(timestamp)
+
+ if (null != inputs) {
+
+ var accumulators = accumulatorState.value
+ var dataCount = dataCountState.value
+
+ var retractList: JList[Row] = null
+ var retractTs: Long = Long.MaxValue
+ var retractCnt: Int = 0
+ var j = 0
+ var i = 0
+
+ while (j < inputs.size) {
+ val input = inputs.get(j)
+
+ // initialize when first run or failover recovery per key
+ if (null == accumulators) {
+ accumulators = new Row(aggregates.length)
+ i = 0
+ while (i < aggregates.length) {
+ accumulators.setField(i, aggregates(i).createAccumulator())
+ i += 1
+ }
+ }
+
+ var retractRow: Row = null
+
+ if (dataCount >= precedingOffset) {
+ if (null == retractList) {
+ // find the smallest timestamp
+ retractTs = Long.MaxValue
+ val dataTimestampIt = dataState.keys.iterator
+ while (dataTimestampIt.hasNext) {
+ val dataTs = dataTimestampIt.next
+ if (dataTs < retractTs) {
+ retractTs = dataTs
+ }
+ }
+ // get the oldest rows to retract them
+ retractList = dataState.get(retractTs)
+ }
+
+ retractRow = retractList.get(retractCnt)
+ retractCnt += 1
+
+ // remove retracted values from state
+ if (retractList.size == retractCnt) {
+ dataState.remove(retractTs)
+ retractList = null
+ retractCnt = 0
+ }
+ } else {
+ dataCount += 1
+ }
+
+ // copy forwarded fields to output row
+ i = 0
+ while (i < forwardedFieldCount) {
+ output.setField(i, input.getField(i))
+ i += 1
+ }
+
+ // retract old row from accumulators
+ if (null != retractRow) {
+ i = 0
+ while (i < aggregates.length) {
+ val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
+ aggregates(i).retract(accumulator, retractRow.getField(aggFields(i)))
+ i += 1
+ }
+ }
+
+ // accumulate current row and set aggregate in output row
+ i = 0
+ while (i < aggregates.length) {
+ val index = forwardedFieldCount + i
+ val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
+ aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
+ output.setField(index, aggregates(i).getValue(accumulator))
+ i += 1
+ }
+ j += 1
+
+ out.collect(output)
+ }
+
+ // update all states
+ if (dataState.contains(retractTs)) {
+ if (retractCnt > 0) {
+ retractList.subList(0, retractCnt).clear()
+ dataState.put(retractTs, retractList)
+ }
+ }
+ dataCountState.update(dataCount)
+ accumulatorState.update(accumulators)
+ }
+
+ lastTriggeringTsState.update(timestamp)
+ }
+}
+
+
http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
index d5a140a..19350a7 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
@@ -19,14 +19,18 @@
package org.apache.flink.table.api.scala.stream.sql
import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.functions.source.SourceFunction
+import org.apache.flink.table.api.scala.stream.sql.SqlITCase.EventTimeSourceFunction
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.table.api.{TableEnvironment, TableException}
import org.apache.flink.table.api.scala._
-import org.apache.flink.table.api.scala.stream.utils.{StreamingWithStateTestBase, StreamITCase,
-StreamTestData}
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase}
import org.apache.flink.types.Row
import org.junit.Assert._
import org.junit._
+import org.apache.flink.streaming.api.TimeCharacteristic
+import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext
import scala.collection.mutable
@@ -293,6 +297,120 @@ class SqlITCase extends StreamingWithStateTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
+ @Test
+ def testBoundPartitionedEventTimeWindowWithRow(): Unit = {
+ val data = Seq(
+ Left((1L, (1L, 1, "Hello"))),
+ Left((2L, (2L, 2, "Hello"))),
+ Left((1L, (1L, 1, "Hello"))),
+ Left((2L, (2L, 2, "Hello"))),
+ Left((2L, (2L, 2, "Hello"))),
+ Left((1L, (1L, 1, "Hello"))),
+ Left((3L, (7L, 7, "Hello World"))),
+ Left((1L, (7L, 7, "Hello World"))),
+ Left((1L, (7L, 7, "Hello World"))),
+ Right(2L),
+ Left((3L, (3L, 3, "Hello"))),
+ Left((4L, (4L, 4, "Hello"))),
+ Left((5L, (5L, 5, "Hello"))),
+ Left((6L, (6L, 6, "Hello"))),
+ Left((20L, (20L, 20, "Hello World"))),
+ Right(6L),
+ Left((8L, (8L, 8, "Hello World"))),
+ Left((7L, (7L, 7, "Hello World"))),
+ Right(20L))
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t1 = env
+ .addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, String)](data))
+ .toTable(tEnv).as('a, 'b, 'c)
+
+ tEnv.registerTable("T1", t1)
+
+ val sqlQuery = "SELECT " +
+ "c, a, " +
+ "count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" +
+ ", sum(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" +
+ " from T1"
+
+ val result = tEnv.sql(sqlQuery).toDataStream[Row]
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
+ "Hello,2,3,4", "Hello,2,3,5","Hello,2,3,6",
+ "Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12",
+ "Hello,6,3,15",
+ "Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21",
+ "Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testBoundNonPartitionedEventTimeWindowWithRow(): Unit = {
+
+ val data = Seq(
+ Left((2L, (2L, 2, "Hello"))),
+ Left((2L, (2L, 2, "Hello"))),
+ Left((1L, (1L, 1, "Hello"))),
+ Left((1L, (1L, 1, "Hello"))),
+ Left((2L, (2L, 2, "Hello"))),
+ Left((1L, (1L, 1, "Hello"))),
+ Left((20L, (20L, 20, "Hello World"))), // early row
+ Right(3L),
+ Left((2L, (2L, 2, "Hello"))), // late row
+ Left((3L, (3L, 3, "Hello"))),
+ Left((4L, (4L, 4, "Hello"))),
+ Left((5L, (5L, 5, "Hello"))),
+ Left((6L, (6L, 6, "Hello"))),
+ Left((7L, (7L, 7, "Hello World"))),
+ Right(7L),
+ Left((9L, (9L, 9, "Hello World"))),
+ Left((8L, (8L, 8, "Hello World"))),
+ Left((8L, (8L, 8, "Hello World"))),
+ Right(20L))
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+ env.setStateBackend(getStateBackend)
+ env.setParallelism(1)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t1 = env
+ .addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, String)](data))
+ .toTable(tEnv).as('a, 'b, 'c)
+
+ tEnv.registerTable("T1", t1)
+
+ val sqlQuery = "SELECT " +
+ "c, a, " +
+ "count(a) OVER (ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)," +
+ "sum(a) OVER (ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" +
+ "from T1"
+
+ val result = tEnv.sql(sqlQuery).toDataStream[Row]
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
+ "Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6",
+ "Hello,3,3,7",
+ "Hello,4,3,9", "Hello,5,3,12",
+ "Hello,6,3,15", "Hello World,7,3,18",
+ "Hello World,8,3,21", "Hello World,8,3,23",
+ "Hello World,9,3,25",
+ "Hello World,20,3,37")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
/**
* All aggregates must be computed on the same window.
*/
@@ -317,4 +435,21 @@ class SqlITCase extends StreamingWithStateTestBase {
result.addSink(new StreamITCase.StringSink)
env.execute()
}
+
+}
+
+object SqlITCase {
+
+ class EventTimeSourceFunction[T](
+ dataWithTimestampList: Seq[Either[(Long, T), Long]]) extends SourceFunction[T] {
+ override def run(ctx: SourceContext[T]): Unit = {
+ dataWithTimestampList.foreach {
+ case Left(t) => ctx.collectWithTimestamp(t._2, t._1)
+ case Right(w) => ctx.emitWatermark(new Watermark(w))
+ }
+ }
+
+ override def cancel(): Unit = ???
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
index a25e59c..9a425b3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
@@ -239,4 +239,59 @@ class WindowAggregateTest extends TableTestBase {
)
streamUtil.verifySql(sql, expected)
}
+
+ @Test
+ def testBoundPartitionedRowTimeWindowWithRow() = {
+ val sql = "SELECT " +
+ "c, " +
+ "count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " +
+ "CURRENT ROW) as cnt1 " +
+ "from MyTable"
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamOverAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "c", "ROWTIME() AS $2")
+ ),
+ term("partitionBy", "c"),
+ term("orderBy", "ROWTIME"),
+ term("rows", "BETWEEN 5 PRECEDING AND CURRENT ROW"),
+ term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0")
+ ),
+ term("select", "c", "w0$o0 AS $1")
+ )
+ streamUtil.verifySql(sql, expected)
+ }
+
+ @Test
+ def testBoundNonPartitionedRowTimeWindowWithRow() = {
+ val sql = "SELECT " +
+ "c, " +
+ "count(a) OVER (ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " +
+ "CURRENT ROW) as cnt1 " +
+ "from MyTable"
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamOverAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "c", "ROWTIME() AS $2")
+ ),
+ term("orderBy", "ROWTIME"),
+ term("rows", "BETWEEN 5 PRECEDING AND CURRENT ROW"),
+ term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0")
+ ),
+ term("select", "c", "w0$o0 AS $1")
+ )
+ streamUtil.verifySql(sql, expected)
+ }
}