You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/12/13 13:34:20 UTC
[3/3] flink git commit: [FLINK-3848] [table] Add
ProjectableTableSource and push projections into BatchTableSourceScan.
[FLINK-3848] [table] Add ProjectableTableSource and push projections into BatchTableSourceScan.
This closes #2923.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/5baea3f2
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/5baea3f2
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/5baea3f2
Branch: refs/heads/master
Commit: 5baea3f2e13cd2b6d904c617092372f368f12b55
Parents: 5c86efb
Author: beyond1920 <be...@126.com>
Authored: Fri Dec 2 11:33:12 2016 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Dec 13 14:13:18 2016 +0100
----------------------------------------------------------------------
.../nodes/dataset/BatchTableSourceScan.scala | 13 +-
.../api/table/plan/rules/FlinkRuleSets.scala | 4 +-
...ushProjectIntoBatchTableSourceScanRule.scala | 84 +++++++++++
.../rules/util/RexProgramProjectExtractor.scala | 120 +++++++++++++++
.../table/sources/ProjectableTableSource.scala | 38 +++++
.../batch/ProjectableTableSourceITCase.scala | 145 +++++++++++++++++++
.../util/RexProgramProjectExtractorTest.scala | 120 +++++++++++++++
7 files changed, 522 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/5baea3f2/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala
index 14da862..e368219 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala
@@ -19,7 +19,8 @@
package org.apache.flink.api.table.plan.nodes.dataset
import org.apache.calcite.plan._
-import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.metadata.RelMetadataQuery
+import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.{BatchTableEnvironment, FlinkTypeFactory}
@@ -39,6 +40,11 @@ class BatchTableSourceScan(
flinkTypeFactory.buildRowDataType(tableSource.getFieldsNames, tableSource.getFieldTypes)
}
+ override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
+ val rowCnt = metadata.getRowCount(this)
+ planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * estimateRowSize(getRowType))
+ }
+
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new BatchTableSourceScan(
cluster,
@@ -48,6 +54,11 @@ class BatchTableSourceScan(
)
}
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ super.explainTerms(pw)
+ .item("fields", tableSource.getFieldsNames.mkString(", "))
+ }
+
override def translateToPlan(
tableEnv: BatchTableEnvironment,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {
http://git-wip-us.apache.org/repos/asf/flink/blob/5baea3f2/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
index 6847425..183065c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
@@ -109,7 +109,9 @@ object FlinkRuleSets {
DataSetSortRule.INSTANCE,
DataSetValuesRule.INSTANCE,
DataSetCorrelateRule.INSTANCE,
- BatchTableSourceScanRule.INSTANCE
+ BatchTableSourceScanRule.INSTANCE,
+ // project pushdown optimization
+ PushProjectIntoBatchTableSourceScanRule.INSTANCE
)
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/5baea3f2/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala
new file mode 100644
index 0000000..301a45b
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.plan.rules.dataSet
+
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
+import org.apache.calcite.plan.RelOptRule.{none, operand}
+import org.apache.flink.api.table.plan.nodes.dataset.{BatchTableSourceScan, DataSetCalc}
+import org.apache.flink.api.table.plan.rules.util.RexProgramProjectExtractor._
+import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSource}
+
+/**
+ * This rule tries to push projections into a BatchTableSourceScan.
+ */
+class PushProjectIntoBatchTableSourceScanRule extends RelOptRule(
+ operand(classOf[DataSetCalc],
+ operand(classOf[BatchTableSourceScan], none)),
+ "PushProjectIntoBatchTableSourceScanRule") {
+
+ override def matches(call: RelOptRuleCall) = {
+ val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan]
+ scan.tableSource match {
+ case _: ProjectableTableSource[_] => true
+ case _ => false
+ }
+ }
+
+ override def onMatch(call: RelOptRuleCall) {
+ val calc: DataSetCalc = call.rel(0).asInstanceOf[DataSetCalc]
+ val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan]
+
+ val usedFields: Array[Int] = extractRefInputFields(calc.calcProgram)
+
+ // if no fields can be projected, there is no need to transform subtree
+ if (scan.tableSource.getNumberOfFields != usedFields.length) {
+ val originTableSource = scan.tableSource.asInstanceOf[ProjectableTableSource[_]]
+ val newTableSource = originTableSource.projectFields(usedFields)
+ val newScan = new BatchTableSourceScan(
+ scan.getCluster,
+ scan.getTraitSet,
+ scan.getTable,
+ newTableSource.asInstanceOf[BatchTableSource[_]])
+
+ val newCalcProgram = rewriteRexProgram(
+ calc.calcProgram,
+ newScan.getRowType,
+ usedFields,
+ calc.getCluster.getRexBuilder)
+
+ // if project merely returns its input and doesn't exist filter, remove datasetCalc nodes
+ if (newCalcProgram.isTrivial) {
+ call.transformTo(newScan)
+ } else {
+ val newCalc = new DataSetCalc(
+ calc.getCluster,
+ calc.getTraitSet,
+ newScan,
+ calc.getRowType,
+ newCalcProgram,
+ description)
+ call.transformTo(newCalc)
+ }
+ }
+ }
+}
+
+object PushProjectIntoBatchTableSourceScanRule {
+ val INSTANCE: RelOptRule = new PushProjectIntoBatchTableSourceScanRule
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/5baea3f2/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractor.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractor.scala
new file mode 100644
index 0000000..d78e07f
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractor.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.plan.rules.util
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rex._
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable
+import scala.collection.JavaConverters._
+
+object RexProgramProjectExtractor {
+
+ /**
+ * Extracts the indexes of input fields accessed by the RexProgram.
+ *
+ * @param rexProgram The RexProgram to analyze
+ * @return The indexes of accessed input fields
+ */
+ def extractRefInputFields(rexProgram: RexProgram): Array[Int] = {
+ val visitor = new RefFieldsVisitor
+ // extract input fields from project expressions
+ rexProgram.getProjectList.foreach(exp => rexProgram.expandLocalRef(exp).accept(visitor))
+ val condition = rexProgram.getCondition
+ // extract input fields from condition expression
+ if (condition != null) {
+ rexProgram.expandLocalRef(condition).accept(visitor)
+ }
+ visitor.getFields
+ }
+
+ /**
+ * Generates a new RexProgram based on mapped input fields.
+ *
+ * @param rexProgram original RexProgram
+ * @param inputRowType input row type
+ * @param usedInputFields indexes of used input fields
+ * @param rexBuilder builder for Rex expressions
+ *
+ * @return A RexProgram with mapped input field expressions.
+ */
+ def rewriteRexProgram(
+ rexProgram: RexProgram,
+ inputRowType: RelDataType,
+ usedInputFields: Array[Int],
+ rexBuilder: RexBuilder): RexProgram = {
+
+ val inputRewriter = new InputRewriter(usedInputFields)
+ val newProjectExpressions = rexProgram.getProjectList.map(
+ exp => rexProgram.expandLocalRef(exp).accept(inputRewriter)
+ ).toList.asJava
+
+ val oldCondition = rexProgram.getCondition
+ val newConditionExpression = {
+ oldCondition match {
+ case ref: RexLocalRef => rexProgram.expandLocalRef(ref).accept(inputRewriter)
+ case _ => null // null does not match any type
+ }
+ }
+ RexProgram.create(
+ inputRowType,
+ newProjectExpressions,
+ newConditionExpression,
+ rexProgram.getOutputRowType,
+ rexBuilder
+ )
+ }
+}
+
+/**
+ * A RexVisitor to extract used input fields
+ */
+class RefFieldsVisitor extends RexVisitorImpl[Unit](true) {
+ private var fields = mutable.LinkedHashSet[Int]()
+
+ def getFields: Array[Int] = fields.toArray
+
+ override def visitInputRef(inputRef: RexInputRef): Unit = fields += inputRef.getIndex
+
+ override def visitCall(call: RexCall): Unit =
+ call.operands.foreach(operand => operand.accept(this))
+}
+
+/**
+ * A RexShuttle to rewrite field accesses of a RexProgram.
+ *
+ * @param fields fields mapping
+ */
+class InputRewriter(fields: Array[Int]) extends RexShuttle {
+
+ /** old input fields ref index -> new input fields ref index mappings */
+ private val fieldMap: Map[Int, Int] =
+ fields.zipWithIndex.toMap
+
+ override def visitInputRef(inputRef: RexInputRef): RexNode =
+ new RexInputRef(relNodeIndex(inputRef), inputRef.getType)
+
+ override def visitLocalRef(localRef: RexLocalRef): RexNode =
+ new RexInputRef(relNodeIndex(localRef), localRef.getType)
+
+ private def relNodeIndex(ref: RexSlot): Int =
+ fieldMap.getOrElse(ref.getIndex,
+ throw new IllegalArgumentException("input field contains invalid index"))
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/5baea3f2/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala
new file mode 100644
index 0000000..c04138a
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.sources
+
+/**
+ * Adds support for projection push-down to a [[TableSource]].
+ * A [[TableSource]] extending this interface is able to project the fields of the return table.
+ *
+ * @tparam T The return type of the [[ProjectableTableSource]].
+ */
+trait ProjectableTableSource[T] {
+
+ /**
+ * Creates a copy of the [[ProjectableTableSource]] that projects its output on the specified
+ * fields.
+ *
+ * @param fields The indexes of the fields to return.
+ * @return A copy of the [[ProjectableTableSource]] that projects its output.
+ */
+ def projectFields(fields: Array[Int]): ProjectableTableSource[T]
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/5baea3f2/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala
new file mode 100644
index 0000000..42b9de0
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.batch
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.{DataSet => JavaSet, ExecutionEnvironment => JavaExecEnv}
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSource}
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.{Row, TableEnvironment}
+import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.apache.flink.test.util.TestBaseUtils
+import org.junit.{Before, Test}
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+
+@RunWith(classOf[Parameterized])
+class ProjectableTableSourceITCase(mode: TestExecutionMode,
+ configMode: TableConfigMode)
+ extends TableProgramsTestBase(mode, configMode) {
+
+ private val tableName = "MyTable"
+ private var tableEnv: BatchTableEnvironment = null
+
+ @Before
+ def initTableEnv(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ tableEnv.registerTableSource(tableName, new TestProjectableTableSource)
+ }
+
+ @Test
+ def testTableAPI(): Unit = {
+ val results = tableEnv
+ .scan(tableName)
+ .where("amount < 4")
+ .select("id, name")
+ .collect()
+
+ val expected = Seq(
+ "0,Record_0", "1,Record_1", "2,Record_2", "3,Record_3", "16,Record_16",
+ "17,Record_17", "18,Record_18", "19,Record_19", "32,Record_32").mkString("\n")
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ @Test
+ def testSQL(): Unit = {
+ val results = tableEnv
+ .sql(s"select id, name from $tableName where amount < 4 ")
+ .collect()
+
+ val expected = Seq(
+ "0,Record_0", "1,Record_1", "2,Record_2", "3,Record_3", "16,Record_16",
+ "17,Record_17", "18,Record_18", "19,Record_19", "32,Record_32").mkString("\n")
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+}
+
+class TestProjectableTableSource(
+ fieldTypes: Array[TypeInformation[_]],
+ fieldNames: Array[String])
+ extends BatchTableSource[Row] with ProjectableTableSource[Row] {
+
+ def this() = this(
+ fieldTypes = Array(
+ BasicTypeInfo.STRING_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO,
+ BasicTypeInfo.INT_TYPE_INFO,
+ BasicTypeInfo.DOUBLE_TYPE_INFO),
+ fieldNames = Array[String]("name", "id", "amount", "price")
+ )
+
+ /** Returns the data of the table as a [[org.apache.flink.api.java.DataSet]]. */
+ override def getDataSet(execEnv: JavaExecEnv): JavaSet[Row] = {
+ execEnv.fromCollection(generateDynamicCollection(33, fieldNames).asJava, getReturnType)
+ }
+
+ /** Returns the types of the table fields. */
+ override def getFieldTypes: Array[TypeInformation[_]] = fieldTypes
+
+ /** Returns the names of the table fields. */
+ override def getFieldsNames: Array[String] = fieldNames
+
+ /** Returns the [[TypeInformation]] for the return type. */
+ override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes)
+
+ /** Returns the number of fields of the table. */
+ override def getNumberOfFields: Int = fieldNames.length
+
+ override def projectFields(fields: Array[Int]): TestProjectableTableSource = {
+ val projectedFieldTypes = new Array[TypeInformation[_]](fields.length)
+ val projectedFieldNames = new Array[String](fields.length)
+
+ fields.zipWithIndex.foreach(f => {
+ projectedFieldTypes(f._2) = fieldTypes(f._1)
+ projectedFieldNames(f._2) = fieldNames(f._1)
+ })
+ new TestProjectableTableSource(projectedFieldTypes, projectedFieldNames)
+ }
+
+ private def generateDynamicCollection(num: Int, fieldNames: Array[String]): Seq[Row] = {
+ for {cnt <- 0 until num}
+ yield {
+ val row = new Row(fieldNames.length)
+ fieldNames.zipWithIndex.foreach(
+ f =>
+ f._1 match {
+ case "name" =>
+ row.setField(f._2, "Record_" + cnt)
+ case "id" =>
+ row.setField(f._2, cnt.toLong)
+ case "amount" =>
+ row.setField(f._2, cnt.toInt % 16)
+ case "price" =>
+ row.setField(f._2, cnt.toDouble / 3)
+ case _ =>
+ throw new IllegalArgumentException(s"unknown field name $f._1")
+ }
+ )
+ row
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/5baea3f2/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala
new file mode 100644
index 0000000..156f281
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.plan.rules.util
+
+import java.math.BigDecimal
+
+import org.apache.calcite.adapter.java.JavaTypeFactory
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
+import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR}
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+
+import scala.collection.JavaConverters._
+import org.apache.flink.api.table.plan.rules.util.RexProgramProjectExtractor._
+import org.junit.{Assert, Before, Test}
+
+/**
+ * This class is responsible for testing RexProgramProjectExtractor
+ */
+class RexProgramProjectExtractorTest {
+ private var typeFactory: JavaTypeFactory = null
+ private var rexBuilder: RexBuilder = null
+ private var allFieldTypes: Seq[RelDataType] = null
+ private val allFieldNames = List("name", "id", "amount", "price")
+
+ @Before
+ def setUp: Unit = {
+ typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
+ rexBuilder = new RexBuilder(typeFactory)
+ allFieldTypes = List(VARCHAR, BIGINT, INTEGER, DOUBLE).map(typeFactory.createSqlType(_))
+ }
+
+ @Test
+ def testExtractRefInputFields: Unit = {
+ val usedFields = extractRefInputFields(buildRexProgram)
+ Assert.assertArrayEquals(usedFields, Array(2, 3, 1))
+ }
+
+ @Test
+ def testRewriteRexProgram: Unit = {
+ val originRexProgram = buildRexProgram
+ Assert.assertTrue(extractExprStrList(originRexProgram).sameElements(Array(
+ "$0",
+ "$1",
+ "$2",
+ "$3",
+ "*($t2, $t3)",
+ "100",
+ "<($t4, $t5)",
+ "6",
+ ">($t1, $t7)",
+ "AND($t6, $t8)")))
+ // use amount, id, price fields to create a new RexProgram
+ val usedFields = Array(2, 3, 1)
+ val types = usedFields.map(allFieldTypes(_)).toList.asJava
+ val names = usedFields.map(allFieldNames(_)).toList.asJava
+ val inputRowType = typeFactory.createStructType(types, names)
+ val newRexProgram = rewriteRexProgram(originRexProgram, inputRowType, usedFields, rexBuilder)
+ Assert.assertTrue(extractExprStrList(newRexProgram).sameElements(Array(
+ "$0",
+ "$1",
+ "$2",
+ "*($t0, $t1)",
+ "100",
+ "<($t3, $t4)",
+ "6",
+ ">($t2, $t6)",
+ "AND($t5, $t7)")))
+ }
+
+ private def buildRexProgram: RexProgram = {
+ val types = allFieldTypes.asJava
+ val names = allFieldNames.asJava
+ val inputRowType = typeFactory.createStructType(types, names)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+ val t0 = rexBuilder.makeInputRef(types.get(2), 2)
+ val t1 = rexBuilder.makeInputRef(types.get(1), 1)
+ val t2 = rexBuilder.makeInputRef(types.get(3), 3)
+ val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
+ val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+ val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
+ // project: amount, amount * price
+ builder.addProject(t0, "amount")
+ builder.addProject(t3, "total")
+ // condition: amount * price < 100 and id > 6
+ val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
+ val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
+ val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
+ builder.addCondition(t8)
+ builder.getProgram
+ }
+
+ /**
+ * extract all expression string list from input RexProgram expression lists
+ *
+ * @param rexProgram input RexProgram instance to analyze
+ * @return all expression string list of input RexProgram expression lists
+ */
+ private def extractExprStrList(rexProgram: RexProgram) = {
+ rexProgram.getExprList.asScala.map(_.toString)
+ }
+
+}