You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/02/20 14:39:26 UTC
flink git commit: [FLINK-5571] [table] add open and close methods for
UserDefinedFunction
Repository: flink
Updated Branches:
refs/heads/master 0bdf3a74c -> b820fd3ca
[FLINK-5571] [table] add open and close methods for UserDefinedFunction
This closes #3176.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b820fd3c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b820fd3c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b820fd3c
Branch: refs/heads/master
Commit: b820fd3ca038e411bc7f43e1c35637bf62981fe5
Parents: 0bdf3a7
Author: godfreyhe <go...@163.com>
Authored: Fri Jan 20 14:42:12 2017 +0800
Committer: twalthr <tw...@apache.org>
Committed: Mon Feb 20 15:38:34 2017 +0100
----------------------------------------------------------------------
docs/dev/table_api.md | 92 +++++++++++++++++
.../flink/table/codegen/CodeGenerator.scala | 78 +++++++++++---
.../flink/table/functions/FunctionContext.scala | 66 ++++++++++++
.../table/functions/UserDefinedFunction.scala | 17 ++-
.../table/runtime/CorrelateFlatMapRunner.scala | 7 ++
.../flink/table/runtime/FlatMapRunner.scala | 7 ++
.../table/api/scala/batch/sql/CalcITCase.scala | 6 +-
.../api/scala/batch/table/CalcITCase.scala | 8 +-
.../table/api/scala/stream/sql/SqlITCase.scala | 6 +-
.../api/scala/stream/table/CalcITCase.scala | 8 +-
.../UserDefinedScalarFunctionTest.scala | 31 +++++-
.../expressions/utils/ExpressionTestBase.scala | 40 ++++++-
.../utils/UserDefinedScalarFunctions.scala | 97 ++++++++++++++++-
.../runtime/dataset/DataSetCalcITCase.scala | 103 +++++++++++++++++++
.../dataset/DataSetCorrelateITCase.scala | 52 +++++++++-
.../datastream/DataStreamCalcITCase.scala | 81 +++++++++++++++
.../datastream/DataStreamCorrelateITCase.scala | 67 ++++++++++--
.../utils/UserDefinedFunctionTestUtils.scala | 53 ++++++++++
.../table/utils/UserDefinedTableFunctions.scala | 58 ++++++++++-
19 files changed, 828 insertions(+), 49 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/docs/dev/table_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md
index 99ae711..22fd636 100644
--- a/docs/dev/table_api.md
+++ b/docs/dev/table_api.md
@@ -4724,6 +4724,11 @@ ELEMENT(ARRAY)
</div>
</div>
+{% top %}
+
+User-defined Functions
+----------------
+
### User-defined Scalar Functions
If a required scalar function is not contained in the built-in functions, it is possible to define custom, user-defined scalar functions for both the Table API and SQL. A user-defined scalar functions maps zero, one, or multiple scalar values to a new scalar value.
@@ -4933,6 +4938,93 @@ class CustomTypeSplit extends TableFunction[Row] {
</div>
</div>
+### Advanced Function Features
+
+Sometimes it might be necessary for a user-defined function to get global runtime information or do some setup/clean-up work before the actual work. User-defined functions provide `open()` and `close()` methods that can be overriden and provide similar functionality as the methods in `RichFunction` of DataSet or DataStream API.
+
+The `open()` method is called once before the evaluation method. The `close()` method after the last call to the evaluation method.
+
+The `open()` method provides a `FunctionContext` that contains information about the context in which user-defined functions are executed, such as the metric group, the distributed cache files, or the global job parameters.
+
+The following information can be obtained by calling the corresponding methods of `FunctionContext`:
+
+| Method | Description |
+| :------------------------------------ | :----------------------------------------------------- |
+| `getMetricGroup()` | Metric group for this parallel subtask. |
+| `getCachedFile(name)` | Local temporary file copy of a distributed cache file. |
+| `getJobParameter(name, defaultValue)` | Global job parameter value associated with given key. |
+
+The following example snippet shows how to use `FunctionContext` in a scalar function for accessing a global job parameter:
+
+<div class="codetabs" markdown="1">
+<div data-lang="java" markdown="1">
+{% highlight java %}
+public class HashCode extends ScalarFunction {
+
+ private int factor = 0;
+
+ @Override
+ public void open(FunctionContext context) throws Exception {
+ // access "hashcode_factor" parameter
+ // "12" would be the default value if parameter does not exist
+ factor = Integer.valueOf(context.getJobParameter("hashcode_factor", "12"));
+ }
+
+ public int eval(String s) {
+ return s.hashCode() * factor;
+ }
+}
+
+ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env);
+
+// set job parameter
+Configuration conf = new Configuration();
+conf.setString("hashcode_factor", "31");
+env.getConfig().setGlobalJobParameters(conf);
+
+// register the function
+tableEnv.registerFunction("hashCode", new HashCode())
+
+// use the function in Java Table API
+myTable.select("string, string.hashCode(), hashCode(string)");
+
+// use the function in SQL
+tableEnv.sql("SELECT string, HASHCODE(string) FROM MyTable");
+{% endhighlight %}
+</div>
+
+<div data-lang="scala" markdown="1">
+{% highlight scala %}
+object hashCode extends ScalarFunction {
+
+ var hashcode_factor = 12;
+
+ override def open(context: FunctionContext): Unit = {
+ // access "hashcode_factor" parameter
+ // "12" would be the default value if parameter does not exist
+ hashcode_factor = context.getJobParameter("hashcode_factor", "12").toInt
+ }
+
+ def eval(s: String): Int = {
+ s.hashCode() * hashcode_factor
+ }
+}
+
+val tableEnv = TableEnvironment.getTableEnvironment(env)
+
+// use the function in Scala Table API
+myTable.select('string, hashCode('string))
+
+// register and use the function in SQL
+tableEnv.registerFunction("hashCode", hashCode)
+tableEnv.sql("SELECT string, HASHCODE(string) FROM MyTable");
+{% endhighlight %}
+
+</div>
+</div>
+
+
### Limitations
The following operations are not supported yet:
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index c679bd8..441b1c0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -25,12 +25,13 @@ import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
-import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction}
+import org.apache.flink.api.common.functions._
import org.apache.flink.api.common.io.GenericInputFormat
import org.apache.flink.api.common.typeinfo.{AtomicType, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, RowTypeInfo, TupleTypeInfo}
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
+import org.apache.flink.configuration.Configuration
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenUtils._
@@ -38,7 +39,7 @@ import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.codegen.calls.FunctionGenerator
import org.apache.flink.table.codegen.calls.ScalarOperators._
-import org.apache.flink.table.functions.UserDefinedFunction
+import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction}
import org.apache.flink.table.runtime.TableFunctionCollector
import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.types.Row
@@ -122,6 +123,14 @@ class CodeGenerator(
// we use a LinkedHashSet to keep the insertion order
private val reusableInitStatements = mutable.LinkedHashSet[String]()
+ // set of open statements for RichFunction that will be added only once
+ // we use a LinkedHashSet to keep the insertion order
+ private val reusableOpenStatements = mutable.LinkedHashSet[String]()
+
+ // set of close statements for RichFunction that will be added only once
+ // we use a LinkedHashSet to keep the insertion order
+ private val reusableCloseStatements = mutable.LinkedHashSet[String]()
+
// set of statements that will be added only once per record
// we use a LinkedHashSet to keep the insertion order
private val reusablePerRecordStatements = mutable.LinkedHashSet[String]()
@@ -150,6 +159,20 @@ class CodeGenerator(
}
/**
+ * @return code block of statements that need to be placed in the open() method of RichFunction
+ */
+ def reuseOpenCode(): String = {
+ reusableOpenStatements.mkString("", "\n", "\n")
+ }
+
+ /**
+ * @return code block of statements that need to be placed in the close() method of RichFunction
+ */
+ def reuseCloseCode(): String = {
+ reusableCloseStatements.mkString("", "\n", "\n")
+ }
+
+ /**
* @return code block of statements that need to be placed in the SAM of the Function
*/
def reusePerRecordCode(): String = {
@@ -240,27 +263,33 @@ class CodeGenerator(
// manual casting here
val samHeader =
// FlatMapFunction
- if (clazz == classOf[FlatMapFunction[_,_]]) {
+ if (clazz == classOf[FlatMapFunction[_, _]]) {
+ val baseClass = classOf[RichFlatMapFunction[_, _]]
val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
- (s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)",
+ (baseClass,
+ s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)",
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
}
// MapFunction
- else if (clazz == classOf[MapFunction[_,_]]) {
+ else if (clazz == classOf[MapFunction[_, _]]) {
+ val baseClass = classOf[RichMapFunction[_, _]]
val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
- ("Object map(Object _in1)",
+ (baseClass,
+ "Object map(Object _in1)",
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
}
// FlatJoinFunction
- else if (clazz == classOf[FlatJoinFunction[_,_,_]]) {
+ else if (clazz == classOf[FlatJoinFunction[_, _, _]]) {
+ val baseClass = classOf[RichFlatJoinFunction[_, _, _]]
val inputTypeTerm1 = boxedTypeTermForTypeInfo(input1)
val inputTypeTerm2 = boxedTypeTermForTypeInfo(input2.getOrElse(
- throw new CodeGenException("Input 2 for FlatJoinFunction should not be null")))
- (s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)",
+ throw new CodeGenException("Input 2 for FlatJoinFunction should not be null")))
+ (baseClass,
+ s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)",
List(s"$inputTypeTerm1 $input1Term = ($inputTypeTerm1) _in1;",
- s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;"))
+ s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;"))
}
else {
// TODO more functions
@@ -269,7 +298,7 @@ class CodeGenerator(
val funcCode = j"""
public class $funcName
- implements ${clazz.getCanonicalName} {
+ extends ${samHeader._1.getCanonicalName} {
${reuseMemberCode()}
@@ -280,12 +309,22 @@ class CodeGenerator(
${reuseConstructorCode(funcName)}
@Override
- public ${samHeader._1} throws Exception {
- ${samHeader._2.mkString("\n")}
+ public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception {
+ ${reuseOpenCode()}
+ }
+
+ @Override
+ public ${samHeader._2} throws Exception {
+ ${samHeader._3.mkString("\n")}
${reusePerRecordCode()}
${reuseInputUnboxingCode()}
$bodyCode
}
+
+ @Override
+ public void close() throws Exception {
+ ${reuseCloseCode()}
+ }
}
""".stripMargin
@@ -1480,6 +1519,19 @@ class CodeGenerator(
|$fieldTerm = ($classQualifier) $constructorTerm.newInstance();
""".stripMargin
reusableInitStatements.add(constructorAccessibility)
+
+ val openFunction =
+ s"""
+ |$fieldTerm.open(new ${classOf[FunctionContext].getCanonicalName}(getRuntimeContext()));
+ """.stripMargin
+ reusableOpenStatements.add(openFunction)
+
+ val closeFunction =
+ s"""
+ |$fieldTerm.close();
+ """.stripMargin
+ reusableCloseStatements.add(closeFunction)
+
fieldTerm
}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala
new file mode 100644
index 0000000..beeb686
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions
+
+import java.io.File
+
+import org.apache.flink.api.common.functions.RuntimeContext
+import org.apache.flink.metrics.MetricGroup
+
+/**
+ * A FunctionContext allows to obtain global runtime information about the context in which the
+ * user-defined function is executed. The information include the metric group,
+ * the distributed cache files, and the global job parameters.
+ *
+ * @param context the runtime context in which the Flink Function is executed
+ */
+class FunctionContext(context: RuntimeContext) {
+
+ /**
+ * Returns the metric group for this parallel subtask.
+ *
+ * @return metric group for this parallel subtask.
+ */
+ def getMetricGroup: MetricGroup = context.getMetricGroup
+
+ /**
+ * Gets the local temporary file copy of a distributed cache files.
+ *
+ * @param name distributed cache file name
+ * @return local temporary file copy of a distributed cache file.
+ */
+ def getCachedFile(name: String): File = context.getDistributedCache.getFile(name)
+
+ /**
+ * Gets the global job parameter value associated with the given key as a string.
+ *
+ * @param key key pointing to the associated value
+ * @param defaultValue default value which is returned in case global job parameter is null
+ * or there is no value associated with the given key
+ * @return (default) value associated with the given key
+ */
+ def getJobParameter(key: String, defaultValue: String): String = {
+ val conf = context.getExecutionConfig.getGlobalJobParameters
+ if (conf != null && conf.toMap.containsKey(key)) {
+ conf.toMap.get(key)
+ } else {
+ defaultValue
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
index b99ab8d..c313d80 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
@@ -23,5 +23,20 @@ package org.apache.flink.table.functions
*
* User-defined functions must have a default constructor and must be instantiable during runtime.
*/
-trait UserDefinedFunction {
+abstract class UserDefinedFunction {
+ /**
+ * Setup method for user-defined function. It can be used for initialization work.
+ *
+ * By default, this method does nothing.
+ */
+ @throws(classOf[Exception])
+ def open(context: FunctionContext): Unit = {}
+
+ /**
+ * Tear-down method for user-defined function. It can be used for clean up work.
+ *
+ * By default, this method does nothing.
+ */
+ @throws(classOf[Exception])
+ def close(): Unit = {}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
index 4e803da..a0415e1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
@@ -18,6 +18,7 @@
package org.apache.flink.table.runtime
+import org.apache.flink.api.common.functions.util.FunctionUtils
import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
@@ -52,6 +53,8 @@ class CorrelateFlatMapRunner[IN, OUT](
val constructor = flatMapClazz.getConstructor(classOf[TableFunctionCollector[_]])
LOG.debug("Instantiating FlatMapFunction.")
function = constructor.newInstance(collector).asInstanceOf[FlatMapFunction[IN, OUT]]
+ FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext)
+ FunctionUtils.openFunction(function, parameters)
}
override def flatMap(in: IN, out: Collector[OUT]): Unit = {
@@ -62,4 +65,8 @@ class CorrelateFlatMapRunner[IN, OUT](
}
override def getProducedType: TypeInformation[OUT] = returnType
+
+ override def close(): Unit = {
+ FunctionUtils.closeFunction(function)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala
index a7bd980..b446306 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala
@@ -18,6 +18,7 @@
package org.apache.flink.table.runtime
+import org.apache.flink.api.common.functions.util.FunctionUtils
import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
@@ -43,10 +44,16 @@ class FlatMapRunner[IN, OUT](
val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code)
LOG.debug("Instantiating FlatMapFunction.")
function = clazz.newInstance()
+ FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext)
+ FunctionUtils.openFunction(function, parameters)
}
override def flatMap(in: IN, out: Collector[OUT]): Unit =
function.flatMap(in, out)
override def getProducedType: TypeInformation[OUT] = returnType
+
+ override def close(): Unit = {
+ FunctionUtils.closeFunction(function)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala
index 3710642..00f4782 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala
@@ -23,11 +23,11 @@ import java.sql.{Date, Time, Timestamp}
import java.util
import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.scala.batch.sql.FilterITCase.MyHashCode
-import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase}
import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
-import org.apache.flink.table.api.scala._
-import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase}
import org.apache.flink.table.api.{TableEnvironment, ValidationException}
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.test.util.TestBaseUtils
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala
index 2f853f3..b78dd91 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala
@@ -22,15 +22,15 @@ import java.sql.{Date, Time, Timestamp}
import java.util
import org.apache.flink.api.scala._
-import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase}
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
-import org.apache.flink.table.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
-import org.apache.flink.types.Row
import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase}
import org.apache.flink.table.expressions.Literal
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.test.util.TestBaseUtils
+import org.apache.flink.types.Row
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
index 97e76fa..70bec72 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
@@ -19,12 +19,12 @@
package org.apache.flink.table.api.scala.stream.sql
import org.apache.flink.api.scala._
-import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
-import org.apache.flink.table.api.scala._
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
+import org.apache.flink.types.Row
import org.junit.Assert._
import org.junit._
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala
index f541eb4..5969e91 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala
@@ -19,13 +19,13 @@
package org.apache.flink.table.api.scala.stream.table
import org.apache.flink.api.scala._
-import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.expressions.Literal
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
import org.apache.flink.table.api.{TableEnvironment, TableException}
+import org.apache.flink.table.expressions.Literal
+import org.apache.flink.types.Row
import org.junit.Assert._
import org.junit.Test
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
index da8c748..a6c1760 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -179,7 +179,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"Func12(f8)",
"+0 00:00:01.000")
}
-
+
@Test
def testJavaBoxedPrimitives(): Unit = {
val JavaFunc0 = new JavaFunc0()
@@ -211,6 +211,30 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"null and 15 and null")
}
+ @Test
+ def testRichFunctions(): Unit = {
+ val richFunc0 = new RichFunc0
+ val richFunc1 = new RichFunc1
+ val richFunc2 = new RichFunc2
+ testAllApis(
+ richFunc0('f0),
+ "RichFunc0(f0)",
+ "RichFunc0(f0)",
+ "43")
+
+ testAllApis(
+ richFunc1('f0),
+ "RichFunc1(f0)",
+ "RichFunc1(f0)",
+ "42")
+
+ testAllApis(
+ richFunc2('f1),
+ "RichFunc2(f1)",
+ "RichFunc2(f1)",
+ "#Test")
+ }
+
// ----------------------------------------------------------------------------------------------
override def testData: Any = {
@@ -256,7 +280,10 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"Func11" -> Func11,
"Func12" -> Func12,
"JavaFunc0" -> new JavaFunc0,
- "JavaFunc1" -> new JavaFunc1
+ "JavaFunc1" -> new JavaFunc1,
+ "RichFunc0" -> new RichFunc0,
+ "RichFunc1" -> new RichFunc1,
+ "RichFunc2" -> new RichFunc2
)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
index 679942c..30da5ba 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
@@ -19,17 +19,23 @@
package org.apache.flink.table.expressions.utils
import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, HepProgramBuilder}
+import java.util
+import java.util.concurrent.Future
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql2rel.RelDecorrelator
import org.apache.calcite.tools.{Programs, RelBuilder}
-import org.apache.flink.api.common.functions.{Function, MapFunction}
+import org.apache.flink.api.common.TaskInfo
+import org.apache.flink.api.common.accumulators.Accumulator
+import org.apache.flink.api.common.functions._
+import org.apache.flink.api.common.functions.util.RuntimeUDFContext
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.java.{DataSet => JDataSet}
import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.types.Row
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.core.fs.Path
import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig, TableEnvironment}
import org.apache.flink.table.calcite.FlinkPlannerImpl
import org.apache.flink.table.codegen.{CodeGenerator, Compiler, GeneratedFunction}
@@ -37,6 +43,7 @@ import org.apache.flink.table.expressions.{Expression, ExpressionParser}
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.plan.nodes.dataset.{DataSetCalc, DataSetConvention}
import org.apache.flink.table.plan.rules.FlinkRuleSets
+import org.apache.flink.types.Row
import org.junit.Assert._
import org.junit.{After, Before}
import org.mockito.Mockito._
@@ -69,7 +76,8 @@ abstract class ExpressionTestBase {
new HepPlanner(builder.build, context._2.getFrameworkConfig.getContext)
}
- private def prepareContext(typeInfo: TypeInformation[Any]): (RelBuilder, TableEnvironment) = {
+ private def prepareContext(typeInfo: TypeInformation[Any])
+ : (RelBuilder, TableEnvironment, ExecutionEnvironment) = {
// create DataSetTable
val dataSetMock = mock(classOf[DataSet[Any]])
val jDataSetMock = mock(classOf[JDataSet[Any]])
@@ -85,7 +93,7 @@ abstract class ExpressionTestBase {
val relBuilder = tEnv.getRelBuilder
relBuilder.scan(tableName)
- (relBuilder, tEnv)
+ (relBuilder, tEnv, env)
}
def testData: Any
@@ -130,8 +138,30 @@ abstract class ExpressionTestBase {
// compile and evaluate
val clazz = new TestCompiler[MapFunction[Any, Row], Row]().compile(genFunc)
val mapper = clazz.newInstance()
+
+ val isRichFunction = mapper.isInstanceOf[RichFunction]
+
+ // call setRuntimeContext method and open method for RichFunction
+ if (isRichFunction) {
+ val richMapper = mapper.asInstanceOf[RichMapFunction[_, _]]
+ val t = new RuntimeUDFContext(
+ new TaskInfo("ExpressionTest", 1, 0, 1, 1),
+ null,
+ context._3.getConfig,
+ new util.HashMap[String, Future[Path]](),
+ new util.HashMap[String, Accumulator[_, _]](),
+ null)
+ richMapper.setRuntimeContext(t)
+ richMapper.open(new Configuration())
+ }
+
val result = mapper.map(testData)
+ // call close method for RichFunction
+ if (isRichFunction) {
+ mapper.asInstanceOf[RichMapFunction[_, _]].close()
+ }
+
// compare
testExprs
.zipWithIndex
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
index 4e9b6d3..f0b347d 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
@@ -22,7 +22,11 @@ import java.sql.{Date, Time, Timestamp}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.Types
-import org.apache.flink.table.functions.ScalarFunction
+import org.apache.flink.table.functions.{ScalarFunction, FunctionContext}
+import org.junit.Assert
+
+import scala.collection.mutable
+import scala.io.Source
case class SimplePojo(name: String, age: Int)
@@ -119,3 +123,94 @@ object Func12 extends ScalarFunction {
Types.INTERVAL_MILLIS
}
}
+
+class RichFunc0 extends ScalarFunction {
+ var openCalled = false
+ var closeCalled = false
+
+ override def open(context: FunctionContext): Unit = {
+ super.open(context)
+ if (openCalled) {
+ Assert.fail("Open called more than once.")
+ } else {
+ openCalled = true
+ }
+ if (closeCalled) {
+ Assert.fail("Close called before open.")
+ }
+ }
+
+ def eval(index: Int): Int = {
+ if (!openCalled) {
+ Assert.fail("Open was not called before eval.")
+ }
+ if (closeCalled) {
+ Assert.fail("Close called before eval.")
+ }
+
+ index + 1
+ }
+
+ override def close(): Unit = {
+ super.close()
+ if (closeCalled) {
+ Assert.fail("Close called more than once.")
+ } else {
+ closeCalled = true
+ }
+ if (!openCalled) {
+ Assert.fail("Open was not called before close.")
+ }
+ }
+}
+
+class RichFunc1 extends ScalarFunction {
+ var added = Int.MaxValue
+
+ override def open(context: FunctionContext): Unit = {
+ added = context.getJobParameter("int.value", "0").toInt
+ }
+
+ def eval(index: Int): Int = {
+ index + added
+ }
+
+ override def close(): Unit = {
+ added = Int.MaxValue
+ }
+}
+
+class RichFunc2 extends ScalarFunction {
+ var prefix = "ERROR_VALUE"
+
+ override def open(context: FunctionContext): Unit = {
+ prefix = context.getJobParameter("string.value", "")
+ }
+
+ def eval(value: String): String = {
+ prefix + "#" + value
+ }
+
+ override def close(): Unit = {
+ prefix = "ERROR_VALUE"
+ }
+}
+
+class RichFunc3 extends ScalarFunction {
+ private val words = mutable.HashSet[String]()
+
+ override def open(context: FunctionContext): Unit = {
+ val file = context.getCachedFile("words")
+ for (line <- Source.fromFile(file.getCanonicalPath).getLines) {
+ words.add(line.trim)
+ }
+ }
+
+ def eval(value: String): Boolean = {
+ words.contains(value)
+ }
+
+ override def close(): Unit = {
+ words.clear()
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala
new file mode 100644
index 0000000..f0b3b44
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.dataset
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.expressions.utils.{RichFunc1, RichFunc2, RichFunc3}
+import org.apache.flink.table.utils._
+import org.apache.flink.test.util.TestBaseUtils
+import org.apache.flink.types.Row
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+
+@RunWith(classOf[Parameterized])
+class DataSetCalcITCase(
+ configMode: TableConfigMode)
+ extends TableProgramsClusterTestBase(configMode) {
+
+ @Test
+ def testUserDefinedScalarFunctionWithParameter(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ tEnv.registerFunction("RichFunc2", new RichFunc2)
+ UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "ABC"))
+
+ val ds = CollectionDataSets.getSmall3TupleDataSet(env)
+ tEnv.registerDataSet("t1", ds, 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT c FROM t1 where RichFunc2(c)='ABC#Hello'"
+
+ val result = tEnv.sql(sqlQuery)
+
+ val expected = "Hello"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testUserDefinedScalarFunctionWithDistributedCache(): Unit = {
+ val words = "Hello\nWord"
+ val filePath = UserDefinedFunctionTestUtils.writeCacheFile("test_words", words)
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.registerCachedFile(filePath, "words")
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ tEnv.registerFunction("RichFunc3", new RichFunc3)
+
+ val ds = CollectionDataSets.getSmall3TupleDataSet(env)
+ tEnv.registerDataSet("t1", ds, 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT c FROM t1 where RichFunc3(c)=true"
+
+ val result = tEnv.sql(sqlQuery)
+
+ val expected = "Hello"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testMultipleUserDefinedScalarFunctions(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ tEnv.registerFunction("RichFunc1", new RichFunc1)
+ tEnv.registerFunction("RichFunc2", new RichFunc2)
+ UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "Abc"))
+
+ val ds = CollectionDataSets.getSmall3TupleDataSet(env)
+ tEnv.registerDataSet("t1", ds, 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT c FROM t1 where " +
+ "RichFunc2(c)='Abc#Hello' or RichFunc1(a)=3 and b=2"
+
+ val result = tEnv.sql(sqlQuery)
+
+ val expected = "Hello\nHello world"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
index 818f52b..cd1ffb5 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
@@ -20,14 +20,16 @@ package org.apache.flink.table.runtime.dataset
import java.sql.{Date, Timestamp}
import org.apache.flink.api.scala._
-import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
-import org.apache.flink.table.api.scala._
+import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
+import org.apache.flink.table.expressions.utils.RichFunc2
import org.apache.flink.table.utils._
import org.apache.flink.test.util.TestBaseUtils
+import org.apache.flink.types.Row
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@@ -147,7 +149,7 @@ class DataSetCorrelateITCase(
}
@Test
- def testUDTFWithScalarFunction(): Unit = {
+ def testUserDefinedTableFunctionWithScalarFunction(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env, config)
val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
@@ -185,6 +187,46 @@ class DataSetCorrelateITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
+ @Test
+ def testUserDefinedTableFunctionWithParameter(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ val richTableFunc1 = new RichTableFunc1
+ tEnv.registerFunction("RichTableFunc1", richTableFunc1)
+ UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> "#"))
+
+ val result = testData(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .join(richTableFunc1('c) as 's)
+ .select('a, 's)
+
+ val expected = "1,Jack\n" + "1,22\n" + "2,John\n" + "2,19\n" + "3,Anna\n" + "3,44"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testUserDefinedTableFunctionWithScalarFunctionWithParameters(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ val richTableFunc1 = new RichTableFunc1
+ tEnv.registerFunction("RichTableFunc1", richTableFunc1)
+ val richFunc2 = new RichFunc2
+ tEnv.registerFunction("RichFunc2", richFunc2)
+ UserDefinedFunctionTestUtils.setJobParameters(
+ env,
+ Map("word_separator" -> "#", "string.value" -> "test"))
+
+ val result = CollectionDataSets.getSmall3TupleDataSet(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .join(richTableFunc1(richFunc2('c)) as 's)
+ .select('a, 's)
+
+ val expected = "1,Hi\n1,test\n2,Hello\n2,test\n3,Hello world\n3,test"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
private def testData(
env: ExecutionEnvironment)
: DataSet[(Int, Long, String)] = {
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala
new file mode 100644
index 0000000..1d48f2c
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.datastream
+
+import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
+import org.apache.flink.table.expressions.utils.{RichFunc1, RichFunc2}
+import org.apache.flink.table.utils.UserDefinedFunctionTestUtils
+import org.apache.flink.types.Row
+import org.junit.Assert._
+import org.junit.Test
+
+import scala.collection.mutable
+
+class DataStreamCalcITCase extends StreamingMultipleProgramsTestBase {
+
+ @Test
+ def testUserDefinedFunctionWithParameter(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ tEnv.registerFunction("RichFunc2", new RichFunc2)
+ UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "ABC"))
+
+ StreamITCase.testResults = mutable.MutableList()
+
+ val result = StreamTestData.get3TupleDataStream(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .where("RichFunc2(c)='ABC#Hello'")
+ .select('c)
+
+ val results = result.toDataStream[Row]
+ results.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("Hello")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testMultipleUserDefinedFunctions(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ tEnv.registerFunction("RichFunc1", new RichFunc1)
+ tEnv.registerFunction("RichFunc2", new RichFunc2)
+ UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "Abc"))
+
+ StreamITCase.testResults = mutable.MutableList()
+
+ val result = StreamTestData.get3TupleDataStream(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .where("RichFunc2(c)='Abc#Hello' || RichFunc1(a)=3 && b=2")
+ .select('c)
+
+ val results = result.toDataStream[Row]
+ results.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("Hello", "Hello world")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
index eb20517..f8a697d 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
@@ -18,13 +18,14 @@
package org.apache.flink.table.runtime.datastream
import org.apache.flink.api.scala._
-import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.stream.utils.StreamITCase
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.utils.TableFunc0
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
+import org.apache.flink.table.expressions.utils.RichFunc2
+import org.apache.flink.table.utils.{RichTableFunc1, TableFunc0, UserDefinedFunctionTestUtils}
+import org.apache.flink.types.Row
import org.junit.Assert._
import org.junit.Test
@@ -76,9 +77,63 @@ class DataStreamCorrelateITCase extends StreamingMultipleProgramsTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
+ @Test
+ def testUserDefinedTableFunctionWithParameter(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ val tableFunc1 = new RichTableFunc1
+ tEnv.registerFunction("RichTableFunc1", tableFunc1)
+ UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> " "))
+ StreamITCase.testResults = mutable.MutableList()
+
+ val result = StreamTestData.getSmall3TupleDataStream(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .join(tableFunc1('c) as 's)
+ .select('a, 's)
+
+ val results = result.toDataStream[Row]
+ results.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("3,Hello", "3,world")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testUserDefinedTableFunctionWithUserDefinedScalarFunction(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ val tableFunc1 = new RichTableFunc1
+ val richFunc2 = new RichFunc2
+ tEnv.registerFunction("RichTableFunc1", tableFunc1)
+ tEnv.registerFunction("RichFunc2", richFunc2)
+ UserDefinedFunctionTestUtils.setJobParameters(
+ env,
+ Map("word_separator" -> "#", "string.value" -> "test"))
+ StreamITCase.testResults = mutable.MutableList()
+
+ val result = StreamTestData.getSmall3TupleDataStream(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .join(tableFunc1(richFunc2('c)) as 's)
+ .select('a, 's)
+
+ val results = result.toDataStream[Row]
+ results.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "1,Hi",
+ "1,test",
+ "2,Hello",
+ "2,test",
+ "3,Hello world",
+ "3,test")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
private def testData(
- env: StreamExecutionEnvironment)
- : DataStream[(Int, Long, String)] = {
+ env: StreamExecutionEnvironment)
+ : DataStream[(Int, Long, String)] = {
val data = new mutable.MutableList[(Int, Long, String)]
data.+=((1, 1L, "Jack#22"))
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala
new file mode 100644
index 0000000..deaedc9
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.utils
+
+import java.io.File
+
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+
+object UserDefinedFunctionTestUtils {
+
+ def setJobParameters(env: ExecutionEnvironment, parameters: Map[String, String]): Unit = {
+ val conf = new Configuration()
+ parameters.foreach {
+ case (k, v) => conf.setString(k, v)
+ }
+ env.getConfig.setGlobalJobParameters(conf)
+ }
+
+ def setJobParameters(env: StreamExecutionEnvironment, parameters: Map[String, String]): Unit = {
+ val conf = new Configuration()
+ parameters.foreach {
+ case (k, v) => conf.setString(k, v)
+ }
+ env.getConfig.setGlobalJobParameters(conf)
+ }
+
+ def writeCacheFile(fileName: String, contents: String): String = {
+ val tempFile = File.createTempFile(this.getClass.getName + "-" + fileName, "tmp")
+ tempFile.deleteOnExit()
+ Files.write(contents, tempFile, Charsets.UTF_8)
+ tempFile.getAbsolutePath
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
index 54861ea..5db9d5f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
@@ -21,9 +21,11 @@ import java.lang.Boolean
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.Tuple3
-import org.apache.flink.types.Row
import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.functions.{TableFunction, FunctionContext}
+import org.apache.flink.types.Row
+import org.junit.Assert
case class SimpleUser(name: String, age: Int)
@@ -115,3 +117,55 @@ object ObjectTableFunction extends TableFunction[Integer] {
collect(b)
}
}
+
+class RichTableFunc0 extends TableFunction[String] {
+ var openCalled = false
+ var closeCalled = false
+
+ override def open(context: FunctionContext): Unit = {
+ super.open(context)
+ if (closeCalled) {
+ Assert.fail("Close called before open.")
+ }
+ openCalled = true
+ }
+
+ def eval(str: String): Unit = {
+ if (!openCalled) {
+ Assert.fail("Open was not called before eval.")
+ }
+ if (closeCalled) {
+ Assert.fail("Close called before eval.")
+ }
+
+ if (!str.contains("#")) {
+ collect(str)
+ }
+ }
+
+ override def close(): Unit = {
+ super.close()
+ if (!openCalled) {
+ Assert.fail("Open was not called before close.")
+ }
+ closeCalled = true
+ }
+}
+
+class RichTableFunc1 extends TableFunction[String] {
+ var separator: Option[String] = None
+
+ override def open(context: FunctionContext): Unit = {
+ separator = Some(context.getJobParameter("word_separator", ""))
+ }
+
+ def eval(str: String): Unit = {
+ if (str.contains(separator.getOrElse(throw new ValidationException(s"no separator")))) {
+ str.split(separator.get).foreach(collect)
+ }
+ }
+
+ override def close(): Unit = {
+ separator = None
+ }
+}