You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/07/13 10:18:38 UTC

[29/44] flink git commit: [FLINK-6617] [table] Restructuring of tests

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala
new file mode 100644
index 0000000..840be17
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala
@@ -0,0 +1,524 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import java.math.BigDecimal
+
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.SqlPostfixOperator
+import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, INTEGER, VARCHAR}
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.plan.util.RexProgramExtractor
+import org.apache.flink.table.utils.InputTypeBuilder.inputOf
+import org.apache.flink.table.validate.FunctionCatalog
+import org.hamcrest.CoreMatchers.is
+import org.junit.Assert.{assertArrayEquals, assertEquals, assertThat}
+import org.junit.Test
+
+import scala.collection.JavaConverters._
+
+class RexProgramExtractorTest extends RexProgramTestBase {
+
+  private val functionCatalog: FunctionCatalog = FunctionCatalog.withBuiltIns
+
+  @Test
+  def testExtractRefInputFields(): Unit = {
+    val usedFields = RexProgramExtractor.extractRefInputFields(buildSimpleRexProgram())
+    assertArrayEquals(usedFields, Array(2, 3, 1))
+  }
+
+  @Test
+  def testExtractSimpleCondition(): Unit = {
+    val builder: RexBuilder = new RexBuilder(typeFactory)
+    val program = buildSimpleRexProgram()
+
+    val firstExp = ExpressionParser.parseExpression("id > 6")
+    val secondExp = ExpressionParser.parseExpression("amount * price < 100")
+    val expected: Array[Expression] = Array(firstExp, secondExp)
+
+    val (convertedExpressions, unconvertedRexNodes) =
+      RexProgramExtractor.extractConjunctiveConditions(
+        program,
+        builder,
+        functionCatalog)
+
+    assertExpressionArrayEquals(expected, convertedExpressions)
+    assertEquals(0, unconvertedRexNodes.length)
+  }
+
+  @Test
+  def testExtractSingleCondition(): Unit = {
+    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    // amount
+    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+    // id
+    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+
+    // a = amount >= id
+    val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1))
+    builder.addCondition(a)
+
+    val program = builder.getProgram
+    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+    val (convertedExpressions, unconvertedRexNodes) =
+      RexProgramExtractor.extractConjunctiveConditions(
+        program,
+        relBuilder,
+        functionCatalog)
+
+    val expected: Array[Expression] = Array(ExpressionParser.parseExpression("amount >= id"))
+    assertExpressionArrayEquals(expected, convertedExpressions)
+    assertEquals(0, unconvertedRexNodes.length)
+  }
+
+  // ((a AND b) OR c) AND (NOT d) => (a OR c) AND (b OR c) AND (NOT d)
+  @Test
+  def testExtractCnfCondition(): Unit = {
+    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    // amount
+    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+    // id
+    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+    // price
+    val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
+    // 100
+    val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+    // a = amount < 100
+    val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t3))
+    // b = id > 100
+    val b = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t3))
+    // c = price == 100
+    val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t3))
+    // d = amount <= id
+    val d = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
+
+    // a AND b
+    val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(a, b).asJava))
+    // (a AND b) or c
+    val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, List(and, c).asJava))
+    // not d
+    val not = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT, List(d).asJava))
+
+    // (a AND b) OR c) AND (NOT d)
+    builder.addCondition(builder.addExpr(
+      rexBuilder.makeCall(SqlStdOperatorTable.AND, List(or, not).asJava)))
+
+    val program = builder.getProgram
+    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+    val (convertedExpressions, unconvertedRexNodes) =
+      RexProgramExtractor.extractConjunctiveConditions(
+        program,
+        relBuilder,
+        functionCatalog)
+
+    val expected: Array[Expression] = Array(
+      ExpressionParser.parseExpression("amount < 100 || price == 100"),
+      ExpressionParser.parseExpression("id > 100 || price == 100"),
+      ExpressionParser.parseExpression("!(amount <= id)"))
+    assertExpressionArrayEquals(expected, convertedExpressions)
+    assertEquals(0, unconvertedRexNodes.length)
+  }
+
+  @Test
+  def testExtractArithmeticConditions(): Unit = {
+    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    // amount
+    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+    // id
+    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+    // 100
+    val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+    val condition = List(
+      // amount < id
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t1)),
+      // amount <= id
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)),
+      // amount <> id
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS, t0, t1)),
+      // amount == id
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t0, t1)),
+      // amount >= id
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1)),
+      // amount > id
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t0, t1)),
+      // amount + id == 100
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+        rexBuilder.makeCall(SqlStdOperatorTable.PLUS, t0, t1), t2)),
+      // amount - id == 100
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+        rexBuilder.makeCall(SqlStdOperatorTable.MINUS, t0, t1), t2)),
+      // amount * id == 100
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+        rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t1), t2)),
+      // amount / id == 100
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+        rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, t0, t1), t2)),
+      // -amount == 100
+      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+        rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, t0), t2))
+    ).asJava
+
+    builder.addCondition(builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, condition)))
+    val program = builder.getProgram
+    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+    val (convertedExpressions, unconvertedRexNodes) =
+      RexProgramExtractor.extractConjunctiveConditions(
+        program,
+        relBuilder,
+        functionCatalog)
+
+    val expected: Array[Expression] = Array(
+      ExpressionParser.parseExpression("amount < id"),
+      ExpressionParser.parseExpression("amount <= id"),
+      ExpressionParser.parseExpression("amount <> id"),
+      ExpressionParser.parseExpression("amount == id"),
+      ExpressionParser.parseExpression("amount >= id"),
+      ExpressionParser.parseExpression("amount > id"),
+      ExpressionParser.parseExpression("amount + id == 100"),
+      ExpressionParser.parseExpression("amount - id == 100"),
+      ExpressionParser.parseExpression("amount * id == 100"),
+      ExpressionParser.parseExpression("amount / id == 100"),
+      ExpressionParser.parseExpression("-amount == 100")
+    )
+    assertExpressionArrayEquals(expected, convertedExpressions)
+    assertEquals(0, unconvertedRexNodes.length)
+  }
+
+  @Test
+  def testExtractPostfixConditions(): Unit = {
+    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NULL, "('flag).isNull")
+    // IS_NOT_NULL will be eliminated since flag is not nullable
+    // testExtractSinglePostfixCondition(SqlStdOperatorTable.IS_NOT_NULL, "('flag).isNotNull")
+    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_TRUE, "('flag).isTrue")
+    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_TRUE, "('flag).isNotTrue")
+    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_FALSE, "('flag).isFalse")
+    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_FALSE, "('flag).isNotFalse")
+  }
+
+  @Test
+  def testExtractConditionWithFunctionCalls(): Unit = {
+    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    // amount
+    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+    // id
+    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+    // 100
+    val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+    // sum(amount) > 100
+    val condition1 = builder.addExpr(
+      rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN,
+        rexBuilder.makeCall(SqlStdOperatorTable.SUM, t0), t2))
+
+    // min(id) == 100
+    val condition2 = builder.addExpr(
+      rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+        rexBuilder.makeCall(SqlStdOperatorTable.MIN, t1), t2))
+
+    builder.addCondition(builder.addExpr(
+      rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2)))
+
+    val program = builder.getProgram
+    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+    val (convertedExpressions, unconvertedRexNodes) =
+      RexProgramExtractor.extractConjunctiveConditions(
+        program,
+        relBuilder,
+        functionCatalog)
+
+    val expected: Array[Expression] = Array(
+      GreaterThan(Sum(UnresolvedFieldReference("amount")), Literal(100)),
+      EqualTo(Min(UnresolvedFieldReference("id")), Literal(100))
+    )
+    assertExpressionArrayEquals(expected, convertedExpressions)
+    assertEquals(0, unconvertedRexNodes.length)
+  }
+
+  @Test
+  def testExtractWithUnsupportedConditions(): Unit = {
+    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    // amount
+    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+    // id
+    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+    // 100
+    val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+    // unsupported now: amount.cast(BigInteger)
+    val cast = builder.addExpr(rexBuilder.makeCast(allFieldTypes.get(1), t0))
+
+    // unsupported now: amount.cast(BigInteger) > 100
+    val condition1 = builder.addExpr(
+      rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, cast, t2))
+
+    // amount <= id
+    val condition2 = builder.addExpr(
+      rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
+
+    // contains unsupported condition: (amount.cast(BigInteger) > 100 OR amount <= id)
+    val condition3 = builder.addExpr(
+      rexBuilder.makeCall(SqlStdOperatorTable.OR, condition1, condition2))
+
+    // only condition2 can be translated
+    builder.addCondition(
+      rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2, condition3))
+
+    val program = builder.getProgram
+    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+    val (convertedExpressions, unconvertedRexNodes) =
+      RexProgramExtractor.extractConjunctiveConditions(
+        program,
+        relBuilder,
+        functionCatalog)
+
+    val expected: Array[Expression] = Array(
+      ExpressionParser.parseExpression("amount <= id")
+    )
+    assertExpressionArrayEquals(expected, convertedExpressions)
+    assertEquals(2, unconvertedRexNodes.length)
+    assertEquals(">(CAST($2):BIGINT NOT NULL, 100)", unconvertedRexNodes(0).toString)
+    assertEquals("OR(>(CAST($2):BIGINT NOT NULL, 100), <=($2, $1))",
+      unconvertedRexNodes(1).toString)
+  }
+
+  @Test
+  def testExtractRefNestedInputFields(): Unit = {
+    val rexProgram = buildRexProgramWithNesting()
+
+    val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
+    val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
+
+    val expected = Array(Array("amount"), Array("*"))
+    assertThat(usedNestedFields, is(expected))
+  }
+
+  @Test
+  def testExtractRefNestedInputFieldsWithNoNesting(): Unit = {
+    val rexProgram = buildSimpleRexProgram()
+
+    val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
+    val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
+
+    val expected = Array(Array("*"), Array("*"), Array("*"))
+    assertThat(usedNestedFields, is(expected))
+  }
+
+  @Test
+  def testExtractDeepRefNestedInputFields(): Unit = {
+    val rexProgram = buildRexProgramWithDeepNesting()
+
+    val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
+    val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
+
+    val expected = Array(
+      Array("amount"),
+      Array("*"),
+      Array("with.deeper.entry", "with.deep.entry"))
+
+    assertThat(usedFields, is(Array(1, 0, 2)))
+    assertThat(usedNestedFields, is(expected))
+  }
+
+  private def buildRexProgramWithDeepNesting(): RexProgram = {
+
+    // person input
+    val passportRow = inputOf(typeFactory)
+      .field("id", VARCHAR)
+      .field("status", VARCHAR)
+      .build
+
+    val personRow = inputOf(typeFactory)
+      .field("name", VARCHAR)
+      .field("age", INTEGER)
+      .nestedField("passport", passportRow)
+      .build
+
+    // payment input
+    val paymentRow = inputOf(typeFactory)
+      .field("id", BIGINT)
+      .field("amount", INTEGER)
+      .build
+
+    // deep field input
+    val deepRowType = inputOf(typeFactory)
+      .field("entry", VARCHAR)
+      .build
+
+    val entryRowType = inputOf(typeFactory)
+      .nestedField("inside", deepRowType)
+      .build
+
+    val deeperRowType = inputOf(typeFactory)
+      .nestedField("entry", entryRowType)
+      .build
+
+    val withRowType = inputOf(typeFactory)
+      .nestedField("deep", deepRowType)
+      .nestedField("deeper", deeperRowType)
+      .build
+
+    val fieldRowType = inputOf(typeFactory)
+      .nestedField("with", withRowType)
+      .build
+
+    // main input
+    val inputRowType = inputOf(typeFactory)
+      .nestedField("persons", personRow)
+      .nestedField("payments", paymentRow)
+      .nestedField("field", fieldRowType)
+      .build
+
+    // inputRowType
+    //
+    // [ persons:  [ name: VARCHAR, age:  INT, passport: [id: VARCHAR, status: VARCHAR ] ],
+    //   payments: [ id: BIGINT, amount: INT ],
+    //   field:    [ with: [ deep: [ entry: VARCHAR ],
+    //                       deeper: [ entry: [ inside: [entry: VARCHAR ] ] ]
+    //             ] ]
+    // ]
+
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    val t0 = rexBuilder.makeInputRef(personRow, 0)
+    val t1 = rexBuilder.makeInputRef(paymentRow, 1)
+    val t2 = rexBuilder.makeInputRef(fieldRowType, 2)
+    val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(10L))
+
+    // person
+    val person$pass = rexBuilder.makeFieldAccess(t0, "passport", false)
+    val person$pass$stat = rexBuilder.makeFieldAccess(person$pass, "status", false)
+
+    // payment
+    val pay$amount = rexBuilder.makeFieldAccess(t1, "amount", false)
+    val multiplyAmount = builder.addExpr(
+      rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, pay$amount, t3))
+
+    // field
+    val field$with = rexBuilder.makeFieldAccess(t2, "with", false)
+    val field$with$deep = rexBuilder.makeFieldAccess(field$with, "deep", false)
+    val field$with$deeper = rexBuilder.makeFieldAccess(field$with, "deeper", false)
+    val field$with$deep$entry = rexBuilder.makeFieldAccess(field$with$deep, "entry", false)
+    val field$with$deeper$entry = rexBuilder.makeFieldAccess(field$with$deeper, "entry", false)
+    val field$with$deeper$entry$inside = rexBuilder
+      .makeFieldAccess(field$with$deeper$entry, "inside", false)
+    val field$with$deeper$entry$inside$entry = rexBuilder
+      .makeFieldAccess(field$with$deeper$entry$inside, "entry", false)
+
+    builder.addProject(multiplyAmount, "amount")
+    builder.addProject(person$pass$stat, "status")
+    builder.addProject(field$with$deep$entry, "entry")
+    builder.addProject(field$with$deeper$entry$inside$entry, "entry")
+    builder.addProject(field$with$deeper$entry, "entry2")
+    builder.addProject(t0, "person")
+
+    // Program
+    // (
+    //   payments.amount * 10),
+    //   persons.passport.status,
+    //   field.with.deep.entry
+    //   field.with.deeper.entry.inside.entry
+    //   field.with.deeper.entry
+    //   persons
+    // )
+
+    builder.getProgram
+
+  }
+
+  private def buildRexProgramWithNesting(): RexProgram = {
+
+    val personRow = inputOf(typeFactory)
+      .field("name", INTEGER)
+      .field("age", VARCHAR)
+      .build
+
+    val paymentRow = inputOf(typeFactory)
+      .field("id", BIGINT)
+      .field("amount", INTEGER)
+      .build
+
+    val types = List(personRow, paymentRow).asJava
+    val names = List("persons", "payments").asJava
+    val inputRowType = typeFactory.createStructType(types, names)
+
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    val t0 = rexBuilder.makeInputRef(types.get(0), 0)
+    val t1 = rexBuilder.makeInputRef(types.get(1), 1)
+    val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+    val payment$amount = rexBuilder.makeFieldAccess(t1, "amount", false)
+
+    builder.addProject(payment$amount, "amount")
+    builder.addProject(t0, "persons")
+    builder.addProject(t2, "number")
+    builder.getProgram
+  }
+
+  private def testExtractSinglePostfixCondition(
+      fieldIndex: Integer,
+      op: SqlPostfixOperator,
+      expr: String) : Unit = {
+
+    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+    rexBuilder = new RexBuilder(typeFactory)
+
+    // flag
+    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(fieldIndex), fieldIndex)
+    builder.addCondition(builder.addExpr(rexBuilder.makeCall(op, t0)))
+
+    val program = builder.getProgram(false)
+    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+    val (convertedExpressions, unconvertedRexNodes) =
+      RexProgramExtractor.extractConjunctiveConditions(
+        program,
+        relBuilder,
+        functionCatalog)
+
+    assertEquals(1, convertedExpressions.length)
+    assertEquals(expr, convertedExpressions.head.toString)
+    assertEquals(0, unconvertedRexNodes.length)
+  }
+
+  private def assertExpressionArrayEquals(
+      expected: Array[Expression],
+      actual: Array[Expression]) = {
+    val sortedExpected = expected.sortBy(e => e.toString)
+    val sortedActual = actual.sortBy(e => e.toString)
+
+    assertEquals(sortedExpected.length, sortedActual.length)
+    sortedExpected.zip(sortedActual).foreach {
+      case (l, r) => assertEquals(l.toString, r.toString)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramRewriterTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramRewriterTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramRewriterTest.scala
new file mode 100644
index 0000000..dc91a82
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramRewriterTest.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import org.apache.flink.table.plan.util.RexProgramRewriter
+import org.junit.Assert.assertTrue
+import org.junit.Test
+
+import scala.collection.JavaConverters._
+
+class RexProgramRewriterTest extends RexProgramTestBase {
+
+  @Test
+  def testRewriteRexProgram(): Unit = {
+    val rexProgram = buildSimpleRexProgram()
+    assertTrue(extractExprStrList(rexProgram) == wrapRefArray(Array(
+      "$0",
+      "$1",
+      "$2",
+      "$3",
+      "$4",
+      "*($t2, $t3)",
+      "100",
+      "<($t5, $t6)",
+      "6",
+      ">($t1, $t8)",
+      "AND($t7, $t9)")))
+
+    // use amount, id, price fields to create a new RexProgram
+    val usedFields = Array(2, 3, 1)
+    val types = usedFields.map(allFieldTypes.get).toList.asJava
+    val names = usedFields.map(allFieldNames.get).toList.asJava
+    val inputRowType = typeFactory.createStructType(types, names)
+    val newRexProgram = RexProgramRewriter.rewriteWithFieldProjection(
+      rexProgram, inputRowType, rexBuilder, usedFields)
+    assertTrue(extractExprStrList(newRexProgram) == wrapRefArray(Array(
+      "$0",
+      "$1",
+      "$2",
+      "*($t0, $t1)",
+      "100",
+      "<($t3, $t4)",
+      "6",
+      ">($t2, $t6)",
+      "AND($t5, $t7)")))
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala
new file mode 100644
index 0000000..b711604
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import java.math.BigDecimal
+
+import org.apache.calcite.adapter.java.JavaTypeFactory
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.`type`.SqlTypeName._
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+abstract class RexProgramTestBase {
+
+  val typeFactory: JavaTypeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
+
+  val allFieldNames: java.util.List[String] = List("name", "id", "amount", "price", "flag").asJava
+
+  val allFieldTypes: java.util.List[RelDataType] =
+    List(VARCHAR, BIGINT, INTEGER, DOUBLE, BOOLEAN).map(typeFactory.createSqlType).asJava
+
+  var rexBuilder: RexBuilder = new RexBuilder(typeFactory)
+
+  /**
+    * extract all expression string list from input RexProgram expression lists
+    *
+    * @param rexProgram input RexProgram instance to analyze
+    * @return all expression string list of input RexProgram expression lists
+    */
+  protected def extractExprStrList(rexProgram: RexProgram): mutable.Buffer[String] = {
+    rexProgram.getExprList.asScala.map(_.toString)
+  }
+
+  // select amount, amount * price as total where amount * price < 100 and id > 6
+  protected def buildSimpleRexProgram(): RexProgram = {
+    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+    val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
+    val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
+    val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+    val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
+
+    // project: amount, amount * price as total
+    builder.addProject(t0, "amount")
+    builder.addProject(t3, "total")
+
+    // condition: amount * price < 100 and id > 6
+    val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
+    val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
+    val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
+    builder.addCondition(t8)
+
+    builder.getProgram
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
new file mode 100644
index 0000000..7885160
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
@@ -0,0 +1,398 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.{TimeIntervalUnit, WindowReference}
+import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.plan.TimeIndicatorConversionTest.TableFunc
+import org.apache.flink.table.plan.logical.TumblingGroupWindow
+import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.table.utils.TableTestUtil._
+import org.junit.Test
+
+/**
+  * Tests for [[org.apache.flink.table.calcite.RelTimeIndicatorConverter]].
+  */
+class TimeIndicatorConversionTest extends TableTestBase {
+
+  @Test
+  def testSimpleMaterialization(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int, 'proctime.proctime)
+
+    val result = t
+      .select('rowtime.floor(TimeIntervalUnit.DAY) as 'rowtime, 'long)
+      .filter('long > 0)
+      .select('rowtime)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      streamTableNode(0),
+      term("select", "FLOOR(TIME_MATERIALIZATION(rowtime)", "FLAG(DAY)) AS rowtime"),
+      term("where", ">(long, 0)")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testSelectAll(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int, 'proctime.proctime)
+
+    val result = t.select('*)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      streamTableNode(0),
+      term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime", "long", "int",
+        "TIME_MATERIALIZATION(proctime) AS proctime")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testFilteringOnRowtime(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int)
+
+    val result = t
+      .filter('rowtime > "1990-12-02 12:11:11".toTimestamp)
+      .select('rowtime)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      streamTableNode(0),
+      term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime"),
+      term("where", ">(TIME_MATERIALIZATION(rowtime), 1990-12-02 12:11:11)")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testGroupingOnRowtime(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int, 'proctime.proctime)
+
+    val result = t
+      .groupBy('rowtime)
+      .select('long.count)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamGroupAggregate",
+        unaryNode(
+          "DataStreamCalc",
+          streamTableNode(0),
+          term("select", "long", "TIME_MATERIALIZATION(rowtime) AS rowtime")
+        ),
+        term("groupBy", "rowtime"),
+        term("select", "rowtime", "COUNT(long) AS TMP_0")
+      ),
+      term("select", "TMP_0")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testAggregationOnRowtime(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int)
+
+    val result = t
+      .groupBy('long)
+      .select('rowtime.min)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamGroupAggregate",
+        unaryNode(
+          "DataStreamCalc",
+          streamTableNode(0),
+          term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime", "long")
+        ),
+        term("groupBy", "long"),
+        term("select", "long", "MIN(rowtime) AS TMP_0")
+      ),
+      term("select", "TMP_0")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testTableFunction(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int, 'proctime.proctime)
+    val func = new TableFunc
+
+    val result = t.join(func('rowtime, 'proctime, "") as 's).select('rowtime, 'proctime, 's)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation",
+          s"${func.functionIdentifier}(TIME_MATERIALIZATION($$0), TIME_MATERIALIZATION($$3), '')"),
+        term("function", func),
+        term("rowType", "RecordType(TIMESTAMP(3) rowtime, BIGINT long, INTEGER int, " +
+          "TIMESTAMP(3) proctime, VARCHAR(2147483647) s)"),
+        term("joinType", "INNER")
+      ),
+      term("select",
+        "TIME_MATERIALIZATION(rowtime) AS rowtime",
+        "TIME_MATERIALIZATION(proctime) AS proctime",
+        "s")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testWindow(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int)
+
+    val result = t
+      .window(Tumble over 100.millis on 'rowtime as 'w)
+      .groupBy('w, 'long)
+      .select('w.end as 'rowtime, 'long, 'int.sum)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamGroupWindowAggregate",
+        streamTableNode(0),
+        term("groupBy", "long"),
+        term(
+          "window",
+          TumblingGroupWindow(
+            'w,
+            'rowtime,
+            100.millis)),
+        term("select", "long", "SUM(int) AS TMP_1", "end('w) AS TMP_0")
+      ),
+      term("select", "TMP_0 AS rowtime", "long", "TMP_1")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testUnion(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int)
+
+    val result = t.unionAll(t).select('rowtime)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      binaryNode(
+        "DataStreamUnion",
+        streamTableNode(0),
+        streamTableNode(0),
+        term("union all", "rowtime", "long", "int")
+      ),
+      term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testMultiWindow(): Unit = {
+    val util = streamTestUtil()
+    val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int)
+
+    val result = t
+      .window(Tumble over 100.millis on 'rowtime as 'w)
+      .groupBy('w, 'long)
+      .select('w.rowtime as 'newrowtime, 'long, 'int.sum as 'int)
+      .window(Tumble over 1.second on 'newrowtime as 'w2)
+      .groupBy('w2, 'long)
+      .select('w2.end, 'long, 'int.sum)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamGroupWindowAggregate",
+        unaryNode(
+          "DataStreamCalc",
+          unaryNode(
+            "DataStreamGroupWindowAggregate",
+            streamTableNode(0),
+            term("groupBy", "long"),
+            term(
+              "window",
+              TumblingGroupWindow(
+                'w,
+                'rowtime,
+                100.millis)),
+            term("select", "long", "SUM(int) AS TMP_1", "rowtime('w) AS TMP_0")
+          ),
+          term("select", "TMP_0 AS newrowtime", "long", "TMP_1 AS int")
+        ),
+        term("groupBy", "long"),
+        term(
+          "window",
+          TumblingGroupWindow(
+            'w2,
+            'newrowtime,
+            1000.millis)),
+        term("select", "long", "SUM(int) AS TMP_3", "end('w2) AS TMP_2")
+      ),
+      term("select", "TMP_2", "long", "TMP_3")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testGroupingOnProctime(): Unit = {
+    val util = streamTestUtil()
+    util.addTable[(Long, Int)]("MyTable" , 'long, 'int, 'proctime.proctime)
+
+    val result = util.tableEnv.sql("SELECT COUNT(long) FROM MyTable GROUP BY proctime")
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamGroupAggregate",
+        unaryNode(
+          "DataStreamCalc",
+          streamTableNode(0),
+          term("select", "TIME_MATERIALIZATION(proctime) AS proctime", "long")
+        ),
+        term("groupBy", "proctime"),
+        term("select", "proctime", "COUNT(long) AS EXPR$0")
+      ),
+      term("select", "EXPR$0")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testAggregationOnProctime(): Unit = {
+    val util = streamTestUtil()
+    util.addTable[(Long, Int)]("MyTable" , 'long, 'int, 'proctime.proctime)
+
+    val result = util.tableEnv.sql("SELECT MIN(proctime) FROM MyTable GROUP BY long")
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamGroupAggregate",
+        unaryNode(
+          "DataStreamCalc",
+          streamTableNode(0),
+          term("select", "long", "TIME_MATERIALIZATION(proctime) AS proctime")
+        ),
+        term("groupBy", "long"),
+        term("select", "long", "MIN(proctime) AS EXPR$0")
+      ),
+      term("select", "EXPR$0")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testWindowSql(): Unit = {
+    val util = streamTestUtil()
+    util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int)
+
+    val result = util.tableEnv.sql(
+      "SELECT TUMBLE_END(rowtime, INTERVAL '0.1' SECOND) AS `rowtime`, `long`, " +
+        "SUM(`int`) FROM MyTable " +
+        "GROUP BY `long`, TUMBLE(rowtime, INTERVAL '0.1' SECOND)")
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamGroupWindowAggregate",
+        streamTableNode(0),
+        term("groupBy", "long"),
+        term(
+          "window",
+          TumblingGroupWindow(
+            WindowReference("w$"),
+            'rowtime,
+            100.millis)),
+        term("select", "long", "SUM(int) AS EXPR$2", "start('w$) AS w$start", "end('w$) AS w$end")
+      ),
+      term("select", "w$end", "long", "EXPR$2")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testWindowWithAggregationOnRowtime(): Unit = {
+    val util = streamTestUtil()
+    util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int)
+
+    val result = util.tableEnv.sql("SELECT MIN(rowtime), long FROM MyTable " +
+      "GROUP BY long, TUMBLE(rowtime, INTERVAL '0.1' SECOND)")
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamGroupWindowAggregate",
+        unaryNode(
+          "DataStreamCalc",
+          streamTableNode(0),
+          term("select", "long", "rowtime", "TIME_MATERIALIZATION(rowtime) AS $f2")
+        ),
+        term("groupBy", "long"),
+        term(
+          "window",
+          TumblingGroupWindow(
+            'w$,
+            'rowtime,
+            100.millis)),
+        term("select", "long", "MIN($f2) AS EXPR$0")
+      ),
+      term("select", "EXPR$0", "long")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+}
+
+object TimeIndicatorConversionTest {
+
+  class TableFunc extends TableFunction[String] {
+    val t = new Timestamp(0L)
+    def eval(time1: Long, time2: Timestamp, string: String): Unit = {
+      collect(time1.toString + time2.after(t) + string)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
deleted file mode 100644
index c307de5..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.plan.rules
-
-import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule
-import org.apache.calcite.tools.RuleSets
-import org.apache.flink.api.scala._
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.calcite.{CalciteConfig, CalciteConfigBuilder}
-import org.apache.flink.table.utils.TableTestBase
-import org.apache.flink.table.utils.TableTestUtil._
-import org.junit.Test
-
-class NormalizationRulesTest extends TableTestBase {
-
-  @Test
-  def testApplyNormalizationRuleForBatchSQL(): Unit = {
-    val util = batchTestUtil()
-
-    // rewrite distinct aggregate
-    val cc: CalciteConfig = new CalciteConfigBuilder()
-        .replaceNormRuleSet(RuleSets.ofList(AggregateExpandDistinctAggregatesRule.JOIN))
-        .replaceLogicalOptRuleSet(RuleSets.ofList())
-        .replacePhysicalOptRuleSet(RuleSets.ofList())
-        .build()
-    util.tableEnv.getConfig.setCalciteConfig(cc)
-
-    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
-
-    val sqlQuery = "SELECT " +
-      "COUNT(DISTINCT a)" +
-      "FROM MyTable group by b"
-
-    // expect double aggregate
-    val expected = unaryNode("LogicalProject",
-      unaryNode("LogicalAggregate",
-        unaryNode("LogicalAggregate",
-          unaryNode("LogicalProject",
-            values("LogicalTableScan", term("table", "[_DataSetTable_0]")),
-            term("b", "$1"), term("a", "$0")),
-          term("group", "{0, 1}")),
-        term("group", "{0}"), term("EXPR$0", "COUNT($1)")
-      ),
-      term("EXPR$0", "$1")
-    )
-
-    util.verifySql(sqlQuery, expected)
-  }
-
-  @Test
-  def testApplyNormalizationRuleForStreamSQL(): Unit = {
-    val util = streamTestUtil()
-
-    // rewrite distinct aggregate
-    val cc: CalciteConfig = new CalciteConfigBuilder()
-        .replaceNormRuleSet(RuleSets.ofList(AggregateExpandDistinctAggregatesRule.JOIN))
-        .replaceLogicalOptRuleSet(RuleSets.ofList())
-        .replacePhysicalOptRuleSet(RuleSets.ofList())
-        .build()
-    util.tableEnv.getConfig.setCalciteConfig(cc)
-
-    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
-
-    val sqlQuery = "SELECT " +
-      "COUNT(DISTINCT a)" +
-      "FROM MyTable group by b"
-
-    // expect double aggregate
-    val expected = unaryNode(
-      "LogicalProject",
-      unaryNode("LogicalAggregate",
-        unaryNode("LogicalAggregate",
-          unaryNode("LogicalProject",
-            values("LogicalTableScan", term("table", "[_DataStreamTable_0]")),
-            term("b", "$1"), term("a", "$0")),
-          term("group", "{0, 1}")),
-        term("group", "{0}"), term("EXPR$0", "COUNT($1)")
-      ),
-      term("EXPR$0", "$1")
-    )
-
-    util.verifySql(sqlQuery, expected)
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/RetractionRulesTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/RetractionRulesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/RetractionRulesTest.scala
deleted file mode 100644
index 75d2f22..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/RetractionRulesTest.scala
+++ /dev/null
@@ -1,321 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.rules
-
-import org.apache.calcite.rel.RelNode
-import org.apache.flink.table.api.Table
-import org.apache.flink.table.plan.nodes.datastream._
-import org.apache.flink.table.utils.{StreamTableTestUtil, TableTestBase}
-import org.apache.flink.table.utils.TableTestUtil._
-import org.junit.Assert._
-import org.junit.{Ignore, Test}
-import org.apache.flink.api.scala._
-import org.apache.flink.table.api.scala._
-
-
-class RetractionRulesTest extends TableTestBase {
-
-  def streamTestForRetractionUtil(): StreamTableTestForRetractionUtil = {
-    new StreamTableTestForRetractionUtil()
-  }
-
-  @Test
-  def testSelect(): Unit = {
-    val util = streamTestForRetractionUtil()
-    val table = util.addTable[(String, Int)]('word, 'number)
-
-    val resultTable = table.select('word, 'number)
-
-    val expected = s"DataStreamScan(false, Acc)"
-
-    util.verifyTableTrait(resultTable, expected)
-  }
-
-  // one level unbounded groupBy
-  @Test
-  def testGroupBy(): Unit = {
-    val util = streamTestForRetractionUtil()
-    val table = util.addTable[(String, Int)]('word, 'number)
-    val defaultStatus = "false, Acc"
-
-    val resultTable = table
-      .groupBy('word)
-      .select('number.count)
-
-    val expected =
-      unaryNode(
-        "DataStreamCalc",
-        unaryNode(
-          "DataStreamGroupAggregate",
-          "DataStreamScan(true, Acc)",
-          s"$defaultStatus"
-        ),
-        s"$defaultStatus"
-      )
-
-    util.verifyTableTrait(resultTable, expected)
-  }
-
-  // two level unbounded groupBy
-  @Test
-  def testTwoGroupBy(): Unit = {
-    val util = streamTestForRetractionUtil()
-    val table = util.addTable[(String, Int)]('word, 'number)
-    val defaultStatus = "false, Acc"
-
-    val resultTable = table
-      .groupBy('word)
-      .select('word, 'number.count as 'count)
-      .groupBy('count)
-      .select('count, 'count.count as 'frequency)
-
-    val expected =
-      unaryNode(
-        "DataStreamCalc",
-        unaryNode(
-          "DataStreamGroupAggregate",
-          unaryNode(
-            "DataStreamCalc",
-            unaryNode(
-              "DataStreamGroupAggregate",
-              "DataStreamScan(true, Acc)",
-              "true, AccRetract"
-            ),
-            "true, AccRetract"
-          ),
-          s"$defaultStatus"
-        ),
-        s"$defaultStatus"
-      )
-
-    util.verifyTableTrait(resultTable, expected)
-  }
-
-  // group window
-  @Test
-  def testGroupWindow(): Unit = {
-    val util = streamTestForRetractionUtil()
-    val table = util.addTable[(String, Int)]('word, 'number, 'rowtime.rowtime)
-    val defaultStatus = "false, Acc"
-
-    val resultTable = table
-      .window(Tumble over 50.milli on 'rowtime as 'w)
-      .groupBy('w, 'word)
-      .select('word, 'number.count as 'count)
-
-    val expected =
-      unaryNode(
-        "DataStreamCalc",
-        unaryNode(
-          "DataStreamGroupWindowAggregate",
-          "DataStreamScan(true, Acc)",
-          s"$defaultStatus"
-        ),
-        s"$defaultStatus"
-      )
-
-    util.verifyTableTrait(resultTable, expected)
-  }
-
-  // group window after unbounded groupBy
-  @Test
-  @Ignore // cannot pass rowtime through non-windowed aggregation
-  def testGroupWindowAfterGroupBy(): Unit = {
-    val util = streamTestForRetractionUtil()
-    val table = util.addTable[(String, Int)]('word, 'number, 'rowtime.rowtime)
-    val defaultStatus = "false, Acc"
-
-    val resultTable = table
-      .groupBy('word)
-      .select('word, 'number.count as 'count)
-      .window(Tumble over 50.milli on 'rowtime as 'w)
-      .groupBy('w, 'count)
-      .select('count, 'count.count as 'frequency)
-
-    val expected =
-      unaryNode(
-        "DataStreamCalc",
-        unaryNode(
-          "DataStreamGroupWindowAggregate",
-          unaryNode(
-            "DataStreamCalc",
-            unaryNode(
-              "DataStreamGroupAggregate",
-              "DataStreamScan(true, Acc)",
-              "true, AccRetract"
-            ),
-            "true, AccRetract"
-          ),
-          s"$defaultStatus"
-        ),
-        s"$defaultStatus"
-      )
-
-    util.verifyTableTrait(resultTable, expected)
-  }
-
-  // over window
-  @Test
-  def testOverWindow(): Unit = {
-    val util = streamTestForRetractionUtil()
-    util.addTable[(String, Int)]("T1", 'word, 'number, 'proctime.proctime)
-    val defaultStatus = "false, Acc"
-
-    val sqlQuery =
-      "SELECT " +
-        "word, count(number) " +
-        "OVER (PARTITION BY word ORDER BY proctime " +
-        "ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW)" +
-        "FROM T1"
-
-    val expected =
-      unaryNode(
-        "DataStreamCalc",
-        unaryNode(
-          "DataStreamOverAggregate",
-          "DataStreamScan(true, Acc)",
-          s"$defaultStatus"
-        ),
-        s"$defaultStatus"
-      )
-
-    util.verifySqlTrait(sqlQuery, expected)
-  }
-
-
-  // over window after unbounded groupBy
-  @Test
-  @Ignore // cannot pass rowtime through non-windowed aggregation
-  def testOverWindowAfterGroupBy(): Unit = {
-    val util = streamTestForRetractionUtil()
-    util.addTable[(String, Int)]("T1", 'word, 'number, 'proctime.proctime)
-    val defaultStatus = "false, Acc"
-
-    val sqlQuery =
-      "SELECT " +
-        "_count, count(word) " +
-        "OVER (PARTITION BY _count ORDER BY proctime " +
-        "ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW)" +
-        "FROM " +
-        "(SELECT word, count(number) as _count FROM T1 GROUP BY word) "
-
-    val expected =
-      unaryNode(
-        "DataStreamCalc",
-        unaryNode(
-          "DataStreamOverAggregate",
-          unaryNode(
-            "DataStreamCalc",
-            unaryNode(
-              "DataStreamGroupAggregate",
-              "DataStreamScan(true, Acc)",
-              "true, AccRetract"
-            ),
-            "true, AccRetract"
-          ),
-          s"$defaultStatus"
-        ),
-        s"$defaultStatus"
-      )
-
-    util.verifySqlTrait(sqlQuery, expected)
-  }
-
-  // test binaryNode
-  @Test
-  def testBinaryNode(): Unit = {
-    val util = streamTestForRetractionUtil()
-    val lTable = util.addTable[(String, Int)]('word, 'number)
-    val rTable = util.addTable[(String, Long)]('word_r, 'count_r)
-    val defaultStatus = "false, Acc"
-
-    val resultTable = lTable
-      .groupBy('word)
-      .select('word, 'number.count as 'count)
-      .unionAll(rTable)
-      .groupBy('count)
-      .select('count, 'count.count as 'frequency)
-
-    val expected =
-      unaryNode(
-        "DataStreamCalc",
-        unaryNode(
-          "DataStreamGroupAggregate",
-          unaryNode(
-            "DataStreamCalc",
-            binaryNode(
-              "DataStreamUnion",
-              unaryNode(
-                "DataStreamCalc",
-                unaryNode(
-                  "DataStreamGroupAggregate",
-                  "DataStreamScan(true, Acc)",
-                  "true, AccRetract"
-                ),
-                "true, AccRetract"
-              ),
-              "DataStreamScan(true, Acc)",
-              "true, AccRetract"
-            ),
-            "true, AccRetract"
-          ),
-          s"$defaultStatus"
-        ),
-        s"$defaultStatus"
-      )
-
-    util.verifyTableTrait(resultTable, expected)
-  }
-}
-
-class StreamTableTestForRetractionUtil extends StreamTableTestUtil {
-
-  def verifySqlTrait(query: String, expected: String): Unit = {
-    verifyTableTrait(tableEnv.sql(query), expected)
-  }
-
-  def verifyTableTrait(resultTable: Table, expected: String): Unit = {
-    val relNode = resultTable.getRelNode
-    val optimized = tableEnv.optimize(relNode, updatesAsRetraction = false)
-    val actual = TraitUtil.toString(optimized)
-    assertEquals(
-      expected.split("\n").map(_.trim).mkString("\n"),
-      actual.split("\n").map(_.trim).mkString("\n"))
-  }
-}
-
-object TraitUtil {
-  def toString(rel: RelNode): String = {
-    val className = rel.getClass.getSimpleName
-    var childString: String = ""
-    var i = 0
-    while (i < rel.getInputs.size()) {
-      childString += TraitUtil.toString(rel.getInput(i))
-      i += 1
-    }
-
-    val retractString = rel.getTraitSet.getTrait(UpdateAsRetractionTraitDef.INSTANCE).toString
-    val accModetString = rel.getTraitSet.getTrait(AccModeTraitDef.INSTANCE).toString
-
-    s"""$className($retractString, $accModetString)
-       |$childString
-       |""".stripMargin.stripLineEnd
-  }
-}
-

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
deleted file mode 100644
index 5d5eece..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
+++ /dev/null
@@ -1,523 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.util
-
-import java.math.BigDecimal
-
-import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
-import org.apache.calcite.sql.SqlPostfixOperator
-import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, INTEGER, VARCHAR}
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
-import org.apache.flink.table.expressions._
-import org.apache.flink.table.utils.InputTypeBuilder.inputOf
-import org.apache.flink.table.validate.FunctionCatalog
-import org.hamcrest.CoreMatchers.is
-import org.junit.Assert.{assertArrayEquals, assertEquals, assertThat}
-import org.junit.Test
-
-import scala.collection.JavaConverters._
-
-class RexProgramExtractorTest extends RexProgramTestBase {
-
-  private val functionCatalog: FunctionCatalog = FunctionCatalog.withBuiltIns
-
-  @Test
-  def testExtractRefInputFields(): Unit = {
-    val usedFields = RexProgramExtractor.extractRefInputFields(buildSimpleRexProgram())
-    assertArrayEquals(usedFields, Array(2, 3, 1))
-  }
-
-  @Test
-  def testExtractSimpleCondition(): Unit = {
-    val builder: RexBuilder = new RexBuilder(typeFactory)
-    val program = buildSimpleRexProgram()
-
-    val firstExp = ExpressionParser.parseExpression("id > 6")
-    val secondExp = ExpressionParser.parseExpression("amount * price < 100")
-    val expected: Array[Expression] = Array(firstExp, secondExp)
-
-    val (convertedExpressions, unconvertedRexNodes) =
-      RexProgramExtractor.extractConjunctiveConditions(
-        program,
-        builder,
-        functionCatalog)
-
-    assertExpressionArrayEquals(expected, convertedExpressions)
-    assertEquals(0, unconvertedRexNodes.length)
-  }
-
-  @Test
-  def testExtractSingleCondition(): Unit = {
-    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
-    // amount
-    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
-    // id
-    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
-
-    // a = amount >= id
-    val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1))
-    builder.addCondition(a)
-
-    val program = builder.getProgram
-    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
-    val (convertedExpressions, unconvertedRexNodes) =
-      RexProgramExtractor.extractConjunctiveConditions(
-        program,
-        relBuilder,
-        functionCatalog)
-
-    val expected: Array[Expression] = Array(ExpressionParser.parseExpression("amount >= id"))
-    assertExpressionArrayEquals(expected, convertedExpressions)
-    assertEquals(0, unconvertedRexNodes.length)
-  }
-
-  // ((a AND b) OR c) AND (NOT d) => (a OR c) AND (b OR c) AND (NOT d)
-  @Test
-  def testExtractCnfCondition(): Unit = {
-    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
-    // amount
-    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
-    // id
-    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
-    // price
-    val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
-    // 100
-    val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
-    // a = amount < 100
-    val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t3))
-    // b = id > 100
-    val b = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t3))
-    // c = price == 100
-    val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t3))
-    // d = amount <= id
-    val d = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
-
-    // a AND b
-    val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(a, b).asJava))
-    // (a AND b) or c
-    val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, List(and, c).asJava))
-    // not d
-    val not = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT, List(d).asJava))
-
-    // (a AND b) OR c) AND (NOT d)
-    builder.addCondition(builder.addExpr(
-      rexBuilder.makeCall(SqlStdOperatorTable.AND, List(or, not).asJava)))
-
-    val program = builder.getProgram
-    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
-    val (convertedExpressions, unconvertedRexNodes) =
-      RexProgramExtractor.extractConjunctiveConditions(
-        program,
-        relBuilder,
-        functionCatalog)
-
-    val expected: Array[Expression] = Array(
-      ExpressionParser.parseExpression("amount < 100 || price == 100"),
-      ExpressionParser.parseExpression("id > 100 || price == 100"),
-      ExpressionParser.parseExpression("!(amount <= id)"))
-    assertExpressionArrayEquals(expected, convertedExpressions)
-    assertEquals(0, unconvertedRexNodes.length)
-  }
-
-  @Test
-  def testExtractArithmeticConditions(): Unit = {
-    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
-    // amount
-    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
-    // id
-    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
-    // 100
-    val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
-    val condition = List(
-      // amount < id
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t1)),
-      // amount <= id
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)),
-      // amount <> id
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS, t0, t1)),
-      // amount == id
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t0, t1)),
-      // amount >= id
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1)),
-      // amount > id
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t0, t1)),
-      // amount + id == 100
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-        rexBuilder.makeCall(SqlStdOperatorTable.PLUS, t0, t1), t2)),
-      // amount - id == 100
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-        rexBuilder.makeCall(SqlStdOperatorTable.MINUS, t0, t1), t2)),
-      // amount * id == 100
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-        rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t1), t2)),
-      // amount / id == 100
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-        rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, t0, t1), t2)),
-      // -amount == 100
-      builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-        rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, t0), t2))
-    ).asJava
-
-    builder.addCondition(builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, condition)))
-    val program = builder.getProgram
-    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
-    val (convertedExpressions, unconvertedRexNodes) =
-      RexProgramExtractor.extractConjunctiveConditions(
-        program,
-        relBuilder,
-        functionCatalog)
-
-    val expected: Array[Expression] = Array(
-      ExpressionParser.parseExpression("amount < id"),
-      ExpressionParser.parseExpression("amount <= id"),
-      ExpressionParser.parseExpression("amount <> id"),
-      ExpressionParser.parseExpression("amount == id"),
-      ExpressionParser.parseExpression("amount >= id"),
-      ExpressionParser.parseExpression("amount > id"),
-      ExpressionParser.parseExpression("amount + id == 100"),
-      ExpressionParser.parseExpression("amount - id == 100"),
-      ExpressionParser.parseExpression("amount * id == 100"),
-      ExpressionParser.parseExpression("amount / id == 100"),
-      ExpressionParser.parseExpression("-amount == 100")
-    )
-    assertExpressionArrayEquals(expected, convertedExpressions)
-    assertEquals(0, unconvertedRexNodes.length)
-  }
-
-  @Test
-  def testExtractPostfixConditions(): Unit = {
-    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NULL, "('flag).isNull")
-    // IS_NOT_NULL will be eliminated since flag is not nullable
-    // testExtractSinglePostfixCondition(SqlStdOperatorTable.IS_NOT_NULL, "('flag).isNotNull")
-    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_TRUE, "('flag).isTrue")
-    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_TRUE, "('flag).isNotTrue")
-    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_FALSE, "('flag).isFalse")
-    testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_FALSE, "('flag).isNotFalse")
-  }
-
-  @Test
-  def testExtractConditionWithFunctionCalls(): Unit = {
-    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
-    // amount
-    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
-    // id
-    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
-    // 100
-    val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
-    // sum(amount) > 100
-    val condition1 = builder.addExpr(
-      rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN,
-        rexBuilder.makeCall(SqlStdOperatorTable.SUM, t0), t2))
-
-    // min(id) == 100
-    val condition2 = builder.addExpr(
-      rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-        rexBuilder.makeCall(SqlStdOperatorTable.MIN, t1), t2))
-
-    builder.addCondition(builder.addExpr(
-      rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2)))
-
-    val program = builder.getProgram
-    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
-    val (convertedExpressions, unconvertedRexNodes) =
-      RexProgramExtractor.extractConjunctiveConditions(
-        program,
-        relBuilder,
-        functionCatalog)
-
-    val expected: Array[Expression] = Array(
-      GreaterThan(Sum(UnresolvedFieldReference("amount")), Literal(100)),
-      EqualTo(Min(UnresolvedFieldReference("id")), Literal(100))
-    )
-    assertExpressionArrayEquals(expected, convertedExpressions)
-    assertEquals(0, unconvertedRexNodes.length)
-  }
-
-  @Test
-  def testExtractWithUnsupportedConditions(): Unit = {
-    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
-    // amount
-    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
-    // id
-    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
-    // 100
-    val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
-    // unsupported now: amount.cast(BigInteger)
-    val cast = builder.addExpr(rexBuilder.makeCast(allFieldTypes.get(1), t0))
-
-    // unsupported now: amount.cast(BigInteger) > 100
-    val condition1 = builder.addExpr(
-      rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, cast, t2))
-
-    // amount <= id
-    val condition2 = builder.addExpr(
-      rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
-
-    // contains unsupported condition: (amount.cast(BigInteger) > 100 OR amount <= id)
-    val condition3 = builder.addExpr(
-      rexBuilder.makeCall(SqlStdOperatorTable.OR, condition1, condition2))
-
-    // only condition2 can be translated
-    builder.addCondition(
-      rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2, condition3))
-
-    val program = builder.getProgram
-    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
-    val (convertedExpressions, unconvertedRexNodes) =
-      RexProgramExtractor.extractConjunctiveConditions(
-        program,
-        relBuilder,
-        functionCatalog)
-
-    val expected: Array[Expression] = Array(
-      ExpressionParser.parseExpression("amount <= id")
-    )
-    assertExpressionArrayEquals(expected, convertedExpressions)
-    assertEquals(2, unconvertedRexNodes.length)
-    assertEquals(">(CAST($2):BIGINT NOT NULL, 100)", unconvertedRexNodes(0).toString)
-    assertEquals("OR(>(CAST($2):BIGINT NOT NULL, 100), <=($2, $1))",
-      unconvertedRexNodes(1).toString)
-  }
-
-  @Test
-  def testExtractRefNestedInputFields(): Unit = {
-    val rexProgram = buildRexProgramWithNesting()
-
-    val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
-    val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
-
-    val expected = Array(Array("amount"), Array("*"))
-    assertThat(usedNestedFields, is(expected))
-  }
-
-  @Test
-  def testExtractRefNestedInputFieldsWithNoNesting(): Unit = {
-    val rexProgram = buildSimpleRexProgram()
-
-    val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
-    val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
-
-    val expected = Array(Array("*"), Array("*"), Array("*"))
-    assertThat(usedNestedFields, is(expected))
-  }
-
-  @Test
-  def testExtractDeepRefNestedInputFields(): Unit = {
-    val rexProgram = buildRexProgramWithDeepNesting()
-
-    val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
-    val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
-
-    val expected = Array(
-      Array("amount"),
-      Array("*"),
-      Array("with.deeper.entry", "with.deep.entry"))
-
-    assertThat(usedFields, is(Array(1, 0, 2)))
-    assertThat(usedNestedFields, is(expected))
-  }
-
-  private def buildRexProgramWithDeepNesting(): RexProgram = {
-
-    // person input
-    val passportRow = inputOf(typeFactory)
-      .field("id", VARCHAR)
-      .field("status", VARCHAR)
-      .build
-
-    val personRow = inputOf(typeFactory)
-      .field("name", VARCHAR)
-      .field("age", INTEGER)
-      .nestedField("passport", passportRow)
-      .build
-
-    // payment input
-    val paymentRow = inputOf(typeFactory)
-      .field("id", BIGINT)
-      .field("amount", INTEGER)
-      .build
-
-    // deep field input
-    val deepRowType = inputOf(typeFactory)
-      .field("entry", VARCHAR)
-      .build
-
-    val entryRowType = inputOf(typeFactory)
-      .nestedField("inside", deepRowType)
-      .build
-
-    val deeperRowType = inputOf(typeFactory)
-      .nestedField("entry", entryRowType)
-      .build
-
-    val withRowType = inputOf(typeFactory)
-      .nestedField("deep", deepRowType)
-      .nestedField("deeper", deeperRowType)
-      .build
-
-    val fieldRowType = inputOf(typeFactory)
-      .nestedField("with", withRowType)
-      .build
-
-    // main input
-    val inputRowType = inputOf(typeFactory)
-      .nestedField("persons", personRow)
-      .nestedField("payments", paymentRow)
-      .nestedField("field", fieldRowType)
-      .build
-
-    // inputRowType
-    //
-    // [ persons:  [ name: VARCHAR, age:  INT, passport: [id: VARCHAR, status: VARCHAR ] ],
-    //   payments: [ id: BIGINT, amount: INT ],
-    //   field:    [ with: [ deep: [ entry: VARCHAR ],
-    //                       deeper: [ entry: [ inside: [entry: VARCHAR ] ] ]
-    //             ] ]
-    // ]
-
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
-    val t0 = rexBuilder.makeInputRef(personRow, 0)
-    val t1 = rexBuilder.makeInputRef(paymentRow, 1)
-    val t2 = rexBuilder.makeInputRef(fieldRowType, 2)
-    val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(10L))
-
-    // person
-    val person$pass = rexBuilder.makeFieldAccess(t0, "passport", false)
-    val person$pass$stat = rexBuilder.makeFieldAccess(person$pass, "status", false)
-
-    // payment
-    val pay$amount = rexBuilder.makeFieldAccess(t1, "amount", false)
-    val multiplyAmount = builder.addExpr(
-      rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, pay$amount, t3))
-
-    // field
-    val field$with = rexBuilder.makeFieldAccess(t2, "with", false)
-    val field$with$deep = rexBuilder.makeFieldAccess(field$with, "deep", false)
-    val field$with$deeper = rexBuilder.makeFieldAccess(field$with, "deeper", false)
-    val field$with$deep$entry = rexBuilder.makeFieldAccess(field$with$deep, "entry", false)
-    val field$with$deeper$entry = rexBuilder.makeFieldAccess(field$with$deeper, "entry", false)
-    val field$with$deeper$entry$inside = rexBuilder
-      .makeFieldAccess(field$with$deeper$entry, "inside", false)
-    val field$with$deeper$entry$inside$entry = rexBuilder
-      .makeFieldAccess(field$with$deeper$entry$inside, "entry", false)
-
-    builder.addProject(multiplyAmount, "amount")
-    builder.addProject(person$pass$stat, "status")
-    builder.addProject(field$with$deep$entry, "entry")
-    builder.addProject(field$with$deeper$entry$inside$entry, "entry")
-    builder.addProject(field$with$deeper$entry, "entry2")
-    builder.addProject(t0, "person")
-
-    // Program
-    // (
-    //   payments.amount * 10),
-    //   persons.passport.status,
-    //   field.with.deep.entry
-    //   field.with.deeper.entry.inside.entry
-    //   field.with.deeper.entry
-    //   persons
-    // )
-
-    builder.getProgram
-
-  }
-
-  private def buildRexProgramWithNesting(): RexProgram = {
-
-    val personRow = inputOf(typeFactory)
-      .field("name", INTEGER)
-      .field("age", VARCHAR)
-      .build
-
-    val paymentRow = inputOf(typeFactory)
-      .field("id", BIGINT)
-      .field("amount", INTEGER)
-      .build
-
-    val types = List(personRow, paymentRow).asJava
-    val names = List("persons", "payments").asJava
-    val inputRowType = typeFactory.createStructType(types, names)
-
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
-    val t0 = rexBuilder.makeInputRef(types.get(0), 0)
-    val t1 = rexBuilder.makeInputRef(types.get(1), 1)
-    val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
-    val payment$amount = rexBuilder.makeFieldAccess(t1, "amount", false)
-
-    builder.addProject(payment$amount, "amount")
-    builder.addProject(t0, "persons")
-    builder.addProject(t2, "number")
-    builder.getProgram
-  }
-
-  private def testExtractSinglePostfixCondition(
-      fieldIndex: Integer,
-      op: SqlPostfixOperator,
-      expr: String) : Unit = {
-
-    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-    rexBuilder = new RexBuilder(typeFactory)
-
-    // flag
-    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(fieldIndex), fieldIndex)
-    builder.addCondition(builder.addExpr(rexBuilder.makeCall(op, t0)))
-
-    val program = builder.getProgram(false)
-    val relBuilder: RexBuilder = new RexBuilder(typeFactory)
-    val (convertedExpressions, unconvertedRexNodes) =
-      RexProgramExtractor.extractConjunctiveConditions(
-        program,
-        relBuilder,
-        functionCatalog)
-
-    assertEquals(1, convertedExpressions.length)
-    assertEquals(expr, convertedExpressions.head.toString)
-    assertEquals(0, unconvertedRexNodes.length)
-  }
-
-  private def assertExpressionArrayEquals(
-      expected: Array[Expression],
-      actual: Array[Expression]) = {
-    val sortedExpected = expected.sortBy(e => e.toString)
-    val sortedActual = actual.sortBy(e => e.toString)
-
-    assertEquals(sortedExpected.length, sortedActual.length)
-    sortedExpected.zip(sortedActual).foreach {
-      case (l, r) => assertEquals(l.toString, r.toString)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
deleted file mode 100644
index 899eed2..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.util
-
-import org.junit.Assert.assertTrue
-import org.junit.Test
-
-import scala.collection.JavaConverters._
-
-class RexProgramRewriterTest extends RexProgramTestBase {
-
-  @Test
-  def testRewriteRexProgram(): Unit = {
-    val rexProgram = buildSimpleRexProgram()
-    assertTrue(extractExprStrList(rexProgram) == wrapRefArray(Array(
-      "$0",
-      "$1",
-      "$2",
-      "$3",
-      "$4",
-      "*($t2, $t3)",
-      "100",
-      "<($t5, $t6)",
-      "6",
-      ">($t1, $t8)",
-      "AND($t7, $t9)")))
-
-    // use amount, id, price fields to create a new RexProgram
-    val usedFields = Array(2, 3, 1)
-    val types = usedFields.map(allFieldTypes.get).toList.asJava
-    val names = usedFields.map(allFieldNames.get).toList.asJava
-    val inputRowType = typeFactory.createStructType(types, names)
-    val newRexProgram = RexProgramRewriter.rewriteWithFieldProjection(
-      rexProgram, inputRowType, rexBuilder, usedFields)
-    assertTrue(extractExprStrList(newRexProgram) == wrapRefArray(Array(
-      "$0",
-      "$1",
-      "$2",
-      "*($t0, $t1)",
-      "100",
-      "<($t3, $t4)",
-      "6",
-      ">($t2, $t6)",
-      "AND($t5, $t7)")))
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
deleted file mode 100644
index 6ef3d82..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.util
-
-import java.math.BigDecimal
-import java.util
-
-import org.apache.calcite.adapter.java.JavaTypeFactory
-import org.apache.calcite.jdbc.JavaTypeFactoryImpl
-import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
-import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
-import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR, BOOLEAN}
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-
-abstract class RexProgramTestBase {
-
-  val typeFactory: JavaTypeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
-
-  val allFieldNames: util.List[String] = List("name", "id", "amount", "price", "flag").asJava
-
-  val allFieldTypes: util.List[RelDataType] =
-    List(VARCHAR, BIGINT, INTEGER, DOUBLE, BOOLEAN).map(typeFactory.createSqlType).asJava
-
-  var rexBuilder: RexBuilder = new RexBuilder(typeFactory)
-
-  /**
-    * extract all expression string list from input RexProgram expression lists
-    *
-    * @param rexProgram input RexProgram instance to analyze
-    * @return all expression string list of input RexProgram expression lists
-    */
-  protected def extractExprStrList(rexProgram: RexProgram): mutable.Buffer[String] = {
-    rexProgram.getExprList.asScala.map(_.toString)
-  }
-
-  // select amount, amount * price as total where amount * price < 100 and id > 6
-  protected def buildSimpleRexProgram(): RexProgram = {
-    val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
-    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
-    val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
-    val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
-    val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
-    val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
-    val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-    val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
-
-    // project: amount, amount * price as total
-    builder.addProject(t0, "amount")
-    builder.addProject(t3, "total")
-
-    // condition: amount * price < 100 and id > 6
-    val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
-    val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
-    val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
-    builder.addCondition(t8)
-
-    builder.getProgram
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala
new file mode 100644
index 0000000..458f80d
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggfunctions
+
+import java.lang.reflect.Method
+import java.math.BigDecimal
+import java.util.{ArrayList => JArrayList, List => JList}
+
+import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.aggfunctions.{DecimalAvgAccumulator, DecimalSumWithRetractAccumulator}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+import org.junit.Assert.assertEquals
+import org.junit.Test
+
+/**
+  * Base class for aggregate function test
+  *
+  * @tparam T the type for the aggregation result
+  */
+abstract class AggFunctionTestBase[T, ACC] {
+  def inputValueSets: Seq[Seq[_]]
+
+  def expectedResults: Seq[T]
+
+  def aggregator: AggregateFunction[T, ACC]
+
+  val accType = aggregator.getClass.getMethod("createAccumulator").getReturnType
+
+  def accumulateFunc: Method = aggregator.getClass.getMethod("accumulate", accType, classOf[Any])
+
+  def retractFunc: Method = null
+
+  @Test
+  // test aggregate and retract functions without partial merge
+  def testAccumulateAndRetractWithoutMerge(): Unit = {
+    // iterate over input sets
+    for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
+      val accumulator = accumulateVals(vals)
+      val result = aggregator.getValue(accumulator)
+      validateResult[T](expected, result)
+
+      if (ifMethodExistInFunction("retract", aggregator)) {
+        retractVals(accumulator, vals)
+        val expectedAccum = aggregator.createAccumulator()
+        //The two accumulators should be exactly same
+        validateResult[ACC](expectedAccum, accumulator)
+      }
+    }
+  }
+
+  @Test
+  def testAggregateWithMerge(): Unit = {
+
+    if (ifMethodExistInFunction("merge", aggregator)) {
+      val mergeFunc =
+        aggregator.getClass.getMethod("merge", accType, classOf[java.lang.Iterable[ACC]])
+      // iterate over input sets
+      for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
+        //equally split the vals sequence into two sequences
+        val (firstVals, secondVals) = vals.splitAt(vals.length / 2)
+
+        //1. verify merge with accumulate
+        val accumulators: JList[ACC] = new JArrayList[ACC]()
+        accumulators.add(accumulateVals(secondVals))
+
+        val acc = accumulateVals(firstVals)
+
+        mergeFunc.invoke(aggregator, acc.asInstanceOf[Object], accumulators)
+        val result = aggregator.getValue(acc)
+        validateResult[T](expected, result)
+
+        //2. verify merge with accumulate & retract
+        if (ifMethodExistInFunction("retract", aggregator)) {
+          retractVals(acc, vals)
+          val expectedAccum = aggregator.createAccumulator()
+          //The two accumulators should be exactly same
+          validateResult[ACC](expectedAccum, acc)
+        }
+      }
+
+      // iterate over input sets
+      for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
+        //3. test partial merge with an empty accumulator
+        val accumulators: JList[ACC] = new JArrayList[ACC]()
+        accumulators.add(aggregator.createAccumulator())
+
+        val acc = accumulateVals(vals)
+
+        mergeFunc.invoke(aggregator, acc.asInstanceOf[Object], accumulators)
+        val result = aggregator.getValue(acc)
+        validateResult[T](expected, result)
+      }
+    }
+  }
+
+  @Test
+  // test aggregate functions with resetAccumulator
+  def testResetAccumulator(): Unit = {
+
+    if (ifMethodExistInFunction("resetAccumulator", aggregator)) {
+      val resetAccFunc = aggregator.getClass.getMethod("resetAccumulator", accType)
+      // iterate over input sets
+      for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
+        val accumulator = accumulateVals(vals)
+        resetAccFunc.invoke(aggregator, accumulator.asInstanceOf[Object])
+        val expectedAccum = aggregator.createAccumulator()
+        //The accumulator after reset should be exactly same as the new accumulator
+        validateResult[ACC](expectedAccum, accumulator)
+      }
+    }
+  }
+
+  private def validateResult[T](expected: T, result: T): Unit = {
+    (expected, result) match {
+      case (e: DecimalSumWithRetractAccumulator, r: DecimalSumWithRetractAccumulator) =>
+        // BigDecimal.equals() value and scale but we are only interested in value.
+        assert(e.f0.compareTo(r.f0) == 0 && e.f1 == r.f1)
+      case (e: DecimalAvgAccumulator, r: DecimalAvgAccumulator) =>
+        // BigDecimal.equals() value and scale but we are only interested in value.
+        assert(e.f0.compareTo(r.f0) == 0 && e.f1 == r.f1)
+      case (e: BigDecimal, r: BigDecimal) =>
+        // BigDecimal.equals() value and scale but we are only interested in value.
+        assert(e.compareTo(r) == 0)
+      case _ =>
+        assertEquals(expected, result)
+    }
+  }
+
+  private def accumulateVals(vals: Seq[_]): ACC = {
+    val accumulator = aggregator.createAccumulator()
+    vals.foreach(
+      v =>
+        accumulateFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object])
+    )
+    accumulator
+  }
+
+  private def retractVals(accumulator:ACC, vals: Seq[_]) = {
+    vals.foreach(
+      v => retractFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object])
+    )
+  }
+}