You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/12/16 15:46:31 UTC
[02/47] flink git commit: [FLINK-4704] [table] Refactor package
structure of flink-table.
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala
new file mode 100644
index 0000000..098feba
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.types.Row
+import org.apache.flink.table.api.{Types, ValidationException}
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.utils.ExpressionTestBase
+import org.junit.Test
+
+class ScalarOperatorsTest extends ExpressionTestBase {
+
+ @Test
+ def testCasting(): Unit = {
+ // test casting
+ // * -> String
+ testTableApi('f2.cast(Types.STRING), "f2.cast(STRING)", "1")
+ testTableApi('f5.cast(Types.STRING), "f5.cast(STRING)", "1.0")
+ testTableApi('f3.cast(Types.STRING), "f3.cast(STRING)", "1")
+ testTableApi('f6.cast(Types.STRING), "f6.cast(STRING)", "true")
+ // NUMERIC TYPE -> Boolean
+ testTableApi('f2.cast(Types.BOOLEAN), "f2.cast(BOOLEAN)", "true")
+ testTableApi('f7.cast(Types.BOOLEAN), "f7.cast(BOOLEAN)", "false")
+ testTableApi('f3.cast(Types.BOOLEAN), "f3.cast(BOOLEAN)", "true")
+ // NUMERIC TYPE -> NUMERIC TYPE
+ testTableApi('f2.cast(Types.DOUBLE), "f2.cast(DOUBLE)", "1.0")
+ testTableApi('f7.cast(Types.INT), "f7.cast(INT)", "0")
+ testTableApi('f3.cast(Types.SHORT), "f3.cast(SHORT)", "1")
+ // Boolean -> NUMERIC TYPE
+ testTableApi('f6.cast(Types.DOUBLE), "f6.cast(DOUBLE)", "1.0")
+ // identity casting
+ testTableApi('f2.cast(Types.INT), "f2.cast(INT)", "1")
+ testTableApi('f7.cast(Types.DOUBLE), "f7.cast(DOUBLE)", "0.0")
+ testTableApi('f3.cast(Types.LONG), "f3.cast(LONG)", "1")
+ testTableApi('f6.cast(Types.BOOLEAN), "f6.cast(BOOLEAN)", "true")
+ // String -> BASIC TYPE (not String, Date, Void, Character)
+ testTableApi('f2.cast(Types.BYTE), "f2.cast(BYTE)", "1")
+ testTableApi('f2.cast(Types.SHORT), "f2.cast(SHORT)", "1")
+ testTableApi('f2.cast(Types.INT), "f2.cast(INT)", "1")
+ testTableApi('f2.cast(Types.LONG), "f2.cast(LONG)", "1")
+ testTableApi('f3.cast(Types.DOUBLE), "f3.cast(DOUBLE)", "1.0")
+ testTableApi('f3.cast(Types.FLOAT), "f3.cast(FLOAT)", "1.0")
+ testTableApi('f5.cast(Types.BOOLEAN), "f5.cast(BOOLEAN)", "true")
+
+ // numeric auto cast in arithmetic
+ testTableApi('f0 + 1, "f0 + 1", "2")
+ testTableApi('f1 + 1, "f1 + 1", "2")
+ testTableApi('f2 + 1L, "f2 + 1L", "2")
+ testTableApi('f3 + 1.0f, "f3 + 1.0f", "2.0")
+ testTableApi('f3 + 1.0d, "f3 + 1.0d", "2.0")
+ testTableApi('f5 + 1, "f5 + 1", "2.0")
+ testTableApi('f3 + 1.0d, "f3 + 1.0d", "2.0")
+ testTableApi('f4 + 'f0, "f4 + f0", "2.0")
+
+ // numeric auto cast in comparison
+ testTableApi(
+ 'f0 > 0 && 'f1 > 0 && 'f2 > 0L && 'f4 > 0.0f && 'f5 > 0.0d && 'f3 > 0,
+ "f0 > 0 && f1 > 0 && f2 > 0L && f4 > 0.0f && f5 > 0.0d && f3 > 0",
+ "true")
+ }
+
+ @Test
+ def testArithmetic(): Unit = {
+ // math arthmetic
+ testTableApi('f8 - 5, "f8 - 5", "0")
+ testTableApi('f8 + 5, "f8 + 5", "10")
+ testTableApi('f8 / 2, "f8 / 2", "2")
+ testTableApi('f8 * 2, "f8 * 2", "10")
+ testTableApi('f8 % 2, "f8 % 2", "1")
+ testTableApi(-'f8, "-f8", "-5")
+ testTableApi(3.toExpr + 'f8, "3 + f8", "8")
+
+ // boolean arithmetic
+ testTableApi('f6 && true, "f6 && true", "true")
+ testTableApi('f6 && false, "f6 && false", "false")
+ testTableApi('f6 || false, "f6 || false", "true")
+ testTableApi(!'f6, "!f6", "false")
+
+ // comparison
+ testTableApi('f8 > 'f2, "f8 > f2", "true")
+ testTableApi('f8 >= 'f8, "f8 >= f8", "true")
+ testTableApi('f8 < 'f2, "f8 < f2", "false")
+ testTableApi('f8.isNull, "f8.isNull", "false")
+ testTableApi('f8.isNotNull, "f8.isNotNull", "true")
+ testTableApi(12.toExpr <= 'f8, "12 <= f8", "false")
+
+ // string arithmetic
+ testTableApi(42.toExpr + 'f10 + 'f9, "42 + f10 + f9", "42String10")
+ testTableApi('f10 + 'f9, "f10 + f9", "String10")
+ }
+
+ @Test
+ def testOtherExpressions(): Unit = {
+ // null
+ testAllApis(Null(Types.INT), "Null(INT)", "CAST(NULL AS INT)", "null")
+ testAllApis(
+ Null(Types.STRING) === "",
+ "Null(STRING) === ''",
+ "CAST(NULL AS VARCHAR) = ''",
+ "null")
+
+ // if
+ testTableApi(('f6 && true).?("true", "false"), "(f6 && true).?('true', 'false')", "true")
+ testTableApi(false.?("true", "false"), "false.?('true', 'false')", "false")
+ testTableApi(
+ true.?(true.?(true.?(10, 4), 4), 4),
+ "true.?(true.?(true.?(10, 4), 4), 4)",
+ "10")
+ testTableApi(true, "?((f6 && true), 'true', 'false')", "true")
+ testSqlApi("CASE 11 WHEN 1 THEN 'a' ELSE 'b' END", "b")
+ testSqlApi("CASE 2 WHEN 1 THEN 'a' ELSE 'b' END", "b")
+ testSqlApi(
+ "CASE 1 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 THEN '3' " +
+ "ELSE 'none of the above' END",
+ "1 or 2 ")
+ testSqlApi("CASE WHEN 'a'='a' THEN 1 END", "1")
+ testSqlApi("CASE 2 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END", "bcd")
+ testSqlApi("CASE f2 WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END", "11")
+ testSqlApi("CASE f7 WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END", "null")
+ testSqlApi("CASE 42 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END", "null")
+ testSqlApi("CASE 1 WHEN 1 THEN true WHEN 2 THEN false ELSE NULL END", "true")
+
+ // case insensitive as
+ testTableApi(5 as 'test, "5 As test", "5")
+
+ // complex expressions
+ testTableApi('f0.isNull.isNull, "f0.isNull().isNull", "false")
+ testTableApi(
+ 'f8.abs() + 'f8.abs().abs().abs().abs(),
+ "f8.abs() + f8.abs().abs().abs().abs()",
+ "10")
+ testTableApi(
+ 'f8.cast(Types.STRING) + 'f8.cast(Types.STRING),
+ "f8.cast(STRING) + f8.cast(STRING)",
+ "55")
+ testTableApi('f8.isNull.cast(Types.INT), "CAST(ISNULL(f8), INT)", "0")
+ testTableApi(
+ 'f8.cast(Types.INT).abs().isNull === false,
+ "ISNULL(CAST(f8, INT).abs()) === false",
+ "true")
+ testTableApi(
+ (((true === true) || false).cast(Types.STRING) + "X ").trim(),
+ "((((true) === true) || false).cast(STRING) + 'X ').trim",
+ "trueX")
+ testTableApi(12.isNull, "12.isNull", "false")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testIfInvalidTypesScala(): Unit = {
+ testTableApi(('f6 && true).?(5, "false"), "FAIL", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testIfInvalidTypesJava(): Unit = {
+ testTableApi("FAIL", "(f8 && true).?(5, 'false')", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testInvalidStringComparison1(): Unit = {
+ testTableApi("w" === 4, "FAIL", "FAIL")
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testInvalidStringComparison2(): Unit = {
+ testTableApi("w" > 4.toExpr, "FAIL", "FAIL")
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ def testData = {
+ val testData = new Row(11)
+ testData.setField(0, 1: Byte)
+ testData.setField(1, 1: Short)
+ testData.setField(2, 1)
+ testData.setField(3, 1L)
+ testData.setField(4, 1.0f)
+ testData.setField(5, 1.0d)
+ testData.setField(6, true)
+ testData.setField(7, 0.0d)
+ testData.setField(8, 5)
+ testData.setField(9, 10)
+ testData.setField(10, "String")
+ testData
+ }
+
+ def typeInfo = {
+ new RowTypeInfo(
+ Types.BYTE,
+ Types.SHORT,
+ Types.INT,
+ Types.LONG,
+ Types.FLOAT,
+ Types.DOUBLE,
+ Types.BOOLEAN,
+ Types.DOUBLE,
+ Types.INT,
+ Types.INT,
+ Types.STRING
+ ).asInstanceOf[TypeInformation[Any]]
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
new file mode 100644
index 0000000..e0f45d4
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.types.Row
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.table.expressions.utils.ExpressionTestBase
+import org.junit.Test
+
+/**
+ * Tests all SQL expressions that are currently supported according to the documentation.
+ * This tests should be kept in sync with the documentation to reduce confusion due to the
+ * large amount of SQL functions.
+ *
+ * The tests do not test every parameter combination of a function.
+ * They are rather a function existence test and simple functional test.
+ *
+ * The tests are split up and ordered like the sections in the documentation.
+ */
+class SqlExpressionTest extends ExpressionTestBase {
+
+ @Test
+ def testComparisonFunctions(): Unit = {
+ testSqlApi("1 = 1", "true")
+ testSqlApi("1 <> 1", "false")
+ testSqlApi("5 > 2", "true")
+ testSqlApi("2 >= 2", "true")
+ testSqlApi("5 < 2", "false")
+ testSqlApi("2 <= 2", "true")
+ testSqlApi("1 IS NULL", "false")
+ testSqlApi("1 IS NOT NULL", "true")
+ testSqlApi("NULLIF(1,1) IS DISTINCT FROM NULLIF(1,1)", "false")
+ testSqlApi("NULLIF(1,1) IS NOT DISTINCT FROM NULLIF(1,1)", "true")
+ testSqlApi("NULLIF(1,1) IS NOT DISTINCT FROM NULLIF(1,1)", "true")
+ testSqlApi("12 BETWEEN 11 AND 13", "true")
+ testSqlApi("12 BETWEEN ASYMMETRIC 13 AND 11", "false")
+ testSqlApi("12 BETWEEN SYMMETRIC 13 AND 11", "true")
+ testSqlApi("12 NOT BETWEEN 11 AND 13", "false")
+ testSqlApi("12 NOT BETWEEN ASYMMETRIC 13 AND 11", "true")
+ testSqlApi("12 NOT BETWEEN SYMMETRIC 13 AND 11", "false")
+ testSqlApi("'TEST' LIKE '%EST'", "true")
+ //testSqlApi("'%EST' LIKE '.%EST' ESCAPE '.'", "true") // TODO
+ testSqlApi("'TEST' NOT LIKE '%EST'", "false")
+ //testSqlApi("'%EST' NOT LIKE '.%EST' ESCAPE '.'", "false") // TODO
+ testSqlApi("'TEST' SIMILAR TO '.EST'", "true")
+ //testSqlApi("'TEST' SIMILAR TO ':.EST' ESCAPE ':'", "true") // TODO
+ testSqlApi("'TEST' NOT SIMILAR TO '.EST'", "false")
+ //testSqlApi("'TEST' NOT SIMILAR TO ':.EST' ESCAPE ':'", "false") // TODO
+ testSqlApi("'TEST' IN ('west', 'TEST', 'rest')", "true")
+ testSqlApi("'TEST' IN ('west', 'rest')", "false")
+ testSqlApi("'TEST' NOT IN ('west', 'TEST', 'rest')", "false")
+ testSqlApi("'TEST' NOT IN ('west', 'rest')", "true")
+
+ // sub-query functions are not listed here
+ }
+
+ @Test
+ def testLogicalFunctions(): Unit = {
+ testSqlApi("TRUE OR FALSE", "true")
+ testSqlApi("TRUE AND FALSE", "false")
+ testSqlApi("NOT TRUE", "false")
+ testSqlApi("TRUE IS FALSE", "false")
+ testSqlApi("TRUE IS NOT FALSE", "true")
+ testSqlApi("TRUE IS TRUE", "true")
+ testSqlApi("TRUE IS NOT TRUE", "false")
+ testSqlApi("NULLIF(TRUE,TRUE) IS UNKNOWN", "true")
+ testSqlApi("NULLIF(TRUE,TRUE) IS NOT UNKNOWN", "false")
+ }
+
+ @Test
+ def testArithmeticFunctions(): Unit = {
+ testSqlApi("+5", "5")
+ testSqlApi("-5", "-5")
+ testSqlApi("5+5", "10")
+ testSqlApi("5-5", "0")
+ testSqlApi("5*5", "25")
+ testSqlApi("5/5", "1")
+ testSqlApi("POWER(5, 5)", "3125.0")
+ testSqlApi("ABS(-5)", "5")
+ testSqlApi("MOD(-26, 5)", "-1")
+ testSqlApi("SQRT(4)", "2.0")
+ testSqlApi("LN(1)", "0.0")
+ testSqlApi("LOG10(1)", "0.0")
+ testSqlApi("EXP(0)", "1.0")
+ testSqlApi("CEIL(2.5)", "3")
+ testSqlApi("FLOOR(2.5)", "2")
+ }
+
+ @Test
+ def testStringFunctions(): Unit = {
+ testSqlApi("'test' || 'string'", "teststring")
+ testSqlApi("CHAR_LENGTH('string')", "6")
+ testSqlApi("CHARACTER_LENGTH('string')", "6")
+ testSqlApi("UPPER('string')", "STRING")
+ testSqlApi("LOWER('STRING')", "string")
+ testSqlApi("POSITION('STR' IN 'STRING')", "1")
+ testSqlApi("TRIM(BOTH ' STRING ')", "STRING")
+ testSqlApi("TRIM(LEADING 'x' FROM 'xxxxSTRINGxxxx')", "STRINGxxxx")
+ testSqlApi("TRIM(TRAILING 'x' FROM 'xxxxSTRINGxxxx')", "xxxxSTRING")
+ testSqlApi(
+ "OVERLAY('This is a old string' PLACING 'new' FROM 11 FOR 3)",
+ "This is a new string")
+ testSqlApi("SUBSTRING('hello world', 2)", "ello world")
+ testSqlApi("SUBSTRING('hello world', 2, 3)", "ell")
+ testSqlApi("INITCAP('hello world')", "Hello World")
+ }
+
+ @Test
+ def testConditionalFunctions(): Unit = {
+ testSqlApi("CASE 2 WHEN 1, 2 THEN 2 ELSE 3 END", "2")
+ testSqlApi("CASE WHEN 1 = 2 THEN 2 WHEN 1 = 1 THEN 3 ELSE 3 END", "3")
+ testSqlApi("NULLIF(1, 1)", "null")
+ testSqlApi("COALESCE(NULL, 5)", "5")
+ }
+
+ @Test
+ def testTypeConversionFunctions(): Unit = {
+ testSqlApi("CAST(2 AS DOUBLE)", "2.0")
+ }
+
+ @Test
+ def testValueConstructorFunctions(): Unit = {
+ // TODO we need a special code path that flattens ROW types
+ // testSqlApi("ROW('hello world', 12)", "hello world") // test base only returns field 0
+ // testSqlApi("('hello world', 12)", "hello world") // test base only returns field 0
+ testSqlApi("ARRAY[TRUE, FALSE][2]", "false")
+ testSqlApi("ARRAY[TRUE, TRUE]", "[true, true]")
+ }
+
+ @Test
+ def testDateTimeFunctions(): Unit = {
+ testSqlApi("DATE '1990-10-14'", "1990-10-14")
+ testSqlApi("TIME '12:12:12'", "12:12:12")
+ testSqlApi("TIMESTAMP '1990-10-14 12:12:12.123'", "1990-10-14 12:12:12.123")
+ testSqlApi("INTERVAL '10 00:00:00.004' DAY TO SECOND", "+10 00:00:00.004")
+ testSqlApi("INTERVAL '10 00:12' DAY TO MINUTE", "+10 00:12:00.000")
+ testSqlApi("INTERVAL '2-10' YEAR TO MONTH", "+2-10")
+ testSqlApi("EXTRACT(DAY FROM DATE '1990-12-01')", "1")
+ testSqlApi("EXTRACT(DAY FROM INTERVAL '19 12:10:10.123' DAY TO SECOND(3))", "19")
+ testSqlApi("QUARTER(DATE '2016-04-12')", "2")
+ }
+
+ @Test
+ def testArrayFunctions(): Unit = {
+ testSqlApi("CARDINALITY(ARRAY[TRUE, TRUE, FALSE])", "3")
+ testSqlApi("ELEMENT(ARRAY['HELLO WORLD'])", "HELLO WORLD")
+ }
+
+ override def testData: Any = new Row(0)
+
+ override def typeInfo: TypeInformation[Any] =
+ new RowTypeInfo().asInstanceOf[TypeInformation[Any]]
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/TemporalTypesTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/TemporalTypesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/TemporalTypesTest.scala
new file mode 100644
index 0000000..840bec1
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/TemporalTypesTest.scala
@@ -0,0 +1,573 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions
+
+import java.sql.{Date, Time, Timestamp}
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.types.Row
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.utils.ExpressionTestBase
+import org.junit.Test
+
+class TemporalTypesTest extends ExpressionTestBase {
+
+ @Test
+ def testTimePointLiterals(): Unit = {
+ testAllApis(
+ "1990-10-14".toDate,
+ "'1990-10-14'.toDate",
+ "DATE '1990-10-14'",
+ "1990-10-14")
+
+ testTableApi(
+ Date.valueOf("2040-09-11"),
+ "'2040-09-11'.toDate",
+ "2040-09-11")
+
+ testAllApis(
+ "1500-04-30".cast(Types.DATE),
+ "'1500-04-30'.cast(DATE)",
+ "CAST('1500-04-30' AS DATE)",
+ "1500-04-30")
+
+ testAllApis(
+ "15:45:59".toTime,
+ "'15:45:59'.toTime",
+ "TIME '15:45:59'",
+ "15:45:59")
+
+ testTableApi(
+ Time.valueOf("00:00:00"),
+ "'00:00:00'.toTime",
+ "00:00:00")
+
+ testAllApis(
+ "1:30:00".cast(Types.TIME),
+ "'1:30:00'.cast(TIME)",
+ "CAST('1:30:00' AS TIME)",
+ "01:30:00")
+
+ testAllApis(
+ "1990-10-14 23:00:00.123".toTimestamp,
+ "'1990-10-14 23:00:00.123'.toTimestamp",
+ "TIMESTAMP '1990-10-14 23:00:00.123'",
+ "1990-10-14 23:00:00.123")
+
+ testTableApi(
+ Timestamp.valueOf("2040-09-11 00:00:00.000"),
+ "'2040-09-11 00:00:00.000'.toTimestamp",
+ "2040-09-11 00:00:00.0")
+
+ testAllApis(
+ "1500-04-30 12:00:00".cast(Types.TIMESTAMP),
+ "'1500-04-30 12:00:00'.cast(TIMESTAMP)",
+ "CAST('1500-04-30 12:00:00' AS TIMESTAMP)",
+ "1500-04-30 12:00:00.0")
+ }
+
+ @Test
+ def testTimeIntervalLiterals(): Unit = {
+ testAllApis(
+ 1.year,
+ "1.year",
+ "INTERVAL '1' YEAR",
+ "+1-00")
+
+ testAllApis(
+ 1.month,
+ "1.month",
+ "INTERVAL '1' MONTH",
+ "+0-01")
+
+ testAllApis(
+ 12.days,
+ "12.days",
+ "INTERVAL '12' DAY",
+ "+12 00:00:00.000")
+
+ testAllApis(
+ 1.hour,
+ "1.hour",
+ "INTERVAL '1' HOUR",
+ "+0 01:00:00.000")
+
+ testAllApis(
+ 3.minutes,
+ "3.minutes",
+ "INTERVAL '3' MINUTE",
+ "+0 00:03:00.000")
+
+ testAllApis(
+ 3.seconds,
+ "3.seconds",
+ "INTERVAL '3' SECOND",
+ "+0 00:00:03.000")
+
+ testAllApis(
+ 3.millis,
+ "3.millis",
+ "INTERVAL '0.003' SECOND",
+ "+0 00:00:00.003")
+ }
+
+ @Test
+ def testTimePointInput(): Unit = {
+ testAllApis(
+ 'f0,
+ "f0",
+ "f0",
+ "1990-10-14")
+
+ testAllApis(
+ 'f1,
+ "f1",
+ "f1",
+ "10:20:45")
+
+ testAllApis(
+ 'f2,
+ "f2",
+ "f2",
+ "1990-10-14 10:20:45.123")
+ }
+
+ @Test
+ def testTimeIntervalInput(): Unit = {
+ testAllApis(
+ 'f9,
+ "f9",
+ "f9",
+ "+2-00")
+
+ testAllApis(
+ 'f10,
+ "f10",
+ "f10",
+ "+0 00:00:12.000")
+ }
+
+ @Test
+ def testTimePointCasting(): Unit = {
+ testAllApis(
+ 'f0.cast(Types.TIMESTAMP),
+ "f0.cast(TIMESTAMP)",
+ "CAST(f0 AS TIMESTAMP)",
+ "1990-10-14 00:00:00.0")
+
+ testAllApis(
+ 'f1.cast(Types.TIMESTAMP),
+ "f1.cast(TIMESTAMP)",
+ "CAST(f1 AS TIMESTAMP)",
+ "1970-01-01 10:20:45.0")
+
+ testAllApis(
+ 'f2.cast(Types.DATE),
+ "f2.cast(DATE)",
+ "CAST(f2 AS DATE)",
+ "1990-10-14")
+
+ testAllApis(
+ 'f2.cast(Types.TIME),
+ "f2.cast(TIME)",
+ "CAST(f2 AS TIME)",
+ "10:20:45")
+
+ testAllApis(
+ 'f2.cast(Types.TIME),
+ "f2.cast(TIME)",
+ "CAST(f2 AS TIME)",
+ "10:20:45")
+
+ testTableApi(
+ 'f7.cast(Types.DATE),
+ "f7.cast(DATE)",
+ "2002-11-09")
+
+ testTableApi(
+ 'f7.cast(Types.DATE).cast(Types.INT),
+ "f7.cast(DATE).cast(INT)",
+ "12000")
+
+ testTableApi(
+ 'f7.cast(Types.TIME),
+ "f7.cast(TIME)",
+ "00:00:12")
+
+ testTableApi(
+ 'f7.cast(Types.TIME).cast(Types.INT),
+ "f7.cast(TIME).cast(INT)",
+ "12000")
+
+ testTableApi(
+ 'f8.cast(Types.TIMESTAMP),
+ "f8.cast(TIMESTAMP)",
+ "2016-06-27 07:23:33.0")
+
+ testTableApi(
+ 'f8.cast(Types.TIMESTAMP).cast(Types.LONG),
+ "f8.cast(TIMESTAMP).cast(LONG)",
+ "1467012213000")
+ }
+
+ @Test
+ def testTimeIntervalCasting(): Unit = {
+ testTableApi(
+ 'f7.cast(Types.INTERVAL_MONTHS),
+ "f7.cast(INTERVAL_MONTHS)",
+ "+1000-00")
+
+ testTableApi(
+ 'f8.cast(Types.INTERVAL_MILLIS),
+ "f8.cast(INTERVAL_MILLIS)",
+ "+16979 07:23:33.000")
+ }
+
+ @Test
+ def testTimePointComparison(): Unit = {
+ testAllApis(
+ 'f0 < 'f3,
+ "f0 < f3",
+ "f0 < f3",
+ "false")
+
+ testAllApis(
+ 'f0 < 'f4,
+ "f0 < f4",
+ "f0 < f4",
+ "true")
+
+ testAllApis(
+ 'f1 < 'f5,
+ "f1 < f5",
+ "f1 < f5",
+ "false")
+
+ testAllApis(
+ 'f0.cast(Types.TIMESTAMP) !== 'f2,
+ "f0.cast(TIMESTAMP) !== f2",
+ "CAST(f0 AS TIMESTAMP) <> f2",
+ "true")
+
+ testAllApis(
+ 'f0.cast(Types.TIMESTAMP) === 'f6,
+ "f0.cast(TIMESTAMP) === f6",
+ "CAST(f0 AS TIMESTAMP) = f6",
+ "true")
+ }
+
+ @Test
+ def testTimeIntervalArithmetic(): Unit = {
+
+ // interval months comparison
+
+ testAllApis(
+ 12.months < 24.months,
+ "12.months < 24.months",
+ "INTERVAL '12' MONTH < INTERVAL '24' MONTH",
+ "true")
+
+ testAllApis(
+ 8.years === 8.years,
+ "8.years === 8.years",
+ "INTERVAL '8' YEAR = INTERVAL '8' YEAR",
+ "true")
+
+ // interval millis comparison
+
+ testAllApis(
+ 8.millis > 10.millis,
+ "8.millis > 10.millis",
+ "INTERVAL '0.008' SECOND > INTERVAL '0.010' SECOND",
+ "false")
+
+ testAllApis(
+ 8.millis === 8.millis,
+ "8.millis === 8.millis",
+ "INTERVAL '0.008' SECOND = INTERVAL '0.008' SECOND",
+ "true")
+
+ // interval months addition/subtraction
+
+ testAllApis(
+ 8.years + 10.months,
+ "8.years + 10.months",
+ "INTERVAL '8' YEAR + INTERVAL '10' MONTH",
+ "+8-10")
+
+ testAllApis(
+ 2.years - 12.months,
+ "2.years - 12.months",
+ "INTERVAL '2' YEAR - INTERVAL '12' MONTH",
+ "+1-00")
+
+ testAllApis(
+ -2.years,
+ "-2.years",
+ "-INTERVAL '2' YEAR",
+ "-2-00")
+
+ // interval millis addition/subtraction
+
+ testAllApis(
+ 8.hours + 10.minutes + 12.seconds + 5.millis,
+ "8.hours + 10.minutes + 12.seconds + 5.millis",
+ "INTERVAL '8' HOUR + INTERVAL '10' MINUTE + INTERVAL '12.005' SECOND",
+ "+0 08:10:12.005")
+
+ testAllApis(
+ 1.minute - 10.seconds,
+ "1.minute - 10.seconds",
+ "INTERVAL '1' MINUTE - INTERVAL '10' SECOND",
+ "+0 00:00:50.000")
+
+ testAllApis(
+ -10.seconds,
+ "-10.seconds",
+ "-INTERVAL '10' SECOND",
+ "-0 00:00:10.000")
+
+ // addition to date
+
+ // interval millis
+ testAllApis(
+ 'f0 + 2.days,
+ "f0 + 2.days",
+ "f0 + INTERVAL '2' DAY",
+ "1990-10-16")
+
+ // interval millis
+ testAllApis(
+ 30.days + 'f0,
+ "30.days + f0",
+ "INTERVAL '30' DAY + f0",
+ "1990-11-13")
+
+ // interval months
+ testAllApis(
+ 'f0 + 2.months,
+ "f0 + 2.months",
+ "f0 + INTERVAL '2' MONTH",
+ "1990-12-14")
+
+ // interval months
+ testAllApis(
+ 2.months + 'f0,
+ "2.months + f0",
+ "INTERVAL '2' MONTH + f0",
+ "1990-12-14")
+
+ // addition to time
+
+ // interval millis
+ testAllApis(
+ 'f1 + 12.hours,
+ "f1 + 12.hours",
+ "f1 + INTERVAL '12' HOUR",
+ "22:20:45")
+
+ // interval millis
+ testAllApis(
+ 12.hours + 'f1,
+ "12.hours + f1",
+ "INTERVAL '12' HOUR + f1",
+ "22:20:45")
+
+ // addition to timestamp
+
+ // interval millis
+ testAllApis(
+ 'f2 + 10.days + 4.millis,
+ "f2 + 10.days + 4.millis",
+ "f2 + INTERVAL '10 00:00:00.004' DAY TO SECOND",
+ "1990-10-24 10:20:45.127")
+
+ // interval millis
+ testAllApis(
+ 10.days + 'f2 + 4.millis,
+ "10.days + f2 + 4.millis",
+ "INTERVAL '10 00:00:00.004' DAY TO SECOND + f2",
+ "1990-10-24 10:20:45.127")
+
+ // interval months
+ testAllApis(
+ 'f2 + 10.years,
+ "f2 + 10.years",
+ "f2 + INTERVAL '10' YEAR",
+ "2000-10-14 10:20:45.123")
+
+ // interval months
+ testAllApis(
+ 10.years + 'f2,
+ "10.years + f2",
+ "INTERVAL '10' YEAR + f2",
+ "2000-10-14 10:20:45.123")
+
+ // subtraction from date
+
+ // interval millis
+ testAllApis(
+ 'f0 - 2.days,
+ "f0 - 2.days",
+ "f0 - INTERVAL '2' DAY",
+ "1990-10-12")
+
+ // interval millis
+ testAllApis(
+ -30.days + 'f0,
+ "-30.days + f0",
+ "INTERVAL '-30' DAY + f0",
+ "1990-09-14")
+
+ // interval months
+ testAllApis(
+ 'f0 - 2.months,
+ "f0 - 2.months",
+ "f0 - INTERVAL '2' MONTH",
+ "1990-08-14")
+
+ // interval months
+ testAllApis(
+ -2.months + 'f0,
+ "-2.months + f0",
+ "-INTERVAL '2' MONTH + f0",
+ "1990-08-14")
+
+ // subtraction from time
+
+ // interval millis
+ testAllApis(
+ 'f1 - 12.hours,
+ "f1 - 12.hours",
+ "f1 - INTERVAL '12' HOUR",
+ "22:20:45")
+
+ // interval millis
+ testAllApis(
+ -12.hours + 'f1,
+ "-12.hours + f1",
+ "INTERVAL '-12' HOUR + f1",
+ "22:20:45")
+
+ // subtraction from timestamp
+
+ // interval millis
+ testAllApis(
+ 'f2 - 10.days - 4.millis,
+ "f2 - 10.days - 4.millis",
+ "f2 - INTERVAL '10 00:00:00.004' DAY TO SECOND",
+ "1990-10-04 10:20:45.119")
+
+ // interval millis
+ testAllApis(
+ -10.days + 'f2 - 4.millis,
+ "-10.days + f2 - 4.millis",
+ "INTERVAL '-10 00:00:00.004' DAY TO SECOND + f2",
+ "1990-10-04 10:20:45.119")
+
+ // interval months
+ testAllApis(
+ 'f2 - 10.years,
+ "f2 - 10.years",
+ "f2 - INTERVAL '10' YEAR",
+ "1980-10-14 10:20:45.123")
+
+ // interval months
+ testAllApis(
+ -10.years + 'f2,
+ "-10.years + f2",
+ "INTERVAL '-10' YEAR + f2",
+ "1980-10-14 10:20:45.123")
+
+ // casting
+
+ testAllApis(
+ -'f9.cast(Types.INTERVAL_MONTHS),
+ "-f9.cast(INTERVAL_MONTHS)",
+ "-CAST(f9 AS INTERVAL YEAR)",
+ "-2-00")
+
+ testAllApis(
+ -'f10.cast(Types.INTERVAL_MILLIS),
+ "-f10.cast(INTERVAL_MILLIS)",
+ "-CAST(f10 AS INTERVAL SECOND)",
+ "-0 00:00:12.000")
+
+ // addition/subtraction of interval millis and interval months
+
+ testAllApis(
+ 'f0 + 2.days + 1.month,
+ "f0 + 2.days + 1.month",
+ "f0 + INTERVAL '2' DAY + INTERVAL '1' MONTH",
+ "1990-11-16")
+
+ testAllApis(
+ 'f0 - 2.days - 1.month,
+ "f0 - 2.days - 1.month",
+ "f0 - INTERVAL '2' DAY - INTERVAL '1' MONTH",
+ "1990-09-12")
+
+ testAllApis(
+ 'f2 + 2.days + 1.month,
+ "f2 + 2.days + 1.month",
+ "f2 + INTERVAL '2' DAY + INTERVAL '1' MONTH",
+ "1990-11-16 10:20:45.123")
+
+ testAllApis(
+ 'f2 - 2.days - 1.month,
+ "f2 - 2.days - 1.month",
+ "f2 - INTERVAL '2' DAY - INTERVAL '1' MONTH",
+ "1990-09-12 10:20:45.123")
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ def testData = {
+ val testData = new Row(11)
+ testData.setField(0, Date.valueOf("1990-10-14"))
+ testData.setField(1, Time.valueOf("10:20:45"))
+ testData.setField(2, Timestamp.valueOf("1990-10-14 10:20:45.123"))
+ testData.setField(3, Date.valueOf("1990-10-13"))
+ testData.setField(4, Date.valueOf("1990-10-15"))
+ testData.setField(5, Time.valueOf("00:00:00"))
+ testData.setField(6, Timestamp.valueOf("1990-10-14 00:00:00.0"))
+ testData.setField(7, 12000)
+ testData.setField(8, 1467012213000L)
+ testData.setField(9, 24)
+ testData.setField(10, 12000L)
+ testData
+ }
+
+ def typeInfo = {
+ new RowTypeInfo(
+ Types.DATE,
+ Types.TIME,
+ Types.TIMESTAMP,
+ Types.DATE,
+ Types.DATE,
+ Types.TIME,
+ Types.TIMESTAMP,
+ Types.INT,
+ Types.LONG,
+ Types.INTERVAL_MONTHS,
+ Types.INTERVAL_MILLIS).asInstanceOf[TypeInformation[Any]]
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
new file mode 100644
index 0000000..26bbd44
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions
+
+import java.sql.{Date, Time, Timestamp}
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.types.Row
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.utils._
+import org.apache.flink.table.functions.ScalarFunction
+import org.junit.Test
+
+class UserDefinedScalarFunctionTest extends ExpressionTestBase {
+
+ @Test
+ def testParameters(): Unit = {
+ testAllApis(
+ Func0('f0),
+ "Func0(f0)",
+ "Func0(f0)",
+ "42")
+
+ testAllApis(
+ Func1('f0),
+ "Func1(f0)",
+ "Func1(f0)",
+ "43")
+
+ testAllApis(
+ Func2('f0, 'f1, 'f3),
+ "Func2(f0, f1, f3)",
+ "Func2(f0, f1, f3)",
+ "42 and Test and SimplePojo(Bob,36)")
+
+ testAllApis(
+ Func0(123),
+ "Func0(123)",
+ "Func0(123)",
+ "123")
+
+ testAllApis(
+ Func6('f4, 'f5, 'f6),
+ "Func6(f4, f5, f6)",
+ "Func6(f4, f5, f6)",
+ "(1990-10-14,12:10:10,1990-10-14 12:10:10.0)")
+ }
+
+ @Test
+ def testNullableParameters(): Unit = {
+ testAllApis(
+ Func3(Null(INT_TYPE_INFO), Null(STRING_TYPE_INFO)),
+ "Func3(Null(INT), Null(STRING))",
+ "Func3(NULL, NULL)",
+ "null and null")
+
+ testAllApis(
+ Func3(Null(INT_TYPE_INFO), "Test"),
+ "Func3(Null(INT), 'Test')",
+ "Func3(NULL, 'Test')",
+ "null and Test")
+
+ testAllApis(
+ Func3(42, Null(STRING_TYPE_INFO)),
+ "Func3(42, Null(STRING))",
+ "Func3(42, NULL)",
+ "42 and null")
+
+ testAllApis(
+ Func0(Null(INT_TYPE_INFO)),
+ "Func0(Null(INT))",
+ "Func0(NULL)",
+ "-1")
+ }
+
+ @Test
+ def testResults(): Unit = {
+ testAllApis(
+ Func4(),
+ "Func4()",
+ "Func4()",
+ "null")
+
+ testAllApis(
+ Func5(),
+ "Func5()",
+ "Func5()",
+ "-1")
+ }
+
+ @Test
+ def testNesting(): Unit = {
+ testAllApis(
+ Func0(Func0('f0)),
+ "Func0(Func0(f0))",
+ "Func0(Func0(f0))",
+ "42")
+
+ testAllApis(
+ Func0(Func0('f0)),
+ "Func0(Func0(f0))",
+ "Func0(Func0(f0))",
+ "42")
+
+ testAllApis(
+ Func7(Func7(Func7(1, 1), Func7(1, 1)), Func7(Func7(1, 1), Func7(1, 1))),
+ "Func7(Func7(Func7(1, 1), Func7(1, 1)), Func7(Func7(1, 1), Func7(1, 1)))",
+ "Func7(Func7(Func7(1, 1), Func7(1, 1)), Func7(Func7(1, 1), Func7(1, 1)))",
+ "8")
+ }
+
+ @Test
+ def testOverloadedParameters(): Unit = {
+ testAllApis(
+ Func8(1),
+ "Func8(1)",
+ "Func8(1)",
+ "a")
+
+ testAllApis(
+ Func8(1, 1),
+ "Func8(1, 1)",
+ "Func8(1, 1)",
+ "b")
+
+ testAllApis(
+ Func8("a", "a"),
+ "Func8('a', 'a')",
+ "Func8('a', 'a')",
+ "c")
+ }
+
+ @Test
+ def testTimePointsOnPrimitives(): Unit = {
+ testAllApis(
+ Func9('f4, 'f5, 'f6),
+ "Func9(f4, f5, f6)",
+ "Func9(f4, f5, f6)",
+ "7591 and 43810000 and 655906210000")
+
+ testAllApis(
+ Func10('f6),
+ "Func10(f6)",
+ "Func10(f6)",
+ "1990-10-14 12:10:10.0")
+ }
+
+ @Test
+ def testTimeIntervalsOnPrimitives(): Unit = {
+ testAllApis(
+ Func11('f7, 'f8),
+ "Func11(f7, f8)",
+ "Func11(f7, f8)",
+ "12 and 1000")
+
+ testAllApis(
+ Func12('f8),
+ "Func12(f8)",
+ "Func12(f8)",
+ "+0 00:00:01.000")
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ override def testData: Any = {
+ val testData = new Row(9)
+ testData.setField(0, 42)
+ testData.setField(1, "Test")
+ testData.setField(2, null)
+ testData.setField(3, SimplePojo("Bob", 36))
+ testData.setField(4, Date.valueOf("1990-10-14"))
+ testData.setField(5, Time.valueOf("12:10:10"))
+ testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10"))
+ testData.setField(7, 12)
+ testData.setField(8, 1000L)
+ testData
+ }
+
+ override def typeInfo: TypeInformation[Any] = {
+ new RowTypeInfo(
+ Types.INT,
+ Types.STRING,
+ Types.BOOLEAN,
+ TypeInformation.of(classOf[SimplePojo]),
+ Types.DATE,
+ Types.TIME,
+ Types.TIMESTAMP,
+ Types.INTERVAL_MONTHS,
+ Types.INTERVAL_MILLIS
+ ).asInstanceOf[TypeInformation[Any]]
+ }
+
+ override def functions: Map[String, ScalarFunction] = Map(
+ "Func0" -> Func0,
+ "Func1" -> Func1,
+ "Func2" -> Func2,
+ "Func3" -> Func3,
+ "Func4" -> Func4,
+ "Func5" -> Func5,
+ "Func6" -> Func6,
+ "Func7" -> Func7,
+ "Func8" -> Func8,
+ "Func9" -> Func9,
+ "Func10" -> Func10,
+ "Func11" -> Func11,
+ "Func12" -> Func12
+ )
+}
+
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
new file mode 100644
index 0000000..8555632
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions.utils
+
+import org.apache.calcite.rex.RexNode
+import org.apache.calcite.sql.`type`.SqlTypeName._
+import org.apache.calcite.sql2rel.RelDecorrelator
+import org.apache.calcite.tools.{Programs, RelBuilder}
+import org.apache.flink.api.common.functions.{Function, MapFunction}
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.{DataSet => JDataSet}
+import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.types.Row
+import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig, TableEnvironment}
+import org.apache.flink.table.calcite.FlinkPlannerImpl
+import org.apache.flink.table.codegen.{CodeGenerator, Compiler, GeneratedFunction}
+import org.apache.flink.table.expressions.{Expression, ExpressionParser}
+import org.apache.flink.table.functions.ScalarFunction
+import org.apache.flink.table.plan.nodes.dataset.{DataSetCalc, DataSetConvention}
+import org.apache.flink.table.plan.rules.FlinkRuleSets
+import org.junit.Assert._
+import org.junit.{After, Before}
+import org.mockito.Mockito._
+
+import scala.collection.mutable
+
+/**
+ * Base test class for expression tests.
+ */
+abstract class ExpressionTestBase {
+
+ private val testExprs = mutable.ArrayBuffer[(RexNode, String)]()
+
+ // setup test utils
+ private val tableName = "testTable"
+ private val context = prepareContext(typeInfo)
+ private val planner = new FlinkPlannerImpl(
+ context._2.getFrameworkConfig,
+ context._2.getPlanner,
+ context._2.getTypeFactory)
+ private val optProgram = Programs.ofRules(FlinkRuleSets.DATASET_OPT_RULES)
+
+ private def prepareContext(typeInfo: TypeInformation[Any]): (RelBuilder, TableEnvironment) = {
+ // create DataSetTable
+ val dataSetMock = mock(classOf[DataSet[Any]])
+ val jDataSetMock = mock(classOf[JDataSet[Any]])
+ when(dataSetMock.javaSet).thenReturn(jDataSetMock)
+ when(jDataSetMock.getType).thenReturn(typeInfo)
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ tEnv.registerDataSet(tableName, dataSetMock)
+ functions.foreach(f => tEnv.registerFunction(f._1, f._2))
+
+ // prepare RelBuilder
+ val relBuilder = tEnv.getRelBuilder
+ relBuilder.scan(tableName)
+
+ (relBuilder, tEnv)
+ }
+
+ def testData: Any
+
+ def typeInfo: TypeInformation[Any]
+
+ def functions: Map[String, ScalarFunction] = Map()
+
+ @Before
+ def resetTestExprs() = {
+ testExprs.clear()
+ }
+
+ @After
+ def evaluateExprs() = {
+ val relBuilder = context._1
+ val config = new TableConfig()
+ val generator = new CodeGenerator(config, false, typeInfo)
+
+ // cast expressions to String
+ val stringTestExprs = testExprs.map(expr => relBuilder.cast(expr._1, VARCHAR)).toSeq
+
+ // generate code
+ val resultType = new RowTypeInfo(Seq.fill(testExprs.size)(STRING_TYPE_INFO): _*)
+ val genExpr = generator.generateResultExpression(
+ resultType,
+ resultType.getFieldNames,
+ stringTestExprs)
+
+ val bodyCode =
+ s"""
+ |${genExpr.code}
+ |return ${genExpr.resultTerm};
+ |""".stripMargin
+
+ val genFunc = generator.generateFunction[MapFunction[Any, String]](
+ "TestFunction",
+ classOf[MapFunction[Any, String]],
+ bodyCode,
+ resultType.asInstanceOf[TypeInformation[Any]])
+
+ // compile and evaluate
+ val clazz = new TestCompiler[MapFunction[Any, String]]().compile(genFunc)
+ val mapper = clazz.newInstance()
+ val result = mapper.map(testData).asInstanceOf[Row]
+
+ // compare
+ testExprs
+ .zipWithIndex
+ .foreach {
+ case ((expr, expected), index) =>
+ val actual = result.getField(index)
+ assertEquals(
+ s"Wrong result for: $expr",
+ expected,
+ if (actual == null) "null" else actual)
+ }
+ }
+
+ private def addSqlTestExpr(sqlExpr: String, expected: String): Unit = {
+ // create RelNode from SQL expression
+ val parsed = planner.parse(s"SELECT $sqlExpr FROM $tableName")
+ val validated = planner.validate(parsed)
+ val converted = planner.rel(validated).rel
+
+ // create DataSetCalc
+ val decorPlan = RelDecorrelator.decorrelateQuery(converted)
+ val flinkOutputProps = converted.getTraitSet.replace(DataSetConvention.INSTANCE).simplify()
+ val dataSetCalc = optProgram.run(context._2.getPlanner, decorPlan, flinkOutputProps)
+
+ // extract RexNode
+ val calcProgram = dataSetCalc
+ .asInstanceOf[DataSetCalc]
+ .calcProgram
+ val expanded = calcProgram.expandLocalRef(calcProgram.getProjectList.get(0))
+
+ testExprs += ((expanded, expected))
+ }
+
+ private def addTableApiTestExpr(tableApiExpr: Expression, expected: String): Unit = {
+ // create RelNode from Table API expression
+ val env = context._2
+ val converted = env
+ .asInstanceOf[BatchTableEnvironment]
+ .scan(tableName)
+ .select(tableApiExpr)
+ .getRelNode
+
+ // create DataSetCalc
+ val decorPlan = RelDecorrelator.decorrelateQuery(converted)
+ val flinkOutputProps = converted.getTraitSet.replace(DataSetConvention.INSTANCE).simplify()
+ val dataSetCalc = optProgram.run(context._2.getPlanner, decorPlan, flinkOutputProps)
+
+ // extract RexNode
+ val calcProgram = dataSetCalc
+ .asInstanceOf[DataSetCalc]
+ .calcProgram
+ val expanded = calcProgram.expandLocalRef(calcProgram.getProjectList.get(0))
+
+ testExprs += ((expanded, expected))
+ }
+
+ private def addTableApiTestExpr(tableApiString: String, expected: String): Unit = {
+ addTableApiTestExpr(ExpressionParser.parseExpression(tableApiString), expected)
+ }
+
+ def testAllApis(
+ expr: Expression,
+ exprString: String,
+ sqlExpr: String,
+ expected: String)
+ : Unit = {
+ addTableApiTestExpr(expr, expected)
+ addTableApiTestExpr(exprString, expected)
+ addSqlTestExpr(sqlExpr, expected)
+ }
+
+ def testTableApi(
+ expr: Expression,
+ exprString: String,
+ expected: String)
+ : Unit = {
+ addTableApiTestExpr(expr, expected)
+ addTableApiTestExpr(exprString, expected)
+ }
+
+ def testSqlApi(
+ sqlExpr: String,
+ expected: String)
+ : Unit = {
+ addSqlTestExpr(sqlExpr, expected)
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ // TestCompiler that uses current class loader
+ class TestCompiler[T <: Function] extends Compiler[T] {
+ def compile(genFunc: GeneratedFunction[T]): Class[T] =
+ compile(getClass.getClassLoader, genFunc.name, genFunc.code)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
new file mode 100644
index 0000000..4e9b6d3
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions.utils
+
+import java.sql.{Date, Time, Timestamp}
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.functions.ScalarFunction
+
+case class SimplePojo(name: String, age: Int)
+
+object Func0 extends ScalarFunction {
+ def eval(index: Int): Int = {
+ index
+ }
+}
+
+object Func1 extends ScalarFunction {
+ def eval(index: Integer): Integer = {
+ index + 1
+ }
+}
+
+object Func2 extends ScalarFunction {
+ def eval(index: Integer, str: String, pojo: SimplePojo): String = {
+ s"$index and $str and $pojo"
+ }
+}
+
+object Func3 extends ScalarFunction {
+ def eval(index: Integer, str: String): String = {
+ s"$index and $str"
+ }
+}
+
+object Func4 extends ScalarFunction {
+ def eval(): Integer = {
+ null
+ }
+}
+
+object Func5 extends ScalarFunction {
+ def eval(): Int = {
+ -1
+ }
+}
+
+object Func6 extends ScalarFunction {
+ def eval(date: Date, time: Time, timestamp: Timestamp): (Date, Time, Timestamp) = {
+ (date, time, timestamp)
+ }
+}
+
+object Func7 extends ScalarFunction {
+ def eval(a: Integer, b: Integer): Integer = {
+ a + b
+ }
+}
+
+object Func8 extends ScalarFunction {
+ def eval(a: Int): String = {
+ "a"
+ }
+
+ def eval(a: Int, b: Int): String = {
+ "b"
+ }
+
+ def eval(a: String, b: String): String = {
+ "c"
+ }
+}
+
+object Func9 extends ScalarFunction {
+ def eval(a: Int, b: Int, c: Long): String = {
+ s"$a and $b and $c"
+ }
+}
+
+object Func10 extends ScalarFunction {
+ def eval(c: Long): Long = {
+ c
+ }
+
+ override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = {
+ Types.TIMESTAMP
+ }
+}
+
+object Func11 extends ScalarFunction {
+ def eval(a: Int, b: Long): String = {
+ s"$a and $b"
+ }
+}
+
+object Func12 extends ScalarFunction {
+ def eval(a: Long): Long = {
+ a
+ }
+
+ override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = {
+ Types.INTERVAL_MILLIS
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractorTest.scala
new file mode 100644
index 0000000..f7385ac
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractorTest.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.util
+
+import java.math.BigDecimal
+
+import org.apache.calcite.adapter.java.JavaTypeFactory
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
+import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR}
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.flink.table.plan.rules.util.RexProgramProjectExtractor
+
+import scala.collection.JavaConverters._
+import RexProgramProjectExtractor._
+import org.junit.{Assert, Before, Test}
+
+/**
+ * This class is responsible for testing RexProgramProjectExtractor
+ */
+class RexProgramProjectExtractorTest {
+ private var typeFactory: JavaTypeFactory = null
+ private var rexBuilder: RexBuilder = null
+ private var allFieldTypes: Seq[RelDataType] = null
+ private val allFieldNames = List("name", "id", "amount", "price")
+
+ @Before
+ def setUp: Unit = {
+ typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
+ rexBuilder = new RexBuilder(typeFactory)
+ allFieldTypes = List(VARCHAR, BIGINT, INTEGER, DOUBLE).map(typeFactory.createSqlType(_))
+ }
+
+ @Test
+ def testExtractRefInputFields: Unit = {
+ val usedFields = extractRefInputFields(buildRexProgram)
+ Assert.assertArrayEquals(usedFields, Array(2, 3, 1))
+ }
+
+ @Test
+ def testRewriteRexProgram: Unit = {
+ val originRexProgram = buildRexProgram
+ Assert.assertTrue(extractExprStrList(originRexProgram).sameElements(Array(
+ "$0",
+ "$1",
+ "$2",
+ "$3",
+ "*($t2, $t3)",
+ "100",
+ "<($t4, $t5)",
+ "6",
+ ">($t1, $t7)",
+ "AND($t6, $t8)")))
+ // use amount, id, price fields to create a new RexProgram
+ val usedFields = Array(2, 3, 1)
+ val types = usedFields.map(allFieldTypes(_)).toList.asJava
+ val names = usedFields.map(allFieldNames(_)).toList.asJava
+ val inputRowType = typeFactory.createStructType(types, names)
+ val newRexProgram = rewriteRexProgram(originRexProgram, inputRowType, usedFields, rexBuilder)
+ Assert.assertTrue(extractExprStrList(newRexProgram).sameElements(Array(
+ "$0",
+ "$1",
+ "$2",
+ "*($t0, $t1)",
+ "100",
+ "<($t3, $t4)",
+ "6",
+ ">($t2, $t6)",
+ "AND($t5, $t7)")))
+ }
+
+ private def buildRexProgram: RexProgram = {
+ val types = allFieldTypes.asJava
+ val names = allFieldNames.asJava
+ val inputRowType = typeFactory.createStructType(types, names)
+ val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+ val t0 = rexBuilder.makeInputRef(types.get(2), 2)
+ val t1 = rexBuilder.makeInputRef(types.get(1), 1)
+ val t2 = rexBuilder.makeInputRef(types.get(3), 3)
+ val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
+ val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+ val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
+ // project: amount, amount * price
+ builder.addProject(t0, "amount")
+ builder.addProject(t3, "total")
+ // condition: amount * price < 100 and id > 6
+ val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
+ val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
+ val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
+ builder.addCondition(t8)
+ builder.getProgram
+ }
+
+ /**
+ * extract all expression string list from input RexProgram expression lists
+ *
+ * @param rexProgram input RexProgram instance to analyze
+ * @return all expression string list of input RexProgram expression lists
+ */
+ private def extractExprStrList(rexProgram: RexProgram) = {
+ rexProgram.getExprList.asScala.map(_.toString)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AggregateTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AggregateTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AggregateTestBase.scala
new file mode 100644
index 0000000..0ca101d
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AggregateTestBase.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+import java.math.BigDecimal
+import org.apache.flink.types.Row
+
+import org.junit.Test
+import org.junit.Assert.assertEquals
+
+abstract class AggregateTestBase[T] {
+
+ private val offset = 2
+ private val rowArity: Int = offset + aggregator.intermediateDataType.length
+
+ def inputValueSets: Seq[Seq[_]]
+
+ def expectedResults: Seq[T]
+
+ def aggregator: Aggregate[T]
+
+ private def createAggregator(): Aggregate[T] = {
+ val agg = aggregator
+ agg.setAggOffsetInRow(offset)
+ agg
+ }
+
+ private def createRow(): Row = {
+ new Row(rowArity)
+ }
+
+ @Test
+ def testAggregate(): Unit = {
+
+ // iterate over input sets
+ for((vals, expected) <- inputValueSets.zip(expectedResults)) {
+
+ // prepare mapper
+ val rows: Seq[Row] = prepare(vals)
+
+ val result = if (aggregator.supportPartial) {
+ // test with combiner
+ val (firstVals, secondVals) = rows.splitAt(rows.length / 2)
+ val combined = partialAgg(firstVals) :: partialAgg(secondVals) :: Nil
+ finalAgg(combined)
+
+ } else {
+ // test without combiner
+ finalAgg(rows)
+ }
+
+ (expected, result) match {
+ case (e: BigDecimal, r: BigDecimal) =>
+ // BigDecimal.equals() value and scale but we are only interested in value.
+ assert(e.compareTo(r) == 0)
+ case _ =>
+ assertEquals(expected, result)
+ }
+ }
+ }
+
+ private def prepare(vals: Seq[_]): Seq[Row] = {
+
+ val agg = createAggregator()
+
+ vals.map { v =>
+ val row = createRow()
+ agg.prepare(v, row)
+ row
+ }
+ }
+
+ private def partialAgg(rows: Seq[Row]): Row = {
+
+ val agg = createAggregator()
+ val aggBuf = createRow()
+
+ agg.initiate(aggBuf)
+ rows.foreach(v => agg.merge(v, aggBuf))
+
+ aggBuf
+ }
+
+ private def finalAgg(rows: Seq[Row]): T = {
+
+ val agg = createAggregator()
+ val aggBuf = createRow()
+
+ agg.initiate(aggBuf)
+ rows.foreach(v => agg.merge(v, aggBuf))
+
+ agg.evaluate(partialAgg(rows))
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AvgAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AvgAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AvgAggregateTest.scala
new file mode 100644
index 0000000..a72d08b
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AvgAggregateTest.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+import java.math.BigDecimal
+
+abstract class AvgAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
+
+ private val numeric: Numeric[T] = implicitly[Numeric[T]]
+
+ def minVal: T
+ def maxVal: T
+
+ override def inputValueSets: Seq[Seq[T]] = Seq(
+ Seq(
+ minVal,
+ minVal,
+ null.asInstanceOf[T],
+ minVal,
+ minVal,
+ null.asInstanceOf[T],
+ minVal,
+ minVal,
+ minVal
+ ),
+ Seq(
+ maxVal,
+ maxVal,
+ null.asInstanceOf[T],
+ maxVal,
+ maxVal,
+ null.asInstanceOf[T],
+ maxVal,
+ maxVal,
+ maxVal
+ ),
+ Seq(
+ minVal,
+ maxVal,
+ null.asInstanceOf[T],
+ numeric.fromInt(0),
+ numeric.negate(maxVal),
+ numeric.negate(minVal),
+ null.asInstanceOf[T]
+ ),
+ Seq(
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T]
+ )
+ )
+
+ override def expectedResults: Seq[T] = Seq(
+ minVal,
+ maxVal,
+ numeric.fromInt(0),
+ null.asInstanceOf[T]
+ )
+}
+
+class ByteAvgAggregateTest extends AvgAggregateTestBase[Byte] {
+
+ override def minVal = (Byte.MinValue + 1).toByte
+ override def maxVal = (Byte.MaxValue - 1).toByte
+
+ override def aggregator = new ByteAvgAggregate()
+}
+
+class ShortAvgAggregateTest extends AvgAggregateTestBase[Short] {
+
+ override def minVal = (Short.MinValue + 1).toShort
+ override def maxVal = (Short.MaxValue - 1).toShort
+
+ override def aggregator = new ShortAvgAggregate()
+}
+
+class IntAvgAggregateTest extends AvgAggregateTestBase[Int] {
+
+ override def minVal = Int.MinValue + 1
+ override def maxVal = Int.MaxValue - 1
+
+ override def aggregator = new IntAvgAggregate()
+}
+
+class LongAvgAggregateTest extends AvgAggregateTestBase[Long] {
+
+ override def minVal = Long.MinValue + 1
+ override def maxVal = Long.MaxValue - 1
+
+ override def aggregator = new LongAvgAggregate()
+}
+
+class FloatAvgAggregateTest extends AvgAggregateTestBase[Float] {
+
+ override def minVal = Float.MinValue
+ override def maxVal = Float.MaxValue
+
+ override def aggregator = new FloatAvgAggregate()
+}
+
+class DoubleAvgAggregateTest extends AvgAggregateTestBase[Double] {
+
+ override def minVal = Float.MinValue
+ override def maxVal = Float.MaxValue
+
+ override def aggregator = new DoubleAvgAggregate()
+}
+
+class DecimalAvgAggregateTest extends AggregateTestBase[BigDecimal] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(
+ new BigDecimal("987654321000000"),
+ new BigDecimal("-0.000000000012345"),
+ null,
+ new BigDecimal("0.000000000012345"),
+ new BigDecimal("-987654321000000"),
+ null,
+ new BigDecimal("0")
+ ),
+ Seq(
+ null,
+ null,
+ null,
+ null
+ )
+ )
+
+ override def expectedResults: Seq[BigDecimal] = Seq(
+ BigDecimal.ZERO,
+ null
+ )
+
+ override def aggregator: Aggregate[BigDecimal] = new DecimalAvgAggregate()
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/CountAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/CountAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/CountAggregateTest.scala
new file mode 100644
index 0000000..55f73b4
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/CountAggregateTest.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+class CountAggregateTest extends AggregateTestBase[Long] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq("a", "b", null, "c", null, "d", "e", null, "f"),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[Long] = Seq(6L, 0L)
+
+ override def aggregator: Aggregate[Long] = new CountAggregate()
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MaxAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MaxAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MaxAggregateTest.scala
new file mode 100644
index 0000000..1bf879d
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MaxAggregateTest.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+import java.math.BigDecimal
+
+abstract class MaxAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
+
+ private val numeric: Numeric[T] = implicitly[Numeric[T]]
+
+ def minVal: T
+ def maxVal: T
+
+ override def inputValueSets: Seq[Seq[T]] = Seq(
+ Seq(
+ numeric.fromInt(1),
+ null.asInstanceOf[T],
+ maxVal,
+ numeric.fromInt(-99),
+ numeric.fromInt(3),
+ numeric.fromInt(56),
+ numeric.fromInt(0),
+ minVal,
+ numeric.fromInt(-20),
+ numeric.fromInt(17),
+ null.asInstanceOf[T]
+ ),
+ Seq(
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T]
+ )
+ )
+
+ override def expectedResults: Seq[T] = Seq(
+ maxVal,
+ null.asInstanceOf[T]
+ )
+}
+
+class ByteMaxAggregateTest extends MaxAggregateTestBase[Byte] {
+
+ override def minVal = (Byte.MinValue + 1).toByte
+ override def maxVal = (Byte.MaxValue - 1).toByte
+
+ override def aggregator: Aggregate[Byte] = new ByteMaxAggregate()
+}
+
+class ShortMaxAggregateTest extends MaxAggregateTestBase[Short] {
+
+ override def minVal = (Short.MinValue + 1).toShort
+ override def maxVal = (Short.MaxValue - 1).toShort
+
+ override def aggregator: Aggregate[Short] = new ShortMaxAggregate()
+}
+
+class IntMaxAggregateTest extends MaxAggregateTestBase[Int] {
+
+ override def minVal = Int.MinValue + 1
+ override def maxVal = Int.MaxValue - 1
+
+ override def aggregator: Aggregate[Int] = new IntMaxAggregate()
+}
+
+class LongMaxAggregateTest extends MaxAggregateTestBase[Long] {
+
+ override def minVal = Long.MinValue + 1
+ override def maxVal = Long.MaxValue - 1
+
+ override def aggregator: Aggregate[Long] = new LongMaxAggregate()
+}
+
+class FloatMaxAggregateTest extends MaxAggregateTestBase[Float] {
+
+ override def minVal = Float.MinValue / 2
+ override def maxVal = Float.MaxValue / 2
+
+ override def aggregator: Aggregate[Float] = new FloatMaxAggregate()
+}
+
+class DoubleMaxAggregateTest extends MaxAggregateTestBase[Double] {
+
+ override def minVal = Double.MinValue / 2
+ override def maxVal = Double.MaxValue / 2
+
+ override def aggregator: Aggregate[Double] = new DoubleMaxAggregate()
+}
+
+class BooleanMaxAggregateTest extends AggregateTestBase[Boolean] {
+
+ override def inputValueSets: Seq[Seq[Boolean]] = Seq(
+ Seq(
+ false,
+ false,
+ false
+ ),
+ Seq(
+ true,
+ true,
+ true
+ ),
+ Seq(
+ true,
+ false,
+ null.asInstanceOf[Boolean],
+ true,
+ false,
+ true,
+ null.asInstanceOf[Boolean]
+ ),
+ Seq(
+ null.asInstanceOf[Boolean],
+ null.asInstanceOf[Boolean],
+ null.asInstanceOf[Boolean]
+ )
+ )
+
+ override def expectedResults: Seq[Boolean] = Seq(
+ false,
+ true,
+ true,
+ null.asInstanceOf[Boolean]
+ )
+
+ override def aggregator: Aggregate[Boolean] = new BooleanMaxAggregate()
+}
+
+class DecimalMaxAggregateTest extends AggregateTestBase[BigDecimal] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(
+ new BigDecimal("1"),
+ new BigDecimal("1000.000001"),
+ new BigDecimal("-1"),
+ new BigDecimal("-999.998999"),
+ null,
+ new BigDecimal("0"),
+ new BigDecimal("-999.999"),
+ null,
+ new BigDecimal("999.999")
+ ),
+ Seq(
+ null,
+ null,
+ null,
+ null,
+ null
+ )
+ )
+
+ override def expectedResults: Seq[BigDecimal] = Seq(
+ new BigDecimal("1000.000001"),
+ null
+ )
+
+ override def aggregator: Aggregate[BigDecimal] = new DecimalMaxAggregate()
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MinAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MinAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MinAggregateTest.scala
new file mode 100644
index 0000000..3e2404d
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MinAggregateTest.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+import java.math.BigDecimal
+
+abstract class MinAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
+
+ private val numeric: Numeric[T] = implicitly[Numeric[T]]
+
+ def minVal: T
+ def maxVal: T
+
+ override def inputValueSets: Seq[Seq[T]] = Seq(
+ Seq(
+ numeric.fromInt(1),
+ null.asInstanceOf[T],
+ maxVal,
+ numeric.fromInt(-99),
+ numeric.fromInt(3),
+ numeric.fromInt(56),
+ numeric.fromInt(0),
+ minVal,
+ numeric.fromInt(-20),
+ numeric.fromInt(17),
+ null.asInstanceOf[T]
+ ),
+ Seq(
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T]
+ )
+ )
+
+ override def expectedResults: Seq[T] = Seq(
+ minVal,
+ null.asInstanceOf[T]
+ )
+}
+
+class ByteMinAggregateTest extends MinAggregateTestBase[Byte] {
+
+ override def minVal = (Byte.MinValue + 1).toByte
+ override def maxVal = (Byte.MaxValue - 1).toByte
+
+ override def aggregator: Aggregate[Byte] = new ByteMinAggregate()
+}
+
+class ShortMinAggregateTest extends MinAggregateTestBase[Short] {
+
+ override def minVal = (Short.MinValue + 1).toShort
+ override def maxVal = (Short.MaxValue - 1).toShort
+
+ override def aggregator: Aggregate[Short] = new ShortMinAggregate()
+}
+
+class IntMinAggregateTest extends MinAggregateTestBase[Int] {
+
+ override def minVal = Int.MinValue + 1
+ override def maxVal = Int.MaxValue - 1
+
+ override def aggregator: Aggregate[Int] = new IntMinAggregate()
+}
+
+class LongMinAggregateTest extends MinAggregateTestBase[Long] {
+
+ override def minVal = Long.MinValue + 1
+ override def maxVal = Long.MaxValue - 1
+
+ override def aggregator: Aggregate[Long] = new LongMinAggregate()
+}
+
+class FloatMinAggregateTest extends MinAggregateTestBase[Float] {
+
+ override def minVal = Float.MinValue / 2
+ override def maxVal = Float.MaxValue / 2
+
+ override def aggregator: Aggregate[Float] = new FloatMinAggregate()
+}
+
+class DoubleMinAggregateTest extends MinAggregateTestBase[Double] {
+
+ override def minVal = Double.MinValue / 2
+ override def maxVal = Double.MaxValue / 2
+
+ override def aggregator: Aggregate[Double] = new DoubleMinAggregate()
+}
+
+class BooleanMinAggregateTest extends AggregateTestBase[Boolean] {
+
+ override def inputValueSets: Seq[Seq[Boolean]] = Seq(
+ Seq(
+ false,
+ false,
+ false
+ ),
+ Seq(
+ true,
+ true,
+ true
+ ),
+ Seq(
+ true,
+ false,
+ null.asInstanceOf[Boolean],
+ true,
+ false,
+ true,
+ null.asInstanceOf[Boolean]
+ ),
+ Seq(
+ null.asInstanceOf[Boolean],
+ null.asInstanceOf[Boolean],
+ null.asInstanceOf[Boolean]
+ )
+ )
+
+ override def expectedResults: Seq[Boolean] = Seq(
+ false,
+ true,
+ false,
+ null.asInstanceOf[Boolean]
+ )
+
+ override def aggregator: Aggregate[Boolean] = new BooleanMinAggregate()
+}
+
+class DecimalMinAggregateTest extends AggregateTestBase[BigDecimal] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(
+ new BigDecimal("1"),
+ new BigDecimal("1000"),
+ new BigDecimal("-1"),
+ new BigDecimal("-999.998999"),
+ null,
+ new BigDecimal("0"),
+ new BigDecimal("-999.999"),
+ null,
+ new BigDecimal("999.999")
+ ),
+ Seq(
+ null,
+ null,
+ null,
+ null,
+ null
+ )
+ )
+
+ override def expectedResults: Seq[BigDecimal] = Seq(
+ new BigDecimal("-999.999"),
+ null
+ )
+
+ override def aggregator: Aggregate[BigDecimal] = new DecimalMinAggregate()
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/SumAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/SumAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/SumAggregateTest.scala
new file mode 100644
index 0000000..c085334
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/SumAggregateTest.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+import java.math.BigDecimal
+
+abstract class SumAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
+
+ private val numeric: Numeric[T] = implicitly[Numeric[T]]
+
+ def maxVal: T
+ private val minVal = numeric.negate(maxVal)
+
+ override def inputValueSets: Seq[Seq[T]] = Seq(
+ Seq(
+ minVal,
+ numeric.fromInt(1),
+ null.asInstanceOf[T],
+ numeric.fromInt(2),
+ numeric.fromInt(3),
+ numeric.fromInt(4),
+ numeric.fromInt(5),
+ numeric.fromInt(-10),
+ numeric.fromInt(-20),
+ numeric.fromInt(17),
+ null.asInstanceOf[T],
+ maxVal
+ ),
+ Seq(
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T],
+ null.asInstanceOf[T]
+ )
+ )
+
+ override def expectedResults: Seq[T] = Seq(
+ numeric.fromInt(2),
+ null.asInstanceOf[T]
+ )
+}
+
+class ByteSumAggregateTest extends SumAggregateTestBase[Byte] {
+
+ override def maxVal = (Byte.MaxValue / 2).toByte
+
+ override def aggregator: Aggregate[Byte] = new ByteSumAggregate
+}
+
+class ShortSumAggregateTest extends SumAggregateTestBase[Short] {
+
+ override def maxVal = (Short.MaxValue / 2).toShort
+
+ override def aggregator: Aggregate[Short] = new ShortSumAggregate
+}
+
+class IntSumAggregateTest extends SumAggregateTestBase[Int] {
+
+ override def maxVal = Int.MaxValue / 2
+
+ override def aggregator: Aggregate[Int] = new IntSumAggregate
+}
+
+class LongSumAggregateTest extends SumAggregateTestBase[Long] {
+
+ override def maxVal = Long.MaxValue / 2
+
+ override def aggregator: Aggregate[Long] = new LongSumAggregate
+}
+
+class FloatSumAggregateTest extends SumAggregateTestBase[Float] {
+
+ override def maxVal = 12345.6789f
+
+ override def aggregator: Aggregate[Float] = new FloatSumAggregate
+}
+
+class DoubleSumAggregateTest extends SumAggregateTestBase[Double] {
+
+ override def maxVal = 12345.6789d
+
+ override def aggregator: Aggregate[Double] = new DoubleSumAggregate
+}
+
+class DecimalSumAggregateTest extends AggregateTestBase[BigDecimal] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(
+ new BigDecimal("1"),
+ new BigDecimal("2"),
+ new BigDecimal("3"),
+ null,
+ new BigDecimal("0"),
+ new BigDecimal("-1000"),
+ new BigDecimal("0.000000000002"),
+ new BigDecimal("1000"),
+ new BigDecimal("-0.000000000001"),
+ new BigDecimal("999.999"),
+ null,
+ new BigDecimal("4"),
+ new BigDecimal("-999.999"),
+ null
+ ),
+ Seq(
+ null,
+ null,
+ null,
+ null,
+ null
+ )
+ )
+
+ override def expectedResults: Seq[BigDecimal] = Seq(
+ new BigDecimal("10.000000000001"),
+ null
+ )
+
+ override def aggregator: Aggregate[BigDecimal] = new DecimalSumAggregate()
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
new file mode 100644
index 0000000..550669e
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.dataset
+
+import org.apache.flink.api.scala._
+import org.apache.flink.types.Row
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.utils._
+import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.apache.flink.test.util.TestBaseUtils
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class DataSetCorrelateITCase(
+ mode: TestExecutionMode,
+ configMode: TableConfigMode)
+ extends TableProgramsTestBase(mode, configMode) {
+
+ @Test
+ def testCrossJoin(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func1 = new TableFunc1
+ val result = in.join(func1('c) as 's).select('c, 's).toDataSet[Row]
+ val results = result.collect()
+ val expected = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" +
+ "Anna#44,Anna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+
+ // with overloading
+ val result2 = in.join(func1('c, "$") as 's).select('c, 's).toDataSet[Row]
+ val results2 = result2.collect()
+ val expected2 = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" +
+ "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n"
+ TestBaseUtils.compareResultAsText(results2.asJava, expected2)
+ }
+
+ @Test
+ def testLeftOuterJoin(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val func2 = new TableFunc2
+ val result = in.leftOuterJoin(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row]
+ val results = result.collect()
+ val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testWithFilter(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func0 = new TableFunc0
+
+ val result = in
+ .join(func0('c) as ('name, 'age))
+ .select('c, 'name, 'age)
+ .filter('age > 20)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testCustomReturnType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func2 = new TableFunc2
+
+ val result = in
+ .join(func2('c) as ('name, 'len))
+ .select('c, 'name, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" +
+ "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testHierarchyType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val hierarchy = new HierarchyTableFunction
+ val result = in
+ .join(hierarchy('c) as ('name, 'adult, 'len))
+ .select('c, 'name, 'adult, 'len)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" +
+ "Anna#44,Anna,true,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testPojoType(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+
+ val pojo = new PojoTableFunc()
+ val result = in
+ .join(pojo('c))
+ .select('c, 'name, 'age)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ @Test
+ def testUDTFWithScalarFunction(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tableEnv = TableEnvironment.getTableEnvironment(env, config)
+ val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
+ val func1 = new TableFunc1
+
+ val result = in
+ .join(func1('c.substring(2)) as 's)
+ .select('c, 's)
+ .toDataSet[Row]
+
+ val results = result.collect()
+ val expected = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" +
+ "Anna#44,nna\n" + "Anna#44,44\n"
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
+
+ private def testData(
+ env: ExecutionEnvironment)
+ : DataSet[(Int, Long, String)] = {
+
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "Jack#22"))
+ data.+=((2, 2L, "John#19"))
+ data.+=((3, 2L, "Anna#44"))
+ data.+=((4, 3L, "nosharp"))
+ env.fromCollection(data)
+ }
+}