You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2018/11/22 13:54:30 UTC

[GitHub] asfgit closed pull request #7147: [FLINK-10674] [table] Fix handling of retractions after clean up

asfgit closed pull request #7147: [FLINK-10674] [table] Fix handling of retractions after clean up
URL: https://github.com/apache/flink/pull/7147
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/flink-libraries/flink-table-common/src/main/java/org/apache/flink/table/utils/EncodingUtils.java b/flink-libraries/flink-table-common/src/main/java/org/apache/flink/table/utils/EncodingUtils.java
index 47aac25e897..5531082611d 100644
--- a/flink-libraries/flink-table-common/src/main/java/org/apache/flink/table/utils/EncodingUtils.java
+++ b/flink-libraries/flink-table-common/src/main/java/org/apache/flink/table/utils/EncodingUtils.java
@@ -76,7 +76,7 @@ public static String encodeObjectToString(Serializable obj) {
 			return instance;
 		} catch (Exception e) {
 			throw new ValidationException(
-				"Unable to deserialize string '" + base64String + "' of base class '" + baseClass.getName() + "'.");
+				"Unable to deserialize string '" + base64String + "' of base class '" + baseClass.getName() + "'.", e);
 		}
 	}
 
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
index 397032003ec..f591c4f2299 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
@@ -95,6 +95,12 @@ class GroupAggProcessFunction(
     var inputCnt = cntState.value()
 
     if (null == accumulators) {
+      // don't create a new accumulator for unknown retractions
+      // e.g. retractions that come in right after state clean up
+      if (!inputC.change) {
+        return
+      }
+      // first accumulate message
       firstRow = true
       accumulators = function.createAccumulators()
     } else {
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
similarity index 65%
rename from flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala
rename to flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
index 7c4f5430328..1dce9946877 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala
@@ -24,20 +24,18 @@ import org.apache.flink.api.common.time.Time
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.streaming.api.operators.LegacyKeyedProcessOperator
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
-import org.apache.flink.table.api.StreamQueryConfig
 import org.apache.flink.table.runtime.aggregate._
 import org.apache.flink.table.runtime.harness.HarnessTestBase._
 import org.apache.flink.table.runtime.types.CRow
-import org.apache.flink.types.Row
 import org.junit.Test
 
-class NonWindowHarnessTest extends HarnessTestBase {
+class GroupAggregateHarnessTest extends HarnessTestBase {
 
   protected var queryConfig =
     new TestStreamQueryConfig(Time.seconds(2), Time.seconds(3))
 
   @Test
-  def testNonWindow(): Unit = {
+  def testAggregate(): Unit = {
 
     val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow](
       new GroupAggProcessFunction(
@@ -54,50 +52,49 @@ class NonWindowHarnessTest extends HarnessTestBase {
 
     testHarness.open()
 
+    val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
     // register cleanup timer with 3001
     testHarness.setProcessingTime(1)
 
     testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1: JInt), 1))
     testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb"), 1))
+    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1: JInt), 1))
     // reuse timer 3001
     testHarness.setProcessingTime(1000)
     testHarness.processElement(new StreamRecord(CRow(3L: JLong, 2: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(3L: JLong, 3: JInt), 1))
     testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(4L: JLong, 6: JInt), 1))
 
     // register cleanup timer with 4002
     testHarness.setProcessingTime(1002)
     testHarness.processElement(new StreamRecord(CRow(5L: JLong, 4: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(5L: JLong, 10: JInt), 1))
     testHarness.processElement(new StreamRecord(CRow(6L: JLong, 2: JInt, "bbb"), 1))
+    expectedOutput.add(new StreamRecord(CRow(6L: JLong, 3: JInt), 1))
 
     // trigger cleanup timer and register cleanup timer with 7003
     testHarness.setProcessingTime(4003)
     testHarness.processElement(new StreamRecord(CRow(7L: JLong, 5: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(7L: JLong, 5: JInt), 1))
     testHarness.processElement(new StreamRecord(CRow(8L: JLong, 6: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(8L: JLong, 11: JInt), 1))
     testHarness.processElement(new StreamRecord(CRow(9L: JLong, 7: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(9L: JLong, 18: JInt), 1))
     testHarness.processElement(new StreamRecord(CRow(10L: JLong, 3: JInt, "bbb"), 1))
+    expectedOutput.add(new StreamRecord(CRow(10L: JLong, 3: JInt), 1))
 
     val result = testHarness.getOutput
 
-    val expectedOutput = new ConcurrentLinkedQueue[Object]()
-
-    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(3L: JLong, 3: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(4L: JLong, 6: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(5L: JLong, 10: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(6L: JLong, 3: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(7L: JLong, 5: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(8L: JLong, 11: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(9L: JLong, 18: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(10L: JLong, 3: JInt), 1))
-
     verify(expectedOutput, result)
 
     testHarness.close()
   }
 
   @Test
-  def testNonWindowWithRetract(): Unit = {
+  def testAggregateWithRetract(): Unit = {
 
     val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow](
       new GroupAggProcessFunction(
@@ -114,42 +111,136 @@ class NonWindowHarnessTest extends HarnessTestBase {
 
     testHarness.open()
 
+    val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
     // register cleanup timer with 3001
     testHarness.setProcessingTime(1)
 
+    // accumulate
     testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa"), 1))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1: JInt), 1))
+
+    // accumulate
     testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb"), 2))
+    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1: JInt), 2))
+
+    // retract for insertion
     testHarness.processElement(new StreamRecord(CRow(3L: JLong, 2: JInt, "aaa"), 3))
+    expectedOutput.add(new StreamRecord(CRow(false, 3L: JLong, 1: JInt), 3))
+    expectedOutput.add(new StreamRecord(CRow(3L: JLong, 3: JInt), 3))
+
+    // retract for deletion
+    testHarness.processElement(new StreamRecord(CRow(false, 3L: JLong, 2: JInt, "aaa"), 3))
+    expectedOutput.add(new StreamRecord(CRow(false, 3L: JLong, 3: JInt), 3))
+    expectedOutput.add(new StreamRecord(CRow(3L: JLong, 1: JInt), 3))
+
+    // accumulate
     testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt, "ccc"), 4))
+    expectedOutput.add(new StreamRecord(CRow(4L: JLong, 3: JInt), 4))
 
     // trigger cleanup timer and register cleanup timer with 6002
     testHarness.setProcessingTime(3002)
-    testHarness.processElement(new StreamRecord(CRow(5L: JLong, 4: JInt, "aaa"), 5))
-    testHarness.processElement(new StreamRecord(CRow(6L: JLong, 2: JInt, "bbb"), 6))
-    testHarness.processElement(new StreamRecord(CRow(7L: JLong, 5: JInt, "aaa"), 7))
-    testHarness.processElement(new StreamRecord(CRow(8L: JLong, 6: JInt, "eee"), 8))
-    testHarness.processElement(new StreamRecord(CRow(9L: JLong, 7: JInt, "aaa"), 9))
-    testHarness.processElement(new StreamRecord(CRow(10L: JLong, 3: JInt, "bbb"), 10))
-
-    val result = testHarness.getOutput
 
-    val expectedOutput = new ConcurrentLinkedQueue[Object]()
+    // retract after clean up
+    testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3: JInt, "ccc"), 4))
 
-    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1: JInt), 1))
-    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1: JInt), 2))
-    expectedOutput.add(new StreamRecord(CRow(false, 3L: JLong, 1: JInt), 3))
-    expectedOutput.add(new StreamRecord(CRow(3L: JLong, 3: JInt), 3))
-    expectedOutput.add(new StreamRecord(CRow(4L: JLong, 3: JInt), 4))
+    // accumulate
+    testHarness.processElement(new StreamRecord(CRow(5L: JLong, 4: JInt, "aaa"), 5))
     expectedOutput.add(new StreamRecord(CRow(5L: JLong, 4: JInt), 5))
+    testHarness.processElement(new StreamRecord(CRow(6L: JLong, 2: JInt, "bbb"), 6))
     expectedOutput.add(new StreamRecord(CRow(6L: JLong, 2: JInt), 6))
+
+    // retract
+    testHarness.processElement(new StreamRecord(CRow(7L: JLong, 5: JInt, "aaa"), 7))
     expectedOutput.add(new StreamRecord(CRow(false, 7L: JLong, 4: JInt), 7))
     expectedOutput.add(new StreamRecord(CRow(7L: JLong, 9: JInt), 7))
+
+    // accumulate
+    testHarness.processElement(new StreamRecord(CRow(8L: JLong, 6: JInt, "eee"), 8))
     expectedOutput.add(new StreamRecord(CRow(8L: JLong, 6: JInt), 8))
+
+    // retract
+    testHarness.processElement(new StreamRecord(CRow(9L: JLong, 7: JInt, "aaa"), 9))
     expectedOutput.add(new StreamRecord(CRow(false, 9L: JLong, 9: JInt), 9))
     expectedOutput.add(new StreamRecord(CRow(9L: JLong, 16: JInt), 9))
+    testHarness.processElement(new StreamRecord(CRow(10L: JLong, 3: JInt, "bbb"), 10))
     expectedOutput.add(new StreamRecord(CRow(false, 10L: JLong, 2: JInt), 10))
     expectedOutput.add(new StreamRecord(CRow(10L: JLong, 5: JInt), 10))
 
+    val result = testHarness.getOutput
+
+    verify(expectedOutput, result)
+
+    testHarness.close()
+  }
+
+  @Test
+  def testDistinctAggregateWithRetract(): Unit = {
+
+    val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow](
+      new GroupAggProcessFunction(
+        genDistinctCountAggFunction,
+        distinctCountAggregationStateType,
+        true,
+        queryConfig))
+
+    val testHarness =
+      createHarnessTester(
+        processFunction,
+        new TupleRowKeySelector[String](2),
+        BasicTypeInfo.STRING_TYPE_INFO)
+
+    testHarness.open()
+
+    val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+    // register cleanup timer with 3001
+    testHarness.setProcessingTime(1)
+
+    // insert
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa")))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
+    testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb")))
+    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong)))
+
+    // distinct count retract then accumulate for downstream operators
+    testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb")))
+    expectedOutput.add(new StreamRecord(CRow(false, 2L: JLong, 1L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong)))
+
+    // update count for accumulate
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2: JInt, "aaa")))
+    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 1L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong)))
+
+    // update count for retraction
+    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2: JInt, "aaa")))
+    expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong)))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
+
+    // insert
+    testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt, "ccc")))
+    expectedOutput.add(new StreamRecord(CRow(4L: JLong, 1L: JLong)))
+
+    // retract entirely
+    testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3: JInt, "ccc")))
+    expectedOutput.add(new StreamRecord(CRow(false, 4L: JLong, 1L: JLong)))
+
+    // trigger cleanup timer and register cleanup timer with 6002
+    testHarness.setProcessingTime(3002)
+
+    // insert
+    testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa")))
+    expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong)))
+
+    // trigger cleanup timer and register cleanup timer with 9002
+    testHarness.setProcessingTime(6002)
+
+    // retract after cleanup
+    testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 1: JInt, "aaa")))
+
+    val result = testHarness.getOutput
+
     verify(expectedOutput, result)
 
     testHarness.close()
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index f70d991e50b..e5cceecc560 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -19,8 +19,9 @@ package org.apache.flink.table.runtime.harness
 
 import java.util.{Comparator, Queue => JQueue}
 
+import org.apache.flink.api.common.state.{MapStateDescriptor, StateDescriptor}
 import org.apache.flink.api.common.time.Time
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{INT_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO}
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRING_TYPE_INFO}
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.functions.KeySelector
 import org.apache.flink.api.java.typeutils.RowTypeInfo
@@ -28,11 +29,11 @@ import org.apache.flink.streaming.api.operators.OneInputStreamOperator
 import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
 import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil}
-import org.apache.flink.table.api.StreamQueryConfig
+import org.apache.flink.table.api.{StreamQueryConfig, Types}
 import org.apache.flink.table.codegen.GeneratedAggregationsFunction
-import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
-import org.apache.flink.table.functions.aggfunctions.{IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
+import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
+import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
 import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks}
 import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
 import org.apache.flink.table.utils.EncodingUtils
@@ -55,20 +56,16 @@ class HarnessTestBase {
   val intSumWithRetractAggFunction: String =
     EncodingUtils.encodeObjectToString(new IntSumWithRetractAggFunction)
 
+  val distinctCountAggFunction: String =
+    EncodingUtils.encodeObjectToString(new CountAggFunction())
+
   protected val MinMaxRowType = new RowTypeInfo(Array[TypeInformation[_]](
     LONG_TYPE_INFO,
     STRING_TYPE_INFO,
     LONG_TYPE_INFO),
     Array("rowtime", "a", "b"))
 
-  protected val SumRowType = new RowTypeInfo(Array[TypeInformation[_]](
-    LONG_TYPE_INFO,
-    INT_TYPE_INFO,
-    STRING_TYPE_INFO),
-    Array("a", "b", "c"))
-
   protected val minMaxCRowType = new CRowTypeInfo(MinMaxRowType)
-  protected val sumCRowType = new CRowTypeInfo(SumRowType)
 
   protected val minMaxAggregates: Array[AggregateFunction[_, _]] =
     Array(new LongMinWithRetractAggFunction,
@@ -77,15 +74,28 @@ class HarnessTestBase {
   protected val sumAggregates: Array[AggregateFunction[_, _]] =
     Array(new IntSumWithRetractAggFunction).asInstanceOf[Array[AggregateFunction[_, _]]]
 
+  protected val distinctCountAggregates: Array[AggregateFunction[_, _]] =
+    Array(new CountAggFunction).asInstanceOf[Array[AggregateFunction[_, _]]]
+
   protected val minMaxAggregationStateType: RowTypeInfo =
     new RowTypeInfo(minMaxAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)
 
   protected val sumAggregationStateType: RowTypeInfo =
     new RowTypeInfo(sumAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)
 
+  protected val distinctCountAggregationStateType: RowTypeInfo =
+    new RowTypeInfo(distinctCountAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)
+
+  protected val distinctCountDescriptor: String = EncodingUtils.encodeObjectToString(
+    new MapStateDescriptor("distinctAgg0", distinctCountAggregationStateType, Types.LONG))
+
+  protected val minMaxFuncName = "MinMaxAggregateHelper"
+  protected val sumFuncName = "SumAggregationHelper"
+  protected val distinctCountFuncName = "DistinctCountAggregationHelper"
+
   val minMaxCode: String =
     s"""
-      |public class MinMaxAggregateHelper
+      |public class $minMaxFuncName
       |  extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
       |
       |  transient org.apache.flink.table.functions.aggfunctions.LongMinWithRetractAggFunction
@@ -94,7 +104,7 @@ class HarnessTestBase {
       |  transient org.apache.flink.table.functions.aggfunctions.LongMaxWithRetractAggFunction
       |    fmax = null;
       |
-      |  public MinMaxAggregateHelper() throws Exception {
+      |  public $minMaxFuncName() throws Exception {
       |
       |    fmin = (org.apache.flink.table.functions.aggfunctions.LongMinWithRetractAggFunction)
       |    ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
@@ -207,25 +217,25 @@ class HarnessTestBase {
 
   val sumAggCode: String =
     s"""
-      |public final class SumAggregationHelper
+      |public final class $sumFuncName
       |  extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
       |
       |
-      |transient org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction
-      |sum = null;
-      |private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache
+      |  transient org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction
+      |  sum = null;
+      |  private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache
       |    .flink.table.functions.aggfunctions.SumWithRetractAccumulator> accIt0 =
       |      new org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache.flink
       |      .table
       |      .functions.aggfunctions.SumWithRetractAccumulator>();
       |
-      |  public SumAggregationHelper() throws Exception {
+      |  public $sumFuncName() throws Exception {
       |
-      |sum = (org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction)
-      |${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
-      |  "$intSumWithRetractAggFunction",
-      |  ${classOf[UserDefinedFunction].getCanonicalName}.class);
-      |}
+      |    sum = (org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction)
+      |      ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
+      |        "$intSumWithRetractAggFunction",
+      |        ${classOf[UserDefinedFunction].getCanonicalName}.class);
+      |  }
       |
       |  public final void setAggregationResults(
       |    org.apache.flink.types.Row accs,
@@ -256,6 +266,12 @@ class HarnessTestBase {
       |  public final void retract(
       |    org.apache.flink.types.Row accs,
       |    org.apache.flink.types.Row input) {
+      |
+      |    sum.retract(
+      |      ((org.apache.flink.table.functions.aggfunctions.SumWithRetractAccumulator) accs
+      |      .getField
+      |      (0)),
+      |      (java.lang.Integer) input.getField(1));
       |  }
       |
       |  public final org.apache.flink.types.Row createAccumulators()
@@ -281,6 +297,162 @@ class HarnessTestBase {
       |      input.getField(0));
       |  }
       |
+      |  public final org.apache.flink.types.Row createOutputRow() {
+      |    return new org.apache.flink.types.Row(2);
+      |  }
+      |
+      |
+      |  public final org.apache.flink.types.Row mergeAccumulatorsPair(
+      |    org.apache.flink.types.Row a,
+      |    org.apache.flink.types.Row b)
+      |            {
+      |
+      |      return a;
+      |
+      |  }
+      |
+      |  public final void resetAccumulator(
+      |    org.apache.flink.types.Row accs) {
+      |  }
+      |
+      |  public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) {
+      |  }
+      |
+      |  public void cleanup() {
+      |  }
+      |
+      |  public void close() {
+      |  }
+      |}
+      |""".stripMargin
+
+    val distinctCountAggCode: String =
+    s"""
+      |public final class $distinctCountFuncName
+      |  extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
+      |
+      |  final org.apache.flink.table.functions.aggfunctions.CountAggFunction count;
+      |
+      |  final org.apache.flink.table.api.dataview.MapView acc0_distinctValueMap_dataview;
+      |
+      |  final java.lang.reflect.Field distinctValueMap =
+      |      org.apache.flink.api.java.typeutils.TypeExtractor.getDeclaredField(
+      |        org.apache.flink.table.functions.aggfunctions.DistinctAccumulator.class,
+      |        "distinctValueMap");
+      |
+      |
+      |  private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache
+      |    .flink.table.functions.aggfunctions.CountAccumulator> accIt0 =
+      |      new org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache.flink
+      |      .table
+      |      .functions.aggfunctions.CountAccumulator>();
+      |
+      |  public $distinctCountFuncName() throws Exception {
+      |
+      |    count = (org.apache.flink.table.functions.aggfunctions.CountAggFunction)
+      |      ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
+      |      "$distinctCountAggFunction",
+      |      ${classOf[UserDefinedFunction].getCanonicalName}.class);
+      |
+      |    distinctValueMap.setAccessible(true);
+      |  }
+      |
+      |  public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) {
+      |    org.apache.flink.api.common.state.StateDescriptor acc0_distinctValueMap_dataview_desc =
+      |      (org.apache.flink.api.common.state.StateDescriptor)
+      |      ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
+      |      "$distinctCountDescriptor",
+      |      ${classOf[StateDescriptor[_, _]].getCanonicalName}.class,
+      |      ctx.getUserCodeClassLoader());
+      |    acc0_distinctValueMap_dataview = new org.apache.flink.table.dataview.StateMapView(
+      |          ctx.getMapState((org.apache.flink.api.common.state.MapStateDescriptor)
+      |          acc0_distinctValueMap_dataview_desc));
+      |  }
+      |
+      |  public final void setAggregationResults(
+      |    org.apache.flink.types.Row accs,
+      |    org.apache.flink.types.Row output) {
+      |
+      |    org.apache.flink.table.functions.AggregateFunction baseClass0 =
+      |      (org.apache.flink.table.functions.AggregateFunction)
+      |      count;
+      |
+      |    org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
+      |      (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
+      |    org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
+      |      (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
+      |      distinctAcc0.getRealAcc();
+      |
+      |    output.setField(1, baseClass0.getValue(acc0));
+      |  }
+      |
+      |  public final void accumulate(
+      |    org.apache.flink.types.Row accs,
+      |    org.apache.flink.types.Row input) throws Exception {
+      |
+      |    org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
+      |      (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
+      |
+      |    distinctValueMap.set(distinctAcc0, acc0_distinctValueMap_dataview);
+      |
+      |    if (distinctAcc0.add(
+      |          org.apache.flink.types.Row.of((java.lang.Integer) input.getField(1)))) {
+      |        org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
+      |          (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
+      |          distinctAcc0.getRealAcc();
+      |
+      |
+      |        count.accumulate(acc0, (java.lang.Integer) input.getField(1));
+      |    }
+      |  }
+      |
+      |  public final void retract(
+      |    org.apache.flink.types.Row accs,
+      |    org.apache.flink.types.Row input) throws Exception {
+      |
+      |    org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
+      |      (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0);
+      |
+      |    distinctValueMap.set(distinctAcc0, acc0_distinctValueMap_dataview);
+      |
+      |    if (distinctAcc0.remove(
+      |          org.apache.flink.types.Row.of((java.lang.Integer) input.getField(1)))) {
+      |        org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
+      |          (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
+      |            distinctAcc0.getRealAcc();
+      |
+      |        count.retract(acc0 , (java.lang.Integer) input.getField(1));
+      |      }
+      |  }
+      |
+      |  public final org.apache.flink.types.Row createAccumulators()
+      |     {
+      |
+      |      org.apache.flink.types.Row accs = new org.apache.flink.types.Row(1);
+      |
+      |      org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 =
+      |        (org.apache.flink.table.functions.aggfunctions.CountAccumulator)
+      |        count.createAccumulator();
+      |      org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 =
+      |        (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator)
+      |        new org.apache.flink.table.functions.aggfunctions.DistinctAccumulator (acc0);
+      |      accs.setField(
+      |        0,
+      |        distinctAcc0);
+      |
+      |        return accs;
+      |  }
+      |
+      |  public final void setForwardedFields(
+      |    org.apache.flink.types.Row input,
+      |    org.apache.flink.types.Row output)
+      |     {
+      |
+      |    output.setField(
+      |      0,
+      |      input.getField(0));
+      |  }
+      |
       |  public final void setConstantFlags(org.apache.flink.types.Row output)
       |     {
       |
@@ -304,10 +476,8 @@ class HarnessTestBase {
       |    org.apache.flink.types.Row accs) {
       |  }
       |
-      |  public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) {
-      |  }
-      |
       |  public void cleanup() {
+      |    acc0_distinctValueMap_dataview.clear();
       |  }
       |
       |  public void close() {
@@ -315,12 +485,11 @@ class HarnessTestBase {
       |}
       |""".stripMargin
 
-
-  protected val minMaxFuncName = "MinMaxAggregateHelper"
-  protected val sumFuncName = "SumAggregationHelper"
-
   protected val genMinMaxAggFunction = GeneratedAggregationsFunction(minMaxFuncName, minMaxCode)
   protected val genSumAggFunction = GeneratedAggregationsFunction(sumFuncName, sumAggCode)
+  protected val genDistinctCountAggFunction = GeneratedAggregationsFunction(
+    distinctCountFuncName,
+    distinctCountAggCode)
 
   def createHarnessTester[IN, OUT, KEY](
     operator: OneInputStreamOperator[IN, OUT],
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
index c3da65f887a..46dde8e0225 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
@@ -263,6 +263,43 @@ class SqlITCase extends StreamingWithStateTestBase {
     assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
   }
 
+  @Test
+  def testDistinctWithRetraction(): Unit = {
+
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val data = new mutable.MutableList[(Int, Long, String)]
+    data.+=((1, 1L, "Hi"))
+    data.+=((1, 1L, "Hi World"))
+    data.+=((1, 1L, "Test"))
+    data.+=((2, 1L, "Hi World"))
+    data.+=((2, 1L, "Test"))
+    data.+=((3, 1L, "Hi World"))
+    data.+=((3, 1L, "Hi World"))
+    data.+=((3, 1L, "Hi World"))
+    data.+=((4, 1L, "Hi World"))
+    data.+=((4, 1L, "Test"))
+
+    val t = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c)
+    tEnv.registerTable("MyTable", t)
+
+    // "1,1,3", "2,1,2", "3,1,1", "4,1,2"
+    val distinct = "SELECT a, COUNT(DISTINCT b) AS distinct_b, COUNT(DISTINCT c) AS distinct_c " +
+      "FROM MyTable GROUP BY a"
+    val nestedDistinct = s"SELECT distinct_b, COUNT(DISTINCT distinct_c) " +
+      s"FROM ($distinct) GROUP BY distinct_b"
+
+    val result = tEnv.sqlQuery(nestedDistinct).toRetractStream[Row]
+    result.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+
+    env.execute()
+
+    val expected = List("1,3")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
   @Test
   def testUnboundedGroupByCollect(): Unit = {
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services