You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2016/12/07 15:57:21 UTC

[2/5] flink git commit: [FLINK-4469] [table] Add support for user defined table function in Table API & SQL

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
index e7416f7..932baeb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -19,14 +19,18 @@
 
 package org.apache.flink.api.table.functions.utils
 
+import java.lang.reflect.{Method, Modifier}
 import java.sql.{Date, Time, Timestamp}
 
 import com.google.common.primitives.Primitives
+import org.apache.calcite.sql.SqlFunction
 import org.apache.flink.api.common.functions.InvalidTypesException
-import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.common.typeutils.CompositeType
 import org.apache.flink.api.java.typeutils.TypeExtractor
-import org.apache.flink.api.table.ValidationException
-import org.apache.flink.api.table.functions.{ScalarFunction, UserDefinedFunction}
+import org.apache.flink.api.table.{FlinkTypeFactory, TableException, ValidationException}
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
+import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
 import org.apache.flink.util.InstantiationUtil
 
 object UserDefinedFunctionUtils {
@@ -62,101 +66,167 @@ object UserDefinedFunctionUtils {
       .getOrElse(throw ValidationException("Function class needs a default constructor."))
   }
 
+  /**
+    * Check whether this is a Scala object. It is forbidden to use [[TableFunction]] implemented
+    * by a Scala object, since concurrent risks.
+    */
+  def checkNotSingleton(clazz: Class[_]): Unit = {
+    // TODO it is not a good way to check singleton. Maybe improve it further.
+    if (clazz.getFields.map(_.getName) contains "MODULE$") {
+      throw new ValidationException(
+        s"TableFunction implemented by class ${clazz.getCanonicalName} " +
+          s"is a Scala object, it is forbidden since concurrent risks.")
+    }
+  }
+
   // ----------------------------------------------------------------------------------------------
-  // Utilities for ScalarFunction
+  // Utilities for eval methods
   // ----------------------------------------------------------------------------------------------
 
   /**
-    * Prints one signature consisting of classes.
+    * Returns signatures matching the given signature of [[TypeInformation]].
+    * Elements of the signature can be null (act as a wildcard).
     */
-  def signatureToString(signature: Array[Class[_]]): String =
-    "(" + signature.map { clazz =>
-      if (clazz == null) {
-        "null"
-      } else {
-        clazz.getCanonicalName
-      }
-    }.mkString(", ") + ")"
+  def getSignature(
+      function: UserDefinedFunction,
+      signature: Seq[TypeInformation[_]])
+    : Option[Array[Class[_]]] = {
+    // We compare the raw Java classes not the TypeInformation.
+    // TypeInformation does not matter during runtime (e.g. within a MapFunction).
+    val actualSignature = typeInfoToClass(signature)
+    val signatures = getSignatures(function)
+
+    signatures
+      // go over all signatures and find one matching actual signature
+      .find { curSig =>
+      // match parameters of signature to actual parameters
+      actualSignature.length == curSig.length &&
+        curSig.zipWithIndex.forall { case (clazz, i) =>
+          parameterTypeEquals(actualSignature(i), clazz)
+        }
+    }
+  }
 
   /**
-    * Prints one signature consisting of TypeInformation.
+    * Returns eval method matching the given signature of [[TypeInformation]].
     */
-  def signatureToString(signature: Seq[TypeInformation[_]]): String = {
-    signatureToString(typeInfoToClass(signature))
+  def getEvalMethod(
+      function: UserDefinedFunction,
+      signature: Seq[TypeInformation[_]])
+    : Option[Method] = {
+    // We compare the raw Java classes not the TypeInformation.
+    // TypeInformation does not matter during runtime (e.g. within a MapFunction).
+    val actualSignature = typeInfoToClass(signature)
+    val evalMethods = checkAndExtractEvalMethods(function)
+
+    evalMethods
+      // go over all eval methods and find one matching
+      .find { cur =>
+      val signatures = cur.getParameterTypes
+      // match parameters of signature to actual parameters
+      actualSignature.length == signatures.length &&
+        signatures.zipWithIndex.forall { case (clazz, i) =>
+          parameterTypeEquals(actualSignature(i), clazz)
+        }
+    }
   }
 
   /**
-    * Extracts type classes of [[TypeInformation]] in a null-aware way.
+    * Extracts "eval" methods and throws a [[ValidationException]] if no implementation
+    * can be found.
     */
-  def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
-    typeInfos.map { typeInfo =>
-      if (typeInfo == null) {
-        null
-      } else {
-        typeInfo.getTypeClass
+  def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = {
+    val methods = function
+      .getClass
+      .getDeclaredMethods
+      .filter { m =>
+        val modifiers = m.getModifiers
+        m.getName == "eval" && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers)
       }
-    }.toArray
 
+    if (methods.isEmpty) {
+      throw new ValidationException(
+        s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
+          s"one method named 'eval' which is public and not abstract.")
+    } else {
+      methods
+    }
+  }
+
+  def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = {
+    checkAndExtractEvalMethods(function).map(_.getParameterTypes)
+  }
+
+  // ----------------------------------------------------------------------------------------------
+  // Utilities for sql functions
+  // ----------------------------------------------------------------------------------------------
 
   /**
-    * Compares parameter candidate classes with expected classes. If true, the parameters match.
-    * Candidate can be null (acts as a wildcard).
+    * Create [[SqlFunction]] for a [[ScalarFunction]]
+    * @param name function name
+    * @param function scalar function
+    * @param typeFactory type factory
+    * @return the ScalarSqlFunction
     */
-  def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
-    candidate == null ||
-      candidate == expected ||
-      expected.isPrimitive && Primitives.wrap(expected) == candidate ||
-      candidate == classOf[Date] && expected == classOf[Int] ||
-      candidate == classOf[Time] && expected == classOf[Int] ||
-      candidate == classOf[Timestamp] && expected == classOf[Long]
+  def createScalarSqlFunction(
+      name: String,
+      function: ScalarFunction,
+      typeFactory: FlinkTypeFactory)
+    : SqlFunction = {
+    new ScalarSqlFunction(name, function, typeFactory)
+  }
 
   /**
-    * Returns signatures matching the given signature of [[TypeInformation]].
-    * Elements of the signature can be null (act as a wildcard).
+    * Create [[SqlFunction]]s for a [[TableFunction]]'s every eval method
+    * @param name function name
+    * @param tableFunction table function
+    * @param resultType the type information of returned table
+    * @param typeFactory type factory
+    * @return the TableSqlFunction
     */
-  def getSignature(
-      scalarFunction: ScalarFunction,
-      signature: Seq[TypeInformation[_]])
-    : Option[Array[Class[_]]] = {
-    // We compare the raw Java classes not the TypeInformation.
-    // TypeInformation does not matter during runtime (e.g. within a MapFunction).
-    val actualSignature = typeInfoToClass(signature)
+  def createTableSqlFunctions(
+      name: String,
+      tableFunction: TableFunction[_],
+      resultType: TypeInformation[_],
+      typeFactory: FlinkTypeFactory)
+    : Seq[SqlFunction] = {
+    val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType)
+    val evalMethods = checkAndExtractEvalMethods(tableFunction)
 
-    scalarFunction
-      .getSignatures
-      // go over all signatures and find one matching actual signature
-      .find { curSig =>
-        // match parameters of signature to actual parameters
-        actualSignature.length == curSig.length &&
-          curSig.zipWithIndex.forall { case (clazz, i) =>
-            parameterTypeEquals(actualSignature(i), clazz)
-          }
-      }
+    evalMethods.map { method =>
+      val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method)
+      TableSqlFunction(name, tableFunction, resultType, typeFactory, function)
+    }
   }
 
+  // ----------------------------------------------------------------------------------------------
+  // Utilities for scalar functions
+  // ----------------------------------------------------------------------------------------------
+
   /**
     * Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses
     * [[TypeExtractor]] as default return type inference.
     */
   def getResultType(
-      scalarFunction: ScalarFunction,
+      function: ScalarFunction,
       signature: Array[Class[_]])
     : TypeInformation[_] = {
     // find method for signature
-    val evalMethod = scalarFunction.getEvalMethods
+    val evalMethod = checkAndExtractEvalMethods(function)
       .find(m => signature.sameElements(m.getParameterTypes))
       .getOrElse(throw new ValidationException("Given signature is invalid."))
 
-    val userDefinedTypeInfo = scalarFunction.getResultType(signature)
+    val userDefinedTypeInfo = function.getResultType(signature)
     if (userDefinedTypeInfo != null) {
-        userDefinedTypeInfo
+      userDefinedTypeInfo
     } else {
       try {
         TypeExtractor.getForClass(evalMethod.getReturnType)
       } catch {
         case ite: InvalidTypesException =>
-          throw new ValidationException(s"Return type of scalar function '$this' cannot be " +
-            s"automatically determined. Please provide type information manually.")
+          throw new ValidationException(
+            s"Return type of scalar function '${function.getClass.getCanonicalName}' cannot be " +
+              s"automatically determined. Please provide type information manually.")
       }
     }
   }
@@ -165,21 +235,100 @@ object UserDefinedFunctionUtils {
     * Returns the return type of the evaluation method matching the given signature.
     */
   def getResultTypeClass(
-      scalarFunction: ScalarFunction,
+      function: ScalarFunction,
       signature: Array[Class[_]])
     : Class[_] = {
     // find method for signature
-    val evalMethod = scalarFunction.getEvalMethods
+    val evalMethod = checkAndExtractEvalMethods(function)
       .find(m => signature.sameElements(m.getParameterTypes))
       .getOrElse(throw new IllegalArgumentException("Given signature is invalid."))
     evalMethod.getReturnType
   }
 
+  // ----------------------------------------------------------------------------------------------
+  // Miscellaneous
+  // ----------------------------------------------------------------------------------------------
+
   /**
-    * Prints all signatures of a [[ScalarFunction]].
+    * Returns field names and field positions for a given [[TypeInformation]].
+    *
+    * Field names are automatically extracted for
+    * [[org.apache.flink.api.common.typeutils.CompositeType]].
+    *
+    * @param inputType The TypeInformation extract the field names and positions from.
+    * @return A tuple of two arrays holding the field names and corresponding field positions.
     */
-  def signaturesToString(scalarFunction: ScalarFunction): String = {
-    scalarFunction.getSignatures.map(signatureToString).mkString(", ")
+  def getFieldInfo(inputType: TypeInformation[_])
+    : (Array[String], Array[Int], Array[TypeInformation[_]]) = {
+
+    val fieldNames: Array[String] = inputType match {
+      case t: CompositeType[_] => t.getFieldNames
+      case a: AtomicType[_] => Array("f0")
+      case tpe =>
+        throw new TableException(s"Currently only support CompositeType and AtomicType. " +
+                                   s"Type $tpe lacks explicit field naming")
+    }
+    val fieldIndexes = fieldNames.indices.toArray
+    val fieldTypes: Array[TypeInformation[_]] = fieldNames.map { i =>
+      inputType match {
+        case t: CompositeType[_] => t.getTypeAt(i).asInstanceOf[TypeInformation[_]]
+        case a: AtomicType[_] => a.asInstanceOf[TypeInformation[_]]
+        case tpe =>
+          throw new TableException(s"Currently only support CompositeType and AtomicType.")
+      }
+    }
+    (fieldNames, fieldIndexes, fieldTypes)
   }
 
+  /**
+    * Prints one signature consisting of classes.
+    */
+  def signatureToString(signature: Array[Class[_]]): String =
+  signature.map { clazz =>
+    if (clazz == null) {
+      "null"
+    } else {
+      clazz.getCanonicalName
+    }
+  }.mkString("(", ", ", ")")
+
+  /**
+    * Prints one signature consisting of TypeInformation.
+    */
+  def signatureToString(signature: Seq[TypeInformation[_]]): String = {
+    signatureToString(typeInfoToClass(signature))
+  }
+
+  /**
+    * Prints all eval methods signatures of a class.
+    */
+  def signaturesToString(function: UserDefinedFunction): String = {
+    getSignatures(function).map(signatureToString).mkString(", ")
+  }
+
+  /**
+    * Extracts type classes of [[TypeInformation]] in a null-aware way.
+    */
+  private def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
+  typeInfos.map { typeInfo =>
+    if (typeInfo == null) {
+      null
+    } else {
+      typeInfo.getTypeClass
+    }
+  }.toArray
+
+
+  /**
+    * Compares parameter candidate classes with expected classes. If true, the parameters match.
+    * Candidate can be null (acts as a wildcard).
+    */
+  private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
+  candidate == null ||
+    candidate == expected ||
+    expected.isPrimitive && Primitives.wrap(expected) == candidate ||
+    candidate == classOf[Date] && expected == classOf[Int] ||
+    candidate == classOf[Time] && expected == classOf[Int] ||
+    candidate == classOf[Timestamp] && expected == classOf[Long]
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
index cd22f6a..f6ddeef 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
@@ -122,10 +122,10 @@ object ProjectionTranslator {
       case prop: WindowProperty =>
         val name = propNames(prop)
         Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
-      case n @ Alias(agg: Aggregation, name) =>
+      case n @ Alias(agg: Aggregation, name, _) =>
         val aName = aggNames(agg)
         Alias(UnresolvedFieldReference(aName), name)
-      case n @ Alias(prop: WindowProperty, name) =>
+      case n @ Alias(prop: WindowProperty, name, _) =>
         val pName = propNames(prop)
         Alias(UnresolvedFieldReference(pName), name)
       case l: LeafExpression => l

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
index ecf1996..4dc2ab7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
@@ -17,9 +17,13 @@
  */
 package org.apache.flink.api.table.plan.logical
 
+import java.lang.reflect.Method
+import java.util
+
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.logical.LogicalProject
+import org.apache.calcite.rel.core.CorrelationId
+import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableFunctionScan}
 import org.apache.calcite.rex.{RexInputRef, RexNode}
 import org.apache.calcite.tools.RelBuilder
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
@@ -27,6 +31,10 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.operators.join.JoinType
 import org.apache.flink.api.table._
 import org.apache.flink.api.table.expressions._
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
 import org.apache.flink.api.table.typeutils.TypeConverter
 import org.apache.flink.api.table.validate.{ValidationFailure, ValidationSuccess}
 
@@ -216,7 +224,7 @@ case class Aggregate(
     relBuilder.aggregate(
       relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
       aggregateExpressions.map {
-        case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder)
+        case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
         case _ => throw new RuntimeException("This should never happen.")
       }.asJava)
   }
@@ -361,7 +369,8 @@ case class Join(
     left: LogicalNode,
     right: LogicalNode,
     joinType: JoinType,
-    condition: Option[Expression]) extends BinaryNode {
+    condition: Option[Expression],
+    correlated: Boolean) extends BinaryNode {
 
   override def output: Seq[Attribute] = {
     left.output ++ right.output
@@ -411,22 +420,31 @@ case class Join(
         right)
     }
     val resolvedCondition = node.condition.map(_.postOrderTransform(partialFunction))
-    Join(node.left, node.right, node.joinType, resolvedCondition)
+    Join(node.left, node.right, node.joinType, resolvedCondition, correlated)
   }
 
   override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
     left.construct(relBuilder)
     right.construct(relBuilder)
+
+    val corSet = mutable.Set[CorrelationId]()
+
+    if (correlated) {
+      corSet += relBuilder.peek().getCluster.createCorrel()
+    }
+
     relBuilder.join(
       TypeConverter.flinkJoinTypeToRelType(joinType),
-      condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)))
+      condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)),
+      corSet.asJava)
   }
 
   private def ambiguousName: Set[String] =
     left.output.map(_.name).toSet.intersect(right.output.map(_.name).toSet)
 
   override def validate(tableEnv: TableEnvironment): LogicalNode = {
-    if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+    if (tableEnv.isInstanceOf[StreamTableEnvironment]
+      && !right.isInstanceOf[LogicalTableFunctionCall]) {
       failValidation(s"Join on stream tables is currently not supported.")
     }
 
@@ -551,11 +569,11 @@ case class WindowAggregate(
       window,
       relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
       propertyExpressions.map {
-        case Alias(prop: WindowProperty, name) => prop.toNamedWindowProperty(name)(relBuilder)
+        case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)(relBuilder)
         case _ => throw new RuntimeException("This should never happen.")
       },
       aggregateExpressions.map {
-        case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder)
+        case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
         case _ => throw new RuntimeException("This should never happen.")
       }.asJava)
   }
@@ -605,3 +623,71 @@ case class WindowAggregate(
     resolvedWindowAggregate
   }
 }
+
+
+/**
+  * LogicalNode for calling a user-defined table functions.
+  * @param functionName function name
+  * @param tableFunction table function to be called (might be overloaded)
+  * @param parameters actual parameters
+  * @param fieldNames output field names
+  * @param child child logical node
+  */
+case class LogicalTableFunctionCall(
+  functionName: String,
+  tableFunction: TableFunction[_],
+  parameters: Seq[Expression],
+  resultType: TypeInformation[_],
+  fieldNames: Array[String],
+  child: LogicalNode)
+  extends UnaryNode {
+
+  val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType)
+  var evalMethod: Method = _
+
+  override def output: Seq[Attribute] = fieldNames.zip(fieldTypes).map {
+    case (n, t) => ResolvedFieldReference(n, t)
+  }
+
+  override def validate(tableEnv: TableEnvironment): LogicalNode = {
+    val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall]
+    // check not Scala object
+    checkNotSingleton(tableFunction.getClass)
+    // check could be instantiated
+    checkForInstantiation(tableFunction.getClass)
+    // look for a signature that matches the input types
+    val signature = node.parameters.map(_.resultType)
+    val foundMethod = getEvalMethod(tableFunction, signature)
+    if (foundMethod.isEmpty) {
+      failValidation(
+        s"Given parameters of function '$functionName' do not match any signature. \n" +
+          s"Actual: ${signatureToString(signature)} \n" +
+          s"Expected: ${signaturesToString(tableFunction)}")
+    } else {
+      node.evalMethod = foundMethod.get
+    }
+    node
+  }
+
+  override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+    val fieldIndexes = getFieldInfo(resultType)._2
+    val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod)
+    val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+    val sqlFunction = TableSqlFunction(
+      tableFunction.toString,
+      tableFunction,
+      resultType,
+      typeFactory,
+      function)
+
+    val scan = LogicalTableFunctionScan.create(
+      relBuilder.peek().getCluster,
+      new util.ArrayList[RelNode](),
+      relBuilder.call(sqlFunction, parameters.map(_.toRexNode(relBuilder)).asJava),
+      function.getElementType(null),
+      function.getRowType(relBuilder.getTypeFactory, null),
+      null)
+
+    relBuilder.push(scan)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
new file mode 100644
index 0000000..9745be1
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression, GeneratedFunction}
+import org.apache.flink.api.table.codegen.CodeGenUtils.primitiveDefaultValue
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.runtime.FlatMapRunner
+import org.apache.flink.api.table.typeutils.TypeConverter._
+import org.apache.flink.api.table.{TableConfig, TableException}
+
+import scala.collection.JavaConverters._
+
+/**
+  * cross/outer apply a user-defined table function
+  */
+trait FlinkCorrelate {
+
+  private[flink] def functionBody(
+      generator: CodeGenerator,
+      udtfTypeInfo: TypeInformation[Any],
+      rowType: RelDataType,
+      rexCall: RexCall,
+      condition: Option[RexNode],
+      config: TableConfig,
+      joinType: SemiJoinType,
+      expectedType: Option[TypeInformation[Any]]): String = {
+
+    val returnType = determineReturnType(
+      rowType,
+      expectedType,
+      config.getNullCheck,
+      config.getEfficientTypeUsage)
+
+    val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs
+
+    val call = generator.generateExpression(rexCall)
+    var body =
+      s"""
+         |${call.code}
+         |java.util.Iterator iter = ${call.resultTerm}.getRowsIterator();
+       """.stripMargin
+
+    if (joinType == SemiJoinType.INNER) {
+      // cross apply
+      body +=
+        s"""
+           |if (!iter.hasNext()) {
+           |  return;
+           |}
+        """.stripMargin
+    } else if (joinType == SemiJoinType.LEFT) {
+      // outer apply
+
+      // in case of outer apply and the returned row of table function is empty,
+      // fill null to all fields of the row
+      val input2NullExprs = input2AccessExprs.map { x =>
+        GeneratedExpression(
+          primitiveDefaultValue(x.resultType),
+          GeneratedExpression.ALWAYS_NULL,
+          "",
+          x.resultType)
+      }
+      val outerResultExpr = generator.generateResultExpression(
+        input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala)
+      body +=
+        s"""
+           |if (!iter.hasNext()) {
+           |  ${outerResultExpr.code}
+           |  ${generator.collectorTerm}.collect(${outerResultExpr.resultTerm});
+           |  return;
+           |}
+        """.stripMargin
+    } else {
+      throw TableException(s"Unsupported SemiJoinType: $joinType for correlate join.")
+    }
+
+    val crossResultExpr = generator.generateResultExpression(
+      input1AccessExprs ++ input2AccessExprs,
+      returnType,
+      rowType.getFieldNames.asScala)
+
+    val projection = if (condition.isEmpty) {
+      s"""
+         |${crossResultExpr.code}
+         |${generator.collectorTerm}.collect(${crossResultExpr.resultTerm});
+       """.stripMargin
+    } else {
+      val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo)
+      filterGenerator.input1Term = filterGenerator.input2Term
+      val filterCondition = filterGenerator.generateExpression(condition.get)
+      s"""
+         |${filterGenerator.reuseInputUnboxingCode()}
+         |${filterCondition.code}
+         |if (${filterCondition.resultTerm}) {
+         |  ${crossResultExpr.code}
+         |  ${generator.collectorTerm}.collect(${crossResultExpr.resultTerm});
+         |}
+         |""".stripMargin
+    }
+
+    val outputTypeClass = udtfTypeInfo.getTypeClass.getCanonicalName
+    body +=
+      s"""
+         |while (iter.hasNext()) {
+         |  $outputTypeClass ${generator.input2Term} = ($outputTypeClass) iter.next();
+         |  $projection
+         |}
+       """.stripMargin
+    body
+  }
+
+  private[flink] def correlateMapFunction(
+      genFunction: GeneratedFunction[FlatMapFunction[Any, Any]])
+    : FlatMapRunner[Any, Any] = {
+
+    new FlatMapRunner[Any, Any](
+      genFunction.name,
+      genFunction.code,
+      genFunction.returnType)
+  }
+
+  private[flink] def selectToString(rowType: RelDataType): String = {
+    rowType.getFieldNames.asScala.mkString(",")
+  }
+
+  private[flink] def correlateOpName(
+      rexCall: RexCall,
+      sqlFunction: TableSqlFunction,
+      rowType: RelDataType)
+    : String = {
+
+    s"correlate: ${correlateToString(rexCall, sqlFunction)}, select: ${selectToString(rowType)}"
+  }
+
+  private[flink] def correlateToString(rexCall: RexCall, sqlFunction: TableSqlFunction): String = {
+    val udtfName = sqlFunction.getName
+    val operands = rexCall.getOperands.asScala.map(_.toString).mkString(",")
+    s"table($udtfName($operands))"
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
new file mode 100644
index 0000000..4aa7fea
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes.dataset
+
+import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan
+import org.apache.calcite.rel.metadata.RelMetadataQuery
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rex.{RexNode, RexCall}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.table.BatchTableEnvironment
+import org.apache.flink.api.table.codegen.CodeGenerator
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.plan.nodes.FlinkCorrelate
+import org.apache.flink.api.table.typeutils.TypeConverter._
+
+/**
+  * Flink RelNode which matches along with cross apply a user defined table function.
+  */
+class DataSetCorrelate(
+    cluster: RelOptCluster,
+    traitSet: RelTraitSet,
+    inputNode: RelNode,
+    scan: LogicalTableFunctionScan,
+    condition: Option[RexNode],
+    relRowType: RelDataType,
+    joinRowType: RelDataType,
+    joinType: SemiJoinType,
+    ruleDescription: String)
+  extends SingleRel(cluster, traitSet, inputNode)
+  with FlinkCorrelate
+  with DataSetRel {
+
+  override def deriveRowType() = relRowType
+
+  override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
+    val rowCnt = metadata.getRowCount(getInput) * 1.5
+    planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * 0.5)
+  }
+
+  override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+    new DataSetCorrelate(
+      cluster,
+      traitSet,
+      inputs.get(0),
+      scan,
+      condition,
+      relRowType,
+      joinRowType,
+      joinType,
+      ruleDescription)
+  }
+
+  override def toString: String = {
+    val rexCall = scan.getCall.asInstanceOf[RexCall]
+    val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+    correlateToString(rexCall, sqlFunction)
+  }
+
+  override def explainTerms(pw: RelWriter): RelWriter = {
+    val rexCall = scan.getCall.asInstanceOf[RexCall]
+    val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+    super.explainTerms(pw)
+      .item("invocation", scan.getCall)
+      .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
+      .item("rowType", relRowType)
+      .item("joinType", joinType)
+      .itemIf("condition", condition.orNull, condition.isDefined)
+  }
+
+  override def translateToPlan(
+      tableEnv: BatchTableEnvironment,
+      expectedType: Option[TypeInformation[Any]])
+    : DataSet[Any] = {
+
+    val config = tableEnv.getConfig
+    val returnType = determineReturnType(
+      getRowType,
+      expectedType,
+      config.getNullCheck,
+      config.getEfficientTypeUsage)
+
+    // do not need to specify input type
+    val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
+
+    val funcRel = scan.asInstanceOf[LogicalTableFunctionScan]
+    val rexCall = funcRel.getCall.asInstanceOf[RexCall]
+    val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+    val pojoFieldMapping = sqlFunction.getPojoFieldMapping
+    val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
+
+    val generator = new CodeGenerator(
+      config,
+      false,
+      inputDS.getType,
+      Some(udtfTypeInfo),
+      None,
+      Some(pojoFieldMapping))
+
+    val body = functionBody(
+      generator,
+      udtfTypeInfo,
+      getRowType,
+      rexCall,
+      condition,
+      config,
+      joinType,
+      expectedType)
+
+    val genFunction = generator.generateFunction(
+      ruleDescription,
+      classOf[FlatMapFunction[Any, Any]],
+      body,
+      returnType)
+
+    val mapFunc = correlateMapFunction(genFunction)
+
+    inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType))
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
new file mode 100644
index 0000000..b0bc48a
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes.datastream
+
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.StreamTableEnvironment
+import org.apache.flink.api.table.codegen.CodeGenerator
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.plan.nodes.FlinkCorrelate
+import org.apache.flink.api.table.typeutils.TypeConverter._
+import org.apache.flink.streaming.api.datastream.DataStream
+
+/**
+  * Flink RelNode which matches along with cross apply a user defined table function.
+  */
+class DataStreamCorrelate(
+    cluster: RelOptCluster,
+    traitSet: RelTraitSet,
+    inputNode: RelNode,
+    scan: LogicalTableFunctionScan,
+    condition: Option[RexNode],
+    relRowType: RelDataType,
+    joinRowType: RelDataType,
+    joinType: SemiJoinType,
+    ruleDescription: String)
+  extends SingleRel(cluster, traitSet, inputNode)
+  with FlinkCorrelate
+  with DataStreamRel {
+  override def deriveRowType() = relRowType
+
+  override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+    new DataStreamCorrelate(
+      cluster,
+      traitSet,
+      inputs.get(0),
+      scan,
+      condition,
+      relRowType,
+      joinRowType,
+      joinType,
+      ruleDescription)
+  }
+
+  override def toString: String = {
+    val rexCall = scan.getCall.asInstanceOf[RexCall]
+    val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+    correlateToString(rexCall, sqlFunction)
+  }
+
+  override def explainTerms(pw: RelWriter): RelWriter = {
+    val rexCall = scan.getCall.asInstanceOf[RexCall]
+    val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+    super.explainTerms(pw)
+      .item("invocation", scan.getCall)
+      .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
+      .item("rowType", relRowType)
+      .item("joinType", joinType)
+      .itemIf("condition", condition.orNull, condition.isDefined)
+  }
+
+  override def translateToPlan(
+      tableEnv: StreamTableEnvironment,
+      expectedType: Option[TypeInformation[Any]])
+    : DataStream[Any] = {
+
+    val config = tableEnv.getConfig
+    val returnType = determineReturnType(
+      getRowType,
+      expectedType,
+      config.getNullCheck,
+      config.getEfficientTypeUsage)
+
+    // do not need to specify input type
+    val inputDS = inputNode.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
+
+    val funcRel = scan.asInstanceOf[LogicalTableFunctionScan]
+    val rexCall = funcRel.getCall.asInstanceOf[RexCall]
+    val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+    val pojoFieldMapping = sqlFunction.getPojoFieldMapping
+    val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
+
+    val generator = new CodeGenerator(
+      config,
+      false,
+      inputDS.getType,
+      Some(udtfTypeInfo),
+      None,
+      Some(pojoFieldMapping))
+
+    val body = functionBody(
+      generator,
+      udtfTypeInfo,
+      getRowType,
+      rexCall,
+      condition,
+      config,
+      joinType,
+      expectedType)
+
+    val genFunction = generator.generateFunction(
+      ruleDescription,
+      classOf[FlatMapFunction[Any, Any]],
+      body,
+      returnType)
+
+    val mapFunc = correlateMapFunction(genFunction)
+
+    inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
index 9e20df4..6847425 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
@@ -108,6 +108,7 @@ object FlinkRuleSets {
     DataSetMinusRule.INSTANCE,
     DataSetSortRule.INSTANCE,
     DataSetValuesRule.INSTANCE,
+    DataSetCorrelateRule.INSTANCE,
     BatchTableSourceScanRule.INSTANCE
   )
 
@@ -151,6 +152,7 @@ object FlinkRuleSets {
       DataStreamScanRule.INSTANCE,
       DataStreamUnionRule.INSTANCE,
       DataStreamValuesRule.INSTANCE,
+      DataStreamCorrelateRule.INSTANCE,
       StreamTableSourceScanRule.INSTANCE
   )
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
new file mode 100644
index 0000000..e6cf0cf
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.rules.dataSet
+
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.logical.{LogicalFilter, LogicalCorrelate, LogicalTableFunctionScan}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.table.plan.nodes.dataset.{DataSetConvention, DataSetCorrelate}
+
+/**
+  * Rule to convert a LogicalCorrelate into a DataSetCorrelate.
+  */
+class DataSetCorrelateRule
+  extends ConverterRule(
+      classOf[LogicalCorrelate],
+      Convention.NONE,
+      DataSetConvention.INSTANCE,
+      "DataSetCorrelateRule")
+  {
+
+    override def matches(call: RelOptRuleCall): Boolean = {
+      val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate]
+      val right = join.getRight.asInstanceOf[RelSubset].getOriginal
+
+
+      right match {
+        // right node is a table function
+        case scan: LogicalTableFunctionScan => true
+        // a filter is pushed above the table function
+        case filter: LogicalFilter =>
+          filter.getInput.asInstanceOf[RelSubset].getOriginal
+            .isInstanceOf[LogicalTableFunctionScan]
+        case _ => false
+      }
+    }
+
+    override def convert(rel: RelNode): RelNode = {
+      val join: LogicalCorrelate = rel.asInstanceOf[LogicalCorrelate]
+      val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
+      val convInput: RelNode = RelOptRule.convert(join.getInput(0), DataSetConvention.INSTANCE)
+      val right: RelNode = join.getInput(1)
+
+      def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): DataSetCorrelate = {
+        relNode match {
+          case rel: RelSubset =>
+            convertToCorrelate(rel.getRelList.get(0), condition)
+
+          case filter: LogicalFilter =>
+            convertToCorrelate(
+              filter.getInput.asInstanceOf[RelSubset].getOriginal,
+              Some(filter.getCondition))
+
+          case scan: LogicalTableFunctionScan =>
+            new DataSetCorrelate(
+              rel.getCluster,
+              traitSet,
+              convInput,
+              scan,
+              condition,
+              rel.getRowType,
+              join.getRowType,
+              join.getJoinType,
+              description)
+        }
+      }
+      convertToCorrelate(right, None)
+    }
+  }
+
+object DataSetCorrelateRule {
+  val INSTANCE: RelOptRule = new DataSetCorrelateRule
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
new file mode 100644
index 0000000..bb52fd7
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.rules.datastream
+
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.logical.{LogicalFilter, LogicalCorrelate, LogicalTableFunctionScan}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.table.plan.nodes.datastream.{DataStreamCorrelate, DataStreamConvention}
+
+/**
+  * Rule to convert a LogicalCorrelate into a DataStreamCorrelate.
+  */
+class DataStreamCorrelateRule
+  extends ConverterRule(
+    classOf[LogicalCorrelate],
+    Convention.NONE,
+    DataStreamConvention.INSTANCE,
+    "DataStreamCorrelateRule")
+{
+
+  override def matches(call: RelOptRuleCall): Boolean = {
+    val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate]
+    val right = join.getRight.asInstanceOf[RelSubset].getOriginal
+
+    right match {
+      // right node is a table function
+      case scan: LogicalTableFunctionScan => true
+      // a filter is pushed above the table function
+      case filter: LogicalFilter =>
+        filter.getInput.asInstanceOf[RelSubset].getOriginal
+          .isInstanceOf[LogicalTableFunctionScan]
+      case _ => false
+    }
+  }
+
+  override def convert(rel: RelNode): RelNode = {
+    val join: LogicalCorrelate = rel.asInstanceOf[LogicalCorrelate]
+    val traitSet: RelTraitSet = rel.getTraitSet.replace(DataStreamConvention.INSTANCE)
+    val convInput: RelNode = RelOptRule.convert(join.getInput(0), DataStreamConvention.INSTANCE)
+    val right: RelNode = join.getInput(1)
+
+    def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): DataStreamCorrelate = {
+      relNode match {
+        case rel: RelSubset =>
+          convertToCorrelate(rel.getRelList.get(0), condition)
+
+        case filter: LogicalFilter =>
+          convertToCorrelate(filter.getInput.asInstanceOf[RelSubset].getOriginal,
+                             Some(filter.getCondition))
+
+        case scan: LogicalTableFunctionScan =>
+          new DataStreamCorrelate(
+            rel.getCluster,
+            traitSet,
+            convInput,
+            scan,
+            condition,
+            rel.getRowType,
+            join.getRowType,
+            join.getJoinType,
+            description)
+      }
+    }
+    convertToCorrelate(right, None)
+  }
+
+}
+
+object DataStreamCorrelateRule {
+  val INSTANCE: RelOptRule = new DataStreamCorrelateRule
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
new file mode 100644
index 0000000..540a5c8
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.schema
+
+import java.lang.reflect.{Method, Type}
+import java.util
+
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
+import org.apache.calcite.schema.TableFunction
+import org.apache.calcite.schema.impl.ReflectiveFunctionBase
+import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.table.{FlinkTypeFactory, TableException}
+
+/**
+  * This is heavily inspired by Calcite's [[org.apache.calcite.schema.impl.TableFunctionImpl]].
+  * We need it in order to create a [[org.apache.flink.api.table.functions.utils.TableSqlFunction]].
+  * The main difference is that we override the [[getRowType()]] and [[getElementType()]].
+  */
+class FlinkTableFunctionImpl[T](
+    val typeInfo: TypeInformation[T],
+    val fieldIndexes: Array[Int],
+    val fieldNames: Array[String],
+    val evalMethod: Method)
+  extends ReflectiveFunctionBase(evalMethod)
+  with TableFunction {
+
+  if (fieldIndexes.length != fieldNames.length) {
+    throw new TableException(
+      "Number of field indexes and field names must be equal.")
+  }
+
+  // check uniqueness of field names
+  if (fieldNames.length != fieldNames.toSet.size) {
+    throw new TableException(
+      "Table field names must be unique.")
+  }
+
+  val fieldTypes: Array[TypeInformation[_]] =
+    typeInfo match {
+      case cType: CompositeType[T] =>
+        if (fieldNames.length != cType.getArity) {
+          throw new TableException(
+            s"Arity of type (" + cType.getFieldNames.deep + ") " +
+              "not equal to number of field names " + fieldNames.deep + ".")
+        }
+        fieldIndexes.map(cType.getTypeAt(_).asInstanceOf[TypeInformation[_]])
+      case aType: AtomicType[T] =>
+        if (fieldIndexes.length != 1 || fieldIndexes(0) != 0) {
+          throw new TableException(
+            "Non-composite input type may have only a single field and its index must be 0.")
+        }
+        Array(aType)
+    }
+
+  override def getElementType(arguments: util.List[AnyRef]): Type = classOf[Array[Object]]
+
+  override def getRowType(typeFactory: RelDataTypeFactory,
+                          arguments: util.List[AnyRef]): RelDataType = {
+    val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory]
+    val builder = flinkTypeFactory.builder
+    fieldNames
+      .zip(fieldTypes)
+      .foreach { f =>
+        builder.add(f._1, flinkTypeFactory.createTypeFromTypeInfo(f._2)).nullable(true)
+      }
+    builder.build
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
index c45e871..a75f2fc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
@@ -20,7 +20,8 @@ package org.apache.flink.api.table
 import org.apache.calcite.rel.RelNode
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.operators.join.JoinType
-import org.apache.flink.api.table.expressions.{Asc, Expression, ExpressionParser, Ordering}
+import org.apache.flink.api.table.plan.logical.Minus
+import org.apache.flink.api.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall}
 import org.apache.flink.api.table.plan.ProjectionTranslator._
 import org.apache.flink.api.table.plan.logical._
 import org.apache.flink.api.table.sinks.TableSink
@@ -400,7 +401,8 @@ class Table(
     }
     new Table(
       tableEnv,
-      Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate).validate(tableEnv))
+      Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate, correlated = false)
+        .validate(tableEnv))
   }
 
   /**
@@ -609,6 +611,126 @@ class Table(
   }
 
   /**
+    * The Cross Apply returns rows from the outer table (table on the left of the Apply operator)
+    * that produces matching values from the table-valued function (which is on the right side of
+    * the operator).
+    *
+    * The Cross Apply is equivalent to Inner Join, but it works with a table-valued function.
+    *
+    * Example:
+    *
+    * {{{
+    *   class MySplitUDTF extends TableFunction[String] {
+    *     def eval(str: String): Unit = {
+    *       str.split("#").foreach(collect)
+    *     }
+    *   }
+    *
+    *   val split = new MySplitUDTF()
+    *   table.crossApply(split('c) as ('s)).select('a,'b,'c,'s)
+    * }}}
+    */
+  def crossApply(udtf: Expression): Table = {
+    applyInternal(udtf, JoinType.INNER)
+  }
+
+  /**
+    * The Cross Apply returns rows from the outer table (table on the left of the Apply operator)
+    * that produces matching values from the table-valued function (which is on the right side of
+    * the operator).
+    *
+    * The Cross Apply is equivalent to Inner Join, but it works with a table-valued function.
+    *
+    * Example:
+    *
+    * {{{
+    *   class MySplitUDTF extends TableFunction[String] {
+    *     def eval(str: String): Unit = {
+    *       str.split("#").foreach(collect)
+    *     }
+    *   }
+    *
+    *   val split = new MySplitUDTF()
+    *   table.crossApply("split(c) as (s)").select("a, b, c, s")
+    * }}}
+    */
+  def crossApply(udtf: String): Table = {
+    applyInternal(udtf, JoinType.INNER)
+  }
+
+  /**
+    * The Outer Apply returns all the rows from the outer table (table on the left of the Apply
+    * operator), and rows that do not matches the condition from the table-valued function (which
+    * is on the right side of the operator), NULL values are displayed.
+    *
+    * The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function.
+    *
+    * Example:
+    *
+    * {{{
+    *   class MySplitUDTF extends TableFunction[String] {
+    *     def eval(str: String): Unit = {
+    *       str.split("#").foreach(collect)
+    *     }
+    *   }
+    *
+    *   val split = new MySplitUDTF()
+    *   table.outerApply(split('c) as ('s)).select('a,'b,'c,'s)
+    * }}}
+    */
+  def outerApply(udtf: Expression): Table = {
+    applyInternal(udtf, JoinType.LEFT_OUTER)
+  }
+
+  /**
+    * The Outer Apply returns all the rows from the outer table (table on the left of the Apply
+    * operator), and rows that do not matches the condition from the table-valued function (which
+    * is on the right side of the operator), NULL values are displayed.
+    *
+    * The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function.
+    *
+    * Example:
+    *
+    * {{{
+    *   val split = new MySplitUDTF()
+    *   table.outerApply("split(c) as (s)").select("a, b, c, s")
+    * }}}
+    */
+  def outerApply(udtf: String): Table = {
+    applyInternal(udtf, JoinType.LEFT_OUTER)
+  }
+
+  private def applyInternal(udtfString: String, joinType: JoinType): Table = {
+    val udtf = ExpressionParser.parseExpression(udtfString)
+    applyInternal(udtf, joinType)
+  }
+
+  private def applyInternal(udtf: Expression, joinType: JoinType): Table = {
+    var alias: Option[Seq[String]] = None
+
+    // unwrap an Expression until get a TableFunctionCall
+    def unwrap(expr: Expression): TableFunctionCall = expr match {
+      case Alias(child, name, extraNames) =>
+        alias = Some(Seq(name) ++ extraNames)
+        unwrap(child)
+      case Call(name, args) =>
+        val function = tableEnv.getFunctionCatalog.lookupFunction(name, args)
+        unwrap(function)
+      case c: TableFunctionCall => c
+      case _ => throw new TableException("Cross/Outer Apply only accept TableFunction")
+    }
+
+    val call = unwrap(udtf)
+      .as(alias)
+      .toLogicalTableFunctionCall(this.logicalPlan)
+      .validate(tableEnv)
+
+    new Table(
+      tableEnv,
+      Join(this.logicalPlan, call, joinType, None, correlated = true).validate(tableEnv))
+  }
+
+  /**
     * Writes the [[Table]] to a [[TableSink]]. A [[TableSink]] defines an external storage location.
     *
     * A batch [[Table]] can only be written to a

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
index 679733c..4029a7d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
@@ -23,8 +23,8 @@ import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTabl
 import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable}
 import org.apache.flink.api.table.ValidationException
 import org.apache.flink.api.table.expressions._
-import org.apache.flink.api.table.functions.ScalarFunction
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction}
+import org.apache.flink.api.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils}
 
 import scala.collection.JavaConversions._
 import scala.collection.mutable
@@ -47,6 +47,20 @@ class FunctionCatalog {
     sqlFunctions += sqlFunction
   }
 
+  /** Register multiple sql functions at one time. The functions has the same name. **/
+  def registerSqlFunctions(functions: Seq[SqlFunction]): Unit = {
+    if (functions.nonEmpty) {
+      val name = functions.head.getName
+      // check all name is the same in the functions
+      if (functions.forall(_.getName == name)) {
+        sqlFunctions --= sqlFunctions.filter(_.getName == name)
+        sqlFunctions ++= functions
+      } else {
+        throw ValidationException("The sql functions request to register have different name.")
+      }
+    }
+  }
+
   def getSqlOperatorTable: SqlOperatorTable =
     ChainedSqlOperatorTable.of(
       new BasicOperatorTable(),
@@ -59,14 +73,9 @@ class FunctionCatalog {
   def lookupFunction(name: String, children: Seq[Expression]): Expression = {
     val funcClass = functionBuilders
       .getOrElse(name.toLowerCase, throw ValidationException(s"Undefined function: $name"))
-    withChildren(funcClass, children)
-  }
 
-  /**
-    * Instantiate a function using the provided `children`.
-    */
-  private def withChildren(func: Class[_], children: Seq[Expression]): Expression = {
-    func match {
+    // Instantiate a function using the provided `children`
+    funcClass match {
 
       // user-defined scalar function call
       case sf if classOf[ScalarFunction].isAssignableFrom(sf) =>
@@ -75,10 +84,20 @@ class FunctionCatalog {
           case Failure(e) => throw ValidationException(e.getMessage)
         }
 
+      // user-defined table function call
+      case tf if classOf[TableFunction[_]].isAssignableFrom(tf) =>
+        val tableSqlFunction = sqlFunctions
+          .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[TableSqlFunction])
+          .getOrElse(throw ValidationException(s"Unregistered table sql function: $name"))
+          .asInstanceOf[TableSqlFunction]
+        val typeInfo = tableSqlFunction.getRowTypeInfo
+        val function = tableSqlFunction.getTableFunction
+        TableFunctionCall(name, function, children, typeInfo)
+
       // general expression call
       case expression if classOf[Expression].isAssignableFrom(expression) =>
         // try to find a constructor accepts `Seq[Expression]`
-        Try(func.getDeclaredConstructor(classOf[Seq[_]])) match {
+        Try(funcClass.getDeclaredConstructor(classOf[Seq[_]])) match {
           case Success(seqCtor) =>
             Try(seqCtor.newInstance(children).asInstanceOf[Expression]) match {
               case Success(expr) => expr
@@ -87,14 +106,14 @@ class FunctionCatalog {
           case Failure(e) =>
             val childrenClass = Seq.fill(children.length)(classOf[Expression])
             // try to find a constructor matching the exact number of children
-            Try(func.getDeclaredConstructor(childrenClass: _*)) match {
+            Try(funcClass.getDeclaredConstructor(childrenClass: _*)) match {
               case Success(ctor) =>
                 Try(ctor.newInstance(children: _*).asInstanceOf[Expression]) match {
                   case Success(expr) => expr
                   case Failure(exception) => throw ValidationException(exception.getMessage)
                 }
               case Failure(exception) =>
-                throw ValidationException(s"Invalid number of arguments for function $func")
+                throw ValidationException(s"Invalid number of arguments for function $funcClass")
             }
         }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
new file mode 100644
index 0000000..7e0d0ff
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.batch
+
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils._
+import org.apache.flink.api.table.{Row, Table, TableEnvironment}
+import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.apache.flink.test.util.TestBaseUtils
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class UserDefinedTableFunctionITCase(
+  mode: TestExecutionMode,
+  configMode: TableConfigMode)
+  extends TableProgramsTestBase(mode, configMode) {
+
+  @Test
+  def testSQLCrossApply(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+    tableEnv.registerTable("MyTable", in)
+    tableEnv.registerFunction("split", new TableFunc1)
+
+    val sqlQuery = "SELECT MyTable.c, t.s FROM MyTable, LATERAL TABLE(split(c)) AS t(s)"
+
+    val result = tableEnv.sql(sqlQuery).toDataSet[Row]
+    val results = result.collect()
+    val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+      "Anna#44,Anna\n" + "Anna#44,44\n"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
+  def testSQLOuterApply(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+    tableEnv.registerTable("MyTable", in)
+    tableEnv.registerFunction("split", new TableFunc2)
+
+    val sqlQuery = "SELECT MyTable.c, t.a, t.b  FROM MyTable LEFT JOIN LATERAL TABLE(split(c)) " +
+      "AS t(a,b) ON TRUE"
+
+    val result = tableEnv.sql(sqlQuery).toDataSet[Row]
+    val results = result.collect()
+    val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+      "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
+  def testTableAPICrossApply(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+    val func1 = new TableFunc1
+    val result = in.crossApply(func1('c) as ('s)).select('c, 's).toDataSet[Row]
+    val results = result.collect()
+    val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+      "Anna#44,Anna\n" + "Anna#44,44\n"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+
+    // with overloading
+    val result2 = in.crossApply(func1('c, "$") as ('s)).select('c, 's).toDataSet[Row]
+    val results2 = result2.collect()
+    val expected2: String = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" +
+      "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n"
+    TestBaseUtils.compareResultAsText(results2.asJava, expected2)
+  }
+
+
+  @Test
+  def testTableAPIOuterApply(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+    val func2 = new TableFunc2
+    val result = in.outerApply(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row]
+    val results = result.collect()
+    val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+      "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+
+  @Test
+  def testCustomReturnType(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+    val func2 = new TableFunc2
+
+    val result = in
+      .crossApply(func2('c) as ('name, 'len))
+      .select('c, 'name, 'len)
+      .toDataSet[Row]
+
+    val results = result.collect()
+    val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+      "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
+  def testHierarchyType(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+    val hierarchy = new HierarchyTableFunction
+    val result = in
+      .crossApply(hierarchy('c) as ('name, 'adult, 'len))
+      .select('c, 'name, 'adult, 'len)
+      .toDataSet[Row]
+
+    val results = result.collect()
+    val expected: String = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" +
+      "Anna#44,Anna,true,44\n"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
+  def testPojoType(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+    val pojo = new PojoTableFunc()
+    val result = in
+      .crossApply(pojo('c))
+      .select('c, 'name, 'age)
+      .toDataSet[Row]
+
+    val results = result.collect()
+    val expected: String = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+
+  @Test
+  def testTableAPIWithFilter(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+    val func0 = new TableFunc0
+
+    val result = in
+      .crossApply(func0('c) as ('name, 'age))
+      .select('c, 'name, 'age)
+      .filter('age > 20)
+      .toDataSet[Row]
+
+    val results = result.collect()
+    val expected: String = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+
+  @Test
+  def testUDTFWithScalarFunction(): Unit = {
+    val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+    val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+    val func1 = new TableFunc1
+
+    val result = in
+      .crossApply(func1('c.substring(2)) as 's)
+      .select('c, 's)
+      .toDataSet[Row]
+
+    val results = result.collect()
+    val expected: String = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" +
+      "Anna#44,nna\n" + "Anna#44,44\n"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+
+  private def getSmall3TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, String)] = {
+    val data = new mutable.MutableList[(Int, Long, String)]
+    data.+=((1, 1L, "Jack#22"))
+    data.+=((2, 2L, "John#19"))
+    data.+=((3, 2L, "Anna#44"))
+    data.+=((4, 3L, "nosharp"))
+    env.fromCollection(data)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..7e236d1
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.batch
+
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment => ScalaExecutionEnv, _}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.{DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv}
+import org.apache.flink.api.table.expressions.utils.{HierarchyTableFunction, PojoTableFunc, TableFunc1, TableFunc2}
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.apache.flink.api.table.{Row, TableEnvironment, Types}
+import org.junit.Test
+import org.mockito.Mockito._
+
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+  @Test
+  def testTableAPI(): Unit = {
+    // mock
+    val ds = mock(classOf[DataSet[Row]])
+    val jDs = mock(classOf[JDataSet[Row]])
+    val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
+    when(ds.javaSet).thenReturn(jDs)
+    when(jDs.getType).thenReturn(typeInfo)
+
+    // Scala environment
+    val env = mock(classOf[ScalaExecutionEnv])
+    val tableEnv = TableEnvironment.getTableEnvironment(env)
+    val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
+
+    // Java environment
+    val javaEnv = mock(classOf[JavaExecutionEnv])
+    val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
+    val in2 = javaTableEnv.fromDataSet(jDs).as("a, b, c")
+    javaTableEnv.registerTable("MyTable", in2)
+
+    // test cross apply
+    val func1 = new TableFunc1
+    javaTableEnv.registerFunction("func1", func1)
+    var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's)
+    var javaTable = in2.crossApply("func1(c) as (s)").select("c, s")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test outer apply
+    scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's)
+    javaTable = in2.outerApply("func1(c) as (s)").select("c, s")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test overloading
+    scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's)
+    javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test custom result type
+    val func2 = new TableFunc2
+    javaTableEnv.registerFunction("func2", func2)
+    scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
+    javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test hierarchy generic type
+    val hierarchy = new HierarchyTableFunction
+    javaTableEnv.registerFunction("hierarchy", hierarchy)
+    scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
+      .select('c, 'name, 'len, 'adult)
+    javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)")
+      .select("c, name, len, adult")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test pojo type
+    val pojo = new PojoTableFunc
+    javaTableEnv.registerFunction("pojo", pojo)
+    scalaTable = in1.crossApply(pojo('c))
+      .select('c, 'name, 'age)
+    javaTable = in2.crossApply("pojo(c)")
+      .select("c, name, age")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test with filter
+    scalaTable = in1.crossApply(func2('c) as ('name, 'len))
+      .select('c, 'name, 'len).filter('len > 2)
+    javaTable = in2.crossApply("func2(c) as (name, len)")
+      .select("c, name, len").filter("len > 2")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test with scalar function
+    scalaTable = in1.crossApply(func1('c.substring(2)) as ('s))
+      .select('a, 'c, 's)
+    javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
+      .select("a, c, s")
+    verifyTableEquals(scalaTable, javaTable)
+  }
+
+  @Test
+  def testSQLWithCrossApply(): Unit = {
+    val util = batchTestUtil()
+    val func1 = new TableFunc1
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func1", func1)
+
+    val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", "func1($cor0.c)"),
+        term("function", func1.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS s")
+    )
+
+    util.verifySql(sqlQuery, expected)
+
+    // test overloading
+
+    val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
+
+    val expected2 = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", "func1($cor0.c, '$')"),
+        term("function", func1.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS s")
+    )
+
+    util.verifySql(sqlQuery2, expected2)
+  }
+
+  @Test
+  def testSQLWithOuterApply(): Unit = {
+    val util = batchTestUtil()
+    val func1 = new TableFunc1
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func1", func1)
+
+    val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", "func1($cor0.c)"),
+        term("function", func1.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+        term("joinType", "LEFT")
+      ),
+      term("select", "c", "f0 AS s")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSQLWithCustomType(): Unit = {
+    val util = batchTestUtil()
+    val func2 = new TableFunc2
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func2", func2)
+
+    val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", "func2($cor0.c)"),
+        term("function", func2.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+               "VARCHAR(2147483647) f0, INTEGER f1)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS name", "f1 AS len")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSQLWithHierarchyType(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val function = new HierarchyTableFunction
+    util.addFunction("hierarchy", function)
+
+    val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", "hierarchy($cor0.c)"),
+        term("function", function.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+               " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSQLWithPojoType(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val function = new PojoTableFunc
+    util.addFunction("pojo", function)
+
+    val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", "pojo($cor0.c)"),
+        term("function", function.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+               " INTEGER age, VARCHAR(2147483647) name)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "name", "age")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSQLWithFilter(): Unit = {
+    val util = batchTestUtil()
+    val func2 = new TableFunc2
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func2", func2)
+
+    val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " +
+      "WHERE len > 2"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", "func2($cor0.c)"),
+        term("function", func2.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+               "VARCHAR(2147483647) f0, INTEGER f1)"),
+        term("joinType", "INNER"),
+        term("condition", ">($1, 2)")
+      ),
+      term("select", "c", "f0 AS name", "f1 AS len")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+
+  @Test
+  def testSQLWithScalarFunction(): Unit = {
+    val util = batchTestUtil()
+    val func1 = new TableFunc1
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func1", func1)
+
+    val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
+        term("function", func1.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS s")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+}