You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/03/09 18:12:18 UTC
flink git commit: [FLINK-5804] [table] Add support for procTime
non-partitioned OVER RANGE BETWEEN UNBOUNDED PRECEDING aggregation to SQL.
Repository: flink
Updated Branches:
refs/heads/master 3fcc4e37c -> 7456d78d2
[FLINK-5804] [table] Add support for procTime non-partitioned OVER RANGE BETWEEN UNBOUNDED PRECEDING aggregation to SQL.
This closes #3491.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7456d78d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7456d78d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7456d78d
Branch: refs/heads/master
Commit: 7456d78d271b217c80d46e24029c55741807e51d
Parents: 3fcc4e3
Author: \u91d1\u7af9 <ji...@alibaba-inc.com>
Authored: Wed Mar 8 10:52:43 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Mar 9 19:11:43 2017 +0100
----------------------------------------------------------------------
.../datastream/DataStreamOverAggregate.scala | 14 ++-
.../table/runtime/aggregate/AggregateUtil.scala | 27 +++--
...rtitionedProcessingOverProcessFunction.scala | 106 +++++++++++++++++++
.../table/api/scala/stream/sql/SqlITCase.scala | 53 ++++++++++
.../scala/stream/sql/WindowAggregateTest.scala | 54 ++++++++++
5 files changed, 243 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
index db115e0..34b3b0f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
@@ -143,10 +143,18 @@ class DataStreamOverAggregate(
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
- // global non-partitioned aggregation
+ // non-partitioned aggregation
else {
- throw TableException(
- "Non-partitioned processing time OVER aggregation is not supported yet.")
+ val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction(
+ namedAggregates,
+ inputType,
+ false)
+
+ inputDS
+ .process(processFunction).setParallelism(1).setMaxParallelism(1)
+ .returns(rowTypeInfo)
+ .name(aggOpName)
+ .asInstanceOf[DataStream[Row]]
}
result
}
http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 6555143..b6b3445 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -53,15 +53,18 @@ object AggregateUtil {
type JavaList[T] = java.util.List[T]
/**
- * Create an [[ProcessFunction]] to evaluate final aggregate value.
+ * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] to evaluate final
+ * aggregate value.
*
* @param namedAggregates List of calls to aggregate functions and their output field names
* @param inputType Input row type
- * @return [[UnboundedProcessingOverProcessFunction]]
+ * @param isPartitioned Flag to indicate whether the input is partitioned or not
+ * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
private[flink] def CreateUnboundedProcessingOverProcessFunction(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- inputType: RelDataType): UnboundedProcessingOverProcessFunction = {
+ inputType: RelDataType,
+ isPartitioned: Boolean = true): ProcessFunction[Row, Row] = {
val (aggFields, aggregates) =
transformToAggregateFunctions(
@@ -72,11 +75,19 @@ object AggregateUtil {
val aggregationStateType: RowTypeInfo =
createDataSetAggregateBufferDataType(Array(), aggregates, inputType)
- new UnboundedProcessingOverProcessFunction(
- aggregates,
- aggFields,
- inputType.getFieldCount,
- aggregationStateType)
+ if (isPartitioned) {
+ new UnboundedProcessingOverProcessFunction(
+ aggregates,
+ aggFields,
+ inputType.getFieldCount,
+ aggregationStateType)
+ } else {
+ new UnboundedNonPartitionedProcessingOverProcessFunction(
+ aggregates,
+ aggFields,
+ inputType.getFieldCount,
+ aggregationStateType)
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
new file mode 100644
index 0000000..51c8315
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import org.apache.flink.api.common.state.{ListState, ListStateDescriptor}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext}
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.types.Row
+import org.apache.flink.util.{Collector, Preconditions}
+
+/**
+ * Process Function used for the aggregate in
+ * [[org.apache.flink.streaming.api.datastream.DataStream]]
+ *
+ * @param aggregates the list of all [[org.apache.flink.table.functions.AggregateFunction]]
+ * used for this aggregation
+ * @param aggFields the position (in the input Row) of the input value for each aggregate
+ */
+class UnboundedNonPartitionedProcessingOverProcessFunction(
+ private val aggregates: Array[AggregateFunction[_]],
+ private val aggFields: Array[Int],
+ private val forwardedFieldCount: Int,
+ private val aggregationStateType: RowTypeInfo)
+ extends ProcessFunction[Row, Row] with CheckpointedFunction{
+
+ Preconditions.checkNotNull(aggregates)
+ Preconditions.checkNotNull(aggFields)
+ Preconditions.checkArgument(aggregates.length == aggFields.length)
+
+ private var accumulators: Row = _
+ private var output: Row = _
+ private var state: ListState[Row] = null
+
+ override def open(config: Configuration) {
+ output = new Row(forwardedFieldCount + aggregates.length)
+ if (null == accumulators) {
+ val it = state.get().iterator()
+ if (it.hasNext) {
+ accumulators = it.next()
+ } else {
+ accumulators = new Row(aggregates.length)
+ var i = 0
+ while (i < aggregates.length) {
+ accumulators.setField(i, aggregates(i).createAccumulator())
+ i += 1
+ }
+ }
+ }
+ }
+
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ var i = 0
+ while (i < forwardedFieldCount) {
+ output.setField(i, input.getField(i))
+ i += 1
+ }
+
+ i = 0
+ while (i < aggregates.length) {
+ val index = forwardedFieldCount + i
+ val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
+ aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
+ output.setField(index, aggregates(i).getValue(accumulator))
+ i += 1
+ }
+
+ out.collect(output)
+ }
+
+ override def snapshotState(context: FunctionSnapshotContext): Unit = {
+ state.clear()
+ if (null != accumulators) {
+ state.add(accumulators)
+ }
+ }
+
+ override def initializeState(context: FunctionInitializationContext): Unit = {
+ val stateSerializer =
+ aggregationStateType.createSerializer(getRuntimeContext.getExecutionConfig)
+ val accumulatorsDescriptor = new ListStateDescriptor[Row]("overState", stateSerializer)
+ state = context.getOperatorStateStore.getOperatorState(accumulatorsDescriptor)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
index cf8e442..d5a140a 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
@@ -240,6 +240,59 @@ class SqlITCase extends StreamingWithStateTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
+ @Test
+ def testUnboundNonPartitionedProcessingWindowWithRange(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.testResults = mutable.MutableList()
+
+ // for sum aggregation ensure that every time the order of each element is consistent
+ env.setParallelism(1)
+
+ val t1 = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c)
+
+ tEnv.registerTable("T1", t1)
+
+ val sqlQuery = "SELECT " +
+ "c, " +
+ "count(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt1, " +
+ "sum(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt2 " +
+ "from T1"
+
+ val result = tEnv.sql(sqlQuery).toDataStream[Row]
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "Hello World,7,28", "Hello World,8,36", "Hello World,9,56",
+ "Hello,1,1", "Hello,2,3", "Hello,3,6", "Hello,4,10", "Hello,5,15", "Hello,6,21")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testUnboundNonPartitionedProcessingWindowWithRow(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.testResults = mutable.MutableList()
+
+ val t1 = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c)
+
+ tEnv.registerTable("T1", t1)
+
+ val sqlQuery = "SELECT " +
+ "count(a) OVER (ORDER BY ProcTime() ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW)" +
+ "from T1"
+
+ val result = tEnv.sql(sqlQuery).toDataStream[Row]
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("1", "2", "3", "4", "5", "6", "7", "8", "9")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
/**
* All aggregates must be computed on the same window.
*/
http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
index 85bc5a7..2781fb8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
@@ -185,4 +185,58 @@ class WindowAggregateTest extends TableTestBase {
)
streamUtil.verifySql(sql, expected)
}
+
+ @Test
+ def testUnboundNonPartitionedProcessingWindowWithRange() = {
+ val sql = "SELECT " +
+ "c, " +
+ "count(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt1, " +
+ "sum(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt2 " +
+ "from MyTable"
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamOverAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "c", "PROCTIME() AS $2")
+ ),
+ term("orderBy", "PROCTIME"),
+ term("range", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"),
+ term("select", "a", "c", "PROCTIME", "COUNT(a) AS w0$o0", "$SUM0(a) AS w0$o1")
+ ),
+ term("select", "c", "w0$o0 AS cnt1", "CASE(>(w0$o0, 0)", "CAST(w0$o1), null) AS cnt2")
+ )
+ streamUtil.verifySql(sql, expected)
+ }
+
+ @Test
+ def testUnboundNonPartitionedProcessingWindowWithRow() = {
+ val sql = "SELECT " +
+ "c, " +
+ "count(a) OVER (ORDER BY ProcTime() ROWS BETWEEN UNBOUNDED preceding AND " +
+ "CURRENT ROW) as cnt1 " +
+ "from MyTable"
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamOverAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "c", "PROCTIME() AS $2")
+ ),
+ term("orderBy", "PROCTIME"),
+ term("rows", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"),
+ term("select", "a", "c", "PROCTIME", "COUNT(a) AS w0$o0")
+ ),
+ term("select", "c", "w0$o0 AS $1")
+ )
+ streamUtil.verifySql(sql, expected)
+ }
}