You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2016/12/07 15:57:22 UTC
[3/5] flink git commit: [FLINK-4469] [table] Add support for user
defined table function in Table API & SQL
[FLINK-4469] [table] Add support for user defined table function in Table API & SQL
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e139f59c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e139f59c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e139f59c
Branch: refs/heads/master
Commit: e139f59ce97875338c5ee74bb0432389b3f343bf
Parents: c024b0b
Author: Jark Wu <wu...@alibaba-inc.com>
Authored: Tue Oct 18 11:15:07 2016 +0800
Committer: twalthr <tw...@apache.org>
Committed: Wed Dec 7 16:27:42 2016 +0100
----------------------------------------------------------------------
.../api/java/table/BatchTableEnvironment.scala | 15 +
.../api/java/table/StreamTableEnvironment.scala | 15 +
.../api/scala/table/BatchTableEnvironment.scala | 12 +
.../scala/table/StreamTableEnvironment.scala | 11 +
.../scala/table/TableFunctionCallBuilder.scala | 39 ++
.../flink/api/scala/table/expressionDsl.scala | 5 +-
.../flink/api/table/FlinkTypeFactory.scala | 14 +-
.../flink/api/table/TableEnvironment.scala | 50 ++-
.../flink/api/table/codegen/CodeGenerator.scala | 95 +++--
.../table/codegen/calls/FunctionGenerator.scala | 369 +++++++++++++++++
.../table/codegen/calls/ScalarFunctions.scala | 359 -----------------
.../codegen/calls/TableFunctionCallGen.scala | 82 ++++
.../table/expressions/ExpressionParser.scala | 4 +-
.../flink/api/table/expressions/call.scala | 83 +++-
.../api/table/expressions/fieldExpression.scala | 6 +-
.../api/table/functions/ScalarFunction.scala | 44 +-
.../api/table/functions/TableFunction.scala | 121 ++++++
.../table/functions/UserDefinedFunction.scala | 36 +-
.../functions/utils/ScalarSqlFunction.scala | 6 +-
.../functions/utils/TableSqlFunction.scala | 119 ++++++
.../utils/UserDefinedFunctionUtils.scala | 275 ++++++++++---
.../api/table/plan/ProjectionTranslator.scala | 4 +-
.../api/table/plan/logical/operators.scala | 102 ++++-
.../api/table/plan/nodes/FlinkCorrelate.scala | 162 ++++++++
.../plan/nodes/dataset/DataSetCorrelate.scala | 139 +++++++
.../nodes/datastream/DataStreamCorrelate.scala | 133 ++++++
.../api/table/plan/rules/FlinkRuleSets.scala | 2 +
.../rules/dataSet/DataSetCorrelateRule.scala | 90 +++++
.../datastream/DataStreamCorrelateRule.scala | 89 ++++
.../plan/schema/FlinkTableFunctionImpl.scala | 84 ++++
.../org/apache/flink/api/table/table.scala | 126 +++++-
.../api/table/validate/FunctionCatalog.scala | 43 +-
.../batch/UserDefinedTableFunctionITCase.scala | 212 ++++++++++
.../batch/UserDefinedTableFunctionTest.scala | 320 +++++++++++++++
.../stream/UserDefinedTableFunctionITCase.scala | 181 +++++++++
.../stream/UserDefinedTableFunctionTest.scala | 402 +++++++++++++++++++
.../UserDefinedScalarFunctionTest.scala | 4 +-
.../expressions/utils/ExpressionTestBase.scala | 4 +-
.../utils/UserDefinedTableFunctions.scala | 116 ++++++
.../flink/api/table/utils/TableTestBase.scala | 32 ++
40 files changed, 3414 insertions(+), 591 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
index a4f40d5..b353377 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
@@ -21,6 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.api.table.expressions.ExpressionParser
+import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.{Table, TableConfig}
/**
@@ -162,4 +163,18 @@ class BatchTableEnvironment(
translate[T](table)(typeInfo)
}
+ /**
+ * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in Table API and SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param tf The TableFunction to register
+ */
+ def registerFunction[T](name: String, tf: TableFunction[T]): Unit = {
+ implicit val typeInfo: TypeInformation[T] = TypeExtractor
+ .createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0)
+ .asInstanceOf[TypeInformation[T]]
+
+ registerTableFunctionInternal[T](name, tf)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
index f8dbc37..367cb82 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
@@ -19,6 +19,7 @@ package org.apache.flink.api.java.table
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.{TableConfig, Table}
import org.apache.flink.api.table.expressions.ExpressionParser
import org.apache.flink.streaming.api.datastream.DataStream
@@ -164,4 +165,18 @@ class StreamTableEnvironment(
translate[T](table)(typeInfo)
}
+ /**
+ * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in Table API and SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param tf The TableFunction to register
+ */
+ def registerFunction[T](name: String, tf: TableFunction[T]): Unit = {
+ implicit val typeInfo: TypeInformation[T] = TypeExtractor
+ .createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0)
+ .asInstanceOf[TypeInformation[T]]
+
+ registerTableFunctionInternal[T](name, tf)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
index adb444b..36885d2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
@@ -20,6 +20,7 @@ package org.apache.flink.api.scala.table
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala._
import org.apache.flink.api.table.expressions.Expression
+import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.{TableConfig, Table}
import scala.reflect.ClassTag
@@ -139,4 +140,15 @@ class BatchTableEnvironment(
wrap[T](translate(table))(ClassTag.AnyRef.asInstanceOf[ClassTag[T]])
}
+ /**
+ * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param tf The TableFunction to register
+ */
+ def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
+ registerTableFunctionInternal(name, tf)
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala
index e106178..dde69d5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala
@@ -18,6 +18,7 @@
package org.apache.flink.api.scala.table
import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.{TableConfig, Table}
import org.apache.flink.api.table.expressions.Expression
import org.apache.flink.streaming.api.scala.{StreamExecutionEnvironment, DataStream}
@@ -142,4 +143,14 @@ class StreamTableEnvironment(
asScalaStream(translate(table))
}
+ /**
+ * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param tf The TableFunction to register
+ */
+ def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
+ registerTableFunctionInternal(name, tf)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
new file mode 100644
index 0000000..2261b70
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.table
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.expressions.{Expression, TableFunctionCall}
+import org.apache.flink.api.table.functions.TableFunction
+
+case class TableFunctionCallBuilder[T: TypeInformation](udtf: TableFunction[T]) {
+ /**
+ * Creates a call to a [[TableFunction]] in Scala Table API.
+ *
+ * @param params actual parameters of function
+ * @return [[TableFunctionCall]]
+ */
+ def apply(params: Expression*): Expression = {
+ val resultType = if (udtf.getResultType == null) {
+ implicitly[TypeInformation[T]]
+ } else {
+ udtf.getResultType
+ }
+ TableFunctionCall(udtf.getClass.getSimpleName, udtf, params, resultType)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
index fee43d8..cc4c68d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, toMonthInterval, toRowInterval}
import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit
import org.apache.flink.api.table.expressions._
+import org.apache.flink.api.table.functions.TableFunction
import scala.language.implicitConversions
@@ -97,7 +98,7 @@ trait ImplicitExpressionOperations {
def cast(toType: TypeInformation[_]) = Cast(expr, toType)
- def as(name: Symbol) = Alias(expr, name.name)
+ def as(name: Symbol, extraNames: Symbol*) = Alias(expr, name.name, extraNames.map(_.name))
def asc = Asc(expr)
def desc = Desc(expr)
@@ -539,6 +540,8 @@ trait ImplicitExpressionConversions {
implicit def sqlDate2Literal(sqlDate: Date): Expression = Literal(sqlDate)
implicit def sqlTime2Literal(sqlTime: Time): Expression = Literal(sqlTime)
implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp)
+ implicit def UDTF2TableFunctionCall[T: TypeInformation](udtf: TableFunction[T]):
+ TableFunctionCallBuilder[T] = TableFunctionCallBuilder(udtf)
}
// ------------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
index 12dace4..bb11576 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
@@ -26,7 +26,7 @@ import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeinfo.{NothingTypeInfo, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.ValueTypeInfo._
import org.apache.flink.api.table.FlinkTypeFactory.typeInfoToSqlTypeName
@@ -115,9 +115,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
}
override def createTypeWithNullability(
- relDataType: RelDataType,
- nullable: Boolean)
- : RelDataType = relDataType match {
+ relDataType: RelDataType,
+ nullable: Boolean)
+ : RelDataType = relDataType match {
case composite: CompositeRelDataType =>
// at the moment we do not care about nullability
composite
@@ -172,8 +172,7 @@ object FlinkTypeFactory {
case typeName if DAY_INTERVAL_TYPES.contains(typeName) => TimeIntervalTypeInfo.INTERVAL_MILLIS
case NULL =>
- throw TableException("Type NULL is not supported. " +
- "Null values must have a supported type.")
+ throw TableException("Type NULL is not supported. Null values must have a supported type.")
// symbol for special flags e.g. TRIM's BOTH, LEADING, TRAILING
// are represented as integer
@@ -188,6 +187,9 @@ object FlinkTypeFactory {
val compositeRelDataType = relDataType.asInstanceOf[CompositeRelDataType]
compositeRelDataType.compositeType
+ // ROW and CURSOR for UDTF case, whose type info will never be used, just a placeholder
+ case ROW | CURSOR => new NothingTypeInfo
+
case _@t =>
throw TableException(s"Type is not supported: $t")
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
index 7b2b738..8cabadb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
@@ -40,7 +40,8 @@ import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.scala.{ExecutionEnvironment => ScalaBatchExecEnv}
import org.apache.flink.api.table.codegen.ExpressionReducer
import org.apache.flink.api.table.expressions.{Alias, Expression, UnresolvedFieldReference}
-import org.apache.flink.api.table.functions.{ScalarFunction, UserDefinedFunction}
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createTableSqlFunctions, createScalarSqlFunction}
+import org.apache.flink.api.table.functions.{TableFunction, ScalarFunction}
import org.apache.flink.api.table.plan.cost.DataSetCostFactory
import org.apache.flink.api.table.plan.schema.RelTable
import org.apache.flink.api.table.sinks.TableSink
@@ -153,21 +154,42 @@ abstract class TableEnvironment(val config: TableConfig) {
protected def getBuiltInRuleSet: RuleSet
/**
- * Registers a [[UserDefinedFunction]] under a unique name. Replaces already existing
+ * Registers a [[ScalarFunction]] under a unique name. Replaces already existing
* user-defined functions under this name.
*/
- def registerFunction(name: String, function: UserDefinedFunction): Unit = {
- function match {
- case sf: ScalarFunction =>
- // register in Table API
- functionCatalog.registerFunction(name, function.getClass)
+ def registerFunction(name: String, function: ScalarFunction): Unit = {
+ // check could be instantiated
+ checkForInstantiation(function.getClass)
- // register in SQL API
- functionCatalog.registerSqlFunction(sf.getSqlFunction(name, typeFactory))
+ // register in Table API
+ functionCatalog.registerFunction(name, function.getClass)
- case _ =>
- throw new TableException("Unsupported user-defined function type.")
+ // register in SQL API
+ functionCatalog.registerSqlFunction(createScalarSqlFunction(name, function, typeFactory))
+ }
+
+ /**
+ * Registers a [[TableFunction]] under a unique name. Replaces already existing
+ * user-defined functions under this name.
+ */
+ private[flink] def registerTableFunctionInternal[T: TypeInformation](
+ name: String, function: TableFunction[T]): Unit = {
+ // check not Scala object
+ checkNotSingleton(function.getClass)
+ // check could be instantiated
+ checkForInstantiation(function.getClass)
+
+ val typeInfo: TypeInformation[_] = if (function.getResultType != null) {
+ function.getResultType
+ } else {
+ implicitly[TypeInformation[T]]
}
+
+ // register in Table API
+ functionCatalog.registerFunction(name, function.getClass)
+ // register in SQL API
+ val sqlFunctions = createTableSqlFunctions(name, function, typeInfo, typeFactory)
+ functionCatalog.registerSqlFunctions(sqlFunctions)
}
/**
@@ -364,7 +386,7 @@ abstract class TableEnvironment(val config: TableConfig) {
case t: TupleTypeInfo[A] =>
exprs.zipWithIndex.map {
case (UnresolvedFieldReference(name), idx) => (idx, name)
- case (Alias(UnresolvedFieldReference(origName), name), _) =>
+ case (Alias(UnresolvedFieldReference(origName), name, _), _) =>
val idx = t.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $t")
@@ -376,7 +398,7 @@ abstract class TableEnvironment(val config: TableConfig) {
case c: CaseClassTypeInfo[A] =>
exprs.zipWithIndex.map {
case (UnresolvedFieldReference(name), idx) => (idx, name)
- case (Alias(UnresolvedFieldReference(origName), name), _) =>
+ case (Alias(UnresolvedFieldReference(origName), name, _), _) =>
val idx = c.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $c")
@@ -393,7 +415,7 @@ abstract class TableEnvironment(val config: TableConfig) {
throw new TableException(s"$name is not a field of type $p")
}
(idx, name)
- case Alias(UnresolvedFieldReference(origName), name) =>
+ case Alias(UnresolvedFieldReference(origName), name, _) =>
val idx = p.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $p")
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
index 2a8ef44..9e4f569 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
@@ -33,7 +33,7 @@ import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, Tuple
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.table.codegen.CodeGenUtils._
import org.apache.flink.api.table.codegen.Indenter.toISC
-import org.apache.flink.api.table.codegen.calls.ScalarFunctions
+import org.apache.flink.api.table.codegen.calls.FunctionGenerator
import org.apache.flink.api.table.codegen.calls.ScalarOperators._
import org.apache.flink.api.table.functions.UserDefinedFunction
import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeConverter}
@@ -50,16 +50,19 @@ import scala.collection.mutable
* @param nullableInput input(s) can be null.
* @param input1 type information about the first input of the Function
* @param input2 type information about the second input if the Function is binary
- * @param inputPojoFieldMapping additional mapping information if input1 is a POJO (POJO types
- * have no deterministic field order). We assume that input2 is
- * converted before and thus is never a POJO.
+ * @param input1PojoFieldMapping additional mapping information if input1 is a POJO (POJO types
+ * have no deterministic field order).
+ * @param input2PojoFieldMapping additional mapping information if input2 is a POJO (POJO types
+ * have no deterministic field order).
+ *
*/
class CodeGenerator(
config: TableConfig,
nullableInput: Boolean,
input1: TypeInformation[Any],
input2: Option[TypeInformation[Any]] = None,
- inputPojoFieldMapping: Option[Array[Int]] = None)
+ input1PojoFieldMapping: Option[Array[Int]] = None,
+ input2PojoFieldMapping: Option[Array[Int]] = None)
extends RexVisitor[GeneratedExpression] {
// check if nullCheck is enabled when inputs can be null
@@ -67,18 +70,19 @@ class CodeGenerator(
throw new CodeGenException("Null check must be enabled if entire rows can be null.")
}
- // check for POJO input mapping
+ // check for POJO input1 mapping
input1 match {
case pt: PojoTypeInfo[_] =>
- inputPojoFieldMapping.getOrElse(
- throw new CodeGenException("No input mapping is specified for input of type POJO."))
+ input1PojoFieldMapping.getOrElse(
+ throw new CodeGenException("No input mapping is specified for input1 of type POJO."))
case _ => // ok
}
- // check that input2 is never a POJO
+ // check for POJO input2 mapping
input2 match {
case Some(pt: PojoTypeInfo[_]) =>
- throw new CodeGenException("Second input must not be a POJO type.")
+ input2PojoFieldMapping.getOrElse(
+ throw new CodeGenException("No input mapping is specified for input2 of type POJO."))
case _ => // ok
}
@@ -156,22 +160,22 @@ class CodeGenerator(
/**
* @return term of the (casted and possibly boxed) first input
*/
- def input1Term = "in1"
+ var input1Term = "in1"
/**
* @return term of the (casted and possibly boxed) second input
*/
- def input2Term = "in2"
+ var input2Term = "in2"
/**
* @return term of the (casted) output collector
*/
- def collectorTerm = "c"
+ var collectorTerm = "c"
/**
* @return term of the output record (possibly defined in the member area e.g. Row, Tuple)
*/
- def outRecordTerm = "out"
+ var outRecordTerm = "out"
/**
* @return returns if null checking is enabled
@@ -334,11 +338,11 @@ class CodeGenerator(
resultFieldNames: Seq[String])
: GeneratedExpression = {
val input1AccessExprs = for (i <- 0 until input1.getArity)
- yield generateInputAccess(input1, input1Term, i)
+ yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping)
val input2AccessExprs = input2 match {
case Some(ti) => for (i <- 0 until ti.getArity)
- yield generateInputAccess(ti, input2Term, i)
+ yield generateInputAccess(ti, input2Term, i, input2PojoFieldMapping)
case None => Seq() // add nothing
}
@@ -346,6 +350,23 @@ class CodeGenerator(
}
/**
+ * Generates an expression from the left input and the right table function.
+ */
+ def generateCorrelateAccessExprs: (Seq[GeneratedExpression], Seq[GeneratedExpression]) = {
+ val input1AccessExprs = for (i <- 0 until input1.getArity)
+ yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping)
+
+ val input2AccessExprs = input2 match {
+ case Some(ti) => for (i <- 0 until ti.getArity)
+ // use generateFieldAccess instead of generateInputAccess to avoid the generated table
+ // function's field access code is put on the top of function body rather than the while loop
+ yield generateFieldAccess(ti, input2Term, i, input2PojoFieldMapping)
+ case None => throw new CodeGenException("type information of input2 must not be null")
+ }
+ (input1AccessExprs, input2AccessExprs)
+ }
+
+ /**
* Generates an expression from a sequence of RexNode. If objects or variables can be reused,
* they will be added to reusable code sections internally. The evaluation result
* may be stored in the global result variable (see [[outRecordTerm]]).
@@ -594,9 +615,11 @@ class CodeGenerator(
override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = {
// if inputRef index is within size of input1 we work with input1, input2 otherwise
val input = if (inputRef.getIndex < input1.getArity) {
- (input1, input1Term)
+ (input1, input1Term, input1PojoFieldMapping)
} else {
- (input2.getOrElse(throw new CodeGenException("Invalid input access.")), input2Term)
+ (input2.getOrElse(throw new CodeGenException("Invalid input access.")),
+ input2Term,
+ input2PojoFieldMapping)
}
val index = if (input._2 == input1Term) {
@@ -605,13 +628,17 @@ class CodeGenerator(
inputRef.getIndex - input1.getArity
}
- generateInputAccess(input._1, input._2, index)
+ generateInputAccess(input._1, input._2, index, input._3)
}
override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = {
val refExpr = rexFieldAccess.getReferenceExpr.accept(this)
val index = rexFieldAccess.getField.getIndex
- val fieldAccessExpr = generateFieldAccess(refExpr.resultType, refExpr.resultTerm, index)
+ val fieldAccessExpr = generateFieldAccess(
+ refExpr.resultType,
+ refExpr.resultTerm,
+ index,
+ input1PojoFieldMapping)
val resultTerm = newName("result")
val nullTerm = newName("isNull")
@@ -753,8 +780,9 @@ class CodeGenerator(
}
}
- override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression =
- throw new CodeGenException("Correlating variables are not supported yet.")
+ override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = {
+ GeneratedExpression(input1Term, GeneratedExpression.NEVER_NULL, "", input1)
+ }
override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression =
throw new CodeGenException("Local variables are not supported yet.")
@@ -948,7 +976,7 @@ class CodeGenerator(
// advanced scalar functions
case sqlOperator: SqlOperator =>
- val callGen = ScalarFunctions.getCallGenerator(
+ val callGen = FunctionGenerator.getCallGenerator(
sqlOperator,
operands.map(_.resultType),
resultType)
@@ -977,7 +1005,8 @@ class CodeGenerator(
private def generateInputAccess(
inputType: TypeInformation[Any],
inputTerm: String,
- index: Int)
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
: GeneratedExpression = {
// if input has been used before, we can reuse the code that
// has already been generated
@@ -989,10 +1018,10 @@ class CodeGenerator(
// generate input access and unboxing if necessary
case None =>
val expr = if (nullableInput) {
- generateNullableInputFieldAccess(inputType, inputTerm, index)
+ generateNullableInputFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
}
else {
- generateFieldAccess(inputType, inputTerm, index)
+ generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
}
reusableInputUnboxingExprs((inputTerm, index)) = expr
@@ -1005,7 +1034,8 @@ class CodeGenerator(
private def generateNullableInputFieldAccess(
inputType: TypeInformation[Any],
inputTerm: String,
- index: Int)
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
: GeneratedExpression = {
val resultTerm = newName("result")
val nullTerm = newName("isNull")
@@ -1013,7 +1043,7 @@ class CodeGenerator(
val fieldType = inputType match {
case ct: CompositeType[_] =>
val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) {
- inputPojoFieldMapping.get(index)
+ pojoFieldMapping.get(index)
}
else {
index
@@ -1024,7 +1054,7 @@ class CodeGenerator(
}
val resultTypeTerm = primitiveTypeTermForTypeInfo(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
- val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index)
+ val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
val inputCheckCode =
s"""
@@ -1047,12 +1077,13 @@ class CodeGenerator(
private def generateFieldAccess(
inputType: TypeInformation[_],
inputTerm: String,
- index: Int)
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
: GeneratedExpression = {
inputType match {
case ct: CompositeType[_] =>
- val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]] && inputPojoFieldMapping.nonEmpty) {
- inputPojoFieldMapping.get(index)
+ val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]] && pojoFieldMapping.nonEmpty) {
+ pojoFieldMapping.get(index)
}
else {
index
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala
new file mode 100644
index 0000000..9b144ba
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala
@@ -0,0 +1,369 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.codegen.calls
+
+import java.lang.reflect.Method
+
+import org.apache.calcite.avatica.util.TimeUnitRange
+import org.apache.calcite.sql.SqlOperator
+import org.apache.calcite.sql.fun.SqlStdOperatorTable._
+import org.apache.calcite.sql.fun.SqlTrimFunction
+import org.apache.calcite.util.BuiltInMethod
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.GenericTypeInfo
+import org.apache.flink.api.table.functions.utils.{TableSqlFunction, ScalarSqlFunction}
+
+import scala.collection.mutable
+
+/**
+ * Global hub for user-defined and built-in advanced SQL functions.
+ */
+object FunctionGenerator {
+
+ private val sqlFunctions: mutable.Map[(SqlOperator, Seq[TypeInformation[_]]), CallGenerator] =
+ mutable.Map()
+
+ // ----------------------------------------------------------------------------------------------
+ // String functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ SUBSTRING,
+ Seq(STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.SUBSTRING.method)
+
+ addSqlFunctionMethod(
+ SUBSTRING,
+ Seq(STRING_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.SUBSTRING.method)
+
+ addSqlFunction(
+ TRIM,
+ Seq(new GenericTypeInfo(classOf[SqlTrimFunction.Flag]), STRING_TYPE_INFO, STRING_TYPE_INFO),
+ new TrimCallGen())
+
+ addSqlFunctionMethod(
+ CHAR_LENGTH,
+ Seq(STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.CHAR_LENGTH.method)
+
+ addSqlFunctionMethod(
+ CHARACTER_LENGTH,
+ Seq(STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.CHAR_LENGTH.method)
+
+ addSqlFunctionMethod(
+ UPPER,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.UPPER.method)
+
+ addSqlFunctionMethod(
+ LOWER,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.LOWER.method)
+
+ addSqlFunctionMethod(
+ INITCAP,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.INITCAP.method)
+
+ addSqlFunctionMethod(
+ LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethod.LIKE.method)
+
+ addSqlFunctionMethod(
+ LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethods.LIKE_WITH_ESCAPE)
+
+ addSqlFunctionNotMethod(
+ NOT_LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BuiltInMethod.LIKE.method)
+
+ addSqlFunctionMethod(
+ SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethod.SIMILAR.method)
+
+ addSqlFunctionMethod(
+ SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethods.SIMILAR_WITH_ESCAPE)
+
+ addSqlFunctionNotMethod(
+ NOT_SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BuiltInMethod.SIMILAR.method)
+
+ addSqlFunctionMethod(
+ POSITION,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.POSITION.method)
+
+ addSqlFunctionMethod(
+ OVERLAY,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.OVERLAY.method)
+
+ addSqlFunctionMethod(
+ OVERLAY,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.OVERLAY.method)
+
+ // ----------------------------------------------------------------------------------------------
+ // Arithmetic functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ LOG10,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.LOG10)
+
+ addSqlFunctionMethod(
+ LN,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.LN)
+
+ addSqlFunctionMethod(
+ EXP,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.EXP)
+
+ addSqlFunctionMethod(
+ POWER,
+ Seq(DOUBLE_TYPE_INFO, DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.POWER)
+
+ addSqlFunctionMethod(
+ POWER,
+ Seq(DOUBLE_TYPE_INFO, BIG_DEC_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.POWER_DEC)
+
+ addSqlFunction(
+ ABS,
+ Seq(DOUBLE_TYPE_INFO),
+ new MultiTypeMethodCallGen(BuiltInMethods.ABS))
+
+ addSqlFunction(
+ ABS,
+ Seq(BIG_DEC_TYPE_INFO),
+ new MultiTypeMethodCallGen(BuiltInMethods.ABS_DEC))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(DOUBLE_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(BIG_DEC_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
+
+ addSqlFunction(
+ CEIL,
+ Seq(DOUBLE_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.CEIL.method))
+
+ addSqlFunction(
+ CEIL,
+ Seq(BIG_DEC_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.CEIL.method))
+
+ // ----------------------------------------------------------------------------------------------
+ // Temporal functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ EXTRACT_DATE,
+ Seq(new GenericTypeInfo(classOf[TimeUnitRange]), LONG_TYPE_INFO),
+ LONG_TYPE_INFO,
+ BuiltInMethod.UNIX_DATE_EXTRACT.method)
+
+ addSqlFunctionMethod(
+ EXTRACT_DATE,
+ Seq(new GenericTypeInfo(classOf[TimeUnitRange]), SqlTimeTypeInfo.DATE),
+ LONG_TYPE_INFO,
+ BuiltInMethod.UNIX_DATE_EXTRACT.method)
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_TIMESTAMP_FLOOR.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_TIMESTAMP_CEIL.method)))
+
+ addSqlFunction(
+ CURRENT_DATE,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.DATE, local = false))
+
+ addSqlFunction(
+ CURRENT_TIME,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = false))
+
+ addSqlFunction(
+ CURRENT_TIMESTAMP,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = false))
+
+ addSqlFunction(
+ LOCALTIME,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = true))
+
+ addSqlFunction(
+ LOCALTIMESTAMP,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = true))
+
+ // ----------------------------------------------------------------------------------------------
+
+ /**
+ * Returns a [[CallGenerator]] that generates all required code for calling the given
+ * [[SqlOperator]].
+ *
+ * @param sqlOperator SQL operator (might be overloaded)
+ * @param operandTypes actual operand types
+ * @param resultType expected return type
+ * @return [[CallGenerator]]
+ */
+ def getCallGenerator(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ resultType: TypeInformation[_])
+ : Option[CallGenerator] = sqlOperator match {
+
+ // user-defined scalar function
+ case ssf: ScalarSqlFunction =>
+ Some(
+ new ScalarFunctionCallGen(
+ ssf.getScalarFunction,
+ operandTypes,
+ resultType
+ )
+ )
+
+ // user-defined table function
+ case tsf: TableSqlFunction =>
+ Some(
+ new TableFunctionCallGen(
+ tsf.getTableFunction,
+ operandTypes,
+ resultType
+ )
+ )
+
+ // built-in scalar function
+ case _ =>
+ sqlFunctions.get((sqlOperator, operandTypes))
+ .orElse(sqlFunctions.find(entry => entry._1._1 == sqlOperator
+ && entry._1._2.length == operandTypes.length
+ && entry._1._2.zip(operandTypes).forall {
+ case (x: BasicTypeInfo[_], y: BasicTypeInfo[_]) => y.shouldAutocastTo(x) || x == y
+ case _ => false
+ }).map(_._2))
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ private def addSqlFunctionMethod(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ returnType: TypeInformation[_],
+ method: Method)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) = new MethodCallGen(returnType, method)
+ }
+
+ private def addSqlFunctionNotMethod(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ method: Method)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) =
+ new NotCallGenerator(new MethodCallGen(BOOLEAN_TYPE_INFO, method))
+ }
+
+ private def addSqlFunction(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ callGenerator: CallGenerator)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) = callGenerator
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala
deleted file mode 100644
index e7c436a..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala
+++ /dev/null
@@ -1,359 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.api.table.codegen.calls
-
-import java.lang.reflect.Method
-
-import org.apache.calcite.avatica.util.TimeUnitRange
-import org.apache.calcite.sql.SqlOperator
-import org.apache.calcite.sql.fun.SqlStdOperatorTable._
-import org.apache.calcite.sql.fun.SqlTrimFunction
-import org.apache.calcite.util.BuiltInMethod
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.GenericTypeInfo
-import org.apache.flink.api.table.functions.utils.ScalarSqlFunction
-
-import scala.collection.mutable
-
-/**
- * Global hub for user-defined and built-in advanced SQL scalar functions.
- */
-object ScalarFunctions {
-
- private val sqlFunctions: mutable.Map[(SqlOperator, Seq[TypeInformation[_]]), CallGenerator] =
- mutable.Map()
-
- // ----------------------------------------------------------------------------------------------
- // String functions
- // ----------------------------------------------------------------------------------------------
-
- addSqlFunctionMethod(
- SUBSTRING,
- Seq(STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.SUBSTRING.method)
-
- addSqlFunctionMethod(
- SUBSTRING,
- Seq(STRING_TYPE_INFO, INT_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.SUBSTRING.method)
-
- addSqlFunction(
- TRIM,
- Seq(new GenericTypeInfo(classOf[SqlTrimFunction.Flag]), STRING_TYPE_INFO, STRING_TYPE_INFO),
- new TrimCallGen())
-
- addSqlFunctionMethod(
- CHAR_LENGTH,
- Seq(STRING_TYPE_INFO),
- INT_TYPE_INFO,
- BuiltInMethod.CHAR_LENGTH.method)
-
- addSqlFunctionMethod(
- CHARACTER_LENGTH,
- Seq(STRING_TYPE_INFO),
- INT_TYPE_INFO,
- BuiltInMethod.CHAR_LENGTH.method)
-
- addSqlFunctionMethod(
- UPPER,
- Seq(STRING_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.UPPER.method)
-
- addSqlFunctionMethod(
- LOWER,
- Seq(STRING_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.LOWER.method)
-
- addSqlFunctionMethod(
- INITCAP,
- Seq(STRING_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.INITCAP.method)
-
- addSqlFunctionMethod(
- LIKE,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- BOOLEAN_TYPE_INFO,
- BuiltInMethod.LIKE.method)
-
- addSqlFunctionMethod(
- LIKE,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
- BOOLEAN_TYPE_INFO,
- BuiltInMethods.LIKE_WITH_ESCAPE)
-
- addSqlFunctionNotMethod(
- NOT_LIKE,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- BuiltInMethod.LIKE.method)
-
- addSqlFunctionMethod(
- SIMILAR_TO,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- BOOLEAN_TYPE_INFO,
- BuiltInMethod.SIMILAR.method)
-
- addSqlFunctionMethod(
- SIMILAR_TO,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
- BOOLEAN_TYPE_INFO,
- BuiltInMethods.SIMILAR_WITH_ESCAPE)
-
- addSqlFunctionNotMethod(
- NOT_SIMILAR_TO,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- BuiltInMethod.SIMILAR.method)
-
- addSqlFunctionMethod(
- POSITION,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- INT_TYPE_INFO,
- BuiltInMethod.POSITION.method)
-
- addSqlFunctionMethod(
- OVERLAY,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.OVERLAY.method)
-
- addSqlFunctionMethod(
- OVERLAY,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.OVERLAY.method)
-
- // ----------------------------------------------------------------------------------------------
- // Arithmetic functions
- // ----------------------------------------------------------------------------------------------
-
- addSqlFunctionMethod(
- LOG10,
- Seq(DOUBLE_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.LOG10)
-
- addSqlFunctionMethod(
- LN,
- Seq(DOUBLE_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.LN)
-
- addSqlFunctionMethod(
- EXP,
- Seq(DOUBLE_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.EXP)
-
- addSqlFunctionMethod(
- POWER,
- Seq(DOUBLE_TYPE_INFO, DOUBLE_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.POWER)
-
- addSqlFunctionMethod(
- POWER,
- Seq(DOUBLE_TYPE_INFO, BIG_DEC_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.POWER_DEC)
-
- addSqlFunction(
- ABS,
- Seq(DOUBLE_TYPE_INFO),
- new MultiTypeMethodCallGen(BuiltInMethods.ABS))
-
- addSqlFunction(
- ABS,
- Seq(BIG_DEC_TYPE_INFO),
- new MultiTypeMethodCallGen(BuiltInMethods.ABS_DEC))
-
- addSqlFunction(
- FLOOR,
- Seq(DOUBLE_TYPE_INFO),
- new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
-
- addSqlFunction(
- FLOOR,
- Seq(BIG_DEC_TYPE_INFO),
- new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
-
- addSqlFunction(
- CEIL,
- Seq(DOUBLE_TYPE_INFO),
- new FloorCeilCallGen(BuiltInMethod.CEIL.method))
-
- addSqlFunction(
- CEIL,
- Seq(BIG_DEC_TYPE_INFO),
- new FloorCeilCallGen(BuiltInMethod.CEIL.method))
-
- // ----------------------------------------------------------------------------------------------
- // Temporal functions
- // ----------------------------------------------------------------------------------------------
-
- addSqlFunctionMethod(
- EXTRACT_DATE,
- Seq(new GenericTypeInfo(classOf[TimeUnitRange]), LONG_TYPE_INFO),
- LONG_TYPE_INFO,
- BuiltInMethod.UNIX_DATE_EXTRACT.method)
-
- addSqlFunctionMethod(
- EXTRACT_DATE,
- Seq(new GenericTypeInfo(classOf[TimeUnitRange]), SqlTimeTypeInfo.DATE),
- LONG_TYPE_INFO,
- BuiltInMethod.UNIX_DATE_EXTRACT.method)
-
- addSqlFunction(
- FLOOR,
- Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.FLOOR.method,
- Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
-
- addSqlFunction(
- FLOOR,
- Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.FLOOR.method,
- Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
-
- addSqlFunction(
- FLOOR,
- Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.FLOOR.method,
- Some(BuiltInMethod.UNIX_TIMESTAMP_FLOOR.method)))
-
- addSqlFunction(
- CEIL,
- Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.CEIL.method,
- Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
-
- addSqlFunction(
- CEIL,
- Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.CEIL.method,
- Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
-
- addSqlFunction(
- CEIL,
- Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.CEIL.method,
- Some(BuiltInMethod.UNIX_TIMESTAMP_CEIL.method)))
-
- addSqlFunction(
- CURRENT_DATE,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.DATE, local = false))
-
- addSqlFunction(
- CURRENT_TIME,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = false))
-
- addSqlFunction(
- CURRENT_TIMESTAMP,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = false))
-
- addSqlFunction(
- LOCALTIME,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = true))
-
- addSqlFunction(
- LOCALTIMESTAMP,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = true))
-
- // ----------------------------------------------------------------------------------------------
-
- /**
- * Returns a [[CallGenerator]] that generates all required code for calling the given
- * [[SqlOperator]].
- *
- * @param sqlOperator SQL operator (might be overloaded)
- * @param operandTypes actual operand types
- * @param resultType expected return type
- * @return [[CallGenerator]]
- */
- def getCallGenerator(
- sqlOperator: SqlOperator,
- operandTypes: Seq[TypeInformation[_]],
- resultType: TypeInformation[_])
- : Option[CallGenerator] = sqlOperator match {
-
- // user-defined scalar function
- case ssf: ScalarSqlFunction =>
- Some(
- new ScalarFunctionCallGen(
- ssf.getScalarFunction,
- operandTypes,
- resultType
- )
- )
-
- // built-in scalar function
- case _ =>
- sqlFunctions.get((sqlOperator, operandTypes))
- .orElse(sqlFunctions.find(entry => entry._1._1 == sqlOperator
- && entry._1._2.length == operandTypes.length
- && entry._1._2.zip(operandTypes).forall {
- case (x: BasicTypeInfo[_], y: BasicTypeInfo[_]) => y.shouldAutocastTo(x) || x == y
- case _ => false
- }).map(_._2))
- }
-
- // ----------------------------------------------------------------------------------------------
-
- private def addSqlFunctionMethod(
- sqlOperator: SqlOperator,
- operandTypes: Seq[TypeInformation[_]],
- returnType: TypeInformation[_],
- method: Method)
- : Unit = {
- sqlFunctions((sqlOperator, operandTypes)) = new MethodCallGen(returnType, method)
- }
-
- private def addSqlFunctionNotMethod(
- sqlOperator: SqlOperator,
- operandTypes: Seq[TypeInformation[_]],
- method: Method)
- : Unit = {
- sqlFunctions((sqlOperator, operandTypes)) =
- new NotCallGenerator(new MethodCallGen(BOOLEAN_TYPE_INFO, method))
- }
-
- private def addSqlFunction(
- sqlOperator: SqlOperator,
- operandTypes: Seq[TypeInformation[_]],
- callGenerator: CallGenerator)
- : Unit = {
- sqlFunctions((sqlOperator, operandTypes)) = callGenerator
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
new file mode 100644
index 0000000..27cb43f
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.codegen.calls
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.codegen.CodeGenUtils._
+import org.apache.flink.api.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression}
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._
+
+/**
+ * Generates a call to user-defined [[TableFunction]].
+ *
+ * @param tableFunction user-defined [[TableFunction]] that might be overloaded
+ * @param signature actual signature with which the function is called
+ * @param returnType actual return type required by the surrounding
+ */
+class TableFunctionCallGen(
+ tableFunction: TableFunction[_],
+ signature: Seq[TypeInformation[_]],
+ returnType: TypeInformation[_])
+ extends CallGenerator {
+
+ override def generate(
+ codeGenerator: CodeGenerator,
+ operands: Seq[GeneratedExpression])
+ : GeneratedExpression = {
+ // determine function signature
+ val matchingSignature = getSignature(tableFunction, signature)
+ .getOrElse(throw new CodeGenException("No matching signature found."))
+
+ // convert parameters for function (output boxing)
+ val parameters = matchingSignature
+ .zip(operands)
+ .map { case (paramClass, operandExpr) =>
+ if (paramClass.isPrimitive) {
+ operandExpr
+ } else {
+ val boxedTypeTerm = boxedTypeTermForTypeInfo(operandExpr.resultType)
+ val boxedExpr = codeGenerator.generateOutputFieldBoxing(operandExpr)
+ val exprOrNull: String = if (codeGenerator.nullCheck) {
+ s"${boxedExpr.nullTerm} ? null : ($boxedTypeTerm) ${boxedExpr.resultTerm}"
+ } else {
+ boxedExpr.resultTerm
+ }
+ boxedExpr.copy(resultTerm = exprOrNull)
+ }
+ }
+
+ // generate function call
+ val functionReference = codeGenerator.addReusableFunction(tableFunction)
+ val functionCallCode =
+ s"""
+ |${parameters.map(_.code).mkString("\n")}
+ |$functionReference.clear();
+ |$functionReference.eval(${parameters.map(_.resultTerm).mkString(", ")});
+ |""".stripMargin
+
+ // has no result
+ GeneratedExpression(
+ functionReference,
+ GeneratedExpression.NEVER_NULL,
+ functionCallCode,
+ returnType)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
index 6b6c129..6cd63ff 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
@@ -447,7 +447,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val alias: PackratParser[Expression] = logic ~ AS ~ fieldReference ^^ {
case e ~ _ ~ name => Alias(e, name.name)
- } | logic
+ } | logic ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
+ case e ~ _ ~ _ ~ names ~ _ => Alias(e, names.head.name, names.drop(1).map(_.name))
+ } | logic
lazy val expression: PackratParser[Expression] = alias |
failure("Invalid expression.")
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
index 39367be..3e8d8b1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
@@ -19,10 +19,12 @@ package org.apache.flink.api.table.expressions
import org.apache.calcite.rex.RexNode
import org.apache.calcite.tools.RelBuilder
-import org.apache.flink.api.table.functions.ScalarFunction
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, signatureToString, signaturesToString}
-import org.apache.flink.api.table.validate.{ValidationResult, ValidationFailure, ValidationSuccess}
-import org.apache.flink.api.table.{FlinkTypeFactory, UnresolvedException}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction}
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.api.table.plan.logical.{LogicalNode, LogicalTableFunctionCall}
+import org.apache.flink.api.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
+import org.apache.flink.api.table.{FlinkTypeFactory, UnresolvedException, ValidationException}
/**
* General expression for unresolved function calls. The function can be a built-in
@@ -63,11 +65,15 @@ case class ScalarFunctionCall(
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
relBuilder.call(
- scalarFunction.getSqlFunction(scalarFunction.toString, typeFactory),
+ createScalarSqlFunction(
+ scalarFunction.getClass.getCanonicalName,
+ scalarFunction,
+ typeFactory),
parameters.map(_.toRexNode): _*)
}
- override def toString = s"$scalarFunction(${parameters.mkString(", ")})"
+ override def toString =
+ s"${scalarFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})"
override private[flink] def resultType = getResultType(scalarFunction, foundSignature.get)
@@ -85,3 +91,68 @@ case class ScalarFunctionCall(
}
}
+
+
+/**
+ *
+ * Expression for calling a user-defined table function with actual parameters.
+ *
+ * @param functionName function name
+ * @param tableFunction user-defined table function
+ * @param parameters actual parameters of function
+ * @param resultType type information of returned table
+ */
+case class TableFunctionCall(
+ functionName: String,
+ tableFunction: TableFunction[_],
+ parameters: Seq[Expression],
+ resultType: TypeInformation[_])
+ extends Expression {
+
+ private var aliases: Option[Seq[String]] = None
+
+ override private[flink] def children: Seq[Expression] = parameters
+
+ /**
+ * Assigns an alias for this table function returned fields that the following `select()` clause
+ * can refer to.
+ *
+ * @param aliasList alias for this table function returned fields
+ * @return this table function call
+ */
+ private[flink] def as(aliasList: Option[Seq[String]]): TableFunctionCall = {
+ this.aliases = aliasList
+ this
+ }
+
+ /**
+ * Converts an API class to a logical node for planning.
+ */
+ private[flink] def toLogicalTableFunctionCall(child: LogicalNode): LogicalTableFunctionCall = {
+ val originNames = getFieldInfo(resultType)._1
+
+ // determine the final field names
+ val fieldNames = if (aliases.isDefined) {
+ val aliasList = aliases.get
+ if (aliasList.length != originNames.length) {
+ throw ValidationException(
+ s"List of column aliases must have same degree as table; " +
+ s"the returned table of function '$functionName' has ${originNames.length} " +
+ s"columns (${originNames.mkString(",")}), " +
+ s"whereas alias list has ${aliasList.length} columns")
+ } else {
+ aliasList.toArray
+ }
+ } else {
+ originNames
+ }
+
+ LogicalTableFunctionCall(
+ functionName,
+ tableFunction,
+ parameters,
+ resultType,
+ fieldNames,
+ child)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala
index c7817bf..e651bb3 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala
@@ -67,7 +67,7 @@ case class ResolvedFieldReference(
}
}
-case class Alias(child: Expression, name: String)
+case class Alias(child: Expression, name: String, extraNames: Seq[String] = Seq())
extends UnaryExpression with NamedExpression {
override def toString = s"$child as '$name"
@@ -80,7 +80,7 @@ case class Alias(child: Expression, name: String)
override private[flink] def makeCopy(anyRefs: Array[AnyRef]): this.type = {
val child: Expression = anyRefs.head.asInstanceOf[Expression]
- copy(child, name).asInstanceOf[this.type]
+ copy(child, name, extraNames).asInstanceOf[this.type]
}
override private[flink] def toAttribute: Attribute = {
@@ -94,6 +94,8 @@ case class Alias(child: Expression, name: String)
override private[flink] def validateInput(): ValidationResult = {
if (name == "*") {
ValidationFailure("Alias can not accept '*' as name.")
+ } else if (extraNames.nonEmpty) {
+ ValidationFailure("Invalid call to Alias with multiple names.")
} else {
ValidationSuccess
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
index 5f9d834..86d9d66 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
@@ -60,47 +60,6 @@ abstract class ScalarFunction extends UserDefinedFunction {
ScalarFunctionCall(this, params)
}
- // ----------------------------------------------------------------------------------------------
-
- private val evalMethods = checkAndExtractEvalMethods()
- private lazy val signatures = evalMethods.map(_.getParameterTypes)
-
- /**
- * Extracts evaluation methods and throws a [[ValidationException]] if no implementation
- * can be found.
- */
- private def checkAndExtractEvalMethods(): Array[Method] = {
- val methods = getClass.asSubclass(classOf[ScalarFunction])
- .getDeclaredMethods
- .filter { m =>
- val modifiers = m.getModifiers
- m.getName == "eval" && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers)
- }
-
- if (methods.isEmpty) {
- throw new ValidationException(s"Scalar function class '$this' does not implement at least " +
- s"one method named 'eval' which is public and not abstract.")
- } else {
- methods
- }
- }
-
- /**
- * Returns all found evaluation methods of the possibly overloaded function.
- */
- private[flink] final def getEvalMethods: Array[Method] = evalMethods
-
- /**
- * Returns all found signature of the possibly overloaded function.
- */
- private[flink] final def getSignatures: Array[Array[Class[_]]] = signatures
-
- override private[flink] final def createSqlFunction(
- name: String,
- typeFactory: FlinkTypeFactory)
- : SqlFunction = {
- new ScalarSqlFunction(name, this, typeFactory)
- }
// ----------------------------------------------------------------------------------------------
@@ -135,7 +94,8 @@ abstract class ScalarFunction extends UserDefinedFunction {
TypeExtractor.getForClass(c)
} catch {
case ite: InvalidTypesException =>
- throw new ValidationException(s"Parameter types of scalar function '$this' cannot be " +
+ throw new ValidationException(
+ s"Parameter types of scalar function '${this.getClass.getCanonicalName}' cannot be " +
s"automatically determined. Please provide type information manually.")
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
new file mode 100644
index 0000000..98a2921
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.functions
+
+import java.util
+
+import org.apache.flink.api.common.functions.InvalidTypesException
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.flink.api.table.ValidationException
+
+/**
+ * Base class for a user-defined table function (UDTF). A user-defined table functions works on
+ * zero, one, or multiple scalar values as input and returns multiple rows as output.
+ *
+ * The behavior of a [[TableFunction]] can be defined by implementing a custom evaluation
+ * method. An evaluation method must be declared publicly and named "eval". Evaluation methods
+ * can also be overloaded by implementing multiple methods named "eval".
+ *
+ * User-defined functions must have a default constructor and must be instantiable during runtime.
+ *
+ * By default the result type of an evaluation method is determined by Flink's type extraction
+ * facilities. This is sufficient for basic types or simple POJOs but might be wrong for more
+ * complex, custom, or composite types. In these cases [[TypeInformation]] of the result type
+ * can be manually defined by overriding [[getResultType()]].
+ *
+ * Internally, the Table/SQL API code generation works with primitive values as much as possible.
+ * If a user-defined table function should not introduce much overhead during runtime, it is
+ * recommended to declare parameters and result types as primitive types instead of their boxed
+ * classes. DATE/TIME is equal to int, TIMESTAMP is equal to long.
+ *
+ * Example:
+ *
+ * {{{
+ *
+ * public class Split extends TableFunction<String> {
+ *
+ * // implement an "eval" method with several parameters you want
+ * public void eval(String str) {
+ * for (String s : str.split(" ")) {
+ * collect(s); // use collect(...) to emit an output row
+ * }
+ * }
+ *
+ * // can overloading eval method here ...
+ * }
+ *
+ * val tEnv: TableEnvironment = ...
+ * val table: Table = ... // schema: [a: String]
+ *
+ * // for Scala users
+ * val split = new Split()
+ * table.crossApply(split('c) as ('s)).select('a, 's)
+ *
+ * // for Java users
+ * tEnv.registerFunction("split", new Split()) // register table function first
+ * table.crossApply("split(a) as (s)").select("a, s")
+ *
+ * // for SQL users
+ * tEnv.registerFunction("split", new Split()) // register table function first
+ * tEnv.sql("SELECT a, s FROM MyTable, LATERAL TABLE(split(a)) as T(s)")
+ *
+ * }}}
+ *
+ * @tparam T The type of the output row
+ */
+abstract class TableFunction[T] extends UserDefinedFunction {
+
+ private val rows: util.ArrayList[T] = new util.ArrayList[T]()
+
+ /**
+ * Emit an output row.
+ *
+ * @param row the output row
+ */
+ protected def collect(row: T): Unit = {
+ // cache rows for now, maybe immediately process them further
+ rows.add(row)
+ }
+
+ /**
+ * Internal use. Get an iterator of the buffered rows.
+ */
+ def getRowsIterator = rows.iterator()
+
+ /**
+ * Internal use. Clear buffered rows.
+ */
+ def clear() = rows.clear()
+
+ // ----------------------------------------------------------------------------------------------
+
+ /**
+ * Returns the result type of the evaluation method with a given signature.
+ *
+ * This method needs to be overriden in case Flink's type extraction facilities are not
+ * sufficient to extract the [[TypeInformation]] based on the return type of the evaluation
+ * method. Flink's type extraction facilities can handle basic types or
+ * simple POJOs but might be wrong for more complex, custom, or composite types.
+ *
+ * @return [[TypeInformation]] of result type or null if Flink should determine the type
+ */
+ def getResultType: TypeInformation[T] = null
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala
index 62afef0..cdf6b07 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala
@@ -15,47 +15,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.api.table.functions
-import org.apache.calcite.sql.SqlFunction
-import org.apache.flink.api.table.FlinkTypeFactory
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.checkForInstantiation
-
-import scala.collection.mutable
-
/**
* Base class for all user-defined functions such as scalar functions, table functions,
* or aggregation functions.
*
* User-defined functions must have a default constructor and must be instantiable during runtime.
*/
-abstract class UserDefinedFunction {
-
- // we cache SQL functions to reduce amount of created objects
- // (i.e. for type inference, validation, etc.)
- private val cachedSqlFunctions = mutable.HashMap[String, SqlFunction]()
-
- // check if function can be instantiated
- checkForInstantiation(this.getClass)
-
- /**
- * Returns the corresponding [[SqlFunction]]. Creates an instance if not already created.
- */
- private[flink] final def getSqlFunction(
- name: String,
- typeFactory: FlinkTypeFactory)
- : SqlFunction = {
- cachedSqlFunctions.getOrElseUpdate(name, createSqlFunction(name, typeFactory))
- }
-
- /**
- * Creates corresponding [[SqlFunction]].
- */
- private[flink] def createSqlFunction(
- name: String,
- typeFactory: FlinkTypeFactory)
- : SqlFunction
-
- override def toString = getClass.getCanonicalName
+trait UserDefinedFunction {
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
index 531313e..0a987aa 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
@@ -26,7 +26,7 @@ import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.functions.ScalarFunction
import org.apache.flink.api.table.functions.utils.ScalarSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, signatureToString, signaturesToString}
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, getSignatures, signatureToString, signaturesToString}
import org.apache.flink.api.table.{FlinkTypeFactory, ValidationException}
import scala.collection.JavaConverters._
@@ -123,6 +123,8 @@ object ScalarSqlFunction {
name: String,
scalarFunction: ScalarFunction)
: SqlOperandTypeChecker = {
+
+ val signatures = getSignatures(scalarFunction)
/**
* Operand type checker based on [[ScalarFunction]] given information.
*/
@@ -132,7 +134,7 @@ object ScalarSqlFunction {
}
override def getOperandCountRange: SqlOperandCountRange = {
- val signatureLengths = scalarFunction.getSignatures.map(_.length)
+ val signatureLengths = signatures.map(_.length)
SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
new file mode 100644
index 0000000..6eadfbc
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.functions.utils
+
+import com.google.common.base.Predicate
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.sql._
+import org.apache.calcite.sql.`type`._
+import org.apache.calcite.sql.parser.SqlParserPos
+import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction
+import org.apache.calcite.util.Util
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.FlinkTypeFactory
+import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
+
+import scala.collection.JavaConverters._
+import java.util
+
+
+/**
+ * Calcite wrapper for user-defined table functions.
+ */
+class TableSqlFunction(
+ name: String,
+ udtf: TableFunction[_],
+ rowTypeInfo: TypeInformation[_],
+ returnTypeInference: SqlReturnTypeInference,
+ operandTypeInference: SqlOperandTypeInference,
+ operandTypeChecker: SqlOperandTypeChecker,
+ paramTypes: util.List[RelDataType],
+ functionImpl: FlinkTableFunctionImpl[_])
+ extends SqlUserDefinedTableFunction(
+ new SqlIdentifier(name, SqlParserPos.ZERO),
+ returnTypeInference,
+ operandTypeInference,
+ operandTypeChecker,
+ paramTypes,
+ functionImpl) {
+
+ /**
+ * Get the user-defined table function
+ */
+ def getTableFunction = udtf
+
+ /**
+ * Get the returned table type information of the table function
+ */
+ def getRowTypeInfo = rowTypeInfo
+
+ /**
+ * Get additional mapping information if the returned table type is a POJO
+ * (POJO types have no deterministic field order)
+ */
+ def getPojoFieldMapping = functionImpl.fieldIndexes
+
+}
+
+object TableSqlFunction {
+ /**
+ * Util function to create a [[TableSqlFunction]]
+ * @param name function name (used by SQL parser)
+ * @param udtf user defined table function to be called
+ * @param rowTypeInfo the row type information generated by the table function
+ * @param typeFactory type factory for converting Flink's between Calcite's types
+ * @param functionImpl calcite table function schema
+ * @return [[TableSqlFunction]]
+ */
+ def apply(
+ name: String,
+ udtf: TableFunction[_],
+ rowTypeInfo: TypeInformation[_],
+ typeFactory: FlinkTypeFactory,
+ functionImpl: FlinkTableFunctionImpl[_]): TableSqlFunction = {
+
+ val argTypes: util.List[RelDataType] = new util.ArrayList[RelDataType]
+ val typeFamilies: util.List[SqlTypeFamily] = new util.ArrayList[SqlTypeFamily]
+ // derives operands' data types and type families
+ functionImpl.getParameters.asScala.foreach{ o =>
+ val relType: RelDataType = o.getType(typeFactory)
+ argTypes.add(relType)
+ typeFamilies.add(Util.first(relType.getSqlTypeName.getFamily, SqlTypeFamily.ANY))
+ }
+ // derives whether the 'input'th parameter of a method is optional.
+ val optional: Predicate[Integer] = new Predicate[Integer]() {
+ def apply(input: Integer): Boolean = {
+ functionImpl.getParameters.get(input).isOptional
+ }
+ }
+ // create type check for the operands
+ val typeChecker: FamilyOperandTypeChecker = OperandTypes.family(typeFamilies, optional)
+
+ new TableSqlFunction(
+ name,
+ udtf,
+ rowTypeInfo,
+ ReturnTypes.CURSOR,
+ InferTypes.explicit(argTypes),
+ typeChecker,
+ argTypes,
+ functionImpl)
+ }
+}