You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2016/12/07 15:57:21 UTC
[2/5] flink git commit: [FLINK-4469] [table] Add support for user
defined table function in Table API & SQL
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
index e7416f7..932baeb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -19,14 +19,18 @@
package org.apache.flink.api.table.functions.utils
+import java.lang.reflect.{Method, Modifier}
import java.sql.{Date, Time, Timestamp}
import com.google.common.primitives.Primitives
+import org.apache.calcite.sql.SqlFunction
import org.apache.flink.api.common.functions.InvalidTypesException
-import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.TypeExtractor
-import org.apache.flink.api.table.ValidationException
-import org.apache.flink.api.table.functions.{ScalarFunction, UserDefinedFunction}
+import org.apache.flink.api.table.{FlinkTypeFactory, TableException, ValidationException}
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
+import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.util.InstantiationUtil
object UserDefinedFunctionUtils {
@@ -62,101 +66,167 @@ object UserDefinedFunctionUtils {
.getOrElse(throw ValidationException("Function class needs a default constructor."))
}
+ /**
+ * Check whether this is a Scala object. It is forbidden to use [[TableFunction]] implemented
+ * by a Scala object, since concurrent risks.
+ */
+ def checkNotSingleton(clazz: Class[_]): Unit = {
+ // TODO it is not a good way to check singleton. Maybe improve it further.
+ if (clazz.getFields.map(_.getName) contains "MODULE$") {
+ throw new ValidationException(
+ s"TableFunction implemented by class ${clazz.getCanonicalName} " +
+ s"is a Scala object, it is forbidden since concurrent risks.")
+ }
+ }
+
// ----------------------------------------------------------------------------------------------
- // Utilities for ScalarFunction
+ // Utilities for eval methods
// ----------------------------------------------------------------------------------------------
/**
- * Prints one signature consisting of classes.
+ * Returns signatures matching the given signature of [[TypeInformation]].
+ * Elements of the signature can be null (act as a wildcard).
*/
- def signatureToString(signature: Array[Class[_]]): String =
- "(" + signature.map { clazz =>
- if (clazz == null) {
- "null"
- } else {
- clazz.getCanonicalName
- }
- }.mkString(", ") + ")"
+ def getSignature(
+ function: UserDefinedFunction,
+ signature: Seq[TypeInformation[_]])
+ : Option[Array[Class[_]]] = {
+ // We compare the raw Java classes not the TypeInformation.
+ // TypeInformation does not matter during runtime (e.g. within a MapFunction).
+ val actualSignature = typeInfoToClass(signature)
+ val signatures = getSignatures(function)
+
+ signatures
+ // go over all signatures and find one matching actual signature
+ .find { curSig =>
+ // match parameters of signature to actual parameters
+ actualSignature.length == curSig.length &&
+ curSig.zipWithIndex.forall { case (clazz, i) =>
+ parameterTypeEquals(actualSignature(i), clazz)
+ }
+ }
+ }
/**
- * Prints one signature consisting of TypeInformation.
+ * Returns eval method matching the given signature of [[TypeInformation]].
*/
- def signatureToString(signature: Seq[TypeInformation[_]]): String = {
- signatureToString(typeInfoToClass(signature))
+ def getEvalMethod(
+ function: UserDefinedFunction,
+ signature: Seq[TypeInformation[_]])
+ : Option[Method] = {
+ // We compare the raw Java classes not the TypeInformation.
+ // TypeInformation does not matter during runtime (e.g. within a MapFunction).
+ val actualSignature = typeInfoToClass(signature)
+ val evalMethods = checkAndExtractEvalMethods(function)
+
+ evalMethods
+ // go over all eval methods and find one matching
+ .find { cur =>
+ val signatures = cur.getParameterTypes
+ // match parameters of signature to actual parameters
+ actualSignature.length == signatures.length &&
+ signatures.zipWithIndex.forall { case (clazz, i) =>
+ parameterTypeEquals(actualSignature(i), clazz)
+ }
+ }
}
/**
- * Extracts type classes of [[TypeInformation]] in a null-aware way.
+ * Extracts "eval" methods and throws a [[ValidationException]] if no implementation
+ * can be found.
*/
- def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
- typeInfos.map { typeInfo =>
- if (typeInfo == null) {
- null
- } else {
- typeInfo.getTypeClass
+ def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = {
+ val methods = function
+ .getClass
+ .getDeclaredMethods
+ .filter { m =>
+ val modifiers = m.getModifiers
+ m.getName == "eval" && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers)
}
- }.toArray
+ if (methods.isEmpty) {
+ throw new ValidationException(
+ s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
+ s"one method named 'eval' which is public and not abstract.")
+ } else {
+ methods
+ }
+ }
+
+ def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = {
+ checkAndExtractEvalMethods(function).map(_.getParameterTypes)
+ }
+
+ // ----------------------------------------------------------------------------------------------
+ // Utilities for sql functions
+ // ----------------------------------------------------------------------------------------------
/**
- * Compares parameter candidate classes with expected classes. If true, the parameters match.
- * Candidate can be null (acts as a wildcard).
+ * Create [[SqlFunction]] for a [[ScalarFunction]]
+ * @param name function name
+ * @param function scalar function
+ * @param typeFactory type factory
+ * @return the ScalarSqlFunction
*/
- def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
- candidate == null ||
- candidate == expected ||
- expected.isPrimitive && Primitives.wrap(expected) == candidate ||
- candidate == classOf[Date] && expected == classOf[Int] ||
- candidate == classOf[Time] && expected == classOf[Int] ||
- candidate == classOf[Timestamp] && expected == classOf[Long]
+ def createScalarSqlFunction(
+ name: String,
+ function: ScalarFunction,
+ typeFactory: FlinkTypeFactory)
+ : SqlFunction = {
+ new ScalarSqlFunction(name, function, typeFactory)
+ }
/**
- * Returns signatures matching the given signature of [[TypeInformation]].
- * Elements of the signature can be null (act as a wildcard).
+ * Create [[SqlFunction]]s for a [[TableFunction]]'s every eval method
+ * @param name function name
+ * @param tableFunction table function
+ * @param resultType the type information of returned table
+ * @param typeFactory type factory
+ * @return the TableSqlFunction
*/
- def getSignature(
- scalarFunction: ScalarFunction,
- signature: Seq[TypeInformation[_]])
- : Option[Array[Class[_]]] = {
- // We compare the raw Java classes not the TypeInformation.
- // TypeInformation does not matter during runtime (e.g. within a MapFunction).
- val actualSignature = typeInfoToClass(signature)
+ def createTableSqlFunctions(
+ name: String,
+ tableFunction: TableFunction[_],
+ resultType: TypeInformation[_],
+ typeFactory: FlinkTypeFactory)
+ : Seq[SqlFunction] = {
+ val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType)
+ val evalMethods = checkAndExtractEvalMethods(tableFunction)
- scalarFunction
- .getSignatures
- // go over all signatures and find one matching actual signature
- .find { curSig =>
- // match parameters of signature to actual parameters
- actualSignature.length == curSig.length &&
- curSig.zipWithIndex.forall { case (clazz, i) =>
- parameterTypeEquals(actualSignature(i), clazz)
- }
- }
+ evalMethods.map { method =>
+ val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method)
+ TableSqlFunction(name, tableFunction, resultType, typeFactory, function)
+ }
}
+ // ----------------------------------------------------------------------------------------------
+ // Utilities for scalar functions
+ // ----------------------------------------------------------------------------------------------
+
/**
* Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses
* [[TypeExtractor]] as default return type inference.
*/
def getResultType(
- scalarFunction: ScalarFunction,
+ function: ScalarFunction,
signature: Array[Class[_]])
: TypeInformation[_] = {
// find method for signature
- val evalMethod = scalarFunction.getEvalMethods
+ val evalMethod = checkAndExtractEvalMethods(function)
.find(m => signature.sameElements(m.getParameterTypes))
.getOrElse(throw new ValidationException("Given signature is invalid."))
- val userDefinedTypeInfo = scalarFunction.getResultType(signature)
+ val userDefinedTypeInfo = function.getResultType(signature)
if (userDefinedTypeInfo != null) {
- userDefinedTypeInfo
+ userDefinedTypeInfo
} else {
try {
TypeExtractor.getForClass(evalMethod.getReturnType)
} catch {
case ite: InvalidTypesException =>
- throw new ValidationException(s"Return type of scalar function '$this' cannot be " +
- s"automatically determined. Please provide type information manually.")
+ throw new ValidationException(
+ s"Return type of scalar function '${function.getClass.getCanonicalName}' cannot be " +
+ s"automatically determined. Please provide type information manually.")
}
}
}
@@ -165,21 +235,100 @@ object UserDefinedFunctionUtils {
* Returns the return type of the evaluation method matching the given signature.
*/
def getResultTypeClass(
- scalarFunction: ScalarFunction,
+ function: ScalarFunction,
signature: Array[Class[_]])
: Class[_] = {
// find method for signature
- val evalMethod = scalarFunction.getEvalMethods
+ val evalMethod = checkAndExtractEvalMethods(function)
.find(m => signature.sameElements(m.getParameterTypes))
.getOrElse(throw new IllegalArgumentException("Given signature is invalid."))
evalMethod.getReturnType
}
+ // ----------------------------------------------------------------------------------------------
+ // Miscellaneous
+ // ----------------------------------------------------------------------------------------------
+
/**
- * Prints all signatures of a [[ScalarFunction]].
+ * Returns field names and field positions for a given [[TypeInformation]].
+ *
+ * Field names are automatically extracted for
+ * [[org.apache.flink.api.common.typeutils.CompositeType]].
+ *
+ * @param inputType The TypeInformation extract the field names and positions from.
+ * @return A tuple of two arrays holding the field names and corresponding field positions.
*/
- def signaturesToString(scalarFunction: ScalarFunction): String = {
- scalarFunction.getSignatures.map(signatureToString).mkString(", ")
+ def getFieldInfo(inputType: TypeInformation[_])
+ : (Array[String], Array[Int], Array[TypeInformation[_]]) = {
+
+ val fieldNames: Array[String] = inputType match {
+ case t: CompositeType[_] => t.getFieldNames
+ case a: AtomicType[_] => Array("f0")
+ case tpe =>
+ throw new TableException(s"Currently only support CompositeType and AtomicType. " +
+ s"Type $tpe lacks explicit field naming")
+ }
+ val fieldIndexes = fieldNames.indices.toArray
+ val fieldTypes: Array[TypeInformation[_]] = fieldNames.map { i =>
+ inputType match {
+ case t: CompositeType[_] => t.getTypeAt(i).asInstanceOf[TypeInformation[_]]
+ case a: AtomicType[_] => a.asInstanceOf[TypeInformation[_]]
+ case tpe =>
+ throw new TableException(s"Currently only support CompositeType and AtomicType.")
+ }
+ }
+ (fieldNames, fieldIndexes, fieldTypes)
}
+ /**
+ * Prints one signature consisting of classes.
+ */
+ def signatureToString(signature: Array[Class[_]]): String =
+ signature.map { clazz =>
+ if (clazz == null) {
+ "null"
+ } else {
+ clazz.getCanonicalName
+ }
+ }.mkString("(", ", ", ")")
+
+ /**
+ * Prints one signature consisting of TypeInformation.
+ */
+ def signatureToString(signature: Seq[TypeInformation[_]]): String = {
+ signatureToString(typeInfoToClass(signature))
+ }
+
+ /**
+ * Prints all eval methods signatures of a class.
+ */
+ def signaturesToString(function: UserDefinedFunction): String = {
+ getSignatures(function).map(signatureToString).mkString(", ")
+ }
+
+ /**
+ * Extracts type classes of [[TypeInformation]] in a null-aware way.
+ */
+ private def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
+ typeInfos.map { typeInfo =>
+ if (typeInfo == null) {
+ null
+ } else {
+ typeInfo.getTypeClass
+ }
+ }.toArray
+
+
+ /**
+ * Compares parameter candidate classes with expected classes. If true, the parameters match.
+ * Candidate can be null (acts as a wildcard).
+ */
+ private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
+ candidate == null ||
+ candidate == expected ||
+ expected.isPrimitive && Primitives.wrap(expected) == candidate ||
+ candidate == classOf[Date] && expected == classOf[Int] ||
+ candidate == classOf[Time] && expected == classOf[Int] ||
+ candidate == classOf[Timestamp] && expected == classOf[Long]
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
index cd22f6a..f6ddeef 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
@@ -122,10 +122,10 @@ object ProjectionTranslator {
case prop: WindowProperty =>
val name = propNames(prop)
Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
- case n @ Alias(agg: Aggregation, name) =>
+ case n @ Alias(agg: Aggregation, name, _) =>
val aName = aggNames(agg)
Alias(UnresolvedFieldReference(aName), name)
- case n @ Alias(prop: WindowProperty, name) =>
+ case n @ Alias(prop: WindowProperty, name, _) =>
val pName = propNames(prop)
Alias(UnresolvedFieldReference(pName), name)
case l: LeafExpression => l
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
index ecf1996..4dc2ab7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
@@ -17,9 +17,13 @@
*/
package org.apache.flink.api.table.plan.logical
+import java.lang.reflect.Method
+import java.util
+
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.logical.LogicalProject
+import org.apache.calcite.rel.core.CorrelationId
+import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableFunctionScan}
import org.apache.calcite.rex.{RexInputRef, RexNode}
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
@@ -27,6 +31,10 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table._
import org.apache.flink.api.table.expressions._
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.api.table.typeutils.TypeConverter
import org.apache.flink.api.table.validate.{ValidationFailure, ValidationSuccess}
@@ -216,7 +224,7 @@ case class Aggregate(
relBuilder.aggregate(
relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
aggregateExpressions.map {
- case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder)
+ case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
}.asJava)
}
@@ -361,7 +369,8 @@ case class Join(
left: LogicalNode,
right: LogicalNode,
joinType: JoinType,
- condition: Option[Expression]) extends BinaryNode {
+ condition: Option[Expression],
+ correlated: Boolean) extends BinaryNode {
override def output: Seq[Attribute] = {
left.output ++ right.output
@@ -411,22 +420,31 @@ case class Join(
right)
}
val resolvedCondition = node.condition.map(_.postOrderTransform(partialFunction))
- Join(node.left, node.right, node.joinType, resolvedCondition)
+ Join(node.left, node.right, node.joinType, resolvedCondition, correlated)
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
left.construct(relBuilder)
right.construct(relBuilder)
+
+ val corSet = mutable.Set[CorrelationId]()
+
+ if (correlated) {
+ corSet += relBuilder.peek().getCluster.createCorrel()
+ }
+
relBuilder.join(
TypeConverter.flinkJoinTypeToRelType(joinType),
- condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)))
+ condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)),
+ corSet.asJava)
}
private def ambiguousName: Set[String] =
left.output.map(_.name).toSet.intersect(right.output.map(_.name).toSet)
override def validate(tableEnv: TableEnvironment): LogicalNode = {
- if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]
+ && !right.isInstanceOf[LogicalTableFunctionCall]) {
failValidation(s"Join on stream tables is currently not supported.")
}
@@ -551,11 +569,11 @@ case class WindowAggregate(
window,
relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
propertyExpressions.map {
- case Alias(prop: WindowProperty, name) => prop.toNamedWindowProperty(name)(relBuilder)
+ case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
},
aggregateExpressions.map {
- case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder)
+ case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
}.asJava)
}
@@ -605,3 +623,71 @@ case class WindowAggregate(
resolvedWindowAggregate
}
}
+
+
+/**
+ * LogicalNode for calling a user-defined table functions.
+ * @param functionName function name
+ * @param tableFunction table function to be called (might be overloaded)
+ * @param parameters actual parameters
+ * @param fieldNames output field names
+ * @param child child logical node
+ */
+case class LogicalTableFunctionCall(
+ functionName: String,
+ tableFunction: TableFunction[_],
+ parameters: Seq[Expression],
+ resultType: TypeInformation[_],
+ fieldNames: Array[String],
+ child: LogicalNode)
+ extends UnaryNode {
+
+ val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType)
+ var evalMethod: Method = _
+
+ override def output: Seq[Attribute] = fieldNames.zip(fieldTypes).map {
+ case (n, t) => ResolvedFieldReference(n, t)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall]
+ // check not Scala object
+ checkNotSingleton(tableFunction.getClass)
+ // check could be instantiated
+ checkForInstantiation(tableFunction.getClass)
+ // look for a signature that matches the input types
+ val signature = node.parameters.map(_.resultType)
+ val foundMethod = getEvalMethod(tableFunction, signature)
+ if (foundMethod.isEmpty) {
+ failValidation(
+ s"Given parameters of function '$functionName' do not match any signature. \n" +
+ s"Actual: ${signatureToString(signature)} \n" +
+ s"Expected: ${signaturesToString(tableFunction)}")
+ } else {
+ node.evalMethod = foundMethod.get
+ }
+ node
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ val fieldIndexes = getFieldInfo(resultType)._2
+ val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod)
+ val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ val sqlFunction = TableSqlFunction(
+ tableFunction.toString,
+ tableFunction,
+ resultType,
+ typeFactory,
+ function)
+
+ val scan = LogicalTableFunctionScan.create(
+ relBuilder.peek().getCluster,
+ new util.ArrayList[RelNode](),
+ relBuilder.call(sqlFunction, parameters.map(_.toRexNode(relBuilder)).asJava),
+ function.getElementType(null),
+ function.getRowType(relBuilder.getTypeFactory, null),
+ null)
+
+ relBuilder.push(scan)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
new file mode 100644
index 0000000..9745be1
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression, GeneratedFunction}
+import org.apache.flink.api.table.codegen.CodeGenUtils.primitiveDefaultValue
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.runtime.FlatMapRunner
+import org.apache.flink.api.table.typeutils.TypeConverter._
+import org.apache.flink.api.table.{TableConfig, TableException}
+
+import scala.collection.JavaConverters._
+
+/**
+ * cross/outer apply a user-defined table function
+ */
+trait FlinkCorrelate {
+
+ private[flink] def functionBody(
+ generator: CodeGenerator,
+ udtfTypeInfo: TypeInformation[Any],
+ rowType: RelDataType,
+ rexCall: RexCall,
+ condition: Option[RexNode],
+ config: TableConfig,
+ joinType: SemiJoinType,
+ expectedType: Option[TypeInformation[Any]]): String = {
+
+ val returnType = determineReturnType(
+ rowType,
+ expectedType,
+ config.getNullCheck,
+ config.getEfficientTypeUsage)
+
+ val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs
+
+ val call = generator.generateExpression(rexCall)
+ var body =
+ s"""
+ |${call.code}
+ |java.util.Iterator iter = ${call.resultTerm}.getRowsIterator();
+ """.stripMargin
+
+ if (joinType == SemiJoinType.INNER) {
+ // cross apply
+ body +=
+ s"""
+ |if (!iter.hasNext()) {
+ | return;
+ |}
+ """.stripMargin
+ } else if (joinType == SemiJoinType.LEFT) {
+ // outer apply
+
+ // in case of outer apply and the returned row of table function is empty,
+ // fill null to all fields of the row
+ val input2NullExprs = input2AccessExprs.map { x =>
+ GeneratedExpression(
+ primitiveDefaultValue(x.resultType),
+ GeneratedExpression.ALWAYS_NULL,
+ "",
+ x.resultType)
+ }
+ val outerResultExpr = generator.generateResultExpression(
+ input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala)
+ body +=
+ s"""
+ |if (!iter.hasNext()) {
+ | ${outerResultExpr.code}
+ | ${generator.collectorTerm}.collect(${outerResultExpr.resultTerm});
+ | return;
+ |}
+ """.stripMargin
+ } else {
+ throw TableException(s"Unsupported SemiJoinType: $joinType for correlate join.")
+ }
+
+ val crossResultExpr = generator.generateResultExpression(
+ input1AccessExprs ++ input2AccessExprs,
+ returnType,
+ rowType.getFieldNames.asScala)
+
+ val projection = if (condition.isEmpty) {
+ s"""
+ |${crossResultExpr.code}
+ |${generator.collectorTerm}.collect(${crossResultExpr.resultTerm});
+ """.stripMargin
+ } else {
+ val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo)
+ filterGenerator.input1Term = filterGenerator.input2Term
+ val filterCondition = filterGenerator.generateExpression(condition.get)
+ s"""
+ |${filterGenerator.reuseInputUnboxingCode()}
+ |${filterCondition.code}
+ |if (${filterCondition.resultTerm}) {
+ | ${crossResultExpr.code}
+ | ${generator.collectorTerm}.collect(${crossResultExpr.resultTerm});
+ |}
+ |""".stripMargin
+ }
+
+ val outputTypeClass = udtfTypeInfo.getTypeClass.getCanonicalName
+ body +=
+ s"""
+ |while (iter.hasNext()) {
+ | $outputTypeClass ${generator.input2Term} = ($outputTypeClass) iter.next();
+ | $projection
+ |}
+ """.stripMargin
+ body
+ }
+
+ private[flink] def correlateMapFunction(
+ genFunction: GeneratedFunction[FlatMapFunction[Any, Any]])
+ : FlatMapRunner[Any, Any] = {
+
+ new FlatMapRunner[Any, Any](
+ genFunction.name,
+ genFunction.code,
+ genFunction.returnType)
+ }
+
+ private[flink] def selectToString(rowType: RelDataType): String = {
+ rowType.getFieldNames.asScala.mkString(",")
+ }
+
+ private[flink] def correlateOpName(
+ rexCall: RexCall,
+ sqlFunction: TableSqlFunction,
+ rowType: RelDataType)
+ : String = {
+
+ s"correlate: ${correlateToString(rexCall, sqlFunction)}, select: ${selectToString(rowType)}"
+ }
+
+ private[flink] def correlateToString(rexCall: RexCall, sqlFunction: TableSqlFunction): String = {
+ val udtfName = sqlFunction.getName
+ val operands = rexCall.getOperands.asScala.map(_.toString).mkString(",")
+ s"table($udtfName($operands))"
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
new file mode 100644
index 0000000..4aa7fea
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes.dataset
+
+import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan
+import org.apache.calcite.rel.metadata.RelMetadataQuery
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rex.{RexNode, RexCall}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.table.BatchTableEnvironment
+import org.apache.flink.api.table.codegen.CodeGenerator
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.plan.nodes.FlinkCorrelate
+import org.apache.flink.api.table.typeutils.TypeConverter._
+
+/**
+ * Flink RelNode which matches along with cross apply a user defined table function.
+ */
+class DataSetCorrelate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ inputNode: RelNode,
+ scan: LogicalTableFunctionScan,
+ condition: Option[RexNode],
+ relRowType: RelDataType,
+ joinRowType: RelDataType,
+ joinType: SemiJoinType,
+ ruleDescription: String)
+ extends SingleRel(cluster, traitSet, inputNode)
+ with FlinkCorrelate
+ with DataSetRel {
+
+ override def deriveRowType() = relRowType
+
+ override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
+ val rowCnt = metadata.getRowCount(getInput) * 1.5
+ planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * 0.5)
+ }
+
+ override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+ new DataSetCorrelate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ scan,
+ condition,
+ relRowType,
+ joinRowType,
+ joinType,
+ ruleDescription)
+ }
+
+ override def toString: String = {
+ val rexCall = scan.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ correlateToString(rexCall, sqlFunction)
+ }
+
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ val rexCall = scan.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ super.explainTerms(pw)
+ .item("invocation", scan.getCall)
+ .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
+ .item("rowType", relRowType)
+ .item("joinType", joinType)
+ .itemIf("condition", condition.orNull, condition.isDefined)
+ }
+
+ override def translateToPlan(
+ tableEnv: BatchTableEnvironment,
+ expectedType: Option[TypeInformation[Any]])
+ : DataSet[Any] = {
+
+ val config = tableEnv.getConfig
+ val returnType = determineReturnType(
+ getRowType,
+ expectedType,
+ config.getNullCheck,
+ config.getEfficientTypeUsage)
+
+ // do not need to specify input type
+ val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
+
+ val funcRel = scan.asInstanceOf[LogicalTableFunctionScan]
+ val rexCall = funcRel.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ val pojoFieldMapping = sqlFunction.getPojoFieldMapping
+ val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
+
+ val generator = new CodeGenerator(
+ config,
+ false,
+ inputDS.getType,
+ Some(udtfTypeInfo),
+ None,
+ Some(pojoFieldMapping))
+
+ val body = functionBody(
+ generator,
+ udtfTypeInfo,
+ getRowType,
+ rexCall,
+ condition,
+ config,
+ joinType,
+ expectedType)
+
+ val genFunction = generator.generateFunction(
+ ruleDescription,
+ classOf[FlatMapFunction[Any, Any]],
+ body,
+ returnType)
+
+ val mapFunc = correlateMapFunction(genFunction)
+
+ inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType))
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
new file mode 100644
index 0000000..b0bc48a
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes.datastream
+
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.StreamTableEnvironment
+import org.apache.flink.api.table.codegen.CodeGenerator
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.plan.nodes.FlinkCorrelate
+import org.apache.flink.api.table.typeutils.TypeConverter._
+import org.apache.flink.streaming.api.datastream.DataStream
+
+/**
+ * Flink RelNode which matches along with cross apply a user defined table function.
+ */
+class DataStreamCorrelate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ inputNode: RelNode,
+ scan: LogicalTableFunctionScan,
+ condition: Option[RexNode],
+ relRowType: RelDataType,
+ joinRowType: RelDataType,
+ joinType: SemiJoinType,
+ ruleDescription: String)
+ extends SingleRel(cluster, traitSet, inputNode)
+ with FlinkCorrelate
+ with DataStreamRel {
+ override def deriveRowType() = relRowType
+
+ override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+ new DataStreamCorrelate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ scan,
+ condition,
+ relRowType,
+ joinRowType,
+ joinType,
+ ruleDescription)
+ }
+
+ override def toString: String = {
+ val rexCall = scan.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ correlateToString(rexCall, sqlFunction)
+ }
+
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ val rexCall = scan.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ super.explainTerms(pw)
+ .item("invocation", scan.getCall)
+ .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
+ .item("rowType", relRowType)
+ .item("joinType", joinType)
+ .itemIf("condition", condition.orNull, condition.isDefined)
+ }
+
+ override def translateToPlan(
+ tableEnv: StreamTableEnvironment,
+ expectedType: Option[TypeInformation[Any]])
+ : DataStream[Any] = {
+
+ val config = tableEnv.getConfig
+ val returnType = determineReturnType(
+ getRowType,
+ expectedType,
+ config.getNullCheck,
+ config.getEfficientTypeUsage)
+
+ // do not need to specify input type
+ val inputDS = inputNode.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
+
+ val funcRel = scan.asInstanceOf[LogicalTableFunctionScan]
+ val rexCall = funcRel.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ val pojoFieldMapping = sqlFunction.getPojoFieldMapping
+ val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
+
+ val generator = new CodeGenerator(
+ config,
+ false,
+ inputDS.getType,
+ Some(udtfTypeInfo),
+ None,
+ Some(pojoFieldMapping))
+
+ val body = functionBody(
+ generator,
+ udtfTypeInfo,
+ getRowType,
+ rexCall,
+ condition,
+ config,
+ joinType,
+ expectedType)
+
+ val genFunction = generator.generateFunction(
+ ruleDescription,
+ classOf[FlatMapFunction[Any, Any]],
+ body,
+ returnType)
+
+ val mapFunc = correlateMapFunction(genFunction)
+
+ inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType))
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
index 9e20df4..6847425 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
@@ -108,6 +108,7 @@ object FlinkRuleSets {
DataSetMinusRule.INSTANCE,
DataSetSortRule.INSTANCE,
DataSetValuesRule.INSTANCE,
+ DataSetCorrelateRule.INSTANCE,
BatchTableSourceScanRule.INSTANCE
)
@@ -151,6 +152,7 @@ object FlinkRuleSets {
DataStreamScanRule.INSTANCE,
DataStreamUnionRule.INSTANCE,
DataStreamValuesRule.INSTANCE,
+ DataStreamCorrelateRule.INSTANCE,
StreamTableSourceScanRule.INSTANCE
)
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
new file mode 100644
index 0000000..e6cf0cf
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.rules.dataSet
+
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.logical.{LogicalFilter, LogicalCorrelate, LogicalTableFunctionScan}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.table.plan.nodes.dataset.{DataSetConvention, DataSetCorrelate}
+
+/**
+ * Rule to convert a LogicalCorrelate into a DataSetCorrelate.
+ */
+class DataSetCorrelateRule
+ extends ConverterRule(
+ classOf[LogicalCorrelate],
+ Convention.NONE,
+ DataSetConvention.INSTANCE,
+ "DataSetCorrelateRule")
+ {
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate]
+ val right = join.getRight.asInstanceOf[RelSubset].getOriginal
+
+
+ right match {
+ // right node is a table function
+ case scan: LogicalTableFunctionScan => true
+ // a filter is pushed above the table function
+ case filter: LogicalFilter =>
+ filter.getInput.asInstanceOf[RelSubset].getOriginal
+ .isInstanceOf[LogicalTableFunctionScan]
+ case _ => false
+ }
+ }
+
+ override def convert(rel: RelNode): RelNode = {
+ val join: LogicalCorrelate = rel.asInstanceOf[LogicalCorrelate]
+ val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
+ val convInput: RelNode = RelOptRule.convert(join.getInput(0), DataSetConvention.INSTANCE)
+ val right: RelNode = join.getInput(1)
+
+ def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): DataSetCorrelate = {
+ relNode match {
+ case rel: RelSubset =>
+ convertToCorrelate(rel.getRelList.get(0), condition)
+
+ case filter: LogicalFilter =>
+ convertToCorrelate(
+ filter.getInput.asInstanceOf[RelSubset].getOriginal,
+ Some(filter.getCondition))
+
+ case scan: LogicalTableFunctionScan =>
+ new DataSetCorrelate(
+ rel.getCluster,
+ traitSet,
+ convInput,
+ scan,
+ condition,
+ rel.getRowType,
+ join.getRowType,
+ join.getJoinType,
+ description)
+ }
+ }
+ convertToCorrelate(right, None)
+ }
+ }
+
+object DataSetCorrelateRule {
+ val INSTANCE: RelOptRule = new DataSetCorrelateRule
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
new file mode 100644
index 0000000..bb52fd7
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.rules.datastream
+
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.logical.{LogicalFilter, LogicalCorrelate, LogicalTableFunctionScan}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.table.plan.nodes.datastream.{DataStreamCorrelate, DataStreamConvention}
+
+/**
+ * Rule to convert a LogicalCorrelate into a DataStreamCorrelate.
+ */
+class DataStreamCorrelateRule
+ extends ConverterRule(
+ classOf[LogicalCorrelate],
+ Convention.NONE,
+ DataStreamConvention.INSTANCE,
+ "DataStreamCorrelateRule")
+{
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate]
+ val right = join.getRight.asInstanceOf[RelSubset].getOriginal
+
+ right match {
+ // right node is a table function
+ case scan: LogicalTableFunctionScan => true
+ // a filter is pushed above the table function
+ case filter: LogicalFilter =>
+ filter.getInput.asInstanceOf[RelSubset].getOriginal
+ .isInstanceOf[LogicalTableFunctionScan]
+ case _ => false
+ }
+ }
+
+ override def convert(rel: RelNode): RelNode = {
+ val join: LogicalCorrelate = rel.asInstanceOf[LogicalCorrelate]
+ val traitSet: RelTraitSet = rel.getTraitSet.replace(DataStreamConvention.INSTANCE)
+ val convInput: RelNode = RelOptRule.convert(join.getInput(0), DataStreamConvention.INSTANCE)
+ val right: RelNode = join.getInput(1)
+
+ def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): DataStreamCorrelate = {
+ relNode match {
+ case rel: RelSubset =>
+ convertToCorrelate(rel.getRelList.get(0), condition)
+
+ case filter: LogicalFilter =>
+ convertToCorrelate(filter.getInput.asInstanceOf[RelSubset].getOriginal,
+ Some(filter.getCondition))
+
+ case scan: LogicalTableFunctionScan =>
+ new DataStreamCorrelate(
+ rel.getCluster,
+ traitSet,
+ convInput,
+ scan,
+ condition,
+ rel.getRowType,
+ join.getRowType,
+ join.getJoinType,
+ description)
+ }
+ }
+ convertToCorrelate(right, None)
+ }
+
+}
+
+object DataStreamCorrelateRule {
+ val INSTANCE: RelOptRule = new DataStreamCorrelateRule
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
new file mode 100644
index 0000000..540a5c8
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.schema
+
+import java.lang.reflect.{Method, Type}
+import java.util
+
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
+import org.apache.calcite.schema.TableFunction
+import org.apache.calcite.schema.impl.ReflectiveFunctionBase
+import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.table.{FlinkTypeFactory, TableException}
+
+/**
+ * This is heavily inspired by Calcite's [[org.apache.calcite.schema.impl.TableFunctionImpl]].
+ * We need it in order to create a [[org.apache.flink.api.table.functions.utils.TableSqlFunction]].
+ * The main difference is that we override the [[getRowType()]] and [[getElementType()]].
+ */
+class FlinkTableFunctionImpl[T](
+ val typeInfo: TypeInformation[T],
+ val fieldIndexes: Array[Int],
+ val fieldNames: Array[String],
+ val evalMethod: Method)
+ extends ReflectiveFunctionBase(evalMethod)
+ with TableFunction {
+
+ if (fieldIndexes.length != fieldNames.length) {
+ throw new TableException(
+ "Number of field indexes and field names must be equal.")
+ }
+
+ // check uniqueness of field names
+ if (fieldNames.length != fieldNames.toSet.size) {
+ throw new TableException(
+ "Table field names must be unique.")
+ }
+
+ val fieldTypes: Array[TypeInformation[_]] =
+ typeInfo match {
+ case cType: CompositeType[T] =>
+ if (fieldNames.length != cType.getArity) {
+ throw new TableException(
+ s"Arity of type (" + cType.getFieldNames.deep + ") " +
+ "not equal to number of field names " + fieldNames.deep + ".")
+ }
+ fieldIndexes.map(cType.getTypeAt(_).asInstanceOf[TypeInformation[_]])
+ case aType: AtomicType[T] =>
+ if (fieldIndexes.length != 1 || fieldIndexes(0) != 0) {
+ throw new TableException(
+ "Non-composite input type may have only a single field and its index must be 0.")
+ }
+ Array(aType)
+ }
+
+ override def getElementType(arguments: util.List[AnyRef]): Type = classOf[Array[Object]]
+
+ override def getRowType(typeFactory: RelDataTypeFactory,
+ arguments: util.List[AnyRef]): RelDataType = {
+ val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory]
+ val builder = flinkTypeFactory.builder
+ fieldNames
+ .zip(fieldTypes)
+ .foreach { f =>
+ builder.add(f._1, flinkTypeFactory.createTypeFromTypeInfo(f._2)).nullable(true)
+ }
+ builder.build
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
index c45e871..a75f2fc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
@@ -20,7 +20,8 @@ package org.apache.flink.api.table
import org.apache.calcite.rel.RelNode
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
-import org.apache.flink.api.table.expressions.{Asc, Expression, ExpressionParser, Ordering}
+import org.apache.flink.api.table.plan.logical.Minus
+import org.apache.flink.api.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall}
import org.apache.flink.api.table.plan.ProjectionTranslator._
import org.apache.flink.api.table.plan.logical._
import org.apache.flink.api.table.sinks.TableSink
@@ -400,7 +401,8 @@ class Table(
}
new Table(
tableEnv,
- Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate).validate(tableEnv))
+ Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate, correlated = false)
+ .validate(tableEnv))
}
/**
@@ -609,6 +611,126 @@ class Table(
}
/**
+ * The Cross Apply returns rows from the outer table (table on the left of the Apply operator)
+ * that produces matching values from the table-valued function (which is on the right side of
+ * the operator).
+ *
+ * The Cross Apply is equivalent to Inner Join, but it works with a table-valued function.
+ *
+ * Example:
+ *
+ * {{{
+ * class MySplitUDTF extends TableFunction[String] {
+ * def eval(str: String): Unit = {
+ * str.split("#").foreach(collect)
+ * }
+ * }
+ *
+ * val split = new MySplitUDTF()
+ * table.crossApply(split('c) as ('s)).select('a,'b,'c,'s)
+ * }}}
+ */
+ def crossApply(udtf: Expression): Table = {
+ applyInternal(udtf, JoinType.INNER)
+ }
+
+ /**
+ * The Cross Apply returns rows from the outer table (table on the left of the Apply operator)
+ * that produces matching values from the table-valued function (which is on the right side of
+ * the operator).
+ *
+ * The Cross Apply is equivalent to Inner Join, but it works with a table-valued function.
+ *
+ * Example:
+ *
+ * {{{
+ * class MySplitUDTF extends TableFunction[String] {
+ * def eval(str: String): Unit = {
+ * str.split("#").foreach(collect)
+ * }
+ * }
+ *
+ * val split = new MySplitUDTF()
+ * table.crossApply("split(c) as (s)").select("a, b, c, s")
+ * }}}
+ */
+ def crossApply(udtf: String): Table = {
+ applyInternal(udtf, JoinType.INNER)
+ }
+
+ /**
+ * The Outer Apply returns all the rows from the outer table (table on the left of the Apply
+ * operator), and rows that do not matches the condition from the table-valued function (which
+ * is on the right side of the operator), NULL values are displayed.
+ *
+ * The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function.
+ *
+ * Example:
+ *
+ * {{{
+ * class MySplitUDTF extends TableFunction[String] {
+ * def eval(str: String): Unit = {
+ * str.split("#").foreach(collect)
+ * }
+ * }
+ *
+ * val split = new MySplitUDTF()
+ * table.outerApply(split('c) as ('s)).select('a,'b,'c,'s)
+ * }}}
+ */
+ def outerApply(udtf: Expression): Table = {
+ applyInternal(udtf, JoinType.LEFT_OUTER)
+ }
+
+ /**
+ * The Outer Apply returns all the rows from the outer table (table on the left of the Apply
+ * operator), and rows that do not matches the condition from the table-valued function (which
+ * is on the right side of the operator), NULL values are displayed.
+ *
+ * The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function.
+ *
+ * Example:
+ *
+ * {{{
+ * val split = new MySplitUDTF()
+ * table.outerApply("split(c) as (s)").select("a, b, c, s")
+ * }}}
+ */
+ def outerApply(udtf: String): Table = {
+ applyInternal(udtf, JoinType.LEFT_OUTER)
+ }
+
+ private def applyInternal(udtfString: String, joinType: JoinType): Table = {
+ val udtf = ExpressionParser.parseExpression(udtfString)
+ applyInternal(udtf, joinType)
+ }
+
+ private def applyInternal(udtf: Expression, joinType: JoinType): Table = {
+ var alias: Option[Seq[String]] = None
+
+ // unwrap an Expression until get a TableFunctionCall
+ def unwrap(expr: Expression): TableFunctionCall = expr match {
+ case Alias(child, name, extraNames) =>
+ alias = Some(Seq(name) ++ extraNames)
+ unwrap(child)
+ case Call(name, args) =>
+ val function = tableEnv.getFunctionCatalog.lookupFunction(name, args)
+ unwrap(function)
+ case c: TableFunctionCall => c
+ case _ => throw new TableException("Cross/Outer Apply only accept TableFunction")
+ }
+
+ val call = unwrap(udtf)
+ .as(alias)
+ .toLogicalTableFunctionCall(this.logicalPlan)
+ .validate(tableEnv)
+
+ new Table(
+ tableEnv,
+ Join(this.logicalPlan, call, joinType, None, correlated = true).validate(tableEnv))
+ }
+
+ /**
* Writes the [[Table]] to a [[TableSink]]. A [[TableSink]] defines an external storage location.
*
* A batch [[Table]] can only be written to a
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
index 679733c..4029a7d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
@@ -23,8 +23,8 @@ import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTabl
import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable}
import org.apache.flink.api.table.ValidationException
import org.apache.flink.api.table.expressions._
-import org.apache.flink.api.table.functions.ScalarFunction
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction}
+import org.apache.flink.api.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils}
import scala.collection.JavaConversions._
import scala.collection.mutable
@@ -47,6 +47,20 @@ class FunctionCatalog {
sqlFunctions += sqlFunction
}
+ /** Register multiple sql functions at one time. The functions has the same name. **/
+ def registerSqlFunctions(functions: Seq[SqlFunction]): Unit = {
+ if (functions.nonEmpty) {
+ val name = functions.head.getName
+ // check all name is the same in the functions
+ if (functions.forall(_.getName == name)) {
+ sqlFunctions --= sqlFunctions.filter(_.getName == name)
+ sqlFunctions ++= functions
+ } else {
+ throw ValidationException("The sql functions request to register have different name.")
+ }
+ }
+ }
+
def getSqlOperatorTable: SqlOperatorTable =
ChainedSqlOperatorTable.of(
new BasicOperatorTable(),
@@ -59,14 +73,9 @@ class FunctionCatalog {
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
val funcClass = functionBuilders
.getOrElse(name.toLowerCase, throw ValidationException(s"Undefined function: $name"))
- withChildren(funcClass, children)
- }
- /**
- * Instantiate a function using the provided `children`.
- */
- private def withChildren(func: Class[_], children: Seq[Expression]): Expression = {
- func match {
+ // Instantiate a function using the provided `children`
+ funcClass match {
// user-defined scalar function call
case sf if classOf[ScalarFunction].isAssignableFrom(sf) =>
@@ -75,10 +84,20 @@ class FunctionCatalog {
case Failure(e) => throw ValidationException(e.getMessage)
}
+ // user-defined table function call
+ case tf if classOf[TableFunction[_]].isAssignableFrom(tf) =>
+ val tableSqlFunction = sqlFunctions
+ .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[TableSqlFunction])
+ .getOrElse(throw ValidationException(s"Unregistered table sql function: $name"))
+ .asInstanceOf[TableSqlFunction]
+ val typeInfo = tableSqlFunction.getRowTypeInfo
+ val function = tableSqlFunction.getTableFunction
+ TableFunctionCall(name, function, children, typeInfo)
+
// general expression call
case expression if classOf[Expression].isAssignableFrom(expression) =>
// try to find a constructor accepts `Seq[Expression]`
- Try(func.getDeclaredConstructor(classOf[Seq[_]])) match {
+ Try(funcClass.getDeclaredConstructor(classOf[Seq[_]])) match {
case Success(seqCtor) =>
Try(seqCtor.newInstance(children).asInstanceOf[Expression]) match {
case Success(expr) => expr
@@ -87,14 +106,14 @@ class FunctionCatalog {
case Failure(e) =>
val childrenClass = Seq.fill(children.length)(classOf[Expression])
// try to find a constructor matching the exact number of children
- Try(func.getDeclaredConstructor(childrenClass: _*)) match {
+ Try(funcClass.getDeclaredConstructor(childrenClass: _*)) match {
case Success(ctor) =>
Try(ctor.newInstance(children: _*).asInstanceOf[Expression]) match {
case Success(expr) => expr
case Failure(exception) => throw ValidationException(exception.getMessage)
}
case Failure(exception) =>
- throw ValidationException(s"Invalid number of arguments for function $func")
+ throw ValidationException(s"Invalid number of arguments for function $funcClass")
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
new file mode 100644
index 0000000..7e0d0ff
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.batch
+
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils._
+import org.apache.flink.api.table.{Row, Table, TableEnvironment}
+import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.apache.flink.test.util.TestBaseUtils
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class UserDefinedTableFunctionITCase(
+ mode: TestExecutionMode,
+ configMode: TableConfigMode)
+ extends TableProgramsTestBase(mode, configMode) {
+
+ @Test
+ def testSQLCrossApply(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ tableEnv.registerTable("MyTable", in)
+ tableEnv.registerFunction("split", new TableFunc1)
+
+ val sqlQuery = "SELECT MyTable.c, t.s FROM MyTable, LATERAL TABLE(split(c)) AS t(s)"
+
+ val result = tableEnv.sql(sqlQuery).toDataSet[Row]
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+ "Anna#44,Anna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testSQLOuterApply(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ tableEnv.registerTable("MyTable", in)
+ tableEnv.registerFunction("split", new TableFunc2)
+
+ val sqlQuery = "SELECT MyTable.c, t.a, t.b FROM MyTable LEFT JOIN LATERAL TABLE(split(c)) " +
+ "AS t(a,b) ON TRUE"
+
+ val result = tableEnv.sql(sqlQuery).toDataSet[Row]
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testTableAPICrossApply(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func1 = new TableFunc1
+ val result = in.crossApply(func1('c) as ('s)).select('c, 's).toDataSet[Row]
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+ "Anna#44,Anna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+
+ // with overloading
+ val result2 = in.crossApply(func1('c, "$") as ('s)).select('c, 's).toDataSet[Row]
+ val results2 = result2.collect()
+ val expected2: String = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" +
+ "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n"
+ TestBaseUtils.compareResultAsText(results2.asJava, expected2)
+ }
+
+
+ @Test
+ def testTableAPIOuterApply(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func2 = new TableFunc2
+ val result = in.outerApply(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row]
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ @Test
+ def testCustomReturnType(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func2 = new TableFunc2
+
+ val result = in
+ .crossApply(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testHierarchyType(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val hierarchy = new HierarchyTableFunction
+ val result = in
+ .crossApply(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'adult, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" +
+ "Anna#44,Anna,true,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testPojoType(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val pojo = new PojoTableFunc()
+ val result = in
+ .crossApply(pojo('c))
+ .select('c, 'name, 'age)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ @Test
+ def testTableAPIWithFilter(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = in
+ .crossApply(func0('c) as ('name, 'age))
+ .select('c, 'name, 'age)
+ .filter('age > 20)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ @Test
+ def testUDTFWithScalarFunction(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func1 = new TableFunc1
+
+ val result = in
+ .crossApply(func1('c.substring(2)) as 's)
+ .select('c, 's)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" +
+ "Anna#44,nna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ private def getSmall3TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, String)] = {
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "Jack#22"))
+ data.+=((2, 2L, "John#19"))
+ data.+=((3, 2L, "Anna#44"))
+ data.+=((4, 3L, "nosharp"))
+ env.fromCollection(data)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..7e236d1
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.batch
+
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment => ScalaExecutionEnv, _}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.{DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv}
+import org.apache.flink.api.table.expressions.utils.{HierarchyTableFunction, PojoTableFunc, TableFunc1, TableFunc2}
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.apache.flink.api.table.{Row, TableEnvironment, Types}
+import org.junit.Test
+import org.mockito.Mockito._
+
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+ @Test
+ def testTableAPI(): Unit = {
+ // mock
+ val ds = mock(classOf[DataSet[Row]])
+ val jDs = mock(classOf[JDataSet[Row]])
+ val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
+ when(ds.javaSet).thenReturn(jDs)
+ when(jDs.getType).thenReturn(typeInfo)
+
+ // Scala environment
+ val env = mock(classOf[ScalaExecutionEnv])
+ val tableEnv = TableEnvironment.getTableEnvironment(env)
+ val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
+
+ // Java environment
+ val javaEnv = mock(classOf[JavaExecutionEnv])
+ val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
+ val in2 = javaTableEnv.fromDataSet(jDs).as("a, b, c")
+ javaTableEnv.registerTable("MyTable", in2)
+
+ // test cross apply
+ val func1 = new TableFunc1
+ javaTableEnv.registerFunction("func1", func1)
+ var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's)
+ var javaTable = in2.crossApply("func1(c) as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test outer apply
+ scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's)
+ javaTable = in2.outerApply("func1(c) as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test overloading
+ scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's)
+ javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test custom result type
+ val func2 = new TableFunc2
+ javaTableEnv.registerFunction("func2", func2)
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
+ javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test hierarchy generic type
+ val hierarchy = new HierarchyTableFunction
+ javaTableEnv.registerFunction("hierarchy", hierarchy)
+ scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'len, 'adult)
+ javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)")
+ .select("c, name, len, adult")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test pojo type
+ val pojo = new PojoTableFunc
+ javaTableEnv.registerFunction("pojo", pojo)
+ scalaTable = in1.crossApply(pojo('c))
+ .select('c, 'name, 'age)
+ javaTable = in2.crossApply("pojo(c)")
+ .select("c, name, age")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with filter
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len).filter('len > 2)
+ javaTable = in2.crossApply("func2(c) as (name, len)")
+ .select("c, name, len").filter("len > 2")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with scalar function
+ scalaTable = in1.crossApply(func1('c.substring(2)) as ('s))
+ .select('a, 'c, 's)
+ javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
+ .select("a, c, s")
+ verifyTableEquals(scalaTable, javaTable)
+ }
+
+ @Test
+ def testSQLWithCrossApply(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+
+ // test overloading
+
+ val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
+
+ val expected2 = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c, '$')"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery2, expected2)
+ }
+
+ @Test
+ def testSQLWithOuterApply(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "LEFT")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithCustomType(): Unit = {
+ val util = batchTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithHierarchyType(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new HierarchyTableFunction
+ util.addFunction("hierarchy", function)
+
+ val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "hierarchy($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithPojoType(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new PojoTableFunc
+ util.addFunction("pojo", function)
+
+ val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "pojo($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " INTEGER age, VARCHAR(2147483647) name)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "name", "age")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithFilter(): Unit = {
+ val util = batchTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " +
+ "WHERE len > 2"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER"),
+ term("condition", ">($1, 2)")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+
+ @Test
+ def testSQLWithScalarFunction(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+}