You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/05/09 16:50:55 UTC

[2/4] flink git commit: [FLINK-6436] [table] Fix code-gen bug when using a scalar UDF in a UDTF join condition.

[FLINK-6436] [table] Fix code-gen bug when using a scalar UDF in a UDTF join condition.

This closes #3815.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e2cb2215
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e2cb2215
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e2cb2215

Branch: refs/heads/master
Commit: e2cb2215917e33d35fff5b07ed6a64c05e14abce
Parents: f26a911
Author: godfreyhe <go...@163.com>
Authored: Wed May 3 20:59:34 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue May 9 18:50:20 2017 +0200

----------------------------------------------------------------------
 .../table/plan/nodes/CommonCorrelate.scala      | 13 ++++++++++-
 .../utils/UserDefinedScalarFunctions.scala      |  9 ++++++--
 .../DataSetUserDefinedFunctionITCase.scala      | 24 +++++++++++++++++---
 .../DataStreamUserDefinedFunctionITCase.scala   | 22 ++++++++++++++++--
 4 files changed, 60 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/e2cb2215/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala
index 44a109e3..c95f2f7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala
@@ -18,7 +18,7 @@
 package org.apache.flink.table.plan.nodes
 
 import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode, RexShuttle}
 import org.apache.calcite.sql.SemiJoinType
 import org.apache.flink.api.common.functions.FlatMapFunction
 import org.apache.flink.api.common.typeinfo.TypeInformation
@@ -143,6 +143,17 @@ trait CommonCorrelate[T] {
          |getCollector().collect(${crossResultExpr.resultTerm});
          |""".stripMargin
     } else {
+
+      // adjust indicies of InputRefs to adhere to schema expected by generator
+      val changeInputRefIndexShuttle = new RexShuttle {
+        override def visitInputRef(inputRef: RexInputRef): RexNode = {
+          new RexInputRef(inputSchema.physicalArity + inputRef.getIndex, inputRef.getType)
+        }
+      }
+      // Run generateExpression to add init statements (ScalarFunctions) of condition to generator.
+      //   The generated expression is discarded.
+      generator.generateExpression(condition.get.accept(changeInputRefIndexShuttle))
+
       val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo, None, pojoFieldMapping)
       filterGenerator.input1Term = filterGenerator.input2Term
       val filterCondition = filterGenerator.generateExpression(condition.get)

http://git-wip-us.apache.org/repos/asf/flink/blob/e2cb2215/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
index 8972a77..5285569 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
@@ -25,11 +25,10 @@ import org.apache.flink.table.api.Types
 import org.apache.flink.table.functions.{ScalarFunction, FunctionContext}
 import org.junit.Assert
 
+import scala.annotation.varargs
 import scala.collection.mutable
 import scala.io.Source
 
-import scala.annotation.varargs
-
 case class SimplePojo(name: String, age: Int)
 
 object Func0 extends ScalarFunction {
@@ -263,3 +262,9 @@ object Func17 extends ScalarFunction {
     a.mkString(", ")
   }
 }
+
+object Func18 extends ScalarFunction {
+  def eval(str: String, prefix: String): Boolean = {
+    str.startsWith(prefix)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e2cb2215/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala
index 20bbf8b..b69dd49 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala
@@ -24,9 +24,9 @@ import org.apache.flink.api.scala.util.CollectionDataSets
 import org.apache.flink.table.api.TableEnvironment
 import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0
 import org.apache.flink.table.api.scala._
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
 import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
-import org.apache.flink.table.expressions.utils.{Func13, RichFunc2}
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.expressions.utils.{Func1, Func13, Func18, RichFunc2}
 import org.apache.flink.table.utils._
 import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
 import org.apache.flink.test.util.TestBaseUtils
@@ -143,7 +143,7 @@ class DataSetUserDefinedFunctionITCase(
     val pojo = new PojoTableFunc()
     val result = in
       .join(pojo('c))
-      .where(('age > 20))
+      .where('age > 20)
       .select('c, 'name, 'age)
       .toDataSet[Row]
 
@@ -171,6 +171,24 @@ class DataSetUserDefinedFunctionITCase(
   }
 
   @Test
+  def testUserDefinedTableFunctionWithScalarFunctionInCondition(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+    val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+    val func0 = new TableFunc0
+
+    val result = in
+      .join(func0('c))
+      .where(Func18('name, "J") && (Func1('a) < 3) && Func1('age) > 20)
+      .select('c, 'name, 'age)
+      .toDataSet[Row]
+
+    val results = result.collect()
+    val expected = "Jack#22,Jack,22"
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
   def testLongAndTemporalTypes(): Unit = {
     val env = ExecutionEnvironment.getExecutionEnvironment
     val tableEnv = TableEnvironment.getTableEnvironment(env, config)

http://git-wip-us.apache.org/repos/asf/flink/blob/e2cb2215/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala
index 2e8a065..b3d9c6f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala
@@ -23,7 +23,7 @@ import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
 import org.apache.flink.table.api.TableEnvironment
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
-import org.apache.flink.table.expressions.utils.{Func13, RichFunc2}
+import org.apache.flink.table.expressions.utils.{Func13, Func18, RichFunc2}
 import org.apache.flink.table.utils._
 import org.apache.flink.types.Row
 import org.junit.Assert._
@@ -51,7 +51,7 @@ class DataStreamUserDefinedFunctionITCase extends StreamingMultipleProgramsTestB
       .join(func0('c) as('d, 'e))
       .select('c, 'd, 'e)
       .join(pojoFunc0('c))
-      .where(('age > 20))
+      .where('age > 20)
       .select('c, 'name, 'age)
       .toDataStream[Row]
 
@@ -82,6 +82,24 @@ class DataStreamUserDefinedFunctionITCase extends StreamingMultipleProgramsTestB
   }
 
   @Test
+  def testUserDefinedTableFunctionWithScalarFunction(): Unit = {
+    val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
+    val func0 = new TableFunc0
+
+    val result = t
+      .join(func0('c) as('d, 'e))
+      .where(Func18('d, "J"))
+      .select('c, 'd, 'e)
+      .toDataStream[Row]
+
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
   def testUserDefinedTableFunctionWithParameter(): Unit = {
     val tableFunc1 = new RichTableFunc1
     tEnv.registerFunction("RichTableFunc1", tableFunc1)