You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/03/13 09:33:39 UTC
flink git commit: [FLINK-5881] [table] ScalarFunction(UDF) should
support variable types and variable arguments
Repository: flink
Updated Branches:
refs/heads/master 354a13edf -> 9b179beae
[FLINK-5881] [table] ScalarFunction(UDF) should support variable types and variable arguments
This closes #3389.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/9b179bea
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/9b179bea
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/9b179bea
Branch: refs/heads/master
Commit: 9b179beaea2b623ad3637e417f6d8014b696d038
Parents: 354a13e
Author: Zhuoluo Yang <zh...@alibaba-inc.com>
Authored: Wed Feb 22 18:53:34 2017 +0800
Committer: twalthr <tw...@apache.org>
Committed: Mon Mar 13 10:29:35 2017 +0100
----------------------------------------------------------------------
.../codegen/calls/ScalarFunctionCallGen.scala | 17 ++--
.../functions/utils/ScalarSqlFunction.scala | 26 ++++--
.../utils/UserDefinedFunctionUtils.scala | 74 +++++++++------
.../java/utils/UserDefinedScalarFunctions.java | 20 +++++
.../UserDefinedScalarFunctionTest.scala | 95 +++++++++++++++++++-
.../utils/UserDefinedScalarFunctions.scala | 36 ++++++++
6 files changed, 229 insertions(+), 39 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
index 7ff18eb..b0b4e09 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
@@ -43,15 +43,22 @@ class ScalarFunctionCallGen(
codeGenerator: CodeGenerator,
operands: Seq[GeneratedExpression])
: GeneratedExpression = {
- // determine function signature and result class
- val matchingSignature = getSignature(scalarFunction, signature)
+ // determine function method and result class
+ val matchingMethod = getEvalMethod(scalarFunction, signature)
.getOrElse(throw new CodeGenException("No matching signature found."))
+ val matchingSignature = matchingMethod.getParameterTypes
val resultClass = getResultTypeClass(scalarFunction, matchingSignature)
+ // zip for variable signatures
+ var paramToOperands = matchingSignature.zip(operands)
+ if (operands.length > matchingSignature.length) {
+ operands.drop(matchingSignature.length).foreach(op =>
+ paramToOperands = paramToOperands :+ (matchingSignature.last.getComponentType, op)
+ )
+ }
+
// convert parameters for function (output boxing)
- val parameters = matchingSignature
- .zip(operands)
- .map { case (paramClass, operandExpr) =>
+ val parameters = paramToOperands.map { case (paramClass, operandExpr) =>
if (paramClass.isPrimitive) {
operandExpr
} else if (ClassUtils.isPrimitiveWrapper(paramClass)
http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
index dc6d41f..e2cd272 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
@@ -113,9 +113,15 @@ object ScalarSqlFunction {
.getParameterTypes(foundSignature)
.map(typeFactory.createTypeFromTypeInfo)
- inferredTypes.zipWithIndex.foreach {
- case (inferredType, i) =>
- operandTypes(i) = inferredType
+ for (i <- operandTypes.indices) {
+ if (i < inferredTypes.length - 1) {
+ operandTypes(i) = inferredTypes(i)
+ } else if (null != inferredTypes.last.getComponentType) {
+ // last argument is a collection, the array type
+ operandTypes(i) = inferredTypes.last.getComponentType
+ } else {
+ operandTypes(i) = inferredTypes.last
+ }
}
}
}
@@ -137,8 +143,18 @@ object ScalarSqlFunction {
}
override def getOperandCountRange: SqlOperandCountRange = {
- val signatureLengths = signatures.map(_.length)
- SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max)
+ var min = 255
+ var max = -1
+ signatures.foreach( sig => {
+ var len = sig.length
+ if (len > 0 && sig(sig.length - 1).isArray) {
+ max = 254 // according to JVM spec 4.3.3
+ len = sig.length - 1
+ }
+ max = Math.max(len, max)
+ min = Math.min(len, min)
+ })
+ SqlOperandCountRanges.between(min, max)
}
override def checkOperandTypes(
http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 21d28b5..c1cfe06 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -78,20 +78,7 @@ object UserDefinedFunctionUtils {
function: UserDefinedFunction,
signature: Seq[TypeInformation[_]])
: Option[Array[Class[_]]] = {
- // We compare the raw Java classes not the TypeInformation.
- // TypeInformation does not matter during runtime (e.g. within a MapFunction).
- val actualSignature = typeInfoToClass(signature)
- val signatures = getSignatures(function)
-
- signatures
- // go over all signatures and find one matching actual signature
- .find { curSig =>
- // match parameters of signature to actual parameters
- actualSignature.length == curSig.length &&
- curSig.zipWithIndex.forall { case (clazz, i) =>
- parameterTypeEquals(actualSignature(i), clazz)
- }
- }
+ getEvalMethod(function, signature).map(_.getParameterTypes)
}
/**
@@ -106,16 +93,52 @@ object UserDefinedFunctionUtils {
val actualSignature = typeInfoToClass(signature)
val evalMethods = checkAndExtractEvalMethods(function)
- evalMethods
- // go over all eval methods and find one matching
- .find { cur =>
- val signatures = cur.getParameterTypes
- // match parameters of signature to actual parameters
- actualSignature.length == signatures.length &&
- signatures.zipWithIndex.forall { case (clazz, i) =>
- parameterTypeEquals(actualSignature(i), clazz)
+ val filtered = evalMethods
+ // go over all eval methods and filter out matching methods
+ .filter {
+ case cur if !cur.isVarArgs =>
+ val signatures = cur.getParameterTypes
+ // match parameters of signature to actual parameters
+ actualSignature.length == signatures.length &&
+ signatures.zipWithIndex.forall { case (clazz, i) =>
+ parameterTypeEquals(actualSignature(i), clazz)
+ }
+ case cur if cur.isVarArgs =>
+ val signatures = cur.getParameterTypes
+ actualSignature.zipWithIndex.forall {
+ // non-varargs
+ case (clazz, i) if i < signatures.length - 1 =>
+ parameterTypeEquals(clazz, signatures(i))
+ // varargs
+ case (clazz, i) if i >= signatures.length - 1 =>
+ parameterTypeEquals(clazz, signatures.last.getComponentType)
+ } || (actualSignature.isEmpty && signatures.length == 1) // empty varargs
+ }
+
+ // if there is a fixed method, compiler will call this method preferentially
+ val fixedMethodsCount = filtered.count(!_.isVarArgs)
+ val found = filtered.filter { cur =>
+ fixedMethodsCount > 0 && !cur.isVarArgs ||
+ fixedMethodsCount == 0 && cur.isVarArgs
+ }
+
+ // check if there is a Scala varargs annotation
+ if (found.isEmpty &&
+ evalMethods.exists { evalMethod =>
+ val signatures = evalMethod.getParameterTypes
+ signatures.zipWithIndex.forall {
+ case (clazz, i) if i < signatures.length - 1 =>
+ parameterTypeEquals(actualSignature(i), clazz)
+ case (clazz, i) if i == signatures.length - 1 =>
+ clazz.getName.equals("scala.collection.Seq")
}
+ }) {
+ throw new ValidationException("Scala-style variable arguments in 'eval' methods are not " +
+ "supported. Please add a @scala.annotation.varargs annotation.")
+ } else if (found.length > 1) {
+ throw new ValidationException("Found multiple 'eval' methods which match the signature.")
}
+ found.headOption
}
/**
@@ -133,7 +156,7 @@ object UserDefinedFunctionUtils {
/**
* Extracts "eval" methods and throws a [[ValidationException]] if no implementation
- * can be found.
+ * can be found, or implementation does not match the requirements.
*/
def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = {
val methods = function
@@ -152,9 +175,9 @@ object UserDefinedFunctionUtils {
s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
s"one method named 'eval' which is public, not abstract and " +
s"(in case of table functions) not static.")
- } else {
- methods
}
+
+ methods
}
def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = {
@@ -317,6 +340,7 @@ object UserDefinedFunctionUtils {
private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
candidate == null ||
candidate == expected ||
+ expected == classOf[Object] ||
expected.isPrimitive && Primitives.wrap(expected) == candidate ||
candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
index e817f06..56f866d 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
@@ -33,4 +33,24 @@ public class UserDefinedScalarFunctions {
}
}
+ public static class JavaFunc2 extends ScalarFunction {
+ public String eval(String s, Integer... a) {
+ int m = 1;
+ for (int n : a) {
+ m *= n;
+ }
+ return s + m;
+ }
+ }
+
+ public static class JavaFunc3 extends ScalarFunction {
+ public int eval(String a, int... b) {
+ return b.length;
+ }
+
+ public String eval(String c) {
+ return c;
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
index a6c1760..51583c3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.types.Row
-import org.apache.flink.table.api.Types
-import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1}
+import org.apache.flink.table.api.{Types, ValidationException}
+import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1, JavaFunc2, JavaFunc3}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils._
import org.apache.flink.table.functions.ScalarFunction
@@ -181,6 +181,85 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
}
@Test
+ def testVariableArgs(): Unit = {
+ testAllApis(
+ Func14(1, 2, 3, 4),
+ "Func14(1, 2, 3, 4)",
+ "Func14(1, 2, 3, 4)",
+ "10")
+
+ // Test for empty arguments
+ testAllApis(
+ Func14(),
+ "Func14()",
+ "Func14()",
+ "0")
+
+ // Test for override
+ testAllApis(
+ Func15("Hello"),
+ "Func15('Hello')",
+ "Func15('Hello')",
+ "Hello"
+ )
+
+ testAllApis(
+ Func15('f1),
+ "Func15(f1)",
+ "Func15(f1)",
+ "Test"
+ )
+
+ testAllApis(
+ Func15("Hello", 1, 2, 3),
+ "Func15('Hello', 1, 2, 3)",
+ "Func15('Hello', 1, 2, 3)",
+ "Hello3"
+ )
+
+ testAllApis(
+ Func16('f9),
+ "Func16(f9)",
+ "Func16(f9)",
+ "Hello, World"
+ )
+
+ try {
+ testAllApis(
+ Func17("Hello", "World"),
+ "Func17('Hello', 'World')",
+ "Func17('Hello', 'World')",
+ "Hello, World"
+ )
+ throw new RuntimeException("Shouldn't be reached here!")
+ } catch {
+ case ex: ValidationException =>
+ // ok
+ }
+
+ val JavaFunc2 = new JavaFunc2
+ testAllApis(
+ JavaFunc2("Hi", 1, 3, 5, 7),
+ "JavaFunc2('Hi', 1, 3, 5, 7)",
+ "JavaFunc2('Hi', 1, 3, 5, 7)",
+ "Hi105")
+
+ // test overloading
+ val JavaFunc3 = new JavaFunc3
+ testAllApis(
+ JavaFunc3("Hi"),
+ "JavaFunc3('Hi')",
+ "JavaFunc3('Hi')",
+ "Hi")
+
+ testAllApis(
+ JavaFunc3('f1),
+ "JavaFunc3(f1)",
+ "JavaFunc3(f1)",
+ "Test")
+ }
+
+ @Test
def testJavaBoxedPrimitives(): Unit = {
val JavaFunc0 = new JavaFunc0()
val JavaFunc1 = new JavaFunc1()
@@ -238,7 +317,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
// ----------------------------------------------------------------------------------------------
override def testData: Any = {
- val testData = new Row(9)
+ val testData = new Row(10)
testData.setField(0, 42)
testData.setField(1, "Test")
testData.setField(2, null)
@@ -248,6 +327,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10"))
testData.setField(7, 12)
testData.setField(8, 1000L)
+ testData.setField(9, Seq("Hello", "World"))
testData
}
@@ -261,7 +341,8 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
Types.TIME,
Types.TIMESTAMP,
Types.INTERVAL_MONTHS,
- Types.INTERVAL_MILLIS
+ Types.INTERVAL_MILLIS,
+ TypeInformation.of(classOf[Seq[String]])
).asInstanceOf[TypeInformation[Any]]
}
@@ -279,8 +360,14 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"Func10" -> Func10,
"Func11" -> Func11,
"Func12" -> Func12,
+ "Func14" -> Func14,
+ "Func15" -> Func15,
+ "Func16" -> Func16,
+ "Func17" -> Func17,
"JavaFunc0" -> new JavaFunc0,
"JavaFunc1" -> new JavaFunc1,
+ "JavaFunc2" -> new JavaFunc2,
+ "JavaFunc3" -> new JavaFunc3,
"RichFunc0" -> new RichFunc0,
"RichFunc1" -> new RichFunc1,
"RichFunc2" -> new RichFunc2
http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
index 1258137..e858187 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
@@ -28,6 +28,8 @@ import org.junit.Assert
import scala.collection.mutable
import scala.io.Source
+import scala.annotation.varargs
+
case class SimplePojo(name: String, age: Int)
object Func0 extends ScalarFunction {
@@ -227,3 +229,37 @@ class Func13(prefix: String) extends ScalarFunction {
}
}
+object Func14 extends ScalarFunction {
+
+ @varargs
+ def eval(a: Int*): Int = {
+ a.sum
+ }
+}
+
+object Func15 extends ScalarFunction {
+
+ @varargs
+ def eval(a: String, b: Int*): String = {
+ a + b.length
+ }
+
+ def eval(a: String): String = {
+ a
+ }
+}
+
+object Func16 extends ScalarFunction {
+
+ def eval(a: Seq[String]): String = {
+ a.mkString(", ")
+ }
+}
+
+object Func17 extends ScalarFunction {
+
+ // Without @varargs, we will throw an exception
+ def eval(a: String*): String = {
+ a.mkString(", ")
+ }
+}