You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/11/03 08:43:22 UTC

[3/3] flink git commit: [FLINK-7338] [table] Fix retrieval of OVER window lower bound.

[FLINK-7338] [table] Fix retrieval of OVER window lower bound.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/16b08821
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/16b08821
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/16b08821

Branch: refs/heads/master
Commit: 16b088218435dc64848ad641a383f9ce808f07c2
Parents: 78c8ea2
Author: Fabian Hueske <fh...@apache.org>
Authored: Thu Nov 2 23:09:36 2017 +0100
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Nov 3 00:01:44 2017 +0100

----------------------------------------------------------------------
 .../flink/table/plan/nodes/OverAggregate.scala  |  21 ++--
 .../runtime/stream/sql/OverWindowITCase.scala   | 119 ++++++++++++-------
 2 files changed, 85 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/16b08821/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
index 87ebd86..f9bf803 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
@@ -35,8 +35,8 @@ trait OverAggregate {
   }
 
   private[flink] def orderingToString(
-    inputType: RelDataType,
-    orderFields: java.util.List[RelFieldCollation]): String = {
+      inputType: RelDataType,
+      orderFields: java.util.List[RelFieldCollation]): String = {
 
     val inFields = inputType.getFieldList.asScala
 
@@ -48,9 +48,9 @@ trait OverAggregate {
   }
 
   private[flink] def windowRange(
-    logicWindow: Window,
-    overWindow: Group,
-    input: RelNode): String = {
+      logicWindow: Window,
+      overWindow: Group,
+      input: RelNode): String = {
     if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded) {
       s"BETWEEN ${getLowerBoundary(logicWindow, overWindow, input)} PRECEDING " +
           s"AND ${overWindow.upperBound}"
@@ -63,8 +63,7 @@ trait OverAggregate {
       inputType: RelDataType,
       constants: Seq[RexLiteral],
       rowType: RelDataType,
-      namedAggregates: Seq[CalcitePair[AggregateCall, String]])
-    : String = {
+      namedAggregates: Seq[CalcitePair[AggregateCall, String]]): String = {
 
     val inFields = inputType.getFieldNames.asScala
     val outFields = rowType.getFieldNames.asScala
@@ -97,12 +96,12 @@ trait OverAggregate {
   }
 
   private[flink] def getLowerBoundary(
-    logicWindow: Window,
-    overWindow: Group,
-    input: RelNode): Long = {
+      logicWindow: Window,
+      overWindow: Group,
+      input: RelNode): Long = {
 
     val ref: RexInputRef = overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef]
-    val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex
+    val lowerBoundIndex = ref.getIndex - input.getRowType.getFieldCount
     val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2
     lowerBound match {
       case x: java.math.BigDecimal => x.asInstanceOf[java.math.BigDecimal].longValue()

http://git-wip-us.apache.org/repos/asf/flink/blob/16b08821/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
index 4884513..9bfdc4c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala
@@ -19,6 +19,7 @@
 package org.apache.flink.table.runtime.stream.sql
 
 import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.java.tuple.Tuple1
 import org.apache.flink.api.scala._
 import org.apache.flink.streaming.api.TimeCharacteristic
 import org.apache.flink.streaming.api.functions.source.SourceFunction
@@ -28,6 +29,7 @@ import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction
 import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment}
+import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase}
 import org.apache.flink.types.Row
 import org.junit.Assert._
@@ -293,13 +295,16 @@ class OverWindowITCase extends StreamingWithStateTestBase {
       .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
 
     tEnv.registerTable("T1", t1)
+    tEnv.registerFunction("LTCNT", new LargerThanCount)
 
     val sqlQuery = "SELECT " +
       "  c, b, " +
+      "  LTCNT(a, CAST('4' AS BIGINT)) OVER (PARTITION BY c ORDER BY rowtime RANGE " +
+      "    BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW), " +
       "  COUNT(a) OVER (PARTITION BY c ORDER BY rowtime RANGE " +
       "    BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW), " +
       "  SUM(a) OVER (PARTITION BY c ORDER BY rowtime RANGE " +
-      "    BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW)" +
+      "    BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW) " +
       " FROM T1"
 
     val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
@@ -307,16 +312,17 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     env.execute()
 
     val expected = List(
-      "Hello,1,1,1", "Hello,15,2,2", "Hello,16,3,3",
-      "Hello,2,6,9", "Hello,3,6,9", "Hello,2,6,9",
-      "Hello,3,4,9",
-      "Hello,4,2,7",
-      "Hello,5,2,9",
-      "Hello,6,2,11", "Hello,65,2,12",
-      "Hello,9,2,12", "Hello,9,2,12", "Hello,18,3,18",
-      "Hello World,7,1,7", "Hello World,17,3,21", "Hello World,77,3,21", "Hello World,18,1,7",
-      "Hello World,8,2,15",
-      "Hello World,20,1,20")
+      "Hello,1,0,1,1", "Hello,15,0,2,2", "Hello,16,0,3,3",
+      "Hello,2,0,6,9", "Hello,3,0,6,9", "Hello,2,0,6,9",
+      "Hello,3,0,4,9",
+      "Hello,4,0,2,7",
+      "Hello,5,1,2,9",
+      "Hello,6,2,2,11", "Hello,65,2,2,12",
+      "Hello,9,2,2,12", "Hello,9,2,2,12", "Hello,18,3,3,18",
+      "Hello World,7,1,1,7", "Hello World,17,3,3,21", "Hello World,77,3,3,21",
+      "Hello World,18,1,1,7",
+      "Hello World,8,2,2,15",
+      "Hello World,20,1,1,20")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -354,9 +360,12 @@ class OverWindowITCase extends StreamingWithStateTestBase {
       .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
 
     tEnv.registerTable("T1", t1)
+    tEnv.registerFunction("LTCNT", new LargerThanCount)
 
     val sqlQuery = "SELECT " +
       " c, a, " +
+      "  LTCNT(a, CAST('4' AS BIGINT)) " +
+      "    OVER (PARTITION BY c ORDER BY rowtime ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), " +
       "  COUNT(a) " +
       "    OVER (PARTITION BY c ORDER BY rowtime ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), " +
       "  SUM(a) " +
@@ -368,12 +377,12 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     env.execute()
 
     val expected = List(
-      "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
-      "Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6",
-      "Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12",
-      "Hello,6,3,15",
-      "Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21",
-      "Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35")
+      "Hello,1,0,1,1", "Hello,1,0,2,2", "Hello,1,0,3,3",
+      "Hello,2,0,3,4", "Hello,2,0,3,5", "Hello,2,0,3,6",
+      "Hello,3,0,3,7", "Hello,4,0,3,9", "Hello,5,1,3,12",
+      "Hello,6,2,3,15",
+      "Hello World,7,1,1,7", "Hello World,7,2,2,14", "Hello World,7,3,3,21",
+      "Hello World,7,3,3,21", "Hello World,8,3,3,22", "Hello World,20,3,3,35")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -518,6 +527,8 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     StreamITCase.clear
 
     val sqlQuery = "SELECT a, b, c, " +
+      "  LTCNT(b, CAST('4' AS BIGINT)) OVER(" +
+      "    PARTITION BY a ORDER BY rowtime RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), " +
       "  SUM(b) OVER (" +
       "    PARTITION BY a ORDER BY rowtime RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), " +
       "  COUNT(b) OVER (" +
@@ -552,25 +563,26 @@ class OverWindowITCase extends StreamingWithStateTestBase {
       .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
 
     tEnv.registerTable("T1", t1)
+    tEnv.registerFunction("LTCNT", new LargerThanCount)
 
     val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
     result.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
     val expected = List(
-      "1,1,Hello,6,3,2,3,1",
-      "1,2,Hello,6,3,2,3,1",
-      "1,3,Hello world,6,3,2,3,1",
-      "1,1,Hi,7,4,1,3,1",
-      "2,1,Hello,1,1,1,1,1",
-      "2,2,Hello world,6,3,2,3,1",
-      "2,3,Hello world,6,3,2,3,1",
-      "1,4,Hello world,11,5,2,4,1",
-      "1,5,Hello world,29,8,3,7,1",
-      "1,6,Hello world,29,8,3,7,1",
-      "1,7,Hello world,29,8,3,7,1",
-      "2,4,Hello world,15,5,3,5,1",
-      "2,5,Hello world,15,5,3,5,1")
+      "1,1,Hello,0,6,3,2,3,1",
+      "1,2,Hello,0,6,3,2,3,1",
+      "1,3,Hello world,0,6,3,2,3,1",
+      "1,1,Hi,0,7,4,1,3,1",
+      "2,1,Hello,0,1,1,1,1,1",
+      "2,2,Hello world,0,6,3,2,3,1",
+      "2,3,Hello world,0,6,3,2,3,1",
+      "1,4,Hello world,0,11,5,2,4,1",
+      "1,5,Hello world,3,29,8,3,7,1",
+      "1,6,Hello world,3,29,8,3,7,1",
+      "1,7,Hello world,3,29,8,3,7,1",
+      "2,4,Hello world,1,15,5,3,5,1",
+      "2,5,Hello world,1,15,5,3,5,1")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -583,6 +595,8 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     StreamITCase.testResults = mutable.MutableList()
 
     val sqlQuery = "SELECT a, b, c, " +
+      "LTCNT(b, CAST('4' AS BIGINT)) over(" +
+      "partition by a order by rowtime rows between unbounded preceding and current row), " +
       "SUM(b) over (" +
       "partition by a order by rowtime rows between unbounded preceding and current row), " +
       "count(b) over (" +
@@ -618,26 +632,27 @@ class OverWindowITCase extends StreamingWithStateTestBase {
              .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
 
     tEnv.registerTable("T1", t1)
+    tEnv.registerFunction("LTCNT", new LargerThanCount)
 
     val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
     result.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
     val expected = mutable.MutableList(
-      "1,2,Hello,2,1,2,2,2",
-      "1,3,Hello world,5,2,2,3,2",
-      "1,1,Hi,6,3,2,3,1",
-      "2,1,Hello,1,1,1,1,1",
-      "2,2,Hello world,3,2,1,2,1",
-      "3,1,Hello,1,1,1,1,1",
-      "3,2,Hello world,3,2,1,2,1",
-      "1,5,Hello world,11,4,2,5,1",
-      "1,6,Hello world,17,5,3,6,1",
-      "1,9,Hello world,26,6,4,9,1",
-      "1,8,Hello world,34,7,4,9,1",
-      "1,7,Hello world,41,8,5,9,1",
-      "2,5,Hello world,8,3,2,5,1",
-      "3,5,Hello world,8,3,2,5,1")
+      "1,2,Hello,0,2,1,2,2,2",
+      "1,3,Hello world,0,5,2,2,3,2",
+      "1,1,Hi,0,6,3,2,3,1",
+      "2,1,Hello,0,1,1,1,1,1",
+      "2,2,Hello world,0,3,2,1,2,1",
+      "3,1,Hello,0,1,1,1,1,1",
+      "3,2,Hello world,0,3,2,1,2,1",
+      "1,5,Hello world,1,11,4,2,5,1",
+      "1,6,Hello world,2,17,5,3,6,1",
+      "1,9,Hello world,3,26,6,4,9,1",
+      "1,8,Hello world,4,34,7,4,9,1",
+      "1,7,Hello world,5,41,8,5,9,1",
+      "2,5,Hello world,1,8,3,2,5,1",
+      "3,5,Hello world,1,8,3,2,5,1")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -852,3 +867,19 @@ object OverWindowITCase {
     override def cancel(): Unit = ???
   }
 }
+
+/** Counts how often the first argument was larger than the second argument. */
+class LargerThanCount extends AggregateFunction[Long, Tuple1[Long]] {
+
+  def accumulate(acc: Tuple1[Long], a: Long, b: Long): Unit = {
+    if (a > b) acc.f0 += 1
+  }
+
+  def retract(acc: Tuple1[Long], a: Long, b: Long): Unit = {
+    if (a > b) acc.f0 -= 1
+  }
+
+  override def createAccumulator(): Tuple1[Long] = Tuple1.of(0L)
+
+  override def getValue(acc: Tuple1[Long]): Long = acc.f0
+}