You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/07/13 10:18:38 UTC
[29/44] flink git commit: [FLINK-6617] [table] Restructuring of tests
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala
new file mode 100644
index 0000000..840be17
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala
@@ -0,0 +1,524 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import java.math.BigDecimal
+
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.SqlPostfixOperator
+import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, INTEGER, VARCHAR}
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.plan.util.RexProgramExtractor
+import org.apache.flink.table.utils.InputTypeBuilder.inputOf
+import org.apache.flink.table.validate.FunctionCatalog
+import org.hamcrest.CoreMatchers.is
+import org.junit.Assert.{assertArrayEquals, assertEquals, assertThat}
+import org.junit.Test
+
+import scala.collection.JavaConverters._
+
+class RexProgramExtractorTest extends RexProgramTestBase {
+
+ private val functionCatalog: FunctionCatalog = FunctionCatalog.withBuiltIns
+
+ @Test
+ def testExtractRefInputFields(): Unit = {
+ val usedFields = RexProgramExtractor.extractRefInputFields(buildSimpleRexProgram())
+ assertArrayEquals(usedFields, Array(2, 3, 1))
+ }
+
+ @Test
+ def testExtractSimpleCondition(): Unit = {
+ val builder: RexBuilder = new RexBuilder(typeFactory)
+ val program = buildSimpleRexProgram()
+
+ val firstExp = ExpressionParser.parseExpression("id > 6")
+ val secondExp = ExpressionParser.parseExpression("amount * price < 100")
+ val expected: Array[Expression] = Array(firstExp, secondExp)
+
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ builder,
+ functionCatalog)
+
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ @Test
+ def testExtractSingleCondition(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+
+ // a = amount >= id
+ val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1))
+ builder.addCondition(a)
+
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(ExpressionParser.parseExpression("amount >= id"))
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ // ((a AND b) OR c) AND (NOT d) => (a OR c) AND (b OR c) AND (NOT d)
+ @Test
+ def testExtractCnfCondition(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ // price
+ val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
+ // 100
+ val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ // a = amount < 100
+ val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t3))
+ // b = id > 100
+ val b = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t3))
+ // c = price == 100
+ val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t3))
+ // d = amount <= id
+ val d = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
+
+ // a AND b
+ val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(a, b).asJava))
+ // (a AND b) or c
+ val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, List(and, c).asJava))
+ // not d
+ val not = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT, List(d).asJava))
+
+ // (a AND b) OR c) AND (NOT d)
+ builder.addCondition(builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.AND, List(or, not).asJava)))
+
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(
+ ExpressionParser.parseExpression("amount < 100 || price == 100"),
+ ExpressionParser.parseExpression("id > 100 || price == 100"),
+ ExpressionParser.parseExpression("!(amount <= id)"))
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ @Test
+ def testExtractArithmeticConditions(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ // 100
+ val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ val condition = List(
+ // amount < id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t1)),
+ // amount <= id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)),
+ // amount <> id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS, t0, t1)),
+ // amount == id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t0, t1)),
+ // amount >= id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1)),
+ // amount > id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t0, t1)),
+ // amount + id == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.PLUS, t0, t1), t2)),
+ // amount - id == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.MINUS, t0, t1), t2)),
+ // amount * id == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t1), t2)),
+ // amount / id == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, t0, t1), t2)),
+ // -amount == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, t0), t2))
+ ).asJava
+
+ builder.addCondition(builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, condition)))
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(
+ ExpressionParser.parseExpression("amount < id"),
+ ExpressionParser.parseExpression("amount <= id"),
+ ExpressionParser.parseExpression("amount <> id"),
+ ExpressionParser.parseExpression("amount == id"),
+ ExpressionParser.parseExpression("amount >= id"),
+ ExpressionParser.parseExpression("amount > id"),
+ ExpressionParser.parseExpression("amount + id == 100"),
+ ExpressionParser.parseExpression("amount - id == 100"),
+ ExpressionParser.parseExpression("amount * id == 100"),
+ ExpressionParser.parseExpression("amount / id == 100"),
+ ExpressionParser.parseExpression("-amount == 100")
+ )
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ @Test
+ def testExtractPostfixConditions(): Unit = {
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NULL, "('flag).isNull")
+ // IS_NOT_NULL will be eliminated since flag is not nullable
+ // testExtractSinglePostfixCondition(SqlStdOperatorTable.IS_NOT_NULL, "('flag).isNotNull")
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_TRUE, "('flag).isTrue")
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_TRUE, "('flag).isNotTrue")
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_FALSE, "('flag).isFalse")
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_FALSE, "('flag).isNotFalse")
+ }
+
+ @Test
+ def testExtractConditionWithFunctionCalls(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ // 100
+ val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ // sum(amount) > 100
+ val condition1 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN,
+ rexBuilder.makeCall(SqlStdOperatorTable.SUM, t0), t2))
+
+ // min(id) == 100
+ val condition2 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.MIN, t1), t2))
+
+ builder.addCondition(builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2)))
+
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(
+ GreaterThan(Sum(UnresolvedFieldReference("amount")), Literal(100)),
+ EqualTo(Min(UnresolvedFieldReference("id")), Literal(100))
+ )
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ @Test
+ def testExtractWithUnsupportedConditions(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ // 100
+ val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ // unsupported now: amount.cast(BigInteger)
+ val cast = builder.addExpr(rexBuilder.makeCast(allFieldTypes.get(1), t0))
+
+ // unsupported now: amount.cast(BigInteger) > 100
+ val condition1 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, cast, t2))
+
+ // amount <= id
+ val condition2 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
+
+ // contains unsupported condition: (amount.cast(BigInteger) > 100 OR amount <= id)
+ val condition3 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.OR, condition1, condition2))
+
+ // only condition2 can be translated
+ builder.addCondition(
+ rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2, condition3))
+
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(
+ ExpressionParser.parseExpression("amount <= id")
+ )
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(2, unconvertedRexNodes.length)
+ assertEquals(">(CAST($2):BIGINT NOT NULL, 100)", unconvertedRexNodes(0).toString)
+ assertEquals("OR(>(CAST($2):BIGINT NOT NULL, 100), <=($2, $1))",
+ unconvertedRexNodes(1).toString)
+ }
+
+ @Test
+ def testExtractRefNestedInputFields(): Unit = {
+ val rexProgram = buildRexProgramWithNesting()
+
+ val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
+ val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
+
+ val expected = Array(Array("amount"), Array("*"))
+ assertThat(usedNestedFields, is(expected))
+ }
+
+ @Test
+ def testExtractRefNestedInputFieldsWithNoNesting(): Unit = {
+ val rexProgram = buildSimpleRexProgram()
+
+ val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
+ val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
+
+ val expected = Array(Array("*"), Array("*"), Array("*"))
+ assertThat(usedNestedFields, is(expected))
+ }
+
+ @Test
+ def testExtractDeepRefNestedInputFields(): Unit = {
+ val rexProgram = buildRexProgramWithDeepNesting()
+
+ val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
+ val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
+
+ val expected = Array(
+ Array("amount"),
+ Array("*"),
+ Array("with.deeper.entry", "with.deep.entry"))
+
+ assertThat(usedFields, is(Array(1, 0, 2)))
+ assertThat(usedNestedFields, is(expected))
+ }
+
+ private def buildRexProgramWithDeepNesting(): RexProgram = {
+
+ // person input
+ val passportRow = inputOf(typeFactory)
+ .field("id", VARCHAR)
+ .field("status", VARCHAR)
+ .build
+
+ val personRow = inputOf(typeFactory)
+ .field("name", VARCHAR)
+ .field("age", INTEGER)
+ .nestedField("passport", passportRow)
+ .build
+
+ // payment input
+ val paymentRow = inputOf(typeFactory)
+ .field("id", BIGINT)
+ .field("amount", INTEGER)
+ .build
+
+ // deep field input
+ val deepRowType = inputOf(typeFactory)
+ .field("entry", VARCHAR)
+ .build
+
+ val entryRowType = inputOf(typeFactory)
+ .nestedField("inside", deepRowType)
+ .build
+
+ val deeperRowType = inputOf(typeFactory)
+ .nestedField("entry", entryRowType)
+ .build
+
+ val withRowType = inputOf(typeFactory)
+ .nestedField("deep", deepRowType)
+ .nestedField("deeper", deeperRowType)
+ .build
+
+ val fieldRowType = inputOf(typeFactory)
+ .nestedField("with", withRowType)
+ .build
+
+ // main input
+ val inputRowType = inputOf(typeFactory)
+ .nestedField("persons", personRow)
+ .nestedField("payments", paymentRow)
+ .nestedField("field", fieldRowType)
+ .build
+
+ // inputRowType
+ //
+ // [ persons: [ name: VARCHAR, age: INT, passport: [id: VARCHAR, status: VARCHAR ] ],
+ // payments: [ id: BIGINT, amount: INT ],
+ // field: [ with: [ deep: [ entry: VARCHAR ],
+ // deeper: [ entry: [ inside: [entry: VARCHAR ] ] ]
+ // ] ]
+ // ]
+
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ val t0 = rexBuilder.makeInputRef(personRow, 0)
+ val t1 = rexBuilder.makeInputRef(paymentRow, 1)
+ val t2 = rexBuilder.makeInputRef(fieldRowType, 2)
+ val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(10L))
+
+ // person
+ val person$pass = rexBuilder.makeFieldAccess(t0, "passport", false)
+ val person$pass$stat = rexBuilder.makeFieldAccess(person$pass, "status", false)
+
+ // payment
+ val pay$amount = rexBuilder.makeFieldAccess(t1, "amount", false)
+ val multiplyAmount = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, pay$amount, t3))
+
+ // field
+ val field$with = rexBuilder.makeFieldAccess(t2, "with", false)
+ val field$with$deep = rexBuilder.makeFieldAccess(field$with, "deep", false)
+ val field$with$deeper = rexBuilder.makeFieldAccess(field$with, "deeper", false)
+ val field$with$deep$entry = rexBuilder.makeFieldAccess(field$with$deep, "entry", false)
+ val field$with$deeper$entry = rexBuilder.makeFieldAccess(field$with$deeper, "entry", false)
+ val field$with$deeper$entry$inside = rexBuilder
+ .makeFieldAccess(field$with$deeper$entry, "inside", false)
+ val field$with$deeper$entry$inside$entry = rexBuilder
+ .makeFieldAccess(field$with$deeper$entry$inside, "entry", false)
+
+ builder.addProject(multiplyAmount, "amount")
+ builder.addProject(person$pass$stat, "status")
+ builder.addProject(field$with$deep$entry, "entry")
+ builder.addProject(field$with$deeper$entry$inside$entry, "entry")
+ builder.addProject(field$with$deeper$entry, "entry2")
+ builder.addProject(t0, "person")
+
+ // Program
+ // (
+ // payments.amount * 10),
+ // persons.passport.status,
+ // field.with.deep.entry
+ // field.with.deeper.entry.inside.entry
+ // field.with.deeper.entry
+ // persons
+ // )
+
+ builder.getProgram
+
+ }
+
+ private def buildRexProgramWithNesting(): RexProgram = {
+
+ val personRow = inputOf(typeFactory)
+ .field("name", INTEGER)
+ .field("age", VARCHAR)
+ .build
+
+ val paymentRow = inputOf(typeFactory)
+ .field("id", BIGINT)
+ .field("amount", INTEGER)
+ .build
+
+ val types = List(personRow, paymentRow).asJava
+ val names = List("persons", "payments").asJava
+ val inputRowType = typeFactory.createStructType(types, names)
+
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ val t0 = rexBuilder.makeInputRef(types.get(0), 0)
+ val t1 = rexBuilder.makeInputRef(types.get(1), 1)
+ val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ val payment$amount = rexBuilder.makeFieldAccess(t1, "amount", false)
+
+ builder.addProject(payment$amount, "amount")
+ builder.addProject(t0, "persons")
+ builder.addProject(t2, "number")
+ builder.getProgram
+ }
+
+ private def testExtractSinglePostfixCondition(
+ fieldIndex: Integer,
+ op: SqlPostfixOperator,
+ expr: String) : Unit = {
+
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+ rexBuilder = new RexBuilder(typeFactory)
+
+ // flag
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(fieldIndex), fieldIndex)
+ builder.addCondition(builder.addExpr(rexBuilder.makeCall(op, t0)))
+
+ val program = builder.getProgram(false)
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ assertEquals(1, convertedExpressions.length)
+ assertEquals(expr, convertedExpressions.head.toString)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ private def assertExpressionArrayEquals(
+ expected: Array[Expression],
+ actual: Array[Expression]) = {
+ val sortedExpected = expected.sortBy(e => e.toString)
+ val sortedActual = actual.sortBy(e => e.toString)
+
+ assertEquals(sortedExpected.length, sortedActual.length)
+ sortedExpected.zip(sortedActual).foreach {
+ case (l, r) => assertEquals(l.toString, r.toString)
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramRewriterTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramRewriterTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramRewriterTest.scala
new file mode 100644
index 0000000..dc91a82
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramRewriterTest.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import org.apache.flink.table.plan.util.RexProgramRewriter
+import org.junit.Assert.assertTrue
+import org.junit.Test
+
+import scala.collection.JavaConverters._
+
+class RexProgramRewriterTest extends RexProgramTestBase {
+
+ @Test
+ def testRewriteRexProgram(): Unit = {
+ val rexProgram = buildSimpleRexProgram()
+ assertTrue(extractExprStrList(rexProgram) == wrapRefArray(Array(
+ "$0",
+ "$1",
+ "$2",
+ "$3",
+ "$4",
+ "*($t2, $t3)",
+ "100",
+ "<($t5, $t6)",
+ "6",
+ ">($t1, $t8)",
+ "AND($t7, $t9)")))
+
+ // use amount, id, price fields to create a new RexProgram
+ val usedFields = Array(2, 3, 1)
+ val types = usedFields.map(allFieldTypes.get).toList.asJava
+ val names = usedFields.map(allFieldNames.get).toList.asJava
+ val inputRowType = typeFactory.createStructType(types, names)
+ val newRexProgram = RexProgramRewriter.rewriteWithFieldProjection(
+ rexProgram, inputRowType, rexBuilder, usedFields)
+ assertTrue(extractExprStrList(newRexProgram) == wrapRefArray(Array(
+ "$0",
+ "$1",
+ "$2",
+ "*($t0, $t1)",
+ "100",
+ "<($t3, $t4)",
+ "6",
+ ">($t2, $t6)",
+ "AND($t5, $t7)")))
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala
new file mode 100644
index 0000000..b711604
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramTestBase.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import java.math.BigDecimal
+
+import org.apache.calcite.adapter.java.JavaTypeFactory
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.`type`.SqlTypeName._
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+abstract class RexProgramTestBase {
+
+ val typeFactory: JavaTypeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
+
+ val allFieldNames: java.util.List[String] = List("name", "id", "amount", "price", "flag").asJava
+
+ val allFieldTypes: java.util.List[RelDataType] =
+ List(VARCHAR, BIGINT, INTEGER, DOUBLE, BOOLEAN).map(typeFactory.createSqlType).asJava
+
+ var rexBuilder: RexBuilder = new RexBuilder(typeFactory)
+
+ /**
+ * extract all expression string list from input RexProgram expression lists
+ *
+ * @param rexProgram input RexProgram instance to analyze
+ * @return all expression string list of input RexProgram expression lists
+ */
+ protected def extractExprStrList(rexProgram: RexProgram): mutable.Buffer[String] = {
+ rexProgram.getExprList.asScala.map(_.toString)
+ }
+
+ // select amount, amount * price as total where amount * price < 100 and id > 6
+ protected def buildSimpleRexProgram(): RexProgram = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
+ val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
+ val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+ val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
+
+ // project: amount, amount * price as total
+ builder.addProject(t0, "amount")
+ builder.addProject(t3, "total")
+
+ // condition: amount * price < 100 and id > 6
+ val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
+ val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
+ val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
+ builder.addCondition(t8)
+
+ builder.getProgram
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
new file mode 100644
index 0000000..7885160
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
@@ -0,0 +1,398 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.{TimeIntervalUnit, WindowReference}
+import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.plan.TimeIndicatorConversionTest.TableFunc
+import org.apache.flink.table.plan.logical.TumblingGroupWindow
+import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.table.utils.TableTestUtil._
+import org.junit.Test
+
+/**
+ * Tests for [[org.apache.flink.table.calcite.RelTimeIndicatorConverter]].
+ */
+class TimeIndicatorConversionTest extends TableTestBase {
+
+ @Test
+ def testSimpleMaterialization(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int, 'proctime.proctime)
+
+ val result = t
+ .select('rowtime.floor(TimeIntervalUnit.DAY) as 'rowtime, 'long)
+ .filter('long > 0)
+ .select('rowtime)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "FLOOR(TIME_MATERIALIZATION(rowtime)", "FLAG(DAY)) AS rowtime"),
+ term("where", ">(long, 0)")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testSelectAll(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int, 'proctime.proctime)
+
+ val result = t.select('*)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime", "long", "int",
+ "TIME_MATERIALIZATION(proctime) AS proctime")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testFilteringOnRowtime(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int)
+
+ val result = t
+ .filter('rowtime > "1990-12-02 12:11:11".toTimestamp)
+ .select('rowtime)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime"),
+ term("where", ">(TIME_MATERIALIZATION(rowtime), 1990-12-02 12:11:11)")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testGroupingOnRowtime(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int, 'proctime.proctime)
+
+ val result = t
+ .groupBy('rowtime)
+ .select('long.count)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "long", "TIME_MATERIALIZATION(rowtime) AS rowtime")
+ ),
+ term("groupBy", "rowtime"),
+ term("select", "rowtime", "COUNT(long) AS TMP_0")
+ ),
+ term("select", "TMP_0")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testAggregationOnRowtime(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int)
+
+ val result = t
+ .groupBy('long)
+ .select('rowtime.min)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime", "long")
+ ),
+ term("groupBy", "long"),
+ term("select", "long", "MIN(rowtime) AS TMP_0")
+ ),
+ term("select", "TMP_0")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testTableFunction(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int, 'proctime.proctime)
+ val func = new TableFunc
+
+ val result = t.join(func('rowtime, 'proctime, "") as 's).select('rowtime, 'proctime, 's)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation",
+ s"${func.functionIdentifier}(TIME_MATERIALIZATION($$0), TIME_MATERIALIZATION($$3), '')"),
+ term("function", func),
+ term("rowType", "RecordType(TIMESTAMP(3) rowtime, BIGINT long, INTEGER int, " +
+ "TIMESTAMP(3) proctime, VARCHAR(2147483647) s)"),
+ term("joinType", "INNER")
+ ),
+ term("select",
+ "TIME_MATERIALIZATION(rowtime) AS rowtime",
+ "TIME_MATERIALIZATION(proctime) AS proctime",
+ "s")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testWindow(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int)
+
+ val result = t
+ .window(Tumble over 100.millis on 'rowtime as 'w)
+ .groupBy('w, 'long)
+ .select('w.end as 'rowtime, 'long, 'int.sum)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ streamTableNode(0),
+ term("groupBy", "long"),
+ term(
+ "window",
+ TumblingGroupWindow(
+ 'w,
+ 'rowtime,
+ 100.millis)),
+ term("select", "long", "SUM(int) AS TMP_1", "end('w) AS TMP_0")
+ ),
+ term("select", "TMP_0 AS rowtime", "long", "TMP_1")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testUnion(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int)
+
+ val result = t.unionAll(t).select('rowtime)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamUnion",
+ streamTableNode(0),
+ streamTableNode(0),
+ term("union all", "rowtime", "long", "int")
+ ),
+ term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testMultiWindow(): Unit = {
+ val util = streamTestUtil()
+ val t = util.addTable[(Long, Long, Int)]('rowtime.rowtime, 'long, 'int)
+
+ val result = t
+ .window(Tumble over 100.millis on 'rowtime as 'w)
+ .groupBy('w, 'long)
+ .select('w.rowtime as 'newrowtime, 'long, 'int.sum as 'int)
+ .window(Tumble over 1.second on 'newrowtime as 'w2)
+ .groupBy('w2, 'long)
+ .select('w2.end, 'long, 'int.sum)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ streamTableNode(0),
+ term("groupBy", "long"),
+ term(
+ "window",
+ TumblingGroupWindow(
+ 'w,
+ 'rowtime,
+ 100.millis)),
+ term("select", "long", "SUM(int) AS TMP_1", "rowtime('w) AS TMP_0")
+ ),
+ term("select", "TMP_0 AS newrowtime", "long", "TMP_1 AS int")
+ ),
+ term("groupBy", "long"),
+ term(
+ "window",
+ TumblingGroupWindow(
+ 'w2,
+ 'newrowtime,
+ 1000.millis)),
+ term("select", "long", "SUM(int) AS TMP_3", "end('w2) AS TMP_2")
+ ),
+ term("select", "TMP_2", "long", "TMP_3")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testGroupingOnProctime(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(Long, Int)]("MyTable" , 'long, 'int, 'proctime.proctime)
+
+ val result = util.tableEnv.sql("SELECT COUNT(long) FROM MyTable GROUP BY proctime")
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "TIME_MATERIALIZATION(proctime) AS proctime", "long")
+ ),
+ term("groupBy", "proctime"),
+ term("select", "proctime", "COUNT(long) AS EXPR$0")
+ ),
+ term("select", "EXPR$0")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testAggregationOnProctime(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(Long, Int)]("MyTable" , 'long, 'int, 'proctime.proctime)
+
+ val result = util.tableEnv.sql("SELECT MIN(proctime) FROM MyTable GROUP BY long")
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "long", "TIME_MATERIALIZATION(proctime) AS proctime")
+ ),
+ term("groupBy", "long"),
+ term("select", "long", "MIN(proctime) AS EXPR$0")
+ ),
+ term("select", "EXPR$0")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testWindowSql(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int)
+
+ val result = util.tableEnv.sql(
+ "SELECT TUMBLE_END(rowtime, INTERVAL '0.1' SECOND) AS `rowtime`, `long`, " +
+ "SUM(`int`) FROM MyTable " +
+ "GROUP BY `long`, TUMBLE(rowtime, INTERVAL '0.1' SECOND)")
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ streamTableNode(0),
+ term("groupBy", "long"),
+ term(
+ "window",
+ TumblingGroupWindow(
+ WindowReference("w$"),
+ 'rowtime,
+ 100.millis)),
+ term("select", "long", "SUM(int) AS EXPR$2", "start('w$) AS w$start", "end('w$) AS w$end")
+ ),
+ term("select", "w$end", "long", "EXPR$2")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testWindowWithAggregationOnRowtime(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int)
+
+ val result = util.tableEnv.sql("SELECT MIN(rowtime), long FROM MyTable " +
+ "GROUP BY long, TUMBLE(rowtime, INTERVAL '0.1' SECOND)")
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "long", "rowtime", "TIME_MATERIALIZATION(rowtime) AS $f2")
+ ),
+ term("groupBy", "long"),
+ term(
+ "window",
+ TumblingGroupWindow(
+ 'w$,
+ 'rowtime,
+ 100.millis)),
+ term("select", "long", "MIN($f2) AS EXPR$0")
+ ),
+ term("select", "EXPR$0", "long")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+}
+
+object TimeIndicatorConversionTest {
+
+ class TableFunc extends TableFunction[String] {
+ val t = new Timestamp(0L)
+ def eval(time1: Long, time2: Timestamp, string: String): Unit = {
+ collect(time1.toString + time2.after(t) + string)
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
deleted file mode 100644
index c307de5..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.plan.rules
-
-import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule
-import org.apache.calcite.tools.RuleSets
-import org.apache.flink.api.scala._
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.calcite.{CalciteConfig, CalciteConfigBuilder}
-import org.apache.flink.table.utils.TableTestBase
-import org.apache.flink.table.utils.TableTestUtil._
-import org.junit.Test
-
-class NormalizationRulesTest extends TableTestBase {
-
- @Test
- def testApplyNormalizationRuleForBatchSQL(): Unit = {
- val util = batchTestUtil()
-
- // rewrite distinct aggregate
- val cc: CalciteConfig = new CalciteConfigBuilder()
- .replaceNormRuleSet(RuleSets.ofList(AggregateExpandDistinctAggregatesRule.JOIN))
- .replaceLogicalOptRuleSet(RuleSets.ofList())
- .replacePhysicalOptRuleSet(RuleSets.ofList())
- .build()
- util.tableEnv.getConfig.setCalciteConfig(cc)
-
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
-
- val sqlQuery = "SELECT " +
- "COUNT(DISTINCT a)" +
- "FROM MyTable group by b"
-
- // expect double aggregate
- val expected = unaryNode("LogicalProject",
- unaryNode("LogicalAggregate",
- unaryNode("LogicalAggregate",
- unaryNode("LogicalProject",
- values("LogicalTableScan", term("table", "[_DataSetTable_0]")),
- term("b", "$1"), term("a", "$0")),
- term("group", "{0, 1}")),
- term("group", "{0}"), term("EXPR$0", "COUNT($1)")
- ),
- term("EXPR$0", "$1")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testApplyNormalizationRuleForStreamSQL(): Unit = {
- val util = streamTestUtil()
-
- // rewrite distinct aggregate
- val cc: CalciteConfig = new CalciteConfigBuilder()
- .replaceNormRuleSet(RuleSets.ofList(AggregateExpandDistinctAggregatesRule.JOIN))
- .replaceLogicalOptRuleSet(RuleSets.ofList())
- .replacePhysicalOptRuleSet(RuleSets.ofList())
- .build()
- util.tableEnv.getConfig.setCalciteConfig(cc)
-
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
-
- val sqlQuery = "SELECT " +
- "COUNT(DISTINCT a)" +
- "FROM MyTable group by b"
-
- // expect double aggregate
- val expected = unaryNode(
- "LogicalProject",
- unaryNode("LogicalAggregate",
- unaryNode("LogicalAggregate",
- unaryNode("LogicalProject",
- values("LogicalTableScan", term("table", "[_DataStreamTable_0]")),
- term("b", "$1"), term("a", "$0")),
- term("group", "{0, 1}")),
- term("group", "{0}"), term("EXPR$0", "COUNT($1)")
- ),
- term("EXPR$0", "$1")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/RetractionRulesTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/RetractionRulesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/RetractionRulesTest.scala
deleted file mode 100644
index 75d2f22..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/RetractionRulesTest.scala
+++ /dev/null
@@ -1,321 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.rules
-
-import org.apache.calcite.rel.RelNode
-import org.apache.flink.table.api.Table
-import org.apache.flink.table.plan.nodes.datastream._
-import org.apache.flink.table.utils.{StreamTableTestUtil, TableTestBase}
-import org.apache.flink.table.utils.TableTestUtil._
-import org.junit.Assert._
-import org.junit.{Ignore, Test}
-import org.apache.flink.api.scala._
-import org.apache.flink.table.api.scala._
-
-
-class RetractionRulesTest extends TableTestBase {
-
- def streamTestForRetractionUtil(): StreamTableTestForRetractionUtil = {
- new StreamTableTestForRetractionUtil()
- }
-
- @Test
- def testSelect(): Unit = {
- val util = streamTestForRetractionUtil()
- val table = util.addTable[(String, Int)]('word, 'number)
-
- val resultTable = table.select('word, 'number)
-
- val expected = s"DataStreamScan(false, Acc)"
-
- util.verifyTableTrait(resultTable, expected)
- }
-
- // one level unbounded groupBy
- @Test
- def testGroupBy(): Unit = {
- val util = streamTestForRetractionUtil()
- val table = util.addTable[(String, Int)]('word, 'number)
- val defaultStatus = "false, Acc"
-
- val resultTable = table
- .groupBy('word)
- .select('number.count)
-
- val expected =
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupAggregate",
- "DataStreamScan(true, Acc)",
- s"$defaultStatus"
- ),
- s"$defaultStatus"
- )
-
- util.verifyTableTrait(resultTable, expected)
- }
-
- // two level unbounded groupBy
- @Test
- def testTwoGroupBy(): Unit = {
- val util = streamTestForRetractionUtil()
- val table = util.addTable[(String, Int)]('word, 'number)
- val defaultStatus = "false, Acc"
-
- val resultTable = table
- .groupBy('word)
- .select('word, 'number.count as 'count)
- .groupBy('count)
- .select('count, 'count.count as 'frequency)
-
- val expected =
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupAggregate",
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupAggregate",
- "DataStreamScan(true, Acc)",
- "true, AccRetract"
- ),
- "true, AccRetract"
- ),
- s"$defaultStatus"
- ),
- s"$defaultStatus"
- )
-
- util.verifyTableTrait(resultTable, expected)
- }
-
- // group window
- @Test
- def testGroupWindow(): Unit = {
- val util = streamTestForRetractionUtil()
- val table = util.addTable[(String, Int)]('word, 'number, 'rowtime.rowtime)
- val defaultStatus = "false, Acc"
-
- val resultTable = table
- .window(Tumble over 50.milli on 'rowtime as 'w)
- .groupBy('w, 'word)
- .select('word, 'number.count as 'count)
-
- val expected =
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupWindowAggregate",
- "DataStreamScan(true, Acc)",
- s"$defaultStatus"
- ),
- s"$defaultStatus"
- )
-
- util.verifyTableTrait(resultTable, expected)
- }
-
- // group window after unbounded groupBy
- @Test
- @Ignore // cannot pass rowtime through non-windowed aggregation
- def testGroupWindowAfterGroupBy(): Unit = {
- val util = streamTestForRetractionUtil()
- val table = util.addTable[(String, Int)]('word, 'number, 'rowtime.rowtime)
- val defaultStatus = "false, Acc"
-
- val resultTable = table
- .groupBy('word)
- .select('word, 'number.count as 'count)
- .window(Tumble over 50.milli on 'rowtime as 'w)
- .groupBy('w, 'count)
- .select('count, 'count.count as 'frequency)
-
- val expected =
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupWindowAggregate",
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupAggregate",
- "DataStreamScan(true, Acc)",
- "true, AccRetract"
- ),
- "true, AccRetract"
- ),
- s"$defaultStatus"
- ),
- s"$defaultStatus"
- )
-
- util.verifyTableTrait(resultTable, expected)
- }
-
- // over window
- @Test
- def testOverWindow(): Unit = {
- val util = streamTestForRetractionUtil()
- util.addTable[(String, Int)]("T1", 'word, 'number, 'proctime.proctime)
- val defaultStatus = "false, Acc"
-
- val sqlQuery =
- "SELECT " +
- "word, count(number) " +
- "OVER (PARTITION BY word ORDER BY proctime " +
- "ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW)" +
- "FROM T1"
-
- val expected =
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamOverAggregate",
- "DataStreamScan(true, Acc)",
- s"$defaultStatus"
- ),
- s"$defaultStatus"
- )
-
- util.verifySqlTrait(sqlQuery, expected)
- }
-
-
- // over window after unbounded groupBy
- @Test
- @Ignore // cannot pass rowtime through non-windowed aggregation
- def testOverWindowAfterGroupBy(): Unit = {
- val util = streamTestForRetractionUtil()
- util.addTable[(String, Int)]("T1", 'word, 'number, 'proctime.proctime)
- val defaultStatus = "false, Acc"
-
- val sqlQuery =
- "SELECT " +
- "_count, count(word) " +
- "OVER (PARTITION BY _count ORDER BY proctime " +
- "ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW)" +
- "FROM " +
- "(SELECT word, count(number) as _count FROM T1 GROUP BY word) "
-
- val expected =
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamOverAggregate",
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupAggregate",
- "DataStreamScan(true, Acc)",
- "true, AccRetract"
- ),
- "true, AccRetract"
- ),
- s"$defaultStatus"
- ),
- s"$defaultStatus"
- )
-
- util.verifySqlTrait(sqlQuery, expected)
- }
-
- // test binaryNode
- @Test
- def testBinaryNode(): Unit = {
- val util = streamTestForRetractionUtil()
- val lTable = util.addTable[(String, Int)]('word, 'number)
- val rTable = util.addTable[(String, Long)]('word_r, 'count_r)
- val defaultStatus = "false, Acc"
-
- val resultTable = lTable
- .groupBy('word)
- .select('word, 'number.count as 'count)
- .unionAll(rTable)
- .groupBy('count)
- .select('count, 'count.count as 'frequency)
-
- val expected =
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupAggregate",
- unaryNode(
- "DataStreamCalc",
- binaryNode(
- "DataStreamUnion",
- unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamGroupAggregate",
- "DataStreamScan(true, Acc)",
- "true, AccRetract"
- ),
- "true, AccRetract"
- ),
- "DataStreamScan(true, Acc)",
- "true, AccRetract"
- ),
- "true, AccRetract"
- ),
- s"$defaultStatus"
- ),
- s"$defaultStatus"
- )
-
- util.verifyTableTrait(resultTable, expected)
- }
-}
-
-class StreamTableTestForRetractionUtil extends StreamTableTestUtil {
-
- def verifySqlTrait(query: String, expected: String): Unit = {
- verifyTableTrait(tableEnv.sql(query), expected)
- }
-
- def verifyTableTrait(resultTable: Table, expected: String): Unit = {
- val relNode = resultTable.getRelNode
- val optimized = tableEnv.optimize(relNode, updatesAsRetraction = false)
- val actual = TraitUtil.toString(optimized)
- assertEquals(
- expected.split("\n").map(_.trim).mkString("\n"),
- actual.split("\n").map(_.trim).mkString("\n"))
- }
-}
-
-object TraitUtil {
- def toString(rel: RelNode): String = {
- val className = rel.getClass.getSimpleName
- var childString: String = ""
- var i = 0
- while (i < rel.getInputs.size()) {
- childString += TraitUtil.toString(rel.getInput(i))
- i += 1
- }
-
- val retractString = rel.getTraitSet.getTrait(UpdateAsRetractionTraitDef.INSTANCE).toString
- val accModetString = rel.getTraitSet.getTrait(AccModeTraitDef.INSTANCE).toString
-
- s"""$className($retractString, $accModetString)
- |$childString
- |""".stripMargin.stripLineEnd
- }
-}
-
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
deleted file mode 100644
index 5d5eece..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
+++ /dev/null
@@ -1,523 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.util
-
-import java.math.BigDecimal
-
-import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
-import org.apache.calcite.sql.SqlPostfixOperator
-import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, INTEGER, VARCHAR}
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
-import org.apache.flink.table.expressions._
-import org.apache.flink.table.utils.InputTypeBuilder.inputOf
-import org.apache.flink.table.validate.FunctionCatalog
-import org.hamcrest.CoreMatchers.is
-import org.junit.Assert.{assertArrayEquals, assertEquals, assertThat}
-import org.junit.Test
-
-import scala.collection.JavaConverters._
-
-class RexProgramExtractorTest extends RexProgramTestBase {
-
- private val functionCatalog: FunctionCatalog = FunctionCatalog.withBuiltIns
-
- @Test
- def testExtractRefInputFields(): Unit = {
- val usedFields = RexProgramExtractor.extractRefInputFields(buildSimpleRexProgram())
- assertArrayEquals(usedFields, Array(2, 3, 1))
- }
-
- @Test
- def testExtractSimpleCondition(): Unit = {
- val builder: RexBuilder = new RexBuilder(typeFactory)
- val program = buildSimpleRexProgram()
-
- val firstExp = ExpressionParser.parseExpression("id > 6")
- val secondExp = ExpressionParser.parseExpression("amount * price < 100")
- val expected: Array[Expression] = Array(firstExp, secondExp)
-
- val (convertedExpressions, unconvertedRexNodes) =
- RexProgramExtractor.extractConjunctiveConditions(
- program,
- builder,
- functionCatalog)
-
- assertExpressionArrayEquals(expected, convertedExpressions)
- assertEquals(0, unconvertedRexNodes.length)
- }
-
- @Test
- def testExtractSingleCondition(): Unit = {
- val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- // amount
- val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
- // id
- val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
-
- // a = amount >= id
- val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1))
- builder.addCondition(a)
-
- val program = builder.getProgram
- val relBuilder: RexBuilder = new RexBuilder(typeFactory)
- val (convertedExpressions, unconvertedRexNodes) =
- RexProgramExtractor.extractConjunctiveConditions(
- program,
- relBuilder,
- functionCatalog)
-
- val expected: Array[Expression] = Array(ExpressionParser.parseExpression("amount >= id"))
- assertExpressionArrayEquals(expected, convertedExpressions)
- assertEquals(0, unconvertedRexNodes.length)
- }
-
- // ((a AND b) OR c) AND (NOT d) => (a OR c) AND (b OR c) AND (NOT d)
- @Test
- def testExtractCnfCondition(): Unit = {
- val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- // amount
- val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
- // id
- val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
- // price
- val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
- // 100
- val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
- // a = amount < 100
- val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t3))
- // b = id > 100
- val b = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t3))
- // c = price == 100
- val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t3))
- // d = amount <= id
- val d = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
-
- // a AND b
- val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(a, b).asJava))
- // (a AND b) or c
- val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, List(and, c).asJava))
- // not d
- val not = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT, List(d).asJava))
-
- // (a AND b) OR c) AND (NOT d)
- builder.addCondition(builder.addExpr(
- rexBuilder.makeCall(SqlStdOperatorTable.AND, List(or, not).asJava)))
-
- val program = builder.getProgram
- val relBuilder: RexBuilder = new RexBuilder(typeFactory)
- val (convertedExpressions, unconvertedRexNodes) =
- RexProgramExtractor.extractConjunctiveConditions(
- program,
- relBuilder,
- functionCatalog)
-
- val expected: Array[Expression] = Array(
- ExpressionParser.parseExpression("amount < 100 || price == 100"),
- ExpressionParser.parseExpression("id > 100 || price == 100"),
- ExpressionParser.parseExpression("!(amount <= id)"))
- assertExpressionArrayEquals(expected, convertedExpressions)
- assertEquals(0, unconvertedRexNodes.length)
- }
-
- @Test
- def testExtractArithmeticConditions(): Unit = {
- val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- // amount
- val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
- // id
- val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
- // 100
- val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
- val condition = List(
- // amount < id
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t1)),
- // amount <= id
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)),
- // amount <> id
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS, t0, t1)),
- // amount == id
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t0, t1)),
- // amount >= id
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1)),
- // amount > id
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t0, t1)),
- // amount + id == 100
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
- rexBuilder.makeCall(SqlStdOperatorTable.PLUS, t0, t1), t2)),
- // amount - id == 100
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
- rexBuilder.makeCall(SqlStdOperatorTable.MINUS, t0, t1), t2)),
- // amount * id == 100
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
- rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t1), t2)),
- // amount / id == 100
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
- rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, t0, t1), t2)),
- // -amount == 100
- builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
- rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, t0), t2))
- ).asJava
-
- builder.addCondition(builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, condition)))
- val program = builder.getProgram
- val relBuilder: RexBuilder = new RexBuilder(typeFactory)
- val (convertedExpressions, unconvertedRexNodes) =
- RexProgramExtractor.extractConjunctiveConditions(
- program,
- relBuilder,
- functionCatalog)
-
- val expected: Array[Expression] = Array(
- ExpressionParser.parseExpression("amount < id"),
- ExpressionParser.parseExpression("amount <= id"),
- ExpressionParser.parseExpression("amount <> id"),
- ExpressionParser.parseExpression("amount == id"),
- ExpressionParser.parseExpression("amount >= id"),
- ExpressionParser.parseExpression("amount > id"),
- ExpressionParser.parseExpression("amount + id == 100"),
- ExpressionParser.parseExpression("amount - id == 100"),
- ExpressionParser.parseExpression("amount * id == 100"),
- ExpressionParser.parseExpression("amount / id == 100"),
- ExpressionParser.parseExpression("-amount == 100")
- )
- assertExpressionArrayEquals(expected, convertedExpressions)
- assertEquals(0, unconvertedRexNodes.length)
- }
-
- @Test
- def testExtractPostfixConditions(): Unit = {
- testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NULL, "('flag).isNull")
- // IS_NOT_NULL will be eliminated since flag is not nullable
- // testExtractSinglePostfixCondition(SqlStdOperatorTable.IS_NOT_NULL, "('flag).isNotNull")
- testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_TRUE, "('flag).isTrue")
- testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_TRUE, "('flag).isNotTrue")
- testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_FALSE, "('flag).isFalse")
- testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_FALSE, "('flag).isNotFalse")
- }
-
- @Test
- def testExtractConditionWithFunctionCalls(): Unit = {
- val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- // amount
- val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
- // id
- val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
- // 100
- val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
- // sum(amount) > 100
- val condition1 = builder.addExpr(
- rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN,
- rexBuilder.makeCall(SqlStdOperatorTable.SUM, t0), t2))
-
- // min(id) == 100
- val condition2 = builder.addExpr(
- rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
- rexBuilder.makeCall(SqlStdOperatorTable.MIN, t1), t2))
-
- builder.addCondition(builder.addExpr(
- rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2)))
-
- val program = builder.getProgram
- val relBuilder: RexBuilder = new RexBuilder(typeFactory)
- val (convertedExpressions, unconvertedRexNodes) =
- RexProgramExtractor.extractConjunctiveConditions(
- program,
- relBuilder,
- functionCatalog)
-
- val expected: Array[Expression] = Array(
- GreaterThan(Sum(UnresolvedFieldReference("amount")), Literal(100)),
- EqualTo(Min(UnresolvedFieldReference("id")), Literal(100))
- )
- assertExpressionArrayEquals(expected, convertedExpressions)
- assertEquals(0, unconvertedRexNodes.length)
- }
-
- @Test
- def testExtractWithUnsupportedConditions(): Unit = {
- val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- // amount
- val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
- // id
- val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
- // 100
- val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
- // unsupported now: amount.cast(BigInteger)
- val cast = builder.addExpr(rexBuilder.makeCast(allFieldTypes.get(1), t0))
-
- // unsupported now: amount.cast(BigInteger) > 100
- val condition1 = builder.addExpr(
- rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, cast, t2))
-
- // amount <= id
- val condition2 = builder.addExpr(
- rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
-
- // contains unsupported condition: (amount.cast(BigInteger) > 100 OR amount <= id)
- val condition3 = builder.addExpr(
- rexBuilder.makeCall(SqlStdOperatorTable.OR, condition1, condition2))
-
- // only condition2 can be translated
- builder.addCondition(
- rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2, condition3))
-
- val program = builder.getProgram
- val relBuilder: RexBuilder = new RexBuilder(typeFactory)
- val (convertedExpressions, unconvertedRexNodes) =
- RexProgramExtractor.extractConjunctiveConditions(
- program,
- relBuilder,
- functionCatalog)
-
- val expected: Array[Expression] = Array(
- ExpressionParser.parseExpression("amount <= id")
- )
- assertExpressionArrayEquals(expected, convertedExpressions)
- assertEquals(2, unconvertedRexNodes.length)
- assertEquals(">(CAST($2):BIGINT NOT NULL, 100)", unconvertedRexNodes(0).toString)
- assertEquals("OR(>(CAST($2):BIGINT NOT NULL, 100), <=($2, $1))",
- unconvertedRexNodes(1).toString)
- }
-
- @Test
- def testExtractRefNestedInputFields(): Unit = {
- val rexProgram = buildRexProgramWithNesting()
-
- val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
- val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
-
- val expected = Array(Array("amount"), Array("*"))
- assertThat(usedNestedFields, is(expected))
- }
-
- @Test
- def testExtractRefNestedInputFieldsWithNoNesting(): Unit = {
- val rexProgram = buildSimpleRexProgram()
-
- val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
- val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
-
- val expected = Array(Array("*"), Array("*"), Array("*"))
- assertThat(usedNestedFields, is(expected))
- }
-
- @Test
- def testExtractDeepRefNestedInputFields(): Unit = {
- val rexProgram = buildRexProgramWithDeepNesting()
-
- val usedFields = RexProgramExtractor.extractRefInputFields(rexProgram)
- val usedNestedFields = RexProgramExtractor.extractRefNestedInputFields(rexProgram, usedFields)
-
- val expected = Array(
- Array("amount"),
- Array("*"),
- Array("with.deeper.entry", "with.deep.entry"))
-
- assertThat(usedFields, is(Array(1, 0, 2)))
- assertThat(usedNestedFields, is(expected))
- }
-
- private def buildRexProgramWithDeepNesting(): RexProgram = {
-
- // person input
- val passportRow = inputOf(typeFactory)
- .field("id", VARCHAR)
- .field("status", VARCHAR)
- .build
-
- val personRow = inputOf(typeFactory)
- .field("name", VARCHAR)
- .field("age", INTEGER)
- .nestedField("passport", passportRow)
- .build
-
- // payment input
- val paymentRow = inputOf(typeFactory)
- .field("id", BIGINT)
- .field("amount", INTEGER)
- .build
-
- // deep field input
- val deepRowType = inputOf(typeFactory)
- .field("entry", VARCHAR)
- .build
-
- val entryRowType = inputOf(typeFactory)
- .nestedField("inside", deepRowType)
- .build
-
- val deeperRowType = inputOf(typeFactory)
- .nestedField("entry", entryRowType)
- .build
-
- val withRowType = inputOf(typeFactory)
- .nestedField("deep", deepRowType)
- .nestedField("deeper", deeperRowType)
- .build
-
- val fieldRowType = inputOf(typeFactory)
- .nestedField("with", withRowType)
- .build
-
- // main input
- val inputRowType = inputOf(typeFactory)
- .nestedField("persons", personRow)
- .nestedField("payments", paymentRow)
- .nestedField("field", fieldRowType)
- .build
-
- // inputRowType
- //
- // [ persons: [ name: VARCHAR, age: INT, passport: [id: VARCHAR, status: VARCHAR ] ],
- // payments: [ id: BIGINT, amount: INT ],
- // field: [ with: [ deep: [ entry: VARCHAR ],
- // deeper: [ entry: [ inside: [entry: VARCHAR ] ] ]
- // ] ]
- // ]
-
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- val t0 = rexBuilder.makeInputRef(personRow, 0)
- val t1 = rexBuilder.makeInputRef(paymentRow, 1)
- val t2 = rexBuilder.makeInputRef(fieldRowType, 2)
- val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(10L))
-
- // person
- val person$pass = rexBuilder.makeFieldAccess(t0, "passport", false)
- val person$pass$stat = rexBuilder.makeFieldAccess(person$pass, "status", false)
-
- // payment
- val pay$amount = rexBuilder.makeFieldAccess(t1, "amount", false)
- val multiplyAmount = builder.addExpr(
- rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, pay$amount, t3))
-
- // field
- val field$with = rexBuilder.makeFieldAccess(t2, "with", false)
- val field$with$deep = rexBuilder.makeFieldAccess(field$with, "deep", false)
- val field$with$deeper = rexBuilder.makeFieldAccess(field$with, "deeper", false)
- val field$with$deep$entry = rexBuilder.makeFieldAccess(field$with$deep, "entry", false)
- val field$with$deeper$entry = rexBuilder.makeFieldAccess(field$with$deeper, "entry", false)
- val field$with$deeper$entry$inside = rexBuilder
- .makeFieldAccess(field$with$deeper$entry, "inside", false)
- val field$with$deeper$entry$inside$entry = rexBuilder
- .makeFieldAccess(field$with$deeper$entry$inside, "entry", false)
-
- builder.addProject(multiplyAmount, "amount")
- builder.addProject(person$pass$stat, "status")
- builder.addProject(field$with$deep$entry, "entry")
- builder.addProject(field$with$deeper$entry$inside$entry, "entry")
- builder.addProject(field$with$deeper$entry, "entry2")
- builder.addProject(t0, "person")
-
- // Program
- // (
- // payments.amount * 10),
- // persons.passport.status,
- // field.with.deep.entry
- // field.with.deeper.entry.inside.entry
- // field.with.deeper.entry
- // persons
- // )
-
- builder.getProgram
-
- }
-
- private def buildRexProgramWithNesting(): RexProgram = {
-
- val personRow = inputOf(typeFactory)
- .field("name", INTEGER)
- .field("age", VARCHAR)
- .build
-
- val paymentRow = inputOf(typeFactory)
- .field("id", BIGINT)
- .field("amount", INTEGER)
- .build
-
- val types = List(personRow, paymentRow).asJava
- val names = List("persons", "payments").asJava
- val inputRowType = typeFactory.createStructType(types, names)
-
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- val t0 = rexBuilder.makeInputRef(types.get(0), 0)
- val t1 = rexBuilder.makeInputRef(types.get(1), 1)
- val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
-
- val payment$amount = rexBuilder.makeFieldAccess(t1, "amount", false)
-
- builder.addProject(payment$amount, "amount")
- builder.addProject(t0, "persons")
- builder.addProject(t2, "number")
- builder.getProgram
- }
-
- private def testExtractSinglePostfixCondition(
- fieldIndex: Integer,
- op: SqlPostfixOperator,
- expr: String) : Unit = {
-
- val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
- rexBuilder = new RexBuilder(typeFactory)
-
- // flag
- val t0 = rexBuilder.makeInputRef(allFieldTypes.get(fieldIndex), fieldIndex)
- builder.addCondition(builder.addExpr(rexBuilder.makeCall(op, t0)))
-
- val program = builder.getProgram(false)
- val relBuilder: RexBuilder = new RexBuilder(typeFactory)
- val (convertedExpressions, unconvertedRexNodes) =
- RexProgramExtractor.extractConjunctiveConditions(
- program,
- relBuilder,
- functionCatalog)
-
- assertEquals(1, convertedExpressions.length)
- assertEquals(expr, convertedExpressions.head.toString)
- assertEquals(0, unconvertedRexNodes.length)
- }
-
- private def assertExpressionArrayEquals(
- expected: Array[Expression],
- actual: Array[Expression]) = {
- val sortedExpected = expected.sortBy(e => e.toString)
- val sortedActual = actual.sortBy(e => e.toString)
-
- assertEquals(sortedExpected.length, sortedActual.length)
- sortedExpected.zip(sortedActual).foreach {
- case (l, r) => assertEquals(l.toString, r.toString)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
deleted file mode 100644
index 899eed2..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.util
-
-import org.junit.Assert.assertTrue
-import org.junit.Test
-
-import scala.collection.JavaConverters._
-
-class RexProgramRewriterTest extends RexProgramTestBase {
-
- @Test
- def testRewriteRexProgram(): Unit = {
- val rexProgram = buildSimpleRexProgram()
- assertTrue(extractExprStrList(rexProgram) == wrapRefArray(Array(
- "$0",
- "$1",
- "$2",
- "$3",
- "$4",
- "*($t2, $t3)",
- "100",
- "<($t5, $t6)",
- "6",
- ">($t1, $t8)",
- "AND($t7, $t9)")))
-
- // use amount, id, price fields to create a new RexProgram
- val usedFields = Array(2, 3, 1)
- val types = usedFields.map(allFieldTypes.get).toList.asJava
- val names = usedFields.map(allFieldNames.get).toList.asJava
- val inputRowType = typeFactory.createStructType(types, names)
- val newRexProgram = RexProgramRewriter.rewriteWithFieldProjection(
- rexProgram, inputRowType, rexBuilder, usedFields)
- assertTrue(extractExprStrList(newRexProgram) == wrapRefArray(Array(
- "$0",
- "$1",
- "$2",
- "*($t0, $t1)",
- "100",
- "<($t3, $t4)",
- "6",
- ">($t2, $t6)",
- "AND($t5, $t7)")))
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
deleted file mode 100644
index 6ef3d82..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.util
-
-import java.math.BigDecimal
-import java.util
-
-import org.apache.calcite.adapter.java.JavaTypeFactory
-import org.apache.calcite.jdbc.JavaTypeFactoryImpl
-import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
-import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
-import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR, BOOLEAN}
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-
-abstract class RexProgramTestBase {
-
- val typeFactory: JavaTypeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
-
- val allFieldNames: util.List[String] = List("name", "id", "amount", "price", "flag").asJava
-
- val allFieldTypes: util.List[RelDataType] =
- List(VARCHAR, BIGINT, INTEGER, DOUBLE, BOOLEAN).map(typeFactory.createSqlType).asJava
-
- var rexBuilder: RexBuilder = new RexBuilder(typeFactory)
-
- /**
- * extract all expression string list from input RexProgram expression lists
- *
- * @param rexProgram input RexProgram instance to analyze
- * @return all expression string list of input RexProgram expression lists
- */
- protected def extractExprStrList(rexProgram: RexProgram): mutable.Buffer[String] = {
- rexProgram.getExprList.asScala.map(_.toString)
- }
-
- // select amount, amount * price as total where amount * price < 100 and id > 6
- protected def buildSimpleRexProgram(): RexProgram = {
- val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
- val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
- val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
- val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
- val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
- val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
-
- // project: amount, amount * price as total
- builder.addProject(t0, "amount")
- builder.addProject(t3, "total")
-
- // condition: amount * price < 100 and id > 6
- val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
- val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
- val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
- builder.addCondition(t8)
-
- builder.getProgram
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/f1fafc0e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala
new file mode 100644
index 0000000..458f80d
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggfunctions
+
+import java.lang.reflect.Method
+import java.math.BigDecimal
+import java.util.{ArrayList => JArrayList, List => JList}
+
+import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.aggfunctions.{DecimalAvgAccumulator, DecimalSumWithRetractAccumulator}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+import org.junit.Assert.assertEquals
+import org.junit.Test
+
+/**
+ * Base class for aggregate function test
+ *
+ * @tparam T the type for the aggregation result
+ */
+abstract class AggFunctionTestBase[T, ACC] {
+ def inputValueSets: Seq[Seq[_]]
+
+ def expectedResults: Seq[T]
+
+ def aggregator: AggregateFunction[T, ACC]
+
+ val accType = aggregator.getClass.getMethod("createAccumulator").getReturnType
+
+ def accumulateFunc: Method = aggregator.getClass.getMethod("accumulate", accType, classOf[Any])
+
+ def retractFunc: Method = null
+
+ @Test
+ // test aggregate and retract functions without partial merge
+ def testAccumulateAndRetractWithoutMerge(): Unit = {
+ // iterate over input sets
+ for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
+ val accumulator = accumulateVals(vals)
+ val result = aggregator.getValue(accumulator)
+ validateResult[T](expected, result)
+
+ if (ifMethodExistInFunction("retract", aggregator)) {
+ retractVals(accumulator, vals)
+ val expectedAccum = aggregator.createAccumulator()
+ //The two accumulators should be exactly same
+ validateResult[ACC](expectedAccum, accumulator)
+ }
+ }
+ }
+
+ @Test
+ def testAggregateWithMerge(): Unit = {
+
+ if (ifMethodExistInFunction("merge", aggregator)) {
+ val mergeFunc =
+ aggregator.getClass.getMethod("merge", accType, classOf[java.lang.Iterable[ACC]])
+ // iterate over input sets
+ for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
+ //equally split the vals sequence into two sequences
+ val (firstVals, secondVals) = vals.splitAt(vals.length / 2)
+
+ //1. verify merge with accumulate
+ val accumulators: JList[ACC] = new JArrayList[ACC]()
+ accumulators.add(accumulateVals(secondVals))
+
+ val acc = accumulateVals(firstVals)
+
+ mergeFunc.invoke(aggregator, acc.asInstanceOf[Object], accumulators)
+ val result = aggregator.getValue(acc)
+ validateResult[T](expected, result)
+
+ //2. verify merge with accumulate & retract
+ if (ifMethodExistInFunction("retract", aggregator)) {
+ retractVals(acc, vals)
+ val expectedAccum = aggregator.createAccumulator()
+ //The two accumulators should be exactly same
+ validateResult[ACC](expectedAccum, acc)
+ }
+ }
+
+ // iterate over input sets
+ for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
+ //3. test partial merge with an empty accumulator
+ val accumulators: JList[ACC] = new JArrayList[ACC]()
+ accumulators.add(aggregator.createAccumulator())
+
+ val acc = accumulateVals(vals)
+
+ mergeFunc.invoke(aggregator, acc.asInstanceOf[Object], accumulators)
+ val result = aggregator.getValue(acc)
+ validateResult[T](expected, result)
+ }
+ }
+ }
+
+ @Test
+ // test aggregate functions with resetAccumulator
+ def testResetAccumulator(): Unit = {
+
+ if (ifMethodExistInFunction("resetAccumulator", aggregator)) {
+ val resetAccFunc = aggregator.getClass.getMethod("resetAccumulator", accType)
+ // iterate over input sets
+ for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
+ val accumulator = accumulateVals(vals)
+ resetAccFunc.invoke(aggregator, accumulator.asInstanceOf[Object])
+ val expectedAccum = aggregator.createAccumulator()
+ //The accumulator after reset should be exactly same as the new accumulator
+ validateResult[ACC](expectedAccum, accumulator)
+ }
+ }
+ }
+
+ private def validateResult[T](expected: T, result: T): Unit = {
+ (expected, result) match {
+ case (e: DecimalSumWithRetractAccumulator, r: DecimalSumWithRetractAccumulator) =>
+ // BigDecimal.equals() value and scale but we are only interested in value.
+ assert(e.f0.compareTo(r.f0) == 0 && e.f1 == r.f1)
+ case (e: DecimalAvgAccumulator, r: DecimalAvgAccumulator) =>
+ // BigDecimal.equals() value and scale but we are only interested in value.
+ assert(e.f0.compareTo(r.f0) == 0 && e.f1 == r.f1)
+ case (e: BigDecimal, r: BigDecimal) =>
+ // BigDecimal.equals() value and scale but we are only interested in value.
+ assert(e.compareTo(r) == 0)
+ case _ =>
+ assertEquals(expected, result)
+ }
+ }
+
+ private def accumulateVals(vals: Seq[_]): ACC = {
+ val accumulator = aggregator.createAccumulator()
+ vals.foreach(
+ v =>
+ accumulateFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object])
+ )
+ accumulator
+ }
+
+ private def retractVals(accumulator:ACC, vals: Seq[_]) = {
+ vals.foreach(
+ v => retractFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object])
+ )
+ }
+}