You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2017/01/25 19:56:07 UTC

[2/4] flink git commit: [FLINK-5630] [streaming api] Followups to the AggregateFunction

http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/WindowTranslationTest.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/WindowTranslationTest.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/WindowTranslationTest.scala
index 7235b22..bd3fe3d 100644
--- a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/WindowTranslationTest.scala
+++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/WindowTranslationTest.scala
@@ -19,8 +19,8 @@
 package org.apache.flink.streaming.api.scala
 
 import org.apache.flink.api.common.ExecutionConfig
-import org.apache.flink.api.common.functions.{FoldFunction, ReduceFunction, RichFoldFunction, RichReduceFunction}
-import org.apache.flink.api.common.state.{FoldingStateDescriptor, ListStateDescriptor, ReducingStateDescriptor}
+import org.apache.flink.api.common.functions._
+import org.apache.flink.api.common.state.{AggregatingStateDescriptor, FoldingStateDescriptor, ListStateDescriptor, ReducingStateDescriptor}
 import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
 import org.apache.flink.api.java.functions.KeySelector
 import org.apache.flink.streaming.api.TimeCharacteristic
@@ -48,6 +48,10 @@ import org.junit.Test
   */
 class WindowTranslationTest {
 
+  // --------------------------------------------------------------------------
+  //  rich function tests
+  // --------------------------------------------------------------------------
+
   /**
     * .reduce() does not support [[RichReduceFunction]], since the reduce function is used
     * internally in a [[org.apache.flink.api.common.state.ReducingState]].
@@ -70,6 +74,25 @@ class WindowTranslationTest {
   }
 
   /**
+   * .reduce() does not support [[RichReduceFunction]], since the reduce function is used
+   * internally in a [[org.apache.flink.api.common.state.ReducingState]].
+   */
+  @Test(expected = classOf[UnsupportedOperationException])
+  def testAggregateWithRichFunctionFails() {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val source = env.fromElements(("hello", 1), ("hello", 2))
+
+    env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime)
+
+    source
+      .keyBy(0)
+      .window(SlidingEventTimeWindows.of(Time.seconds(1), Time.milliseconds(100)))
+      .aggregate(new DummyRichAggregator())
+
+    fail("exception was not thrown")
+  }
+
+  /**
     * .fold() does not support [[RichFoldFunction]], since the reduce function is used internally
     * in a [[org.apache.flink.api.common.state.FoldingState]].
     */
@@ -90,6 +113,10 @@ class WindowTranslationTest {
     fail("exception was not thrown")
   }
 
+  // --------------------------------------------------------------------------
+  //  merging window checks
+  // --------------------------------------------------------------------------
+
   @Test
   def testSessionWithFoldFails() {
     // verify that fold does not work with merging windows
@@ -151,6 +178,10 @@ class WindowTranslationTest {
     fail("The trigger call should fail.")
   }
 
+  // --------------------------------------------------------------------------
+  //  reduce() tests
+  // --------------------------------------------------------------------------
+
   @Test
   def testReduceEventTime() {
     val env = StreamExecutionEnvironment.getExecutionEnvironment
@@ -412,6 +443,186 @@ class WindowTranslationTest {
       ("hello", 1))
   }
 
+  // --------------------------------------------------------------------------
+  //  aggregate() tests
+  // --------------------------------------------------------------------------
+
+  @Test
+  def testAggregateEventTime() {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime)
+
+    val source = env.fromElements(("hello", 1), ("hello", 2))
+
+    val window1 = source
+      .keyBy(_._1)
+      .window(SlidingEventTimeWindows.of(Time.seconds(1), Time.milliseconds(100)))
+      .aggregate(new DummyAggregator())
+
+    val transform = window1
+      .javaStream
+      .getTransformation
+      .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+    val operator = transform.getOperator
+    assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+    val winOperator = operator
+      .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+    assertTrue(winOperator.getTrigger.isInstanceOf[EventTimeTrigger])
+    assertTrue(winOperator.getWindowAssigner.isInstanceOf[SlidingEventTimeWindows])
+    assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+    processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+      winOperator,
+      winOperator.getKeySelector,
+      BasicTypeInfo.STRING_TYPE_INFO,
+      ("hello", 1))
+  }
+
+  @Test
+  def testAggregateProcessingTime() {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime)
+
+    val source = env.fromElements(("hello", 1), ("hello", 2))
+
+    val window1 = source
+      .keyBy(_._1)
+      .window(SlidingProcessingTimeWindows.of(Time.seconds(1), Time.milliseconds(100)))
+      .aggregate(new DummyAggregator())
+
+    val transform = window1
+      .javaStream
+      .getTransformation
+      .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+    val operator = transform.getOperator
+    assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+    val winOperator = operator
+      .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+    assertTrue(winOperator.getTrigger.isInstanceOf[ProcessingTimeTrigger])
+    assertTrue(winOperator.getWindowAssigner.isInstanceOf[SlidingProcessingTimeWindows])
+    assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+    processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+      winOperator,
+      winOperator.getKeySelector,
+      BasicTypeInfo.STRING_TYPE_INFO,
+      ("hello", 1))
+  }
+
+  @Test
+  def testAggregateWithWindowFunctionEventTime() {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime)
+
+    val source = env.fromElements(("hello", 1), ("hello", 2))
+
+    val window1 = source
+      .keyBy(_._1)
+      .window(TumblingEventTimeWindows.of(Time.seconds(1)))
+      .aggregate(new DummyAggregator(), new TestWindowFunction())
+
+    val transform = window1
+      .javaStream
+      .getTransformation
+      .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+    val operator = transform.getOperator
+    assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+    val winOperator = operator
+      .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+    assertTrue(winOperator.getTrigger.isInstanceOf[EventTimeTrigger])
+    assertTrue(winOperator.getWindowAssigner.isInstanceOf[TumblingEventTimeWindows])
+    assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+    processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+      winOperator,
+      winOperator.getKeySelector,
+      BasicTypeInfo.STRING_TYPE_INFO,
+      ("hello", 1))
+  }
+
+  @Test
+  def testAggregateWithWindowFunctionProcessingTime() {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime)
+
+    val source = env.fromElements(("hello", 1), ("hello", 2))
+
+    val window1 = source
+      .keyBy(_._1)
+      .window(TumblingProcessingTimeWindows.of(Time.seconds(1)))
+      .aggregate(new DummyAggregator(), new TestWindowFunction())
+
+    val transform = window1
+      .javaStream
+      .getTransformation
+      .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+    val operator = transform.getOperator
+    assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+    val winOperator = operator
+      .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+    assertTrue(winOperator.getTrigger.isInstanceOf[ProcessingTimeTrigger])
+    assertTrue(winOperator.getWindowAssigner.isInstanceOf[TumblingProcessingTimeWindows])
+    assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+    processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+      winOperator,
+      winOperator.getKeySelector,
+      BasicTypeInfo.STRING_TYPE_INFO,
+      ("hello", 1))
+  }
+
+  @Test
+  def testAggregateWithWindowFunctionEventTimeWithScalaFunction() {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime)
+
+    val source = env.fromElements(("hello", 1), ("hello", 2))
+
+    val window1 = source
+      .keyBy(_._1)
+      .window(TumblingEventTimeWindows.of(Time.seconds(1)))
+      .aggregate(new DummyAggregator(),
+        { (_, _, in: Iterable[(String, Int)], out: Collector[(String, Int)]) => {
+          in foreach { x => out.collect(x)}
+        } })
+
+    val transform = window1
+      .javaStream
+      .getTransformation
+      .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+    val operator = transform.getOperator
+    assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+    val winOperator = operator
+      .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+    assertTrue(winOperator.getTrigger.isInstanceOf[EventTimeTrigger])
+    assertTrue(winOperator.getWindowAssigner.isInstanceOf[TumblingEventTimeWindows])
+    assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+    processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+      winOperator,
+      winOperator.getKeySelector,
+      BasicTypeInfo.STRING_TYPE_INFO,
+      ("hello", 1))
+  }
+
+  // --------------------------------------------------------------------------
+  //  fold() tests
+  // --------------------------------------------------------------------------
 
   @Test
   def testFoldEventTime() {
@@ -685,6 +896,9 @@ class WindowTranslationTest {
       ("hello", 1))
   }
 
+  // --------------------------------------------------------------------------
+  //  apply() tests
+  // --------------------------------------------------------------------------
 
   @Test
   def testApplyEventTime() {
@@ -1082,3 +1296,39 @@ class DummyFolder extends FoldFunction[(String, Int), (String, String, Int)] {
   }
 }
 
+class DummyAggregator extends AggregateFunction[(String, Int), (String, Int), (String, Int)] {
+
+  override def createAccumulator(): (String, Int) = ("", 0)
+
+  override def merge(a: (String, Int), b: (String, Int)): (String, Int) = a
+
+  override def getResult(accumulator: (String, Int)): (String, Int) = accumulator
+
+  override def add(value: (String, Int), accumulator: (String, Int)): Unit = ()
+}
+
+class DummyRichAggregator extends RichAggregateFunction[(String, Int), (String, Int), (String, Int)]
+{
+
+  override def createAccumulator(): (String, Int) = ("", 0)
+
+  override def merge(a: (String, Int), b: (String, Int)): (String, Int) = a
+
+  override def getResult(accumulator: (String, Int)): (String, Int) = accumulator
+
+  override def add(value: (String, Int), accumulator: (String, Int)): Unit = ()
+}
+
+class TestWindowFunction extends WindowFunction[(String, Int), (String, Int), String, TimeWindow] {
+
+  override def apply(
+      key: String,
+      window: TimeWindow,
+      input: Iterable[(String, Int)],
+      out: Collector[(String, Int)]): Unit = {
+
+    input.foreach(out.collect)
+  }
+}
+
+