You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2018/12/17 06:13:11 UTC

[flink] branch master updated: [FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBackend and add harness tests for CollectAggFunction

This is an automated email from the ASF dual-hosted git repository.

jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a45fca  [FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBackend and add harness tests for CollectAggFunction
9a45fca is described below

commit 9a45fcabc51fc479af48c024c456aea548620bd2
Author: Dian Fu <fu...@alibaba-inc.com>
AuthorDate: Thu Dec 6 20:25:45 2018 +0800

    [FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBackend and add harness tests for CollectAggFunction
    
    This closes #7253
---
 .../aggfunctions/CollectAggFunction.scala          |  10 +-
 .../runtime/harness/AggFunctionHarnessTest.scala   | 110 +++++++++++++++++++++
 .../table/runtime/harness/HarnessTestBase.scala    |  88 +++++++++++++++--
 3 files changed, 195 insertions(+), 13 deletions(-)

diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
index b10be61..5186d66 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
@@ -112,10 +112,12 @@ class CollectAggFunction[E](valueTypeInfo: TypeInformation[_])
   def retract(acc: CollectAccumulator[E], value: E): Unit = {
     if (value != null) {
       val count = acc.map.get(value)
-      if (count == 1) {
-        acc.map.remove(value)
-      } else {
-        acc.map.put(value, count - 1)
+      if (count != null) {
+        if (count == 1) {
+          acc.map.remove(value)
+        } else {
+          acc.map.put(value, count - 1)
+        }
       }
     }
   }
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala
new file mode 100644
index 0000000..0549339
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.harness
+
+import java.lang.{Integer => JInt}
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.scala._
+import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.dataview.MapView
+import org.apache.flink.table.dataview.StateMapView
+import org.apache.flink.table.runtime.aggregate.GroupAggProcessFunction
+import org.apache.flink.table.runtime.harness.HarnessTestBase.TestStreamQueryConfig
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.types.Row
+import org.junit.Assert.assertTrue
+import org.junit.Test
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+class AggFunctionHarnessTest extends HarnessTestBase {
+  private val queryConfig = new TestStreamQueryConfig(Time.seconds(0), Time.seconds(0))
+
+  @Test
+  def testCollectAggregate(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+
+    val data = new mutable.MutableList[(JInt, String)]
+    val t = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+    tEnv.registerTable("T", t)
+    val sqlQuery = tEnv.sqlQuery(
+      s"""
+         |SELECT
+         |  b, collect(a)
+         |FROM (
+         |  SELECT a, b
+         |  FROM T
+         |  GROUP BY a, b
+         |) GROUP BY b
+         |""".stripMargin)
+
+    val testHarness = createHarnessTester[String, CRow, CRow](
+      sqlQuery.toRetractStream[Row](queryConfig), "groupBy")
+
+    testHarness.setStateBackend(getStateBackend)
+    testHarness.open()
+
+    val operator = getOperator(testHarness)
+    val state = getState(
+      operator,
+      "function",
+      classOf[GroupAggProcessFunction],
+      "acc0_map_dataview").asInstanceOf[MapView[JInt, JInt]]
+    assertTrue(state.isInstanceOf[StateMapView[_, _]])
+    assertTrue(operator.getKeyedStateBackend.isInstanceOf[RocksDBKeyedStateBackend[_]])
+
+    val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+    testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 1).asJava), 1))
+
+    testHarness.processElement(new StreamRecord(CRow(1: JInt, "bbb"), 1))
+    expectedOutput.add(new StreamRecord(CRow("bbb", Map(1 -> 1).asJava), 1))
+
+    testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 1).asJava), 1))
+    expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2).asJava), 1))
+
+    testHarness.processElement(new StreamRecord(CRow(2: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 2).asJava), 1))
+    expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2, 2 -> 1).asJava), 1))
+
+    // remove some state: state may be cleaned up by the state backend
+    // if not accessed beyond ttl time
+    operator.setCurrentKey(Row.of("aaa"))
+    state.remove(2)
+
+    // retract after state has been cleaned up
+    testHarness.processElement(new StreamRecord(CRow(false, 2: JInt, "aaa"), 1))
+
+    val result = testHarness.getOutput
+
+    verify(expectedOutput, result)
+
+    testHarness.close()
+  }
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index e5cceec..c37fd0c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -25,27 +25,27 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRIN
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.functions.KeySelector
 import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator
+import org.apache.flink.streaming.api.operators.{AbstractUdfStreamOperator, OneInputStreamOperator}
+import org.apache.flink.streaming.api.scala.DataStream
+import org.apache.flink.streaming.api.transformations._
 import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
-import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil}
+import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness, TestHarnessUtil}
+import org.apache.flink.table.api.dataview.DataView
 import org.apache.flink.table.api.{StreamQueryConfig, Types}
 import org.apache.flink.table.codegen.GeneratedAggregationsFunction
 import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
 import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.runtime.aggregate.GeneratedAggregations
 import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks}
 import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
+import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase
 import org.apache.flink.table.utils.EncodingUtils
-import org.junit.Rule
-import org.junit.rules.ExpectedException
 
-class HarnessTestBase {
-  // used for accurate exception information checking.
-  val expectedException = ExpectedException.none()
+import _root_.scala.collection.JavaConversions._
 
-  @Rule
-  def thrown = expectedException
+class HarnessTestBase extends StreamingWithStateTestBase {
 
   val longMinWithRetractAggFunction: String =
     EncodingUtils.encodeObjectToString(new LongMinWithRetractAggFunction)
@@ -491,6 +491,68 @@ class HarnessTestBase {
     distinctCountFuncName,
     distinctCountAggCode)
 
+  def createHarnessTester[KEY, IN, OUT](
+      dataStream: DataStream[_],
+      prefixOperatorName: String)
+  : KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = {
+
+    val transformation = extractExpectedTransformation(
+      dataStream.javaStream.getTransformation,
+      prefixOperatorName).asInstanceOf[OneInputTransformation[_, _]]
+    if (transformation == null) {
+      throw new Exception("Can not find the expected transformation")
+    }
+
+    val processOperator = transformation.getOperator.asInstanceOf[OneInputStreamOperator[IN, OUT]]
+    val keySelector = transformation.getStateKeySelector.asInstanceOf[KeySelector[IN, KEY]]
+    val keyType = transformation.getStateKeyType.asInstanceOf[TypeInformation[KEY]]
+
+    createHarnessTester(processOperator, keySelector, keyType)
+      .asInstanceOf[KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT]]
+  }
+
+  private def extractExpectedTransformation(
+      transformation: StreamTransformation[_],
+      prefixOperatorName: String): StreamTransformation[_] = {
+    def extractFromInputs(inputs: StreamTransformation[_]*): StreamTransformation[_] = {
+      for (input <- inputs) {
+        val t = extractExpectedTransformation(input, prefixOperatorName)
+        if (t != null) {
+          return t
+        }
+      }
+      null
+    }
+
+    transformation match {
+      case one: OneInputTransformation[_, _] =>
+        if (one.getName.startsWith(prefixOperatorName)) {
+          one
+        } else {
+          extractExpectedTransformation(one.getInput, prefixOperatorName)
+        }
+      case union: UnionTransformation[_] => extractFromInputs(union.getInputs.toSeq: _*)
+      case p: PartitionTransformation[_] => extractFromInputs(p.getInput)
+      case _: SourceTransformation[_] => null
+      case _ => throw new UnsupportedOperationException("This should not happen.")
+    }
+  }
+
+  def getState(
+      operator: AbstractUdfStreamOperator[_, _],
+      funcName: String,
+      funcClass: Class[_],
+      stateFieldName: String): DataView = {
+    val function = funcClass.getDeclaredField(funcName)
+    function.setAccessible(true)
+    val generatedAggregation =
+      function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
+    val cls = generatedAggregation.getClass
+    val stateField = cls.getDeclaredField(stateFieldName)
+    stateField.setAccessible(true)
+    stateField.get(generatedAggregation).asInstanceOf[DataView]
+  }
+
   def createHarnessTester[IN, OUT, KEY](
     operator: OneInputStreamOperator[IN, OUT],
     keySelector: KeySelector[IN, KEY],
@@ -498,6 +560,14 @@ class HarnessTestBase {
     new KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT](operator, keySelector, keyType)
   }
 
+  def getOperator(testHarness: OneInputStreamOperatorTestHarness[_, _])
+      : AbstractUdfStreamOperator[_, _] = {
+    val operatorField = classOf[OneInputStreamOperatorTestHarness[_, _]]
+      .getDeclaredField("oneInputOperator")
+    operatorField.setAccessible(true)
+    operatorField.get(testHarness).asInstanceOf[AbstractUdfStreamOperator[_, _]]
+  }
+
   def verify(expected: JQueue[Object], actual: JQueue[Object]): Unit = {
     verify(expected, actual, new RowResultSortComparator)
   }