You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/01/30 13:52:29 UTC

flink git commit: [FLINK-5678] [table] Fix User-defined Functions do not support all types of parameters

Repository: flink
Updated Branches:
  refs/heads/master 50b665677 -> 126fb1779


[FLINK-5678] [table] Fix User-defined Functions do not support all types of parameters

This closes #3233.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/126fb177
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/126fb177
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/126fb177

Branch: refs/heads/master
Commit: 126fb1779b84a934dd6bb632c96a7666ac8c521e
Parents: 50b6656
Author: Jark Wu <wu...@alibaba-inc.com>
Authored: Sun Jan 29 18:36:28 2017 +0600
Committer: twalthr <tw...@apache.org>
Committed: Mon Jan 30 14:51:57 2017 +0100

----------------------------------------------------------------------
 .../codegen/calls/ScalarFunctionCallGen.scala   | 12 +++++++
 .../codegen/calls/TableFunctionCallGen.scala    | 12 +++++++
 .../flink/table/expressions/literals.scala      |  5 +++
 .../utils/UserDefinedFunctionUtils.scala        | 12 +++----
 .../java/utils/UserDefinedScalarFunctions.java  | 36 ++++++++++++++++++++
 .../java/utils/UserDefinedTableFunctions.java   | 31 +++++++++++++++++
 .../UserDefinedScalarFunctionTest.scala         | 36 +++++++++++++++++++-
 .../dataset/DataSetCorrelateITCase.scala        | 26 +++++++++++++-
 8 files changed, 162 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/126fb177/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
index ac840df..7ff18eb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
@@ -18,11 +18,13 @@
 
 package org.apache.flink.table.codegen.calls
 
+import org.apache.commons.lang3.ClassUtils
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.table.codegen.CodeGenUtils._
 import org.apache.flink.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression}
 import org.apache.flink.table.functions.ScalarFunction
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.table.typeutils.TypeCheckUtils
 
 /**
   * Generates a call to user-defined [[ScalarFunction]].
@@ -52,6 +54,16 @@ class ScalarFunctionCallGen(
         .map { case (paramClass, operandExpr) =>
           if (paramClass.isPrimitive) {
             operandExpr
+          } else if (ClassUtils.isPrimitiveWrapper(paramClass)
+              && TypeCheckUtils.isTemporal(operandExpr.resultType)) {
+            // we use primitives to represent temporal types internally, so no casting needed here
+            val exprOrNull: String = if (codeGenerator.nullCheck) {
+              s"${operandExpr.nullTerm} ? null : " +
+                s"(${paramClass.getCanonicalName}) ${operandExpr.resultTerm}"
+            } else {
+              operandExpr.resultTerm
+            }
+            operandExpr.copy(resultTerm = exprOrNull)
           } else {
             val boxedTypeTerm = boxedTypeTermForTypeInfo(operandExpr.resultType)
             val boxedExpr = codeGenerator.generateOutputFieldBoxing(operandExpr)

http://git-wip-us.apache.org/repos/asf/flink/blob/126fb177/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
index 6e44f55..890b6bd 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
@@ -18,12 +18,14 @@
 
 package org.apache.flink.table.codegen.calls
 
+import org.apache.commons.lang3.ClassUtils
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.table.codegen.CodeGenUtils._
 import org.apache.flink.table.codegen.GeneratedExpression.NEVER_NULL
 import org.apache.flink.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression}
 import org.apache.flink.table.functions.TableFunction
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.table.typeutils.TypeCheckUtils
 
 /**
   * Generates a call to user-defined [[TableFunction]].
@@ -52,6 +54,16 @@ class TableFunctionCallGen(
         .map { case (paramClass, operandExpr) =>
           if (paramClass.isPrimitive) {
             operandExpr
+          } else if (ClassUtils.isPrimitiveWrapper(paramClass)
+              && TypeCheckUtils.isTemporal(operandExpr.resultType)) {
+            // we use primitives to represent temporal types internally, so no casting needed here
+            val exprOrNull: String = if (codeGenerator.nullCheck) {
+              s"${operandExpr.nullTerm} ? null : " +
+                s"(${paramClass.getCanonicalName}) ${operandExpr.resultTerm}"
+            } else {
+              operandExpr.resultTerm
+            }
+            operandExpr.copy(resultTerm = exprOrNull)
           } else {
             val boxedTypeTerm = boxedTypeTermForTypeInfo(operandExpr.resultType)
             val boxedExpr = codeGenerator.generateOutputFieldBoxing(operandExpr)

http://git-wip-us.apache.org/repos/asf/flink/blob/126fb177/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala
index ccdfc2d..916fe73 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/literals.scala
@@ -68,6 +68,11 @@ case class Literal(value: Any, resultType: TypeInformation[_]) extends LeafExpre
         val decType = relBuilder.getTypeFactory.createSqlType(SqlTypeName.DECIMAL)
         relBuilder.getRexBuilder.makeExactLiteral(bigDecValue, decType)
 
+      // create BIGINT literals for long type
+      case BasicTypeInfo.LONG_TYPE_INFO =>
+        val bigint = java.math.BigDecimal.valueOf(value.asInstanceOf[Long])
+        relBuilder.getRexBuilder.makeBigintLiteral(bigint)
+
       // date/time
       case SqlTimeTypeInfo.DATE =>
         relBuilder.getRexBuilder.makeDateLiteral(dateToCalendar)

http://git-wip-us.apache.org/repos/asf/flink/blob/126fb177/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index fa4668d..f324dc1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -19,17 +19,17 @@
 
 package org.apache.flink.table.functions.utils
 
+import java.lang.{Long => JLong, Integer => JInt}
 import java.lang.reflect.{Method, Modifier}
 import java.sql.{Date, Time, Timestamp}
 
 import com.google.common.primitives.Primitives
 import org.apache.calcite.sql.SqlFunction
 import org.apache.flink.api.common.functions.InvalidTypesException
-import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
-import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.TypeExtractor
 import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
+import org.apache.flink.table.api.{TableEnvironment, ValidationException}
 import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
 import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
 import org.apache.flink.util.InstantiationUtil
@@ -320,8 +320,8 @@ object UserDefinedFunctionUtils {
   candidate == null ||
     candidate == expected ||
     expected.isPrimitive && Primitives.wrap(expected) == candidate ||
-    candidate == classOf[Date] && expected == classOf[Int] ||
-    candidate == classOf[Time] && expected == classOf[Int] ||
-    candidate == classOf[Timestamp] && expected == classOf[Long]
+    candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt])  ||
+    candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
+    candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong])
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/126fb177/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
new file mode 100644
index 0000000..e817f06
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.java.utils;
+
+import org.apache.flink.table.functions.ScalarFunction;
+
+public class UserDefinedScalarFunctions {
+
+	public static class JavaFunc0 extends ScalarFunction {
+		public long eval(Long l) {
+			return l + 1;
+		}
+	}
+
+	public static class JavaFunc1 extends ScalarFunction {
+		public String eval(Integer a, int b,  Long c) {
+			return a + " and " + b + " and " + c;
+		}
+	}
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/126fb177/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedTableFunctions.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedTableFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedTableFunctions.java
new file mode 100644
index 0000000..3af8646
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedTableFunctions.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.java.utils;
+
+import org.apache.flink.table.functions.TableFunction;
+
+public class UserDefinedTableFunctions {
+
+	public static class JavaTableFunc0 extends TableFunction<Long> {
+		public void eval(Integer a, Long b, Long c) {
+			collect(a.longValue());
+			collect(b);
+			collect(c);
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/126fb177/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
index 26bbd44..da8c748 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.types.Row
 import org.apache.flink.table.api.Types
+import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1}
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.expressions.utils._
 import org.apache.flink.table.functions.ScalarFunction
@@ -179,6 +180,37 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
       "+0 00:00:01.000")
   }
 
+  @Test
+  def testJavaBoxedPrimitives(): Unit = {
+    val JavaFunc0 = new JavaFunc0()
+    val JavaFunc1 = new JavaFunc1()
+
+    testAllApis(
+      JavaFunc0('f8),
+      "JavaFunc0(f8)",
+      "JavaFunc0(f8)",
+      "1001"
+    )
+
+    testTableApi(
+      JavaFunc0(1000L),
+      "JavaFunc0(1000L)",
+      "1001"
+    )
+
+    testAllApis(
+      JavaFunc1('f4, 'f5, 'f6),
+      "JavaFunc1(f4, f5, f6)",
+      "JavaFunc1(f4, f5, f6)",
+      "7591 and 43810000 and 655906210000")
+
+    testAllApis(
+      JavaFunc1(Null(Types.TIME), 15, Null(Types.TIMESTAMP)),
+      "JavaFunc1(Null(TIME), 15, Null(TIMESTAMP))",
+      "JavaFunc1(NULL, 15, NULL)",
+      "null and 15 and null")
+  }
+
   // ----------------------------------------------------------------------------------------------
 
   override def testData: Any = {
@@ -222,7 +254,9 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
     "Func9" -> Func9,
     "Func10" -> Func10,
     "Func11" -> Func11,
-    "Func12" -> Func12
+    "Func12" -> Func12,
+    "JavaFunc0" -> new JavaFunc0,
+    "JavaFunc1" -> new JavaFunc1
   )
 }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/126fb177/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
index 783a457..818f52b 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
@@ -17,12 +17,15 @@
  */
 package org.apache.flink.table.runtime.dataset
 
+import java.sql.{Date, Timestamp}
+
 import org.apache.flink.api.scala._
 import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.batch.utils.{TableProgramsClusterTestBase, TableProgramsCollectionTestBase}
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
 import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0
 import org.apache.flink.table.utils._
 import org.apache.flink.test.util.TestBaseUtils
 import org.junit.Test
@@ -161,6 +164,27 @@ class DataSetCorrelateITCase(
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
 
+  @Test
+  def testLongAndTemporalTypes(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+    val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+    val func0 = new JavaTableFunc0
+
+    val result = in
+        .where('a === 1)
+        .select(Date.valueOf("1990-10-14") as 'x,
+                1000L as 'y,
+                Timestamp.valueOf("1990-10-14 12:10:10") as 'z)
+        .join(func0('x, 'y, 'z) as 's)
+        .select('s)
+        .toDataSet[Row]
+
+    val results = result.collect()
+    val expected = "1000\n" + "655906210000\n" + "7591\n"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
   private def testData(
       env: ExecutionEnvironment)
     : DataSet[(Int, Long, String)] = {