You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2017/02/22 02:03:39 UTC
flink git commit: [FLINK-5795] [table] Improve UDF&UDTF to support
constructor with parameter
Repository: flink
Updated Branches:
refs/heads/master 11c868f91 -> 45e01cf23
[FLINK-5795] [table] Improve UDF&UDTF to support constructor with parameter
this closes #3330
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/45e01cf2
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/45e01cf2
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/45e01cf2
Branch: refs/heads/master
Commit: 45e01cf2321dda58f572d8b9dbe64947c6725ad1
Parents: 11c868f
Author: \u91d1\u7af9 <ji...@alibaba-inc.com>
Authored: Tue Feb 14 14:43:41 2017 +0800
Committer: Jark Wu <wu...@alibaba-inc.com>
Committed: Wed Feb 22 10:02:04 2017 +0800
----------------------------------------------------------------------
.../flink/table/codegen/CodeGenerator.scala | 19 +-
.../apache/flink/table/expressions/call.scala | 6 +-
.../table/functions/UserDefinedFunction.scala | 11 +-
.../utils/UserDefinedFunctionUtils.scala | 33 +--
.../flink/table/plan/logical/operators.scala | 2 +-
.../flink/table/validate/FunctionCatalog.scala | 12 +-
.../flink/table/CompositeFlatteningTest.scala | 8 +-
.../scala/batch/table/FieldProjectionTest.scala | 4 +-
.../table/UserDefinedTableFunctionTest.scala | 6 +-
.../table/UserDefinedTableFunctionTest.scala | 16 +-
.../utils/UserDefinedScalarFunctions.scala | 7 +
.../dataset/DataSetCorrelateITCase.scala | 241 ----------------
.../DataSetUserDefinedFunctionITCase.scala | 288 +++++++++++++++++++
.../DataSetUserDefinedFunctionITCase.scala | 206 +++++++++++++
.../datastream/DataStreamCorrelateITCase.scala | 146 ----------
.../table/utils/UserDefinedTableFunctions.scala | 36 ++-
16 files changed, 595 insertions(+), 446 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index 441b1c0..6658645 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -40,6 +40,7 @@ import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.codegen.calls.FunctionGenerator
import org.apache.flink.table.codegen.calls.ScalarOperators._
import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.runtime.TableFunctionCollector
import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.types.Row
@@ -1494,15 +1495,14 @@ class CodeGenerator(
/**
* Adds a reusable [[UserDefinedFunction]] to the member area of the generated [[Function]].
- * The [[UserDefinedFunction]] must have a default constructor, however, it does not have
- * to be public.
*
* @param function [[UserDefinedFunction]] object to be instantiated during runtime
* @return member variable term
*/
def addReusableFunction(function: UserDefinedFunction): String = {
val classQualifier = function.getClass.getCanonicalName
- val fieldTerm = s"function_${classQualifier.replace('.', '$')}"
+ val functionSerializedData = UserDefinedFunctionUtils.serialize(function)
+ val fieldTerm = s"function_${function.functionIdentifier}"
val fieldFunction =
s"""
@@ -1510,15 +1510,14 @@ class CodeGenerator(
|""".stripMargin
reusableMemberStatements.add(fieldFunction)
- val constructorTerm = s"constructor_${classQualifier.replace('.', '$')}"
- val constructorAccessibility =
+ val functionDeserialization =
s"""
- |java.lang.reflect.Constructor $constructorTerm =
- | $classQualifier.class.getDeclaredConstructor();
- |$constructorTerm.setAccessible(true);
- |$fieldTerm = ($classQualifier) $constructorTerm.newInstance();
+ |$fieldTerm = ($classQualifier)
+ |${UserDefinedFunctionUtils.getClass.getName.stripSuffix("$")}
+ |.deserialize("$functionSerializedData");
""".stripMargin
- reusableInitStatements.add(constructorAccessibility)
+
+ reusableInitStatements.add(functionDeserialization)
val openFunction =
s"""
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
index ef2cf4e..40db13e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
@@ -20,12 +20,12 @@ package org.apache.flink.table.expressions
import org.apache.calcite.rex.RexNode
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.api.{UnresolvedException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.plan.logical.{LogicalNode, LogicalTableFunctionCall}
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
-import org.apache.flink.table.api.{UnresolvedException, ValidationException}
/**
* General expression for unresolved function calls. The function can be a built-in
@@ -67,7 +67,7 @@ case class ScalarFunctionCall(
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
relBuilder.call(
createScalarSqlFunction(
- scalarFunction.getClass.getCanonicalName,
+ scalarFunction.functionIdentifier,
scalarFunction,
typeFactory),
parameters.map(_.toRexNode): _*)
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
index c313d80..e9e01ee 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
@@ -17,13 +17,13 @@
*/
package org.apache.flink.table.functions
+import org.apache.commons.codec.digest.DigestUtils
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.serialize
/**
* Base class for all user-defined functions such as scalar functions, table functions,
* or aggregation functions.
- *
- * User-defined functions must have a default constructor and must be instantiable during runtime.
*/
-abstract class UserDefinedFunction {
+abstract class UserDefinedFunction extends Serializable {
/**
* Setup method for user-defined function. It can be used for initialization work.
*
@@ -39,4 +39,9 @@ abstract class UserDefinedFunction {
*/
@throws(classOf[Exception])
def close(): Unit = {}
+
+ final def functionIdentifier: String = {
+ val md5 = DigestUtils.md5Hex(serialize(this))
+ getClass.getCanonicalName.replace('.', '$').concat("$").concat(md5)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index f324dc1..16a6717b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -23,6 +23,7 @@ import java.lang.{Long => JLong, Integer => JInt}
import java.lang.reflect.{Method, Modifier}
import java.sql.{Date, Time, Timestamp}
+import org.apache.commons.codec.binary.Base64
import com.google.common.primitives.Primitives
import org.apache.calcite.sql.SqlFunction
import org.apache.flink.api.common.functions.InvalidTypesException
@@ -37,15 +38,6 @@ import org.apache.flink.util.InstantiationUtil
object UserDefinedFunctionUtils {
/**
- * Instantiates a user-defined function.
- */
- def instantiate[T <: UserDefinedFunction](clazz: Class[T]): T = {
- val constructor = clazz.getDeclaredConstructor()
- constructor.setAccessible(true)
- constructor.newInstance()
- }
-
- /**
* Checks if a user-defined function can be easily instantiated.
*/
def checkForInstantiation(clazz: Class[_]): Unit = {
@@ -59,12 +51,6 @@ object UserDefinedFunctionUtils {
else if (InstantiationUtil.isNonStaticInnerClass(clazz)) {
throw ValidationException("The class is an inner class, but not statically accessible.")
}
-
- // check for default constructor (can be private)
- clazz
- .getDeclaredConstructors
- .find(_.getParameterTypes.isEmpty)
- .getOrElse(throw ValidationException("Function class needs a default constructor."))
}
/**
@@ -168,7 +154,7 @@ object UserDefinedFunctionUtils {
/**
* Create [[SqlFunction]] for a [[ScalarFunction]]
- *
+ *
* @param name function name
* @param function scalar function
* @param typeFactory type factory
@@ -184,7 +170,7 @@ object UserDefinedFunctionUtils {
/**
* Create [[SqlFunction]]s for a [[TableFunction]]'s every eval method
- *
+ *
* @param name function name
* @param tableFunction table function
* @param resultType the type information of returned table
@@ -311,7 +297,6 @@ object UserDefinedFunctionUtils {
}
}.toArray
-
/**
* Compares parameter candidate classes with expected classes. If true, the parameters match.
* Candidate can be null (acts as a wildcard).
@@ -324,4 +309,16 @@ object UserDefinedFunctionUtils {
candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong])
+ @throws[Exception]
+ def serialize(function: UserDefinedFunction): String = {
+ val byteArray = InstantiationUtil.serializeObject(function)
+ Base64.encodeBase64URLSafeString(byteArray)
+ }
+
+ @throws[Exception]
+ def deserialize(data: String): UserDefinedFunction = {
+ val byteData = Base64.decodeBase64(data)
+ InstantiationUtil
+ .deserializeObject[UserDefinedFunction](byteData, Thread.currentThread.getContextClassLoader)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
index 20f810a..1b5eafb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
@@ -694,7 +694,7 @@ case class LogicalTableFunctionCall(
val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod)
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
val sqlFunction = TableSqlFunction(
- tableFunction.toString,
+ tableFunction.functionIdentifier,
tableFunction,
resultType,
typeFactory,
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index 207eba1..94237f7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -23,8 +23,8 @@ import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTabl
import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable}
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils}
import org.apache.flink.table.functions.{EventTimeExtractor, RowTime, ScalarFunction, TableFunction}
+import org.apache.flink.table.functions.utils.{TableSqlFunction, ScalarSqlFunction}
import scala.collection.JavaConversions._
import scala.collection.mutable
@@ -81,11 +81,11 @@ class FunctionCatalog {
// user-defined scalar function call
case sf if classOf[ScalarFunction].isAssignableFrom(sf) =>
- Try(UserDefinedFunctionUtils.instantiate(sf.asInstanceOf[Class[ScalarFunction]])) match {
- case Success(scalarFunction) => ScalarFunctionCall(scalarFunction, children)
- case Failure(e) => throw ValidationException(e.getMessage)
- }
-
+ val scalarSqlFunction = sqlFunctions
+ .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[ScalarSqlFunction])
+ .getOrElse(throw ValidationException(s"Undefined scalar function: $name"))
+ .asInstanceOf[ScalarSqlFunction]
+ ScalarFunctionCall(scalarSqlFunction.getScalarFunction, children)
// user-defined table function call
case tf if classOf[TableFunction[_]].isAssignableFrom(tf) =>
val tableSqlFunction = sqlFunctions
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala
index 0055fc2..f5f5ff1 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala
@@ -119,10 +119,10 @@ class CompositeFlatteningTest extends TableTestBase {
"DataSetCalc",
batchTableNode(0),
term("select",
- "org.apache.flink.table.CompositeFlatteningTest.giveMeCaseClass$().my AS _c0",
- "org.apache.flink.table.CompositeFlatteningTest.giveMeCaseClass$().clazz AS _c1",
- "org.apache.flink.table.CompositeFlatteningTest.giveMeCaseClass$().my AS _c2",
- "org.apache.flink.table.CompositeFlatteningTest.giveMeCaseClass$().clazz AS _c3"
+ s"${giveMeCaseClass.functionIdentifier}().my AS _c0",
+ s"${giveMeCaseClass.functionIdentifier}().clazz AS _c1",
+ s"${giveMeCaseClass.functionIdentifier}().my AS _c2",
+ s"${giveMeCaseClass.functionIdentifier}().clazz AS _c3"
)
)
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala
index d053b9f..0066ad2 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala
@@ -103,7 +103,7 @@ class FieldProjectionTest extends TableTestBase {
val expected = unaryNode(
"DataSetCalc",
batchTableNode(0),
- term("select", s"${MyHashCode.getClass.getCanonicalName}(c) AS _c0", "b")
+ term("select", s"${MyHashCode.functionIdentifier}(c) AS _c0", "b")
)
util.verifyTable(resultTable, expected)
@@ -212,7 +212,7 @@ class FieldProjectionTest extends TableTestBase {
unaryNode(
"DataSetCalc",
batchTableNode(0),
- term("select", "a", "c", s"${MyHashCode.getClass.getCanonicalName}(c) AS k")
+ term("select", "a", "c", s"${MyHashCode.functionIdentifier}(c) AS k")
),
term("groupBy", "k"),
term("select", "k", "SUM(a) AS TMP_0")
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala
index f8d9c92..2dbcccf 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala
@@ -120,7 +120,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
unaryNode(
"DataSetCorrelate",
batchTableNode(0),
- term("invocation", s"$function($$2)"),
+ term("invocation", s"${function.functionIdentifier}($$2)"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
@@ -140,7 +140,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
unaryNode(
"DataSetCorrelate",
batchTableNode(0),
- term("invocation", s"$function($$2, '$$')"),
+ term("invocation", s"${function.functionIdentifier}($$2, '$$')"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
@@ -165,7 +165,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
unaryNode(
"DataSetCorrelate",
batchTableNode(0),
- term("invocation", s"$function($$2)"),
+ term("invocation", s"${function.functionIdentifier}($$2)"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala
index 168f9ec..56b9fdb 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala
@@ -183,7 +183,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"$function($$2)"),
+ term("invocation", s"${function.functionIdentifier}($$2)"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
@@ -203,7 +203,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"$function($$2, '$$')"),
+ term("invocation", s"${function.functionIdentifier}($$2, '$$')"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
@@ -228,7 +228,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"$function($$2)"),
+ term("invocation", s"${function.functionIdentifier}($$2)"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
@@ -253,7 +253,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"$function($$2)"),
+ term("invocation", s"${function.functionIdentifier}($$2)"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
@@ -277,7 +277,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
val expected = unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"$function($$2)"),
+ term("invocation", s"${function.functionIdentifier}($$2)"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
@@ -299,7 +299,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
val expected = unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"$function($$2)"),
+ term("invocation", s"${function.functionIdentifier}($$2)"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
@@ -326,7 +326,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"$function($$2)"),
+ term("invocation", s"${function.functionIdentifier}($$2)"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
@@ -351,7 +351,7 @@ class UserDefinedTableFunctionTest extends TableTestBase {
val expected = unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"$function(SUBSTRING($$2, 2, CHAR_LENGTH($$2)))"),
+ term("invocation", s"${function.functionIdentifier}(SUBSTRING($$2, 2, CHAR_LENGTH($$2)))"),
term("function", function),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
index f0b347d..4fee3b2 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
@@ -214,3 +214,10 @@ class RichFunc3 extends ScalarFunction {
words.clear()
}
}
+
+class Func13(prefix: String) extends ScalarFunction {
+ def eval(a: String): String = {
+ s"$prefix-$a"
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
deleted file mode 100644
index cd1ffb5..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
+++ /dev/null
@@ -1,241 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.dataset
-
-import java.sql.{Date, Timestamp}
-
-import org.apache.flink.api.scala._
-import org.apache.flink.api.scala.util.CollectionDataSets
-import org.apache.flink.table.api.TableEnvironment
-import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
-import org.apache.flink.table.expressions.utils.RichFunc2
-import org.apache.flink.table.utils._
-import org.apache.flink.test.util.TestBaseUtils
-import org.apache.flink.types.Row
-import org.junit.Test
-import org.junit.runner.RunWith
-import org.junit.runners.Parameterized
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-
-@RunWith(classOf[Parameterized])
-class DataSetCorrelateITCase(
- configMode: TableConfigMode)
- extends TableProgramsClusterTestBase(configMode) {
-
- @Test
- def testCrossJoin(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
-
- val func1 = new TableFunc1
- val result = in.join(func1('c) as 's).select('c, 's).toDataSet[Row]
- val results = result.collect()
- val expected = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
- "Anna#44,Anna\n" + "Anna#44,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
-
- // with overloading
- val result2 = in.join(func1('c, "$") as 's).select('c, 's).toDataSet[Row]
- val results2 = result2.collect()
- val expected2 = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" +
- "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n"
- TestBaseUtils.compareResultAsText(results2.asJava, expected2)
- }
-
- @Test
- def testLeftOuterJoin(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
-
- val func2 = new TableFunc2
- val result = in.leftOuterJoin(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row]
- val results = result.collect()
- val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
- "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testWithFilter(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
- val func0 = new TableFunc0
-
- val result = in
- .join(func0('c) as ('name, 'age))
- .select('c, 'name, 'age)
- .filter('age > 20)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testCustomReturnType(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
- val func2 = new TableFunc2
-
- val result = in
- .join(func2('c) as ('name, 'len))
- .select('c, 'name, 'len)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
- "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testHierarchyType(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
-
- val hierarchy = new HierarchyTableFunction
- val result = in
- .join(hierarchy('c) as ('name, 'adult, 'len))
- .select('c, 'name, 'adult, 'len)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" +
- "Anna#44,Anna,true,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testPojoType(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
-
- val pojo = new PojoTableFunc()
- val result = in
- .join(pojo('c))
- .select('c, 'name, 'age)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testUserDefinedTableFunctionWithScalarFunction(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
- val func1 = new TableFunc1
-
- val result = in
- .join(func1('c.substring(2)) as 's)
- .select('c, 's)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" +
- "Anna#44,nna\n" + "Anna#44,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testLongAndTemporalTypes(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv = TableEnvironment.getTableEnvironment(env, config)
- val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
- val func0 = new JavaTableFunc0
-
- val result = in
- .where('a === 1)
- .select(Date.valueOf("1990-10-14") as 'x,
- 1000L as 'y,
- Timestamp.valueOf("1990-10-14 12:10:10") as 'z)
- .join(func0('x, 'y, 'z) as 's)
- .select('s)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected = "1000\n" + "655906210000\n" + "7591\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testUserDefinedTableFunctionWithParameter(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- val richTableFunc1 = new RichTableFunc1
- tEnv.registerFunction("RichTableFunc1", richTableFunc1)
- UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> "#"))
-
- val result = testData(env)
- .toTable(tEnv, 'a, 'b, 'c)
- .join(richTableFunc1('c) as 's)
- .select('a, 's)
-
- val expected = "1,Jack\n" + "1,22\n" + "2,John\n" + "2,19\n" + "3,Anna\n" + "3,44"
- val results = result.toDataSet[Row].collect()
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testUserDefinedTableFunctionWithScalarFunctionWithParameters(): Unit = {
- val env = ExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- val richTableFunc1 = new RichTableFunc1
- tEnv.registerFunction("RichTableFunc1", richTableFunc1)
- val richFunc2 = new RichFunc2
- tEnv.registerFunction("RichFunc2", richFunc2)
- UserDefinedFunctionTestUtils.setJobParameters(
- env,
- Map("word_separator" -> "#", "string.value" -> "test"))
-
- val result = CollectionDataSets.getSmall3TupleDataSet(env)
- .toTable(tEnv, 'a, 'b, 'c)
- .join(richTableFunc1(richFunc2('c)) as 's)
- .select('a, 's)
-
- val expected = "1,Hi\n1,test\n2,Hello\n2,test\n3,Hello world\n3,test"
- val results = result.toDataSet[Row].collect()
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- private def testData(
- env: ExecutionEnvironment)
- : DataSet[(Int, Long, String)] = {
-
- val data = new mutable.MutableList[(Int, Long, String)]
- data.+=((1, 1L, "Jack#22"))
- data.+=((2, 2L, "John#19"))
- data.+=((3, 2L, "Anna#44"))
- data.+=((4, 3L, "nosharp"))
- env.fromCollection(data)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala
new file mode 100644
index 0000000..d268594
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.dataset
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
+import org.apache.flink.table.expressions.utils.{Func13, RichFunc2}
+import org.apache.flink.table.utils._
+import org.apache.flink.test.util.TestBaseUtils
+import org.apache.flink.types.Row
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class DataSetUserDefinedFunctionITCase(
+ configMode: TableConfigMode)
+ extends TableProgramsClusterTestBase(configMode) {
+
+ @Test
+ def testCrossJoin(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func1 = new TableFunc1
+ val result = in.join(func1('c) as 's).select('c, 's).toDataSet[Row]
+ val results = result.collect()
+ val expected = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+ "Anna#44,Anna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+
+ // with overloading
+ val result2 = in.join(func1('c, "$") as 's).select('c, 's).toDataSet[Row]
+ val results2 = result2.collect()
+ val expected2 = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" +
+ "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n"
+ TestBaseUtils.compareResultAsText(results2.asJava, expected2)
+ }
+
+ @Test
+ def testLeftOuterJoin(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func2 = new TableFunc2
+ val result = in.leftOuterJoin(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row]
+ val results = result.collect()
+ val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testWithFilter(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = in
+ .join(func0('c) as ('name, 'age))
+ .select('c, 'name, 'age)
+ .filter('age > 20)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testCustomReturnType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func2 = new TableFunc2
+
+ val result = in
+ .join(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testHierarchyType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val hierarchy = new HierarchyTableFunction
+ val result = in
+ .join(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'adult, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" +
+ "Anna#44,Anna,true,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testPojoType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val pojo = new PojoTableFunc()
+ val result = in
+ .join(pojo('c))
+ .select('c, 'name, 'age)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testUserDefinedTableFunctionWithScalarFunction(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func1 = new TableFunc1
+
+ val result = in
+ .join(func1('c.substring(2)) as 's)
+ .select('c, 's)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" +
+ "Anna#44,nna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testLongAndTemporalTypes(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func0 = new JavaTableFunc0
+
+ val result = in
+ .where('a === 1)
+ .select(Date.valueOf("1990-10-14") as 'x,
+ 1000L as 'y,
+ Timestamp.valueOf("1990-10-14 12:10:10") as 'z)
+ .join(func0('x, 'y, 'z) as 's)
+ .select('s)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "1000\n" + "655906210000\n" + "7591\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testUserDefinedTableFunctionWithParameter(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ val richTableFunc1 = new RichTableFunc1
+ tEnv.registerFunction("RichTableFunc1", richTableFunc1)
+ UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> "#"))
+
+ val result = testData(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .join(richTableFunc1('c) as 's)
+ .select('a, 's)
+
+ val expected = "1,Jack\n" + "1,22\n" + "2,John\n" + "2,19\n" + "3,Anna\n" + "3,44"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testUserDefinedTableFunctionWithScalarFunctionWithParameters(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ val richTableFunc1 = new RichTableFunc1
+ tEnv.registerFunction("RichTableFunc1", richTableFunc1)
+ val richFunc2 = new RichFunc2
+ tEnv.registerFunction("RichFunc2", richFunc2)
+ UserDefinedFunctionTestUtils.setJobParameters(
+ env,
+ Map("word_separator" -> "#", "string.value" -> "test"))
+
+ val result = CollectionDataSets.getSmall3TupleDataSet(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .join(richTableFunc1(richFunc2('c)) as 's)
+ .select('a, 's)
+
+ val expected = "1,Hi\n1,test\n2,Hello\n2,test\n3,Hello world\n3,test"
+ val results = result.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testTableFunctionConstructorWithParams(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func30 = new TableFunc3(null)
+ val func31 = new TableFunc3("OneConf_")
+ val func32 = new TableFunc3("TwoConf_")
+
+ val result = in
+ .join(func30('c) as('d, 'e))
+ .select('c, 'd, 'e)
+ .join(func31('c) as ('f, 'g))
+ .select('c, 'd, 'e, 'f, 'g)
+ .join(func32('c) as ('h, 'i))
+ .select('c, 'd, 'f, 'h, 'e, 'g, 'i)
+ .toDataSet[Row]
+
+ val results = result.collect()
+
+ val expected = "Anna#44,Anna,OneConf_Anna,TwoConf_Anna,44,44,44\n" +
+ "Jack#22,Jack,OneConf_Jack,TwoConf_Jack,22,22,22\n" +
+ "John#19,John,OneConf_John,TwoConf_John,19,19,19\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testScalarFunctionConstructorWithParams(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func0 = new Func13("default")
+ val func1 = new Func13("Sunny")
+ val func2 = new Func13("kevin2")
+
+ val result = in.select(func0('c), func1('c),func2('c))
+
+ val results = result.collect()
+
+ val expected = "default-Anna#44,Sunny-Anna#44,kevin2-Anna#44\n" +
+ "default-Jack#22,Sunny-Jack#22,kevin2-Jack#22\n" +
+ "default-John#19,Sunny-John#19,kevin2-John#19\n" +
+ "default-nosharp,Sunny-nosharp,kevin2-nosharp"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ private def testData(
+ env: ExecutionEnvironment)
+ : DataSet[(Int, Long, String)] = {
+
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "Jack#22"))
+ data.+=((2, 2L, "John#19"))
+ data.+=((3, 2L, "Anna#44"))
+ data.+=((4, 3L, "nosharp"))
+ env.fromCollection(data)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataSetUserDefinedFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataSetUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataSetUserDefinedFunctionITCase.scala
new file mode 100644
index 0000000..21b87e9
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataSetUserDefinedFunctionITCase.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.datastream
+
+import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
+import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
+import org.apache.flink.table.expressions.utils.{Func13, RichFunc2}
+import org.apache.flink.table.utils.{RichTableFunc1, TableFunc0, TableFunc3, UserDefinedFunctionTestUtils}
+import org.apache.flink.types.Row
+import org.junit.Assert._
+import org.junit.Test
+
+import scala.collection.mutable
+
+class DataSetUserDefinedFunctionITCase extends StreamingMultipleProgramsTestBase {
+
+ @Test
+ def testCrossJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = t
+ .join(func0('c) as('d, 'e))
+ .select('c, 'd, 'e)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testLeftOuterJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = t
+ .leftOuterJoin(func0('c) as('d, 'e))
+ .select('c, 'd, 'e)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "nosharp,null,null", "Jack#22,Jack,22",
+ "John#19,John,19", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testUserDefinedTableFunctionWithParameter(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ val tableFunc1 = new RichTableFunc1
+ tEnv.registerFunction("RichTableFunc1", tableFunc1)
+ UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> " "))
+ StreamITCase.testResults = mutable.MutableList()
+
+ val result = StreamTestData.getSmall3TupleDataStream(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .join(tableFunc1('c) as 's)
+ .select('a, 's)
+
+ val results = result.toDataStream[Row]
+ results.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("3,Hello", "3,world")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testUserDefinedTableFunctionWithUserDefinedScalarFunction(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ val tableFunc1 = new RichTableFunc1
+ val richFunc2 = new RichFunc2
+ tEnv.registerFunction("RichTableFunc1", tableFunc1)
+ tEnv.registerFunction("RichFunc2", richFunc2)
+ UserDefinedFunctionTestUtils.setJobParameters(
+ env,
+ Map("word_separator" -> "#", "string.value" -> "test"))
+ StreamITCase.testResults = mutable.MutableList()
+
+ val result = StreamTestData.getSmall3TupleDataStream(env)
+ .toTable(tEnv, 'a, 'b, 'c)
+ .join(tableFunc1(richFunc2('c)) as 's)
+ .select('a, 's)
+
+ val results = result.toDataStream[Row]
+ results.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "1,Hi",
+ "1,test",
+ "2,Hello",
+ "2,test",
+ "3,Hello world",
+ "3,test")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testTableFunctionConstructorWithParams(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
+ val config = Map("key1" -> "value1", "key2" -> "value2")
+ val func30 = new TableFunc3(null)
+ val func31 = new TableFunc3("OneConf_")
+ val func32 = new TableFunc3("TwoConf_", config)
+
+ val result = t
+ .join(func30('c) as('d, 'e))
+ .select('c, 'd, 'e)
+ .join(func31('c) as ('f, 'g))
+ .select('c, 'd, 'e, 'f, 'g)
+ .join(func32('c) as ('h, 'i))
+ .select('c, 'd, 'f, 'h, 'e, 'g, 'i)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "Anna#44,Anna,OneConf_Anna,TwoConf__key=key1_value=value1_Anna,44,44,44",
+ "Anna#44,Anna,OneConf_Anna,TwoConf__key=key2_value=value2_Anna,44,44,44",
+ "Jack#22,Jack,OneConf_Jack,TwoConf__key=key1_value=value1_Jack,22,22,22",
+ "Jack#22,Jack,OneConf_Jack,TwoConf__key=key2_value=value2_Jack,22,22,22",
+ "John#19,John,OneConf_John,TwoConf__key=key1_value=value1_John,19,19,19",
+ "John#19,John,OneConf_John,TwoConf__key=key2_value=value2_John,19,19,19"
+ )
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testScalarFunctionConstructorWithParams(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
+ val func0 = new Func13("default")
+ val func1 = new Func13("Sunny")
+ val func2 = new Func13("kevin2")
+
+ val result = t.select(func0('c), func1('c),func2('c))
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "default-Anna#44,Sunny-Anna#44,kevin2-Anna#44",
+ "default-Jack#22,Sunny-Jack#22,kevin2-Jack#22",
+ "default-John#19,Sunny-John#19,kevin2-John#19",
+ "default-nosharp,Sunny-nosharp,kevin2-nosharp"
+ )
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ private def testData(
+ env: StreamExecutionEnvironment)
+ : DataStream[(Int, Long, String)] = {
+
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "Jack#22"))
+ data.+=((2, 2L, "John#19"))
+ data.+=((3, 2L, "Anna#44"))
+ data.+=((4, 3L, "nosharp"))
+ env.fromCollection(data)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
deleted file mode 100644
index f8a697d..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.datastream
-
-import org.apache.flink.api.scala._
-import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
-import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
-import org.apache.flink.table.api.TableEnvironment
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
-import org.apache.flink.table.expressions.utils.RichFunc2
-import org.apache.flink.table.utils.{RichTableFunc1, TableFunc0, UserDefinedFunctionTestUtils}
-import org.apache.flink.types.Row
-import org.junit.Assert._
-import org.junit.Test
-
-import scala.collection.mutable
-
-class DataStreamCorrelateITCase extends StreamingMultipleProgramsTestBase {
-
- @Test
- def testCrossJoin(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- StreamITCase.clear
-
- val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
- val func0 = new TableFunc0
-
- val result = t
- .join(func0('c) as('d, 'e))
- .select('c, 'd, 'e)
- .toDataStream[Row]
-
- result.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- @Test
- def testLeftOuterJoin(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- StreamITCase.clear
-
- val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
- val func0 = new TableFunc0
-
- val result = t
- .leftOuterJoin(func0('c) as('d, 'e))
- .select('c, 'd, 'e)
- .toDataStream[Row]
-
- result.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList(
- "nosharp,null,null", "Jack#22,Jack,22",
- "John#19,John,19", "Anna#44,Anna,44")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- @Test
- def testUserDefinedTableFunctionWithParameter(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- val tableFunc1 = new RichTableFunc1
- tEnv.registerFunction("RichTableFunc1", tableFunc1)
- UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> " "))
- StreamITCase.testResults = mutable.MutableList()
-
- val result = StreamTestData.getSmall3TupleDataStream(env)
- .toTable(tEnv, 'a, 'b, 'c)
- .join(tableFunc1('c) as 's)
- .select('a, 's)
-
- val results = result.toDataStream[Row]
- results.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList("3,Hello", "3,world")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- @Test
- def testUserDefinedTableFunctionWithUserDefinedScalarFunction(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- val tableFunc1 = new RichTableFunc1
- val richFunc2 = new RichFunc2
- tEnv.registerFunction("RichTableFunc1", tableFunc1)
- tEnv.registerFunction("RichFunc2", richFunc2)
- UserDefinedFunctionTestUtils.setJobParameters(
- env,
- Map("word_separator" -> "#", "string.value" -> "test"))
- StreamITCase.testResults = mutable.MutableList()
-
- val result = StreamTestData.getSmall3TupleDataStream(env)
- .toTable(tEnv, 'a, 'b, 'c)
- .join(tableFunc1(richFunc2('c)) as 's)
- .select('a, 's)
-
- val results = result.toDataStream[Row]
- results.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList(
- "1,Hi",
- "1,test",
- "2,Hello",
- "2,test",
- "3,Hello world",
- "3,test")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- private def testData(
- env: StreamExecutionEnvironment)
- : DataStream[(Int, Long, String)] = {
-
- val data = new mutable.MutableList[(Int, Long, String)]
- data.+=((1, 1L, "Jack#22"))
- data.+=((2, 2L, "John#19"))
- data.+=((3, 2L, "Anna#44"))
- data.+=((4, 3L, "nosharp"))
- env.fromCollection(data)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
index 5db9d5f..88917a2 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
@@ -54,7 +54,6 @@ class TableFunc1 extends TableFunction[String] {
}
}
-
class TableFunc2 extends TableFunction[Row] {
def eval(str: String): Unit = {
if (str.contains("#")) {
@@ -73,6 +72,41 @@ class TableFunc2 extends TableFunction[Row] {
}
}
+class TableFunc3(data: String, conf: Map[String, String]) extends TableFunction[SimpleUser] {
+
+ def this(data: String) {
+ this(data, null)
+ }
+
+ def eval(user: String): Unit = {
+ if (user.contains("#")) {
+ val splits = user.split("#")
+ if (null != data) {
+ if (null != conf && conf.size > 0) {
+ val it = conf.keys.iterator
+ while (it.hasNext) {
+ val key = it.next()
+ val value = conf.get(key).get
+ collect(
+ SimpleUser(
+ data.concat("_key=")
+ .concat(key)
+ .concat("_value=")
+ .concat(value)
+ .concat("_")
+ .concat(splits(0)),
+ splits(1).toInt))
+ }
+ } else {
+ collect(SimpleUser(data.concat(splits(0)), splits(1).toInt))
+ }
+ } else {
+ collect(SimpleUser(splits(0), splits(1).toInt))
+ }
+ }
+ }
+}
+
class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] {
def eval(user: String) {
if (user.contains("#")) {