You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2016/12/07 15:57:20 UTC
[1/5] flink git commit: [FLINK-4469] [table] Add support for user
defined table function in Table API & SQL
Repository: flink
Updated Branches:
refs/heads/master c024b0b6c -> 684defbf3
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
new file mode 100644
index 0000000..f19f7f9
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.stream
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.stream.utils.StreamITCase
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils.{TableFunc0, TableFunc1}
+import org.apache.flink.api.table.{Row, TableEnvironment}
+import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
+import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
+import org.junit.Assert._
+import org.junit.Test
+
+import scala.collection.mutable
+
+class UserDefinedTableFunctionITCase extends StreamingMultipleProgramsTestBase {
+
+ @Test
+ def testSQLCrossApply(): Unit = {
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+ tEnv.registerTable("MyTable", t)
+
+ tEnv.registerFunction("split", new TableFunc0)
+
+ val sqlQuery = "SELECT MyTable.c, t.n, t.a FROM MyTable, LATERAL TABLE(split(c)) AS t(n,a)"
+
+ val result = tEnv.sql(sqlQuery).toDataStream[Row]
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testSQLOuterApply(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+ tEnv.registerTable("MyTable", t)
+
+ tEnv.registerFunction("split", new TableFunc0)
+
+ val sqlQuery = "SELECT MyTable.c, t.n, t.a FROM MyTable " +
+ "LEFT JOIN LATERAL TABLE(split(c)) AS t(n,a) ON TRUE"
+
+ val result = tEnv.sql(sqlQuery).toDataStream[Row]
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "nosharp,null,null", "Jack#22,Jack,22",
+ "John#19,John,19", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testTableAPICrossApply(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = t
+ .crossApply(func0('c) as('d, 'e))
+ .select('c, 'd, 'e)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testTableAPIOuterApply(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = t
+ .outerApply(func0('c) as('d, 'e))
+ .select('c, 'd, 'e)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "nosharp,null,null", "Jack#22,Jack,22",
+ "John#19,John,19", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testTableAPIWithFilter(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = t
+ .crossApply(func0('c) as('name, 'age))
+ .select('c, 'name, 'age)
+ .filter('age > 20)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("Jack#22,Jack,22", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testTableAPIWithScalarFunction(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+ val func1 = new TableFunc1
+
+ val result = t
+ .crossApply(func1('c.substring(2)) as 's)
+ .select('c, 's)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("Jack#22,ack", "Jack#22,22", "John#19,ohn",
+ "John#19,19", "Anna#44,nna", "Anna#44,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ private def getSmall3TupleDataStream(
+ env: StreamExecutionEnvironment)
+ : DataStream[(Int, Long, String)] = {
+
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "Jack#22"))
+ data.+=((2, 2L, "John#19"))
+ data.+=((3, 2L, "Anna#44"))
+ data.+=((4, 3L, "nosharp"))
+ env.fromCollection(data)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..bc01819
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.stream
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils._
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.apache.flink.api.table._
+import org.apache.flink.streaming.api.environment.{StreamExecutionEnvironment => JavaExecutionEnv}
+import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream}
+import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment => ScalaExecutionEnv}
+import org.junit.Assert.{assertTrue, fail}
+import org.junit.Test
+import org.mockito.Mockito._
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+ @Test
+ def testTableAPI(): Unit = {
+ // mock
+ val ds = mock(classOf[DataStream[Row]])
+ val jDs = mock(classOf[JDataStream[Row]])
+ val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
+ when(ds.javaStream).thenReturn(jDs)
+ when(jDs.getType).thenReturn(typeInfo)
+
+ // Scala environment
+ val env = mock(classOf[ScalaExecutionEnv])
+ val tableEnv = TableEnvironment.getTableEnvironment(env)
+ val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
+
+ // Java environment
+ val javaEnv = mock(classOf[JavaExecutionEnv])
+ val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
+ val in2 = javaTableEnv.fromDataStream(jDs).as("a, b, c")
+
+ // test cross apply
+ val func1 = new TableFunc1
+ javaTableEnv.registerFunction("func1", func1)
+ var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's)
+ var javaTable = in2.crossApply("func1(c) as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test outer apply
+ scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's)
+ javaTable = in2.outerApply("func1(c) as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test overloading
+ scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's)
+ javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test custom result type
+ val func2 = new TableFunc2
+ javaTableEnv.registerFunction("func2", func2)
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
+ javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test hierarchy generic type
+ val hierarchy = new HierarchyTableFunction
+ javaTableEnv.registerFunction("hierarchy", hierarchy)
+ scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'len, 'adult)
+ javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)")
+ .select("c, name, len, adult")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test pojo type
+ val pojo = new PojoTableFunc
+ javaTableEnv.registerFunction("pojo", pojo)
+ scalaTable = in1.crossApply(pojo('c))
+ .select('c, 'name, 'age)
+ javaTable = in2.crossApply("pojo(c)")
+ .select("c, name, age")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with filter
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len).filter('len > 2)
+ javaTable = in2.crossApply("func2(c) as (name, len)")
+ .select("c, name, len").filter("len > 2")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with scalar function
+ scalaTable = in1.crossApply(func1('c.substring(2)) as ('s))
+ .select('a, 'c, 's)
+ javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
+ .select("a, c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // check scala object is forbidden
+ expectExceptionThrown(
+ tableEnv.registerFunction("func3", ObjectTableFunction), "Scala object")
+ expectExceptionThrown(
+ javaTableEnv.registerFunction("func3", ObjectTableFunction), "Scala object")
+ expectExceptionThrown(
+ in1.crossApply(ObjectTableFunction('a, 1)),"Scala object")
+
+ }
+
+
+ @Test
+ def testInvalidTableFunction(): Unit = {
+ // mock
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val tEnv = TableEnvironment.getTableEnvironment(mock(classOf[JavaExecutionEnv]))
+
+ //=================== check scala object is forbidden =====================
+ // Scala table environment register
+ expectExceptionThrown(util.addFunction("udtf", ObjectTableFunction), "Scala object")
+ // Java table environment register
+ expectExceptionThrown(tEnv.registerFunction("udtf", ObjectTableFunction), "Scala object")
+ // Scala Table API directly call
+ expectExceptionThrown(t.crossApply(ObjectTableFunction('a, 1)), "Scala object")
+
+
+ //============ throw exception when table function is not registered =========
+ // Java Table API call
+ expectExceptionThrown(t.crossApply("nonexist(a)"), "Undefined function: NONEXIST")
+ // SQL API call
+ expectExceptionThrown(
+ util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(nonexist(a))"),
+ "No match found for function signature nonexist(<NUMERIC>)")
+
+
+ //========= throw exception when the called function is a scalar function ====
+ util.addFunction("func0", Func0)
+ // Java Table API call
+ expectExceptionThrown(
+ t.crossApply("func0(a)"),
+ "only accept TableFunction",
+ classOf[TableException])
+ // SQL API call
+ // NOTE: it doesn't throw an exception but an AssertionError, maybe a Calcite bug
+ expectExceptionThrown(
+ util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func0(a))"),
+ null,
+ classOf[AssertionError])
+
+ //========== throw exception when the parameters is not correct ===============
+ // Java Table API call
+ util.addFunction("func2", new TableFunc2)
+ expectExceptionThrown(
+ t.crossApply("func2(c, c)"),
+ "Given parameters of function 'FUNC2' do not match any signature")
+ // SQL API call
+ expectExceptionThrown(
+ util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func2(c, c))"),
+ "No match found for function signature func2(<CHARACTER>, <CHARACTER>)")
+ }
+
+ private def expectExceptionThrown(
+ function: => Unit,
+ keywords: String,
+ clazz: Class[_ <: Throwable] = classOf[ValidationException])
+ : Unit = {
+ try {
+ function
+ fail(s"Expected a $clazz, but no exception is thrown.")
+ } catch {
+ case e if e.getClass == clazz =>
+ if (keywords != null) {
+ assertTrue(
+ s"The exception message '${e.getMessage}' doesn't contain keyword '$keywords'",
+ e.getMessage.contains(keywords))
+ }
+ case e: Throwable => fail(s"Expected throw ${clazz.getSimpleName}, but is $e.")
+ }
+ }
+
+ @Test
+ def testSQLWithCrossApply(): Unit = {
+ val util = streamTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+
+ // test overloading
+
+ val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
+
+ val expected2 = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func1($cor0.c, '$')"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery2, expected2)
+ }
+
+ @Test
+ def testSQLWithOuterApply(): Unit = {
+ val util = streamTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "LEFT")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithCustomType(): Unit = {
+ val util = streamTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithHierarchyType(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new HierarchyTableFunction
+ util.addFunction("hierarchy", function)
+
+ val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "hierarchy($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithPojoType(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new PojoTableFunc
+ util.addFunction("pojo", function)
+
+ val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "pojo($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " INTEGER age, VARCHAR(2147483647) name)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "name", "age")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithFilter(): Unit = {
+ val util = streamTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " +
+ "WHERE len > 2"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER"),
+ term("condition", ">($1, 2)")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+
+ @Test
+ def testSQLWithScalarFunction(): Unit = {
+ val util = streamTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
index 95cb331..ffe3cd3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -24,7 +24,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala.table._
import org.apache.flink.api.table.expressions.utils._
-import org.apache.flink.api.table.functions.UserDefinedFunction
+import org.apache.flink.api.table.functions.ScalarFunction
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.{Row, Types}
import org.junit.Test
@@ -208,7 +208,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
)).asInstanceOf[TypeInformation[Any]]
}
- override def functions: Map[String, UserDefinedFunction] = Map(
+ override def functions: Map[String, ScalarFunction] = Map(
"Func0" -> Func0,
"Func1" -> Func1,
"Func2" -> Func2,
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
index 84b61da..958fd25 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
@@ -30,7 +30,7 @@ import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
import org.apache.flink.api.table._
import org.apache.flink.api.table.codegen.{CodeGenerator, Compiler, GeneratedFunction}
import org.apache.flink.api.table.expressions.{Expression, ExpressionParser}
-import org.apache.flink.api.table.functions.UserDefinedFunction
+import org.apache.flink.api.table.functions.ScalarFunction
import org.apache.flink.api.table.plan.nodes.dataset.{DataSetCalc, DataSetConvention}
import org.apache.flink.api.table.plan.rules.FlinkRuleSets
import org.apache.flink.api.table.typeutils.RowTypeInfo
@@ -79,7 +79,7 @@ abstract class ExpressionTestBase {
def typeInfo: TypeInformation[Any]
- def functions: Map[String, UserDefinedFunction] = Map()
+ def functions: Map[String, ScalarFunction] = Map()
@Before
def resetTestExprs() = {
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
new file mode 100644
index 0000000..1e6bdb8
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.expressions.utils
+
+import java.lang.Boolean
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.tuple.Tuple3
+import org.apache.flink.api.table.Row
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+
+
+case class SimpleUser(name: String, age: Int)
+
+class TableFunc0 extends TableFunction[SimpleUser] {
+ // make sure input element's format is "<string>#<int>"
+ def eval(user: String): Unit = {
+ if (user.contains("#")) {
+ val splits = user.split("#")
+ collect(SimpleUser(splits(0), splits(1).toInt))
+ }
+ }
+}
+
+class TableFunc1 extends TableFunction[String] {
+ def eval(str: String): Unit = {
+ if (str.contains("#")){
+ str.split("#").foreach(collect)
+ }
+ }
+
+ def eval(str: String, prefix: String): Unit = {
+ if (str.contains("#")) {
+ str.split("#").foreach(s => collect(prefix + s))
+ }
+ }
+}
+
+
+class TableFunc2 extends TableFunction[Row] {
+ def eval(str: String): Unit = {
+ if (str.contains("#")) {
+ str.split("#").foreach({ s =>
+ val row = new Row(2)
+ row.setField(0, s)
+ row.setField(1, s.length)
+ collect(row)
+ })
+ }
+ }
+
+ override def getResultType: TypeInformation[Row] = {
+ new RowTypeInfo(Seq(BasicTypeInfo.STRING_TYPE_INFO,
+ BasicTypeInfo.INT_TYPE_INFO))
+ }
+}
+
+class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] {
+ def eval(user: String) {
+ if (user.contains("#")) {
+ val splits = user.split("#")
+ val age = splits(1).toInt
+ collect(new Tuple3[String, Boolean, Integer](splits(0), age >= 20, age))
+ }
+ }
+}
+
+abstract class SplittableTableFunction[A, B] extends TableFunction[Tuple3[String, A, B]] {}
+
+class PojoTableFunc extends TableFunction[PojoUser] {
+ def eval(user: String) {
+ if (user.contains("#")) {
+ val splits = user.split("#")
+ collect(new PojoUser(splits(0), splits(1).toInt))
+ }
+ }
+}
+
+class PojoUser() {
+ var name: String = _
+ var age: Int = 0
+
+ def this(name: String, age: Int) {
+ this()
+ this.name = name
+ this.age = age
+ }
+}
+
+// ----------------------------------------------------------------------------------------------
+// Invalid Table Functions
+// ----------------------------------------------------------------------------------------------
+
+
+// this is used to check whether scala object is forbidden
+object ObjectTableFunction extends TableFunction[Integer] {
+ def eval(a: Int, b: Int): Unit = {
+ collect(a)
+ collect(b)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
index 539bb61..73f50f5 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.java.{DataSet => JDataSet}
import org.apache.flink.api.scala.table._
import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
import org.apache.flink.api.table.expressions.Expression
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.api.table.{Table, TableEnvironment}
import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream}
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
@@ -43,6 +44,12 @@ class TableTestBase {
StreamTableTestUtil()
}
+ def verifyTableEquals(expected: Table, actual: Table): Unit = {
+ assertEquals("Logical Plan do not match",
+ RelOptUtil.toString(expected.getRelNode),
+ RelOptUtil.toString(actual.getRelNode))
+ }
+
}
abstract class TableTestUtil {
@@ -54,6 +61,9 @@ abstract class TableTestUtil {
}
def addTable[T: TypeInformation](name: String, fields: Expression*): Table
+ def addFunction[T: TypeInformation](name: String, function: TableFunction[T]): Unit
+ def addFunction(name: String, function: ScalarFunction): Unit
+
def verifySql(query: String, expected: String): Unit
def verifyTable(resultTable: Table, expected: String): Unit
@@ -119,6 +129,17 @@ case class BatchTableTestUtil() extends TableTestUtil {
t
}
+ def addFunction[T: TypeInformation](
+ name: String,
+ function: TableFunction[T])
+ : Unit = {
+ tEnv.registerFunction(name, function)
+ }
+
+ def addFunction(name: String, function: ScalarFunction): Unit = {
+ tEnv.registerFunction(name, function)
+ }
+
def verifySql(query: String, expected: String): Unit = {
verifyTable(tEnv.sql(query), expected)
}
@@ -164,6 +185,17 @@ case class StreamTableTestUtil() extends TableTestUtil {
t
}
+ def addFunction[T: TypeInformation](
+ name: String,
+ function: TableFunction[T])
+ : Unit = {
+ tEnv.registerFunction(name, function)
+ }
+
+ def addFunction(name: String, function: ScalarFunction): Unit = {
+ tEnv.registerFunction(name, function)
+ }
+
def verifySql(query: String, expected: String): Unit = {
verifyTable(tEnv.sql(query), expected)
}
[2/5] flink git commit: [FLINK-4469] [table] Add support for user
defined table function in Table API & SQL
Posted by tw...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
index e7416f7..932baeb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -19,14 +19,18 @@
package org.apache.flink.api.table.functions.utils
+import java.lang.reflect.{Method, Modifier}
import java.sql.{Date, Time, Timestamp}
import com.google.common.primitives.Primitives
+import org.apache.calcite.sql.SqlFunction
import org.apache.flink.api.common.functions.InvalidTypesException
-import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.TypeExtractor
-import org.apache.flink.api.table.ValidationException
-import org.apache.flink.api.table.functions.{ScalarFunction, UserDefinedFunction}
+import org.apache.flink.api.table.{FlinkTypeFactory, TableException, ValidationException}
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
+import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.util.InstantiationUtil
object UserDefinedFunctionUtils {
@@ -62,101 +66,167 @@ object UserDefinedFunctionUtils {
.getOrElse(throw ValidationException("Function class needs a default constructor."))
}
+ /**
+ * Check whether this is a Scala object. It is forbidden to use [[TableFunction]] implemented
+ * by a Scala object, since concurrent risks.
+ */
+ def checkNotSingleton(clazz: Class[_]): Unit = {
+ // TODO it is not a good way to check singleton. Maybe improve it further.
+ if (clazz.getFields.map(_.getName) contains "MODULE$") {
+ throw new ValidationException(
+ s"TableFunction implemented by class ${clazz.getCanonicalName} " +
+ s"is a Scala object, it is forbidden since concurrent risks.")
+ }
+ }
+
// ----------------------------------------------------------------------------------------------
- // Utilities for ScalarFunction
+ // Utilities for eval methods
// ----------------------------------------------------------------------------------------------
/**
- * Prints one signature consisting of classes.
+ * Returns signatures matching the given signature of [[TypeInformation]].
+ * Elements of the signature can be null (act as a wildcard).
*/
- def signatureToString(signature: Array[Class[_]]): String =
- "(" + signature.map { clazz =>
- if (clazz == null) {
- "null"
- } else {
- clazz.getCanonicalName
- }
- }.mkString(", ") + ")"
+ def getSignature(
+ function: UserDefinedFunction,
+ signature: Seq[TypeInformation[_]])
+ : Option[Array[Class[_]]] = {
+ // We compare the raw Java classes not the TypeInformation.
+ // TypeInformation does not matter during runtime (e.g. within a MapFunction).
+ val actualSignature = typeInfoToClass(signature)
+ val signatures = getSignatures(function)
+
+ signatures
+ // go over all signatures and find one matching actual signature
+ .find { curSig =>
+ // match parameters of signature to actual parameters
+ actualSignature.length == curSig.length &&
+ curSig.zipWithIndex.forall { case (clazz, i) =>
+ parameterTypeEquals(actualSignature(i), clazz)
+ }
+ }
+ }
/**
- * Prints one signature consisting of TypeInformation.
+ * Returns eval method matching the given signature of [[TypeInformation]].
*/
- def signatureToString(signature: Seq[TypeInformation[_]]): String = {
- signatureToString(typeInfoToClass(signature))
+ def getEvalMethod(
+ function: UserDefinedFunction,
+ signature: Seq[TypeInformation[_]])
+ : Option[Method] = {
+ // We compare the raw Java classes not the TypeInformation.
+ // TypeInformation does not matter during runtime (e.g. within a MapFunction).
+ val actualSignature = typeInfoToClass(signature)
+ val evalMethods = checkAndExtractEvalMethods(function)
+
+ evalMethods
+ // go over all eval methods and find one matching
+ .find { cur =>
+ val signatures = cur.getParameterTypes
+ // match parameters of signature to actual parameters
+ actualSignature.length == signatures.length &&
+ signatures.zipWithIndex.forall { case (clazz, i) =>
+ parameterTypeEquals(actualSignature(i), clazz)
+ }
+ }
}
/**
- * Extracts type classes of [[TypeInformation]] in a null-aware way.
+ * Extracts "eval" methods and throws a [[ValidationException]] if no implementation
+ * can be found.
*/
- def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
- typeInfos.map { typeInfo =>
- if (typeInfo == null) {
- null
- } else {
- typeInfo.getTypeClass
+ def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = {
+ val methods = function
+ .getClass
+ .getDeclaredMethods
+ .filter { m =>
+ val modifiers = m.getModifiers
+ m.getName == "eval" && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers)
}
- }.toArray
+ if (methods.isEmpty) {
+ throw new ValidationException(
+ s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
+ s"one method named 'eval' which is public and not abstract.")
+ } else {
+ methods
+ }
+ }
+
+ def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = {
+ checkAndExtractEvalMethods(function).map(_.getParameterTypes)
+ }
+
+ // ----------------------------------------------------------------------------------------------
+ // Utilities for sql functions
+ // ----------------------------------------------------------------------------------------------
/**
- * Compares parameter candidate classes with expected classes. If true, the parameters match.
- * Candidate can be null (acts as a wildcard).
+ * Create [[SqlFunction]] for a [[ScalarFunction]]
+ * @param name function name
+ * @param function scalar function
+ * @param typeFactory type factory
+ * @return the ScalarSqlFunction
*/
- def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
- candidate == null ||
- candidate == expected ||
- expected.isPrimitive && Primitives.wrap(expected) == candidate ||
- candidate == classOf[Date] && expected == classOf[Int] ||
- candidate == classOf[Time] && expected == classOf[Int] ||
- candidate == classOf[Timestamp] && expected == classOf[Long]
+ def createScalarSqlFunction(
+ name: String,
+ function: ScalarFunction,
+ typeFactory: FlinkTypeFactory)
+ : SqlFunction = {
+ new ScalarSqlFunction(name, function, typeFactory)
+ }
/**
- * Returns signatures matching the given signature of [[TypeInformation]].
- * Elements of the signature can be null (act as a wildcard).
+ * Create [[SqlFunction]]s for a [[TableFunction]]'s every eval method
+ * @param name function name
+ * @param tableFunction table function
+ * @param resultType the type information of returned table
+ * @param typeFactory type factory
+ * @return the TableSqlFunction
*/
- def getSignature(
- scalarFunction: ScalarFunction,
- signature: Seq[TypeInformation[_]])
- : Option[Array[Class[_]]] = {
- // We compare the raw Java classes not the TypeInformation.
- // TypeInformation does not matter during runtime (e.g. within a MapFunction).
- val actualSignature = typeInfoToClass(signature)
+ def createTableSqlFunctions(
+ name: String,
+ tableFunction: TableFunction[_],
+ resultType: TypeInformation[_],
+ typeFactory: FlinkTypeFactory)
+ : Seq[SqlFunction] = {
+ val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType)
+ val evalMethods = checkAndExtractEvalMethods(tableFunction)
- scalarFunction
- .getSignatures
- // go over all signatures and find one matching actual signature
- .find { curSig =>
- // match parameters of signature to actual parameters
- actualSignature.length == curSig.length &&
- curSig.zipWithIndex.forall { case (clazz, i) =>
- parameterTypeEquals(actualSignature(i), clazz)
- }
- }
+ evalMethods.map { method =>
+ val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method)
+ TableSqlFunction(name, tableFunction, resultType, typeFactory, function)
+ }
}
+ // ----------------------------------------------------------------------------------------------
+ // Utilities for scalar functions
+ // ----------------------------------------------------------------------------------------------
+
/**
* Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses
* [[TypeExtractor]] as default return type inference.
*/
def getResultType(
- scalarFunction: ScalarFunction,
+ function: ScalarFunction,
signature: Array[Class[_]])
: TypeInformation[_] = {
// find method for signature
- val evalMethod = scalarFunction.getEvalMethods
+ val evalMethod = checkAndExtractEvalMethods(function)
.find(m => signature.sameElements(m.getParameterTypes))
.getOrElse(throw new ValidationException("Given signature is invalid."))
- val userDefinedTypeInfo = scalarFunction.getResultType(signature)
+ val userDefinedTypeInfo = function.getResultType(signature)
if (userDefinedTypeInfo != null) {
- userDefinedTypeInfo
+ userDefinedTypeInfo
} else {
try {
TypeExtractor.getForClass(evalMethod.getReturnType)
} catch {
case ite: InvalidTypesException =>
- throw new ValidationException(s"Return type of scalar function '$this' cannot be " +
- s"automatically determined. Please provide type information manually.")
+ throw new ValidationException(
+ s"Return type of scalar function '${function.getClass.getCanonicalName}' cannot be " +
+ s"automatically determined. Please provide type information manually.")
}
}
}
@@ -165,21 +235,100 @@ object UserDefinedFunctionUtils {
* Returns the return type of the evaluation method matching the given signature.
*/
def getResultTypeClass(
- scalarFunction: ScalarFunction,
+ function: ScalarFunction,
signature: Array[Class[_]])
: Class[_] = {
// find method for signature
- val evalMethod = scalarFunction.getEvalMethods
+ val evalMethod = checkAndExtractEvalMethods(function)
.find(m => signature.sameElements(m.getParameterTypes))
.getOrElse(throw new IllegalArgumentException("Given signature is invalid."))
evalMethod.getReturnType
}
+ // ----------------------------------------------------------------------------------------------
+ // Miscellaneous
+ // ----------------------------------------------------------------------------------------------
+
/**
- * Prints all signatures of a [[ScalarFunction]].
+ * Returns field names and field positions for a given [[TypeInformation]].
+ *
+ * Field names are automatically extracted for
+ * [[org.apache.flink.api.common.typeutils.CompositeType]].
+ *
+ * @param inputType The TypeInformation extract the field names and positions from.
+ * @return A tuple of two arrays holding the field names and corresponding field positions.
*/
- def signaturesToString(scalarFunction: ScalarFunction): String = {
- scalarFunction.getSignatures.map(signatureToString).mkString(", ")
+ def getFieldInfo(inputType: TypeInformation[_])
+ : (Array[String], Array[Int], Array[TypeInformation[_]]) = {
+
+ val fieldNames: Array[String] = inputType match {
+ case t: CompositeType[_] => t.getFieldNames
+ case a: AtomicType[_] => Array("f0")
+ case tpe =>
+ throw new TableException(s"Currently only support CompositeType and AtomicType. " +
+ s"Type $tpe lacks explicit field naming")
+ }
+ val fieldIndexes = fieldNames.indices.toArray
+ val fieldTypes: Array[TypeInformation[_]] = fieldNames.map { i =>
+ inputType match {
+ case t: CompositeType[_] => t.getTypeAt(i).asInstanceOf[TypeInformation[_]]
+ case a: AtomicType[_] => a.asInstanceOf[TypeInformation[_]]
+ case tpe =>
+ throw new TableException(s"Currently only support CompositeType and AtomicType.")
+ }
+ }
+ (fieldNames, fieldIndexes, fieldTypes)
}
+ /**
+ * Prints one signature consisting of classes.
+ */
+ def signatureToString(signature: Array[Class[_]]): String =
+ signature.map { clazz =>
+ if (clazz == null) {
+ "null"
+ } else {
+ clazz.getCanonicalName
+ }
+ }.mkString("(", ", ", ")")
+
+ /**
+ * Prints one signature consisting of TypeInformation.
+ */
+ def signatureToString(signature: Seq[TypeInformation[_]]): String = {
+ signatureToString(typeInfoToClass(signature))
+ }
+
+ /**
+ * Prints all eval methods signatures of a class.
+ */
+ def signaturesToString(function: UserDefinedFunction): String = {
+ getSignatures(function).map(signatureToString).mkString(", ")
+ }
+
+ /**
+ * Extracts type classes of [[TypeInformation]] in a null-aware way.
+ */
+ private def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
+ typeInfos.map { typeInfo =>
+ if (typeInfo == null) {
+ null
+ } else {
+ typeInfo.getTypeClass
+ }
+ }.toArray
+
+
+ /**
+ * Compares parameter candidate classes with expected classes. If true, the parameters match.
+ * Candidate can be null (acts as a wildcard).
+ */
+ private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
+ candidate == null ||
+ candidate == expected ||
+ expected.isPrimitive && Primitives.wrap(expected) == candidate ||
+ candidate == classOf[Date] && expected == classOf[Int] ||
+ candidate == classOf[Time] && expected == classOf[Int] ||
+ candidate == classOf[Timestamp] && expected == classOf[Long]
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
index cd22f6a..f6ddeef 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
@@ -122,10 +122,10 @@ object ProjectionTranslator {
case prop: WindowProperty =>
val name = propNames(prop)
Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
- case n @ Alias(agg: Aggregation, name) =>
+ case n @ Alias(agg: Aggregation, name, _) =>
val aName = aggNames(agg)
Alias(UnresolvedFieldReference(aName), name)
- case n @ Alias(prop: WindowProperty, name) =>
+ case n @ Alias(prop: WindowProperty, name, _) =>
val pName = propNames(prop)
Alias(UnresolvedFieldReference(pName), name)
case l: LeafExpression => l
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
index ecf1996..4dc2ab7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
@@ -17,9 +17,13 @@
*/
package org.apache.flink.api.table.plan.logical
+import java.lang.reflect.Method
+import java.util
+
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.logical.LogicalProject
+import org.apache.calcite.rel.core.CorrelationId
+import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableFunctionScan}
import org.apache.calcite.rex.{RexInputRef, RexNode}
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
@@ -27,6 +31,10 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table._
import org.apache.flink.api.table.expressions._
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.api.table.typeutils.TypeConverter
import org.apache.flink.api.table.validate.{ValidationFailure, ValidationSuccess}
@@ -216,7 +224,7 @@ case class Aggregate(
relBuilder.aggregate(
relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
aggregateExpressions.map {
- case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder)
+ case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
}.asJava)
}
@@ -361,7 +369,8 @@ case class Join(
left: LogicalNode,
right: LogicalNode,
joinType: JoinType,
- condition: Option[Expression]) extends BinaryNode {
+ condition: Option[Expression],
+ correlated: Boolean) extends BinaryNode {
override def output: Seq[Attribute] = {
left.output ++ right.output
@@ -411,22 +420,31 @@ case class Join(
right)
}
val resolvedCondition = node.condition.map(_.postOrderTransform(partialFunction))
- Join(node.left, node.right, node.joinType, resolvedCondition)
+ Join(node.left, node.right, node.joinType, resolvedCondition, correlated)
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
left.construct(relBuilder)
right.construct(relBuilder)
+
+ val corSet = mutable.Set[CorrelationId]()
+
+ if (correlated) {
+ corSet += relBuilder.peek().getCluster.createCorrel()
+ }
+
relBuilder.join(
TypeConverter.flinkJoinTypeToRelType(joinType),
- condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)))
+ condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)),
+ corSet.asJava)
}
private def ambiguousName: Set[String] =
left.output.map(_.name).toSet.intersect(right.output.map(_.name).toSet)
override def validate(tableEnv: TableEnvironment): LogicalNode = {
- if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
+ if (tableEnv.isInstanceOf[StreamTableEnvironment]
+ && !right.isInstanceOf[LogicalTableFunctionCall]) {
failValidation(s"Join on stream tables is currently not supported.")
}
@@ -551,11 +569,11 @@ case class WindowAggregate(
window,
relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
propertyExpressions.map {
- case Alias(prop: WindowProperty, name) => prop.toNamedWindowProperty(name)(relBuilder)
+ case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
},
aggregateExpressions.map {
- case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder)
+ case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
}.asJava)
}
@@ -605,3 +623,71 @@ case class WindowAggregate(
resolvedWindowAggregate
}
}
+
+
+/**
+ * LogicalNode for calling a user-defined table functions.
+ * @param functionName function name
+ * @param tableFunction table function to be called (might be overloaded)
+ * @param parameters actual parameters
+ * @param fieldNames output field names
+ * @param child child logical node
+ */
+case class LogicalTableFunctionCall(
+ functionName: String,
+ tableFunction: TableFunction[_],
+ parameters: Seq[Expression],
+ resultType: TypeInformation[_],
+ fieldNames: Array[String],
+ child: LogicalNode)
+ extends UnaryNode {
+
+ val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType)
+ var evalMethod: Method = _
+
+ override def output: Seq[Attribute] = fieldNames.zip(fieldTypes).map {
+ case (n, t) => ResolvedFieldReference(n, t)
+ }
+
+ override def validate(tableEnv: TableEnvironment): LogicalNode = {
+ val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall]
+ // check not Scala object
+ checkNotSingleton(tableFunction.getClass)
+ // check could be instantiated
+ checkForInstantiation(tableFunction.getClass)
+ // look for a signature that matches the input types
+ val signature = node.parameters.map(_.resultType)
+ val foundMethod = getEvalMethod(tableFunction, signature)
+ if (foundMethod.isEmpty) {
+ failValidation(
+ s"Given parameters of function '$functionName' do not match any signature. \n" +
+ s"Actual: ${signatureToString(signature)} \n" +
+ s"Expected: ${signaturesToString(tableFunction)}")
+ } else {
+ node.evalMethod = foundMethod.get
+ }
+ node
+ }
+
+ override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
+ val fieldIndexes = getFieldInfo(resultType)._2
+ val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod)
+ val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ val sqlFunction = TableSqlFunction(
+ tableFunction.toString,
+ tableFunction,
+ resultType,
+ typeFactory,
+ function)
+
+ val scan = LogicalTableFunctionScan.create(
+ relBuilder.peek().getCluster,
+ new util.ArrayList[RelNode](),
+ relBuilder.call(sqlFunction, parameters.map(_.toRexNode(relBuilder)).asJava),
+ function.getElementType(null),
+ function.getRowType(relBuilder.getTypeFactory, null),
+ null)
+
+ relBuilder.push(scan)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
new file mode 100644
index 0000000..9745be1
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression, GeneratedFunction}
+import org.apache.flink.api.table.codegen.CodeGenUtils.primitiveDefaultValue
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.runtime.FlatMapRunner
+import org.apache.flink.api.table.typeutils.TypeConverter._
+import org.apache.flink.api.table.{TableConfig, TableException}
+
+import scala.collection.JavaConverters._
+
+/**
+ * cross/outer apply a user-defined table function
+ */
+trait FlinkCorrelate {
+
+ private[flink] def functionBody(
+ generator: CodeGenerator,
+ udtfTypeInfo: TypeInformation[Any],
+ rowType: RelDataType,
+ rexCall: RexCall,
+ condition: Option[RexNode],
+ config: TableConfig,
+ joinType: SemiJoinType,
+ expectedType: Option[TypeInformation[Any]]): String = {
+
+ val returnType = determineReturnType(
+ rowType,
+ expectedType,
+ config.getNullCheck,
+ config.getEfficientTypeUsage)
+
+ val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs
+
+ val call = generator.generateExpression(rexCall)
+ var body =
+ s"""
+ |${call.code}
+ |java.util.Iterator iter = ${call.resultTerm}.getRowsIterator();
+ """.stripMargin
+
+ if (joinType == SemiJoinType.INNER) {
+ // cross apply
+ body +=
+ s"""
+ |if (!iter.hasNext()) {
+ | return;
+ |}
+ """.stripMargin
+ } else if (joinType == SemiJoinType.LEFT) {
+ // outer apply
+
+ // in case of outer apply and the returned row of table function is empty,
+ // fill null to all fields of the row
+ val input2NullExprs = input2AccessExprs.map { x =>
+ GeneratedExpression(
+ primitiveDefaultValue(x.resultType),
+ GeneratedExpression.ALWAYS_NULL,
+ "",
+ x.resultType)
+ }
+ val outerResultExpr = generator.generateResultExpression(
+ input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala)
+ body +=
+ s"""
+ |if (!iter.hasNext()) {
+ | ${outerResultExpr.code}
+ | ${generator.collectorTerm}.collect(${outerResultExpr.resultTerm});
+ | return;
+ |}
+ """.stripMargin
+ } else {
+ throw TableException(s"Unsupported SemiJoinType: $joinType for correlate join.")
+ }
+
+ val crossResultExpr = generator.generateResultExpression(
+ input1AccessExprs ++ input2AccessExprs,
+ returnType,
+ rowType.getFieldNames.asScala)
+
+ val projection = if (condition.isEmpty) {
+ s"""
+ |${crossResultExpr.code}
+ |${generator.collectorTerm}.collect(${crossResultExpr.resultTerm});
+ """.stripMargin
+ } else {
+ val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo)
+ filterGenerator.input1Term = filterGenerator.input2Term
+ val filterCondition = filterGenerator.generateExpression(condition.get)
+ s"""
+ |${filterGenerator.reuseInputUnboxingCode()}
+ |${filterCondition.code}
+ |if (${filterCondition.resultTerm}) {
+ | ${crossResultExpr.code}
+ | ${generator.collectorTerm}.collect(${crossResultExpr.resultTerm});
+ |}
+ |""".stripMargin
+ }
+
+ val outputTypeClass = udtfTypeInfo.getTypeClass.getCanonicalName
+ body +=
+ s"""
+ |while (iter.hasNext()) {
+ | $outputTypeClass ${generator.input2Term} = ($outputTypeClass) iter.next();
+ | $projection
+ |}
+ """.stripMargin
+ body
+ }
+
+ private[flink] def correlateMapFunction(
+ genFunction: GeneratedFunction[FlatMapFunction[Any, Any]])
+ : FlatMapRunner[Any, Any] = {
+
+ new FlatMapRunner[Any, Any](
+ genFunction.name,
+ genFunction.code,
+ genFunction.returnType)
+ }
+
+ private[flink] def selectToString(rowType: RelDataType): String = {
+ rowType.getFieldNames.asScala.mkString(",")
+ }
+
+ private[flink] def correlateOpName(
+ rexCall: RexCall,
+ sqlFunction: TableSqlFunction,
+ rowType: RelDataType)
+ : String = {
+
+ s"correlate: ${correlateToString(rexCall, sqlFunction)}, select: ${selectToString(rowType)}"
+ }
+
+ private[flink] def correlateToString(rexCall: RexCall, sqlFunction: TableSqlFunction): String = {
+ val udtfName = sqlFunction.getName
+ val operands = rexCall.getOperands.asScala.map(_.toString).mkString(",")
+ s"table($udtfName($operands))"
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
new file mode 100644
index 0000000..4aa7fea
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes.dataset
+
+import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan
+import org.apache.calcite.rel.metadata.RelMetadataQuery
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rex.{RexNode, RexCall}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.table.BatchTableEnvironment
+import org.apache.flink.api.table.codegen.CodeGenerator
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.plan.nodes.FlinkCorrelate
+import org.apache.flink.api.table.typeutils.TypeConverter._
+
+/**
+ * Flink RelNode which matches along with cross apply a user defined table function.
+ */
+class DataSetCorrelate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ inputNode: RelNode,
+ scan: LogicalTableFunctionScan,
+ condition: Option[RexNode],
+ relRowType: RelDataType,
+ joinRowType: RelDataType,
+ joinType: SemiJoinType,
+ ruleDescription: String)
+ extends SingleRel(cluster, traitSet, inputNode)
+ with FlinkCorrelate
+ with DataSetRel {
+
+ override def deriveRowType() = relRowType
+
+ override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
+ val rowCnt = metadata.getRowCount(getInput) * 1.5
+ planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * 0.5)
+ }
+
+ override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+ new DataSetCorrelate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ scan,
+ condition,
+ relRowType,
+ joinRowType,
+ joinType,
+ ruleDescription)
+ }
+
+ override def toString: String = {
+ val rexCall = scan.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ correlateToString(rexCall, sqlFunction)
+ }
+
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ val rexCall = scan.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ super.explainTerms(pw)
+ .item("invocation", scan.getCall)
+ .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
+ .item("rowType", relRowType)
+ .item("joinType", joinType)
+ .itemIf("condition", condition.orNull, condition.isDefined)
+ }
+
+ override def translateToPlan(
+ tableEnv: BatchTableEnvironment,
+ expectedType: Option[TypeInformation[Any]])
+ : DataSet[Any] = {
+
+ val config = tableEnv.getConfig
+ val returnType = determineReturnType(
+ getRowType,
+ expectedType,
+ config.getNullCheck,
+ config.getEfficientTypeUsage)
+
+ // do not need to specify input type
+ val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
+
+ val funcRel = scan.asInstanceOf[LogicalTableFunctionScan]
+ val rexCall = funcRel.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ val pojoFieldMapping = sqlFunction.getPojoFieldMapping
+ val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
+
+ val generator = new CodeGenerator(
+ config,
+ false,
+ inputDS.getType,
+ Some(udtfTypeInfo),
+ None,
+ Some(pojoFieldMapping))
+
+ val body = functionBody(
+ generator,
+ udtfTypeInfo,
+ getRowType,
+ rexCall,
+ condition,
+ config,
+ joinType,
+ expectedType)
+
+ val genFunction = generator.generateFunction(
+ ruleDescription,
+ classOf[FlatMapFunction[Any, Any]],
+ body,
+ returnType)
+
+ val mapFunc = correlateMapFunction(genFunction)
+
+ inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType))
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
new file mode 100644
index 0000000..b0bc48a
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.nodes.datastream
+
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan
+import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rex.{RexCall, RexNode}
+import org.apache.calcite.sql.SemiJoinType
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.StreamTableEnvironment
+import org.apache.flink.api.table.codegen.CodeGenerator
+import org.apache.flink.api.table.functions.utils.TableSqlFunction
+import org.apache.flink.api.table.plan.nodes.FlinkCorrelate
+import org.apache.flink.api.table.typeutils.TypeConverter._
+import org.apache.flink.streaming.api.datastream.DataStream
+
+/**
+ * Flink RelNode which matches along with cross apply a user defined table function.
+ */
+class DataStreamCorrelate(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ inputNode: RelNode,
+ scan: LogicalTableFunctionScan,
+ condition: Option[RexNode],
+ relRowType: RelDataType,
+ joinRowType: RelDataType,
+ joinType: SemiJoinType,
+ ruleDescription: String)
+ extends SingleRel(cluster, traitSet, inputNode)
+ with FlinkCorrelate
+ with DataStreamRel {
+ override def deriveRowType() = relRowType
+
+ override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+ new DataStreamCorrelate(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ scan,
+ condition,
+ relRowType,
+ joinRowType,
+ joinType,
+ ruleDescription)
+ }
+
+ override def toString: String = {
+ val rexCall = scan.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ correlateToString(rexCall, sqlFunction)
+ }
+
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ val rexCall = scan.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ super.explainTerms(pw)
+ .item("invocation", scan.getCall)
+ .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
+ .item("rowType", relRowType)
+ .item("joinType", joinType)
+ .itemIf("condition", condition.orNull, condition.isDefined)
+ }
+
+ override def translateToPlan(
+ tableEnv: StreamTableEnvironment,
+ expectedType: Option[TypeInformation[Any]])
+ : DataStream[Any] = {
+
+ val config = tableEnv.getConfig
+ val returnType = determineReturnType(
+ getRowType,
+ expectedType,
+ config.getNullCheck,
+ config.getEfficientTypeUsage)
+
+ // do not need to specify input type
+ val inputDS = inputNode.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
+
+ val funcRel = scan.asInstanceOf[LogicalTableFunctionScan]
+ val rexCall = funcRel.getCall.asInstanceOf[RexCall]
+ val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ val pojoFieldMapping = sqlFunction.getPojoFieldMapping
+ val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
+
+ val generator = new CodeGenerator(
+ config,
+ false,
+ inputDS.getType,
+ Some(udtfTypeInfo),
+ None,
+ Some(pojoFieldMapping))
+
+ val body = functionBody(
+ generator,
+ udtfTypeInfo,
+ getRowType,
+ rexCall,
+ condition,
+ config,
+ joinType,
+ expectedType)
+
+ val genFunction = generator.generateFunction(
+ ruleDescription,
+ classOf[FlatMapFunction[Any, Any]],
+ body,
+ returnType)
+
+ val mapFunc = correlateMapFunction(genFunction)
+
+ inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType))
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
index 9e20df4..6847425 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
@@ -108,6 +108,7 @@ object FlinkRuleSets {
DataSetMinusRule.INSTANCE,
DataSetSortRule.INSTANCE,
DataSetValuesRule.INSTANCE,
+ DataSetCorrelateRule.INSTANCE,
BatchTableSourceScanRule.INSTANCE
)
@@ -151,6 +152,7 @@ object FlinkRuleSets {
DataStreamScanRule.INSTANCE,
DataStreamUnionRule.INSTANCE,
DataStreamValuesRule.INSTANCE,
+ DataStreamCorrelateRule.INSTANCE,
StreamTableSourceScanRule.INSTANCE
)
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
new file mode 100644
index 0000000..e6cf0cf
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.rules.dataSet
+
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.logical.{LogicalFilter, LogicalCorrelate, LogicalTableFunctionScan}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.table.plan.nodes.dataset.{DataSetConvention, DataSetCorrelate}
+
+/**
+ * Rule to convert a LogicalCorrelate into a DataSetCorrelate.
+ */
+class DataSetCorrelateRule
+ extends ConverterRule(
+ classOf[LogicalCorrelate],
+ Convention.NONE,
+ DataSetConvention.INSTANCE,
+ "DataSetCorrelateRule")
+ {
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate]
+ val right = join.getRight.asInstanceOf[RelSubset].getOriginal
+
+
+ right match {
+ // right node is a table function
+ case scan: LogicalTableFunctionScan => true
+ // a filter is pushed above the table function
+ case filter: LogicalFilter =>
+ filter.getInput.asInstanceOf[RelSubset].getOriginal
+ .isInstanceOf[LogicalTableFunctionScan]
+ case _ => false
+ }
+ }
+
+ override def convert(rel: RelNode): RelNode = {
+ val join: LogicalCorrelate = rel.asInstanceOf[LogicalCorrelate]
+ val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
+ val convInput: RelNode = RelOptRule.convert(join.getInput(0), DataSetConvention.INSTANCE)
+ val right: RelNode = join.getInput(1)
+
+ def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): DataSetCorrelate = {
+ relNode match {
+ case rel: RelSubset =>
+ convertToCorrelate(rel.getRelList.get(0), condition)
+
+ case filter: LogicalFilter =>
+ convertToCorrelate(
+ filter.getInput.asInstanceOf[RelSubset].getOriginal,
+ Some(filter.getCondition))
+
+ case scan: LogicalTableFunctionScan =>
+ new DataSetCorrelate(
+ rel.getCluster,
+ traitSet,
+ convInput,
+ scan,
+ condition,
+ rel.getRowType,
+ join.getRowType,
+ join.getJoinType,
+ description)
+ }
+ }
+ convertToCorrelate(right, None)
+ }
+ }
+
+object DataSetCorrelateRule {
+ val INSTANCE: RelOptRule = new DataSetCorrelateRule
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
new file mode 100644
index 0000000..bb52fd7
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.rules.datastream
+
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.logical.{LogicalFilter, LogicalCorrelate, LogicalTableFunctionScan}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.table.plan.nodes.datastream.{DataStreamCorrelate, DataStreamConvention}
+
+/**
+ * Rule to convert a LogicalCorrelate into a DataStreamCorrelate.
+ */
+class DataStreamCorrelateRule
+ extends ConverterRule(
+ classOf[LogicalCorrelate],
+ Convention.NONE,
+ DataStreamConvention.INSTANCE,
+ "DataStreamCorrelateRule")
+{
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate]
+ val right = join.getRight.asInstanceOf[RelSubset].getOriginal
+
+ right match {
+ // right node is a table function
+ case scan: LogicalTableFunctionScan => true
+ // a filter is pushed above the table function
+ case filter: LogicalFilter =>
+ filter.getInput.asInstanceOf[RelSubset].getOriginal
+ .isInstanceOf[LogicalTableFunctionScan]
+ case _ => false
+ }
+ }
+
+ override def convert(rel: RelNode): RelNode = {
+ val join: LogicalCorrelate = rel.asInstanceOf[LogicalCorrelate]
+ val traitSet: RelTraitSet = rel.getTraitSet.replace(DataStreamConvention.INSTANCE)
+ val convInput: RelNode = RelOptRule.convert(join.getInput(0), DataStreamConvention.INSTANCE)
+ val right: RelNode = join.getInput(1)
+
+ def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): DataStreamCorrelate = {
+ relNode match {
+ case rel: RelSubset =>
+ convertToCorrelate(rel.getRelList.get(0), condition)
+
+ case filter: LogicalFilter =>
+ convertToCorrelate(filter.getInput.asInstanceOf[RelSubset].getOriginal,
+ Some(filter.getCondition))
+
+ case scan: LogicalTableFunctionScan =>
+ new DataStreamCorrelate(
+ rel.getCluster,
+ traitSet,
+ convInput,
+ scan,
+ condition,
+ rel.getRowType,
+ join.getRowType,
+ join.getJoinType,
+ description)
+ }
+ }
+ convertToCorrelate(right, None)
+ }
+
+}
+
+object DataStreamCorrelateRule {
+ val INSTANCE: RelOptRule = new DataStreamCorrelateRule
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
new file mode 100644
index 0000000..540a5c8
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.schema
+
+import java.lang.reflect.{Method, Type}
+import java.util
+
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
+import org.apache.calcite.schema.TableFunction
+import org.apache.calcite.schema.impl.ReflectiveFunctionBase
+import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.table.{FlinkTypeFactory, TableException}
+
+/**
+ * This is heavily inspired by Calcite's [[org.apache.calcite.schema.impl.TableFunctionImpl]].
+ * We need it in order to create a [[org.apache.flink.api.table.functions.utils.TableSqlFunction]].
+ * The main difference is that we override the [[getRowType()]] and [[getElementType()]].
+ */
+class FlinkTableFunctionImpl[T](
+ val typeInfo: TypeInformation[T],
+ val fieldIndexes: Array[Int],
+ val fieldNames: Array[String],
+ val evalMethod: Method)
+ extends ReflectiveFunctionBase(evalMethod)
+ with TableFunction {
+
+ if (fieldIndexes.length != fieldNames.length) {
+ throw new TableException(
+ "Number of field indexes and field names must be equal.")
+ }
+
+ // check uniqueness of field names
+ if (fieldNames.length != fieldNames.toSet.size) {
+ throw new TableException(
+ "Table field names must be unique.")
+ }
+
+ val fieldTypes: Array[TypeInformation[_]] =
+ typeInfo match {
+ case cType: CompositeType[T] =>
+ if (fieldNames.length != cType.getArity) {
+ throw new TableException(
+ s"Arity of type (" + cType.getFieldNames.deep + ") " +
+ "not equal to number of field names " + fieldNames.deep + ".")
+ }
+ fieldIndexes.map(cType.getTypeAt(_).asInstanceOf[TypeInformation[_]])
+ case aType: AtomicType[T] =>
+ if (fieldIndexes.length != 1 || fieldIndexes(0) != 0) {
+ throw new TableException(
+ "Non-composite input type may have only a single field and its index must be 0.")
+ }
+ Array(aType)
+ }
+
+ override def getElementType(arguments: util.List[AnyRef]): Type = classOf[Array[Object]]
+
+ override def getRowType(typeFactory: RelDataTypeFactory,
+ arguments: util.List[AnyRef]): RelDataType = {
+ val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory]
+ val builder = flinkTypeFactory.builder
+ fieldNames
+ .zip(fieldTypes)
+ .foreach { f =>
+ builder.add(f._1, flinkTypeFactory.createTypeFromTypeInfo(f._2)).nullable(true)
+ }
+ builder.build
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
index c45e871..a75f2fc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
@@ -20,7 +20,8 @@ package org.apache.flink.api.table
import org.apache.calcite.rel.RelNode
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
-import org.apache.flink.api.table.expressions.{Asc, Expression, ExpressionParser, Ordering}
+import org.apache.flink.api.table.plan.logical.Minus
+import org.apache.flink.api.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall}
import org.apache.flink.api.table.plan.ProjectionTranslator._
import org.apache.flink.api.table.plan.logical._
import org.apache.flink.api.table.sinks.TableSink
@@ -400,7 +401,8 @@ class Table(
}
new Table(
tableEnv,
- Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate).validate(tableEnv))
+ Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate, correlated = false)
+ .validate(tableEnv))
}
/**
@@ -609,6 +611,126 @@ class Table(
}
/**
+ * The Cross Apply returns rows from the outer table (table on the left of the Apply operator)
+ * that produces matching values from the table-valued function (which is on the right side of
+ * the operator).
+ *
+ * The Cross Apply is equivalent to Inner Join, but it works with a table-valued function.
+ *
+ * Example:
+ *
+ * {{{
+ * class MySplitUDTF extends TableFunction[String] {
+ * def eval(str: String): Unit = {
+ * str.split("#").foreach(collect)
+ * }
+ * }
+ *
+ * val split = new MySplitUDTF()
+ * table.crossApply(split('c) as ('s)).select('a,'b,'c,'s)
+ * }}}
+ */
+ def crossApply(udtf: Expression): Table = {
+ applyInternal(udtf, JoinType.INNER)
+ }
+
+ /**
+ * The Cross Apply returns rows from the outer table (table on the left of the Apply operator)
+ * that produces matching values from the table-valued function (which is on the right side of
+ * the operator).
+ *
+ * The Cross Apply is equivalent to Inner Join, but it works with a table-valued function.
+ *
+ * Example:
+ *
+ * {{{
+ * class MySplitUDTF extends TableFunction[String] {
+ * def eval(str: String): Unit = {
+ * str.split("#").foreach(collect)
+ * }
+ * }
+ *
+ * val split = new MySplitUDTF()
+ * table.crossApply("split(c) as (s)").select("a, b, c, s")
+ * }}}
+ */
+ def crossApply(udtf: String): Table = {
+ applyInternal(udtf, JoinType.INNER)
+ }
+
+ /**
+ * The Outer Apply returns all the rows from the outer table (table on the left of the Apply
+ * operator), and rows that do not matches the condition from the table-valued function (which
+ * is on the right side of the operator), NULL values are displayed.
+ *
+ * The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function.
+ *
+ * Example:
+ *
+ * {{{
+ * class MySplitUDTF extends TableFunction[String] {
+ * def eval(str: String): Unit = {
+ * str.split("#").foreach(collect)
+ * }
+ * }
+ *
+ * val split = new MySplitUDTF()
+ * table.outerApply(split('c) as ('s)).select('a,'b,'c,'s)
+ * }}}
+ */
+ def outerApply(udtf: Expression): Table = {
+ applyInternal(udtf, JoinType.LEFT_OUTER)
+ }
+
+ /**
+ * The Outer Apply returns all the rows from the outer table (table on the left of the Apply
+ * operator), and rows that do not matches the condition from the table-valued function (which
+ * is on the right side of the operator), NULL values are displayed.
+ *
+ * The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function.
+ *
+ * Example:
+ *
+ * {{{
+ * val split = new MySplitUDTF()
+ * table.outerApply("split(c) as (s)").select("a, b, c, s")
+ * }}}
+ */
+ def outerApply(udtf: String): Table = {
+ applyInternal(udtf, JoinType.LEFT_OUTER)
+ }
+
+ private def applyInternal(udtfString: String, joinType: JoinType): Table = {
+ val udtf = ExpressionParser.parseExpression(udtfString)
+ applyInternal(udtf, joinType)
+ }
+
+ private def applyInternal(udtf: Expression, joinType: JoinType): Table = {
+ var alias: Option[Seq[String]] = None
+
+ // unwrap an Expression until get a TableFunctionCall
+ def unwrap(expr: Expression): TableFunctionCall = expr match {
+ case Alias(child, name, extraNames) =>
+ alias = Some(Seq(name) ++ extraNames)
+ unwrap(child)
+ case Call(name, args) =>
+ val function = tableEnv.getFunctionCatalog.lookupFunction(name, args)
+ unwrap(function)
+ case c: TableFunctionCall => c
+ case _ => throw new TableException("Cross/Outer Apply only accept TableFunction")
+ }
+
+ val call = unwrap(udtf)
+ .as(alias)
+ .toLogicalTableFunctionCall(this.logicalPlan)
+ .validate(tableEnv)
+
+ new Table(
+ tableEnv,
+ Join(this.logicalPlan, call, joinType, None, correlated = true).validate(tableEnv))
+ }
+
+ /**
* Writes the [[Table]] to a [[TableSink]]. A [[TableSink]] defines an external storage location.
*
* A batch [[Table]] can only be written to a
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
index 679733c..4029a7d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
@@ -23,8 +23,8 @@ import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTabl
import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable}
import org.apache.flink.api.table.ValidationException
import org.apache.flink.api.table.expressions._
-import org.apache.flink.api.table.functions.ScalarFunction
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction}
+import org.apache.flink.api.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils}
import scala.collection.JavaConversions._
import scala.collection.mutable
@@ -47,6 +47,20 @@ class FunctionCatalog {
sqlFunctions += sqlFunction
}
+ /** Register multiple sql functions at one time. The functions has the same name. **/
+ def registerSqlFunctions(functions: Seq[SqlFunction]): Unit = {
+ if (functions.nonEmpty) {
+ val name = functions.head.getName
+ // check all name is the same in the functions
+ if (functions.forall(_.getName == name)) {
+ sqlFunctions --= sqlFunctions.filter(_.getName == name)
+ sqlFunctions ++= functions
+ } else {
+ throw ValidationException("The sql functions request to register have different name.")
+ }
+ }
+ }
+
def getSqlOperatorTable: SqlOperatorTable =
ChainedSqlOperatorTable.of(
new BasicOperatorTable(),
@@ -59,14 +73,9 @@ class FunctionCatalog {
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
val funcClass = functionBuilders
.getOrElse(name.toLowerCase, throw ValidationException(s"Undefined function: $name"))
- withChildren(funcClass, children)
- }
- /**
- * Instantiate a function using the provided `children`.
- */
- private def withChildren(func: Class[_], children: Seq[Expression]): Expression = {
- func match {
+ // Instantiate a function using the provided `children`
+ funcClass match {
// user-defined scalar function call
case sf if classOf[ScalarFunction].isAssignableFrom(sf) =>
@@ -75,10 +84,20 @@ class FunctionCatalog {
case Failure(e) => throw ValidationException(e.getMessage)
}
+ // user-defined table function call
+ case tf if classOf[TableFunction[_]].isAssignableFrom(tf) =>
+ val tableSqlFunction = sqlFunctions
+ .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[TableSqlFunction])
+ .getOrElse(throw ValidationException(s"Unregistered table sql function: $name"))
+ .asInstanceOf[TableSqlFunction]
+ val typeInfo = tableSqlFunction.getRowTypeInfo
+ val function = tableSqlFunction.getTableFunction
+ TableFunctionCall(name, function, children, typeInfo)
+
// general expression call
case expression if classOf[Expression].isAssignableFrom(expression) =>
// try to find a constructor accepts `Seq[Expression]`
- Try(func.getDeclaredConstructor(classOf[Seq[_]])) match {
+ Try(funcClass.getDeclaredConstructor(classOf[Seq[_]])) match {
case Success(seqCtor) =>
Try(seqCtor.newInstance(children).asInstanceOf[Expression]) match {
case Success(expr) => expr
@@ -87,14 +106,14 @@ class FunctionCatalog {
case Failure(e) =>
val childrenClass = Seq.fill(children.length)(classOf[Expression])
// try to find a constructor matching the exact number of children
- Try(func.getDeclaredConstructor(childrenClass: _*)) match {
+ Try(funcClass.getDeclaredConstructor(childrenClass: _*)) match {
case Success(ctor) =>
Try(ctor.newInstance(children: _*).asInstanceOf[Expression]) match {
case Success(expr) => expr
case Failure(exception) => throw ValidationException(exception.getMessage)
}
case Failure(exception) =>
- throw ValidationException(s"Invalid number of arguments for function $func")
+ throw ValidationException(s"Invalid number of arguments for function $funcClass")
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
new file mode 100644
index 0000000..7e0d0ff
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.batch
+
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils._
+import org.apache.flink.api.table.{Row, Table, TableEnvironment}
+import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.apache.flink.test.util.TestBaseUtils
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class UserDefinedTableFunctionITCase(
+ mode: TestExecutionMode,
+ configMode: TableConfigMode)
+ extends TableProgramsTestBase(mode, configMode) {
+
+ @Test
+ def testSQLCrossApply(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ tableEnv.registerTable("MyTable", in)
+ tableEnv.registerFunction("split", new TableFunc1)
+
+ val sqlQuery = "SELECT MyTable.c, t.s FROM MyTable, LATERAL TABLE(split(c)) AS t(s)"
+
+ val result = tableEnv.sql(sqlQuery).toDataSet[Row]
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+ "Anna#44,Anna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testSQLOuterApply(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ tableEnv.registerTable("MyTable", in)
+ tableEnv.registerFunction("split", new TableFunc2)
+
+ val sqlQuery = "SELECT MyTable.c, t.a, t.b FROM MyTable LEFT JOIN LATERAL TABLE(split(c)) " +
+ "AS t(a,b) ON TRUE"
+
+ val result = tableEnv.sql(sqlQuery).toDataSet[Row]
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testTableAPICrossApply(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func1 = new TableFunc1
+ val result = in.crossApply(func1('c) as ('s)).select('c, 's).toDataSet[Row]
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+ "Anna#44,Anna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+
+ // with overloading
+ val result2 = in.crossApply(func1('c, "$") as ('s)).select('c, 's).toDataSet[Row]
+ val results2 = result2.collect()
+ val expected2: String = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" +
+ "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n"
+ TestBaseUtils.compareResultAsText(results2.asJava, expected2)
+ }
+
+
+ @Test
+ def testTableAPIOuterApply(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func2 = new TableFunc2
+ val result = in.outerApply(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row]
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ @Test
+ def testCustomReturnType(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func2 = new TableFunc2
+
+ val result = in
+ .crossApply(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testHierarchyType(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val hierarchy = new HierarchyTableFunction
+ val result = in
+ .crossApply(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'adult, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" +
+ "Anna#44,Anna,true,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testPojoType(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val pojo = new PojoTableFunc()
+ val result = in
+ .crossApply(pojo('c))
+ .select('c, 'name, 'age)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ @Test
+ def testTableAPIWithFilter(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = in
+ .crossApply(func0('c) as ('name, 'age))
+ .select('c, 'name, 'age)
+ .filter('age > 20)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ @Test
+ def testUDTFWithScalarFunction(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
+ val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func1 = new TableFunc1
+
+ val result = in
+ .crossApply(func1('c.substring(2)) as 's)
+ .select('c, 's)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected: String = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" +
+ "Anna#44,nna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+
+ private def getSmall3TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, String)] = {
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "Jack#22"))
+ data.+=((2, 2L, "John#19"))
+ data.+=((3, 2L, "Anna#44"))
+ data.+=((4, 3L, "nosharp"))
+ env.fromCollection(data)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..7e236d1
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.batch
+
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment => ScalaExecutionEnv, _}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.{DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv}
+import org.apache.flink.api.table.expressions.utils.{HierarchyTableFunction, PojoTableFunc, TableFunc1, TableFunc2}
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.apache.flink.api.table.{Row, TableEnvironment, Types}
+import org.junit.Test
+import org.mockito.Mockito._
+
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+ @Test
+ def testTableAPI(): Unit = {
+ // mock
+ val ds = mock(classOf[DataSet[Row]])
+ val jDs = mock(classOf[JDataSet[Row]])
+ val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
+ when(ds.javaSet).thenReturn(jDs)
+ when(jDs.getType).thenReturn(typeInfo)
+
+ // Scala environment
+ val env = mock(classOf[ScalaExecutionEnv])
+ val tableEnv = TableEnvironment.getTableEnvironment(env)
+ val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
+
+ // Java environment
+ val javaEnv = mock(classOf[JavaExecutionEnv])
+ val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
+ val in2 = javaTableEnv.fromDataSet(jDs).as("a, b, c")
+ javaTableEnv.registerTable("MyTable", in2)
+
+ // test cross apply
+ val func1 = new TableFunc1
+ javaTableEnv.registerFunction("func1", func1)
+ var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's)
+ var javaTable = in2.crossApply("func1(c) as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test outer apply
+ scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's)
+ javaTable = in2.outerApply("func1(c) as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test overloading
+ scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's)
+ javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test custom result type
+ val func2 = new TableFunc2
+ javaTableEnv.registerFunction("func2", func2)
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
+ javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test hierarchy generic type
+ val hierarchy = new HierarchyTableFunction
+ javaTableEnv.registerFunction("hierarchy", hierarchy)
+ scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'len, 'adult)
+ javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)")
+ .select("c, name, len, adult")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test pojo type
+ val pojo = new PojoTableFunc
+ javaTableEnv.registerFunction("pojo", pojo)
+ scalaTable = in1.crossApply(pojo('c))
+ .select('c, 'name, 'age)
+ javaTable = in2.crossApply("pojo(c)")
+ .select("c, name, age")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with filter
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len).filter('len > 2)
+ javaTable = in2.crossApply("func2(c) as (name, len)")
+ .select("c, name, len").filter("len > 2")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with scalar function
+ scalaTable = in1.crossApply(func1('c.substring(2)) as ('s))
+ .select('a, 'c, 's)
+ javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
+ .select("a, c, s")
+ verifyTableEquals(scalaTable, javaTable)
+ }
+
+ @Test
+ def testSQLWithCrossApply(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+
+ // test overloading
+
+ val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
+
+ val expected2 = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c, '$')"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery2, expected2)
+ }
+
+ @Test
+ def testSQLWithOuterApply(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "LEFT")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithCustomType(): Unit = {
+ val util = batchTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithHierarchyType(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new HierarchyTableFunction
+ util.addFunction("hierarchy", function)
+
+ val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "hierarchy($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithPojoType(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new PojoTableFunc
+ util.addFunction("pojo", function)
+
+ val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "pojo($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " INTEGER age, VARCHAR(2147483647) name)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "name", "age")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testSQLWithFilter(): Unit = {
+ val util = batchTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " +
+ "WHERE len > 2"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER"),
+ term("condition", ">($1, 2)")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+
+ @Test
+ def testSQLWithScalarFunction(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+}
[3/5] flink git commit: [FLINK-4469] [table] Add support for user
defined table function in Table API & SQL
Posted by tw...@apache.org.
[FLINK-4469] [table] Add support for user defined table function in Table API & SQL
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e139f59c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e139f59c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e139f59c
Branch: refs/heads/master
Commit: e139f59ce97875338c5ee74bb0432389b3f343bf
Parents: c024b0b
Author: Jark Wu <wu...@alibaba-inc.com>
Authored: Tue Oct 18 11:15:07 2016 +0800
Committer: twalthr <tw...@apache.org>
Committed: Wed Dec 7 16:27:42 2016 +0100
----------------------------------------------------------------------
.../api/java/table/BatchTableEnvironment.scala | 15 +
.../api/java/table/StreamTableEnvironment.scala | 15 +
.../api/scala/table/BatchTableEnvironment.scala | 12 +
.../scala/table/StreamTableEnvironment.scala | 11 +
.../scala/table/TableFunctionCallBuilder.scala | 39 ++
.../flink/api/scala/table/expressionDsl.scala | 5 +-
.../flink/api/table/FlinkTypeFactory.scala | 14 +-
.../flink/api/table/TableEnvironment.scala | 50 ++-
.../flink/api/table/codegen/CodeGenerator.scala | 95 +++--
.../table/codegen/calls/FunctionGenerator.scala | 369 +++++++++++++++++
.../table/codegen/calls/ScalarFunctions.scala | 359 -----------------
.../codegen/calls/TableFunctionCallGen.scala | 82 ++++
.../table/expressions/ExpressionParser.scala | 4 +-
.../flink/api/table/expressions/call.scala | 83 +++-
.../api/table/expressions/fieldExpression.scala | 6 +-
.../api/table/functions/ScalarFunction.scala | 44 +-
.../api/table/functions/TableFunction.scala | 121 ++++++
.../table/functions/UserDefinedFunction.scala | 36 +-
.../functions/utils/ScalarSqlFunction.scala | 6 +-
.../functions/utils/TableSqlFunction.scala | 119 ++++++
.../utils/UserDefinedFunctionUtils.scala | 275 ++++++++++---
.../api/table/plan/ProjectionTranslator.scala | 4 +-
.../api/table/plan/logical/operators.scala | 102 ++++-
.../api/table/plan/nodes/FlinkCorrelate.scala | 162 ++++++++
.../plan/nodes/dataset/DataSetCorrelate.scala | 139 +++++++
.../nodes/datastream/DataStreamCorrelate.scala | 133 ++++++
.../api/table/plan/rules/FlinkRuleSets.scala | 2 +
.../rules/dataSet/DataSetCorrelateRule.scala | 90 +++++
.../datastream/DataStreamCorrelateRule.scala | 89 ++++
.../plan/schema/FlinkTableFunctionImpl.scala | 84 ++++
.../org/apache/flink/api/table/table.scala | 126 +++++-
.../api/table/validate/FunctionCatalog.scala | 43 +-
.../batch/UserDefinedTableFunctionITCase.scala | 212 ++++++++++
.../batch/UserDefinedTableFunctionTest.scala | 320 +++++++++++++++
.../stream/UserDefinedTableFunctionITCase.scala | 181 +++++++++
.../stream/UserDefinedTableFunctionTest.scala | 402 +++++++++++++++++++
.../UserDefinedScalarFunctionTest.scala | 4 +-
.../expressions/utils/ExpressionTestBase.scala | 4 +-
.../utils/UserDefinedTableFunctions.scala | 116 ++++++
.../flink/api/table/utils/TableTestBase.scala | 32 ++
40 files changed, 3414 insertions(+), 591 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
index a4f40d5..b353377 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
@@ -21,6 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.api.table.expressions.ExpressionParser
+import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.{Table, TableConfig}
/**
@@ -162,4 +163,18 @@ class BatchTableEnvironment(
translate[T](table)(typeInfo)
}
+ /**
+ * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in Table API and SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param tf The TableFunction to register
+ */
+ def registerFunction[T](name: String, tf: TableFunction[T]): Unit = {
+ implicit val typeInfo: TypeInformation[T] = TypeExtractor
+ .createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0)
+ .asInstanceOf[TypeInformation[T]]
+
+ registerTableFunctionInternal[T](name, tf)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
index f8dbc37..367cb82 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
@@ -19,6 +19,7 @@ package org.apache.flink.api.java.table
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.{TableConfig, Table}
import org.apache.flink.api.table.expressions.ExpressionParser
import org.apache.flink.streaming.api.datastream.DataStream
@@ -164,4 +165,18 @@ class StreamTableEnvironment(
translate[T](table)(typeInfo)
}
+ /**
+ * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in Table API and SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param tf The TableFunction to register
+ */
+ def registerFunction[T](name: String, tf: TableFunction[T]): Unit = {
+ implicit val typeInfo: TypeInformation[T] = TypeExtractor
+ .createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0)
+ .asInstanceOf[TypeInformation[T]]
+
+ registerTableFunctionInternal[T](name, tf)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
index adb444b..36885d2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
@@ -20,6 +20,7 @@ package org.apache.flink.api.scala.table
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala._
import org.apache.flink.api.table.expressions.Expression
+import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.{TableConfig, Table}
import scala.reflect.ClassTag
@@ -139,4 +140,15 @@ class BatchTableEnvironment(
wrap[T](translate(table))(ClassTag.AnyRef.asInstanceOf[ClassTag[T]])
}
+ /**
+ * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param tf The TableFunction to register
+ */
+ def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
+ registerTableFunctionInternal(name, tf)
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala
index e106178..dde69d5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala
@@ -18,6 +18,7 @@
package org.apache.flink.api.scala.table
import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.{TableConfig, Table}
import org.apache.flink.api.table.expressions.Expression
import org.apache.flink.streaming.api.scala.{StreamExecutionEnvironment, DataStream}
@@ -142,4 +143,14 @@ class StreamTableEnvironment(
asScalaStream(translate(table))
}
+ /**
+ * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param tf The TableFunction to register
+ */
+ def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
+ registerTableFunctionInternal(name, tf)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
new file mode 100644
index 0000000..2261b70
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.table
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.expressions.{Expression, TableFunctionCall}
+import org.apache.flink.api.table.functions.TableFunction
+
+case class TableFunctionCallBuilder[T: TypeInformation](udtf: TableFunction[T]) {
+ /**
+ * Creates a call to a [[TableFunction]] in Scala Table API.
+ *
+ * @param params actual parameters of function
+ * @return [[TableFunctionCall]]
+ */
+ def apply(params: Expression*): Expression = {
+ val resultType = if (udtf.getResultType == null) {
+ implicitly[TypeInformation[T]]
+ } else {
+ udtf.getResultType
+ }
+ TableFunctionCall(udtf.getClass.getSimpleName, udtf, params, resultType)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
index fee43d8..cc4c68d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, toMonthInterval, toRowInterval}
import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit
import org.apache.flink.api.table.expressions._
+import org.apache.flink.api.table.functions.TableFunction
import scala.language.implicitConversions
@@ -97,7 +98,7 @@ trait ImplicitExpressionOperations {
def cast(toType: TypeInformation[_]) = Cast(expr, toType)
- def as(name: Symbol) = Alias(expr, name.name)
+ def as(name: Symbol, extraNames: Symbol*) = Alias(expr, name.name, extraNames.map(_.name))
def asc = Asc(expr)
def desc = Desc(expr)
@@ -539,6 +540,8 @@ trait ImplicitExpressionConversions {
implicit def sqlDate2Literal(sqlDate: Date): Expression = Literal(sqlDate)
implicit def sqlTime2Literal(sqlTime: Time): Expression = Literal(sqlTime)
implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp)
+ implicit def UDTF2TableFunctionCall[T: TypeInformation](udtf: TableFunction[T]):
+ TableFunctionCallBuilder[T] = TableFunctionCallBuilder(udtf)
}
// ------------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
index 12dace4..bb11576 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala
@@ -26,7 +26,7 @@ import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeinfo.{NothingTypeInfo, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.ValueTypeInfo._
import org.apache.flink.api.table.FlinkTypeFactory.typeInfoToSqlTypeName
@@ -115,9 +115,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
}
override def createTypeWithNullability(
- relDataType: RelDataType,
- nullable: Boolean)
- : RelDataType = relDataType match {
+ relDataType: RelDataType,
+ nullable: Boolean)
+ : RelDataType = relDataType match {
case composite: CompositeRelDataType =>
// at the moment we do not care about nullability
composite
@@ -172,8 +172,7 @@ object FlinkTypeFactory {
case typeName if DAY_INTERVAL_TYPES.contains(typeName) => TimeIntervalTypeInfo.INTERVAL_MILLIS
case NULL =>
- throw TableException("Type NULL is not supported. " +
- "Null values must have a supported type.")
+ throw TableException("Type NULL is not supported. Null values must have a supported type.")
// symbol for special flags e.g. TRIM's BOTH, LEADING, TRAILING
// are represented as integer
@@ -188,6 +187,9 @@ object FlinkTypeFactory {
val compositeRelDataType = relDataType.asInstanceOf[CompositeRelDataType]
compositeRelDataType.compositeType
+ // ROW and CURSOR for UDTF case, whose type info will never be used, just a placeholder
+ case ROW | CURSOR => new NothingTypeInfo
+
case _@t =>
throw TableException(s"Type is not supported: $t")
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
index 7b2b738..8cabadb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
@@ -40,7 +40,8 @@ import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.scala.{ExecutionEnvironment => ScalaBatchExecEnv}
import org.apache.flink.api.table.codegen.ExpressionReducer
import org.apache.flink.api.table.expressions.{Alias, Expression, UnresolvedFieldReference}
-import org.apache.flink.api.table.functions.{ScalarFunction, UserDefinedFunction}
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createTableSqlFunctions, createScalarSqlFunction}
+import org.apache.flink.api.table.functions.{TableFunction, ScalarFunction}
import org.apache.flink.api.table.plan.cost.DataSetCostFactory
import org.apache.flink.api.table.plan.schema.RelTable
import org.apache.flink.api.table.sinks.TableSink
@@ -153,21 +154,42 @@ abstract class TableEnvironment(val config: TableConfig) {
protected def getBuiltInRuleSet: RuleSet
/**
- * Registers a [[UserDefinedFunction]] under a unique name. Replaces already existing
+ * Registers a [[ScalarFunction]] under a unique name. Replaces already existing
* user-defined functions under this name.
*/
- def registerFunction(name: String, function: UserDefinedFunction): Unit = {
- function match {
- case sf: ScalarFunction =>
- // register in Table API
- functionCatalog.registerFunction(name, function.getClass)
+ def registerFunction(name: String, function: ScalarFunction): Unit = {
+ // check could be instantiated
+ checkForInstantiation(function.getClass)
- // register in SQL API
- functionCatalog.registerSqlFunction(sf.getSqlFunction(name, typeFactory))
+ // register in Table API
+ functionCatalog.registerFunction(name, function.getClass)
- case _ =>
- throw new TableException("Unsupported user-defined function type.")
+ // register in SQL API
+ functionCatalog.registerSqlFunction(createScalarSqlFunction(name, function, typeFactory))
+ }
+
+ /**
+ * Registers a [[TableFunction]] under a unique name. Replaces already existing
+ * user-defined functions under this name.
+ */
+ private[flink] def registerTableFunctionInternal[T: TypeInformation](
+ name: String, function: TableFunction[T]): Unit = {
+ // check not Scala object
+ checkNotSingleton(function.getClass)
+ // check could be instantiated
+ checkForInstantiation(function.getClass)
+
+ val typeInfo: TypeInformation[_] = if (function.getResultType != null) {
+ function.getResultType
+ } else {
+ implicitly[TypeInformation[T]]
}
+
+ // register in Table API
+ functionCatalog.registerFunction(name, function.getClass)
+ // register in SQL API
+ val sqlFunctions = createTableSqlFunctions(name, function, typeInfo, typeFactory)
+ functionCatalog.registerSqlFunctions(sqlFunctions)
}
/**
@@ -364,7 +386,7 @@ abstract class TableEnvironment(val config: TableConfig) {
case t: TupleTypeInfo[A] =>
exprs.zipWithIndex.map {
case (UnresolvedFieldReference(name), idx) => (idx, name)
- case (Alias(UnresolvedFieldReference(origName), name), _) =>
+ case (Alias(UnresolvedFieldReference(origName), name, _), _) =>
val idx = t.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $t")
@@ -376,7 +398,7 @@ abstract class TableEnvironment(val config: TableConfig) {
case c: CaseClassTypeInfo[A] =>
exprs.zipWithIndex.map {
case (UnresolvedFieldReference(name), idx) => (idx, name)
- case (Alias(UnresolvedFieldReference(origName), name), _) =>
+ case (Alias(UnresolvedFieldReference(origName), name, _), _) =>
val idx = c.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $c")
@@ -393,7 +415,7 @@ abstract class TableEnvironment(val config: TableConfig) {
throw new TableException(s"$name is not a field of type $p")
}
(idx, name)
- case Alias(UnresolvedFieldReference(origName), name) =>
+ case Alias(UnresolvedFieldReference(origName), name, _) =>
val idx = p.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $p")
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
index 2a8ef44..9e4f569 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
@@ -33,7 +33,7 @@ import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, Tuple
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.table.codegen.CodeGenUtils._
import org.apache.flink.api.table.codegen.Indenter.toISC
-import org.apache.flink.api.table.codegen.calls.ScalarFunctions
+import org.apache.flink.api.table.codegen.calls.FunctionGenerator
import org.apache.flink.api.table.codegen.calls.ScalarOperators._
import org.apache.flink.api.table.functions.UserDefinedFunction
import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeConverter}
@@ -50,16 +50,19 @@ import scala.collection.mutable
* @param nullableInput input(s) can be null.
* @param input1 type information about the first input of the Function
* @param input2 type information about the second input if the Function is binary
- * @param inputPojoFieldMapping additional mapping information if input1 is a POJO (POJO types
- * have no deterministic field order). We assume that input2 is
- * converted before and thus is never a POJO.
+ * @param input1PojoFieldMapping additional mapping information if input1 is a POJO (POJO types
+ * have no deterministic field order).
+ * @param input2PojoFieldMapping additional mapping information if input2 is a POJO (POJO types
+ * have no deterministic field order).
+ *
*/
class CodeGenerator(
config: TableConfig,
nullableInput: Boolean,
input1: TypeInformation[Any],
input2: Option[TypeInformation[Any]] = None,
- inputPojoFieldMapping: Option[Array[Int]] = None)
+ input1PojoFieldMapping: Option[Array[Int]] = None,
+ input2PojoFieldMapping: Option[Array[Int]] = None)
extends RexVisitor[GeneratedExpression] {
// check if nullCheck is enabled when inputs can be null
@@ -67,18 +70,19 @@ class CodeGenerator(
throw new CodeGenException("Null check must be enabled if entire rows can be null.")
}
- // check for POJO input mapping
+ // check for POJO input1 mapping
input1 match {
case pt: PojoTypeInfo[_] =>
- inputPojoFieldMapping.getOrElse(
- throw new CodeGenException("No input mapping is specified for input of type POJO."))
+ input1PojoFieldMapping.getOrElse(
+ throw new CodeGenException("No input mapping is specified for input1 of type POJO."))
case _ => // ok
}
- // check that input2 is never a POJO
+ // check for POJO input2 mapping
input2 match {
case Some(pt: PojoTypeInfo[_]) =>
- throw new CodeGenException("Second input must not be a POJO type.")
+ input2PojoFieldMapping.getOrElse(
+ throw new CodeGenException("No input mapping is specified for input2 of type POJO."))
case _ => // ok
}
@@ -156,22 +160,22 @@ class CodeGenerator(
/**
* @return term of the (casted and possibly boxed) first input
*/
- def input1Term = "in1"
+ var input1Term = "in1"
/**
* @return term of the (casted and possibly boxed) second input
*/
- def input2Term = "in2"
+ var input2Term = "in2"
/**
* @return term of the (casted) output collector
*/
- def collectorTerm = "c"
+ var collectorTerm = "c"
/**
* @return term of the output record (possibly defined in the member area e.g. Row, Tuple)
*/
- def outRecordTerm = "out"
+ var outRecordTerm = "out"
/**
* @return returns if null checking is enabled
@@ -334,11 +338,11 @@ class CodeGenerator(
resultFieldNames: Seq[String])
: GeneratedExpression = {
val input1AccessExprs = for (i <- 0 until input1.getArity)
- yield generateInputAccess(input1, input1Term, i)
+ yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping)
val input2AccessExprs = input2 match {
case Some(ti) => for (i <- 0 until ti.getArity)
- yield generateInputAccess(ti, input2Term, i)
+ yield generateInputAccess(ti, input2Term, i, input2PojoFieldMapping)
case None => Seq() // add nothing
}
@@ -346,6 +350,23 @@ class CodeGenerator(
}
/**
+ * Generates an expression from the left input and the right table function.
+ */
+ def generateCorrelateAccessExprs: (Seq[GeneratedExpression], Seq[GeneratedExpression]) = {
+ val input1AccessExprs = for (i <- 0 until input1.getArity)
+ yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping)
+
+ val input2AccessExprs = input2 match {
+ case Some(ti) => for (i <- 0 until ti.getArity)
+ // use generateFieldAccess instead of generateInputAccess to avoid the generated table
+ // function's field access code is put on the top of function body rather than the while loop
+ yield generateFieldAccess(ti, input2Term, i, input2PojoFieldMapping)
+ case None => throw new CodeGenException("type information of input2 must not be null")
+ }
+ (input1AccessExprs, input2AccessExprs)
+ }
+
+ /**
* Generates an expression from a sequence of RexNode. If objects or variables can be reused,
* they will be added to reusable code sections internally. The evaluation result
* may be stored in the global result variable (see [[outRecordTerm]]).
@@ -594,9 +615,11 @@ class CodeGenerator(
override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = {
// if inputRef index is within size of input1 we work with input1, input2 otherwise
val input = if (inputRef.getIndex < input1.getArity) {
- (input1, input1Term)
+ (input1, input1Term, input1PojoFieldMapping)
} else {
- (input2.getOrElse(throw new CodeGenException("Invalid input access.")), input2Term)
+ (input2.getOrElse(throw new CodeGenException("Invalid input access.")),
+ input2Term,
+ input2PojoFieldMapping)
}
val index = if (input._2 == input1Term) {
@@ -605,13 +628,17 @@ class CodeGenerator(
inputRef.getIndex - input1.getArity
}
- generateInputAccess(input._1, input._2, index)
+ generateInputAccess(input._1, input._2, index, input._3)
}
override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = {
val refExpr = rexFieldAccess.getReferenceExpr.accept(this)
val index = rexFieldAccess.getField.getIndex
- val fieldAccessExpr = generateFieldAccess(refExpr.resultType, refExpr.resultTerm, index)
+ val fieldAccessExpr = generateFieldAccess(
+ refExpr.resultType,
+ refExpr.resultTerm,
+ index,
+ input1PojoFieldMapping)
val resultTerm = newName("result")
val nullTerm = newName("isNull")
@@ -753,8 +780,9 @@ class CodeGenerator(
}
}
- override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression =
- throw new CodeGenException("Correlating variables are not supported yet.")
+ override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = {
+ GeneratedExpression(input1Term, GeneratedExpression.NEVER_NULL, "", input1)
+ }
override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression =
throw new CodeGenException("Local variables are not supported yet.")
@@ -948,7 +976,7 @@ class CodeGenerator(
// advanced scalar functions
case sqlOperator: SqlOperator =>
- val callGen = ScalarFunctions.getCallGenerator(
+ val callGen = FunctionGenerator.getCallGenerator(
sqlOperator,
operands.map(_.resultType),
resultType)
@@ -977,7 +1005,8 @@ class CodeGenerator(
private def generateInputAccess(
inputType: TypeInformation[Any],
inputTerm: String,
- index: Int)
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
: GeneratedExpression = {
// if input has been used before, we can reuse the code that
// has already been generated
@@ -989,10 +1018,10 @@ class CodeGenerator(
// generate input access and unboxing if necessary
case None =>
val expr = if (nullableInput) {
- generateNullableInputFieldAccess(inputType, inputTerm, index)
+ generateNullableInputFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
}
else {
- generateFieldAccess(inputType, inputTerm, index)
+ generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
}
reusableInputUnboxingExprs((inputTerm, index)) = expr
@@ -1005,7 +1034,8 @@ class CodeGenerator(
private def generateNullableInputFieldAccess(
inputType: TypeInformation[Any],
inputTerm: String,
- index: Int)
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
: GeneratedExpression = {
val resultTerm = newName("result")
val nullTerm = newName("isNull")
@@ -1013,7 +1043,7 @@ class CodeGenerator(
val fieldType = inputType match {
case ct: CompositeType[_] =>
val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) {
- inputPojoFieldMapping.get(index)
+ pojoFieldMapping.get(index)
}
else {
index
@@ -1024,7 +1054,7 @@ class CodeGenerator(
}
val resultTypeTerm = primitiveTypeTermForTypeInfo(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
- val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index)
+ val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
val inputCheckCode =
s"""
@@ -1047,12 +1077,13 @@ class CodeGenerator(
private def generateFieldAccess(
inputType: TypeInformation[_],
inputTerm: String,
- index: Int)
+ index: Int,
+ pojoFieldMapping: Option[Array[Int]])
: GeneratedExpression = {
inputType match {
case ct: CompositeType[_] =>
- val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]] && inputPojoFieldMapping.nonEmpty) {
- inputPojoFieldMapping.get(index)
+ val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]] && pojoFieldMapping.nonEmpty) {
+ pojoFieldMapping.get(index)
}
else {
index
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala
new file mode 100644
index 0000000..9b144ba
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala
@@ -0,0 +1,369 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.codegen.calls
+
+import java.lang.reflect.Method
+
+import org.apache.calcite.avatica.util.TimeUnitRange
+import org.apache.calcite.sql.SqlOperator
+import org.apache.calcite.sql.fun.SqlStdOperatorTable._
+import org.apache.calcite.sql.fun.SqlTrimFunction
+import org.apache.calcite.util.BuiltInMethod
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.GenericTypeInfo
+import org.apache.flink.api.table.functions.utils.{TableSqlFunction, ScalarSqlFunction}
+
+import scala.collection.mutable
+
+/**
+ * Global hub for user-defined and built-in advanced SQL functions.
+ */
+object FunctionGenerator {
+
+ private val sqlFunctions: mutable.Map[(SqlOperator, Seq[TypeInformation[_]]), CallGenerator] =
+ mutable.Map()
+
+ // ----------------------------------------------------------------------------------------------
+ // String functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ SUBSTRING,
+ Seq(STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.SUBSTRING.method)
+
+ addSqlFunctionMethod(
+ SUBSTRING,
+ Seq(STRING_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.SUBSTRING.method)
+
+ addSqlFunction(
+ TRIM,
+ Seq(new GenericTypeInfo(classOf[SqlTrimFunction.Flag]), STRING_TYPE_INFO, STRING_TYPE_INFO),
+ new TrimCallGen())
+
+ addSqlFunctionMethod(
+ CHAR_LENGTH,
+ Seq(STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.CHAR_LENGTH.method)
+
+ addSqlFunctionMethod(
+ CHARACTER_LENGTH,
+ Seq(STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.CHAR_LENGTH.method)
+
+ addSqlFunctionMethod(
+ UPPER,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.UPPER.method)
+
+ addSqlFunctionMethod(
+ LOWER,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.LOWER.method)
+
+ addSqlFunctionMethod(
+ INITCAP,
+ Seq(STRING_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.INITCAP.method)
+
+ addSqlFunctionMethod(
+ LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethod.LIKE.method)
+
+ addSqlFunctionMethod(
+ LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethods.LIKE_WITH_ESCAPE)
+
+ addSqlFunctionNotMethod(
+ NOT_LIKE,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BuiltInMethod.LIKE.method)
+
+ addSqlFunctionMethod(
+ SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethod.SIMILAR.method)
+
+ addSqlFunctionMethod(
+ SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BOOLEAN_TYPE_INFO,
+ BuiltInMethods.SIMILAR_WITH_ESCAPE)
+
+ addSqlFunctionNotMethod(
+ NOT_SIMILAR_TO,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ BuiltInMethod.SIMILAR.method)
+
+ addSqlFunctionMethod(
+ POSITION,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
+ INT_TYPE_INFO,
+ BuiltInMethod.POSITION.method)
+
+ addSqlFunctionMethod(
+ OVERLAY,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.OVERLAY.method)
+
+ addSqlFunctionMethod(
+ OVERLAY,
+ Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
+ STRING_TYPE_INFO,
+ BuiltInMethod.OVERLAY.method)
+
+ // ----------------------------------------------------------------------------------------------
+ // Arithmetic functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ LOG10,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.LOG10)
+
+ addSqlFunctionMethod(
+ LN,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.LN)
+
+ addSqlFunctionMethod(
+ EXP,
+ Seq(DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.EXP)
+
+ addSqlFunctionMethod(
+ POWER,
+ Seq(DOUBLE_TYPE_INFO, DOUBLE_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.POWER)
+
+ addSqlFunctionMethod(
+ POWER,
+ Seq(DOUBLE_TYPE_INFO, BIG_DEC_TYPE_INFO),
+ DOUBLE_TYPE_INFO,
+ BuiltInMethods.POWER_DEC)
+
+ addSqlFunction(
+ ABS,
+ Seq(DOUBLE_TYPE_INFO),
+ new MultiTypeMethodCallGen(BuiltInMethods.ABS))
+
+ addSqlFunction(
+ ABS,
+ Seq(BIG_DEC_TYPE_INFO),
+ new MultiTypeMethodCallGen(BuiltInMethods.ABS_DEC))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(DOUBLE_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(BIG_DEC_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
+
+ addSqlFunction(
+ CEIL,
+ Seq(DOUBLE_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.CEIL.method))
+
+ addSqlFunction(
+ CEIL,
+ Seq(BIG_DEC_TYPE_INFO),
+ new FloorCeilCallGen(BuiltInMethod.CEIL.method))
+
+ // ----------------------------------------------------------------------------------------------
+ // Temporal functions
+ // ----------------------------------------------------------------------------------------------
+
+ addSqlFunctionMethod(
+ EXTRACT_DATE,
+ Seq(new GenericTypeInfo(classOf[TimeUnitRange]), LONG_TYPE_INFO),
+ LONG_TYPE_INFO,
+ BuiltInMethod.UNIX_DATE_EXTRACT.method)
+
+ addSqlFunctionMethod(
+ EXTRACT_DATE,
+ Seq(new GenericTypeInfo(classOf[TimeUnitRange]), SqlTimeTypeInfo.DATE),
+ LONG_TYPE_INFO,
+ BuiltInMethod.UNIX_DATE_EXTRACT.method)
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
+
+ addSqlFunction(
+ FLOOR,
+ Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.FLOOR.method,
+ Some(BuiltInMethod.UNIX_TIMESTAMP_FLOOR.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
+
+ addSqlFunction(
+ CEIL,
+ Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
+ new FloorCeilCallGen(
+ BuiltInMethod.CEIL.method,
+ Some(BuiltInMethod.UNIX_TIMESTAMP_CEIL.method)))
+
+ addSqlFunction(
+ CURRENT_DATE,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.DATE, local = false))
+
+ addSqlFunction(
+ CURRENT_TIME,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = false))
+
+ addSqlFunction(
+ CURRENT_TIMESTAMP,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = false))
+
+ addSqlFunction(
+ LOCALTIME,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = true))
+
+ addSqlFunction(
+ LOCALTIMESTAMP,
+ Seq(),
+ new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = true))
+
+ // ----------------------------------------------------------------------------------------------
+
+ /**
+ * Returns a [[CallGenerator]] that generates all required code for calling the given
+ * [[SqlOperator]].
+ *
+ * @param sqlOperator SQL operator (might be overloaded)
+ * @param operandTypes actual operand types
+ * @param resultType expected return type
+ * @return [[CallGenerator]]
+ */
+ def getCallGenerator(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ resultType: TypeInformation[_])
+ : Option[CallGenerator] = sqlOperator match {
+
+ // user-defined scalar function
+ case ssf: ScalarSqlFunction =>
+ Some(
+ new ScalarFunctionCallGen(
+ ssf.getScalarFunction,
+ operandTypes,
+ resultType
+ )
+ )
+
+ // user-defined table function
+ case tsf: TableSqlFunction =>
+ Some(
+ new TableFunctionCallGen(
+ tsf.getTableFunction,
+ operandTypes,
+ resultType
+ )
+ )
+
+ // built-in scalar function
+ case _ =>
+ sqlFunctions.get((sqlOperator, operandTypes))
+ .orElse(sqlFunctions.find(entry => entry._1._1 == sqlOperator
+ && entry._1._2.length == operandTypes.length
+ && entry._1._2.zip(operandTypes).forall {
+ case (x: BasicTypeInfo[_], y: BasicTypeInfo[_]) => y.shouldAutocastTo(x) || x == y
+ case _ => false
+ }).map(_._2))
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ private def addSqlFunctionMethod(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ returnType: TypeInformation[_],
+ method: Method)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) = new MethodCallGen(returnType, method)
+ }
+
+ private def addSqlFunctionNotMethod(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ method: Method)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) =
+ new NotCallGenerator(new MethodCallGen(BOOLEAN_TYPE_INFO, method))
+ }
+
+ private def addSqlFunction(
+ sqlOperator: SqlOperator,
+ operandTypes: Seq[TypeInformation[_]],
+ callGenerator: CallGenerator)
+ : Unit = {
+ sqlFunctions((sqlOperator, operandTypes)) = callGenerator
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala
deleted file mode 100644
index e7c436a..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala
+++ /dev/null
@@ -1,359 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.api.table.codegen.calls
-
-import java.lang.reflect.Method
-
-import org.apache.calcite.avatica.util.TimeUnitRange
-import org.apache.calcite.sql.SqlOperator
-import org.apache.calcite.sql.fun.SqlStdOperatorTable._
-import org.apache.calcite.sql.fun.SqlTrimFunction
-import org.apache.calcite.util.BuiltInMethod
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.GenericTypeInfo
-import org.apache.flink.api.table.functions.utils.ScalarSqlFunction
-
-import scala.collection.mutable
-
-/**
- * Global hub for user-defined and built-in advanced SQL scalar functions.
- */
-object ScalarFunctions {
-
- private val sqlFunctions: mutable.Map[(SqlOperator, Seq[TypeInformation[_]]), CallGenerator] =
- mutable.Map()
-
- // ----------------------------------------------------------------------------------------------
- // String functions
- // ----------------------------------------------------------------------------------------------
-
- addSqlFunctionMethod(
- SUBSTRING,
- Seq(STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.SUBSTRING.method)
-
- addSqlFunctionMethod(
- SUBSTRING,
- Seq(STRING_TYPE_INFO, INT_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.SUBSTRING.method)
-
- addSqlFunction(
- TRIM,
- Seq(new GenericTypeInfo(classOf[SqlTrimFunction.Flag]), STRING_TYPE_INFO, STRING_TYPE_INFO),
- new TrimCallGen())
-
- addSqlFunctionMethod(
- CHAR_LENGTH,
- Seq(STRING_TYPE_INFO),
- INT_TYPE_INFO,
- BuiltInMethod.CHAR_LENGTH.method)
-
- addSqlFunctionMethod(
- CHARACTER_LENGTH,
- Seq(STRING_TYPE_INFO),
- INT_TYPE_INFO,
- BuiltInMethod.CHAR_LENGTH.method)
-
- addSqlFunctionMethod(
- UPPER,
- Seq(STRING_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.UPPER.method)
-
- addSqlFunctionMethod(
- LOWER,
- Seq(STRING_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.LOWER.method)
-
- addSqlFunctionMethod(
- INITCAP,
- Seq(STRING_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.INITCAP.method)
-
- addSqlFunctionMethod(
- LIKE,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- BOOLEAN_TYPE_INFO,
- BuiltInMethod.LIKE.method)
-
- addSqlFunctionMethod(
- LIKE,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
- BOOLEAN_TYPE_INFO,
- BuiltInMethods.LIKE_WITH_ESCAPE)
-
- addSqlFunctionNotMethod(
- NOT_LIKE,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- BuiltInMethod.LIKE.method)
-
- addSqlFunctionMethod(
- SIMILAR_TO,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- BOOLEAN_TYPE_INFO,
- BuiltInMethod.SIMILAR.method)
-
- addSqlFunctionMethod(
- SIMILAR_TO,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, STRING_TYPE_INFO),
- BOOLEAN_TYPE_INFO,
- BuiltInMethods.SIMILAR_WITH_ESCAPE)
-
- addSqlFunctionNotMethod(
- NOT_SIMILAR_TO,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- BuiltInMethod.SIMILAR.method)
-
- addSqlFunctionMethod(
- POSITION,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO),
- INT_TYPE_INFO,
- BuiltInMethod.POSITION.method)
-
- addSqlFunctionMethod(
- OVERLAY,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.OVERLAY.method)
-
- addSqlFunctionMethod(
- OVERLAY,
- Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO),
- STRING_TYPE_INFO,
- BuiltInMethod.OVERLAY.method)
-
- // ----------------------------------------------------------------------------------------------
- // Arithmetic functions
- // ----------------------------------------------------------------------------------------------
-
- addSqlFunctionMethod(
- LOG10,
- Seq(DOUBLE_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.LOG10)
-
- addSqlFunctionMethod(
- LN,
- Seq(DOUBLE_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.LN)
-
- addSqlFunctionMethod(
- EXP,
- Seq(DOUBLE_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.EXP)
-
- addSqlFunctionMethod(
- POWER,
- Seq(DOUBLE_TYPE_INFO, DOUBLE_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.POWER)
-
- addSqlFunctionMethod(
- POWER,
- Seq(DOUBLE_TYPE_INFO, BIG_DEC_TYPE_INFO),
- DOUBLE_TYPE_INFO,
- BuiltInMethods.POWER_DEC)
-
- addSqlFunction(
- ABS,
- Seq(DOUBLE_TYPE_INFO),
- new MultiTypeMethodCallGen(BuiltInMethods.ABS))
-
- addSqlFunction(
- ABS,
- Seq(BIG_DEC_TYPE_INFO),
- new MultiTypeMethodCallGen(BuiltInMethods.ABS_DEC))
-
- addSqlFunction(
- FLOOR,
- Seq(DOUBLE_TYPE_INFO),
- new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
-
- addSqlFunction(
- FLOOR,
- Seq(BIG_DEC_TYPE_INFO),
- new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
-
- addSqlFunction(
- CEIL,
- Seq(DOUBLE_TYPE_INFO),
- new FloorCeilCallGen(BuiltInMethod.CEIL.method))
-
- addSqlFunction(
- CEIL,
- Seq(BIG_DEC_TYPE_INFO),
- new FloorCeilCallGen(BuiltInMethod.CEIL.method))
-
- // ----------------------------------------------------------------------------------------------
- // Temporal functions
- // ----------------------------------------------------------------------------------------------
-
- addSqlFunctionMethod(
- EXTRACT_DATE,
- Seq(new GenericTypeInfo(classOf[TimeUnitRange]), LONG_TYPE_INFO),
- LONG_TYPE_INFO,
- BuiltInMethod.UNIX_DATE_EXTRACT.method)
-
- addSqlFunctionMethod(
- EXTRACT_DATE,
- Seq(new GenericTypeInfo(classOf[TimeUnitRange]), SqlTimeTypeInfo.DATE),
- LONG_TYPE_INFO,
- BuiltInMethod.UNIX_DATE_EXTRACT.method)
-
- addSqlFunction(
- FLOOR,
- Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.FLOOR.method,
- Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
-
- addSqlFunction(
- FLOOR,
- Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.FLOOR.method,
- Some(BuiltInMethod.UNIX_DATE_FLOOR.method)))
-
- addSqlFunction(
- FLOOR,
- Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.FLOOR.method,
- Some(BuiltInMethod.UNIX_TIMESTAMP_FLOOR.method)))
-
- addSqlFunction(
- CEIL,
- Seq(SqlTimeTypeInfo.DATE, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.CEIL.method,
- Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
-
- addSqlFunction(
- CEIL,
- Seq(SqlTimeTypeInfo.TIME, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.CEIL.method,
- Some(BuiltInMethod.UNIX_DATE_CEIL.method)))
-
- addSqlFunction(
- CEIL,
- Seq(SqlTimeTypeInfo.TIMESTAMP, new GenericTypeInfo(classOf[TimeUnitRange])),
- new FloorCeilCallGen(
- BuiltInMethod.CEIL.method,
- Some(BuiltInMethod.UNIX_TIMESTAMP_CEIL.method)))
-
- addSqlFunction(
- CURRENT_DATE,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.DATE, local = false))
-
- addSqlFunction(
- CURRENT_TIME,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = false))
-
- addSqlFunction(
- CURRENT_TIMESTAMP,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = false))
-
- addSqlFunction(
- LOCALTIME,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.TIME, local = true))
-
- addSqlFunction(
- LOCALTIMESTAMP,
- Seq(),
- new CurrentTimePointCallGen(SqlTimeTypeInfo.TIMESTAMP, local = true))
-
- // ----------------------------------------------------------------------------------------------
-
- /**
- * Returns a [[CallGenerator]] that generates all required code for calling the given
- * [[SqlOperator]].
- *
- * @param sqlOperator SQL operator (might be overloaded)
- * @param operandTypes actual operand types
- * @param resultType expected return type
- * @return [[CallGenerator]]
- */
- def getCallGenerator(
- sqlOperator: SqlOperator,
- operandTypes: Seq[TypeInformation[_]],
- resultType: TypeInformation[_])
- : Option[CallGenerator] = sqlOperator match {
-
- // user-defined scalar function
- case ssf: ScalarSqlFunction =>
- Some(
- new ScalarFunctionCallGen(
- ssf.getScalarFunction,
- operandTypes,
- resultType
- )
- )
-
- // built-in scalar function
- case _ =>
- sqlFunctions.get((sqlOperator, operandTypes))
- .orElse(sqlFunctions.find(entry => entry._1._1 == sqlOperator
- && entry._1._2.length == operandTypes.length
- && entry._1._2.zip(operandTypes).forall {
- case (x: BasicTypeInfo[_], y: BasicTypeInfo[_]) => y.shouldAutocastTo(x) || x == y
- case _ => false
- }).map(_._2))
- }
-
- // ----------------------------------------------------------------------------------------------
-
- private def addSqlFunctionMethod(
- sqlOperator: SqlOperator,
- operandTypes: Seq[TypeInformation[_]],
- returnType: TypeInformation[_],
- method: Method)
- : Unit = {
- sqlFunctions((sqlOperator, operandTypes)) = new MethodCallGen(returnType, method)
- }
-
- private def addSqlFunctionNotMethod(
- sqlOperator: SqlOperator,
- operandTypes: Seq[TypeInformation[_]],
- method: Method)
- : Unit = {
- sqlFunctions((sqlOperator, operandTypes)) =
- new NotCallGenerator(new MethodCallGen(BOOLEAN_TYPE_INFO, method))
- }
-
- private def addSqlFunction(
- sqlOperator: SqlOperator,
- operandTypes: Seq[TypeInformation[_]],
- callGenerator: CallGenerator)
- : Unit = {
- sqlFunctions((sqlOperator, operandTypes)) = callGenerator
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
new file mode 100644
index 0000000..27cb43f
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.codegen.calls
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.codegen.CodeGenUtils._
+import org.apache.flink.api.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression}
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._
+
+/**
+ * Generates a call to user-defined [[TableFunction]].
+ *
+ * @param tableFunction user-defined [[TableFunction]] that might be overloaded
+ * @param signature actual signature with which the function is called
+ * @param returnType actual return type required by the surrounding
+ */
+class TableFunctionCallGen(
+ tableFunction: TableFunction[_],
+ signature: Seq[TypeInformation[_]],
+ returnType: TypeInformation[_])
+ extends CallGenerator {
+
+ override def generate(
+ codeGenerator: CodeGenerator,
+ operands: Seq[GeneratedExpression])
+ : GeneratedExpression = {
+ // determine function signature
+ val matchingSignature = getSignature(tableFunction, signature)
+ .getOrElse(throw new CodeGenException("No matching signature found."))
+
+ // convert parameters for function (output boxing)
+ val parameters = matchingSignature
+ .zip(operands)
+ .map { case (paramClass, operandExpr) =>
+ if (paramClass.isPrimitive) {
+ operandExpr
+ } else {
+ val boxedTypeTerm = boxedTypeTermForTypeInfo(operandExpr.resultType)
+ val boxedExpr = codeGenerator.generateOutputFieldBoxing(operandExpr)
+ val exprOrNull: String = if (codeGenerator.nullCheck) {
+ s"${boxedExpr.nullTerm} ? null : ($boxedTypeTerm) ${boxedExpr.resultTerm}"
+ } else {
+ boxedExpr.resultTerm
+ }
+ boxedExpr.copy(resultTerm = exprOrNull)
+ }
+ }
+
+ // generate function call
+ val functionReference = codeGenerator.addReusableFunction(tableFunction)
+ val functionCallCode =
+ s"""
+ |${parameters.map(_.code).mkString("\n")}
+ |$functionReference.clear();
+ |$functionReference.eval(${parameters.map(_.resultTerm).mkString(", ")});
+ |""".stripMargin
+
+ // has no result
+ GeneratedExpression(
+ functionReference,
+ GeneratedExpression.NEVER_NULL,
+ functionCallCode,
+ returnType)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
index 6b6c129..6cd63ff 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
@@ -447,7 +447,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val alias: PackratParser[Expression] = logic ~ AS ~ fieldReference ^^ {
case e ~ _ ~ name => Alias(e, name.name)
- } | logic
+ } | logic ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
+ case e ~ _ ~ _ ~ names ~ _ => Alias(e, names.head.name, names.drop(1).map(_.name))
+ } | logic
lazy val expression: PackratParser[Expression] = alias |
failure("Invalid expression.")
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
index 39367be..3e8d8b1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
@@ -19,10 +19,12 @@ package org.apache.flink.api.table.expressions
import org.apache.calcite.rex.RexNode
import org.apache.calcite.tools.RelBuilder
-import org.apache.flink.api.table.functions.ScalarFunction
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, signatureToString, signaturesToString}
-import org.apache.flink.api.table.validate.{ValidationResult, ValidationFailure, ValidationSuccess}
-import org.apache.flink.api.table.{FlinkTypeFactory, UnresolvedException}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction}
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.api.table.plan.logical.{LogicalNode, LogicalTableFunctionCall}
+import org.apache.flink.api.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
+import org.apache.flink.api.table.{FlinkTypeFactory, UnresolvedException, ValidationException}
/**
* General expression for unresolved function calls. The function can be a built-in
@@ -63,11 +65,15 @@ case class ScalarFunctionCall(
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
relBuilder.call(
- scalarFunction.getSqlFunction(scalarFunction.toString, typeFactory),
+ createScalarSqlFunction(
+ scalarFunction.getClass.getCanonicalName,
+ scalarFunction,
+ typeFactory),
parameters.map(_.toRexNode): _*)
}
- override def toString = s"$scalarFunction(${parameters.mkString(", ")})"
+ override def toString =
+ s"${scalarFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})"
override private[flink] def resultType = getResultType(scalarFunction, foundSignature.get)
@@ -85,3 +91,68 @@ case class ScalarFunctionCall(
}
}
+
+
+/**
+ *
+ * Expression for calling a user-defined table function with actual parameters.
+ *
+ * @param functionName function name
+ * @param tableFunction user-defined table function
+ * @param parameters actual parameters of function
+ * @param resultType type information of returned table
+ */
+case class TableFunctionCall(
+ functionName: String,
+ tableFunction: TableFunction[_],
+ parameters: Seq[Expression],
+ resultType: TypeInformation[_])
+ extends Expression {
+
+ private var aliases: Option[Seq[String]] = None
+
+ override private[flink] def children: Seq[Expression] = parameters
+
+ /**
+ * Assigns an alias for this table function returned fields that the following `select()` clause
+ * can refer to.
+ *
+ * @param aliasList alias for this table function returned fields
+ * @return this table function call
+ */
+ private[flink] def as(aliasList: Option[Seq[String]]): TableFunctionCall = {
+ this.aliases = aliasList
+ this
+ }
+
+ /**
+ * Converts an API class to a logical node for planning.
+ */
+ private[flink] def toLogicalTableFunctionCall(child: LogicalNode): LogicalTableFunctionCall = {
+ val originNames = getFieldInfo(resultType)._1
+
+ // determine the final field names
+ val fieldNames = if (aliases.isDefined) {
+ val aliasList = aliases.get
+ if (aliasList.length != originNames.length) {
+ throw ValidationException(
+ s"List of column aliases must have same degree as table; " +
+ s"the returned table of function '$functionName' has ${originNames.length} " +
+ s"columns (${originNames.mkString(",")}), " +
+ s"whereas alias list has ${aliasList.length} columns")
+ } else {
+ aliasList.toArray
+ }
+ } else {
+ originNames
+ }
+
+ LogicalTableFunctionCall(
+ functionName,
+ tableFunction,
+ parameters,
+ resultType,
+ fieldNames,
+ child)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala
index c7817bf..e651bb3 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala
@@ -67,7 +67,7 @@ case class ResolvedFieldReference(
}
}
-case class Alias(child: Expression, name: String)
+case class Alias(child: Expression, name: String, extraNames: Seq[String] = Seq())
extends UnaryExpression with NamedExpression {
override def toString = s"$child as '$name"
@@ -80,7 +80,7 @@ case class Alias(child: Expression, name: String)
override private[flink] def makeCopy(anyRefs: Array[AnyRef]): this.type = {
val child: Expression = anyRefs.head.asInstanceOf[Expression]
- copy(child, name).asInstanceOf[this.type]
+ copy(child, name, extraNames).asInstanceOf[this.type]
}
override private[flink] def toAttribute: Attribute = {
@@ -94,6 +94,8 @@ case class Alias(child: Expression, name: String)
override private[flink] def validateInput(): ValidationResult = {
if (name == "*") {
ValidationFailure("Alias can not accept '*' as name.")
+ } else if (extraNames.nonEmpty) {
+ ValidationFailure("Invalid call to Alias with multiple names.")
} else {
ValidationSuccess
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
index 5f9d834..86d9d66 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
@@ -60,47 +60,6 @@ abstract class ScalarFunction extends UserDefinedFunction {
ScalarFunctionCall(this, params)
}
- // ----------------------------------------------------------------------------------------------
-
- private val evalMethods = checkAndExtractEvalMethods()
- private lazy val signatures = evalMethods.map(_.getParameterTypes)
-
- /**
- * Extracts evaluation methods and throws a [[ValidationException]] if no implementation
- * can be found.
- */
- private def checkAndExtractEvalMethods(): Array[Method] = {
- val methods = getClass.asSubclass(classOf[ScalarFunction])
- .getDeclaredMethods
- .filter { m =>
- val modifiers = m.getModifiers
- m.getName == "eval" && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers)
- }
-
- if (methods.isEmpty) {
- throw new ValidationException(s"Scalar function class '$this' does not implement at least " +
- s"one method named 'eval' which is public and not abstract.")
- } else {
- methods
- }
- }
-
- /**
- * Returns all found evaluation methods of the possibly overloaded function.
- */
- private[flink] final def getEvalMethods: Array[Method] = evalMethods
-
- /**
- * Returns all found signature of the possibly overloaded function.
- */
- private[flink] final def getSignatures: Array[Array[Class[_]]] = signatures
-
- override private[flink] final def createSqlFunction(
- name: String,
- typeFactory: FlinkTypeFactory)
- : SqlFunction = {
- new ScalarSqlFunction(name, this, typeFactory)
- }
// ----------------------------------------------------------------------------------------------
@@ -135,7 +94,8 @@ abstract class ScalarFunction extends UserDefinedFunction {
TypeExtractor.getForClass(c)
} catch {
case ite: InvalidTypesException =>
- throw new ValidationException(s"Parameter types of scalar function '$this' cannot be " +
+ throw new ValidationException(
+ s"Parameter types of scalar function '${this.getClass.getCanonicalName}' cannot be " +
s"automatically determined. Please provide type information manually.")
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
new file mode 100644
index 0000000..98a2921
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.functions
+
+import java.util
+
+import org.apache.flink.api.common.functions.InvalidTypesException
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.flink.api.table.ValidationException
+
+/**
+ * Base class for a user-defined table function (UDTF). A user-defined table functions works on
+ * zero, one, or multiple scalar values as input and returns multiple rows as output.
+ *
+ * The behavior of a [[TableFunction]] can be defined by implementing a custom evaluation
+ * method. An evaluation method must be declared publicly and named "eval". Evaluation methods
+ * can also be overloaded by implementing multiple methods named "eval".
+ *
+ * User-defined functions must have a default constructor and must be instantiable during runtime.
+ *
+ * By default the result type of an evaluation method is determined by Flink's type extraction
+ * facilities. This is sufficient for basic types or simple POJOs but might be wrong for more
+ * complex, custom, or composite types. In these cases [[TypeInformation]] of the result type
+ * can be manually defined by overriding [[getResultType()]].
+ *
+ * Internally, the Table/SQL API code generation works with primitive values as much as possible.
+ * If a user-defined table function should not introduce much overhead during runtime, it is
+ * recommended to declare parameters and result types as primitive types instead of their boxed
+ * classes. DATE/TIME is equal to int, TIMESTAMP is equal to long.
+ *
+ * Example:
+ *
+ * {{{
+ *
+ * public class Split extends TableFunction<String> {
+ *
+ * // implement an "eval" method with several parameters you want
+ * public void eval(String str) {
+ * for (String s : str.split(" ")) {
+ * collect(s); // use collect(...) to emit an output row
+ * }
+ * }
+ *
+ * // can overloading eval method here ...
+ * }
+ *
+ * val tEnv: TableEnvironment = ...
+ * val table: Table = ... // schema: [a: String]
+ *
+ * // for Scala users
+ * val split = new Split()
+ * table.crossApply(split('c) as ('s)).select('a, 's)
+ *
+ * // for Java users
+ * tEnv.registerFunction("split", new Split()) // register table function first
+ * table.crossApply("split(a) as (s)").select("a, s")
+ *
+ * // for SQL users
+ * tEnv.registerFunction("split", new Split()) // register table function first
+ * tEnv.sql("SELECT a, s FROM MyTable, LATERAL TABLE(split(a)) as T(s)")
+ *
+ * }}}
+ *
+ * @tparam T The type of the output row
+ */
+abstract class TableFunction[T] extends UserDefinedFunction {
+
+ private val rows: util.ArrayList[T] = new util.ArrayList[T]()
+
+ /**
+ * Emit an output row.
+ *
+ * @param row the output row
+ */
+ protected def collect(row: T): Unit = {
+ // cache rows for now, maybe immediately process them further
+ rows.add(row)
+ }
+
+ /**
+ * Internal use. Get an iterator of the buffered rows.
+ */
+ def getRowsIterator = rows.iterator()
+
+ /**
+ * Internal use. Clear buffered rows.
+ */
+ def clear() = rows.clear()
+
+ // ----------------------------------------------------------------------------------------------
+
+ /**
+ * Returns the result type of the evaluation method with a given signature.
+ *
+ * This method needs to be overriden in case Flink's type extraction facilities are not
+ * sufficient to extract the [[TypeInformation]] based on the return type of the evaluation
+ * method. Flink's type extraction facilities can handle basic types or
+ * simple POJOs but might be wrong for more complex, custom, or composite types.
+ *
+ * @return [[TypeInformation]] of result type or null if Flink should determine the type
+ */
+ def getResultType: TypeInformation[T] = null
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala
index 62afef0..cdf6b07 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala
@@ -15,47 +15,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.api.table.functions
-import org.apache.calcite.sql.SqlFunction
-import org.apache.flink.api.table.FlinkTypeFactory
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.checkForInstantiation
-
-import scala.collection.mutable
-
/**
* Base class for all user-defined functions such as scalar functions, table functions,
* or aggregation functions.
*
* User-defined functions must have a default constructor and must be instantiable during runtime.
*/
-abstract class UserDefinedFunction {
-
- // we cache SQL functions to reduce amount of created objects
- // (i.e. for type inference, validation, etc.)
- private val cachedSqlFunctions = mutable.HashMap[String, SqlFunction]()
-
- // check if function can be instantiated
- checkForInstantiation(this.getClass)
-
- /**
- * Returns the corresponding [[SqlFunction]]. Creates an instance if not already created.
- */
- private[flink] final def getSqlFunction(
- name: String,
- typeFactory: FlinkTypeFactory)
- : SqlFunction = {
- cachedSqlFunctions.getOrElseUpdate(name, createSqlFunction(name, typeFactory))
- }
-
- /**
- * Creates corresponding [[SqlFunction]].
- */
- private[flink] def createSqlFunction(
- name: String,
- typeFactory: FlinkTypeFactory)
- : SqlFunction
-
- override def toString = getClass.getCanonicalName
+trait UserDefinedFunction {
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
index 531313e..0a987aa 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
@@ -26,7 +26,7 @@ import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.functions.ScalarFunction
import org.apache.flink.api.table.functions.utils.ScalarSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}
-import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, signatureToString, signaturesToString}
+import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, getSignatures, signatureToString, signaturesToString}
import org.apache.flink.api.table.{FlinkTypeFactory, ValidationException}
import scala.collection.JavaConverters._
@@ -123,6 +123,8 @@ object ScalarSqlFunction {
name: String,
scalarFunction: ScalarFunction)
: SqlOperandTypeChecker = {
+
+ val signatures = getSignatures(scalarFunction)
/**
* Operand type checker based on [[ScalarFunction]] given information.
*/
@@ -132,7 +134,7 @@ object ScalarSqlFunction {
}
override def getOperandCountRange: SqlOperandCountRange = {
- val signatureLengths = scalarFunction.getSignatures.map(_.length)
+ val signatureLengths = signatures.map(_.length)
SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
new file mode 100644
index 0000000..6eadfbc
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.functions.utils
+
+import com.google.common.base.Predicate
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.sql._
+import org.apache.calcite.sql.`type`._
+import org.apache.calcite.sql.parser.SqlParserPos
+import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction
+import org.apache.calcite.util.Util
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.FlinkTypeFactory
+import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
+
+import scala.collection.JavaConverters._
+import java.util
+
+
+/**
+ * Calcite wrapper for user-defined table functions.
+ */
+class TableSqlFunction(
+ name: String,
+ udtf: TableFunction[_],
+ rowTypeInfo: TypeInformation[_],
+ returnTypeInference: SqlReturnTypeInference,
+ operandTypeInference: SqlOperandTypeInference,
+ operandTypeChecker: SqlOperandTypeChecker,
+ paramTypes: util.List[RelDataType],
+ functionImpl: FlinkTableFunctionImpl[_])
+ extends SqlUserDefinedTableFunction(
+ new SqlIdentifier(name, SqlParserPos.ZERO),
+ returnTypeInference,
+ operandTypeInference,
+ operandTypeChecker,
+ paramTypes,
+ functionImpl) {
+
+ /**
+ * Get the user-defined table function
+ */
+ def getTableFunction = udtf
+
+ /**
+ * Get the returned table type information of the table function
+ */
+ def getRowTypeInfo = rowTypeInfo
+
+ /**
+ * Get additional mapping information if the returned table type is a POJO
+ * (POJO types have no deterministic field order)
+ */
+ def getPojoFieldMapping = functionImpl.fieldIndexes
+
+}
+
+object TableSqlFunction {
+ /**
+ * Util function to create a [[TableSqlFunction]]
+ * @param name function name (used by SQL parser)
+ * @param udtf user defined table function to be called
+ * @param rowTypeInfo the row type information generated by the table function
+ * @param typeFactory type factory for converting Flink's between Calcite's types
+ * @param functionImpl calcite table function schema
+ * @return [[TableSqlFunction]]
+ */
+ def apply(
+ name: String,
+ udtf: TableFunction[_],
+ rowTypeInfo: TypeInformation[_],
+ typeFactory: FlinkTypeFactory,
+ functionImpl: FlinkTableFunctionImpl[_]): TableSqlFunction = {
+
+ val argTypes: util.List[RelDataType] = new util.ArrayList[RelDataType]
+ val typeFamilies: util.List[SqlTypeFamily] = new util.ArrayList[SqlTypeFamily]
+ // derives operands' data types and type families
+ functionImpl.getParameters.asScala.foreach{ o =>
+ val relType: RelDataType = o.getType(typeFactory)
+ argTypes.add(relType)
+ typeFamilies.add(Util.first(relType.getSqlTypeName.getFamily, SqlTypeFamily.ANY))
+ }
+ // derives whether the 'input'th parameter of a method is optional.
+ val optional: Predicate[Integer] = new Predicate[Integer]() {
+ def apply(input: Integer): Boolean = {
+ functionImpl.getParameters.get(input).isOptional
+ }
+ }
+ // create type check for the operands
+ val typeChecker: FamilyOperandTypeChecker = OperandTypes.family(typeFamilies, optional)
+
+ new TableSqlFunction(
+ name,
+ udtf,
+ rowTypeInfo,
+ ReturnTypes.CURSOR,
+ InferTypes.explicit(argTypes),
+ typeChecker,
+ argTypes,
+ functionImpl)
+ }
+}
[4/5] flink git commit: [FLINK-4469] [table] Minor improvements
Posted by tw...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..a9f3f7b
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.batch.table
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.{DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv}
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment => ScalaExecutionEnv, _}
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.apache.flink.api.table.utils.{PojoTableFunc, TableFunc2, _}
+import org.apache.flink.api.table.{Row, TableEnvironment, Types}
+import org.junit.Test
+import org.mockito.Mockito._
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+ @Test
+ def testJavaScalaTableAPIEquality(): Unit = {
+ // mock
+ val ds = mock(classOf[DataSet[Row]])
+ val jDs = mock(classOf[JDataSet[Row]])
+ val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
+ when(ds.javaSet).thenReturn(jDs)
+ when(jDs.getType).thenReturn(typeInfo)
+
+ // Scala environment
+ val env = mock(classOf[ScalaExecutionEnv])
+ val tableEnv = TableEnvironment.getTableEnvironment(env)
+ val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
+
+ // Java environment
+ val javaEnv = mock(classOf[JavaExecutionEnv])
+ val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
+ val in2 = javaTableEnv.fromDataSet(jDs).as("a, b, c")
+ javaTableEnv.registerTable("MyTable", in2)
+
+ // test cross apply
+ val func1 = new TableFunc1
+ javaTableEnv.registerFunction("func1", func1)
+ var scalaTable = in1.crossApply(func1('c) as 's).select('c, 's)
+ var javaTable = in2.crossApply("func1(c).as(s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test outer apply
+ scalaTable = in1.outerApply(func1('c) as 's).select('c, 's)
+ javaTable = in2.outerApply("as(func1(c), s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test overloading
+ scalaTable = in1.crossApply(func1('c, "$") as 's).select('c, 's)
+ javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test custom result type
+ val func2 = new TableFunc2
+ javaTableEnv.registerFunction("func2", func2)
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
+ javaTable = in2.crossApply("func2(c).as(name, len)").select("c, name, len")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test hierarchy generic type
+ val hierarchy = new HierarchyTableFunction
+ javaTableEnv.registerFunction("hierarchy", hierarchy)
+ scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'len, 'adult)
+ javaTable = in2.crossApply("AS(hierarchy(c), name, adult, len)")
+ .select("c, name, len, adult")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test pojo type
+ val pojo = new PojoTableFunc
+ javaTableEnv.registerFunction("pojo", pojo)
+ scalaTable = in1.crossApply(pojo('c))
+ .select('c, 'name, 'age)
+ javaTable = in2.crossApply("pojo(c)")
+ .select("c, name, age")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with filter
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len).filter('len > 2)
+ javaTable = in2.crossApply("func2(c) as (name, len)")
+ .select("c, name, len").filter("len > 2")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with scalar function
+ scalaTable = in1.crossApply(func1('c.substring(2)) as 's)
+ .select('a, 'c, 's)
+ javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
+ .select("a, c, s")
+ verifyTableEquals(scalaTable, javaTable)
+ }
+
+ @Test
+ def testCrossApply(): Unit = {
+ val util = batchTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("func1", new TableFunc1)
+
+ val result1 = table.crossApply(function('c) as 's).select('c, 's)
+
+ val expected1 = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", s"$function($$2)"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "s")
+ )
+
+ util.verifyTable(result1, expected1)
+
+ // test overloading
+
+ val result2 = table.crossApply(function('c, "$") as 's).select('c, 's)
+
+ val expected2 = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", s"$function($$2, '$$')"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "s")
+ )
+
+ util.verifyTable(result2, expected2)
+ }
+
+ @Test
+ def testOuterApply(): Unit = {
+ val util = batchTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("func1", new TableFunc1)
+
+ val result = table.outerApply(function('c) as 's).select('c, 's)
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", s"$function($$2)"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
+ term("joinType", "LEFT")
+ ),
+ term("select", "c", "s")
+ )
+
+ util.verifyTable(result, expected)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
deleted file mode 100644
index f19f7f9..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
+++ /dev/null
@@ -1,181 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.api.scala.stream
-
-import org.apache.flink.api.scala._
-import org.apache.flink.api.scala.stream.utils.StreamITCase
-import org.apache.flink.api.scala.table._
-import org.apache.flink.api.table.expressions.utils.{TableFunc0, TableFunc1}
-import org.apache.flink.api.table.{Row, TableEnvironment}
-import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
-import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
-import org.junit.Assert._
-import org.junit.Test
-
-import scala.collection.mutable
-
-class UserDefinedTableFunctionITCase extends StreamingMultipleProgramsTestBase {
-
- @Test
- def testSQLCrossApply(): Unit = {
-
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- StreamITCase.clear
-
- val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
- tEnv.registerTable("MyTable", t)
-
- tEnv.registerFunction("split", new TableFunc0)
-
- val sqlQuery = "SELECT MyTable.c, t.n, t.a FROM MyTable, LATERAL TABLE(split(c)) AS t(n,a)"
-
- val result = tEnv.sql(sqlQuery).toDataStream[Row]
- result.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList(
- "Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- @Test
- def testSQLOuterApply(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- StreamITCase.clear
-
- val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
- tEnv.registerTable("MyTable", t)
-
- tEnv.registerFunction("split", new TableFunc0)
-
- val sqlQuery = "SELECT MyTable.c, t.n, t.a FROM MyTable " +
- "LEFT JOIN LATERAL TABLE(split(c)) AS t(n,a) ON TRUE"
-
- val result = tEnv.sql(sqlQuery).toDataStream[Row]
- result.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList(
- "nosharp,null,null", "Jack#22,Jack,22",
- "John#19,John,19", "Anna#44,Anna,44")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- @Test
- def testTableAPICrossApply(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- StreamITCase.clear
-
- val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
- val func0 = new TableFunc0
-
- val result = t
- .crossApply(func0('c) as('d, 'e))
- .select('c, 'd, 'e)
- .toDataStream[Row]
-
- result.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- @Test
- def testTableAPIOuterApply(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- StreamITCase.clear
-
- val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
- val func0 = new TableFunc0
-
- val result = t
- .outerApply(func0('c) as('d, 'e))
- .select('c, 'd, 'e)
- .toDataStream[Row]
-
- result.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList(
- "nosharp,null,null", "Jack#22,Jack,22",
- "John#19,John,19", "Anna#44,Anna,44")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- @Test
- def testTableAPIWithFilter(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- StreamITCase.clear
-
- val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
- val func0 = new TableFunc0
-
- val result = t
- .crossApply(func0('c) as('name, 'age))
- .select('c, 'name, 'age)
- .filter('age > 20)
- .toDataStream[Row]
-
- result.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList("Jack#22,Jack,22", "Anna#44,Anna,44")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- @Test
- def testTableAPIWithScalarFunction(): Unit = {
- val env = StreamExecutionEnvironment.getExecutionEnvironment
- val tEnv = TableEnvironment.getTableEnvironment(env)
- StreamITCase.clear
-
- val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
- val func1 = new TableFunc1
-
- val result = t
- .crossApply(func1('c.substring(2)) as 's)
- .select('c, 's)
- .toDataStream[Row]
-
- result.addSink(new StreamITCase.StringSink)
- env.execute()
-
- val expected = mutable.MutableList("Jack#22,ack", "Jack#22,22", "John#19,ohn",
- "John#19,19", "Anna#44,nna", "Anna#44,44")
- assertEquals(expected.sorted, StreamITCase.testResults.sorted)
- }
-
- private def getSmall3TupleDataStream(
- env: StreamExecutionEnvironment)
- : DataStream[(Int, Long, String)] = {
-
- val data = new mutable.MutableList[(Int, Long, String)]
- data.+=((1, 1L, "Jack#22"))
- data.+=((2, 2L, "John#19"))
- data.+=((3, 2L, "Anna#44"))
- data.+=((4, 3L, "nosharp"))
- env.fromCollection(data)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
deleted file mode 100644
index bc01819..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
+++ /dev/null
@@ -1,402 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.api.scala.stream
-
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.scala._
-import org.apache.flink.api.scala.table._
-import org.apache.flink.api.table.expressions.utils._
-import org.apache.flink.api.table.typeutils.RowTypeInfo
-import org.apache.flink.api.table.utils.TableTestBase
-import org.apache.flink.api.table.utils.TableTestUtil._
-import org.apache.flink.api.table._
-import org.apache.flink.streaming.api.environment.{StreamExecutionEnvironment => JavaExecutionEnv}
-import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream}
-import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment => ScalaExecutionEnv}
-import org.junit.Assert.{assertTrue, fail}
-import org.junit.Test
-import org.mockito.Mockito._
-
-class UserDefinedTableFunctionTest extends TableTestBase {
-
- @Test
- def testTableAPI(): Unit = {
- // mock
- val ds = mock(classOf[DataStream[Row]])
- val jDs = mock(classOf[JDataStream[Row]])
- val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
- when(ds.javaStream).thenReturn(jDs)
- when(jDs.getType).thenReturn(typeInfo)
-
- // Scala environment
- val env = mock(classOf[ScalaExecutionEnv])
- val tableEnv = TableEnvironment.getTableEnvironment(env)
- val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
-
- // Java environment
- val javaEnv = mock(classOf[JavaExecutionEnv])
- val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
- val in2 = javaTableEnv.fromDataStream(jDs).as("a, b, c")
-
- // test cross apply
- val func1 = new TableFunc1
- javaTableEnv.registerFunction("func1", func1)
- var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's)
- var javaTable = in2.crossApply("func1(c) as (s)").select("c, s")
- verifyTableEquals(scalaTable, javaTable)
-
- // test outer apply
- scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's)
- javaTable = in2.outerApply("func1(c) as (s)").select("c, s")
- verifyTableEquals(scalaTable, javaTable)
-
- // test overloading
- scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's)
- javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
- verifyTableEquals(scalaTable, javaTable)
-
- // test custom result type
- val func2 = new TableFunc2
- javaTableEnv.registerFunction("func2", func2)
- scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
- javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len")
- verifyTableEquals(scalaTable, javaTable)
-
- // test hierarchy generic type
- val hierarchy = new HierarchyTableFunction
- javaTableEnv.registerFunction("hierarchy", hierarchy)
- scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
- .select('c, 'name, 'len, 'adult)
- javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)")
- .select("c, name, len, adult")
- verifyTableEquals(scalaTable, javaTable)
-
- // test pojo type
- val pojo = new PojoTableFunc
- javaTableEnv.registerFunction("pojo", pojo)
- scalaTable = in1.crossApply(pojo('c))
- .select('c, 'name, 'age)
- javaTable = in2.crossApply("pojo(c)")
- .select("c, name, age")
- verifyTableEquals(scalaTable, javaTable)
-
- // test with filter
- scalaTable = in1.crossApply(func2('c) as ('name, 'len))
- .select('c, 'name, 'len).filter('len > 2)
- javaTable = in2.crossApply("func2(c) as (name, len)")
- .select("c, name, len").filter("len > 2")
- verifyTableEquals(scalaTable, javaTable)
-
- // test with scalar function
- scalaTable = in1.crossApply(func1('c.substring(2)) as ('s))
- .select('a, 'c, 's)
- javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
- .select("a, c, s")
- verifyTableEquals(scalaTable, javaTable)
-
- // check scala object is forbidden
- expectExceptionThrown(
- tableEnv.registerFunction("func3", ObjectTableFunction), "Scala object")
- expectExceptionThrown(
- javaTableEnv.registerFunction("func3", ObjectTableFunction), "Scala object")
- expectExceptionThrown(
- in1.crossApply(ObjectTableFunction('a, 1)),"Scala object")
-
- }
-
-
- @Test
- def testInvalidTableFunction(): Unit = {
- // mock
- val util = streamTestUtil()
- val t = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- val tEnv = TableEnvironment.getTableEnvironment(mock(classOf[JavaExecutionEnv]))
-
- //=================== check scala object is forbidden =====================
- // Scala table environment register
- expectExceptionThrown(util.addFunction("udtf", ObjectTableFunction), "Scala object")
- // Java table environment register
- expectExceptionThrown(tEnv.registerFunction("udtf", ObjectTableFunction), "Scala object")
- // Scala Table API directly call
- expectExceptionThrown(t.crossApply(ObjectTableFunction('a, 1)), "Scala object")
-
-
- //============ throw exception when table function is not registered =========
- // Java Table API call
- expectExceptionThrown(t.crossApply("nonexist(a)"), "Undefined function: NONEXIST")
- // SQL API call
- expectExceptionThrown(
- util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(nonexist(a))"),
- "No match found for function signature nonexist(<NUMERIC>)")
-
-
- //========= throw exception when the called function is a scalar function ====
- util.addFunction("func0", Func0)
- // Java Table API call
- expectExceptionThrown(
- t.crossApply("func0(a)"),
- "only accept TableFunction",
- classOf[TableException])
- // SQL API call
- // NOTE: it doesn't throw an exception but an AssertionError, maybe a Calcite bug
- expectExceptionThrown(
- util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func0(a))"),
- null,
- classOf[AssertionError])
-
- //========== throw exception when the parameters is not correct ===============
- // Java Table API call
- util.addFunction("func2", new TableFunc2)
- expectExceptionThrown(
- t.crossApply("func2(c, c)"),
- "Given parameters of function 'FUNC2' do not match any signature")
- // SQL API call
- expectExceptionThrown(
- util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func2(c, c))"),
- "No match found for function signature func2(<CHARACTER>, <CHARACTER>)")
- }
-
- private def expectExceptionThrown(
- function: => Unit,
- keywords: String,
- clazz: Class[_ <: Throwable] = classOf[ValidationException])
- : Unit = {
- try {
- function
- fail(s"Expected a $clazz, but no exception is thrown.")
- } catch {
- case e if e.getClass == clazz =>
- if (keywords != null) {
- assertTrue(
- s"The exception message '${e.getMessage}' doesn't contain keyword '$keywords'",
- e.getMessage.contains(keywords))
- }
- case e: Throwable => fail(s"Expected throw ${clazz.getSimpleName}, but is $e.")
- }
- }
-
- @Test
- def testSQLWithCrossApply(): Unit = {
- val util = streamTestUtil()
- val func1 = new TableFunc1
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func1", func1)
-
- val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
-
- val expected = unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamCorrelate",
- streamTableNode(0),
- term("invocation", "func1($cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS s")
- )
-
- util.verifySql(sqlQuery, expected)
-
- // test overloading
-
- val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
-
- val expected2 = unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamCorrelate",
- streamTableNode(0),
- term("invocation", "func1($cor0.c, '$')"),
- term("function", func1.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS s")
- )
-
- util.verifySql(sqlQuery2, expected2)
- }
-
- @Test
- def testSQLWithOuterApply(): Unit = {
- val util = streamTestUtil()
- val func1 = new TableFunc1
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func1", func1)
-
- val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE"
-
- val expected = unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamCorrelate",
- streamTableNode(0),
- term("invocation", "func1($cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
- term("joinType", "LEFT")
- ),
- term("select", "c", "f0 AS s")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testSQLWithCustomType(): Unit = {
- val util = streamTestUtil()
- val func2 = new TableFunc2
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func2", func2)
-
- val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)"
-
- val expected = unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamCorrelate",
- streamTableNode(0),
- term("invocation", "func2($cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
- "VARCHAR(2147483647) f0, INTEGER f1)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS name", "f1 AS len")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testSQLWithHierarchyType(): Unit = {
- val util = streamTestUtil()
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- val function = new HierarchyTableFunction
- util.addFunction("hierarchy", function)
-
- val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)"
-
- val expected = unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamCorrelate",
- streamTableNode(0),
- term("invocation", "hierarchy($cor0.c)"),
- term("function", function.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
- " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testSQLWithPojoType(): Unit = {
- val util = streamTestUtil()
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- val function = new PojoTableFunc
- util.addFunction("pojo", function)
-
- val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
-
- val expected = unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamCorrelate",
- streamTableNode(0),
- term("invocation", "pojo($cor0.c)"),
- term("function", function.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
- " INTEGER age, VARCHAR(2147483647) name)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "name", "age")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testSQLWithFilter(): Unit = {
- val util = streamTestUtil()
- val func2 = new TableFunc2
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func2", func2)
-
- val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " +
- "WHERE len > 2"
-
- val expected = unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamCorrelate",
- streamTableNode(0),
- term("invocation", "func2($cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
- "VARCHAR(2147483647) f0, INTEGER f1)"),
- term("joinType", "INNER"),
- term("condition", ">($1, 2)")
- ),
- term("select", "c", "f0 AS name", "f1 AS len")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
-
- @Test
- def testSQLWithScalarFunction(): Unit = {
- val util = streamTestUtil()
- val func1 = new TableFunc1
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func1", func1)
-
- val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
-
- val expected = unaryNode(
- "DataStreamCalc",
- unaryNode(
- "DataStreamCorrelate",
- streamTableNode(0),
- term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
- term("function", func1.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS s")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/sql/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/sql/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/sql/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..c2ded28
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/sql/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.stream.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.utils.{HierarchyTableFunction, PojoTableFunc, TableFunc2}
+import org.apache.flink.api.table.utils._
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.junit.Test
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+ @Test
+ def testCrossApply(): Unit = {
+ val util = streamTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+
+ // test overloading
+
+ val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
+
+ val expected2 = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func1($cor0.c, '$')"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery2, expected2)
+ }
+
+ @Test
+ def testOuterApply(): Unit = {
+ val util = streamTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "LEFT")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testCustomType(): Unit = {
+ val util = streamTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testHierarchyType(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new HierarchyTableFunction
+ util.addFunction("hierarchy", function)
+
+ val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "hierarchy($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testPojoType(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new PojoTableFunc
+ util.addFunction("pojo", function)
+
+ val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "pojo($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " INTEGER age, VARCHAR(2147483647) name)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "name", "age")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testFilter(): Unit = {
+ val util = streamTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " +
+ "WHERE len > 2"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER"),
+ term("condition", ">($1, 2)")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testScalarFunction(): Unit = {
+ val util = streamTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..bc28d67
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,385 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.stream.table
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table._
+import org.apache.flink.api.table.expressions.utils._
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.apache.flink.api.table.utils._
+import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream}
+import org.apache.flink.streaming.api.environment.{StreamExecutionEnvironment => JavaExecutionEnv}
+import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment => ScalaExecutionEnv}
+import org.junit.Assert.{assertTrue, fail}
+import org.junit.Test
+import org.mockito.Mockito._
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+ @Test
+ def testJavaScalaTableAPIEquality(): Unit = {
+ // mock
+ val ds = mock(classOf[DataStream[Row]])
+ val jDs = mock(classOf[JDataStream[Row]])
+ val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
+ when(ds.javaStream).thenReturn(jDs)
+ when(jDs.getType).thenReturn(typeInfo)
+
+ // Scala environment
+ val env = mock(classOf[ScalaExecutionEnv])
+ val tableEnv = TableEnvironment.getTableEnvironment(env)
+ val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
+
+ // Java environment
+ val javaEnv = mock(classOf[JavaExecutionEnv])
+ val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
+ val in2 = javaTableEnv.fromDataStream(jDs).as("a, b, c")
+
+ // test cross apply
+ val func1 = new TableFunc1
+ javaTableEnv.registerFunction("func1", func1)
+ var scalaTable = in1.crossApply(func1('c) as 's).select('c, 's)
+ var javaTable = in2.crossApply("func1(c).as(s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test outer apply
+ scalaTable = in1.outerApply(func1('c) as 's).select('c, 's)
+ javaTable = in2.outerApply("as(func1(c), s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test overloading
+ scalaTable = in1.crossApply(func1('c, "$") as 's).select('c, 's)
+ javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test custom result type
+ val func2 = new TableFunc2
+ javaTableEnv.registerFunction("func2", func2)
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
+ javaTable = in2.crossApply("func2(c).as(name, len)").select("c, name, len")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test hierarchy generic type
+ val hierarchy = new HierarchyTableFunction
+ javaTableEnv.registerFunction("hierarchy", hierarchy)
+ scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'len, 'adult)
+ javaTable = in2.crossApply("AS(hierarchy(c), name, adult, len)")
+ .select("c, name, len, adult")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test pojo type
+ val pojo = new PojoTableFunc
+ javaTableEnv.registerFunction("pojo", pojo)
+ scalaTable = in1.crossApply(pojo('c))
+ .select('c, 'name, 'age)
+ javaTable = in2.crossApply("pojo(c)")
+ .select("c, name, age")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with filter
+ scalaTable = in1.crossApply(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len).filter('len > 2)
+ javaTable = in2.crossApply("func2(c) as (name, len)")
+ .select("c, name, len").filter("len > 2")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // test with scalar function
+ scalaTable = in1.crossApply(func1('c.substring(2)) as 's)
+ .select('a, 'c, 's)
+ javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
+ .select("a, c, s")
+ verifyTableEquals(scalaTable, javaTable)
+
+ // check scala object is forbidden
+ expectExceptionThrown(
+ tableEnv.registerFunction("func3", ObjectTableFunction), "Scala object")
+ expectExceptionThrown(
+ javaTableEnv.registerFunction("func3", ObjectTableFunction), "Scala object")
+ expectExceptionThrown(
+ in1.crossApply(ObjectTableFunction('a, 1)),"Scala object")
+
+ }
+
+ @Test
+ def testInvalidTableFunction(): Unit = {
+ // mock
+ val util = streamTestUtil()
+ val t = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val tEnv = TableEnvironment.getTableEnvironment(mock(classOf[JavaExecutionEnv]))
+
+ //=================== check scala object is forbidden =====================
+ // Scala table environment register
+ expectExceptionThrown(util.addFunction("udtf", ObjectTableFunction), "Scala object")
+ // Java table environment register
+ expectExceptionThrown(tEnv.registerFunction("udtf", ObjectTableFunction), "Scala object")
+ // Scala Table API directly call
+ expectExceptionThrown(t.crossApply(ObjectTableFunction('a, 1)), "Scala object")
+
+
+ //============ throw exception when table function is not registered =========
+ // Java Table API call
+ expectExceptionThrown(t.crossApply("nonexist(a)"), "Undefined function: NONEXIST")
+ // SQL API call
+ expectExceptionThrown(
+ util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(nonexist(a))"),
+ "No match found for function signature nonexist(<NUMERIC>)")
+
+
+ //========= throw exception when the called function is a scalar function ====
+ util.addFunction("func0", Func0)
+ // Java Table API call
+ expectExceptionThrown(
+ t.crossApply("func0(a)"),
+ "only accept expressions that define table functions",
+ classOf[TableException])
+ // SQL API call
+ // NOTE: it doesn't throw an exception but an AssertionError, maybe a Calcite bug
+ expectExceptionThrown(
+ util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func0(a))"),
+ null,
+ classOf[AssertionError])
+
+ //========== throw exception when the parameters is not correct ===============
+ // Java Table API call
+ util.addFunction("func2", new TableFunc2)
+ expectExceptionThrown(
+ t.crossApply("func2(c, c)"),
+ "Given parameters of function 'FUNC2' do not match any signature")
+ // SQL API call
+ expectExceptionThrown(
+ util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func2(c, c))"),
+ "No match found for function signature func2(<CHARACTER>, <CHARACTER>)")
+ }
+
+ @Test
+ def testCrossApply(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("func1", new TableFunc1)
+
+ val result1 = table.crossApply(function('c) as 's).select('c, 's)
+
+ val expected1 = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", s"$function($$2)"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "s")
+ )
+
+ util.verifyTable(result1, expected1)
+
+ // test overloading
+
+ val result2 = table.crossApply(function('c, "$") as 's).select('c, 's)
+
+ val expected2 = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", s"$function($$2, '$$')"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "s")
+ )
+
+ util.verifyTable(result2, expected2)
+ }
+
+ @Test
+ def testOuterApply(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("func1", new TableFunc1)
+
+ val result = table.outerApply(function('c) as 's).select('c, 's)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", s"$function($$2)"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
+ term("joinType", "LEFT")
+ ),
+ term("select", "c", "s")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testCustomType(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("func2", new TableFunc2)
+
+ val result = table.crossApply(function('c) as ('name, 'len)).select('c, 'name, 'len)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", s"$function($$2)"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) name, INTEGER len)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "name", "len")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testHierarchyType(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("hierarchy", new HierarchyTableFunction)
+
+ val result = table.crossApply(function('c) as ('name, 'adult, 'len))
+
+ val expected = unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", s"$function($$2)"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " VARCHAR(2147483647) name, BOOLEAN adult, INTEGER len)"),
+ term("joinType", "INNER")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testPojoType(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("pojo", new PojoTableFunc)
+
+ val result = table.crossApply(function('c))
+
+ val expected = unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", s"$function($$2)"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "INTEGER age, VARCHAR(2147483647) name)"),
+ term("joinType", "INNER")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testFilter(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("func2", new TableFunc2)
+
+ val result = table
+ .crossApply(function('c) as ('name, 'len))
+ .select('c, 'name, 'len)
+ .filter('len > 2)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", s"$function($$2)"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) name, INTEGER len)"),
+ term("joinType", "INNER"),
+ term("condition", ">($1, 2)")
+ ),
+ term("select", "c", "name", "len")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testScalarFunction(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = util.addFunction("func1", new TableFunc1)
+
+ val result = table.crossApply(function('c.substring(2)) as 's)
+
+ val expected = unaryNode(
+ "DataStreamCorrelate",
+ streamTableNode(0),
+ term("invocation", s"$function(SUBSTRING($$2, 2, CHAR_LENGTH($$2)))"),
+ term("function", function),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"),
+ term("joinType", "INNER")
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ private def expectExceptionThrown(
+ function: => Unit,
+ keywords: String,
+ clazz: Class[_ <: Throwable] = classOf[ValidationException])
+ : Unit = {
+ try {
+ function
+ fail(s"Expected a $clazz, but no exception is thrown.")
+ } catch {
+ case e if e.getClass == clazz =>
+ if (keywords != null) {
+ assertTrue(
+ s"The exception message '${e.getMessage}' doesn't contain keyword '$keywords'",
+ e.getMessage.contains(keywords))
+ }
+ case e: Throwable => fail(s"Expected throw ${clazz.getSimpleName}, but is $e.")
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
deleted file mode 100644
index 1e6bdb8..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.api.table.expressions.utils
-
-import java.lang.Boolean
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.tuple.Tuple3
-import org.apache.flink.api.table.Row
-import org.apache.flink.api.table.functions.TableFunction
-import org.apache.flink.api.table.typeutils.RowTypeInfo
-
-
-case class SimpleUser(name: String, age: Int)
-
-class TableFunc0 extends TableFunction[SimpleUser] {
- // make sure input element's format is "<string>#<int>"
- def eval(user: String): Unit = {
- if (user.contains("#")) {
- val splits = user.split("#")
- collect(SimpleUser(splits(0), splits(1).toInt))
- }
- }
-}
-
-class TableFunc1 extends TableFunction[String] {
- def eval(str: String): Unit = {
- if (str.contains("#")){
- str.split("#").foreach(collect)
- }
- }
-
- def eval(str: String, prefix: String): Unit = {
- if (str.contains("#")) {
- str.split("#").foreach(s => collect(prefix + s))
- }
- }
-}
-
-
-class TableFunc2 extends TableFunction[Row] {
- def eval(str: String): Unit = {
- if (str.contains("#")) {
- str.split("#").foreach({ s =>
- val row = new Row(2)
- row.setField(0, s)
- row.setField(1, s.length)
- collect(row)
- })
- }
- }
-
- override def getResultType: TypeInformation[Row] = {
- new RowTypeInfo(Seq(BasicTypeInfo.STRING_TYPE_INFO,
- BasicTypeInfo.INT_TYPE_INFO))
- }
-}
-
-class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] {
- def eval(user: String) {
- if (user.contains("#")) {
- val splits = user.split("#")
- val age = splits(1).toInt
- collect(new Tuple3[String, Boolean, Integer](splits(0), age >= 20, age))
- }
- }
-}
-
-abstract class SplittableTableFunction[A, B] extends TableFunction[Tuple3[String, A, B]] {}
-
-class PojoTableFunc extends TableFunction[PojoUser] {
- def eval(user: String) {
- if (user.contains("#")) {
- val splits = user.split("#")
- collect(new PojoUser(splits(0), splits(1).toInt))
- }
- }
-}
-
-class PojoUser() {
- var name: String = _
- var age: Int = 0
-
- def this(name: String, age: Int) {
- this()
- this.name = name
- this.age = age
- }
-}
-
-// ----------------------------------------------------------------------------------------------
-// Invalid Table Functions
-// ----------------------------------------------------------------------------------------------
-
-
-// this is used to check whether scala object is forbidden
-object ObjectTableFunction extends TableFunction[Integer] {
- def eval(a: Int, b: Int): Unit = {
- collect(a)
- collect(b)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/dataset/DataSetCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/dataset/DataSetCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/dataset/DataSetCorrelateITCase.scala
new file mode 100644
index 0000000..cc551f9
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/dataset/DataSetCorrelateITCase.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.runtime.dataset
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase
+import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils._
+import org.apache.flink.api.table.utils._
+import org.apache.flink.api.table.{Row, TableEnvironment}
+import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.apache.flink.test.util.TestBaseUtils
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class DataSetCorrelateITCase(
+ mode: TestExecutionMode,
+ configMode: TableConfigMode)
+ extends TableProgramsTestBase(mode, configMode) {
+
+ @Test
+ def testCrossApply(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func1 = new TableFunc1
+ val result = in.crossApply(func1('c) as 's).select('c, 's).toDataSet[Row]
+ val results = result.collect()
+ val expected = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+ "Anna#44,Anna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+
+ // with overloading
+ val result2 = in.crossApply(func1('c, "$") as 's).select('c, 's).toDataSet[Row]
+ val results2 = result2.collect()
+ val expected2 = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" +
+ "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n"
+ TestBaseUtils.compareResultAsText(results2.asJava, expected2)
+ }
+
+ @Test
+ def testOuterApply(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func2 = new TableFunc2
+ val result = in.outerApply(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row]
+ val results = result.collect()
+ val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testWithFilter(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = in
+ .crossApply(func0('c) as ('name, 'age))
+ .select('c, 'name, 'age)
+ .filter('age > 20)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testCustomReturnType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func2 = new TableFunc2
+
+ val result = in
+ .crossApply(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testHierarchyType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val hierarchy = new HierarchyTableFunction
+ val result = in
+ .crossApply(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'adult, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" +
+ "Anna#44,Anna,true,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testPojoType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val pojo = new PojoTableFunc()
+ val result = in
+ .crossApply(pojo('c))
+ .select('c, 'name, 'age)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testUDTFWithScalarFunction(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func1 = new TableFunc1
+
+ val result = in
+ .crossApply(func1('c.substring(2)) as 's)
+ .select('c, 's)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" +
+ "Anna#44,nna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ private def testData(
+ env: ExecutionEnvironment)
+ : DataSet[(Int, Long, String)] = {
+
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "Jack#22"))
+ data.+=((2, 2L, "John#19"))
+ data.+=((3, 2L, "Anna#44"))
+ data.+=((4, 3L, "nosharp"))
+ env.fromCollection(data)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/datastream/DataStreamCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/datastream/DataStreamCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/datastream/DataStreamCorrelateITCase.scala
new file mode 100644
index 0000000..c2c523a
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/datastream/DataStreamCorrelateITCase.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.runtime.datastream
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.stream.utils.StreamITCase
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.utils.TableFunc0
+import org.apache.flink.api.table.{Row, TableEnvironment}
+import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
+import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
+import org.junit.Assert._
+import org.junit.Test
+
+import scala.collection.mutable
+
+class DataStreamCorrelateITCase extends StreamingMultipleProgramsTestBase {
+
+ @Test
+ def testCrossApply(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = t
+ .crossApply(func0('c) as('d, 'e))
+ .select('c, 'd, 'e)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ @Test
+ def testOuterApply(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = t
+ .outerApply(func0('c) as('d, 'e))
+ .select('c, 'd, 'e)
+ .toDataStream[Row]
+
+ result.addSink(new StreamITCase.StringSink)
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "nosharp,null,null", "Jack#22,Jack,22",
+ "John#19,John,19", "Anna#44,Anna,44")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
+
+ private def testData(
+ env: StreamExecutionEnvironment)
+ : DataStream[(Int, Long, String)] = {
+
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "Jack#22"))
+ data.+=((2, 2L, "John#19"))
+ data.+=((3, 2L, "Anna#44"))
+ data.+=((4, 3L, "nosharp"))
+ env.fromCollection(data)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
index 73f50f5..4eaba90 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
@@ -45,9 +45,10 @@ class TableTestBase {
}
def verifyTableEquals(expected: Table, actual: Table): Unit = {
- assertEquals("Logical Plan do not match",
- RelOptUtil.toString(expected.getRelNode),
- RelOptUtil.toString(actual.getRelNode))
+ assertEquals(
+ "Logical plans do not match",
+ RelOptUtil.toString(expected.getRelNode),
+ RelOptUtil.toString(actual.getRelNode))
}
}
@@ -61,7 +62,7 @@ abstract class TableTestUtil {
}
def addTable[T: TypeInformation](name: String, fields: Expression*): Table
- def addFunction[T: TypeInformation](name: String, function: TableFunction[T]): Unit
+ def addFunction[T: TypeInformation](name: String, function: TableFunction[T]): TableFunction[T]
def addFunction(name: String, function: ScalarFunction): Unit
def verifySql(query: String, expected: String): Unit
@@ -132,8 +133,9 @@ case class BatchTableTestUtil() extends TableTestUtil {
def addFunction[T: TypeInformation](
name: String,
function: TableFunction[T])
- : Unit = {
+ : TableFunction[T] = {
tEnv.registerFunction(name, function)
+ function
}
def addFunction(name: String, function: ScalarFunction): Unit = {
@@ -188,8 +190,9 @@ case class StreamTableTestUtil() extends TableTestUtil {
def addFunction[T: TypeInformation](
name: String,
function: TableFunction[T])
- : Unit = {
+ : TableFunction[T] = {
tEnv.registerFunction(name, function)
+ function
}
def addFunction(name: String, function: ScalarFunction): Unit = {
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/UserDefinedTableFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/UserDefinedTableFunctions.scala
new file mode 100644
index 0000000..3da3857
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/UserDefinedTableFunctions.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.utils
+
+import java.lang.Boolean
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.tuple.Tuple3
+import org.apache.flink.api.table.Row
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+
+
+case class SimpleUser(name: String, age: Int)
+
+class TableFunc0 extends TableFunction[SimpleUser] {
+ // make sure input element's format is "<string>#<int>"
+ def eval(user: String): Unit = {
+ if (user.contains("#")) {
+ val splits = user.split("#")
+ collect(SimpleUser(splits(0), splits(1).toInt))
+ }
+ }
+}
+
+class TableFunc1 extends TableFunction[String] {
+ def eval(str: String): Unit = {
+ if (str.contains("#")){
+ str.split("#").foreach(collect)
+ }
+ }
+
+ def eval(str: String, prefix: String): Unit = {
+ if (str.contains("#")) {
+ str.split("#").foreach(s => collect(prefix + s))
+ }
+ }
+}
+
+
+class TableFunc2 extends TableFunction[Row] {
+ def eval(str: String): Unit = {
+ if (str.contains("#")) {
+ str.split("#").foreach({ s =>
+ val row = new Row(2)
+ row.setField(0, s)
+ row.setField(1, s.length)
+ collect(row)
+ })
+ }
+ }
+
+ override def getResultType: TypeInformation[Row] = {
+ new RowTypeInfo(Seq(BasicTypeInfo.STRING_TYPE_INFO,
+ BasicTypeInfo.INT_TYPE_INFO))
+ }
+}
+
+class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] {
+ def eval(user: String) {
+ if (user.contains("#")) {
+ val splits = user.split("#")
+ val age = splits(1).toInt
+ collect(new Tuple3[String, Boolean, Integer](splits(0), age >= 20, age))
+ }
+ }
+}
+
+abstract class SplittableTableFunction[A, B] extends TableFunction[Tuple3[String, A, B]] {}
+
+class PojoTableFunc extends TableFunction[PojoUser] {
+ def eval(user: String) {
+ if (user.contains("#")) {
+ val splits = user.split("#")
+ collect(new PojoUser(splits(0), splits(1).toInt))
+ }
+ }
+}
+
+class PojoUser() {
+ var name: String = _
+ var age: Int = 0
+
+ def this(name: String, age: Int) {
+ this()
+ this.name = name
+ this.age = age
+ }
+}
+
+// ----------------------------------------------------------------------------------------------
+// Invalid Table Functions
+// ----------------------------------------------------------------------------------------------
+
+
+// this is used to check whether scala object is forbidden
+object ObjectTableFunction extends TableFunction[Integer] {
+ def eval(a: Int, b: Int): Unit = {
+ collect(a)
+ collect(b)
+ }
+}
[5/5] flink git commit: [FLINK-4469] [table] Minor improvements
Posted by tw...@apache.org.
[FLINK-4469] [table] Minor improvements
- Fixed typos
- Removed implicit conversion with TableCallBuilder
- Fixed bugs about expression parser alias and static eval methods
- Refactored tests
This closes #2653.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/684defbf
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/684defbf
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/684defbf
Branch: refs/heads/master
Commit: 684defbf33168e34657bc1a25607adb53be248c5
Parents: e139f59
Author: twalthr <tw...@apache.org>
Authored: Tue Dec 6 17:46:54 2016 +0100
Committer: twalthr <tw...@apache.org>
Committed: Wed Dec 7 16:55:37 2016 +0100
----------------------------------------------------------------------
docs/dev/table_api.md | 22 +
.../api/java/table/BatchTableEnvironment.scala | 3 +-
.../api/java/table/StreamTableEnvironment.scala | 3 +-
.../api/scala/table/BatchTableEnvironment.scala | 6 +-
.../scala/table/TableFunctionCallBuilder.scala | 39 --
.../flink/api/scala/table/expressionDsl.scala | 10 +-
.../flink/api/table/TableEnvironment.scala | 10 +-
.../flink/api/table/codegen/CodeGenerator.scala | 13 +-
.../codegen/calls/TableFunctionCallGen.scala | 3 +-
.../table/expressions/ExpressionParser.scala | 12 +-
.../flink/api/table/expressions/call.scala | 9 +-
.../api/table/functions/ScalarFunction.scala | 7 +-
.../api/table/functions/TableFunction.scala | 31 +-
.../functions/utils/ScalarSqlFunction.scala | 2 +-
.../functions/utils/TableSqlFunction.scala | 15 +-
.../utils/UserDefinedFunctionUtils.scala | 18 +-
.../api/table/plan/ProjectionTranslator.scala | 4 +-
.../api/table/plan/logical/operators.scala | 23 +-
.../api/table/plan/nodes/FlinkCorrelate.scala | 7 +-
.../plan/nodes/dataset/DataSetCorrelate.scala | 2 +-
.../nodes/datastream/DataStreamCorrelate.scala | 3 +-
.../rules/dataSet/DataSetCorrelateRule.scala | 7 +-
.../datastream/DataStreamCorrelateRule.scala | 12 +-
.../org/apache/flink/api/table/table.scala | 56 ++-
.../api/table/validate/FunctionCatalog.scala | 10 +-
.../src/test/resources/log4j-test.properties | 2 +-
.../batch/UserDefinedTableFunctionITCase.scala | 212 ----------
.../batch/UserDefinedTableFunctionTest.scala | 320 ---------------
.../sql/UserDefinedTableFunctionTest.scala | 238 +++++++++++
.../table/UserDefinedTableFunctionTest.scala | 179 +++++++++
.../stream/UserDefinedTableFunctionITCase.scala | 181 ---------
.../stream/UserDefinedTableFunctionTest.scala | 402 -------------------
.../sql/UserDefinedTableFunctionTest.scala | 237 +++++++++++
.../table/UserDefinedTableFunctionTest.scala | 385 ++++++++++++++++++
.../utils/UserDefinedTableFunctions.scala | 116 ------
.../dataset/DataSetCorrelateITCase.scala | 177 ++++++++
.../datastream/DataStreamCorrelateITCase.scala | 90 +++++
.../flink/api/table/utils/TableTestBase.scala | 15 +-
.../table/utils/UserDefinedTableFunctions.scala | 117 ++++++
39 files changed, 1611 insertions(+), 1387 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/docs/dev/table_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md
index 848f9e4..6cf0dee 100644
--- a/docs/dev/table_api.md
+++ b/docs/dev/table_api.md
@@ -1494,6 +1494,17 @@ Both the Table API and SQL come with a set of built-in functions for data transf
<tr>
<td>
{% highlight java %}
+ANY.as(name [, name ]* )
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Specifies a name for an expression i.e. a field. Additional names can be specified if the expression expands to multiple fields.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight java %}
ANY.isNull
{% endhighlight %}
</td>
@@ -2045,6 +2056,17 @@ COMPOSITE.get(INT)
<tr>
<td>
{% highlight scala %}
+ANY.as(name [, name ]* )
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Specifies a name for an expression i.e. a field. Additional names can be specified if the expression expands to multiple fields.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight scala %}
ANY.isNull
{% endhighlight %}
</td>
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
index b353377..3517338 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala
@@ -168,7 +168,8 @@ class BatchTableEnvironment(
* Registered functions can be referenced in Table API and SQL queries.
*
* @param name The name under which the function is registered.
- * @param tf The TableFunction to register
+ * @param tf The TableFunction to register.
+ * @tparam T The type of the output row.
*/
def registerFunction[T](name: String, tf: TableFunction[T]): Unit = {
implicit val typeInfo: TypeInformation[T] = TypeExtractor
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
index 367cb82..83293e3 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala
@@ -170,7 +170,8 @@ class StreamTableEnvironment(
* Registered functions can be referenced in Table API and SQL queries.
*
* @param name The name under which the function is registered.
- * @param tf The TableFunction to register
+ * @param tf The TableFunction to register.
+ * @tparam T The type of the output row.
*/
def registerFunction[T](name: String, tf: TableFunction[T]): Unit = {
implicit val typeInfo: TypeInformation[T] = TypeExtractor
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
index 36885d2..f4bfe31 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala
@@ -142,13 +142,13 @@ class BatchTableEnvironment(
/**
* Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
- * Registered functions can be referenced in SQL queries.
+ * Registered functions can be referenced in Table API and SQL queries.
*
* @param name The name under which the function is registered.
- * @param tf The TableFunction to register
+ * @param tf The TableFunction to register.
+ * @tparam T The type of the output row.
*/
def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
registerTableFunctionInternal(name, tf)
}
-
}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
deleted file mode 100644
index 2261b70..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.api.scala.table
-
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.table.expressions.{Expression, TableFunctionCall}
-import org.apache.flink.api.table.functions.TableFunction
-
-case class TableFunctionCallBuilder[T: TypeInformation](udtf: TableFunction[T]) {
- /**
- * Creates a call to a [[TableFunction]] in Scala Table API.
- *
- * @param params actual parameters of function
- * @return [[TableFunctionCall]]
- */
- def apply(params: Expression*): Expression = {
- val resultType = if (udtf.getResultType == null) {
- implicitly[TypeInformation[T]]
- } else {
- udtf.getResultType
- }
- TableFunctionCall(udtf.getClass.getSimpleName, udtf, params, resultType)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
index cc4c68d..175ce2e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
@@ -24,7 +24,6 @@ import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, toMonthInterval, toRowInterval}
import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit
import org.apache.flink.api.table.expressions._
-import org.apache.flink.api.table.functions.TableFunction
import scala.language.implicitConversions
@@ -98,6 +97,13 @@ trait ImplicitExpressionOperations {
def cast(toType: TypeInformation[_]) = Cast(expr, toType)
+ /**
+ * Specifies a name for an expression i.e. a field.
+ *
+ * @param name name for one field
+ * @param extraNames additional names if the expression expands to multiple fields
+ * @return field with an alias
+ */
def as(name: Symbol, extraNames: Symbol*) = Alias(expr, name.name, extraNames.map(_.name))
def asc = Asc(expr)
@@ -540,8 +546,6 @@ trait ImplicitExpressionConversions {
implicit def sqlDate2Literal(sqlDate: Date): Expression = Literal(sqlDate)
implicit def sqlTime2Literal(sqlTime: Time): Expression = Literal(sqlTime)
implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp)
- implicit def UDTF2TableFunctionCall[T: TypeInformation](udtf: TableFunction[T]):
- TableFunctionCallBuilder[T] = TableFunctionCallBuilder(udtf)
}
// ------------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
index 8cabadb..b6d0e31 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala
@@ -24,8 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.calcite.config.Lex
import org.apache.calcite.plan.RelOptPlanner
import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rex.RexExecutorImpl
-import org.apache.calcite.schema.{SchemaPlus, Schemas}
+import org.apache.calcite.schema.SchemaPlus
import org.apache.calcite.schema.impl.AbstractTable
import org.apache.calcite.sql.SqlOperatorTable
import org.apache.calcite.sql.parser.SqlParser
@@ -158,7 +157,7 @@ abstract class TableEnvironment(val config: TableConfig) {
* user-defined functions under this name.
*/
def registerFunction(name: String, function: ScalarFunction): Unit = {
- // check could be instantiated
+ // check if class could be instantiated
checkForInstantiation(function.getClass)
// register in Table API
@@ -174,9 +173,9 @@ abstract class TableEnvironment(val config: TableConfig) {
*/
private[flink] def registerTableFunctionInternal[T: TypeInformation](
name: String, function: TableFunction[T]): Unit = {
- // check not Scala object
+ // check if class not Scala object
checkNotSingleton(function.getClass)
- // check could be instantiated
+ // check if class could be instantiated
checkForInstantiation(function.getClass)
val typeInfo: TypeInformation[_] = if (function.getResultType != null) {
@@ -187,6 +186,7 @@ abstract class TableEnvironment(val config: TableConfig) {
// register in Table API
functionCatalog.registerFunction(name, function.getClass)
+
// register in SQL API
val sqlFunctions = createTableSqlFunctions(name, function, typeInfo, typeFactory)
functionCatalog.registerSqlFunctions(sqlFunctions)
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
index 9e4f569..f7d6863 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
@@ -32,6 +32,7 @@ import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, TupleTypeInfo}
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.table.codegen.CodeGenUtils._
+import org.apache.flink.api.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.api.table.codegen.Indenter.toISC
import org.apache.flink.api.table.codegen.calls.FunctionGenerator
import org.apache.flink.api.table.codegen.calls.ScalarOperators._
@@ -358,10 +359,11 @@ class CodeGenerator(
val input2AccessExprs = input2 match {
case Some(ti) => for (i <- 0 until ti.getArity)
- // use generateFieldAccess instead of generateInputAccess to avoid the generated table
- // function's field access code is put on the top of function body rather than the while loop
+ // use generateFieldAccess instead of generateInputAccess to avoid the generated table
+ // function's field access code is put on the top of function body rather than
+ // the while loop
yield generateFieldAccess(ti, input2Term, i, input2PojoFieldMapping)
- case None => throw new CodeGenException("type information of input2 must not be null")
+ case None => throw new CodeGenException("Type information of input2 must not be null.")
}
(input1AccessExprs, input2AccessExprs)
}
@@ -781,7 +783,7 @@ class CodeGenerator(
}
override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = {
- GeneratedExpression(input1Term, GeneratedExpression.NEVER_NULL, "", input1)
+ GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1)
}
override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression =
@@ -1019,8 +1021,7 @@ class CodeGenerator(
case None =>
val expr = if (nullableInput) {
generateNullableInputFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
- }
- else {
+ } else {
generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
index 27cb43f..37e70e4 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala
@@ -20,6 +20,7 @@ package org.apache.flink.api.table.codegen.calls
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.codegen.CodeGenUtils._
+import org.apache.flink.api.table.codegen.GeneratedExpression.NEVER_NULL
import org.apache.flink.api.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression}
import org.apache.flink.api.table.functions.TableFunction
import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._
@@ -75,7 +76,7 @@ class TableFunctionCallGen(
// has no result
GeneratedExpression(
functionReference,
- GeneratedExpression.NEVER_NULL,
+ NEVER_NULL,
functionCallCode,
returnType)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
index 6cd63ff..a926717 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
@@ -204,8 +204,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
}
lazy val suffixAs: PackratParser[Expression] =
- composite ~ "." ~ AS ~ "(" ~ fieldReference ~ ")" ^^ {
- case e ~ _ ~ _ ~ _ ~ target ~ _ => Alias(e, target.name)
+ composite ~ "." ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
+ case e ~ _ ~ _ ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name))
}
lazy val suffixTrim = composite ~ "." ~ TRIM ~ "(" ~ trimMode ~ "," ~ expression ~ ")" ^^ {
@@ -325,8 +325,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
}
lazy val prefixAs: PackratParser[Expression] =
- AS ~ "(" ~ expression ~ "," ~ fieldReference ~ ")" ^^ {
- case _ ~ _ ~ e ~ _ ~ target ~ _ => Alias(e, target.name)
+ AS ~ "(" ~ expression ~ "," ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
+ case _ ~ _ ~ e ~ _ ~ target ~ _ => Alias(e, target.head.name, target.tail.map(_.name))
}
lazy val prefixIf: PackratParser[Expression] =
@@ -447,8 +447,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val alias: PackratParser[Expression] = logic ~ AS ~ fieldReference ^^ {
case e ~ _ ~ name => Alias(e, name.name)
- } | logic ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
- case e ~ _ ~ _ ~ names ~ _ => Alias(e, names.head.name, names.drop(1).map(_.name))
+ } | logic ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ {
+ case e ~ _ ~ _ ~ names ~ _ => Alias(e, names.head.name, names.tail.map(_.name))
} | logic
lazy val expression: PackratParser[Expression] = alias |
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
index 3e8d8b1..3bb9dac 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
@@ -89,10 +89,8 @@ case class ScalarFunctionCall(
ValidationSuccess
}
}
-
}
-
/**
*
* Expression for calling a user-defined table function with actual parameters.
@@ -114,10 +112,10 @@ case class TableFunctionCall(
override private[flink] def children: Seq[Expression] = parameters
/**
- * Assigns an alias for this table function returned fields that the following `select()` clause
+ * Assigns an alias for this table function's returned fields that the following operator
* can refer to.
*
- * @param aliasList alias for this table function returned fields
+ * @param aliasList alias for this table function's returned fields
* @return this table function call
*/
private[flink] def as(aliasList: Option[Seq[String]]): TableFunctionCall = {
@@ -155,4 +153,7 @@ case class TableFunctionCall(
fieldNames,
child)
}
+
+ override def toString =
+ s"${tableFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})"
}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
index 86d9d66..2e16096 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala
@@ -18,15 +18,11 @@
package org.apache.flink.api.table.functions
-import java.lang.reflect.{Method, Modifier}
-
-import org.apache.calcite.sql.SqlFunction
import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.flink.api.table.ValidationException
import org.apache.flink.api.table.expressions.{Expression, ScalarFunctionCall}
-import org.apache.flink.api.table.functions.utils.ScalarSqlFunction
-import org.apache.flink.api.table.{FlinkTypeFactory, ValidationException}
/**
* Base class for a user-defined scalar function. A user-defined scalar functions maps zero, one,
@@ -60,6 +56,7 @@ abstract class ScalarFunction extends UserDefinedFunction {
ScalarFunctionCall(this, params)
}
+ override def toString: String = getClass.getCanonicalName
// ----------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
index 98a2921..3a56efb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala
@@ -20,18 +20,16 @@ package org.apache.flink.api.table.functions
import java.util
-import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.java.typeutils.TypeExtractor
-import org.apache.flink.api.table.ValidationException
+import org.apache.flink.api.table.expressions.{Expression, TableFunctionCall}
/**
* Base class for a user-defined table function (UDTF). A user-defined table functions works on
* zero, one, or multiple scalar values as input and returns multiple rows as output.
*
* The behavior of a [[TableFunction]] can be defined by implementing a custom evaluation
- * method. An evaluation method must be declared publicly and named "eval". Evaluation methods
- * can also be overloaded by implementing multiple methods named "eval".
+ * method. An evaluation method must be declared publicly, not static and named "eval".
+ * Evaluation methods can also be overloaded by implementing multiple methods named "eval".
*
* User-defined functions must have a default constructor and must be instantiable during runtime.
*
@@ -51,14 +49,14 @@ import org.apache.flink.api.table.ValidationException
*
* public class Split extends TableFunction<String> {
*
- * // implement an "eval" method with several parameters you want
+ * // implement an "eval" method with as many parameters as you want
* public void eval(String str) {
* for (String s : str.split(" ")) {
* collect(s); // use collect(...) to emit an output row
* }
* }
*
- * // can overloading eval method here ...
+ * // you can overload the eval method here ...
* }
*
* val tEnv: TableEnvironment = ...
@@ -82,6 +80,25 @@ import org.apache.flink.api.table.ValidationException
*/
abstract class TableFunction[T] extends UserDefinedFunction {
+ /**
+ * Creates a call to a [[TableFunction]] in Scala Table API.
+ *
+ * @param params actual parameters of function
+ * @return [[Expression]] in form of a [[TableFunctionCall]]
+ */
+ final def apply(params: Expression*)(implicit typeInfo: TypeInformation[T]): Expression = {
+ val resultType = if (getResultType == null) {
+ typeInfo
+ } else {
+ getResultType
+ }
+ TableFunctionCall(getClass.getSimpleName, this, params, resultType)
+ }
+
+ override def toString: String = getClass.getCanonicalName
+
+ // ----------------------------------------------------------------------------------------------
+
private val rows: util.ArrayList[T] = new util.ArrayList[T]()
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
index 0a987aa..7953b25 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala
@@ -125,6 +125,7 @@ object ScalarSqlFunction {
: SqlOperandTypeChecker = {
val signatures = getSignatures(scalarFunction)
+
/**
* Operand type checker based on [[ScalarFunction]] given information.
*/
@@ -178,5 +179,4 @@ object ScalarSqlFunction {
}
}
}
-
}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
index 6eadfbc..738238d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala
@@ -33,7 +33,6 @@ import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl
import scala.collection.JavaConverters._
import java.util
-
/**
* Calcite wrapper for user-defined table functions.
*/
@@ -55,31 +54,33 @@ class TableSqlFunction(
functionImpl) {
/**
- * Get the user-defined table function
+ * Get the user-defined table function.
*/
def getTableFunction = udtf
/**
- * Get the returned table type information of the table function
+ * Get the type information of the table returned by the table function.
*/
def getRowTypeInfo = rowTypeInfo
/**
* Get additional mapping information if the returned table type is a POJO
- * (POJO types have no deterministic field order)
+ * (POJO types have no deterministic field order).
*/
def getPojoFieldMapping = functionImpl.fieldIndexes
}
object TableSqlFunction {
+
/**
- * Util function to create a [[TableSqlFunction]]
+ * Util function to create a [[TableSqlFunction]].
+ *
* @param name function name (used by SQL parser)
- * @param udtf user defined table function to be called
+ * @param udtf user-defined table function to be called
* @param rowTypeInfo the row type information generated by the table function
* @param typeFactory type factory for converting Flink's between Calcite's types
- * @param functionImpl calcite table function schema
+ * @param functionImpl Calcite table function schema
* @return [[TableSqlFunction]]
*/
def apply(
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
index 932baeb..4899691 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -141,13 +141,17 @@ object UserDefinedFunctionUtils {
.getDeclaredMethods
.filter { m =>
val modifiers = m.getModifiers
- m.getName == "eval" && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers)
+ m.getName == "eval" &&
+ Modifier.isPublic(modifiers) &&
+ !Modifier.isAbstract(modifiers) &&
+ !(function.isInstanceOf[TableFunction[_]] && Modifier.isStatic(modifiers))
}
if (methods.isEmpty) {
throw new ValidationException(
s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
- s"one method named 'eval' which is public and not abstract.")
+ s"one method named 'eval' which is public, not abstract and " +
+ s"(in case of table functions) not static.")
} else {
methods
}
@@ -158,7 +162,7 @@ object UserDefinedFunctionUtils {
}
// ----------------------------------------------------------------------------------------------
- // Utilities for sql functions
+ // Utilities for SQL functions
// ----------------------------------------------------------------------------------------------
/**
@@ -255,7 +259,7 @@ object UserDefinedFunctionUtils {
* Field names are automatically extracted for
* [[org.apache.flink.api.common.typeutils.CompositeType]].
*
- * @param inputType The TypeInformation extract the field names and positions from.
+ * @param inputType The TypeInformation to extract the field names and positions from.
* @return A tuple of two arrays holding the field names and corresponding field positions.
*/
def getFieldInfo(inputType: TypeInformation[_])
@@ -265,8 +269,8 @@ object UserDefinedFunctionUtils {
case t: CompositeType[_] => t.getFieldNames
case a: AtomicType[_] => Array("f0")
case tpe =>
- throw new TableException(s"Currently only support CompositeType and AtomicType. " +
- s"Type $tpe lacks explicit field naming")
+ throw new TableException(s"Currently only CompositeType and AtomicType are supported. " +
+ s"Type $tpe lacks explicit field naming")
}
val fieldIndexes = fieldNames.indices.toArray
val fieldTypes: Array[TypeInformation[_]] = fieldNames.map { i =>
@@ -274,7 +278,7 @@ object UserDefinedFunctionUtils {
case t: CompositeType[_] => t.getTypeAt(i).asInstanceOf[TypeInformation[_]]
case a: AtomicType[_] => a.asInstanceOf[TypeInformation[_]]
case tpe =>
- throw new TableException(s"Currently only support CompositeType and AtomicType.")
+ throw new TableException(s"Currently only CompositeType and AtomicType are supported.")
}
}
(fieldNames, fieldIndexes, fieldTypes)
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
index f6ddeef..c093f1a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
@@ -49,9 +49,7 @@ object ProjectionTranslator {
val replaced = exprs
.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
- .map {
- case e: Expression => UnresolvedAlias(e)
- }
+ .map(UnresolvedAlias)
val aggs = aggNames.map( a => Alias(a._1, a._2)).toSeq
val props = propNames.map( p => Alias(p._1, p._2)).toSeq
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
index 4dc2ab7..438698a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala
@@ -428,7 +428,6 @@ case class Join(
right.construct(relBuilder)
val corSet = mutable.Set[CorrelationId]()
-
if (correlated) {
corSet += relBuilder.peek().getCluster.createCorrel()
}
@@ -624,9 +623,9 @@ case class WindowAggregate(
}
}
-
/**
* LogicalNode for calling a user-defined table functions.
+ *
* @param functionName function name
* @param tableFunction table function to be called (might be overloaded)
* @param parameters actual parameters
@@ -634,16 +633,16 @@ case class WindowAggregate(
* @param child child logical node
*/
case class LogicalTableFunctionCall(
- functionName: String,
- tableFunction: TableFunction[_],
- parameters: Seq[Expression],
- resultType: TypeInformation[_],
- fieldNames: Array[String],
- child: LogicalNode)
+ functionName: String,
+ tableFunction: TableFunction[_],
+ parameters: Seq[Expression],
+ resultType: TypeInformation[_],
+ fieldNames: Array[String],
+ child: LogicalNode)
extends UnaryNode {
- val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType)
- var evalMethod: Method = _
+ private val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType)
+ private var evalMethod: Method = _
override def output: Seq[Attribute] = fieldNames.zip(fieldTypes).map {
case (n, t) => ResolvedFieldReference(n, t)
@@ -651,9 +650,9 @@ case class LogicalTableFunctionCall(
override def validate(tableEnv: TableEnvironment): LogicalNode = {
val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall]
- // check not Scala object
+ // check if not Scala object
checkNotSingleton(tableFunction.getClass)
- // check could be instantiated
+ // check if class could be instantiated
checkForInstantiation(tableFunction.getClass)
// look for a signature that matches the input types
val signature = node.parameters.map(_.resultType)
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
index 9745be1..93a8f53 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression, GeneratedFunction}
import org.apache.flink.api.table.codegen.CodeGenUtils.primitiveDefaultValue
+import org.apache.flink.api.table.codegen.GeneratedExpression.{ALWAYS_NULL, NO_CODE}
import org.apache.flink.api.table.functions.utils.TableSqlFunction
import org.apache.flink.api.table.runtime.FlatMapRunner
import org.apache.flink.api.table.typeutils.TypeConverter._
@@ -73,12 +74,12 @@ trait FlinkCorrelate {
// outer apply
// in case of outer apply and the returned row of table function is empty,
- // fill null to all fields of the row
+ // fill all fields of row with null
val input2NullExprs = input2AccessExprs.map { x =>
GeneratedExpression(
primitiveDefaultValue(x.resultType),
- GeneratedExpression.ALWAYS_NULL,
- "",
+ ALWAYS_NULL,
+ NO_CODE,
x.resultType)
}
val outerResultExpr = generator.generateResultExpression(
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
index 4aa7fea..3cddf8b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala
@@ -99,7 +99,7 @@ class DataSetCorrelate(
config.getNullCheck,
config.getEfficientTypeUsage)
- // do not need to specify input type
+ // we do not need to specify input type
val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
val funcRel = scan.asInstanceOf[LogicalTableFunctionScan]
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
index b0bc48a..028cb10 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala
@@ -48,6 +48,7 @@ class DataStreamCorrelate(
extends SingleRel(cluster, traitSet, inputNode)
with FlinkCorrelate
with DataStreamRel {
+
override def deriveRowType() = relRowType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
@@ -92,7 +93,7 @@ class DataStreamCorrelate(
config.getNullCheck,
config.getEfficientTypeUsage)
- // do not need to specify input type
+ // we do not need to specify input type
val inputDS = inputNode.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
val funcRel = scan.asInstanceOf[LogicalTableFunctionScan]
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
index e6cf0cf..39756be 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala
@@ -33,8 +33,7 @@ class DataSetCorrelateRule
classOf[LogicalCorrelate],
Convention.NONE,
DataSetConvention.INSTANCE,
- "DataSetCorrelateRule")
- {
+ "DataSetCorrelateRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate]
@@ -46,7 +45,9 @@ class DataSetCorrelateRule
case scan: LogicalTableFunctionScan => true
// a filter is pushed above the table function
case filter: LogicalFilter =>
- filter.getInput.asInstanceOf[RelSubset].getOriginal
+ filter
+ .getInput.asInstanceOf[RelSubset]
+ .getOriginal
.isInstanceOf[LogicalTableFunctionScan]
case _ => false
}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
index bb52fd7..554c6c1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala
@@ -33,8 +33,7 @@ class DataStreamCorrelateRule
classOf[LogicalCorrelate],
Convention.NONE,
DataStreamConvention.INSTANCE,
- "DataStreamCorrelateRule")
-{
+ "DataStreamCorrelateRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate]
@@ -45,7 +44,9 @@ class DataStreamCorrelateRule
case scan: LogicalTableFunctionScan => true
// a filter is pushed above the table function
case filter: LogicalFilter =>
- filter.getInput.asInstanceOf[RelSubset].getOriginal
+ filter
+ .getInput.asInstanceOf[RelSubset]
+ .getOriginal
.isInstanceOf[LogicalTableFunctionScan]
case _ => false
}
@@ -63,8 +64,9 @@ class DataStreamCorrelateRule
convertToCorrelate(rel.getRelList.get(0), condition)
case filter: LogicalFilter =>
- convertToCorrelate(filter.getInput.asInstanceOf[RelSubset].getOriginal,
- Some(filter.getCondition))
+ convertToCorrelate(
+ filter.getInput.asInstanceOf[RelSubset].getOriginal,
+ Some(filter.getCondition))
case scan: LogicalTableFunctionScan =>
new DataStreamCorrelate(
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
index a75f2fc..b421c8e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
@@ -611,9 +611,9 @@ class Table(
}
/**
- * The Cross Apply returns rows from the outer table (table on the left of the Apply operator)
- * that produces matching values from the table-valued function (which is on the right side of
- * the operator).
+ * The Cross Apply operator returns rows from the outer table (table on the left of the
+ * operator) that produces matching values from the table-valued function (which is defined in
+ * the expression on the right side of the operator).
*
* The Cross Apply is equivalent to Inner Join, but it works with a table-valued function.
*
@@ -635,23 +635,25 @@ class Table(
}
/**
- * The Cross Apply returns rows from the outer table (table on the left of the Apply operator)
- * that produces matching values from the table-valued function (which is on the right side of
- * the operator).
+ * The Cross Apply operator returns rows from the outer table (table on the left of the
+ * operator) that produces matching values from the table-valued function (which is defined in
+ * the expression on the right side of the operator).
*
* The Cross Apply is equivalent to Inner Join, but it works with a table-valued function.
*
* Example:
*
* {{{
- * class MySplitUDTF extends TableFunction[String] {
- * def eval(str: String): Unit = {
- * str.split("#").foreach(collect)
+ * class MySplitUDTF extends TableFunction<String> {
+ * public void eval(String str) {
+ * str.split("#").forEach(this::collect);
* }
* }
*
- * val split = new MySplitUDTF()
- * table.crossApply("split(c) as (s)").select("a, b, c, s")
+ * TableFunction<String> split = new MySplitUDTF();
+ * tableEnv.registerFunction("split", split);
+ *
+ * table.crossApply("split(c) as (s)").select("a, b, c, s");
* }}}
*/
def crossApply(udtf: String): Table = {
@@ -659,9 +661,10 @@ class Table(
}
/**
- * The Outer Apply returns all the rows from the outer table (table on the left of the Apply
- * operator), and rows that do not matches the condition from the table-valued function (which
- * is on the right side of the operator), NULL values are displayed.
+ * The Outer Apply operator returns all the rows from the outer table (table on the left of the
+ * Apply operator), and rows that do not match the condition from the table-valued function
+ * (which is defined in the expression on the right side of the operator).
+ * Rows with no matching condition are filled with null values.
*
* The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function.
*
@@ -683,17 +686,26 @@ class Table(
}
/**
- * The Outer Apply returns all the rows from the outer table (table on the left of the Apply
- * operator), and rows that do not matches the condition from the table-valued function (which
- * is on the right side of the operator), NULL values are displayed.
+ * The Outer Apply operator returns all the rows from the outer table (table on the left of the
+ * Apply operator), and rows that do not match the condition from the table-valued function
+ * (which is defined in the expression on the right side of the operator).
+ * Rows with no matching condition are filled with null values.
*
* The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function.
*
* Example:
*
* {{{
- * val split = new MySplitUDTF()
- * table.outerApply("split(c) as (s)").select("a, b, c, s")
+ * class MySplitUDTF extends TableFunction<String> {
+ * public void eval(String str) {
+ * str.split("#").forEach(this::collect);
+ * }
+ * }
+ *
+ * TableFunction<String> split = new MySplitUDTF();
+ * tableEnv.registerFunction("split", split);
+ *
+ * table.outerApply("split(c) as (s)").select("a, b, c, s");
* }}}
*/
def outerApply(udtf: String): Table = {
@@ -708,7 +720,7 @@ class Table(
private def applyInternal(udtf: Expression, joinType: JoinType): Table = {
var alias: Option[Seq[String]] = None
- // unwrap an Expression until get a TableFunctionCall
+ // unwrap an Expression until we get a TableFunctionCall
def unwrap(expr: Expression): TableFunctionCall = expr match {
case Alias(child, name, extraNames) =>
alias = Some(Seq(name) ++ extraNames)
@@ -717,7 +729,9 @@ class Table(
val function = tableEnv.getFunctionCatalog.lookupFunction(name, args)
unwrap(function)
case c: TableFunctionCall => c
- case _ => throw new TableException("Cross/Outer Apply only accept TableFunction")
+ case _ =>
+ throw new TableException(
+ "Cross/Outer Apply operators only accept expressions that define table functions.")
}
val call = unwrap(udtf)
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
index 4029a7d..dc68b89 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala
@@ -47,16 +47,18 @@ class FunctionCatalog {
sqlFunctions += sqlFunction
}
- /** Register multiple sql functions at one time. The functions has the same name. **/
+ /**
+ * Register multiple SQL functions at the same time. The functions have the same name.
+ */
def registerSqlFunctions(functions: Seq[SqlFunction]): Unit = {
if (functions.nonEmpty) {
val name = functions.head.getName
- // check all name is the same in the functions
+ // check that all functions have the same name
if (functions.forall(_.getName == name)) {
sqlFunctions --= sqlFunctions.filter(_.getName == name)
sqlFunctions ++= functions
} else {
- throw ValidationException("The sql functions request to register have different name.")
+ throw ValidationException("The SQL functions to be registered have different names.")
}
}
}
@@ -88,7 +90,7 @@ class FunctionCatalog {
case tf if classOf[TableFunction[_]].isAssignableFrom(tf) =>
val tableSqlFunction = sqlFunctions
.find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[TableSqlFunction])
- .getOrElse(throw ValidationException(s"Unregistered table sql function: $name"))
+ .getOrElse(throw ValidationException(s"Undefined table function: $name"))
.asInstanceOf[TableSqlFunction]
val typeInfo = tableSqlFunction.getRowTypeInfo
val function = tableSqlFunction.getTableFunction
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/resources/log4j-test.properties
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/resources/log4j-test.properties b/flink-libraries/flink-table/src/test/resources/log4j-test.properties
index 4c74d85..f713aa8 100644
--- a/flink-libraries/flink-table/src/test/resources/log4j-test.properties
+++ b/flink-libraries/flink-table/src/test/resources/log4j-test.properties
@@ -18,7 +18,7 @@
# Set root logger level to OFF to not flood build logs
# set manually to INFO for debugging purposes
-log4j.rootLogger=OFF, testlogger
+log4j.rootLogger=INFO, testlogger
# A1 is set to be a ConsoleAppender.
log4j.appender.testlogger=org.apache.log4j.ConsoleAppender
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
deleted file mode 100644
index 7e0d0ff..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala
+++ /dev/null
@@ -1,212 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.api.scala.batch
-
-import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase
-import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
-import org.apache.flink.api.scala._
-import org.apache.flink.api.scala.table._
-import org.apache.flink.api.table.expressions.utils._
-import org.apache.flink.api.table.{Row, Table, TableEnvironment}
-import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
-import org.apache.flink.test.util.TestBaseUtils
-import org.junit.Test
-import org.junit.runner.RunWith
-import org.junit.runners.Parameterized
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-
-@RunWith(classOf[Parameterized])
-class UserDefinedTableFunctionITCase(
- mode: TestExecutionMode,
- configMode: TableConfigMode)
- extends TableProgramsTestBase(mode, configMode) {
-
- @Test
- def testSQLCrossApply(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
- tableEnv.registerTable("MyTable", in)
- tableEnv.registerFunction("split", new TableFunc1)
-
- val sqlQuery = "SELECT MyTable.c, t.s FROM MyTable, LATERAL TABLE(split(c)) AS t(s)"
-
- val result = tableEnv.sql(sqlQuery).toDataSet[Row]
- val results = result.collect()
- val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
- "Anna#44,Anna\n" + "Anna#44,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testSQLOuterApply(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
- tableEnv.registerTable("MyTable", in)
- tableEnv.registerFunction("split", new TableFunc2)
-
- val sqlQuery = "SELECT MyTable.c, t.a, t.b FROM MyTable LEFT JOIN LATERAL TABLE(split(c)) " +
- "AS t(a,b) ON TRUE"
-
- val result = tableEnv.sql(sqlQuery).toDataSet[Row]
- val results = result.collect()
- val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
- "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testTableAPICrossApply(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
-
- val func1 = new TableFunc1
- val result = in.crossApply(func1('c) as ('s)).select('c, 's).toDataSet[Row]
- val results = result.collect()
- val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
- "Anna#44,Anna\n" + "Anna#44,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
-
- // with overloading
- val result2 = in.crossApply(func1('c, "$") as ('s)).select('c, 's).toDataSet[Row]
- val results2 = result2.collect()
- val expected2: String = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" +
- "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n"
- TestBaseUtils.compareResultAsText(results2.asJava, expected2)
- }
-
-
- @Test
- def testTableAPIOuterApply(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
- val func2 = new TableFunc2
- val result = in.outerApply(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row]
- val results = result.collect()
- val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
- "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
-
- @Test
- def testCustomReturnType(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
- val func2 = new TableFunc2
-
- val result = in
- .crossApply(func2('c) as ('name, 'len))
- .select('c, 'name, 'len)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
- "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testHierarchyType(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
-
- val hierarchy = new HierarchyTableFunction
- val result = in
- .crossApply(hierarchy('c) as ('name, 'adult, 'len))
- .select('c, 'name, 'adult, 'len)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected: String = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" +
- "Anna#44,Anna,true,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
- @Test
- def testPojoType(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
-
- val pojo = new PojoTableFunc()
- val result = in
- .crossApply(pojo('c))
- .select('c, 'name, 'age)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected: String = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
-
- @Test
- def testTableAPIWithFilter(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
- val func0 = new TableFunc0
-
- val result = in
- .crossApply(func0('c) as ('name, 'age))
- .select('c, 'name, 'age)
- .filter('age > 20)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected: String = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
-
- @Test
- def testUDTFWithScalarFunction(): Unit = {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env)
- val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c)
- val func1 = new TableFunc1
-
- val result = in
- .crossApply(func1('c.substring(2)) as 's)
- .select('c, 's)
- .toDataSet[Row]
-
- val results = result.collect()
- val expected: String = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" +
- "Anna#44,nna\n" + "Anna#44,44\n"
- TestBaseUtils.compareResultAsText(results.asJava, expected)
- }
-
-
- private def getSmall3TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, String)] = {
- val data = new mutable.MutableList[(Int, Long, String)]
- data.+=((1, 1L, "Jack#22"))
- data.+=((2, 2L, "John#19"))
- data.+=((3, 2L, "Anna#44"))
- data.+=((4, 3L, "nosharp"))
- env.fromCollection(data)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
deleted file mode 100644
index 7e236d1..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala
+++ /dev/null
@@ -1,320 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.api.scala.batch
-
-import org.apache.flink.api.scala.table._
-import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment => ScalaExecutionEnv, _}
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.java.{DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv}
-import org.apache.flink.api.table.expressions.utils.{HierarchyTableFunction, PojoTableFunc, TableFunc1, TableFunc2}
-import org.apache.flink.api.table.typeutils.RowTypeInfo
-import org.apache.flink.api.table.utils.TableTestBase
-import org.apache.flink.api.table.utils.TableTestUtil._
-import org.apache.flink.api.table.{Row, TableEnvironment, Types}
-import org.junit.Test
-import org.mockito.Mockito._
-
-
-class UserDefinedTableFunctionTest extends TableTestBase {
-
- @Test
- def testTableAPI(): Unit = {
- // mock
- val ds = mock(classOf[DataSet[Row]])
- val jDs = mock(classOf[JDataSet[Row]])
- val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
- when(ds.javaSet).thenReturn(jDs)
- when(jDs.getType).thenReturn(typeInfo)
-
- // Scala environment
- val env = mock(classOf[ScalaExecutionEnv])
- val tableEnv = TableEnvironment.getTableEnvironment(env)
- val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
-
- // Java environment
- val javaEnv = mock(classOf[JavaExecutionEnv])
- val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
- val in2 = javaTableEnv.fromDataSet(jDs).as("a, b, c")
- javaTableEnv.registerTable("MyTable", in2)
-
- // test cross apply
- val func1 = new TableFunc1
- javaTableEnv.registerFunction("func1", func1)
- var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's)
- var javaTable = in2.crossApply("func1(c) as (s)").select("c, s")
- verifyTableEquals(scalaTable, javaTable)
-
- // test outer apply
- scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's)
- javaTable = in2.outerApply("func1(c) as (s)").select("c, s")
- verifyTableEquals(scalaTable, javaTable)
-
- // test overloading
- scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's)
- javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
- verifyTableEquals(scalaTable, javaTable)
-
- // test custom result type
- val func2 = new TableFunc2
- javaTableEnv.registerFunction("func2", func2)
- scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
- javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len")
- verifyTableEquals(scalaTable, javaTable)
-
- // test hierarchy generic type
- val hierarchy = new HierarchyTableFunction
- javaTableEnv.registerFunction("hierarchy", hierarchy)
- scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
- .select('c, 'name, 'len, 'adult)
- javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)")
- .select("c, name, len, adult")
- verifyTableEquals(scalaTable, javaTable)
-
- // test pojo type
- val pojo = new PojoTableFunc
- javaTableEnv.registerFunction("pojo", pojo)
- scalaTable = in1.crossApply(pojo('c))
- .select('c, 'name, 'age)
- javaTable = in2.crossApply("pojo(c)")
- .select("c, name, age")
- verifyTableEquals(scalaTable, javaTable)
-
- // test with filter
- scalaTable = in1.crossApply(func2('c) as ('name, 'len))
- .select('c, 'name, 'len).filter('len > 2)
- javaTable = in2.crossApply("func2(c) as (name, len)")
- .select("c, name, len").filter("len > 2")
- verifyTableEquals(scalaTable, javaTable)
-
- // test with scalar function
- scalaTable = in1.crossApply(func1('c.substring(2)) as ('s))
- .select('a, 'c, 's)
- javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
- .select("a, c, s")
- verifyTableEquals(scalaTable, javaTable)
- }
-
- @Test
- def testSQLWithCrossApply(): Unit = {
- val util = batchTestUtil()
- val func1 = new TableFunc1
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func1", func1)
-
- val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
-
- val expected = unaryNode(
- "DataSetCalc",
- unaryNode(
- "DataSetCorrelate",
- batchTableNode(0),
- term("invocation", "func1($cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS s")
- )
-
- util.verifySql(sqlQuery, expected)
-
- // test overloading
-
- val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
-
- val expected2 = unaryNode(
- "DataSetCalc",
- unaryNode(
- "DataSetCorrelate",
- batchTableNode(0),
- term("invocation", "func1($cor0.c, '$')"),
- term("function", func1.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS s")
- )
-
- util.verifySql(sqlQuery2, expected2)
- }
-
- @Test
- def testSQLWithOuterApply(): Unit = {
- val util = batchTestUtil()
- val func1 = new TableFunc1
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func1", func1)
-
- val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE"
-
- val expected = unaryNode(
- "DataSetCalc",
- unaryNode(
- "DataSetCorrelate",
- batchTableNode(0),
- term("invocation", "func1($cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
- term("joinType", "LEFT")
- ),
- term("select", "c", "f0 AS s")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testSQLWithCustomType(): Unit = {
- val util = batchTestUtil()
- val func2 = new TableFunc2
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func2", func2)
-
- val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)"
-
- val expected = unaryNode(
- "DataSetCalc",
- unaryNode(
- "DataSetCorrelate",
- batchTableNode(0),
- term("invocation", "func2($cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
- "VARCHAR(2147483647) f0, INTEGER f1)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS name", "f1 AS len")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testSQLWithHierarchyType(): Unit = {
- val util = batchTestUtil()
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- val function = new HierarchyTableFunction
- util.addFunction("hierarchy", function)
-
- val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)"
-
- val expected = unaryNode(
- "DataSetCalc",
- unaryNode(
- "DataSetCorrelate",
- batchTableNode(0),
- term("invocation", "hierarchy($cor0.c)"),
- term("function", function.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
- " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testSQLWithPojoType(): Unit = {
- val util = batchTestUtil()
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- val function = new PojoTableFunc
- util.addFunction("pojo", function)
-
- val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
-
- val expected = unaryNode(
- "DataSetCalc",
- unaryNode(
- "DataSetCorrelate",
- batchTableNode(0),
- term("invocation", "pojo($cor0.c)"),
- term("function", function.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
- " INTEGER age, VARCHAR(2147483647) name)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "name", "age")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
- @Test
- def testSQLWithFilter(): Unit = {
- val util = batchTestUtil()
- val func2 = new TableFunc2
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func2", func2)
-
- val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " +
- "WHERE len > 2"
-
- val expected = unaryNode(
- "DataSetCalc",
- unaryNode(
- "DataSetCorrelate",
- batchTableNode(0),
- term("invocation", "func2($cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
- "VARCHAR(2147483647) f0, INTEGER f1)"),
- term("joinType", "INNER"),
- term("condition", ">($1, 2)")
- ),
- term("select", "c", "f0 AS name", "f1 AS len")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-
-
- @Test
- def testSQLWithScalarFunction(): Unit = {
- val util = batchTestUtil()
- val func1 = new TableFunc1
- util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
- util.addFunction("func1", func1)
-
- val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
-
- val expected = unaryNode(
- "DataSetCalc",
- unaryNode(
- "DataSetCorrelate",
- batchTableNode(0),
- term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
- term("function", func1.getClass.getCanonicalName),
- term("rowType",
- "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
- term("joinType", "INNER")
- ),
- term("select", "c", "f0 AS s")
- )
-
- util.verifySql(sqlQuery, expected)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/684defbf/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..1c505ba
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.batch.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.utils.{HierarchyTableFunction, PojoTableFunc, TableFunc2}
+import org.apache.flink.api.table.utils._
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.junit.Test
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+ @Test
+ def testCrossApply(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+
+ // test overloading
+
+ val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
+
+ val expected2 = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c, '$')"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery2, expected2)
+ }
+
+ @Test
+ def testOuterApply(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1($cor0.c)"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "LEFT")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testCustomType(): Unit = {
+ val util = batchTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testHierarchyType(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new HierarchyTableFunction
+ util.addFunction("hierarchy", function)
+
+ val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "hierarchy($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testPojoType(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ val function = new PojoTableFunc
+ util.addFunction("pojo", function)
+
+ val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "pojo($cor0.c)"),
+ term("function", function.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+ " INTEGER age, VARCHAR(2147483647) name)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "name", "age")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testFilter(): Unit = {
+ val util = batchTestUtil()
+ val func2 = new TableFunc2
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func2", func2)
+
+ val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " +
+ "WHERE len > 2"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func2($cor0.c)"),
+ term("function", func2.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+ "VARCHAR(2147483647) f0, INTEGER f1)"),
+ term("joinType", "INNER"),
+ term("condition", ">($1, 2)")
+ ),
+ term("select", "c", "f0 AS name", "f1 AS len")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+
+
+ @Test
+ def testScalarFunction(): Unit = {
+ val util = batchTestUtil()
+ val func1 = new TableFunc1
+ util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func1", func1)
+
+ val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetCorrelate",
+ batchTableNode(0),
+ term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
+ term("function", func1.getClass.getCanonicalName),
+ term("rowType",
+ "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"),
+ term("joinType", "INNER")
+ ),
+ term("select", "c", "f0 AS s")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
+}