You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/08/29 20:11:10 UTC

[3/5] flink git commit: [FLINK-7206] [table] Add DataView to support direct state access in AggregateFunction accumulators.

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
index ccc4b46..09d98ad 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
@@ -23,11 +23,8 @@ import org.apache.flink.configuration.Configuration
 import org.apache.flink.streaming.api.functions.ProcessFunction
 import org.apache.flink.types.Row
 import org.apache.flink.util.{Collector, Preconditions}
-import org.apache.flink.api.common.state.ValueStateDescriptor
+import org.apache.flink.api.common.state._
 import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.api.common.state.ValueState
-import org.apache.flink.api.common.state.MapState
-import org.apache.flink.api.common.state.MapStateDescriptor
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.ListTypeInfo
 import java.util.{List => JList}
@@ -75,6 +72,7 @@ class ProcTimeBoundedRowsOver(
       genAggregations.code)
     LOG.debug("Instantiating AggregateHelper.")
     function = clazz.newInstance()
+    function.open(getRuntimeContext)
 
     output = new CRow(function.createOutputRow(), true)
     // We keep the elements received in a Map state keyed
@@ -194,6 +192,11 @@ class ProcTimeBoundedRowsOver(
 
     if (needToCleanupState(timestamp)) {
       cleanupState(rowMapState, accumulatorState, counterState, smallestTsState)
+      function.cleanup()
     }
   }
+
+  override def close(): Unit = {
+    function.close()
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala
index 7a7b44d..4fb5595 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala
@@ -21,9 +21,8 @@ import org.apache.flink.configuration.Configuration
 import org.apache.flink.streaming.api.functions.ProcessFunction
 import org.apache.flink.types.Row
 import org.apache.flink.util.Collector
-import org.apache.flink.api.common.state.ValueStateDescriptor
+import org.apache.flink.api.common.state.{StateDescriptor, ValueState, ValueStateDescriptor}
 import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.api.common.state.ValueState
 import org.apache.flink.table.api.StreamQueryConfig
 import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
 import org.apache.flink.table.runtime.types.CRow
@@ -56,6 +55,7 @@ class ProcTimeUnboundedOver(
       genAggregations.code)
     LOG.debug("Instantiating AggregateHelper.")
     function = clazz.newInstance()
+    function.open(getRuntimeContext)
 
     output = new CRow(function.createOutputRow(), true)
     val stateDescriptor: ValueStateDescriptor[Row] =
@@ -97,6 +97,11 @@ class ProcTimeUnboundedOver(
 
     if (needToCleanupState(timestamp)) {
       cleanupState(state)
+      function.cleanup()
     }
   }
+  
+  override def close(): Unit = {
+    function.close()
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
index 8a0d682..1ee2693 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
@@ -78,6 +78,7 @@ class RowTimeBoundedRangeOver(
       genAggregations.code)
     LOG.debug("Instantiating AggregateHelper.")
     function = clazz.newInstance()
+    function.open(getRuntimeContext)
 
     output = new CRow(function.createOutputRow(), true)
 
@@ -158,6 +159,7 @@ class RowTimeBoundedRangeOver(
         if (noRecordsToProcess) {
           // we clean the state
           cleanupState(dataState, accumulatorState, lastTriggeringTsState)
+          function.cleanup()
         } else {
           // There are records left to process because a watermark has not been received yet.
           // This would only happen if the input stream has stopped. So we don't need to clean up.
@@ -242,6 +244,10 @@ class RowTimeBoundedRangeOver(
     // update cleanup timer
     registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
   }
+
+  override def close(): Unit = {
+    function.close()
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
index ba65846..60200bc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
@@ -84,6 +84,7 @@ class RowTimeBoundedRowsOver(
       genAggregations.code)
     LOG.debug("Instantiating AggregateHelper.")
     function = clazz.newInstance()
+    function.open(getRuntimeContext)
 
     output = new CRow(function.createOutputRow(), true)
 
@@ -168,6 +169,7 @@ class RowTimeBoundedRowsOver(
         if (noRecordsToProcess) {
           // We clean the state
           cleanupState(dataState, accumulatorState, dataCountState, lastTriggeringTsState)
+          function.cleanup()
         } else {
           // There are records left to process because a watermark has not been received yet.
           // This would only happen if the input stream has stopped. So we don't need to clean up.
@@ -264,6 +266,10 @@ class RowTimeBoundedRowsOver(
     // update cleanup timer
     registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
   }
+
+  override def close(): Unit = {
+    function.close()
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
index 9210c00..c8183c9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
@@ -71,6 +71,7 @@ abstract class RowTimeUnboundedOver(
       genAggregations.code)
     LOG.debug("Instantiating AggregateHelper.")
     function = clazz.newInstance()
+    function.open(getRuntimeContext)
 
     output = new CRow(function.createOutputRow(), true)
     sortedTimestamps = new util.LinkedList[Long]()
@@ -150,6 +151,7 @@ abstract class RowTimeUnboundedOver(
         if (noRecordsToProcess) {
           // we clean the state
           cleanupState(rowMapState, accumulatorState)
+          function.cleanup()
         } else {
           // There are records left to process because a watermark has not been received yet.
           // This would only happen if the input stream has stopped. So we don't need to clean up.
@@ -241,6 +243,9 @@ abstract class RowTimeUnboundedOver(
     lastAccumulator: Row,
     out: Collector[CRow]): Unit
 
+  override def close(): Unit = {
+    function.close()
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
index 4d06bc2..14f812a 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
@@ -18,7 +18,10 @@
 
 package org.apache.flink.table.runtime.utils;
 
+import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.table.api.dataview.ListView;
+import org.apache.flink.table.api.dataview.MapView;
 import org.apache.flink.table.functions.AggregateFunction;
 
 import java.util.Iterator;
@@ -135,4 +138,200 @@ public class JavaUserDefinedAggFunctions {
 			accumulator.count -= iWeight;
 		}
 	}
+
+	/**
+	 * CountDistinct accumulator.
+	 */
+	public static class CountDistinctAccum {
+		public MapView<String, Integer> map;
+		public long count;
+	}
+
+	/**
+	 * CountDistinct aggregate.
+	 */
+	public static class CountDistinct extends AggregateFunction<Long, CountDistinctAccum> {
+
+		@Override
+		public CountDistinctAccum createAccumulator() {
+			CountDistinctAccum accum = new CountDistinctAccum();
+			accum.map = new MapView<>(Types.STRING, Types.INT);
+			accum.count = 0L;
+			return accum;
+		}
+
+		//Overloaded accumulate method
+		public void accumulate(CountDistinctAccum accumulator, String id) {
+			try {
+				Integer cnt = accumulator.map.get(id);
+				if (cnt != null) {
+					cnt += 1;
+					accumulator.map.put(id, cnt);
+				} else {
+					accumulator.map.put(id, 1);
+					accumulator.count += 1;
+				}
+			} catch (Exception e) {
+				e.printStackTrace();
+			}
+		}
+
+		//Overloaded accumulate method
+		public void accumulate(CountDistinctAccum accumulator, long id) {
+			try {
+				Integer cnt = accumulator.map.get(String.valueOf(id));
+				if (cnt != null) {
+					cnt += 1;
+					accumulator.map.put(String.valueOf(id), cnt);
+				} else {
+					accumulator.map.put(String.valueOf(id), 1);
+					accumulator.count += 1;
+				}
+			} catch (Exception e) {
+				e.printStackTrace();
+			}
+		}
+
+		@Override
+		public Long getValue(CountDistinctAccum accumulator) {
+			return accumulator.count;
+		}
+	}
+
+	/**
+	 * CountDistinct aggregate with merge.
+	 */
+	public static class CountDistinctWithMerge extends CountDistinct {
+
+		//Overloaded merge method
+		public void merge(CountDistinctAccum acc, Iterable<CountDistinctAccum> it) {
+			Iterator<CountDistinctAccum> iter = it.iterator();
+			while (iter.hasNext()) {
+				CountDistinctAccum mergeAcc = iter.next();
+				acc.count += mergeAcc.count;
+
+				try {
+					Iterator<String> itrMap = mergeAcc.map.keys().iterator();
+					while (itrMap.hasNext()) {
+						String key = itrMap.next();
+						Integer cnt = mergeAcc.map.get(key);
+						if (acc.map.contains(key)) {
+							acc.map.put(key, acc.map.get(key) + cnt);
+						} else {
+							acc.map.put(key, cnt);
+						}
+					}
+				} catch (Exception e) {
+					e.printStackTrace();
+				}
+			}
+		}
+	}
+
+	/**
+	 * CountDistinct aggregate with merge and reset.
+	 */
+	public static class CountDistinctWithMergeAndReset extends CountDistinctWithMerge {
+
+		//Overloaded retract method
+		public void resetAccumulator(CountDistinctAccum acc) {
+			acc.map.clear();
+			acc.count = 0;
+		}
+	}
+
+	/**
+	 * CountDistinct aggregate with retract.
+	 */
+	public static class CountDistinctWithRetractAndReset extends CountDistinct {
+
+		//Overloaded retract method
+		public void retract(CountDistinctAccum accumulator, long id) {
+			try {
+				Integer cnt = accumulator.map.get(String.valueOf(id));
+				if (cnt != null) {
+					cnt -= 1;
+					if (cnt <= 0) {
+						accumulator.map.remove(String.valueOf(id));
+						accumulator.count -= 1;
+					} else {
+						accumulator.map.put(String.valueOf(id), cnt);
+					}
+				}
+			} catch (Exception e) {
+				e.printStackTrace();
+			}
+		}
+
+		//Overloaded retract method
+		public void resetAccumulator(CountDistinctAccum acc) {
+			acc.map.clear();
+			acc.count = 0;
+		}
+	}
+
+	/**
+	 * Accumulator for test DataView.
+	 */
+	public static class DataViewTestAccum {
+		public MapView<String, Integer> map;
+		public MapView<String, Integer> map2; // for test not initialized
+		public long count;
+		private ListView<Long> list = new ListView<>(Types.LONG);
+
+		public ListView<Long> getList() {
+			return list;
+		}
+
+		public void setList(ListView<Long> list) {
+			this.list = list;
+		}
+	}
+
+	public static boolean isCloseCalled = false;
+
+	/**
+	 * Aggregate for test DataView.
+	 */
+	public static class DataViewTestAgg extends AggregateFunction<Long, DataViewTestAccum> {
+		@Override
+		public DataViewTestAccum createAccumulator() {
+			DataViewTestAccum accum = new DataViewTestAccum();
+			accum.map = new MapView<>(Types.STRING, Types.INT);
+			accum.count = 0L;
+			return accum;
+		}
+
+		// Overloaded accumulate method
+		public void accumulate(DataViewTestAccum accumulator, String a, long b) {
+			try {
+				if (!accumulator.map.contains(a)) {
+					accumulator.map.put(a, 1);
+					accumulator.count++;
+				}
+
+				accumulator.list.add(b);
+			} catch (Exception e) {
+				e.printStackTrace();
+			}
+		}
+
+		@Override
+		public Long getValue(DataViewTestAccum accumulator) {
+			long sum = accumulator.count;
+			try {
+				for (Long value : accumulator.list.get()) {
+					sum += value;
+				}
+			} catch (Exception e) {
+				e.printStackTrace();
+			}
+			return sum;
+		}
+
+		@Override
+		public void close() {
+			isCloseCalled = true;
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/ListViewSerializerTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/ListViewSerializerTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/ListViewSerializerTest.scala
new file mode 100644
index 0000000..3f70bce
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/ListViewSerializerTest.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.dataview
+
+import java.lang.Long
+import java.util.Random
+
+import org.apache.flink.api.common.typeutils.base.{ListSerializer, LongSerializer}
+import org.apache.flink.api.common.typeutils.{SerializerTestBase, TypeSerializer}
+import org.apache.flink.table.api.dataview.ListView
+
+/**
+  * A test for the [[ListViewSerializer]].
+  */
+class ListViewSerializerTest extends SerializerTestBase[ListView[Long]] {
+
+  override protected def createSerializer(): TypeSerializer[ListView[Long]] = {
+    val listSerializer = new ListSerializer[Long](LongSerializer.INSTANCE)
+    new ListViewSerializer[Long](listSerializer)
+  }
+
+  override protected def getLength: Int = -1
+
+  override protected def getTypeClass: Class[ListView[Long]] = classOf[ListView[Long]]
+
+  override protected def getTestData: Array[ListView[Long]] = {
+    val rnd = new Random(321332)
+
+    // empty
+    val listview1 = new ListView[Long]()
+
+    // single element
+    val listview2 = new ListView[Long]()
+    listview2.add(12345L)
+
+    // multiple elements
+    val listview3 = new ListView[Long]()
+    var i = 0
+    while (i < rnd.nextInt(200)) {
+      listview3.add(rnd.nextLong)
+      i += 1
+    }
+
+    Array[ListView[Long]](listview1, listview2, listview3)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/MapViewSerializerTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/MapViewSerializerTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/MapViewSerializerTest.scala
new file mode 100644
index 0000000..15f9b02
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/MapViewSerializerTest.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.dataview
+
+import java.lang.Long
+import java.util.Random
+
+import org.apache.flink.api.common.typeutils.base.{LongSerializer, MapSerializer, StringSerializer}
+import org.apache.flink.api.common.typeutils.{SerializerTestBase, TypeSerializer}
+import org.apache.flink.table.api.dataview.MapView
+
+/**
+  * A test for the [[MapViewSerializer]].
+  */
+class MapViewSerializerTest extends SerializerTestBase[MapView[Long, String]] {
+
+  override protected def createSerializer(): TypeSerializer[MapView[Long, String]] = {
+    val mapSerializer = new MapSerializer[Long, String](LongSerializer.INSTANCE,
+      StringSerializer.INSTANCE)
+    new MapViewSerializer[Long, String](mapSerializer)
+  }
+
+  override protected def getLength: Int = -1
+
+  override protected def getTypeClass: Class[MapView[Long, String]] =
+    classOf[MapView[Long, String]]
+
+  override protected def getTestData: Array[MapView[Long, String]] = {
+    val rnd = new Random(321654)
+
+    // empty
+    val mapview1 = new MapView[Long, String]()
+
+    // single element
+    val mapview2 = new MapView[Long, String]()
+    mapview2.put(12345L, "12345L")
+
+    // multiple elements
+    val mapview3 = new MapView[Long, String]()
+    var i = 0
+    while (i < rnd.nextInt(200)) {
+      mapview3.put(rnd.nextLong, Long.toString(rnd.nextLong))
+      i += 1
+    }
+
+    // null-value maps
+    val mapview4 = new MapView[Long, String]()
+    mapview4.put(999L, null)
+
+    Array[MapView[Long, String]](mapview1, mapview2, mapview3, mapview4)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
index d563f96..cf96d19 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
@@ -23,7 +23,7 @@ import java.math.BigDecimal
 import org.apache.flink.api.scala._
 import org.apache.flink.api.scala.util.CollectionDataSets
 import org.apache.flink.table.api.TableEnvironment
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvgWithMergeAndReset
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinctWithMergeAndReset, WeightedAvgWithMergeAndReset}
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.functions.aggfunctions.CountAggFunction
 import org.apache.flink.table.runtime.utils.TableProgramsCollectionTestBase
@@ -226,13 +226,14 @@ class AggregationsITCase(
     val tEnv = TableEnvironment.getTableEnvironment(env, config)
     val countFun = new CountAggFunction
     val wAvgFun = new WeightedAvgWithMergeAndReset
+    val countDistinct = new CountDistinctWithMergeAndReset
 
     val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
       .groupBy('b)
-      .select('b, 'a.sum, countFun('c), wAvgFun('b, 'a), wAvgFun('a, 'a))
+      .select('b, 'a.sum, countFun('c), wAvgFun('b, 'a), wAvgFun('a, 'a), countDistinct('c))
 
-    val expected = "1,1,1,1,1\n" + "2,5,2,2,2\n" + "3,15,3,3,5\n" + "4,34,4,4,8\n" +
-      "5,65,5,5,13\n" + "6,111,6,6,18\n"
+    val expected = "1,1,1,1,1,1\n" + "2,5,2,2,2,2\n" + "3,15,3,3,5,3\n" + "4,34,4,4,8,4\n" +
+      "5,65,5,5,13,5\n" + "6,111,6,6,18,6\n"
     val results = t.toDataSet[Row].collect()
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index 04aada6..67164b7 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -170,6 +170,14 @@ class HarnessTestBase {
       |    return new org.apache.flink.types.Row(5);
       |  }
       |
+      |  public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) {
+      |  }
+      |
+      |  public void cleanup() {
+      |  }
+      |
+      |  public void close() {
+      |  }
       |/*******  This test does not use the following methods  *******/
       |  public org.apache.flink.types.Row mergeAccumulatorsPair(
       |    org.apache.flink.types.Row a,
@@ -282,6 +290,15 @@ class HarnessTestBase {
       |  public final void resetAccumulator(
       |    org.apache.flink.types.Row accs) {
       |  }
+      |
+      |  public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) {
+      |  }
+      |
+      |  public void cleanup() {
+      |  }
+      |
+      |  public void close() {
+      |  }
       |}
       |""".stripMargin
 

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
index 744ac46..eb3d37f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
@@ -24,7 +24,8 @@ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.runtime.utils.StreamITCase.RetractingSink
 import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment}
-import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase}
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg}
+import org.apache.flink.table.runtime.utils.{JavaUserDefinedAggFunctions, StreamITCase, StreamTestData, StreamingWithStateTestBase}
 import org.apache.flink.types.Row
 import org.junit.Assert.assertEquals
 import org.junit.Test
@@ -154,4 +155,42 @@ class AggregateITCase extends StreamingWithStateTestBase {
       "12,3,5,1", "5,3,4,2")
     assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
   }
+
+  @Test
+  def testGroupAggregateWithStateBackend(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val data = new mutable.MutableList[(Int, Long, String)]
+    data.+=((1, 1L, "A"))
+    data.+=((2, 2L, "B"))
+    data.+=((3, 2L, "B"))
+    data.+=((4, 3L, "C"))
+    data.+=((5, 3L, "C"))
+    data.+=((6, 3L, "C"))
+    data.+=((7, 4L, "B"))
+    data.+=((8, 4L, "A"))
+    data.+=((9, 4L, "D"))
+    data.+=((10, 4L, "E"))
+    data.+=((11, 5L, "A"))
+    data.+=((12, 5L, "B"))
+
+    val distinct = new CountDistinct
+    val testAgg = new DataViewTestAgg
+    val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c)
+      .groupBy('b)
+      .select('b, distinct('c), testAgg('c, 'b))
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new StreamITCase.RetractingSink)
+    env.execute()
+
+    val expected = List("1,1,2", "2,1,5", "3,1,10", "4,4,20", "5,2,12")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+
+    // verify agg close is called
+    assert(JavaUserDefinedAggFunctions.isCloseCalled)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala
index 1561da0..f6e739e 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala
@@ -29,7 +29,7 @@ import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
 import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
 import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{WeightedAvg, WeightedAvgWithMerge}
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, CountDistinctWithMerge, WeightedAvg, WeightedAvgWithMerge}
 import org.apache.flink.table.functions.aggfunctions.CountAggFunction
 import org.apache.flink.table.runtime.stream.table.GroupWindowITCase._
 import org.apache.flink.table.runtime.utils.StreamITCase
@@ -75,19 +75,21 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase {
 
     val countFun = new CountAggFunction
     val weightAvgFun = new WeightedAvg
+    val countDistinct = new CountDistinct
 
     val windowedTable = table
       .window(Slide over 2.rows every 1.rows on 'proctime as 'w)
       .groupBy('w, 'string)
       .select('string, countFun('int), 'int.avg,
-              weightAvgFun('long, 'int), weightAvgFun('int, 'int))
+        weightAvgFun('long, 'int), weightAvgFun('int, 'int),
+        countDistinct('long))
 
     val results = windowedTable.toAppendStream[Row](queryConfig)
     results.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
-    val expected = Seq("Hello world,1,3,8,3", "Hello world,2,3,12,3", "Hello,1,2,2,2",
-                       "Hello,2,2,3,2", "Hi,1,1,1,1")
+    val expected = Seq("Hello world,1,3,8,3,1", "Hello world,2,3,12,3,2", "Hello,1,2,2,2,1",
+      "Hello,2,2,3,2,2", "Hi,1,1,1,1,1")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -112,6 +114,7 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase {
 
     val countFun = new CountAggFunction
     val weightAvgFun = new WeightedAvgWithMerge
+    val countDistinct = new CountDistinctWithMerge
 
     val stream = env
       .fromCollection(sessionWindowTestdata)
@@ -122,13 +125,14 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase {
       .window(Session withGap 5.milli on 'rowtime as 'w)
       .groupBy('w, 'string)
       .select('string, countFun('int), 'int.avg,
-              weightAvgFun('long, 'int), weightAvgFun('int, 'int))
+        weightAvgFun('long, 'int), weightAvgFun('int, 'int),
+        countDistinct('long))
 
     val results = windowedTable.toAppendStream[Row]
     results.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
-    val expected = Seq("Hello World,1,9,9,9", "Hello,1,16,16,16", "Hello,4,3,5,5")
+    val expected = Seq("Hello World,1,9,9,9,1", "Hello,1,16,16,16,1", "Hello,4,3,5,5,4")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -143,18 +147,21 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase {
     val table = stream.toTable(tEnv, 'long, 'int, 'string, 'proctime.proctime)
     val countFun = new CountAggFunction
     val weightAvgFun = new WeightedAvg
+    val countDistinct = new CountDistinct
 
     val windowedTable = table
       .window(Tumble over 2.rows on 'proctime as 'w)
       .groupBy('w)
       .select(countFun('string), 'int.avg,
-              weightAvgFun('long, 'int), weightAvgFun('int, 'int))
+        weightAvgFun('long, 'int), weightAvgFun('int, 'int),
+        countDistinct('long)
+      )
 
     val results = windowedTable.toAppendStream[Row](queryConfig)
     results.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
-    val expected = Seq("2,1,1,1", "2,2,6,2")
+    val expected = Seq("2,1,1,1,2", "2,2,6,2,2")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -171,22 +178,24 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase {
     val table = stream.toTable(tEnv, 'long, 'int, 'string, 'rowtime.rowtime)
     val countFun = new CountAggFunction
     val weightAvgFun = new WeightedAvg
+    val countDistinct = new CountDistinct
 
     val windowedTable = table
       .window(Tumble over 5.milli on 'rowtime as 'w)
       .groupBy('w, 'string)
       .select('string, countFun('string), 'int.avg, weightAvgFun('long, 'int),
-              weightAvgFun('int, 'int), 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end)
+        weightAvgFun('int, 'int), 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end,
+        countDistinct('long))
 
     val results = windowedTable.toAppendStream[Row]
     results.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
     val expected = Seq(
-      "Hello world,1,3,8,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01",
-      "Hello world,1,3,16,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02",
-      "Hello,2,2,3,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005",
-      "Hi,1,1,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005")
+      "Hello world,1,3,8,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01,1",
+      "Hello world,1,3,16,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02,1",
+      "Hello,2,2,3,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005,2",
+      "Hi,1,1,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005,1")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -208,17 +217,18 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase {
     val table = stream.toTable(tEnv, 'long, 'int, 'string, 'int2, 'int3, 'proctime.proctime)
 
     val weightAvgFun = new WeightedAvg
+    val countDistinct = new CountDistinct
 
     val windowedTable = table
       .window(Slide over 2.rows every 1.rows on 'proctime as 'w)
       .groupBy('w, 'int2, 'int3, 'string)
-      .select(weightAvgFun('long, 'int))
+      .select(weightAvgFun('long, 'int), countDistinct('long))
 
     val results = windowedTable.toAppendStream[Row]
     results.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
-    val expected = Seq("12", "8", "2", "3", "1")
+    val expected = Seq("12,2", "8,1", "2,1", "3,2", "1,1")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala
index 73484d2..54971b2 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala
@@ -25,7 +25,7 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceCont
 import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
 import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.table.api.TableEnvironment
-import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, CountDistinctWithRetractAndReset, WeightedAvg}
 import org.apache.flink.table.runtime.utils.JavaUserDefinedScalarFunctions.JavaFunc0
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.functions.aggfunctions.CountAggFunction
@@ -51,6 +51,7 @@ class OverWindowITCase extends StreamingWithStateTestBase {
       (6L, 6, "Hello"),
       (7L, 7, "Hello World"),
       (8L, 8, "Hello World"),
+      (8L, 8, "Hello World"),
       (20L, 20, "Hello World"))
 
     val env = StreamExecutionEnvironment.getExecutionEnvironment
@@ -62,20 +63,24 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     val table = stream.toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime)
     val countFun = new CountAggFunction
     val weightAvgFun = new WeightedAvg
+    val countDist = new CountDistinct
 
     val windowedTable = table
       .window(
         Over partitionBy 'c orderBy 'proctime preceding UNBOUNDED_ROW as 'w)
-      .select('c, countFun('b) over 'w as 'mycount, weightAvgFun('a, 'b) over 'w as 'wAvg)
-      .select('c, 'mycount, 'wAvg)
+      .select('c,
+        countFun('b) over 'w as 'mycount,
+        weightAvgFun('a, 'b) over 'w as 'wAvg,
+        countDist('a) over 'w as 'countDist)
+      .select('c, 'mycount, 'wAvg, 'countDist)
 
     val results = windowedTable.toAppendStream[Row]
     results.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
     val expected = Seq(
-      "Hello World,1,7", "Hello World,2,7", "Hello World,3,14",
-      "Hello,1,1", "Hello,2,1", "Hello,3,2", "Hello,4,3", "Hello,5,3", "Hello,6,4")
+      "Hello World,1,7,1", "Hello World,2,7,2", "Hello World,3,7,2", "Hello World,4,13,3",
+      "Hello,1,1,1", "Hello,2,1,2", "Hello,3,2,3", "Hello,4,3,4", "Hello,5,3,5", "Hello,6,4,6")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -112,6 +117,7 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     val countFun = new CountAggFunction
     val weightAvgFun = new WeightedAvg
     val plusOne = new JavaFunc0
+    val countDist = new CountDistinct
 
     val windowedTable = table
       .window(Over partitionBy 'a orderBy 'rowtime preceding UNBOUNDED_RANGE following
@@ -128,26 +134,27 @@ class OverWindowITCase extends StreamingWithStateTestBase {
         'b.max over 'w,
         'b.min over 'w,
         ('b.min over 'w).abs(),
-        weightAvgFun('b, 'a) over 'w)
+        weightAvgFun('b, 'a) over 'w,
+        countDist('c) over 'w as 'countDist)
 
     val result = windowedTable.toAppendStream[Row]
     result.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
     val expected = mutable.MutableList(
-      "1,1,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
-      "1,2,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
-      "1,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
-      "1,1,Hi,7,SUM:7,4,5,5,[1, 3],1,3,1,1,1",
-      "2,1,Hello,1,SUM:1,1,2,2,[1, 1],1,1,1,1,1",
-      "2,2,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
-      "2,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2",
-      "1,4,Hello world,11,SUM:11,5,6,6,[2, 4],2,4,1,1,2",
-      "1,5,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3",
-      "1,6,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3",
-      "1,7,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3",
-      "2,4,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3",
-      "2,5,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3"
+      "1,1,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2",
+      "1,2,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2",
+      "1,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2",
+      "1,1,Hi,7,SUM:7,4,5,5,[1, 3],1,3,1,1,1,3",
+      "2,1,Hello,1,SUM:1,1,2,2,[1, 1],1,1,1,1,1,1",
+      "2,2,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2",
+      "2,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2",
+      "1,4,Hello world,11,SUM:11,5,6,6,[2, 4],2,4,1,1,2,3",
+      "1,5,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3,3",
+      "1,6,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3,3",
+      "1,7,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3,3",
+      "2,4,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3,2",
+      "2,5,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3,2"
     )
 
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
@@ -179,32 +186,33 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     env.setParallelism(1)
     StreamITCase.testResults = mutable.MutableList()
 
+    val countDist = new CountDistinctWithRetractAndReset
     val stream = env.fromCollection(data)
     val table = stream.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
 
     val windowedTable = table
       .window(Over partitionBy 'a orderBy 'proctime preceding 4.rows following CURRENT_ROW as 'w)
-      .select('a, 'c.sum over 'w, 'c.min over 'w)
+      .select('a, 'c.sum over 'w, 'c.min over 'w, countDist('e) over 'w)
     val result = windowedTable.toAppendStream[Row]
     result.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
     val expected = mutable.MutableList(
-      "1,0,0",
-      "2,1,1",
-      "2,3,1",
-      "3,3,3",
-      "3,7,3",
-      "3,12,3",
-      "4,6,6",
-      "4,13,6",
-      "4,21,6",
-      "4,30,6",
-      "5,10,10",
-      "5,21,10",
-      "5,33,10",
-      "5,46,10",
-      "5,60,10")
+      "1,0,0,1",
+      "2,1,1,1",
+      "2,3,1,2",
+      "3,3,3,1",
+      "3,7,3,1",
+      "3,12,3,2",
+      "4,6,6,1",
+      "4,13,6,2",
+      "4,21,6,2",
+      "4,30,6,2",
+      "5,10,10,1",
+      "5,21,10,2",
+      "5,33,10,2",
+      "5,46,10,3",
+      "5,60,10,3")
 
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
@@ -240,25 +248,27 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     val tEnv = TableEnvironment.getTableEnvironment(env)
     StreamITCase.clear
 
+    val countDist = new CountDistinctWithRetractAndReset
     val table = env.addSource[(Long, Int, String)](
       new RowTimeSourceFunction[(Long, Int, String)](data))
       .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
 
     val windowedTable = table
       .window(Over partitionBy 'c orderBy 'rowtime preceding 2.rows following CURRENT_ROW as 'w)
-      .select('c, 'a, 'a.count over 'w, 'a.sum over 'w)
+      .select('c, 'a, 'a.count over 'w, 'a.sum over 'w, countDist('a) over 'w)
 
     val result = windowedTable.toAppendStream[Row]
     result.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
     val expected = mutable.MutableList(
-      "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
-      "Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6",
-      "Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12",
-      "Hello,6,3,15",
-      "Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21",
-      "Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35")
+      "Hello,1,1,1,1", "Hello,1,2,2,1", "Hello,1,3,3,1",
+      "Hello,2,3,4,2", "Hello,2,3,5,2", "Hello,2,3,6,1",
+      "Hello,3,3,7,2", "Hello,4,3,9,3", "Hello,5,3,12,3",
+      "Hello,6,3,15,3",
+      "Hello World,7,1,7,1", "Hello World,7,2,14,1", "Hello World,7,3,21,1",
+      "Hello World,7,3,21,1", "Hello World,8,3,22,2", "Hello World,20,3,35,3")
+
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
@@ -302,6 +312,7 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     val tEnv = TableEnvironment.getTableEnvironment(env)
     StreamITCase.clear
 
+    val countDist = new CountDistinctWithRetractAndReset
     val table = env.addSource[(Long, Int, String)](
       new RowTimeSourceFunction[(Long, Int, String)](data))
       .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
@@ -309,23 +320,24 @@ class OverWindowITCase extends StreamingWithStateTestBase {
     val windowedTable = table
       .window(
         Over partitionBy 'c orderBy 'rowtime preceding 1.seconds following CURRENT_RANGE as 'w)
-      .select('c, 'b, 'a.count over 'w, 'a.sum over 'w)
+      .select('c, 'b, 'a.count over 'w, 'a.sum over 'w, countDist('a) over 'w)
 
     val result = windowedTable.toAppendStream[Row]
     result.addSink(new StreamITCase.StringSink[Row])
     env.execute()
 
     val expected = mutable.MutableList(
-      "Hello,1,1,1", "Hello,15,2,2", "Hello,16,3,3",
-      "Hello,2,6,9", "Hello,3,6,9", "Hello,2,6,9",
-      "Hello,3,4,9",
-      "Hello,4,2,7",
-      "Hello,5,2,9",
-      "Hello,6,2,11", "Hello,65,2,12",
-      "Hello,9,2,12", "Hello,9,2,12", "Hello,18,3,18",
-      "Hello World,7,1,7", "Hello World,17,3,21", "Hello World,77,3,21", "Hello World,18,1,7",
-      "Hello World,8,2,15",
-      "Hello World,20,1,20")
+      "Hello,1,1,1,1", "Hello,15,2,2,1", "Hello,16,3,3,1",
+      "Hello,2,6,9,2", "Hello,3,6,9,2", "Hello,2,6,9,2",
+      "Hello,3,4,9,2",
+      "Hello,4,2,7,2",
+      "Hello,5,2,9,2",
+      "Hello,6,2,11,2", "Hello,65,2,12,1",
+      "Hello,9,2,12,1", "Hello,9,2,12,1", "Hello,18,3,18,1",
+      "Hello World,7,1,7,1", "Hello World,17,3,21,1",
+      "Hello World,77,3,21,1", "Hello World,18,1,7,1",
+      "Hello World,8,2,15,2",
+      "Hello World,20,1,20,1")
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 }