You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2018/12/17 06:13:11 UTC
[flink] branch master updated: [FLINK-11074] [table][tests] Enable
harness tests with RocksdbStateBackend and add harness tests for
CollectAggFunction
This is an automated email from the ASF dual-hosted git repository.
jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 9a45fca [FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBackend and add harness tests for CollectAggFunction
9a45fca is described below
commit 9a45fcabc51fc479af48c024c456aea548620bd2
Author: Dian Fu <fu...@alibaba-inc.com>
AuthorDate: Thu Dec 6 20:25:45 2018 +0800
[FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBackend and add harness tests for CollectAggFunction
This closes #7253
---
.../aggfunctions/CollectAggFunction.scala | 10 +-
.../runtime/harness/AggFunctionHarnessTest.scala | 110 +++++++++++++++++++++
.../table/runtime/harness/HarnessTestBase.scala | 88 +++++++++++++++--
3 files changed, 195 insertions(+), 13 deletions(-)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
index b10be61..5186d66 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
@@ -112,10 +112,12 @@ class CollectAggFunction[E](valueTypeInfo: TypeInformation[_])
def retract(acc: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
val count = acc.map.get(value)
- if (count == 1) {
- acc.map.remove(value)
- } else {
- acc.map.put(value, count - 1)
+ if (count != null) {
+ if (count == 1) {
+ acc.map.remove(value)
+ } else {
+ acc.map.put(value, count - 1)
+ }
}
}
}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala
new file mode 100644
index 0000000..0549339
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.harness
+
+import java.lang.{Integer => JInt}
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.scala._
+import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.dataview.MapView
+import org.apache.flink.table.dataview.StateMapView
+import org.apache.flink.table.runtime.aggregate.GroupAggProcessFunction
+import org.apache.flink.table.runtime.harness.HarnessTestBase.TestStreamQueryConfig
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.types.Row
+import org.junit.Assert.assertTrue
+import org.junit.Test
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+class AggFunctionHarnessTest extends HarnessTestBase {
+ private val queryConfig = new TestStreamQueryConfig(Time.seconds(0), Time.seconds(0))
+
+ @Test
+ def testCollectAggregate(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+
+ val data = new mutable.MutableList[(JInt, String)]
+ val t = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+ tEnv.registerTable("T", t)
+ val sqlQuery = tEnv.sqlQuery(
+ s"""
+ |SELECT
+ | b, collect(a)
+ |FROM (
+ | SELECT a, b
+ | FROM T
+ | GROUP BY a, b
+ |) GROUP BY b
+ |""".stripMargin)
+
+ val testHarness = createHarnessTester[String, CRow, CRow](
+ sqlQuery.toRetractStream[Row](queryConfig), "groupBy")
+
+ testHarness.setStateBackend(getStateBackend)
+ testHarness.open()
+
+ val operator = getOperator(testHarness)
+ val state = getState(
+ operator,
+ "function",
+ classOf[GroupAggProcessFunction],
+ "acc0_map_dataview").asInstanceOf[MapView[JInt, JInt]]
+ assertTrue(state.isInstanceOf[StateMapView[_, _]])
+ assertTrue(operator.getKeyedStateBackend.isInstanceOf[RocksDBKeyedStateBackend[_]])
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1))
+ expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 1).asJava), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, "bbb"), 1))
+ expectedOutput.add(new StreamRecord(CRow("bbb", Map(1 -> 1).asJava), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 1).asJava), 1))
+ expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2).asJava), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(2: JInt, "aaa"), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 2).asJava), 1))
+ expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2, 2 -> 1).asJava), 1))
+
+ // remove some state: state may be cleaned up by the state backend
+ // if not accessed beyond ttl time
+ operator.setCurrentKey(Row.of("aaa"))
+ state.remove(2)
+
+ // retract after state has been cleaned up
+ testHarness.processElement(new StreamRecord(CRow(false, 2: JInt, "aaa"), 1))
+
+ val result = testHarness.getOutput
+
+ verify(expectedOutput, result)
+
+ testHarness.close()
+ }
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index e5cceec..c37fd0c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -25,27 +25,27 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRIN
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator
+import org.apache.flink.streaming.api.operators.{AbstractUdfStreamOperator, OneInputStreamOperator}
+import org.apache.flink.streaming.api.scala.DataStream
+import org.apache.flink.streaming.api.transformations._
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
-import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil}
+import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness, TestHarnessUtil}
+import org.apache.flink.table.api.dataview.DataView
import org.apache.flink.table.api.{StreamQueryConfig, Types}
import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
+import org.apache.flink.table.runtime.aggregate.GeneratedAggregations
import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks}
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
+import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase
import org.apache.flink.table.utils.EncodingUtils
-import org.junit.Rule
-import org.junit.rules.ExpectedException
-class HarnessTestBase {
- // used for accurate exception information checking.
- val expectedException = ExpectedException.none()
+import _root_.scala.collection.JavaConversions._
- @Rule
- def thrown = expectedException
+class HarnessTestBase extends StreamingWithStateTestBase {
val longMinWithRetractAggFunction: String =
EncodingUtils.encodeObjectToString(new LongMinWithRetractAggFunction)
@@ -491,6 +491,68 @@ class HarnessTestBase {
distinctCountFuncName,
distinctCountAggCode)
+ def createHarnessTester[KEY, IN, OUT](
+ dataStream: DataStream[_],
+ prefixOperatorName: String)
+ : KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = {
+
+ val transformation = extractExpectedTransformation(
+ dataStream.javaStream.getTransformation,
+ prefixOperatorName).asInstanceOf[OneInputTransformation[_, _]]
+ if (transformation == null) {
+ throw new Exception("Can not find the expected transformation")
+ }
+
+ val processOperator = transformation.getOperator.asInstanceOf[OneInputStreamOperator[IN, OUT]]
+ val keySelector = transformation.getStateKeySelector.asInstanceOf[KeySelector[IN, KEY]]
+ val keyType = transformation.getStateKeyType.asInstanceOf[TypeInformation[KEY]]
+
+ createHarnessTester(processOperator, keySelector, keyType)
+ .asInstanceOf[KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT]]
+ }
+
+ private def extractExpectedTransformation(
+ transformation: StreamTransformation[_],
+ prefixOperatorName: String): StreamTransformation[_] = {
+ def extractFromInputs(inputs: StreamTransformation[_]*): StreamTransformation[_] = {
+ for (input <- inputs) {
+ val t = extractExpectedTransformation(input, prefixOperatorName)
+ if (t != null) {
+ return t
+ }
+ }
+ null
+ }
+
+ transformation match {
+ case one: OneInputTransformation[_, _] =>
+ if (one.getName.startsWith(prefixOperatorName)) {
+ one
+ } else {
+ extractExpectedTransformation(one.getInput, prefixOperatorName)
+ }
+ case union: UnionTransformation[_] => extractFromInputs(union.getInputs.toSeq: _*)
+ case p: PartitionTransformation[_] => extractFromInputs(p.getInput)
+ case _: SourceTransformation[_] => null
+ case _ => throw new UnsupportedOperationException("This should not happen.")
+ }
+ }
+
+ def getState(
+ operator: AbstractUdfStreamOperator[_, _],
+ funcName: String,
+ funcClass: Class[_],
+ stateFieldName: String): DataView = {
+ val function = funcClass.getDeclaredField(funcName)
+ function.setAccessible(true)
+ val generatedAggregation =
+ function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
+ val cls = generatedAggregation.getClass
+ val stateField = cls.getDeclaredField(stateFieldName)
+ stateField.setAccessible(true)
+ stateField.get(generatedAggregation).asInstanceOf[DataView]
+ }
+
def createHarnessTester[IN, OUT, KEY](
operator: OneInputStreamOperator[IN, OUT],
keySelector: KeySelector[IN, KEY],
@@ -498,6 +560,14 @@ class HarnessTestBase {
new KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT](operator, keySelector, keyType)
}
+ def getOperator(testHarness: OneInputStreamOperatorTestHarness[_, _])
+ : AbstractUdfStreamOperator[_, _] = {
+ val operatorField = classOf[OneInputStreamOperatorTestHarness[_, _]]
+ .getDeclaredField("oneInputOperator")
+ operatorField.setAccessible(true)
+ operatorField.get(testHarness).asInstanceOf[AbstractUdfStreamOperator[_, _]]
+ }
+
def verify(expected: JQueue[Object], actual: JQueue[Object]): Unit = {
verify(expected, actual, new RowResultSortComparator)
}