You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/03/09 18:12:18 UTC

flink git commit: [FLINK-5804] [table] Add support for procTime non-partitioned OVER RANGE BETWEEN UNBOUNDED PRECEDING aggregation to SQL.

Repository: flink
Updated Branches:
  refs/heads/master 3fcc4e37c -> 7456d78d2


[FLINK-5804] [table] Add support for procTime non-partitioned OVER RANGE BETWEEN UNBOUNDED PRECEDING aggregation to SQL.

This closes #3491.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7456d78d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7456d78d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7456d78d

Branch: refs/heads/master
Commit: 7456d78d271b217c80d46e24029c55741807e51d
Parents: 3fcc4e3
Author: \u91d1\u7af9 <ji...@alibaba-inc.com>
Authored: Wed Mar 8 10:52:43 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Mar 9 19:11:43 2017 +0100

----------------------------------------------------------------------
 .../datastream/DataStreamOverAggregate.scala    |  14 ++-
 .../table/runtime/aggregate/AggregateUtil.scala |  27 +++--
 ...rtitionedProcessingOverProcessFunction.scala | 106 +++++++++++++++++++
 .../table/api/scala/stream/sql/SqlITCase.scala  |  53 ++++++++++
 .../scala/stream/sql/WindowAggregateTest.scala  |  54 ++++++++++
 5 files changed, 243 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
index db115e0..34b3b0f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
@@ -143,10 +143,18 @@ class DataStreamOverAggregate(
           .name(aggOpName)
           .asInstanceOf[DataStream[Row]]
         }
-        // global non-partitioned aggregation
+        // non-partitioned aggregation
         else {
-          throw TableException(
-            "Non-partitioned processing time OVER aggregation is not supported yet.")
+          val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction(
+            namedAggregates,
+            inputType,
+            false)
+
+          inputDS
+            .process(processFunction).setParallelism(1).setMaxParallelism(1)
+            .returns(rowTypeInfo)
+            .name(aggOpName)
+            .asInstanceOf[DataStream[Row]]
         }
     result
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 6555143..b6b3445 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -53,15 +53,18 @@ object AggregateUtil {
   type JavaList[T] = java.util.List[T]
 
   /**
-    * Create an [[ProcessFunction]] to evaluate final aggregate value.
+    * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] to evaluate final
+    * aggregate value.
     *
     * @param namedAggregates List of calls to aggregate functions and their output field names
     * @param inputType Input row type
-    * @return [[UnboundedProcessingOverProcessFunction]]
+    * @param isPartitioned Flag to indicate whether the input is partitioned or not
+    * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
     */
   private[flink] def CreateUnboundedProcessingOverProcessFunction(
     namedAggregates: Seq[CalcitePair[AggregateCall, String]],
-    inputType: RelDataType): UnboundedProcessingOverProcessFunction = {
+    inputType: RelDataType,
+    isPartitioned: Boolean = true): ProcessFunction[Row, Row] = {
 
     val (aggFields, aggregates) =
       transformToAggregateFunctions(
@@ -72,11 +75,19 @@ object AggregateUtil {
     val aggregationStateType: RowTypeInfo =
       createDataSetAggregateBufferDataType(Array(), aggregates, inputType)
 
-    new UnboundedProcessingOverProcessFunction(
-      aggregates,
-      aggFields,
-      inputType.getFieldCount,
-      aggregationStateType)
+    if (isPartitioned) {
+      new UnboundedProcessingOverProcessFunction(
+        aggregates,
+        aggFields,
+        inputType.getFieldCount,
+        aggregationStateType)
+    } else {
+      new UnboundedNonPartitionedProcessingOverProcessFunction(
+        aggregates,
+        aggFields,
+        inputType.getFieldCount,
+        aggregationStateType)
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
new file mode 100644
index 0000000..51c8315
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import org.apache.flink.api.common.state.{ListState, ListStateDescriptor}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext}
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.types.Row
+import org.apache.flink.util.{Collector, Preconditions}
+
+/**
+  * Process Function used for the aggregate in
+  * [[org.apache.flink.streaming.api.datastream.DataStream]]
+  *
+  * @param aggregates the list of all [[org.apache.flink.table.functions.AggregateFunction]]
+  *                   used for this aggregation
+  * @param aggFields  the position (in the input Row) of the input value for each aggregate
+  */
+class UnboundedNonPartitionedProcessingOverProcessFunction(
+    private val aggregates: Array[AggregateFunction[_]],
+    private val aggFields: Array[Int],
+    private val forwardedFieldCount: Int,
+    private val aggregationStateType: RowTypeInfo)
+  extends ProcessFunction[Row, Row] with CheckpointedFunction{
+
+  Preconditions.checkNotNull(aggregates)
+  Preconditions.checkNotNull(aggFields)
+  Preconditions.checkArgument(aggregates.length == aggFields.length)
+
+  private var accumulators: Row = _
+  private var output: Row = _
+  private var state: ListState[Row] = null
+
+  override def open(config: Configuration) {
+    output = new Row(forwardedFieldCount + aggregates.length)
+    if (null == accumulators) {
+      val it = state.get().iterator()
+      if (it.hasNext) {
+        accumulators = it.next()
+      } else {
+        accumulators = new Row(aggregates.length)
+        var i = 0
+        while (i < aggregates.length) {
+          accumulators.setField(i, aggregates(i).createAccumulator())
+          i += 1
+        }
+      }
+    }
+  }
+
+  override def processElement(
+    input: Row,
+    ctx: ProcessFunction[Row, Row]#Context,
+    out: Collector[Row]): Unit = {
+
+    var i = 0
+    while (i < forwardedFieldCount) {
+      output.setField(i, input.getField(i))
+      i += 1
+    }
+
+    i = 0
+    while (i < aggregates.length) {
+      val index = forwardedFieldCount + i
+      val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
+      aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
+      output.setField(index, aggregates(i).getValue(accumulator))
+      i += 1
+    }
+
+    out.collect(output)
+  }
+
+  override def snapshotState(context: FunctionSnapshotContext): Unit = {
+    state.clear()
+    if (null != accumulators) {
+      state.add(accumulators)
+    }
+  }
+
+  override def initializeState(context: FunctionInitializationContext): Unit = {
+    val stateSerializer =
+      aggregationStateType.createSerializer(getRuntimeContext.getExecutionConfig)
+    val accumulatorsDescriptor = new ListStateDescriptor[Row]("overState", stateSerializer)
+    state = context.getOperatorStateStore.getOperatorState(accumulatorsDescriptor)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
index cf8e442..d5a140a 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
@@ -240,6 +240,59 @@ class SqlITCase extends StreamingWithStateTestBase {
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
+  @Test
+  def testUnboundNonPartitionedProcessingWindowWithRange(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.testResults = mutable.MutableList()
+
+    // for sum aggregation ensure that every time the order of each element is consistent
+    env.setParallelism(1)
+
+    val t1 = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c)
+
+    tEnv.registerTable("T1", t1)
+
+    val sqlQuery = "SELECT " +
+      "c, " +
+      "count(a) OVER (ORDER BY ProcTime()  RANGE UNBOUNDED preceding) as cnt1, " +
+      "sum(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt2 " +
+      "from T1"
+
+    val result = tEnv.sql(sqlQuery).toDataStream[Row]
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "Hello World,7,28", "Hello World,8,36", "Hello World,9,56",
+      "Hello,1,1", "Hello,2,3", "Hello,3,6", "Hello,4,10", "Hello,5,15", "Hello,6,21")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testUnboundNonPartitionedProcessingWindowWithRow(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.testResults = mutable.MutableList()
+
+    val t1 = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c)
+
+    tEnv.registerTable("T1", t1)
+
+    val sqlQuery = "SELECT " +
+      "count(a) OVER (ORDER BY ProcTime() ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW)" +
+      "from T1"
+
+    val result = tEnv.sql(sqlQuery).toDataStream[Row]
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList("1", "2", "3", "4", "5", "6", "7", "8", "9")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
   /**
     *  All aggregates must be computed on the same window.
     */

http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
index 85bc5a7..2781fb8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
@@ -185,4 +185,58 @@ class WindowAggregateTest extends TableTestBase {
       )
     streamUtil.verifySql(sql, expected)
   }
+
+  @Test
+  def testUnboundNonPartitionedProcessingWindowWithRange() = {
+    val sql = "SELECT " +
+      "c, " +
+      "count(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt1, " +
+      "sum(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt2 " +
+      "from MyTable"
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamOverAggregate",
+          unaryNode(
+            "DataStreamCalc",
+            streamTableNode(0),
+            term("select", "a", "c", "PROCTIME() AS $2")
+          ),
+          term("orderBy", "PROCTIME"),
+          term("range", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"),
+          term("select", "a", "c", "PROCTIME", "COUNT(a) AS w0$o0", "$SUM0(a) AS w0$o1")
+        ),
+        term("select", "c", "w0$o0 AS cnt1", "CASE(>(w0$o0, 0)", "CAST(w0$o1), null) AS cnt2")
+      )
+    streamUtil.verifySql(sql, expected)
+  }
+
+  @Test
+  def testUnboundNonPartitionedProcessingWindowWithRow() = {
+    val sql = "SELECT " +
+      "c, " +
+      "count(a) OVER (ORDER BY ProcTime() ROWS BETWEEN UNBOUNDED preceding AND " +
+      "CURRENT ROW) as cnt1 " +
+      "from MyTable"
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamOverAggregate",
+          unaryNode(
+            "DataStreamCalc",
+            streamTableNode(0),
+            term("select", "a", "c", "PROCTIME() AS $2")
+          ),
+          term("orderBy", "PROCTIME"),
+          term("rows", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"),
+          term("select", "a", "c", "PROCTIME", "COUNT(a) AS w0$o0")
+        ),
+        term("select", "c", "w0$o0 AS $1")
+      )
+    streamUtil.verifySql(sql, expected)
+  }
 }