You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/12/16 15:46:50 UTC
[21/47] flink git commit: [FLINK-4704] [table] Refactor package
structure of flink-table.
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
new file mode 100644
index 0000000..ed6cf7b
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.plan.logical.{LogicalNode, Project}
+
+import scala.collection.mutable.ListBuffer
+
+object ProjectionTranslator {
+
+ /**
+ * Extracts and deduplicates all aggregation and window property expressions (zero, one, or more)
+ * from the given expressions.
+ *
+ * @param exprs a list of expressions to extract
+ * @param tableEnv the TableEnvironment
+ * @return a Tuple2, the first field contains the extracted and deduplicated aggregations,
+ * and the second field contains the extracted and deduplicated window properties.
+ */
+ def extractAggregationsAndProperties(
+ exprs: Seq[Expression],
+ tableEnv: TableEnvironment): (Map[Expression, String], Map[Expression, String]) = {
+ exprs.foldLeft((Map[Expression, String](), Map[Expression, String]())) {
+ (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
+ }
+ }
+
+ /** Identifies and deduplicates aggregation functions and window properties. */
+ private def identifyAggregationsAndProperties(
+ exp: Expression,
+ tableEnv: TableEnvironment,
+ aggNames: Map[Expression, String],
+ propNames: Map[Expression, String]) : (Map[Expression, String], Map[Expression, String]) = {
+
+ exp match {
+ case agg: Aggregation =>
+ if (aggNames contains agg) {
+ (aggNames, propNames)
+ } else {
+ (aggNames + (agg -> tableEnv.createUniqueAttributeName()), propNames)
+ }
+ case prop: WindowProperty =>
+ if (propNames contains prop) {
+ (aggNames, propNames)
+ } else {
+ (aggNames, propNames + (prop -> tableEnv.createUniqueAttributeName()))
+ }
+ case l: LeafExpression =>
+ (aggNames, propNames)
+ case u: UnaryExpression =>
+ identifyAggregationsAndProperties(u.child, tableEnv, aggNames, propNames)
+ case b: BinaryExpression =>
+ val l = identifyAggregationsAndProperties(b.left, tableEnv, aggNames, propNames)
+ identifyAggregationsAndProperties(b.right, tableEnv, l._1, l._2)
+
+ // Functions calls
+ case c @ Call(name, args) =>
+ args.foldLeft((aggNames, propNames)){
+ (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
+ }
+
+ case sfc @ ScalarFunctionCall(clazz, args) =>
+ args.foldLeft((aggNames, propNames)){
+ (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
+ }
+
+ // General expression
+ case e: Expression =>
+ e.productIterator.foldLeft((aggNames, propNames)){
+ (x, y) => y match {
+ case e: Expression => identifyAggregationsAndProperties(e, tableEnv, x._1, x._2)
+ case _ => (x._1, x._2)
+ }
+ }
+ }
+ }
+
+ /**
+ * Replaces expressions with deduplicated aggregations and properties.
+ *
+ * @param exprs a list of expressions to replace
+ * @param tableEnv the TableEnvironment
+ * @param aggNames the deduplicated aggregations
+ * @param propNames the deduplicated properties
+ * @return a list of replaced expressions
+ */
+ def replaceAggregationsAndProperties(
+ exprs: Seq[Expression],
+ tableEnv: TableEnvironment,
+ aggNames: Map[Expression, String],
+ propNames: Map[Expression, String]): Seq[NamedExpression] = {
+ exprs.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
+ .map(UnresolvedAlias)
+ }
+
+ private def replaceAggregationsAndProperties(
+ exp: Expression,
+ tableEnv: TableEnvironment,
+ aggNames: Map[Expression, String],
+ propNames: Map[Expression, String]) : Expression = {
+
+ exp match {
+ case agg: Aggregation =>
+ val name = aggNames(agg)
+ Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
+ case prop: WindowProperty =>
+ val name = propNames(prop)
+ Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
+ case n @ Alias(agg: Aggregation, name, _) =>
+ val aName = aggNames(agg)
+ Alias(UnresolvedFieldReference(aName), name)
+ case n @ Alias(prop: WindowProperty, name, _) =>
+ val pName = propNames(prop)
+ Alias(UnresolvedFieldReference(pName), name)
+ case l: LeafExpression => l
+ case u: UnaryExpression =>
+ val c = replaceAggregationsAndProperties(u.child, tableEnv, aggNames, propNames)
+ u.makeCopy(Array(c))
+ case b: BinaryExpression =>
+ val l = replaceAggregationsAndProperties(b.left, tableEnv, aggNames, propNames)
+ val r = replaceAggregationsAndProperties(b.right, tableEnv, aggNames, propNames)
+ b.makeCopy(Array(l, r))
+
+ // Functions calls
+ case c @ Call(name, args) =>
+ val newArgs = args.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
+ c.makeCopy(Array(name, newArgs))
+
+ case sfc @ ScalarFunctionCall(clazz, args) =>
+ val newArgs: Seq[Expression] = args
+ .map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
+ sfc.makeCopy(Array(clazz, newArgs))
+
+ // array constructor
+ case c @ ArrayConstructor(args) =>
+ val newArgs = c.elements
+ .map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
+ c.makeCopy(Array(newArgs))
+
+ // General expression
+ case e: Expression =>
+ val newArgs = e.productIterator.map {
+ case arg: Expression =>
+ replaceAggregationsAndProperties(arg, tableEnv, aggNames, propNames)
+ }
+ e.makeCopy(newArgs.toArray)
+ }
+ }
+
+ /**
+ * Expands an UnresolvedFieldReference("*") to parent's full project list.
+ */
+ def expandProjectList(
+ exprs: Seq[Expression],
+ parent: LogicalNode,
+ tableEnv: TableEnvironment)
+ : Seq[Expression] = {
+
+ val projectList = new ListBuffer[Expression]
+
+ exprs.foreach {
+ case n: UnresolvedFieldReference if n.name == "*" =>
+ projectList ++= parent.output.map(a => UnresolvedFieldReference(a.name))
+
+ case Flattening(unresolved) =>
+ // simulate a simple project to resolve fields using current parent
+ val project = Project(Seq(UnresolvedAlias(unresolved)), parent).validate(tableEnv)
+ val resolvedExpr = project
+ .output
+ .headOption
+ .getOrElse(throw new RuntimeException("Could not find resolved composite."))
+ resolvedExpr.validateInput()
+ val newProjects = resolvedExpr.resultType match {
+ case ct: CompositeType[_] =>
+ (0 until ct.getArity).map { idx =>
+ projectList += GetCompositeField(unresolved, ct.getFieldNames()(idx))
+ }
+ case _ =>
+ projectList += unresolved
+ }
+
+ case e: Expression => projectList += e
+ }
+ projectList
+ }
+
+ /**
+ * Extract all field references from the given expressions.
+ *
+ * @param exprs a list of expressions to extract
+ * @return a list of field references extracted from the given expressions
+ */
+ def extractFieldReferences(exprs: Seq[Expression]): Seq[NamedExpression] = {
+ exprs.foldLeft(Set[NamedExpression]()) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }.toSeq
+ }
+
+ private def identifyFieldReferences(
+ expr: Expression,
+ fieldReferences: Set[NamedExpression]): Set[NamedExpression] = expr match {
+
+ case f: UnresolvedFieldReference =>
+ fieldReferences + UnresolvedAlias(f)
+
+ case b: BinaryExpression =>
+ val l = identifyFieldReferences(b.left, fieldReferences)
+ identifyFieldReferences(b.right, l)
+
+ // Functions calls
+ case c @ Call(name, args) =>
+ args.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }
+ case sfc @ ScalarFunctionCall(clazz, args) =>
+ args.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }
+
+ // array constructor
+ case c @ ArrayConstructor(args) =>
+ args.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }
+
+ // ignore fields from window property
+ case w : WindowProperty =>
+ fieldReferences
+
+ // keep this case after all unwanted unary expressions
+ case u: UnaryExpression =>
+ identifyFieldReferences(u.child, fieldReferences)
+
+ // General expression
+ case e: Expression =>
+ e.productIterator.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => expr match {
+ case e: Expression => identifyFieldReferences(e, fieldReferences)
+ case _ => fieldReferences
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/TreeNode.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/TreeNode.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/TreeNode.scala
new file mode 100644
index 0000000..fdf45e7
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/TreeNode.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan
+
+import org.apache.commons.lang.ClassUtils
+
+/**
+ * Generic base class for trees that can be transformed and traversed.
+ */
+abstract class TreeNode[A <: TreeNode[A]] extends Product { self: A =>
+
+ /**
+ * List of child nodes that should be considered when doing transformations. Other values
+ * in the Product will not be transformed, only handed through.
+ */
+ private[flink] def children: Seq[A]
+
+ /**
+ * Tests for equality by first testing for reference equality.
+ */
+ private[flink] def fastEquals(other: TreeNode[_]): Boolean = this.eq(other) || this == other
+
+ /**
+ * Do tree transformation in post order.
+ */
+ private[flink] def postOrderTransform(rule: PartialFunction[A, A]): A = {
+ def childrenTransform(rule: PartialFunction[A, A]): A = {
+ var changed = false
+ val newArgs = productIterator.map {
+ case arg: TreeNode[_] if children.contains(arg) =>
+ val newChild = arg.asInstanceOf[A].postOrderTransform(rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case args: Traversable[_] => args.map {
+ case arg: TreeNode[_] if children.contains(arg) =>
+ val newChild = arg.asInstanceOf[A].postOrderTransform(rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case other => other
+ }
+ case nonChild: AnyRef => nonChild
+ case null => null
+ }.toArray
+ if (changed) makeCopy(newArgs) else this
+ }
+
+ val afterChildren = childrenTransform(rule)
+ if (afterChildren fastEquals this) {
+ rule.applyOrElse(this, identity[A])
+ } else {
+ rule.applyOrElse(afterChildren, identity[A])
+ }
+ }
+
+ /**
+ * Runs the given function first on the node and then recursively on all its children.
+ */
+ private[flink] def preOrderVisit(f: A => Unit): Unit = {
+ f(this)
+ children.foreach(_.preOrderVisit(f))
+ }
+
+ /**
+ * Creates a new copy of this expression with new children. This is used during transformation
+ * if children change.
+ */
+ private[flink] def makeCopy(newArgs: Array[AnyRef]): A = {
+ val ctors = getClass.getConstructors.filter(_.getParameterTypes.size > 0)
+ if (ctors.isEmpty) {
+ throw new RuntimeException(s"No valid constructor for ${getClass.getSimpleName}")
+ }
+
+ val defaultCtor = ctors.find { ctor =>
+ if (ctor.getParameterTypes.size != newArgs.length) {
+ false
+ } else if (newArgs.contains(null)) {
+ false
+ } else {
+ val argsClasses: Array[Class[_]] = newArgs.map(_.getClass)
+ ClassUtils.isAssignable(argsClasses, ctor.getParameterTypes)
+ }
+ }.getOrElse(ctors.maxBy(_.getParameterTypes.size))
+
+ try {
+ defaultCtor.newInstance(newArgs: _*).asInstanceOf[A]
+ } catch {
+ case e: Throwable =>
+ throw new RuntimeException(
+ s"Fail to copy treeNode ${getClass.getName}: ${e.getStackTraceString}")
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/cost/DataSetCost.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/cost/DataSetCost.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/cost/DataSetCost.scala
new file mode 100644
index 0000000..7b439ec
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/cost/DataSetCost.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.cost
+
+import org.apache.calcite.plan.{RelOptUtil, RelOptCostFactory, RelOptCost}
+import org.apache.calcite.util.Util
+
+/**
+ * This class is based on Apache Calcite's `org.apache.calcite.plan.volcano.VolcanoCost` and has
+ * an adapted cost comparison method `isLe(other: RelOptCost)` that takes io and cpu into account.
+ */
+class DataSetCost(val rowCount: Double, val cpu: Double, val io: Double) extends RelOptCost {
+
+ def getCpu: Double = cpu
+
+ def isInfinite: Boolean = {
+ (this eq DataSetCost.Infinity) ||
+ (this.rowCount == Double.PositiveInfinity) ||
+ (this.cpu == Double.PositiveInfinity) ||
+ (this.io == Double.PositiveInfinity)
+ }
+
+ def getIo: Double = io
+
+ def isLe(other: RelOptCost): Boolean = {
+ val that: DataSetCost = other.asInstanceOf[DataSetCost]
+ (this eq that) ||
+ (this.io < that.io) ||
+ (this.io == that.io && this.cpu < that.cpu) ||
+ (this.io == that.io && this.cpu == that.cpu && this.rowCount < that.rowCount)
+ }
+
+ def isLt(other: RelOptCost): Boolean = {
+ isLe(other) && !(this == other)
+ }
+
+ def getRows: Double = rowCount
+
+ override def hashCode: Int = Util.hashCode(rowCount) + Util.hashCode(cpu) + Util.hashCode(io)
+
+ def equals(other: RelOptCost): Boolean = {
+ (this eq other) ||
+ other.isInstanceOf[DataSetCost] &&
+ (this.rowCount == other.asInstanceOf[DataSetCost].rowCount) &&
+ (this.cpu == other.asInstanceOf[DataSetCost].cpu) &&
+ (this.io == other.asInstanceOf[DataSetCost].io)
+ }
+
+ def isEqWithEpsilon(other: RelOptCost): Boolean = {
+ if (!other.isInstanceOf[DataSetCost]) {
+ return false
+ }
+ val that: DataSetCost = other.asInstanceOf[DataSetCost]
+ (this eq that) ||
+ ((Math.abs(this.rowCount - that.rowCount) < RelOptUtil.EPSILON) &&
+ (Math.abs(this.cpu - that.cpu) < RelOptUtil.EPSILON) &&
+ (Math.abs(this.io - that.io) < RelOptUtil.EPSILON))
+ }
+
+ def minus(other: RelOptCost): RelOptCost = {
+ if (this eq DataSetCost.Infinity) {
+ return this
+ }
+ val that: DataSetCost = other.asInstanceOf[DataSetCost]
+ new DataSetCost(this.rowCount - that.rowCount, this.cpu - that.cpu, this.io - that.io)
+ }
+
+ def multiplyBy(factor: Double): RelOptCost = {
+ if (this eq DataSetCost.Infinity) {
+ return this
+ }
+ new DataSetCost(rowCount * factor, cpu * factor, io * factor)
+ }
+
+ def divideBy(cost: RelOptCost): Double = {
+ val that: DataSetCost = cost.asInstanceOf[DataSetCost]
+ var d: Double = 1
+ var n: Double = 0
+ if ((this.rowCount != 0) && !this.rowCount.isInfinite &&
+ (that.rowCount != 0) && !that.rowCount.isInfinite)
+ {
+ d *= this.rowCount / that.rowCount
+ n += 1
+ }
+ if ((this.cpu != 0) && !this.cpu.isInfinite && (that.cpu != 0) && !that.cpu.isInfinite) {
+ d *= this.cpu / that.cpu
+ n += 1
+ }
+ if ((this.io != 0) && !this.io.isInfinite && (that.io != 0) && !that.io.isInfinite) {
+ d *= this.io / that.io
+ n += 1
+ }
+ if (n == 0) {
+ return 1.0
+ }
+ Math.pow(d, 1 / n)
+ }
+
+ def plus(other: RelOptCost): RelOptCost = {
+ val that: DataSetCost = other.asInstanceOf[DataSetCost]
+ if ((this eq DataSetCost.Infinity) || (that eq DataSetCost.Infinity)) {
+ return DataSetCost.Infinity
+ }
+ new DataSetCost(this.rowCount + that.rowCount, this.cpu + that.cpu, this.io + that.io)
+ }
+
+ override def toString: String = s"{$rowCount rows, $cpu cpu, $io io}"
+
+}
+
+object DataSetCost {
+
+ private[flink] val Infinity = new DataSetCost(
+ Double.PositiveInfinity,
+ Double.PositiveInfinity,
+ Double.PositiveInfinity)
+ {
+ override def toString: String = "{inf}"
+ }
+
+ private[flink] val Huge = new DataSetCost(Double.MaxValue, Double.MaxValue, Double.MaxValue) {
+ override def toString: String = "{huge}"
+ }
+
+ private[flink] val Zero = new DataSetCost(0.0, 0.0, 0.0) {
+ override def toString: String = "{0}"
+ }
+
+ private[flink] val Tiny = new DataSetCost(1.0, 1.0, 0.0) {
+ override def toString = "{tiny}"
+ }
+
+ val FACTORY: RelOptCostFactory = new DataSetCostFactory
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/cost/DataSetCostFactory.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/cost/DataSetCostFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/cost/DataSetCostFactory.scala
new file mode 100644
index 0000000..50d3842
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/cost/DataSetCostFactory.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.cost
+
+import org.apache.calcite.plan.{RelOptCost, RelOptCostFactory}
+
+/**
+ * This class is based on Apache Calcite's `org.apache.calcite.plan.volcano.VolcanoCost#Factory`.
+ */
+class DataSetCostFactory extends RelOptCostFactory {
+
+ override def makeCost(dRows: Double, dCpu: Double, dIo: Double): RelOptCost = {
+ new DataSetCost(dRows, dCpu, dIo)
+ }
+
+ override def makeHugeCost: RelOptCost = {
+ DataSetCost.Huge
+ }
+
+ override def makeInfiniteCost: RelOptCost = {
+ DataSetCost.Infinity
+ }
+
+ override def makeTinyCost: RelOptCost = {
+ DataSetCost.Tiny
+ }
+
+ override def makeZeroCost: RelOptCost = {
+ DataSetCost.Zero
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalNode.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalNode.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalNode.scala
new file mode 100644
index 0000000..7a9b08e
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalNode.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan.logical
+
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.tools.RelBuilder
+import org.apache.flink.table.plan.TreeNode
+import org.apache.flink.table.api.{TableEnvironment, ValidationException}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.typeutils.TypeCoercion
+import org.apache.flink.table.validate._
+
+/**
+ * LogicalNode is created and validated as we construct query plan using Table API.
+ *
+ * The main validation procedure is separated into two phases:
+ *
+ * Expressions' resolution and transformation ([[resolveExpressions]]):
+ *
+ * - translate [[UnresolvedFieldReference]] into [[ResolvedFieldReference]]
+ * using child operator's output
+ * - translate [[Call]](UnresolvedFunction) into solid Expression
+ * - generate alias names for query output
+ * - ....
+ *
+ * LogicalNode validation ([[validate]]):
+ *
+ * - check no [[UnresolvedFieldReference]] exists any more
+ * - check if all expressions have children of needed type
+ * - check each logical operator have desired input
+ *
+ * Once we pass the validation phase, we can safely convert LogicalNode into Calcite's RelNode.
+ */
+abstract class LogicalNode extends TreeNode[LogicalNode] {
+ def output: Seq[Attribute]
+
+ def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
+ // resolve references and function calls
+ val exprResolved = expressionPostOrderTransform {
+ case u @ UnresolvedFieldReference(name) =>
+ resolveReference(tableEnv, name).getOrElse(u)
+ case c @ Call(name, children) if c.childrenValid =>
+ tableEnv.getFunctionCatalog.lookupFunction(name, children)
+ }
+
+ exprResolved.expressionPostOrderTransform {
+ case ips: InputTypeSpec if ips.childrenValid =>
+ var changed: Boolean = false
+ val newChildren = ips.expectedTypes.zip(ips.children).map { case (tpe, child) =>
+ val childType = child.resultType
+ if (childType != tpe && TypeCoercion.canSafelyCast(childType, tpe)) {
+ changed = true
+ Cast(child, tpe)
+ } else {
+ child
+ }
+ }.toArray[AnyRef]
+ if (changed) ips.makeCopy(newChildren) else ips
+ }
+ }
+
+ final def toRelNode(relBuilder: RelBuilder): RelNode = construct(relBuilder).build()
+
+ protected[logical] def construct(relBuilder: RelBuilder): RelBuilder
+
+ def validate(tableEnv: TableEnvironment): LogicalNode = {
+ val resolvedNode = resolveExpressions(tableEnv)
+ resolvedNode.expressionPostOrderTransform {
+ case a: Attribute if !a.valid =>
+ val from = children.flatMap(_.output).map(_.name).mkString(", ")
+ failValidation(s"Cannot resolve [${a.name}] given input [$from].")
+
+ case e: Expression if e.validateInput().isFailure =>
+ failValidation(s"Expression $e failed on input check: " +
+ s"${e.validateInput().asInstanceOf[ValidationFailure].message}")
+ }
+ }
+
+ /**
+ * Resolves the given strings to a [[NamedExpression]] using the input from all child
+ * nodes of this LogicalPlan.
+ */
+ def resolveReference(tableEnv: TableEnvironment, name: String): Option[NamedExpression] = {
+ val childrenOutput = children.flatMap(_.output)
+ val candidates = childrenOutput.filter(_.name.equalsIgnoreCase(name))
+ if (candidates.length > 1) {
+ failValidation(s"Reference $name is ambiguous.")
+ } else if (candidates.isEmpty) {
+ None
+ } else {
+ Some(candidates.head.withName(name))
+ }
+ }
+
+ /**
+ * Runs [[postOrderTransform]] with `rule` on all expressions present in this logical node.
+ *
+ * @param rule the rule to be applied to every expression in this logical node.
+ */
+ def expressionPostOrderTransform(rule: PartialFunction[Expression, Expression]): LogicalNode = {
+ var changed = false
+
+ def expressionPostOrderTransform(e: Expression): Expression = {
+ val newExpr = e.postOrderTransform(rule)
+ if (newExpr.fastEquals(e)) {
+ e
+ } else {
+ changed = true
+ newExpr
+ }
+ }
+
+ val newArgs = productIterator.map {
+ case e: Expression => expressionPostOrderTransform(e)
+ case Some(e: Expression) => Some(expressionPostOrderTransform(e))
+ case seq: Traversable[_] => seq.map {
+ case e: Expression => expressionPostOrderTransform(e)
+ case other => other
+ }
+ case r: Resolvable[_] => r.resolveExpressions(e => expressionPostOrderTransform(e))
+ case other: AnyRef => other
+ }.toArray
+
+ if (changed) makeCopy(newArgs) else this
+ }
+
+ protected def failValidation(msg: String): Nothing = {
+ throw new ValidationException(msg)
+ }
+}
+
+abstract class LeafNode extends LogicalNode {
+ override def children: Seq[LogicalNode] = Nil
+}
+
+abstract class UnaryNode extends LogicalNode {
+ def child: LogicalNode
+
+ override def children: Seq[LogicalNode] = child :: Nil
+}
+
+abstract class BinaryNode extends LogicalNode {
+ def left: LogicalNode
+ def right: LogicalNode
+
+ override def children: Seq[LogicalNode] = left :: right :: Nil
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala
new file mode 100644
index 0000000..1264566
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.logical
+
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.expressions.{Expression, WindowReference}
+import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
+
+abstract class LogicalWindow(val alias: Option[Expression]) extends Resolvable[LogicalWindow] {
+
+ def resolveExpressions(resolver: (Expression) => Expression): LogicalWindow = this
+
+ def validate(tableEnv: TableEnvironment): ValidationResult = alias match {
+ case Some(WindowReference(_)) => ValidationSuccess
+ case Some(_) => ValidationFailure("Window reference for window expected.")
+ case None => ValidationSuccess
+ }
+
+ override def toString: String = getClass.getSimpleName
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/Resolvable.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/Resolvable.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/Resolvable.scala
new file mode 100644
index 0000000..995bac5
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/Resolvable.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.logical
+
+import org.apache.flink.table.expressions.Expression
+
+/**
+ * A class implementing this interface can resolve the expressions of its parameters and
+ * return a new instance with resolved parameters. This is necessary if expression are nested in
+ * a not supported structure. By default, the validation of a logical node can resolve common
+ * structures like `Expression`, `Option[Expression]`, `Traversable[Expression]`.
+ *
+ * See also [[LogicalNode.expressionPostOrderTransform(scala.PartialFunction)]].
+ *
+ * @tparam T class which expression parameters need to be resolved
+ */
+trait Resolvable[T <: AnyRef] {
+
+ /**
+ * An implementing class can resolve its expressions by applying the given resolver
+ * function on its parameters.
+ *
+ * @param resolver function that can resolve an expression
+ * @return class with resolved expression parameters
+ */
+ def resolveExpressions(resolver: (Expression) => Expression): T
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala
new file mode 100644
index 0000000..b12e654
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.logical
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.table.api.{BatchTableEnvironment, StreamTableEnvironment, TableEnvironment}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo, TypeCoercion}
+import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
+
+abstract class EventTimeGroupWindow(
+ name: Option[Expression],
+ time: Expression)
+ extends LogicalWindow(name) {
+
+ override def validate(tableEnv: TableEnvironment): ValidationResult = {
+ val valid = super.validate(tableEnv)
+ if (valid.isFailure) {
+ return valid
+ }
+
+ tableEnv match {
+ case _: StreamTableEnvironment =>
+ time match {
+ case RowtimeAttribute() =>
+ ValidationSuccess
+ case _ =>
+ ValidationFailure("Event-time window expects a 'rowtime' time field.")
+ }
+ case _: BatchTableEnvironment =>
+ if (!TypeCoercion.canCast(time.resultType, BasicTypeInfo.LONG_TYPE_INFO)) {
+ ValidationFailure(s"Event-time window expects a time field that can be safely cast " +
+ s"to Long, but is ${time.resultType}")
+ } else {
+ ValidationSuccess
+ }
+ }
+
+ }
+}
+
+abstract class ProcessingTimeGroupWindow(name: Option[Expression]) extends LogicalWindow(name)
+
+// ------------------------------------------------------------------------------------------------
+// Tumbling group windows
+// ------------------------------------------------------------------------------------------------
+
+object TumblingGroupWindow {
+ def validate(tableEnv: TableEnvironment, size: Expression): ValidationResult = size match {
+ case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
+ ValidationSuccess
+ case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
+ ValidationSuccess
+ case _ =>
+ ValidationFailure("Tumbling window expects size literal of type Interval of Milliseconds " +
+ "or Interval of Rows.")
+ }
+}
+
+case class ProcessingTimeTumblingGroupWindow(
+ name: Option[Expression],
+ size: Expression)
+ extends ProcessingTimeGroupWindow(name) {
+
+ override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
+ ProcessingTimeTumblingGroupWindow(
+ name.map(resolve),
+ resolve(size))
+
+ override def validate(tableEnv: TableEnvironment): ValidationResult =
+ super.validate(tableEnv).orElse(TumblingGroupWindow.validate(tableEnv, size))
+
+ override def toString: String = s"ProcessingTimeTumblingGroupWindow($name, $size)"
+}
+
+case class EventTimeTumblingGroupWindow(
+ name: Option[Expression],
+ timeField: Expression,
+ size: Expression)
+ extends EventTimeGroupWindow(
+ name,
+ timeField) {
+
+ override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
+ EventTimeTumblingGroupWindow(
+ name.map(resolve),
+ resolve(timeField),
+ resolve(size))
+
+ override def validate(tableEnv: TableEnvironment): ValidationResult =
+ super.validate(tableEnv)
+ .orElse(TumblingGroupWindow.validate(tableEnv, size))
+ .orElse(size match {
+ case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
+ ValidationFailure(
+ "Event-time grouping windows on row intervals are currently not supported.")
+ case _ =>
+ ValidationSuccess
+ })
+
+ override def toString: String = s"EventTimeTumblingGroupWindow($name, $timeField, $size)"
+}
+
+// ------------------------------------------------------------------------------------------------
+// Sliding group windows
+// ------------------------------------------------------------------------------------------------
+
+object SlidingGroupWindow {
+ def validate(
+ tableEnv: TableEnvironment,
+ size: Expression,
+ slide: Expression)
+ : ValidationResult = {
+
+ val checkedSize = size match {
+ case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
+ ValidationSuccess
+ case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
+ ValidationSuccess
+ case _ =>
+ ValidationFailure("Sliding window expects size literal of type Interval of " +
+ "Milliseconds or Interval of Rows.")
+ }
+
+ val checkedSlide = slide match {
+ case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
+ ValidationSuccess
+ case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
+ ValidationSuccess
+ case _ =>
+ ValidationFailure("Sliding window expects slide literal of type Interval of " +
+ "Milliseconds or Interval of Rows.")
+ }
+
+ checkedSize
+ .orElse(checkedSlide)
+ .orElse {
+ if (size.resultType != slide.resultType) {
+ ValidationFailure("Sliding window expects same type of size and slide.")
+ } else {
+ ValidationSuccess
+ }
+ }
+ }
+}
+
+case class ProcessingTimeSlidingGroupWindow(
+ name: Option[Expression],
+ size: Expression,
+ slide: Expression)
+ extends ProcessingTimeGroupWindow(name) {
+
+ override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
+ ProcessingTimeSlidingGroupWindow(
+ name.map(resolve),
+ resolve(size),
+ resolve(slide))
+
+ override def validate(tableEnv: TableEnvironment): ValidationResult =
+ super.validate(tableEnv).orElse(SlidingGroupWindow.validate(tableEnv, size, slide))
+
+ override def toString: String = s"ProcessingTimeSlidingGroupWindow($name, $size, $slide)"
+}
+
+case class EventTimeSlidingGroupWindow(
+ name: Option[Expression],
+ timeField: Expression,
+ size: Expression,
+ slide: Expression)
+ extends EventTimeGroupWindow(name, timeField) {
+
+ override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
+ EventTimeSlidingGroupWindow(
+ name.map(resolve),
+ resolve(timeField),
+ resolve(size),
+ resolve(slide))
+
+ override def validate(tableEnv: TableEnvironment): ValidationResult =
+ super.validate(tableEnv)
+ .orElse(SlidingGroupWindow.validate(tableEnv, size, slide))
+ .orElse(size match {
+ case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
+ ValidationFailure(
+ "Event-time grouping windows on row intervals are currently not supported.")
+ case _ =>
+ ValidationSuccess
+ })
+
+ override def toString: String = s"EventTimeSlidingGroupWindow($name, $timeField, $size, $slide)"
+}
+
+// ------------------------------------------------------------------------------------------------
+// Session group windows
+// ------------------------------------------------------------------------------------------------
+
+object SessionGroupWindow {
+
+ def validate(tableEnv: TableEnvironment, gap: Expression): ValidationResult = gap match {
+ case Literal(timeInterval: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
+ ValidationSuccess
+ case _ =>
+ ValidationFailure(
+ "Session window expects gap literal of type Interval of Milliseconds.")
+ }
+}
+
+case class ProcessingTimeSessionGroupWindow(
+ name: Option[Expression],
+ gap: Expression)
+ extends ProcessingTimeGroupWindow(name) {
+
+ override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
+ ProcessingTimeSessionGroupWindow(
+ name.map(resolve),
+ resolve(gap))
+
+ override def validate(tableEnv: TableEnvironment): ValidationResult =
+ super.validate(tableEnv).orElse(SessionGroupWindow.validate(tableEnv, gap))
+
+ override def toString: String = s"ProcessingTimeSessionGroupWindow($name, $gap)"
+}
+
+case class EventTimeSessionGroupWindow(
+ name: Option[Expression],
+ timeField: Expression,
+ gap: Expression)
+ extends EventTimeGroupWindow(
+ name,
+ timeField) {
+
+ override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
+ EventTimeSessionGroupWindow(
+ name.map(resolve),
+ resolve(timeField),
+ resolve(gap))
+
+ override def validate(tableEnv: TableEnvironment): ValidationResult =
+ super.validate(tableEnv).orElse(SessionGroupWindow.validate(tableEnv, gap))
+
+ override def toString: String = s"EventTimeSessionGroupWindow($name, $timeField, $gap)"
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
new file mode 100644
index 0000000..eae42cd
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
@@ -0,0 +1,694 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan.logical
+
+import java.lang.reflect.Method
+import java.util
+
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.CorrelationId
+import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableFunctionScan}
+import org.apache.calcite.rex.{RexInputRef, RexNode}
+import org.apache.calcite.tools.RelBuilder
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.operators.join.JoinType
+import org.apache.flink.table._
+import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment, UnresolvedException}
+import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.functions.utils.TableSqlFunction
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
+import org.apache.flink.table.typeutils.TypeConverter
+import org.apache.flink.table.validate.{ValidationFailure, ValidationSuccess}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extends UnaryNode {
+ override def output: Seq[Attribute] = projectList.map(_.toAttribute)
+
+ override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
+ val afterResolve = super.resolveExpressions(tableEnv).asInstanceOf[Project]
+ val newProjectList =
+ afterResolve.projectList.zipWithIndex.map { case (e, i) =>
+ e match {
+ case u @ UnresolvedAlias(c) => c match {
+ case ne: NamedExpression => ne
+ case expr if !expr.valid => u
+ case c @ Cast(ne: NamedExpression, tp) => Alias(c, s"${ne.name}-$tp")
+ case gcf: GetCompositeField => Alias(gcf, gcf.aliasName().getOrElse(s"_c$i"))
+ case other => Alias(other, s"_c$i")
+ }
+ case _ =>
+ throw new RuntimeException("This should never be called and probably points to a bug.")
+ }
+ }
+ Project(newProjectList, child)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ val resolvedProject = super.validate(tableEnv).asInstanceOf[Project]
+ val names: mutable.Set[String] = mutable.Set()
+
+ def checkName(name: String): Unit = {
+ if (names.contains(name)) {
+ failValidation(s"Duplicate field name $name.")
+ } else if (tableEnv.isInstanceOf[StreamTableEnvironment] && name == "rowtime") {
+ failValidation("'rowtime' cannot be used as field name in a streaming environment.")
+ } else {
+ names.add(name)
+ }
+ }
+
+ resolvedProject.projectList.foreach {
+ case n: Alias =>
+ // explicit name
+ checkName(n.name)
+ case r: ResolvedFieldReference =>
+ // simple field forwarding
+ checkName(r.name)
+ case _ => // Do nothing
+ }
+ resolvedProject
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ val allAlias = projectList.forall(_.isInstanceOf[Alias])
+ child.construct(relBuilder)
+ if (allAlias) {
+ // Calcite's RelBuilder does not translate identity projects even if they rename fields.
+ // Add a projection ourselves (will be automatically removed by translation rules).
+ val project = LogicalProject.create(relBuilder.peek(),
+ // avoid AS call
+ projectList.map(_.asInstanceOf[Alias].child.toRexNode(relBuilder)).asJava,
+ projectList.map(_.name).asJava)
+ relBuilder.build() // pop previous relNode
+ relBuilder.push(project)
+ } else {
+ relBuilder.project(projectList.map(_.toRexNode(relBuilder)): _*)
+ }
+ }
+}
+
+case class AliasNode(aliasList: Seq[Expression], child: LogicalNode) extends UnaryNode {
+ override def output: Seq[Attribute] =
+ throw UnresolvedException("Invalid call to output on AliasNode")
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder =
+ throw UnresolvedException("Invalid call to toRelNode on AliasNode")
+
+ override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
+ if (aliasList.length > child.output.length) {
+ failValidation("Aliasing more fields than we actually have")
+ } else if (!aliasList.forall(_.isInstanceOf[UnresolvedFieldReference])) {
+ failValidation("Alias only accept name expressions as arguments")
+ } else if (!aliasList.forall(_.asInstanceOf[UnresolvedFieldReference].name != "*")) {
+ failValidation("Alias can not accept '*' as name")
+ } else if (tableEnv.isInstanceOf[StreamTableEnvironment] && !aliasList.forall {
+ case UnresolvedFieldReference(name) => name != "rowtime"
+ }) {
+ failValidation("'rowtime' cannot be used as field name in a streaming environment.")
+ } else {
+ val names = aliasList.map(_.asInstanceOf[UnresolvedFieldReference].name)
+ val input = child.output
+ Project(
+ names.zip(input).map { case (name, attr) =>
+ Alias(attr, name)} ++ input.drop(names.length), child)
+ }
+ }
+}
+
+case class Distinct(child: LogicalNode) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ child.construct(relBuilder)
+ relBuilder.distinct()
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+ failValidation(s"Distinct on stream tables is currently not supported.")
+ }
+ this
+ }
+}
+
+case class Sort(order: Seq[Ordering], child: LogicalNode) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ child.construct(relBuilder)
+ relBuilder.sort(order.map(_.toRexNode(relBuilder)).asJava)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+ failValidation(s"Sort on stream tables is currently not supported.")
+ }
+ super.validate(tableEnv)
+ }
+}
+
+case class Limit(offset: Int, fetch: Int = -1, child: LogicalNode) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ child.construct(relBuilder)
+ relBuilder.limit(offset, fetch)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+ failValidation(s"Limit on stream tables is currently not supported.")
+ }
+ if (!child.validate(tableEnv).isInstanceOf[Sort]) {
+ failValidation(s"Limit operator must be preceded by an OrderBy operator.")
+ }
+ if (offset < 0) {
+ failValidation(s"Offset should be greater than or equal to zero.")
+ }
+ super.validate(tableEnv)
+ }
+}
+
+case class Filter(condition: Expression, child: LogicalNode) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ child.construct(relBuilder)
+ relBuilder.filter(condition.toRexNode(relBuilder))
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ val resolvedFilter = super.validate(tableEnv).asInstanceOf[Filter]
+ if (resolvedFilter.condition.resultType != BOOLEAN_TYPE_INFO) {
+ failValidation(s"Filter operator requires a boolean expression as input," +
+ s" but ${resolvedFilter.condition} is of type ${resolvedFilter.condition.resultType}")
+ }
+ resolvedFilter
+ }
+}
+
+case class Aggregate(
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[NamedExpression],
+ child: LogicalNode) extends UnaryNode {
+
+ override def output: Seq[Attribute] = {
+ (groupingExpressions ++ aggregateExpressions) map {
+ case ne: NamedExpression => ne.toAttribute
+ case e => Alias(e, e.toString).toAttribute
+ }
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ child.construct(relBuilder)
+ relBuilder.aggregate(
+ relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
+ aggregateExpressions.map {
+ case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
+ case _ => throw new RuntimeException("This should never happen.")
+ }.asJava)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+ failValidation(s"Aggregate on stream tables is currently not supported.")
+ }
+
+ val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate]
+ val groupingExprs = resolvedAggregate.groupingExpressions
+ val aggregateExprs = resolvedAggregate.aggregateExpressions
+ aggregateExprs.foreach(validateAggregateExpression)
+ groupingExprs.foreach(validateGroupingExpression)
+
+ def validateAggregateExpression(expr: Expression): Unit = expr match {
+ // check no nested aggregation exists.
+ case aggExpr: Aggregation =>
+ aggExpr.children.foreach { child =>
+ child.preOrderVisit {
+ case agg: Aggregation =>
+ failValidation(
+ "It's not allowed to use an aggregate function as " +
+ "input of another aggregate function")
+ case _ => // OK
+ }
+ }
+ case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) =>
+ failValidation(
+ s"expression '$a' is invalid because it is neither" +
+ " present in group by nor an aggregate function")
+ case e if groupingExprs.exists(_.checkEquals(e)) => // OK
+ case e => e.children.foreach(validateAggregateExpression)
+ }
+
+ def validateGroupingExpression(expr: Expression): Unit = {
+ if (!expr.resultType.isKeyType) {
+ failValidation(
+ s"expression $expr cannot be used as a grouping expression " +
+ "because it's not a valid key type which must be hashable and comparable")
+ }
+ }
+ resolvedAggregate
+ }
+}
+
+case class Minus(left: LogicalNode, right: LogicalNode, all: Boolean) extends BinaryNode {
+ override def output: Seq[Attribute] = left.output
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ left.construct(relBuilder)
+ right.construct(relBuilder)
+ relBuilder.minus(all)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+ failValidation(s"Minus on stream tables is currently not supported.")
+ }
+
+ val resolvedMinus = super.validate(tableEnv).asInstanceOf[Minus]
+ if (left.output.length != right.output.length) {
+ failValidation(s"Minus two table of different column sizes:" +
+ s" ${left.output.size} and ${right.output.size}")
+ }
+ val sameSchema = left.output.zip(right.output).forall { case (l, r) =>
+ l.resultType == r.resultType
+ }
+ if (!sameSchema) {
+ failValidation(s"Minus two table of different schema:" +
+ s" [${left.output.map(a => (a.name, a.resultType)).mkString(", ")}] and" +
+ s" [${right.output.map(a => (a.name, a.resultType)).mkString(", ")}]")
+ }
+ resolvedMinus
+ }
+}
+
+case class Union(left: LogicalNode, right: LogicalNode, all: Boolean) extends BinaryNode {
+ override def output: Seq[Attribute] = left.output
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ left.construct(relBuilder)
+ right.construct(relBuilder)
+ relBuilder.union(all)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment] && !all) {
+ failValidation(s"Union on stream tables is currently not supported.")
+ }
+
+ val resolvedUnion = super.validate(tableEnv).asInstanceOf[Union]
+ if (left.output.length != right.output.length) {
+ failValidation(s"Union two tables of different column sizes:" +
+ s" ${left.output.size} and ${right.output.size}")
+ }
+ val sameSchema = left.output.zip(right.output).forall { case (l, r) =>
+ l.resultType == r.resultType
+ }
+ if (!sameSchema) {
+ failValidation(s"Union two tables of different schema:" +
+ s" [${left.output.map(a => (a.name, a.resultType)).mkString(", ")}] and" +
+ s" [${right.output.map(a => (a.name, a.resultType)).mkString(", ")}]")
+ }
+ resolvedUnion
+ }
+}
+
+case class Intersect(left: LogicalNode, right: LogicalNode, all: Boolean) extends BinaryNode {
+ override def output: Seq[Attribute] = left.output
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ left.construct(relBuilder)
+ right.construct(relBuilder)
+ relBuilder.intersect(all)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+ failValidation(s"Intersect on stream tables is currently not supported.")
+ }
+
+ val resolvedIntersect = super.validate(tableEnv).asInstanceOf[Intersect]
+ if (left.output.length != right.output.length) {
+ failValidation(s"Intersect two tables of different column sizes:" +
+ s" ${left.output.size} and ${right.output.size}")
+ }
+ // allow different column names between tables
+ val sameSchema = left.output.zip(right.output).forall { case (l, r) =>
+ l.resultType == r.resultType
+ }
+ if (!sameSchema) {
+ failValidation(s"Intersect two tables of different schema:" +
+ s" [${left.output.map(a => (a.name, a.resultType)).mkString(", ")}] and" +
+ s" [${right.output.map(a => (a.name, a.resultType)).mkString(", ")}]")
+ }
+ resolvedIntersect
+ }
+}
+
+case class Join(
+ left: LogicalNode,
+ right: LogicalNode,
+ joinType: JoinType,
+ condition: Option[Expression],
+ correlated: Boolean) extends BinaryNode {
+
+ override def output: Seq[Attribute] = {
+ left.output ++ right.output
+ }
+
+ private case class JoinFieldReference(
+ name: String,
+ resultType: TypeInformation[_],
+ left: LogicalNode,
+ right: LogicalNode) extends Attribute {
+
+ val isFromLeftInput = left.output.map(_.name).contains(name)
+
+ val (indexInInput, indexInJoin) = if (isFromLeftInput) {
+ val indexInLeft = left.output.map(_.name).indexOf(name)
+ (indexInLeft, indexInLeft)
+ } else {
+ val indexInRight = right.output.map(_.name).indexOf(name)
+ (indexInRight, indexInRight + left.output.length)
+ }
+
+ override def toString = s"'$name"
+
+ override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
+ // look up type of field
+ val fieldType = relBuilder.field(2, if (isFromLeftInput) 0 else 1, name).getType
+ // create a new RexInputRef with index offset
+ new RexInputRef(indexInJoin, fieldType)
+ }
+
+ override def withName(newName: String): Attribute = {
+ if (newName == name) {
+ this
+ } else {
+ JoinFieldReference(newName, resultType, left, right)
+ }
+ }
+ }
+
+ override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
+ val node = super.resolveExpressions(tableEnv).asInstanceOf[Join]
+ val partialFunction: PartialFunction[Expression, Expression] = {
+ case field: ResolvedFieldReference => JoinFieldReference(
+ field.name,
+ field.resultType,
+ left,
+ right)
+ }
+ val resolvedCondition = node.condition.map(_.postOrderTransform(partialFunction))
+ Join(node.left, node.right, node.joinType, resolvedCondition, correlated)
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ left.construct(relBuilder)
+ right.construct(relBuilder)
+
+ val corSet = mutable.Set[CorrelationId]()
+ if (correlated) {
+ corSet += relBuilder.peek().getCluster.createCorrel()
+ }
+
+ relBuilder.join(
+ TypeConverter.flinkJoinTypeToRelType(joinType),
+ condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)),
+ corSet.asJava)
+ }
+
+ private def ambiguousName: Set[String] =
+ left.output.map(_.name).toSet.intersect(right.output.map(_.name).toSet)
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]
+ && !right.isInstanceOf[LogicalTableFunctionCall]) {
+ failValidation(s"Join on stream tables is currently not supported.")
+ }
+
+ val resolvedJoin = super.validate(tableEnv).asInstanceOf[Join]
+ if (!resolvedJoin.condition.forall(_.resultType == BOOLEAN_TYPE_INFO)) {
+ failValidation(s"Filter operator requires a boolean expression as input, " +
+ s"but ${resolvedJoin.condition} is of type ${resolvedJoin.joinType}")
+ } else if (ambiguousName.nonEmpty) {
+ failValidation(s"join relations with ambiguous names: ${ambiguousName.mkString(", ")}")
+ }
+
+ resolvedJoin.condition.foreach(testJoinCondition)
+ resolvedJoin
+ }
+
+ private def testJoinCondition(expression: Expression): Unit = {
+
+ def checkIfJoinCondition(exp : BinaryComparison) = exp.children match {
+ case (x : JoinFieldReference) :: (y : JoinFieldReference) :: Nil
+ if x.isFromLeftInput != y.isFromLeftInput => Unit
+ case x => failValidation(
+ s"Invalid non-join predicate $exp. For non-join predicates use Table#where.")
+ }
+
+ var equiJoinFound = false
+ def validateConditions(exp: Expression, isAndBranch: Boolean): Unit = exp match {
+ case x: And => x.children.foreach(validateConditions(_, isAndBranch))
+ case x: Or => x.children.foreach(validateConditions(_, isAndBranch = false))
+ case x: EqualTo =>
+ if (isAndBranch) {
+ equiJoinFound = true
+ }
+ checkIfJoinCondition(x)
+ case x: BinaryComparison => checkIfJoinCondition(x)
+ case x => failValidation(
+ s"Unsupported condition type: ${x.getClass.getSimpleName}. Condition: $x")
+ }
+
+ validateConditions(expression, isAndBranch = true)
+ if (!equiJoinFound) {
+ failValidation(s"Invalid join condition: $expression. At least one equi-join required.")
+ }
+ }
+}
+
+case class CatalogNode(
+ tableName: String,
+ rowType: RelDataType) extends LeafNode {
+
+ val output: Seq[Attribute] = rowType.getFieldList.asScala.map { field =>
+ ResolvedFieldReference(field.getName, FlinkTypeFactory.toTypeInfo(field.getType))
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ relBuilder.scan(tableName)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = this
+}
+
+/**
+ * Wrapper for valid logical plans generated from SQL String.
+ */
+case class LogicalRelNode(
+ relNode: RelNode) extends LeafNode {
+
+ val output: Seq[Attribute] = relNode.getRowType.getFieldList.asScala.map { field =>
+ ResolvedFieldReference(field.getName, FlinkTypeFactory.toTypeInfo(field.getType))
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ relBuilder.push(relNode)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = this
+}
+
+case class WindowAggregate(
+ groupingExpressions: Seq[Expression],
+ window: LogicalWindow,
+ propertyExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[NamedExpression],
+ child: LogicalNode)
+ extends UnaryNode {
+
+ override def output: Seq[Attribute] = {
+ (groupingExpressions ++ aggregateExpressions ++ propertyExpressions) map {
+ case ne: NamedExpression => ne.toAttribute
+ case e => Alias(e, e.toString).toAttribute
+ }
+ }
+
+ // resolve references of this operator's parameters
+ override def resolveReference(
+ tableEnv: TableEnvironment,
+ name: String)
+ : Option[NamedExpression] = tableEnv match {
+ // resolve reference to rowtime attribute in a streaming environment
+ case _: StreamTableEnvironment if name == "rowtime" =>
+ Some(RowtimeAttribute())
+ case _ =>
+ window.alias match {
+ // resolve reference to this window's alias
+ case Some(UnresolvedFieldReference(alias)) if name == alias =>
+ // check if reference can already be resolved by input fields
+ val found = super.resolveReference(tableEnv, name)
+ if (found.isDefined) {
+ failValidation(s"Reference $name is ambiguous.")
+ } else {
+ Some(WindowReference(name))
+ }
+ case _ =>
+ // resolve references as usual
+ super.resolveReference(tableEnv, name)
+ }
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder]
+ child.construct(flinkRelBuilder)
+ flinkRelBuilder.aggregate(
+ window,
+ relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
+ propertyExpressions.map {
+ case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)(relBuilder)
+ case _ => throw new RuntimeException("This should never happen.")
+ },
+ aggregateExpressions.map {
+ case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
+ case _ => throw new RuntimeException("This should never happen.")
+ }.asJava)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ val resolvedWindowAggregate = super.validate(tableEnv).asInstanceOf[WindowAggregate]
+ val groupingExprs = resolvedWindowAggregate.groupingExpressions
+ val aggregateExprs = resolvedWindowAggregate.aggregateExpressions
+ aggregateExprs.foreach(validateAggregateExpression)
+ groupingExprs.foreach(validateGroupingExpression)
+
+ def validateAggregateExpression(expr: Expression): Unit = expr match {
+ // check no nested aggregation exists.
+ case aggExpr: Aggregation =>
+ aggExpr.children.foreach { child =>
+ child.preOrderVisit {
+ case agg: Aggregation =>
+ failValidation(
+ "It's not allowed to use an aggregate function as " +
+ "input of another aggregate function")
+ case _ => // ok
+ }
+ }
+ case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) =>
+ failValidation(
+ s"Expression '$a' is invalid because it is neither" +
+ " present in group by nor an aggregate function")
+ case e if groupingExprs.exists(_.checkEquals(e)) => // ok
+ case e => e.children.foreach(validateAggregateExpression)
+ }
+
+ def validateGroupingExpression(expr: Expression): Unit = {
+ if (!expr.resultType.isKeyType) {
+ failValidation(
+ s"Expression $expr cannot be used as a grouping expression " +
+ "because it's not a valid key type which must be hashable and comparable")
+ }
+ }
+
+ // validate window
+ resolvedWindowAggregate.window.validate(tableEnv) match {
+ case ValidationFailure(msg) =>
+ failValidation(s"$window is invalid: $msg")
+ case ValidationSuccess => // ok
+ }
+
+ resolvedWindowAggregate
+ }
+}
+
+/**
+ * LogicalNode for calling a user-defined table functions.
+ *
+ * @param functionName function name
+ * @param tableFunction table function to be called (might be overloaded)
+ * @param parameters actual parameters
+ * @param fieldNames output field names
+ * @param child child logical node
+ */
+case class LogicalTableFunctionCall(
+ functionName: String,
+ tableFunction: TableFunction[_],
+ parameters: Seq[Expression],
+ resultType: TypeInformation[_],
+ fieldNames: Array[String],
+ child: LogicalNode)
+ extends UnaryNode {
+
+ private val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType)
+ private var evalMethod: Method = _
+
+ override def output: Seq[Attribute] = fieldNames.zip(fieldTypes).map {
+ case (n, t) => ResolvedFieldReference(n, t)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall]
+ // check if not Scala object
+ checkNotSingleton(tableFunction.getClass)
+ // check if class could be instantiated
+ checkForInstantiation(tableFunction.getClass)
+ // look for a signature that matches the input types
+ val signature = node.parameters.map(_.resultType)
+ val foundMethod = getEvalMethod(tableFunction, signature)
+ if (foundMethod.isEmpty) {
+ failValidation(
+ s"Given parameters of function '$functionName' do not match any signature. \n" +
+ s"Actual: ${signatureToString(signature)} \n" +
+ s"Expected: ${signaturesToString(tableFunction)}")
+ } else {
+ node.evalMethod = foundMethod.get
+ }
+ node
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ val fieldIndexes = getFieldInfo(resultType)._2
+ val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod)
+ val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ val sqlFunction = TableSqlFunction(
+ tableFunction.toString,
+ tableFunction,
+ resultType,
+ typeFactory,
+ function)
+
+ val scan = LogicalTableFunctionScan.create(
+ relBuilder.peek().getCluster,
+ new util.ArrayList[RelNode](),
+ relBuilder.call(sqlFunction, parameters.map(_.toRexNode(relBuilder)).asJava),
+ function.getElementType(null),
+ function.getRowType(relBuilder.getTypeFactory, null),
+ null)
+
+ relBuilder.push(scan)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalWindowAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalWindowAggregate.scala
new file mode 100644
index 0000000..d0d9af4
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalWindowAggregate.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.logical.rel
+
+import java.util
+
+import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
+import org.apache.calcite.rel.{RelNode, RelShuttle}
+import org.apache.calcite.util.ImmutableBitSet
+import org.apache.flink.table.calcite.{FlinkTypeFactory, FlinkRelBuilder}
+import FlinkRelBuilder.NamedWindowProperty
+import org.apache.flink.table.plan.logical.LogicalWindow
+
+class LogicalWindowAggregate(
+ window: LogicalWindow,
+ namedProperties: Seq[NamedWindowProperty],
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ child: RelNode,
+ indicator: Boolean,
+ groupSet: ImmutableBitSet,
+ groupSets: util.List[ImmutableBitSet],
+ aggCalls: util.List[AggregateCall])
+ extends Aggregate(
+ cluster,
+ traitSet,
+ child,
+ indicator,
+ groupSet,
+ groupSets,
+ aggCalls) {
+
+ def getWindow = window
+
+ def getNamedProperties = namedProperties
+
+ override def copy(
+ traitSet: RelTraitSet,
+ input: RelNode,
+ indicator: Boolean,
+ groupSet: ImmutableBitSet,
+ groupSets: util.List[ImmutableBitSet],
+ aggCalls: util.List[AggregateCall])
+ : Aggregate = {
+
+ new LogicalWindowAggregate(
+ window,
+ namedProperties,
+ cluster,
+ traitSet,
+ input,
+ indicator,
+ groupSet,
+ groupSets,
+ aggCalls)
+ }
+
+ override def accept(shuttle: RelShuttle): RelNode = shuttle.visit(this)
+
+ override def deriveRowType(): RelDataType = {
+ val aggregateRowType = super.deriveRowType()
+ val typeFactory = getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ val builder = typeFactory.builder
+ builder.addAll(aggregateRowType.getFieldList)
+ namedProperties.foreach { namedProp =>
+ builder.add(
+ namedProp.name,
+ typeFactory.createTypeFromTypeInfo(namedProp.property.resultType)
+ )
+ }
+ builder.build()
+ }
+}
+
+object LogicalWindowAggregate {
+
+ def create(
+ window: LogicalWindow,
+ namedProperties: Seq[NamedWindowProperty],
+ aggregate: Aggregate)
+ : LogicalWindowAggregate = {
+
+ val cluster: RelOptCluster = aggregate.getCluster
+ val traitSet: RelTraitSet = cluster.traitSetOf(Convention.NONE)
+ new LogicalWindowAggregate(
+ window,
+ namedProperties,
+ cluster,
+ traitSet,
+ aggregate.getInput,
+ aggregate.indicator,
+ aggregate.getGroupSet,
+ aggregate.getGroupSets,
+ aggregate.getAggCallList)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkAggregate.scala
new file mode 100644
index 0000000..7290594
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkAggregate.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.flink.table.calcite.FlinkRelBuilder
+import FlinkRelBuilder.NamedWindowProperty
+import org.apache.flink.table.runtime.aggregate.AggregateUtil._
+
+import scala.collection.JavaConverters._
+
+trait FlinkAggregate {
+
+ private[flink] def groupingToString(inputType: RelDataType, grouping: Array[Int]): String = {
+
+ val inFields = inputType.getFieldNames.asScala
+ grouping.map( inFields(_) ).mkString(", ")
+ }
+
+ private[flink] def aggregationToString(
+ inputType: RelDataType,
+ grouping: Array[Int],
+ rowType: RelDataType,
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ namedProperties: Seq[NamedWindowProperty])
+ : String = {
+
+ val inFields = inputType.getFieldNames.asScala
+ val outFields = rowType.getFieldNames.asScala
+
+ val groupStrings = grouping.map( inFields(_) )
+
+ val aggs = namedAggregates.map(_.getKey)
+ val aggStrings = aggs.map( a => s"${a.getAggregation}(${
+ if (a.getArgList.size() > 0) {
+ inFields(a.getArgList.get(0))
+ } else {
+ "*"
+ }
+ })")
+
+ val propStrings = namedProperties.map(_.property.toString)
+
+ (groupStrings ++ aggStrings ++ propStrings).zip(outFields).map {
+ case (f, o) => if (f == o) {
+ f
+ } else {
+ s"$f AS $o"
+ }
+ }.mkString(", ")
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCalc.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCalc.scala
new file mode 100644
index 0000000..5ebd3ee
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCalc.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rex.{RexNode, RexProgram}
+import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.api.TableConfig
+import org.apache.flink.table.codegen.{CodeGenerator, GeneratedFunction}
+import org.apache.flink.table.runtime.FlatMapRunner
+import org.apache.flink.table.typeutils.TypeConverter._
+
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+
+trait FlinkCalc {
+
+ private[flink] def functionBody(
+ generator: CodeGenerator,
+ inputType: TypeInformation[Any],
+ rowType: RelDataType,
+ calcProgram: RexProgram,
+ config: TableConfig,
+ expectedType: Option[TypeInformation[Any]]): String = {
+
+ val returnType = determineReturnType(
+ rowType,
+ expectedType,
+ config.getNullCheck,
+ config.getEfficientTypeUsage)
+
+ val condition = calcProgram.getCondition
+ val expandedExpressions = calcProgram.getProjectList.map(
+ expr => calcProgram.expandLocalRef(expr))
+ val projection = generator.generateResultExpression(
+ returnType,
+ rowType.getFieldNames,
+ expandedExpressions)
+
+ // only projection
+ if (condition == null) {
+ s"""
+ |${projection.code}
+ |${generator.collectorTerm}.collect(${projection.resultTerm});
+ |""".stripMargin
+ }
+ else {
+ val filterCondition = generator.generateExpression(
+ calcProgram.expandLocalRef(calcProgram.getCondition))
+ // only filter
+ if (projection == null) {
+ // conversion
+ if (inputType != returnType) {
+ val conversion = generator.generateConverterResultExpression(
+ returnType,
+ rowType.getFieldNames)
+
+ s"""
+ |${filterCondition.code}
+ |if (${filterCondition.resultTerm}) {
+ | ${conversion.code}
+ | ${generator.collectorTerm}.collect(${conversion.resultTerm});
+ |}
+ |""".stripMargin
+ }
+ // no conversion
+ else {
+ s"""
+ |${filterCondition.code}
+ |if (${filterCondition.resultTerm}) {
+ | ${generator.collectorTerm}.collect(${generator.input1Term});
+ |}
+ |""".stripMargin
+ }
+ }
+ // both filter and projection
+ else {
+ s"""
+ |${filterCondition.code}
+ |if (${filterCondition.resultTerm}) {
+ | ${projection.code}
+ | ${generator.collectorTerm}.collect(${projection.resultTerm});
+ |}
+ |""".stripMargin
+ }
+ }
+ }
+
+ private[flink] def calcMapFunction(
+ genFunction: GeneratedFunction[FlatMapFunction[Any, Any]]): RichFlatMapFunction[Any, Any] = {
+
+ new FlatMapRunner[Any, Any](
+ genFunction.name,
+ genFunction.code,
+ genFunction.returnType)
+ }
+
+ private[flink] def conditionToString(
+ calcProgram: RexProgram,
+ expression: (RexNode, List[String], Option[List[RexNode]]) => String): String = {
+
+ val cond = calcProgram.getCondition
+ val inFields = calcProgram.getInputRowType.getFieldNames.asScala.toList
+ val localExprs = calcProgram.getExprList.asScala.toList
+
+ if (cond != null) {
+ expression(cond, inFields, Some(localExprs))
+ } else {
+ ""
+ }
+ }
+
+ private[flink] def selectionToString(
+ calcProgram: RexProgram,
+ expression: (RexNode, List[String], Option[List[RexNode]]) => String): String = {
+
+ val proj = calcProgram.getProjectList.asScala.toList
+ val inFields = calcProgram.getInputRowType.getFieldNames.asScala.toList
+ val localExprs = calcProgram.getExprList.asScala.toList
+ val outFields = calcProgram.getOutputRowType.getFieldNames.asScala.toList
+
+ proj
+ .map(expression(_, inFields, Some(localExprs)))
+ .zip(outFields).map { case (e, o) => {
+ if (e != o) {
+ e + " AS " + o
+ } else {
+ e
+ }
+ }
+ }.mkString(", ")
+ }
+
+ private[flink] def calcOpName(
+ calcProgram: RexProgram,
+ expression: (RexNode, List[String], Option[List[RexNode]]) => String) = {
+
+ val conditionStr = conditionToString(calcProgram, expression)
+ val selectionStr = selectionToString(calcProgram, expression)
+
+ s"${if (calcProgram.getCondition != null) {
+ s"where: ($conditionStr), "
+ } else {
+ ""
+ }}select: ($selectionStr)"
+ }
+
+ private[flink] def calcToString(
+ calcProgram: RexProgram,
+ expression: (RexNode, List[String], Option[List[RexNode]]) => String) = {
+
+ val name = calcOpName(calcProgram, expression)
+ s"Calc($name)"
+ }
+}