You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/03/13 09:33:39 UTC

flink git commit: [FLINK-5881] [table] ScalarFunction(UDF) should support variable types and variable arguments

Repository: flink
Updated Branches:
  refs/heads/master 354a13edf -> 9b179beae


[FLINK-5881] [table] ScalarFunction(UDF) should support variable types and variable arguments

This closes #3389.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/9b179bea
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/9b179bea
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/9b179bea

Branch: refs/heads/master
Commit: 9b179beaea2b623ad3637e417f6d8014b696d038
Parents: 354a13e
Author: Zhuoluo Yang <zh...@alibaba-inc.com>
Authored: Wed Feb 22 18:53:34 2017 +0800
Committer: twalthr <tw...@apache.org>
Committed: Mon Mar 13 10:29:35 2017 +0100

----------------------------------------------------------------------
 .../codegen/calls/ScalarFunctionCallGen.scala   | 17 ++--
 .../functions/utils/ScalarSqlFunction.scala     | 26 ++++--
 .../utils/UserDefinedFunctionUtils.scala        | 74 +++++++++------
 .../java/utils/UserDefinedScalarFunctions.java  | 20 +++++
 .../UserDefinedScalarFunctionTest.scala         | 95 +++++++++++++++++++-
 .../utils/UserDefinedScalarFunctions.scala      | 36 ++++++++
 6 files changed, 229 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
index 7ff18eb..b0b4e09 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
@@ -43,15 +43,22 @@ class ScalarFunctionCallGen(
       codeGenerator: CodeGenerator,
       operands: Seq[GeneratedExpression])
     : GeneratedExpression = {
-    // determine function signature and result class
-    val matchingSignature = getSignature(scalarFunction, signature)
+    // determine function method and result class
+    val matchingMethod = getEvalMethod(scalarFunction, signature)
       .getOrElse(throw new CodeGenException("No matching signature found."))
+    val matchingSignature = matchingMethod.getParameterTypes
     val resultClass = getResultTypeClass(scalarFunction, matchingSignature)
 
+    // zip for variable signatures
+    var paramToOperands = matchingSignature.zip(operands)
+    if (operands.length > matchingSignature.length) {
+      operands.drop(matchingSignature.length).foreach(op =>
+        paramToOperands = paramToOperands :+ (matchingSignature.last.getComponentType, op)
+      )
+    }
+
     // convert parameters for function (output boxing)
-    val parameters = matchingSignature
-        .zip(operands)
-        .map { case (paramClass, operandExpr) =>
+    val parameters = paramToOperands.map { case (paramClass, operandExpr) =>
           if (paramClass.isPrimitive) {
             operandExpr
           } else if (ClassUtils.isPrimitiveWrapper(paramClass)

http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
index dc6d41f..e2cd272 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
@@ -113,9 +113,15 @@ object ScalarSqlFunction {
           .getParameterTypes(foundSignature)
           .map(typeFactory.createTypeFromTypeInfo)
 
-        inferredTypes.zipWithIndex.foreach {
-          case (inferredType, i) =>
-            operandTypes(i) = inferredType
+        for (i <- operandTypes.indices) {
+          if (i < inferredTypes.length - 1) {
+            operandTypes(i) = inferredTypes(i)
+          } else if (null != inferredTypes.last.getComponentType) {
+            // last argument is a collection, the array type
+            operandTypes(i) = inferredTypes.last.getComponentType
+          } else {
+            operandTypes(i) = inferredTypes.last
+          }
         }
       }
     }
@@ -137,8 +143,18 @@ object ScalarSqlFunction {
       }
 
       override def getOperandCountRange: SqlOperandCountRange = {
-        val signatureLengths = signatures.map(_.length)
-        SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max)
+        var min = 255
+        var max = -1
+        signatures.foreach( sig => {
+          var len = sig.length
+          if (len > 0 && sig(sig.length - 1).isArray) {
+            max = 254  // according to JVM spec 4.3.3
+            len = sig.length - 1
+          }
+          max = Math.max(len, max)
+          min = Math.min(len, min)
+        })
+        SqlOperandCountRanges.between(min, max)
       }
 
       override def checkOperandTypes(

http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 21d28b5..c1cfe06 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -78,20 +78,7 @@ object UserDefinedFunctionUtils {
       function: UserDefinedFunction,
       signature: Seq[TypeInformation[_]])
     : Option[Array[Class[_]]] = {
-    // We compare the raw Java classes not the TypeInformation.
-    // TypeInformation does not matter during runtime (e.g. within a MapFunction).
-    val actualSignature = typeInfoToClass(signature)
-    val signatures = getSignatures(function)
-
-    signatures
-      // go over all signatures and find one matching actual signature
-      .find { curSig =>
-      // match parameters of signature to actual parameters
-      actualSignature.length == curSig.length &&
-        curSig.zipWithIndex.forall { case (clazz, i) =>
-          parameterTypeEquals(actualSignature(i), clazz)
-        }
-    }
+    getEvalMethod(function, signature).map(_.getParameterTypes)
   }
 
   /**
@@ -106,16 +93,52 @@ object UserDefinedFunctionUtils {
     val actualSignature = typeInfoToClass(signature)
     val evalMethods = checkAndExtractEvalMethods(function)
 
-    evalMethods
-      // go over all eval methods and find one matching
-      .find { cur =>
-      val signatures = cur.getParameterTypes
-      // match parameters of signature to actual parameters
-      actualSignature.length == signatures.length &&
-        signatures.zipWithIndex.forall { case (clazz, i) =>
-          parameterTypeEquals(actualSignature(i), clazz)
+    val filtered = evalMethods
+      // go over all eval methods and filter out matching methods
+      .filter {
+        case cur if !cur.isVarArgs =>
+          val signatures = cur.getParameterTypes
+          // match parameters of signature to actual parameters
+          actualSignature.length == signatures.length &&
+            signatures.zipWithIndex.forall { case (clazz, i) =>
+              parameterTypeEquals(actualSignature(i), clazz)
+          }
+        case cur if cur.isVarArgs =>
+          val signatures = cur.getParameterTypes
+          actualSignature.zipWithIndex.forall {
+            // non-varargs
+            case (clazz, i) if i < signatures.length - 1  =>
+              parameterTypeEquals(clazz, signatures(i))
+            // varargs
+            case (clazz, i) if i >= signatures.length - 1 =>
+              parameterTypeEquals(clazz, signatures.last.getComponentType)
+          } || (actualSignature.isEmpty && signatures.length == 1) // empty varargs
+    }
+
+    // if there is a fixed method, compiler will call this method preferentially
+    val fixedMethodsCount = filtered.count(!_.isVarArgs)
+    val found = filtered.filter { cur =>
+      fixedMethodsCount > 0 && !cur.isVarArgs ||
+      fixedMethodsCount == 0 && cur.isVarArgs
+    }
+
+    // check if there is a Scala varargs annotation
+    if (found.isEmpty &&
+      evalMethods.exists { evalMethod =>
+        val signatures = evalMethod.getParameterTypes
+        signatures.zipWithIndex.forall {
+          case (clazz, i) if i < signatures.length - 1 =>
+            parameterTypeEquals(actualSignature(i), clazz)
+          case (clazz, i) if i == signatures.length - 1 =>
+            clazz.getName.equals("scala.collection.Seq")
         }
+      }) {
+      throw new ValidationException("Scala-style variable arguments in 'eval' methods are not " +
+        "supported. Please add a @scala.annotation.varargs annotation.")
+    } else if (found.length > 1) {
+      throw new ValidationException("Found multiple 'eval' methods which match the signature.")
     }
+    found.headOption
   }
 
   /**
@@ -133,7 +156,7 @@ object UserDefinedFunctionUtils {
 
   /**
     * Extracts "eval" methods and throws a [[ValidationException]] if no implementation
-    * can be found.
+    * can be found, or implementation does not match the requirements.
     */
   def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = {
     val methods = function
@@ -152,9 +175,9 @@ object UserDefinedFunctionUtils {
         s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
           s"one method named 'eval' which is public, not abstract and " +
           s"(in case of table functions) not static.")
-    } else {
-      methods
     }
+
+    methods
   }
 
   def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = {
@@ -317,6 +340,7 @@ object UserDefinedFunctionUtils {
   private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
   candidate == null ||
     candidate == expected ||
+    expected == classOf[Object] ||
     expected.isPrimitive && Primitives.wrap(expected) == candidate ||
     candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt])  ||
     candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||

http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
index e817f06..56f866d 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
@@ -33,4 +33,24 @@ public class UserDefinedScalarFunctions {
 		}
 	}
 
+	public static class JavaFunc2 extends ScalarFunction {
+		public String eval(String s, Integer... a) {
+			int m = 1;
+			for (int n : a) {
+				m *= n;
+			}
+			return s + m;
+		}
+	}
+
+	public static class JavaFunc3 extends ScalarFunction {
+		public int eval(String a, int... b) {
+			return b.length;
+		}
+
+		public String eval(String c) {
+			return c;
+		}
+	}
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
index a6c1760..51583c3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.types.Row
-import org.apache.flink.table.api.Types
-import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1}
+import org.apache.flink.table.api.{Types, ValidationException}
+import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1, JavaFunc2, JavaFunc3}
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.expressions.utils._
 import org.apache.flink.table.functions.ScalarFunction
@@ -181,6 +181,85 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
   }
   
   @Test
+  def testVariableArgs(): Unit = {
+    testAllApis(
+      Func14(1, 2, 3, 4),
+      "Func14(1, 2, 3, 4)",
+      "Func14(1, 2, 3, 4)",
+      "10")
+
+    // Test for empty arguments
+    testAllApis(
+      Func14(),
+      "Func14()",
+      "Func14()",
+      "0")
+
+    // Test for override
+    testAllApis(
+      Func15("Hello"),
+      "Func15('Hello')",
+      "Func15('Hello')",
+      "Hello"
+    )
+
+    testAllApis(
+      Func15('f1),
+      "Func15(f1)",
+      "Func15(f1)",
+      "Test"
+    )
+
+    testAllApis(
+      Func15("Hello", 1, 2, 3),
+      "Func15('Hello', 1, 2, 3)",
+      "Func15('Hello', 1, 2, 3)",
+      "Hello3"
+    )
+
+    testAllApis(
+      Func16('f9),
+      "Func16(f9)",
+      "Func16(f9)",
+      "Hello, World"
+    )
+
+    try {
+      testAllApis(
+        Func17("Hello", "World"),
+        "Func17('Hello', 'World')",
+        "Func17('Hello', 'World')",
+        "Hello, World"
+      )
+      throw new RuntimeException("Shouldn't be reached here!")
+    } catch {
+      case ex: ValidationException =>
+        // ok
+    }
+
+    val JavaFunc2 = new JavaFunc2
+    testAllApis(
+      JavaFunc2("Hi", 1, 3, 5, 7),
+      "JavaFunc2('Hi', 1, 3, 5, 7)",
+      "JavaFunc2('Hi', 1, 3, 5, 7)",
+      "Hi105")
+
+    // test overloading
+    val JavaFunc3 = new JavaFunc3
+    testAllApis(
+      JavaFunc3("Hi"),
+      "JavaFunc3('Hi')",
+      "JavaFunc3('Hi')",
+      "Hi")
+
+    testAllApis(
+      JavaFunc3('f1),
+      "JavaFunc3(f1)",
+      "JavaFunc3(f1)",
+      "Test")
+  }
+
+  @Test
   def testJavaBoxedPrimitives(): Unit = {
     val JavaFunc0 = new JavaFunc0()
     val JavaFunc1 = new JavaFunc1()
@@ -238,7 +317,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
   // ----------------------------------------------------------------------------------------------
 
   override def testData: Any = {
-    val testData = new Row(9)
+    val testData = new Row(10)
     testData.setField(0, 42)
     testData.setField(1, "Test")
     testData.setField(2, null)
@@ -248,6 +327,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
     testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10"))
     testData.setField(7, 12)
     testData.setField(8, 1000L)
+    testData.setField(9, Seq("Hello", "World"))
     testData
   }
 
@@ -261,7 +341,8 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
       Types.TIME,
       Types.TIMESTAMP,
       Types.INTERVAL_MONTHS,
-      Types.INTERVAL_MILLIS
+      Types.INTERVAL_MILLIS,
+      TypeInformation.of(classOf[Seq[String]])
     ).asInstanceOf[TypeInformation[Any]]
   }
 
@@ -279,8 +360,14 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
     "Func10" -> Func10,
     "Func11" -> Func11,
     "Func12" -> Func12,
+    "Func14" -> Func14,
+    "Func15" -> Func15,
+    "Func16" -> Func16,
+    "Func17" -> Func17,
     "JavaFunc0" -> new JavaFunc0,
     "JavaFunc1" -> new JavaFunc1,
+    "JavaFunc2" -> new JavaFunc2,
+    "JavaFunc3" -> new JavaFunc3,
     "RichFunc0" -> new RichFunc0,
     "RichFunc1" -> new RichFunc1,
     "RichFunc2" -> new RichFunc2

http://git-wip-us.apache.org/repos/asf/flink/blob/9b179bea/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
index 1258137..e858187 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
@@ -28,6 +28,8 @@ import org.junit.Assert
 import scala.collection.mutable
 import scala.io.Source
 
+import scala.annotation.varargs
+
 case class SimplePojo(name: String, age: Int)
 
 object Func0 extends ScalarFunction {
@@ -227,3 +229,37 @@ class Func13(prefix: String) extends ScalarFunction {
   }
 }
 
+object Func14 extends ScalarFunction {
+
+  @varargs
+  def eval(a: Int*): Int = {
+    a.sum
+  }
+}
+
+object Func15 extends ScalarFunction {
+
+  @varargs
+  def eval(a: String, b: Int*): String = {
+    a + b.length
+  }
+
+  def eval(a: String): String = {
+    a
+  }
+}
+
+object Func16 extends ScalarFunction {
+
+  def eval(a: Seq[String]): String = {
+    a.mkString(", ")
+  }
+}
+
+object Func17 extends ScalarFunction {
+
+  // Without @varargs, we will throw an exception
+  def eval(a: String*): String = {
+    a.mkString(", ")
+  }
+}