You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ku...@apache.org on 2017/03/17 10:03:05 UTC
[3/4] flink git commit: [FLINK-3849] [table] Add
FilterableTableSource interface and rules for pushing it (2)
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala
index 429cccb..570bdff 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/ProjectableTableSource.scala
@@ -22,17 +22,16 @@ package org.apache.flink.table.sources
* Adds support for projection push-down to a [[TableSource]].
* A [[TableSource]] extending this interface is able to project the fields of the return table.
*
- * @tparam T The return type of the [[ProjectableTableSource]].
+ * @tparam T The return type of the [[TableSource]].
*/
trait ProjectableTableSource[T] {
/**
- * Creates a copy of the [[ProjectableTableSource]] that projects its output on the specified
- * fields.
+ * Creates a copy of the [[TableSource]] that projects its output on the specified fields.
*
* @param fields The indexes of the fields to return.
- * @return A copy of the [[ProjectableTableSource]] that projects its output.
+ * @return A copy of the [[TableSource]] that projects its output.
*/
- def projectFields(fields: Array[Int]): ProjectableTableSource[T]
+ def projectFields(fields: Array[Int]): TableSource[T]
}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala
index fe205f1..c41582e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala
@@ -38,4 +38,6 @@ trait TableSource[T] {
/** Returns the [[TypeInformation]] for the return type of the [[TableSource]]. */
def getReturnType: TypeInformation[T]
+ /** Describes the table source */
+ def explainSource(): String = ""
}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index 2c08d8d..fcfcf43 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -86,6 +86,7 @@ class FunctionCatalog {
.getOrElse(throw ValidationException(s"Undefined scalar function: $name"))
.asInstanceOf[ScalarSqlFunction]
ScalarFunctionCall(scalarSqlFunction.getScalarFunction, children)
+
// user-defined table function call
case tf if classOf[TableFunction[_]].isAssignableFrom(tf) =>
val tableSqlFunction = sqlFunctions
@@ -105,7 +106,7 @@ class FunctionCatalog {
case Success(expr) => expr
case Failure(e) => throw new ValidationException(e.getMessage)
}
- case Failure(e) =>
+ case Failure(_) =>
val childrenClass = Seq.fill(children.length)(classOf[Expression])
// try to find a constructor matching the exact number of children
Try(funcClass.getDeclaredConstructor(childrenClass: _*)) match {
@@ -114,7 +115,7 @@ class FunctionCatalog {
case Success(expr) => expr
case Failure(exception) => throw ValidationException(exception.getMessage)
}
- case Failure(exception) =>
+ case Failure(_) =>
throw ValidationException(s"Invalid number of arguments for function $funcClass")
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala
index 058eca7..97d4d59 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala
@@ -22,8 +22,9 @@ import org.apache.flink.table.api.Types
import org.apache.flink.table.api.scala._
import org.apache.flink.table.sources.{CsvTableSource, TableSource}
import org.apache.flink.table.utils.TableTestUtil._
+import org.apache.flink.table.expressions.utils._
+import org.apache.flink.table.utils.{CommonTestData, TableTestBase, TestFilterableTableSource}
import org.junit.{Assert, Test}
-import org.apache.flink.table.utils.{CommonTestData, TableTestBase}
class TableSourceTest extends TableTestBase {
@@ -46,7 +47,7 @@ class TableSourceTest extends TableTestBase {
val expected = unaryNode(
"DataSetCalc",
- projectableSourceBatchTableNode(tableName, projectedFields),
+ batchSourceTableNode(tableName, projectedFields),
term("select", "UPPER(last) AS _c0", "FLOOR(id) AS _c1", "*(score, 2) AS _c2")
)
@@ -64,7 +65,7 @@ class TableSourceTest extends TableTestBase {
val expected = unaryNode(
"DataSetCalc",
- projectableSourceBatchTableNode(tableName, projectedFields),
+ batchSourceTableNode(tableName, projectedFields),
term("select", "last", "FLOOR(id) AS EXPR$1", "*(score, 2) AS EXPR$2")
)
@@ -83,12 +84,37 @@ class TableSourceTest extends TableTestBase {
.scan(tableName)
.select('id, 'score, 'first)
- val expected = projectableSourceBatchTableNode(tableName, noCalcFields)
+ val expected = batchSourceTableNode(tableName, noCalcFields)
util.verifyTable(result, expected)
}
@Test
- def testBatchFilterableSourceScanPlanTableApi(): Unit = {
+ def testBatchFilterableWithoutPushDown(): Unit = {
+ val (tableSource, tableName) = filterableTableSource
+ val util = batchTestUtil()
+ val tEnv = util.tEnv
+
+ tEnv.registerTableSource(tableName, tableSource)
+
+ val result = tEnv
+ .scan(tableName)
+ .select('price, 'id, 'amount)
+ .where("price * 2 < 32")
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ batchSourceTableNode(
+ tableName,
+ Array("name", "id", "amount", "price")),
+ term("select", "price", "id", "amount"),
+ term("where", "<(*(price, 2), 32)")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testBatchFilterablePartialPushDown(): Unit = {
val (tableSource, tableName) = filterableTableSource
val util = batchTestUtil()
val tEnv = util.tEnv
@@ -97,18 +123,94 @@ class TableSourceTest extends TableTestBase {
val result = tEnv
.scan(tableName)
- .select('price, 'id, 'amount)
.where("amount > 2 && price * 2 < 32")
+ .select('price, 'name.lowerCase(), 'amount)
val expected = unaryNode(
"DataSetCalc",
- filterableSourceBatchTableNode(
+ batchFilterableSourceTableNode(
tableName,
Array("name", "id", "amount", "price"),
- ">(amount, 2)"),
- term("select", "price", "id", "amount"),
+ "'amount > 2"),
+ term("select", "price", "LOWER(name) AS _c1", "amount"),
term("where", "<(*(price, 2), 32)")
)
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testBatchFilterableFullyPushedDown(): Unit = {
+ val (tableSource, tableName) = filterableTableSource
+ val util = batchTestUtil()
+ val tEnv = util.tEnv
+
+ tEnv.registerTableSource(tableName, tableSource)
+
+ val result = tEnv
+ .scan(tableName)
+ .select('price, 'id, 'amount)
+ .where("amount > 2 && amount < 32")
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ batchFilterableSourceTableNode(
+ tableName,
+ Array("name", "id", "amount", "price"),
+ "'amount > 2 && 'amount < 32"),
+ term("select", "price", "id", "amount")
+ )
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testBatchFilterableWithUnconvertedExpression(): Unit = {
+ val (tableSource, tableName) = filterableTableSource
+ val util = batchTestUtil()
+ val tEnv = util.tEnv
+
+ tEnv.registerTableSource(tableName, tableSource)
+
+ val result = tEnv
+ .scan(tableName)
+ .select('price, 'id, 'amount)
+ .where("amount > 2 && (amount < 32 || amount.cast(LONG) > 10)") // cast can not be converted
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ batchFilterableSourceTableNode(
+ tableName,
+ Array("name", "id", "amount", "price"),
+ "'amount > 2"),
+ term("select", "price", "id", "amount"),
+ term("where", "OR(<(amount, 32), >(CAST(amount), 10))")
+ )
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testBatchFilterableWithUDF(): Unit = {
+ val (tableSource, tableName) = filterableTableSource
+ val util = batchTestUtil()
+ val tEnv = util.tEnv
+
+ tEnv.registerTableSource(tableName, tableSource)
+ val func = Func0
+ tEnv.registerFunction("func0", func)
+
+ val result = tEnv
+ .scan(tableName)
+ .select('price, 'id, 'amount)
+ .where("amount > 2 && func0(amount) < 32")
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ batchFilterableSourceTableNode(
+ tableName,
+ Array("name", "id", "amount", "price"),
+ "'amount > 2"),
+ term("select", "price", "id", "amount"),
+ term("where", s"<(${func.functionIdentifier}(amount), 32)")
+ )
util.verifyTable(result, expected)
}
@@ -129,7 +231,7 @@ class TableSourceTest extends TableTestBase {
val expected = unaryNode(
"DataStreamCalc",
- projectableSourceStreamTableNode(tableName, projectedFields),
+ streamSourceTableNode(tableName, projectedFields),
term("select", "last", "FLOOR(id) AS _c1", "*(score, 2) AS _c2")
)
@@ -147,7 +249,7 @@ class TableSourceTest extends TableTestBase {
val expected = unaryNode(
"DataStreamCalc",
- projectableSourceStreamTableNode(tableName, projectedFields),
+ streamSourceTableNode(tableName, projectedFields),
term("select", "last", "FLOOR(id) AS EXPR$1", "*(score, 2) AS EXPR$2")
)
@@ -166,7 +268,7 @@ class TableSourceTest extends TableTestBase {
.scan(tableName)
.select('id, 'score, 'first)
- val expected = projectableSourceStreamTableNode(tableName, noCalcFields)
+ val expected = streamSourceTableNode(tableName, noCalcFields)
util.verifyTable(result, expected)
}
@@ -185,10 +287,10 @@ class TableSourceTest extends TableTestBase {
val expected = unaryNode(
"DataStreamCalc",
- filterableSourceStreamTableNode(
+ streamFilterableSourceTableNode(
tableName,
Array("name", "id", "amount", "price"),
- ">(amount, 2)"),
+ "'amount > 2"),
term("select", "price", "id", "amount"),
term("where", "<(*(price, 2), 32)")
)
@@ -254,7 +356,7 @@ class TableSourceTest extends TableTestBase {
// utils
def filterableTableSource:(TableSource[_], String) = {
- val tableSource = CommonTestData.getFilterableTableSource
+ val tableSource = new TestFilterableTableSource
(tableSource, "filterableTable")
}
@@ -264,37 +366,27 @@ class TableSourceTest extends TableTestBase {
(csvTable, tableName)
}
- def projectableSourceBatchTableNode(
- sourceName: String,
- fields: Array[String]): String = {
-
- "BatchTableSourceScan(" +
- s"table=[[$sourceName]], fields=[${fields.mkString(", ")}])"
+ def batchSourceTableNode(sourceName: String, fields: Array[String]): String = {
+ s"BatchTableSourceScan(table=[[$sourceName]], fields=[${fields.mkString(", ")}])"
}
- def projectableSourceStreamTableNode(
- sourceName: String,
- fields: Array[String]): String = {
-
- "StreamTableSourceScan(" +
- s"table=[[$sourceName]], fields=[${fields.mkString(", ")}])"
+ def streamSourceTableNode(sourceName: String, fields: Array[String] ): String = {
+ s"StreamTableSourceScan(table=[[$sourceName]], fields=[${fields.mkString(", ")}])"
}
- def filterableSourceBatchTableNode(
- sourceName: String,
- fields: Array[String],
- exp: String): String = {
-
+ def batchFilterableSourceTableNode(
+ sourceName: String,
+ fields: Array[String],
+ exp: String): String = {
"BatchTableSourceScan(" +
- s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], filter=[$exp])"
+ s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], source=[filter=[$exp]])"
}
- def filterableSourceStreamTableNode(
- sourceName: String,
- fields: Array[String],
- exp: String): String = {
-
+ def streamFilterableSourceTableNode(
+ sourceName: String,
+ fields: Array[String],
+ exp: String): String = {
"StreamTableSourceScan(" +
- s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], filter=[$exp])"
+ s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], source=[filter=[$exp]])"
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala
index ca7cd8a..7e349cf 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala
@@ -23,7 +23,7 @@ import org.apache.flink.table.api.scala.batch.utils.TableProgramsCollectionTestB
import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.TableEnvironment
-import org.apache.flink.table.utils.CommonTestData
+import org.apache.flink.table.utils.{CommonTestData, TestFilterableTableSource}
import org.apache.flink.test.util.TestBaseUtils
import org.junit.Test
import org.junit.runner.RunWith
@@ -107,7 +107,7 @@ class TableSourceITCase(
val tableName = "MyTable"
val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- tableEnv.registerTableSource(tableName, CommonTestData.getFilterableTableSource)
+ tableEnv.registerTableSource(tableName, new TestFilterableTableSource)
val results = tableEnv
.scan(tableName)
.where("amount > 4 && price < 9")
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala
index 973c2f3..66711cb 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala
@@ -24,7 +24,7 @@ import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
import org.apache.flink.table.api.TableEnvironment
-import org.apache.flink.table.utils.CommonTestData
+import org.apache.flink.table.utils.{CommonTestData, TestFilterableTableSource}
import org.apache.flink.types.Row
import org.junit.Assert._
import org.junit.Test
@@ -90,7 +90,7 @@ class TableSourceITCase extends StreamingMultipleProgramsTestBase {
val tableName = "MyTable"
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
- tEnv.registerTableSource(tableName, CommonTestData.getFilterableTableSource)
+ tEnv.registerTableSource(tableName, new TestFilterableTableSource)
tEnv.scan(tableName)
.where("amount > 4 && price < 9")
.select("id, name")
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
index 30da5ba..d8de554 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
@@ -199,7 +199,7 @@ abstract class ExpressionTestBase {
// extract RexNode
val calcProgram = dataSetCalc
.asInstanceOf[DataSetCalc]
- .calcProgram
+ .getProgram
val expanded = calcProgram.expandLocalRef(calcProgram.getProjectList.get(0))
testExprs += ((expanded, expected))
@@ -222,7 +222,7 @@ abstract class ExpressionTestBase {
// extract RexNode
val calcProgram = dataSetCalc
.asInstanceOf[DataSetCalc]
- .calcProgram
+ .getProgram
val expanded = calcProgram.expandLocalRef(calcProgram.getProjectList.get(0))
testExprs += ((expanded, expected))
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala
deleted file mode 100644
index c4059d5..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.util
-
-import java.math.BigDecimal
-
-import org.apache.calcite.adapter.java.JavaTypeFactory
-import org.apache.calcite.plan._
-import org.apache.calcite.plan.volcano.VolcanoPlanner
-import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
-import org.apache.calcite.rel.core.TableScan
-import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
-import org.apache.calcite.sql.`type`.SqlTypeName._
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
-import org.apache.flink.table.expressions.{Expression, ExpressionParser}
-import org.apache.flink.table.plan.util.RexProgramExpressionExtractor._
-import org.apache.flink.table.plan.schema.CompositeRelDataType
-import org.apache.flink.table.utils.CommonTestData
-import org.junit.Test
-import org.junit.Assert._
-
-import scala.collection.JavaConverters._
-
-class RexProgramExpressionExtractorTest {
-
- private val typeFactory = new FlinkTypeFactory(RelDataTypeSystem.DEFAULT)
- private val allFieldTypes = List(VARCHAR, DECIMAL, INTEGER, DOUBLE).map(typeFactory.createSqlType)
- private val allFieldTypeInfos: Array[TypeInformation[_]] =
- Array(BasicTypeInfo.STRING_TYPE_INFO,
- BasicTypeInfo.BIG_DEC_TYPE_INFO,
- BasicTypeInfo.INT_TYPE_INFO,
- BasicTypeInfo.DOUBLE_TYPE_INFO)
- private val allFieldNames = List("name", "id", "amount", "price")
-
- @Test
- def testExtractExpression(): Unit = {
- val builder: RexBuilder = new RexBuilder(typeFactory)
- val program = buildRexProgram(
- allFieldNames, allFieldTypes, typeFactory, builder)
- val firstExp = ExpressionParser.parseExpression("id > 6")
- val secondExp = ExpressionParser.parseExpression("amount * price < 100")
- val expected: Array[Expression] = Array(firstExp, secondExp)
- val actual = extractPredicateExpressions(
- program,
- builder,
- CommonTestData.getMockTableEnvironment.getFunctionCatalog)
-
- assertEquals(expected.length, actual.length)
- // todo
- }
-
- @Test
- def testRewriteRexProgramWithCondition(): Unit = {
- val originalRexProgram = buildRexProgram(
- allFieldNames, allFieldTypes, typeFactory, new RexBuilder(typeFactory))
- val array = Array(
- "$0",
- "$1",
- "$2",
- "$3",
- "*($t2, $t3)",
- "100",
- "<($t4, $t5)",
- "6",
- ">($t1, $t7)",
- "AND($t6, $t8)")
- assertTrue(extractExprStrList(originalRexProgram) sameElements array)
-
- val tEnv = CommonTestData.getMockTableEnvironment
- val builder = FlinkRelBuilder.create(tEnv.getFrameworkConfig)
- val tableScan = new MockTableScan(builder.getRexBuilder)
- val newExpression = ExpressionParser.parseExpression("amount * price < 100")
- val newRexProgram = rewriteRexProgram(
- originalRexProgram,
- tableScan,
- Array(newExpression)
- )(builder)
-
- val newArray = Array(
- "$0",
- "$1",
- "$2",
- "$3",
- "*($t2, $t3)",
- "100",
- "<($t4, $t5)")
- assertTrue(extractExprStrList(newRexProgram) sameElements newArray)
- }
-
-// @Test
-// def testVerifyExpressions(): Unit = {
-// val strPart = "f1 < 4"
-// val part = parseExpression(strPart)
-//
-// val shortFalseOrigin = parseExpression(s"f0 > 10 || $strPart")
-// assertFalse(verifyExpressions(shortFalseOrigin, part))
-//
-// val longFalseOrigin = parseExpression(s"(f0 > 10 || (($strPart) > POWER(f0, f1))) && 2")
-// assertFalse(verifyExpressions(longFalseOrigin, part))
-//
-// val shortOkayOrigin = parseExpression(s"f0 > 10 && ($strPart)")
-// assertTrue(verifyExpressions(shortOkayOrigin, part))
-//
-// val longOkayOrigin = parseExpression(s"f0 > 10 && (($strPart) > POWER(f0, f1))")
-// assertTrue(verifyExpressions(longOkayOrigin, part))
-//
-// val longOkayOrigin2 = parseExpression(s"(f0 > 10 || (2 > POWER(f0, f1))) && $strPart")
-// assertTrue(verifyExpressions(longOkayOrigin2, part))
-// }
-
- private def buildRexProgram(
- fieldNames: List[String],
- fieldTypes: Seq[RelDataType],
- typeFactory: JavaTypeFactory,
- rexBuilder: RexBuilder): RexProgram = {
-
- val inputRowType = typeFactory.createStructType(fieldTypes.asJava, fieldNames.asJava)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
-
- val t0 = rexBuilder.makeInputRef(fieldTypes(2), 2)
- val t1 = rexBuilder.makeInputRef(fieldTypes(1), 1)
- val t2 = rexBuilder.makeInputRef(fieldTypes(3), 3)
- // t3 = t0 * t2
- val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
- val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
- val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
- // project: amount, amount * price
- builder.addProject(t0, "amount")
- builder.addProject(t3, "total")
- // t6 = t3 < t4
- val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
- // t7 = t1 > t5
- val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
- val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
- // condition: t6 and t7
- // (t0 * t2 < t4 && t1 > t5)
- builder.addCondition(t8)
- builder.getProgram
- }
-
- /**
- * extract all expression string list from input RexProgram expression lists
- *
- * @param rexProgram input RexProgram instance to analyze
- * @return all expression string list of input RexProgram expression lists
- */
- private def extractExprStrList(rexProgram: RexProgram) =
- rexProgram.getExprList.asScala.map(_.toString).toArray
-
- class MockTableScan(
- rexBuilder: RexBuilder)
- extends TableScan(
- RelOptCluster.create(new VolcanoPlanner(), rexBuilder),
- RelTraitSet.createEmpty,
- new MockRelOptTable)
-
- class MockRelOptTable
- extends RelOptAbstractTable(
- null,
- "mockRelTable",
- new CompositeRelDataType(
- new RowTypeInfo(allFieldTypeInfos, allFieldNames.toArray), typeFactory))
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
new file mode 100644
index 0000000..b0a5fcf
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExtractorTest.scala
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util
+
+import java.math.BigDecimal
+
+import org.apache.calcite.rex.{RexBuilder, RexProgramBuilder}
+import org.apache.calcite.sql.SqlPostfixOperator
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.flink.table.expressions.{Expression, ExpressionParser}
+import org.apache.flink.table.validate.FunctionCatalog
+import org.junit.Assert.{assertArrayEquals, assertEquals}
+import org.junit.Test
+
+import scala.collection.JavaConverters._
+
+class RexProgramExtractorTest extends RexProgramTestBase {
+
+ private val functionCatalog: FunctionCatalog = FunctionCatalog.withBuiltIns
+
+ @Test
+ def testExtractRefInputFields(): Unit = {
+ val usedFields = RexProgramExtractor.extractRefInputFields(buildSimpleRexProgram())
+ assertArrayEquals(usedFields, Array(2, 3, 1))
+ }
+
+ @Test
+ def testExtractSimpleCondition(): Unit = {
+ val builder: RexBuilder = new RexBuilder(typeFactory)
+ val program = buildSimpleRexProgram()
+
+ val firstExp = ExpressionParser.parseExpression("id > 6")
+ val secondExp = ExpressionParser.parseExpression("amount * price < 100")
+ val expected: Array[Expression] = Array(firstExp, secondExp)
+
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ builder,
+ functionCatalog)
+
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ @Test
+ def testExtractSingleCondition(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+
+ // a = amount >= id
+ val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1))
+ builder.addCondition(a)
+
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(ExpressionParser.parseExpression("amount >= id"))
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ // ((a AND b) OR c) AND (NOT d) => (a OR c) AND (b OR c) AND (NOT d)
+ @Test
+ def testExtractCnfCondition(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ // price
+ val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
+ // 100
+ val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ // a = amount < 100
+ val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t3))
+ // b = id > 100
+ val b = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t3))
+ // c = price == 100
+ val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t3))
+ // d = amount <= id
+ val d = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
+
+ // a AND b
+ val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(a, b).asJava))
+ // (a AND b) or c
+ val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, List(and, c).asJava))
+ // not d
+ val not = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT, List(d).asJava))
+
+ // (a AND b) OR c) AND (NOT d)
+ builder.addCondition(builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.AND, List(or, not).asJava)))
+
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(
+ ExpressionParser.parseExpression("amount < 100 || price == 100"),
+ ExpressionParser.parseExpression("id > 100 || price == 100"),
+ ExpressionParser.parseExpression("!(amount <= id)"))
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ @Test
+ def testExtractArithmeticConditions(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ // 100
+ val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ val condition = List(
+ // amount < id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t1)),
+ // amount <= id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)),
+ // amount <> id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS, t0, t1)),
+ // amount == id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t0, t1)),
+ // amount >= id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, t0, t1)),
+ // amount > id
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t0, t1)),
+ // amount + id == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.PLUS, t0, t1), t2)),
+ // amount - id == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.MINUS, t0, t1), t2)),
+ // amount * id == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t1), t2)),
+ // amount / id == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, t0, t1), t2)),
+ // -amount == 100
+ builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, t0), t2))
+ ).asJava
+
+ builder.addCondition(builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, condition)))
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(
+ ExpressionParser.parseExpression("amount < id"),
+ ExpressionParser.parseExpression("amount <= id"),
+ ExpressionParser.parseExpression("amount <> id"),
+ ExpressionParser.parseExpression("amount == id"),
+ ExpressionParser.parseExpression("amount >= id"),
+ ExpressionParser.parseExpression("amount > id"),
+ ExpressionParser.parseExpression("amount + id == 100"),
+ ExpressionParser.parseExpression("amount - id == 100"),
+ ExpressionParser.parseExpression("amount * id == 100"),
+ ExpressionParser.parseExpression("amount / id == 100"),
+ ExpressionParser.parseExpression("-amount == 100")
+ )
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ @Test
+ def testExtractPostfixConditions(): Unit = {
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NULL, "('flag).isNull")
+ // IS_NOT_NULL will be eliminated since flag is not nullable
+ // testExtractSinglePostfixCondition(SqlStdOperatorTable.IS_NOT_NULL, "('flag).isNotNull")
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_TRUE, "('flag).isTrue")
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_TRUE, "('flag).isNotTrue")
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_FALSE, "('flag).isFalse")
+ testExtractSinglePostfixCondition(4, SqlStdOperatorTable.IS_NOT_FALSE, "('flag).isNotFalse")
+ }
+
+ @Test
+ def testExtractConditionWithFunctionCalls(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ // 100
+ val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ // sum(amount) > 100
+ val condition1 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN,
+ rexBuilder.makeCall(SqlStdOperatorTable.SUM, t0), t2))
+
+ // min(id) == 100
+ val condition2 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+ rexBuilder.makeCall(SqlStdOperatorTable.MIN, t1), t2))
+
+ builder.addCondition(builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2)))
+
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(
+ ExpressionParser.parseExpression("sum(amount) > 100"),
+ ExpressionParser.parseExpression("min(id) == 100")
+ )
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ @Test
+ def testExtractWithUnsupportedConditions(): Unit = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ // amount
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ // id
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ // 100
+ val t2 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+
+ // unsupported now: amount.cast(BigInteger)
+ val cast = builder.addExpr(rexBuilder.makeCast(allFieldTypes.get(1), t0))
+
+ // unsupported now: amount.cast(BigInteger) > 100
+ val condition1 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, cast, t2))
+
+ // amount <= id
+ val condition2 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1))
+
+ // contains unsupported condition: (amount.cast(BigInteger) > 100 OR amount <= id)
+ val condition3 = builder.addExpr(
+ rexBuilder.makeCall(SqlStdOperatorTable.OR, condition1, condition2))
+
+ // only condition2 can be translated
+ builder.addCondition(
+ rexBuilder.makeCall(SqlStdOperatorTable.AND, condition1, condition2, condition3))
+
+ val program = builder.getProgram
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ val expected: Array[Expression] = Array(
+ ExpressionParser.parseExpression("amount <= id")
+ )
+ assertExpressionArrayEquals(expected, convertedExpressions)
+ assertEquals(2, unconvertedRexNodes.length)
+ assertEquals(">(CAST($2):BIGINT NOT NULL, 100)", unconvertedRexNodes(0).toString)
+ assertEquals("OR(>(CAST($2):BIGINT NOT NULL, 100), <=($2, $1))",
+ unconvertedRexNodes(1).toString)
+ }
+
+ private def testExtractSinglePostfixCondition(
+ fieldIndex: Integer,
+ op: SqlPostfixOperator,
+ expr: String) : Unit = {
+
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+ rexBuilder = new RexBuilder(typeFactory)
+
+ // flag
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(fieldIndex), fieldIndex)
+ builder.addCondition(builder.addExpr(rexBuilder.makeCall(op, t0)))
+
+ val program = builder.getProgram(false)
+ val relBuilder: RexBuilder = new RexBuilder(typeFactory)
+ val (convertedExpressions, unconvertedRexNodes) =
+ RexProgramExtractor.extractConjunctiveConditions(
+ program,
+ relBuilder,
+ functionCatalog)
+
+ assertEquals(1, convertedExpressions.length)
+ assertEquals(expr, convertedExpressions.head.toString)
+ assertEquals(0, unconvertedRexNodes.length)
+ }
+
+ private def assertExpressionArrayEquals(
+ expected: Array[Expression],
+ actual: Array[Expression]) = {
+ val sortedExpected = expected.sortBy(e => e.toString)
+ val sortedActual = actual.sortBy(e => e.toString)
+
+ assertEquals(sortedExpected.length, sortedActual.length)
+ sortedExpected.zip(sortedActual).foreach {
+ case (l, r) => assertEquals(l.toString, r.toString)
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala
deleted file mode 100644
index cea9eee..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.plan.util
-
-import java.math.BigDecimal
-
-import org.apache.calcite.adapter.java.JavaTypeFactory
-import org.apache.calcite.jdbc.JavaTypeFactoryImpl
-import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
-import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
-import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR}
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
-import org.apache.flink.table.plan.util.RexProgramProjectExtractor._
-import org.junit.Assert.{assertArrayEquals, assertTrue}
-import org.junit.{Before, Test}
-
-import scala.collection.JavaConverters._
-
-/**
- * This class is responsible for testing RexProgramProjectExtractor.
- */
-class RexProgramProjectExtractorTest {
- private var typeFactory: JavaTypeFactory = _
- private var rexBuilder: RexBuilder = _
- private var allFieldTypes: Seq[RelDataType] = _
- private val allFieldNames = List("name", "id", "amount", "price")
-
- @Before
- def setUp(): Unit = {
- typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
- rexBuilder = new RexBuilder(typeFactory)
- allFieldTypes = List(VARCHAR, BIGINT, INTEGER, DOUBLE).map(typeFactory.createSqlType(_))
- }
-
- @Test
- def testExtractRefInputFields(): Unit = {
- val usedFields = extractRefInputFields(buildRexProgram())
- assertArrayEquals(usedFields, Array(2, 3, 1))
- }
-
- @Test
- def testRewriteRexProgram(): Unit = {
- val originRexProgram = buildRexProgram()
- assertTrue(extractExprStrList(originRexProgram).sameElements(Array(
- "$0",
- "$1",
- "$2",
- "$3",
- "*($t2, $t3)",
- "100",
- "<($t4, $t5)",
- "6",
- ">($t1, $t7)",
- "AND($t6, $t8)")))
- // use amount, id, price fields to create a new RexProgram
- val usedFields = Array(2, 3, 1)
- val types = usedFields.map(allFieldTypes(_)).toList.asJava
- val names = usedFields.map(allFieldNames(_)).toList.asJava
- val inputRowType = typeFactory.createStructType(types, names)
- val newRexProgram = rewriteRexProgram(originRexProgram, inputRowType, usedFields, rexBuilder)
- assertTrue(extractExprStrList(newRexProgram).sameElements(Array(
- "$0",
- "$1",
- "$2",
- "*($t0, $t1)",
- "100",
- "<($t3, $t4)",
- "6",
- ">($t2, $t6)",
- "AND($t5, $t7)")))
- }
-
- private def buildRexProgram(): RexProgram = {
- val types = allFieldTypes.asJava
- val names = allFieldNames.asJava
- val inputRowType = typeFactory.createStructType(types, names)
- val builder = new RexProgramBuilder(inputRowType, rexBuilder)
- val t0 = rexBuilder.makeInputRef(types.get(2), 2)
- val t1 = rexBuilder.makeInputRef(types.get(1), 1)
- val t2 = rexBuilder.makeInputRef(types.get(3), 3)
- val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
- val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
- val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
- // project: amount, amount * price
- builder.addProject(t0, "amount")
- builder.addProject(t3, "total")
- // condition: amount * price < 100 and id > 6
- val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
- val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
- val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
- builder.addCondition(t8)
- builder.getProgram
- }
-
- /**
- * extract all expression string list from input RexProgram expression lists
- *
- * @param rexProgram input RexProgram instance to analyze
- * @return all expression string list of input RexProgram expression lists
- */
- private def extractExprStrList(rexProgram: RexProgram) = {
- rexProgram.getExprList.asScala.map(_.toString)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
new file mode 100644
index 0000000..899eed2
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramRewriterTest.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util
+
+import org.junit.Assert.assertTrue
+import org.junit.Test
+
+import scala.collection.JavaConverters._
+
+class RexProgramRewriterTest extends RexProgramTestBase {
+
+ @Test
+ def testRewriteRexProgram(): Unit = {
+ val rexProgram = buildSimpleRexProgram()
+ assertTrue(extractExprStrList(rexProgram) == wrapRefArray(Array(
+ "$0",
+ "$1",
+ "$2",
+ "$3",
+ "$4",
+ "*($t2, $t3)",
+ "100",
+ "<($t5, $t6)",
+ "6",
+ ">($t1, $t8)",
+ "AND($t7, $t9)")))
+
+ // use amount, id, price fields to create a new RexProgram
+ val usedFields = Array(2, 3, 1)
+ val types = usedFields.map(allFieldTypes.get).toList.asJava
+ val names = usedFields.map(allFieldNames.get).toList.asJava
+ val inputRowType = typeFactory.createStructType(types, names)
+ val newRexProgram = RexProgramRewriter.rewriteWithFieldProjection(
+ rexProgram, inputRowType, rexBuilder, usedFields)
+ assertTrue(extractExprStrList(newRexProgram) == wrapRefArray(Array(
+ "$0",
+ "$1",
+ "$2",
+ "*($t0, $t1)",
+ "100",
+ "<($t3, $t4)",
+ "6",
+ ">($t2, $t6)",
+ "AND($t5, $t7)")))
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
new file mode 100644
index 0000000..6ef3d82
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramTestBase.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util
+
+import java.math.BigDecimal
+import java.util
+
+import org.apache.calcite.adapter.java.JavaTypeFactory
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR, BOOLEAN}
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+abstract class RexProgramTestBase {
+
+ val typeFactory: JavaTypeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
+
+ val allFieldNames: util.List[String] = List("name", "id", "amount", "price", "flag").asJava
+
+ val allFieldTypes: util.List[RelDataType] =
+ List(VARCHAR, BIGINT, INTEGER, DOUBLE, BOOLEAN).map(typeFactory.createSqlType).asJava
+
+ var rexBuilder: RexBuilder = new RexBuilder(typeFactory)
+
+ /**
+ * extract all expression string list from input RexProgram expression lists
+ *
+ * @param rexProgram input RexProgram instance to analyze
+ * @return all expression string list of input RexProgram expression lists
+ */
+ protected def extractExprStrList(rexProgram: RexProgram): mutable.Buffer[String] = {
+ rexProgram.getExprList.asScala.map(_.toString)
+ }
+
+ // select amount, amount * price as total where amount * price < 100 and id > 6
+ protected def buildSimpleRexProgram(): RexProgram = {
+ val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+ val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2)
+ val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1)
+ val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3)
+ val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
+ val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+ val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
+
+ // project: amount, amount * price as total
+ builder.addProject(t0, "amount")
+ builder.addProject(t3, "total")
+
+ // condition: amount * price < 100 and id > 6
+ val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
+ val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
+ val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
+ builder.addCondition(t8)
+
+ builder.getProgram
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
index a720f02..2364f23 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
@@ -21,21 +21,11 @@ package org.apache.flink.table.utils
import java.io.{File, FileOutputStream, OutputStreamWriter}
import java.util
-import org.apache.flink.api.java.typeutils.TypeExtractor
-import org.apache.flink.table.sources.{BatchTableSource, CsvTableSource}
-import org.apache.calcite.tools.RuleSet
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
-import org.apache.flink.streaming.api.datastream.DataStream
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
-import org.apache.flink.table.api.{Table, TableConfig, TableEnvironment}
-import org.apache.flink.table.expressions._
-import org.apache.flink.table.sinks.TableSink
-import org.apache.flink.table.sources._
-import org.apache.flink.types.Row
-
-import scala.collection.JavaConverters._
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.sources.{BatchTableSource, CsvTableSource}
object CommonTestData {
@@ -108,110 +98,4 @@ object CommonTestData {
def getMockTableEnvironment: TableEnvironment = new MockTableEnvironment
- def getFilterableTableSource = new TestFilterableTableSource
-}
-
-class MockTableEnvironment extends TableEnvironment(new TableConfig) {
-
- override private[flink] def writeToSink[T](table: Table, sink: TableSink[T]): Unit = ???
-
- override protected def checkValidTableName(name: String): Unit = ???
-
- override def sql(query: String): Table = ???
-
- override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = ???
-
- override protected def getBuiltInNormRuleSet: RuleSet = ???
-
- override protected def getBuiltInOptRuleSet: RuleSet = ???
-}
-
-class TestFilterableTableSource
- extends BatchTableSource[Row]
- with StreamTableSource[Row]
- with FilterableTableSource
- with DefinedFieldNames {
-
- import org.apache.flink.table.api.Types._
-
- val fieldNames = Array("name", "id", "amount", "price")
- val fieldTypes = Array[TypeInformation[_]](STRING, LONG, INT, DOUBLE)
-
- private var filterLiteral: Literal = _
- private var filterPredicates: Array[Expression] = Array.empty
-
- /** Returns the data of the table as a [[DataSet]]. */
- override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = {
- execEnv.fromCollection[Row](generateDynamicCollection(33).asJava, getReturnType)
- }
-
- /** Returns the data of the table as a [[DataStream]]. */
- def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = {
- execEnv.fromCollection[Row](generateDynamicCollection(33).asJava, getReturnType)
- }
-
- private def generateDynamicCollection(num: Int): Seq[Row] = {
-
- if (filterLiteral == null) {
- throw new RuntimeException("filter expression was not set")
- }
-
- val filterValue = filterLiteral.value.asInstanceOf[Number].intValue()
-
- def shouldCreateRow(value: Int): Boolean = {
- value > filterValue
- }
-
- for {
- cnt <- 0 until num
- if shouldCreateRow(cnt)
- } yield {
- val row = new Row(fieldNames.length)
- fieldNames.zipWithIndex.foreach { case (name, index) =>
- name match {
- case "name" =>
- row.setField(index, s"Record_$cnt")
- case "id" =>
- row.setField(index, cnt.toLong)
- case "amount" =>
- row.setField(index, cnt.toInt)
- case "price" =>
- row.setField(index, cnt.toDouble)
- }
- }
- row
- }
- }
-
- /** Returns the [[TypeInformation]] for the return type. */
- override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames)
-
- /** Returns the names of the table fields. */
- override def getFieldNames: Array[String] = fieldNames
-
- /** Returns the indices of the table fields. */
- override def getFieldIndices: Array[Int] = fieldNames.indices.toArray
-
- override def getPredicate: Array[Expression] = filterPredicates
-
- /** Return an unsupported predicates expression. */
- override def setPredicate(predicates: Array[Expression]): Array[Expression] = {
- predicates(0) match {
- case gt: GreaterThan =>
- gt.left match {
- case f: ResolvedFieldReference =>
- gt.right match {
- case l: Literal =>
- if (f.name.equals("amount")) {
- filterLiteral = l
- filterPredicates = Array(predicates(0))
- Array(predicates(1))
- } else predicates
- case _ => predicates
- }
- case _ => predicates
- }
- case _ => predicates
- }
- }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/MockTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/MockTableEnvironment.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/MockTableEnvironment.scala
new file mode 100644
index 0000000..6a86ace
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/MockTableEnvironment.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.utils
+
+import org.apache.calcite.tools.RuleSet
+import org.apache.flink.table.api.{Table, TableConfig, TableEnvironment}
+import org.apache.flink.table.sinks.TableSink
+import org.apache.flink.table.sources.TableSource
+
+class MockTableEnvironment extends TableEnvironment(new TableConfig) {
+
+ override private[flink] def writeToSink[T](table: Table, sink: TableSink[T]): Unit = ???
+
+ override protected def checkValidTableName(name: String): Unit = ???
+
+ override def sql(query: String): Table = ???
+
+ override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = ???
+
+ override protected def getBuiltInNormRuleSet: RuleSet = ???
+
+ override protected def getBuiltInOptRuleSet: RuleSet = ???
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/78f22aae/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala
new file mode 100644
index 0000000..dcf2acd
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.utils
+
+import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
+import org.apache.flink.streaming.api.datastream.DataStream
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
+import org.apache.flink.table.api.Types._
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.sources.{BatchTableSource, FilterableTableSource, StreamTableSource, TableSource}
+import org.apache.flink.types.Row
+import org.apache.flink.util.Preconditions
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+/**
+ * This source can only handle simple comparision with field "amount".
+ * Supports ">, <, >=, <=, =, <>" with an integer.
+ */
+class TestFilterableTableSource(
+ val recordNum: Int = 33)
+ extends BatchTableSource[Row]
+ with StreamTableSource[Row]
+ with FilterableTableSource[Row] {
+
+ var filterPushedDown: Boolean = false
+
+ val fieldNames: Array[String] = Array("name", "id", "amount", "price")
+
+ val fieldTypes: Array[TypeInformation[_]] = Array(STRING, LONG, INT, DOUBLE)
+
+ // all predicates with field "amount"
+ private var filterPredicates = new mutable.ArrayBuffer[Expression]
+
+ // all comparing values for field "amount"
+ private val filterValues = new mutable.ArrayBuffer[Int]
+
+ override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = {
+ execEnv.fromCollection[Row](generateDynamicCollection().asJava, getReturnType)
+ }
+
+ override def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = {
+ execEnv.fromCollection[Row](generateDynamicCollection().asJava, getReturnType)
+ }
+
+ override def explainSource(): String = {
+ if (filterPredicates.nonEmpty) {
+ s"filter=[${filterPredicates.reduce((l, r) => And(l, r)).toString}]"
+ } else {
+ ""
+ }
+ }
+
+ override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames)
+
+ override def applyPredicate(predicates: JList[Expression]): TableSource[Row] = {
+ val newSource = new TestFilterableTableSource(recordNum)
+ newSource.filterPushedDown = true
+
+ val iterator = predicates.iterator()
+ while (iterator.hasNext) {
+ iterator.next() match {
+ case expr: BinaryComparison =>
+ (expr.left, expr.right) match {
+ case (f: ResolvedFieldReference, v: Literal) if f.name.equals("amount") =>
+ newSource.filterPredicates += expr
+ newSource.filterValues += v.value.asInstanceOf[Number].intValue()
+ iterator.remove()
+ case (_, _) =>
+ }
+ }
+ }
+
+ newSource
+ }
+
+ override def isFilterPushedDown: Boolean = filterPushedDown
+
+ private def generateDynamicCollection(): Seq[Row] = {
+ Preconditions.checkArgument(filterPredicates.length == filterValues.length)
+
+ for {
+ cnt <- 0 until recordNum
+ if shouldCreateRow(cnt)
+ } yield {
+ Row.of(
+ s"Record_$cnt",
+ cnt.toLong.asInstanceOf[Object],
+ cnt.toInt.asInstanceOf[Object],
+ cnt.toDouble.asInstanceOf[Object])
+ }
+ }
+
+ private def shouldCreateRow(value: Int): Boolean = {
+ filterPredicates.zip(filterValues).forall {
+ case (_: GreaterThan, v) =>
+ value > v
+ case (_: LessThan, v) =>
+ value < v
+ case (_: GreaterThanOrEqual, v) =>
+ value >= v
+ case (_: LessThanOrEqual, v) =>
+ value <= v
+ case (_: EqualTo, v) =>
+ value == v
+ case (_: NotEqualTo, v) =>
+ value != v
+ case (expr, _) =>
+ throw new RuntimeException(expr + " not supported!")
+ }
+ }
+}
+