You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/10/10 21:09:57 UTC
[1/5] flink git commit: [FLINK-7410] [table] Use
UserDefinedFunction.toString() to display operator names of UDFs.
Repository: flink
Updated Branches:
refs/heads/master 9829ca00d -> 427dfe42e
[FLINK-7410] [table] Use UserDefinedFunction.toString() to display operator names of UDFs.
This closes #4624.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/427dfe42
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/427dfe42
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/427dfe42
Branch: refs/heads/master
Commit: 427dfe42e2bea891b40e662bc97cdea57cdae3f5
Parents: dccdba1
Author: 军长 <he...@alibaba-inc.com>
Authored: Wed Aug 30 19:30:52 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Oct 10 23:09:07 2017 +0200
----------------------------------------------------------------------
.../flink/table/api/TableEnvironment.scala | 7 +++--
.../flink/table/expressions/aggregations.scala | 3 +-
.../apache/flink/table/expressions/call.scala | 1 +
.../flink/table/functions/ScalarFunction.scala | 2 --
.../flink/table/functions/TableFunction.scala | 2 --
.../table/functions/UserDefinedFunction.scala | 8 ++++-
.../table/functions/utils/AggSqlFunction.scala | 14 ++++++++-
.../functions/utils/ScalarSqlFunction.scala | 4 +++
.../functions/utils/TableSqlFunction.scala | 3 ++
.../utils/UserDefinedFunctionUtils.scala | 8 +++--
.../flink/table/plan/logical/operators.scala | 1 +
.../table/plan/nodes/CommonCorrelate.scala | 20 ++++++++----
.../plan/nodes/dataset/DataSetCorrelate.scala | 20 +++++++++---
.../nodes/datastream/DataStreamCorrelate.scala | 16 ++++++++--
.../plan/rules/logical/LogicalUnnestRule.scala | 1 +
.../utils/JavaUserDefinedAggFunctions.java | 5 +++
.../flink/table/api/TableSourceTest.scala | 2 +-
.../table/api/batch/sql/CorrelateTest.scala | 30 ++++++++++++------
.../flink/table/api/batch/table/CalcTest.scala | 12 +++----
.../table/api/batch/table/CorrelateTest.scala | 10 +++---
.../table/api/batch/table/GroupWindowTest.scala | 6 ++--
.../table/api/stream/sql/CorrelateTest.scala | 30 ++++++++++++------
.../table/api/stream/table/CorrelateTest.scala | 33 +++++++++++++-------
.../api/stream/table/GroupWindowTest.scala | 6 ++--
.../table/api/stream/table/OverWindowTest.scala | 3 +-
.../plan/ExpressionReductionRulesTest.scala | 2 +-
.../plan/TimeIndicatorConversionTest.scala | 3 +-
.../table/runtime/stream/table/CalcITCase.scala | 2 +-
28 files changed, 176 insertions(+), 78 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
index dc82a87..54877ba 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
@@ -331,7 +331,9 @@ abstract class TableEnvironment(val config: TableConfig) {
functionCatalog.registerFunction(name, function.getClass)
// register in SQL API
- functionCatalog.registerSqlFunction(createScalarSqlFunction(name, function, typeFactory))
+ functionCatalog.registerSqlFunction(
+ createScalarSqlFunction(name, name, function, typeFactory)
+ )
}
/**
@@ -355,7 +357,7 @@ abstract class TableEnvironment(val config: TableConfig) {
functionCatalog.registerFunction(name, function.getClass)
// register in SQL API
- val sqlFunction = createTableSqlFunction(name, function, typeInfo, typeFactory)
+ val sqlFunction = createTableSqlFunction(name, name, function, typeInfo, typeFactory)
functionCatalog.registerSqlFunction(sqlFunction)
}
@@ -384,6 +386,7 @@ abstract class TableEnvironment(val config: TableConfig) {
// register in SQL API
val sqlFunctions = createAggregateSqlFunction(
name,
+ name,
function,
resultTypeInfo,
accTypeInfo,
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index c2d1bdf..1ffcb12 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -261,7 +261,8 @@ case class AggFunctionCall(
override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
AggSqlFunction(
- aggregateFunction.getClass.getSimpleName,
+ aggregateFunction.functionIdentifier,
+ aggregateFunction.toString,
aggregateFunction,
resultType,
accTypeInfo,
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
index cad9ccc..8454555 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
@@ -272,6 +272,7 @@ case class ScalarFunctionCall(
relBuilder.call(
createScalarSqlFunction(
scalarFunction.functionIdentifier,
+ scalarFunction.toString,
scalarFunction,
typeFactory),
parameters.map(_.toRexNode): _*)
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala
index 40c60ac..e41b876 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala
@@ -56,8 +56,6 @@ abstract class ScalarFunction extends UserDefinedFunction {
ScalarFunctionCall(this, params)
}
- override def toString: String = getClass.getCanonicalName
-
// ----------------------------------------------------------------------------------------------
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
index b6e801a..ff69954 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
@@ -81,8 +81,6 @@ import org.apache.flink.util.Collector
*/
abstract class TableFunction[T] extends UserDefinedFunction {
- override def toString: String = getClass.getCanonicalName
-
// ----------------------------------------------------------------------------------------------
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
index b841b31..15bcb17 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
@@ -41,7 +41,7 @@ abstract class UserDefinedFunction extends Serializable {
def close(): Unit = {}
/**
- * @return true iff a call to this function is guaranteed to always return
+ * @return true if and only if a call to this function is guaranteed to always return
* the same result given the same parameters; true is assumed by default
* if user's function is not pure functional, like random(), date(), now()...
* isDeterministic must return false
@@ -52,4 +52,10 @@ abstract class UserDefinedFunction extends Serializable {
val md5 = DigestUtils.md5Hex(serialize(this))
getClass.getCanonicalName.replace('.', '$').concat("$").concat(md5)
}
+
+ /**
+ * Returns the name of the UDF that is used for plan explain and logging.
+ */
+ override def toString: String = getClass.getSimpleName
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
index bb71d63..f44598b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
@@ -35,6 +35,7 @@ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
* Calcite wrapper for user-defined aggregate functions.
*
* @param name function name (used by SQL parser)
+ * @param displayName name to be displayed in operator name
* @param aggregateFunction aggregate function to be called
* @param returnType the type information of returned value
* @param accType the type information of the accumulator
@@ -42,6 +43,7 @@ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
*/
class AggSqlFunction(
name: String,
+ displayName: String,
aggregateFunction: AggregateFunction[_, _],
val returnType: TypeInformation[_],
val accType: TypeInformation[_],
@@ -62,19 +64,29 @@ class AggSqlFunction(
def getFunction: AggregateFunction[_, _] = aggregateFunction
override def isDeterministic: Boolean = aggregateFunction.isDeterministic
+
+ override def toString: String = displayName
}
object AggSqlFunction {
def apply(
name: String,
+ displayName: String,
aggregateFunction: AggregateFunction[_, _],
returnType: TypeInformation[_],
accType: TypeInformation[_],
typeFactory: FlinkTypeFactory,
requiresOver: Boolean): AggSqlFunction = {
- new AggSqlFunction(name, aggregateFunction, returnType, accType, typeFactory, requiresOver)
+ new AggSqlFunction(
+ name,
+ displayName,
+ aggregateFunction,
+ returnType,
+ accType,
+ typeFactory,
+ requiresOver)
}
private[flink] def createOperandTypeInference(
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
index 784bca7..27e093d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
@@ -35,11 +35,13 @@ import scala.collection.JavaConverters._
* Calcite wrapper for user-defined scalar functions.
*
* @param name function name (used by SQL parser)
+ * @param displayName name to be displayed in operator name
* @param scalarFunction scalar function to be called
* @param typeFactory type factory for converting Flink's between Calcite's types
*/
class ScalarSqlFunction(
name: String,
+ displayName: String,
scalarFunction: ScalarFunction,
typeFactory: FlinkTypeFactory)
extends SqlFunction(
@@ -53,6 +55,8 @@ class ScalarSqlFunction(
def getScalarFunction = scalarFunction
override def isDeterministic: Boolean = scalarFunction.isDeterministic
+
+ override def toString: String = displayName
}
object ScalarSqlFunction {
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala
index 6d9742c..741d15b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala
@@ -37,6 +37,7 @@ import org.apache.flink.table.functions.utils.TableSqlFunction._
*/
class TableSqlFunction(
name: String,
+ displayName: String,
tableFunction: TableFunction[_],
rowTypeInfo: TypeInformation[_],
typeFactory: FlinkTypeFactory,
@@ -66,6 +67,8 @@ class TableSqlFunction(
def getPojoFieldMapping: Array[Int] = functionImpl.fieldIndexes
override def isDeterministic: Boolean = tableFunction.isDeterministic
+
+ override def toString: String = displayName
}
object TableSqlFunction {
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 6a90569..3cd694a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -251,10 +251,11 @@ object UserDefinedFunctionUtils {
*/
def createScalarSqlFunction(
name: String,
+ displayName: String,
function: ScalarFunction,
typeFactory: FlinkTypeFactory)
: SqlFunction = {
- new ScalarSqlFunction(name, function, typeFactory)
+ new ScalarSqlFunction(name, displayName, function, typeFactory)
}
/**
@@ -268,13 +269,14 @@ object UserDefinedFunctionUtils {
*/
def createTableSqlFunction(
name: String,
+ displayName: String,
tableFunction: TableFunction[_],
resultType: TypeInformation[_],
typeFactory: FlinkTypeFactory)
: SqlFunction = {
val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType)
val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames)
- new TableSqlFunction(name, tableFunction, resultType, typeFactory, function)
+ new TableSqlFunction(name, displayName, tableFunction, resultType, typeFactory, function)
}
/**
@@ -287,6 +289,7 @@ object UserDefinedFunctionUtils {
*/
def createAggregateSqlFunction(
name: String,
+ displayName: String,
aggFunction: AggregateFunction[_, _],
resultType: TypeInformation[_],
accTypeInfo: TypeInformation[_],
@@ -297,6 +300,7 @@ object UserDefinedFunctionUtils {
AggSqlFunction(
name,
+ displayName,
aggFunction,
resultType,
accTypeInfo,
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
index 559d20d..0c8efd7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
@@ -728,6 +728,7 @@ case class LogicalTableFunctionCall(
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
val sqlFunction = new TableSqlFunction(
tableFunction.functionIdentifier,
+ tableFunction.toString,
tableFunction,
resultType,
typeFactory,
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala
index 7c01fde..c53f090 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala
@@ -179,21 +179,29 @@ trait CommonCorrelate {
}
private[flink] def selectToString(rowType: RelDataType): String = {
- rowType.getFieldNames.asScala.mkString(",")
+ rowType.getFieldNames.asScala.mkString(", ")
}
private[flink] def correlateOpName(
+ inputType: RelDataType,
rexCall: RexCall,
sqlFunction: TableSqlFunction,
- rowType: RelDataType)
+ rowType: RelDataType,
+ expression: (RexNode, List[String], Option[List[RexNode]]) => String)
: String = {
- s"correlate: ${correlateToString(rexCall, sqlFunction)}, select: ${selectToString(rowType)}"
+ s"correlate: ${correlateToString(inputType, rexCall, sqlFunction, expression)}," +
+ s" select: ${selectToString(rowType)}"
}
- private[flink] def correlateToString(rexCall: RexCall, sqlFunction: TableSqlFunction): String = {
- val udtfName = sqlFunction.getName
- val operands = rexCall.getOperands.asScala.map(_.toString).mkString(",")
+ private[flink] def correlateToString(
+ inputType: RelDataType,
+ rexCall: RexCall,
+ sqlFunction: TableSqlFunction,
+ expression: (RexNode, List[String], Option[List[RexNode]]) => String): String = {
+ val inFields = inputType.getFieldNames.asScala.toList
+ val udtfName = sqlFunction.toString
+ val operands = rexCall.getOperands.asScala.map(expression(_, inFields, None)).mkString(", ")
s"table($udtfName($operands))"
}
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
index 731d2e5..5f94562 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
@@ -76,7 +76,7 @@ class DataSetCorrelate(
override def toString: String = {
val rexCall = scan.getCall.asInstanceOf[RexCall]
val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
- correlateToString(rexCall, sqlFunction)
+ correlateToString(joinRowType, rexCall, sqlFunction, getExpressionString)
}
override def explainTerms(pw: RelWriter): RelWriter = {
@@ -84,7 +84,11 @@ class DataSetCorrelate(
val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
super.explainTerms(pw)
.item("invocation", scan.getCall)
- .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
+ .item("correlate", correlateToString(
+ inputNode.getRowType,
+ rexCall, sqlFunction,
+ getExpressionString))
+ .item("select", selectToString(relRowType))
.item("rowType", relRowType)
.item("joinType", joinType)
.itemIf("condition", condition.orNull, condition.isDefined)
@@ -103,8 +107,6 @@ class DataSetCorrelate(
val pojoFieldMapping = Some(sqlFunction.getPojoFieldMapping)
val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
- val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
-
val flatMap = generateFunction(
config,
new RowSchema(getInput.getRowType),
@@ -131,6 +133,14 @@ class DataSetCorrelate(
collector.code,
flatMap.returnType)
- inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType))
+ inputDS
+ .flatMap(mapFunc)
+ .name(correlateOpName(
+ inputNode.getRowType,
+ rexCall,
+ sqlFunction,
+ relRowType,
+ getExpressionString)
+ )
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
index 18ab2a3..4c702ee 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
@@ -69,7 +69,7 @@ class DataStreamCorrelate(
override def toString: String = {
val rexCall = scan.getCall.asInstanceOf[RexCall]
val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
- correlateToString(rexCall, sqlFunction)
+ correlateToString(inputSchema.relDataType, rexCall, sqlFunction, getExpressionString)
}
override def explainTerms(pw: RelWriter): RelWriter = {
@@ -77,7 +77,11 @@ class DataStreamCorrelate(
val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
super.explainTerms(pw)
.item("invocation", scan.getCall)
- .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
+ .item("correlate", correlateToString(
+ inputSchema.relDataType,
+ rexCall, sqlFunction,
+ getExpressionString))
+ .item("select", selectToString(schema.relDataType))
.item("rowType", schema.relDataType)
.item("joinType", joinType)
.itemIf("condition", condition.orNull, condition.isDefined)
@@ -130,7 +134,13 @@ class DataStreamCorrelate(
.process(processFunc)
// preserve input parallelism to ensure that acc and retract messages remain in order
.setParallelism(inputParallelism)
- .name(correlateOpName(rexCall, sqlFunction, schema.relDataType))
+ .name(correlateOpName(
+ inputSchema.relDataType,
+ rexCall,
+ sqlFunction,
+ schema.relDataType,
+ getExpressionString)
+ )
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala
index 802fd85..23dfc03 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala
@@ -84,6 +84,7 @@ class LogicalUnnestRule(
// create table function
val explodeTableFunc = UserDefinedFunctionUtils.createTableSqlFunction(
"explode",
+ "explode",
ExplodeFunctionUtil.explodeTableFuncFromType(arrayType.typeInfo),
FlinkTypeFactory.toTypeInfo(arrayType.getComponentType),
cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory])
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
index 14f812a..61f43dc 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java
@@ -110,6 +110,11 @@ public class JavaUserDefinedAggFunctions {
acc.sum += a.sum;
}
}
+
+ @Override
+ public String toString() {
+ return "myWeightedAvg";
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala
index 59d2a47..486b078 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala
@@ -238,7 +238,7 @@ class TableSourceTest extends TableTestBase {
Array("name", "id", "amount", "price"),
"'amount > 2"),
term("select", "price", "id", "amount"),
- term("where", s"<(${func.functionIdentifier}(amount), 32)")
+ term("where", s"<(${Func0.getClass.getSimpleName}(amount), 32)")
)
util.verifyTable(result, expected)
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala
index a71f11c..6942a4e 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala
@@ -42,7 +42,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "func1($cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1($$cor0.c))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
@@ -62,7 +63,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "func1($cor0.c, '$')"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1($$cor0.c, '$$'))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
@@ -88,7 +90,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "func1($cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1($$cor0.c))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "LEFT")
@@ -114,7 +117,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "func2($cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
+ term("correlate", s"table(func2($$cor0.c))"),
+ term("select", "a", "b", "c", "f0", "f1"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
"VARCHAR(65536) f0, INTEGER f1)"),
@@ -141,7 +145,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "hierarchy($cor0.c)"),
- term("function", function.getClass.getCanonicalName),
+ term("correlate", s"table(hierarchy($$cor0.c))"),
+ term("select", "a", "b", "c", "f0", "f1", "f2"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c," +
" VARCHAR(65536) f0, BOOLEAN f1, INTEGER f2)"),
@@ -168,7 +173,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "pojo($cor0.c)"),
- term("function", function.getClass.getCanonicalName),
+ term("correlate", s"table(pojo($$cor0.c))"),
+ term("select", "a", "b", "c", "age", "name"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c," +
" INTEGER age, VARCHAR(65536) name)"),
@@ -196,7 +202,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "func2($cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
+ term("correlate", s"table(func2($$cor0.c))"),
+ term("select", "a", "b", "c", "f0", "f1"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
"VARCHAR(65536) f0, INTEGER f1)"),
@@ -224,7 +231,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1(SUBSTRING($$cor0.c, 2)))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
@@ -250,7 +258,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "func1('hello', 'world', $cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1('hello', 'world', $$cor0.c))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
@@ -272,7 +281,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", "func2('hello', 'world', $cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
+ term("correlate", s"table(func2('hello', 'world', $$cor0.c))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
index ee05547..ff6dcf1 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
@@ -88,10 +88,10 @@ class CalcTest extends TableTestBase {
"DataSetCalc",
batchTableNode(0),
term("select",
- s"${giveMeCaseClass.functionIdentifier}().my AS _c0",
- s"${giveMeCaseClass.functionIdentifier}().clazz AS _c1",
- s"${giveMeCaseClass.functionIdentifier}().my AS _c2",
- s"${giveMeCaseClass.functionIdentifier}().clazz AS _c3"
+ "giveMeCaseClass$().my AS _c0",
+ "giveMeCaseClass$().clazz AS _c1",
+ "giveMeCaseClass$().my AS _c2",
+ "giveMeCaseClass$().clazz AS _c3"
)
)
@@ -171,7 +171,7 @@ class CalcTest extends TableTestBase {
val expected = unaryNode(
"DataSetCalc",
batchTableNode(0),
- term("select", s"${MyHashCode.functionIdentifier}(c) AS _c0", "b")
+ term("select", "MyHashCode$(c) AS _c0", "b")
)
util.verifyTable(resultTable, expected)
@@ -283,7 +283,7 @@ class CalcTest extends TableTestBase {
unaryNode(
"DataSetCalc",
batchTableNode(0),
- term("select", "a", "c", s"${MyHashCode.functionIdentifier}(c) AS k")
+ term("select", "a", "c", "MyHashCode$(c) AS k")
),
term("groupBy", "k"),
term("select", "k", "SUM(a) AS TMP_0")
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
index 15f3def..0b48070 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
@@ -21,7 +21,6 @@ package org.apache.flink.table.api.batch.table
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
import org.apache.flink.table.utils.TableTestUtil._
-import org.apache.flink.table.runtime.utils._
import org.apache.flink.table.utils.{TableFunc1, TableTestBase}
import org.junit.Test
@@ -41,7 +40,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2)"),
- term("function", function),
+ term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+ term("select", "a", "b", "c", "s"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "INNER")
@@ -61,7 +61,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2, '$$')"),
- term("function", function),
+ term("correlate", s"table(${function.getClass.getSimpleName}(c, '$$'))"),
+ term("select", "a", "b", "c", "s"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "INNER")
@@ -86,7 +87,8 @@ class CorrelateTest extends TableTestBase {
"DataSetCorrelate",
batchTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2)"),
- term("function", function),
+ term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+ term("select", "a", "b", "c", "s"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "LEFT")
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala
index e441203..6a2f1a7 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala
@@ -71,7 +71,7 @@ class GroupWindowTest extends TableTestBase {
batchTableNode(0),
term("groupBy", "string"),
term("window", TumblingGroupWindow(WindowReference("w"), 'long, 5.milli)),
- term("select", "string", "WeightedAvgWithMerge(long, int) AS TMP_0")
+ term("select", "string", "myWeightedAvg(long, int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
@@ -212,7 +212,7 @@ class GroupWindowTest extends TableTestBase {
term("groupBy", "string"),
term("window",
SlidingGroupWindow(WindowReference("w"), 'long, 8.milli, 10.milli)),
- term("select", "string", "WeightedAvgWithMerge(long, int) AS TMP_0")
+ term("select", "string", "myWeightedAvg(long, int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
@@ -310,7 +310,7 @@ class GroupWindowTest extends TableTestBase {
batchTableNode(0),
term("groupBy", "string"),
term("window", SessionGroupWindow(WindowReference("w"), 'long, 7.milli)),
- term("select", "string", "WeightedAvgWithMerge(long, int) AS TMP_0")
+ term("select", "string", "myWeightedAvg(long, int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala
index 955ed4b..ec61816 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala
@@ -42,7 +42,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "func1($cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1($$cor0.c))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
@@ -62,7 +63,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "func1($cor0.c, '$')"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1($$cor0.c, '$$'))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
@@ -88,7 +90,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "func1($cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1($$cor0.c))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "LEFT")
@@ -114,7 +117,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "func2($cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
+ term("correlate", s"table(func2($$cor0.c))"),
+ term("select", "a", "b", "c", "f0", "f1"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
"VARCHAR(65536) f0, INTEGER f1)"),
@@ -141,7 +145,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "hierarchy($cor0.c)"),
- term("function", function.getClass.getCanonicalName),
+ term("correlate", s"table(hierarchy($$cor0.c))"),
+ term("select", "a", "b", "c", "f0", "f1", "f2"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c," +
" VARCHAR(65536) f0, BOOLEAN f1, INTEGER f2)"),
@@ -168,7 +173,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "pojo($cor0.c)"),
- term("function", function.getClass.getCanonicalName),
+ term("correlate", s"table(pojo($$cor0.c))"),
+ term("select", "a", "b", "c", "age", "name"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c," +
" INTEGER age, VARCHAR(65536) name)"),
@@ -196,7 +202,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "func2($cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
+ term("correlate", s"table(func2($$cor0.c))"),
+ term("select", "a", "b", "c", "f0", "f1"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
"VARCHAR(65536) f0, INTEGER f1)"),
@@ -224,7 +231,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1(SUBSTRING($$cor0.c, 2)))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
@@ -250,7 +258,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "func1('hello', 'world', $cor0.c)"),
- term("function", func1.getClass.getCanonicalName),
+ term("correlate", s"table(func1('hello', 'world', $$cor0.c))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
@@ -272,7 +281,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "func2('hello', 'world', $cor0.c)"),
- term("function", func2.getClass.getCanonicalName),
+ term("correlate", s"table(func2('hello', 'world', $$cor0.c))"),
+ term("select", "a", "b", "c", "f0"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)"),
term("joinType", "INNER")
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
index f15dea9..9d9d1db 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
@@ -19,8 +19,8 @@ package org.apache.flink.table.api.stream.table
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
+import org.apache.flink.table.expressions.utils.Func13
import org.apache.flink.table.utils.TableTestUtil._
-import org.apache.flink.table.runtime.utils._
import org.apache.flink.table.utils._
import org.junit.Test
@@ -40,7 +40,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2)"),
- term("function", function),
+ term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+ term("select", "a", "b", "c", "s"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "INNER")
@@ -60,7 +61,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2, '$$')"),
- term("function", function),
+ term("correlate", s"table(${function.getClass.getSimpleName}(c, '$$'))"),
+ term("select", "a", "b", "c", "s"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "INNER")
@@ -85,7 +87,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2)"),
- term("function", function),
+ term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+ term("select", "a", "b", "c", "s"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "LEFT")
@@ -101,16 +104,19 @@ class CorrelateTest extends TableTestBase {
val util = streamTestUtil()
val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
val function = util.addFunction("func2", new TableFunc2)
+ val scalarFunc = new Func13("pre")
- val result = table.join(function('c) as ('name, 'len)).select('c, 'name, 'len)
+ val result = table.join(function(scalarFunc('c)) as ('name, 'len)).select('c, 'name, 'len)
val expected = unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
- term("invocation", s"${function.functionIdentifier}($$2)"),
- term("function", function),
+ term("invocation",
+ s"${function.functionIdentifier}(${scalarFunc.functionIdentifier}($$2))"),
+ term("correlate", s"table(${function.getClass.getSimpleName}(Func13(c)))"),
+ term("select", "a", "b", "c", "name", "len"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
"VARCHAR(65536) name, INTEGER len)"),
@@ -134,7 +140,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2)"),
- term("function", function),
+ term("correlate", "table(HierarchyTableFunction(c))"),
+ term("select", "a", "b", "c", "name", "adult", "len"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c," +
" VARCHAR(65536) name, BOOLEAN adult, INTEGER len)"),
@@ -156,7 +163,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2)"),
- term("function", function),
+ term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+ term("select", "a", "b", "c", "age", "name"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
"INTEGER age, VARCHAR(65536) name)"),
@@ -183,7 +191,8 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2)"),
- term("function", function),
+ term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+ term("select", "a", "b", "c", "name", "len"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
"VARCHAR(65536) name, INTEGER len)"),
@@ -208,7 +217,9 @@ class CorrelateTest extends TableTestBase {
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", s"${function.functionIdentifier}(SUBSTRING($$2, 2, CHAR_LENGTH($$2)))"),
- term("function", function),
+ term("correlate",
+ s"table(${function.getClass.getSimpleName}(SUBSTRING(c, 2, CHAR_LENGTH(c))))"),
+ term("select", "a", "b", "c", "s"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "INNER")
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala
index 599c76b..260726b 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala
@@ -181,7 +181,7 @@ class GroupWindowTest extends TableTestBase {
WindowReference("w"),
'rowtime,
5.milli)),
- term("select", "string", "WeightedAvgWithMerge(long, int) AS TMP_0")
+ term("select", "string", "myWeightedAvg(long, int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
@@ -319,7 +319,7 @@ class GroupWindowTest extends TableTestBase {
streamTableNode(0),
term("groupBy", "string"),
term("window", SlidingGroupWindow(WindowReference("w"), 'rowtime, 8.milli, 10.milli)),
- term("select", "string", "WeightedAvgWithMerge(long, int) AS TMP_0")
+ term("select", "string", "myWeightedAvg(long, int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
@@ -363,7 +363,7 @@ class GroupWindowTest extends TableTestBase {
streamTableNode(0),
term("groupBy", "string"),
term("window", SessionGroupWindow(WindowReference("w"), 'rowtime, 7.milli)),
- term("select", "string", "WeightedAvgWithMerge(long, int) AS TMP_0")
+ term("select", "string", "myWeightedAvg(long, int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala
index 8b563a3..55e3ecb 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala
@@ -23,7 +23,6 @@ import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils.Func1
import org.apache.flink.table.api.Table
import org.apache.flink.table.utils.TableTestUtil._
-import org.apache.flink.table.utils.StreamTableTestUtil
import org.apache.flink.table.utils.{StreamTableTestUtil, TableTestBase}
import org.junit.Test
@@ -65,7 +64,7 @@ class OverWindowTest extends TableTestBase {
"WeightedAvgWithRetract(c, a) AS w0$o2")
),
term("select",
- s"${plusOne.functionIdentifier}(w0$$o0) AS d",
+ s"Func1$$(w0$$o0) AS d",
"EXP(CAST(w0$o1)) AS _c1",
"+(w0$o2, 1) AS _c2",
"||('AVG:', CAST(w0$o2)) AS _c3",
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala
index b4ad9ca..ce4de14 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala
@@ -491,7 +491,7 @@ class ExpressionReductionRulesTest extends TableTestBase {
"DataStreamCalc",
streamTableNode(0),
term("select", "a", "b", "c"),
- term("where", s"IS NULL(${NonDeterministicNullFunc.functionIdentifier}())")
+ term("where", s"IS NULL(NonDeterministicNullFunc$$())")
)
util.verifyTable(result, expected)
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
index cfff326..1714ec8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
@@ -160,7 +160,8 @@ class TimeIndicatorConversionTest extends TableTestBase {
streamTableNode(0),
term("invocation",
s"${func.functionIdentifier}(CAST($$0):TIMESTAMP(3) NOT NULL, PROCTIME($$3), '')"),
- term("function", func),
+ term("correlate", s"table(TableFunc(CAST(rowtime), PROCTIME(proctime), ''))"),
+ term("select", "rowtime", "long", "int", "proctime", "s"),
term("rowType", "RecordType(TIME ATTRIBUTE(ROWTIME) rowtime, BIGINT long, INTEGER int, " +
"TIME ATTRIBUTE(PROCTIME) proctime, VARCHAR(65536) s)"),
term("joinType", "INNER")
http://git-wip-us.apache.org/repos/asf/flink/blob/427dfe42/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/CalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/CalcITCase.scala
index c62349c..480d817 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/CalcITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/CalcITCase.scala
@@ -284,7 +284,7 @@ class CalcITCase extends StreamingMultipleProgramsTestBase {
val func1 = new Func13("Sunny")
val func2 = new Func13("kevin2")
- val result = t.select(func0('c), func1('c),func2('c))
+ val result = t.select(func0('c), func1('c), func2('c))
result.addSink(new StreamITCase.StringSink[Row])
env.execute()
[5/5] flink git commit: [FLINK-6233] [table] Add more tests for
rowtime window join + minor refactoring.
Posted by fh...@apache.org.
[FLINK-6233] [table] Add more tests for rowtime window join + minor refactoring.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1ea7f49a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1ea7f49a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1ea7f49a
Branch: refs/heads/master
Commit: 1ea7f49a5030ae481122d34915ca14d30b8626f5
Parents: 655d8b1
Author: Fabian Hueske <fh...@apache.org>
Authored: Tue Oct 10 14:48:24 2017 +0200
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Oct 10 23:09:07 2017 +0200
----------------------------------------------------------------------
.../nodes/datastream/DataStreamWindowJoin.scala | 16 +-
.../datastream/DataStreamWindowJoinRule.scala | 17 +-
.../join/ProcTimeBoundedStreamInnerJoin.scala | 17 +-
.../join/RowTimeBoundedStreamInnerJoin.scala | 18 +-
.../join/TimeBoundedStreamInnerJoin.scala | 38 ++--
.../table/runtime/join/WindowJoinUtil.scala | 21 --
.../flink/table/api/stream/sql/JoinTest.scala | 93 ++++++++-
.../table/runtime/harness/HarnessTestBase.scala | 20 ++
.../table/runtime/harness/JoinHarnessTest.scala | 53 +++--
.../table/runtime/stream/sql/JoinITCase.scala | 209 ++++++++++++++-----
10 files changed, 368 insertions(+), 134 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala
index 9358aa3..3e23006 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala
@@ -136,6 +136,11 @@ class DataStreamWindowJoin(
remainCondition,
ruleDescription)
+ val joinOpName =
+ s"where: (" +
+ s"${joinConditionToString(schema.relDataType, joinCondition, getExpressionString)}), " +
+ s"join: (${joinSelectionToString(schema.relDataType)})"
+
joinType match {
case JoinRelType.INNER =>
if (relativeWindowSize < 0) {
@@ -148,6 +153,7 @@ class DataStreamWindowJoin(
leftDataStream,
rightDataStream,
returnTypeInfo,
+ joinOpName,
joinFunction.name,
joinFunction.code,
leftKeys,
@@ -158,6 +164,7 @@ class DataStreamWindowJoin(
leftDataStream,
rightDataStream,
returnTypeInfo,
+ joinOpName,
joinFunction.name,
joinFunction.code,
leftKeys,
@@ -202,6 +209,7 @@ class DataStreamWindowJoin(
leftDataStream: DataStream[CRow],
rightDataStream: DataStream[CRow],
returnTypeInfo: TypeInformation[CRow],
+ operatorName: String,
joinFunctionName: String,
joinFunctionCode: String,
leftKeys: Array[Int],
@@ -210,7 +218,6 @@ class DataStreamWindowJoin(
val procInnerJoinFunc = new ProcTimeBoundedStreamInnerJoin(
leftLowerBound,
leftUpperBound,
- allowedLateness = 0L,
leftSchema.typeInfo,
rightSchema.typeInfo,
joinFunctionName,
@@ -220,6 +227,7 @@ class DataStreamWindowJoin(
leftDataStream.connect(rightDataStream)
.keyBy(leftKeys, rightKeys)
.process(procInnerJoinFunc)
+ .name(operatorName)
.returns(returnTypeInfo)
} else {
leftDataStream.connect(rightDataStream)
@@ -227,6 +235,7 @@ class DataStreamWindowJoin(
.process(procInnerJoinFunc)
.setParallelism(1)
.setMaxParallelism(1)
+ .name(operatorName)
.returns(returnTypeInfo)
}
}
@@ -235,6 +244,7 @@ class DataStreamWindowJoin(
leftDataStream: DataStream[CRow],
rightDataStream: DataStream[CRow],
returnTypeInfo: TypeInformation[CRow],
+ operatorName: String,
joinFunctionName: String,
joinFunctionCode: String,
leftKeys: Array[Int],
@@ -256,7 +266,7 @@ class DataStreamWindowJoin(
.connect(rightDataStream)
.keyBy(leftKeys, rightKeys)
.transform(
- "InnerRowtimeWindowJoin",
+ operatorName,
returnTypeInfo,
new KeyedCoProcessOperatorWithWatermarkDelay[Tuple, CRow, CRow, CRow](
rowTimeInnerJoinFunc,
@@ -266,7 +276,7 @@ class DataStreamWindowJoin(
leftDataStream.connect(rightDataStream)
.keyBy(new NullByteKeySelector[CRow](), new NullByteKeySelector[CRow])
.transform(
- "InnerRowtimeWindowJoin",
+ operatorName,
returnTypeInfo,
new KeyedCoProcessOperatorWithWatermarkDelay[java.lang.Byte, CRow, CRow, CRow](
rowTimeInnerJoinFunc,
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
index d208d2b..a7358c7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
@@ -41,29 +41,22 @@ class DataStreamWindowJoinRule
override def matches(call: RelOptRuleCall): Boolean = {
val join: FlinkLogicalJoin = call.rel(0).asInstanceOf[FlinkLogicalJoin]
- val (windowBounds, remainingPreds) = WindowJoinUtil.extractWindowBoundsFromPredicate(
+ val (windowBounds, _) = WindowJoinUtil.extractWindowBoundsFromPredicate(
join.getCondition,
join.getLeft.getRowType.getFieldCount,
join.getRowType,
join.getCluster.getRexBuilder,
TableConfig.DEFAULT)
- // remaining predicate must not access time attributes
- val remainingPredsAccessTime = remainingPreds.isDefined &&
- WindowJoinUtil.accessesTimeAttribute(remainingPreds.get, join.getRowType)
-
if (windowBounds.isDefined) {
if (windowBounds.get.isEventTime) {
- !remainingPredsAccessTime
+ true
} else {
- // Check that no event-time attributes are in the input.
- // The proc-time join implementation does ensure that record timestamp are correctly set.
- // It is always the timestamp of the later arriving record.
+ // Check that no event-time attributes are in the input because the processing time window
+ // join does not correctly hold back watermarks.
// We rely on projection pushdown to remove unused attributes before the join.
- val rowTimeAttrInOutput = join.getRowType.getFieldList.asScala
+ !join.getRowType.getFieldList.asScala
.exists(f => FlinkTypeFactory.isRowtimeIndicatorType(f.getType))
-
- !remainingPredsAccessTime && !rowTimeAttrInOutput
}
} else {
// the given join does not have valid window bounds. We cannot translate it.
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala
index ab5a9c3..3bac42c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala
@@ -29,19 +29,18 @@ import org.apache.flink.types.Row
final class ProcTimeBoundedStreamInnerJoin(
leftLowerBound: Long,
leftUpperBound: Long,
- allowedLateness: Long,
leftType: TypeInformation[Row],
rightType: TypeInformation[Row],
genJoinFuncName: String,
genJoinFuncCode: String)
- extends TimeBoundedStreamInnerJoin(
- leftLowerBound,
- leftUpperBound,
- allowedLateness,
- leftType,
- rightType,
- genJoinFuncName,
- genJoinFuncCode) {
+ extends TimeBoundedStreamInnerJoin(
+ leftLowerBound,
+ leftUpperBound,
+ allowedLateness = 0L,
+ leftType,
+ rightType,
+ genJoinFuncName,
+ genJoinFuncCode) {
override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = {
leftOperatorTime = ctx.timerService().currentProcessingTime()
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala
index 5cf5a53..a2d9dca 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala
@@ -36,18 +36,18 @@ final class RowTimeBoundedStreamInnerJoin(
genJoinFuncCode: String,
leftTimeIdx: Int,
rightTimeIdx: Int)
- extends TimeBoundedStreamInnerJoin(
- leftLowerBound,
- leftUpperBound,
- allowedLateness,
- leftType,
- rightType,
- genJoinFuncName,
- genJoinFuncCode) {
+ extends TimeBoundedStreamInnerJoin(
+ leftLowerBound,
+ leftUpperBound,
+ allowedLateness,
+ leftType,
+ rightType,
+ genJoinFuncName,
+ genJoinFuncCode) {
/**
* Get the maximum interval between receiving a row and emitting it (as part of a joined result).
- * Only reasonable for row time join.
+ * This is the time interval by which watermarks need to be held back.
*
* @return the maximum delay for the outputs
*/
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala
index 7bf3d33..9625eac 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala
@@ -38,15 +38,16 @@ import org.apache.flink.util.Collector
/**
* A CoProcessFunction to execute time-bounded stream inner-join.
* Two kinds of time criteria:
- * "L.time between R.time + X and R.time + Y" or "R.time between L.time - Y and L.time - X".
+ * "L.time between R.time + X and R.time + Y" or "R.time between L.time - Y and L.time - X" where
+ * X and Y might be negative or positive and X <= Y.
*
* @param leftLowerBound the lower bound for the left stream (X in the criteria)
* @param leftUpperBound the upper bound for the left stream (Y in the criteria)
* @param allowedLateness the lateness allowed for the two streams
* @param leftType the input type of left stream
* @param rightType the input type of right stream
- * @param genJoinFuncName the function code of other non-equi conditions
- * @param genJoinFuncCode the function name of other non-equi conditions
+ * @param genJoinFuncName the name of the generated function
+ * @param genJoinFuncCode the code of function to evaluate the non-window join conditions
*
*/
abstract class TimeBoundedStreamInnerJoin(
@@ -57,9 +58,9 @@ abstract class TimeBoundedStreamInnerJoin(
private val rightType: TypeInformation[Row],
private val genJoinFuncName: String,
private val genJoinFuncCode: String)
- extends CoProcessFunction[CRow, CRow, CRow]
- with Compiler[FlatJoinFunction[Row, Row, Row]]
- with Logging {
+ extends CoProcessFunction[CRow, CRow, CRow]
+ with Compiler[FlatJoinFunction[Row, Row, Row]]
+ with Logging {
private var cRowWrapper: CRowWrappingCollector = _
@@ -79,15 +80,16 @@ abstract class TimeBoundedStreamInnerJoin(
protected val leftRelativeSize: Long = -leftLowerBound
protected val rightRelativeSize: Long = leftUpperBound
+ // Points in time until which the respective cache has been cleaned.
private var leftExpirationTime: Long = 0L
private var rightExpirationTime: Long = 0L
+ // Current time on the respective input stream.
protected var leftOperatorTime: Long = 0L
protected var rightOperatorTime: Long = 0L
-
- // for delayed cleanup
- private val cleanupDelay = (leftRelativeSize + rightRelativeSize) / 2
+ // Minimum interval by which state is cleaned up
+ private val minCleanUpInterval = (leftRelativeSize + rightRelativeSize) / 2
if (allowedLateness < 0) {
throw new IllegalArgumentException("The allowed lateness must be non-negative.")
@@ -140,12 +142,14 @@ abstract class TimeBoundedStreamInnerJoin(
cRowValue: CRow,
ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
out: Collector[CRow]): Unit = {
+
updateOperatorTime(ctx)
val leftRow = cRowValue.row
val timeForLeftRow: Long = getTimeForLeftStream(ctx, leftRow)
val rightQualifiedLowerBound: Long = timeForLeftRow - rightRelativeSize
val rightQualifiedUpperBound: Long = timeForLeftRow + leftRelativeSize
cRowWrapper.out = out
+
// Check if we need to cache the current row.
if (rightOperatorTime < rightQualifiedUpperBound) {
// Operator time of right stream has not exceeded the upper window bound of the current
@@ -164,7 +168,7 @@ abstract class TimeBoundedStreamInnerJoin(
}
// Check if we need to join the current row against cached rows of the right input.
// The condition here should be rightMinimumTime < rightQualifiedUpperBound.
- // I use rightExpirationTime as an approximation of the rightMinimumTime here,
+ // We use rightExpirationTime as an approximation of the rightMinimumTime here,
// since rightExpirationTime <= rightMinimumTime is always true.
if (rightExpirationTime < rightQualifiedUpperBound) {
// Upper bound of current join window has not passed the cache expiration time yet.
@@ -199,12 +203,14 @@ abstract class TimeBoundedStreamInnerJoin(
cRowValue: CRow,
ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
out: Collector[CRow]): Unit = {
+
updateOperatorTime(ctx)
val rightRow = cRowValue.row
val timeForRightRow: Long = getTimeForRightStream(ctx, rightRow)
val leftQualifiedLowerBound: Long = timeForRightRow - leftRelativeSize
val leftQualifiedUpperBound: Long = timeForRightRow + rightRelativeSize
cRowWrapper.out = out
+
// Check if we need to cache the current row.
if (leftOperatorTime < leftQualifiedUpperBound) {
// Operator time of left stream has not exceeded the upper window bound of the current
@@ -223,7 +229,7 @@ abstract class TimeBoundedStreamInnerJoin(
}
// Check if we need to join the current row against cached rows of the left input.
// The condition here should be leftMinimumTime < leftQualifiedUpperBound.
- // I use leftExpirationTime as an approximation of the leftMinimumTime here,
+ // We use leftExpirationTime as an approximation of the leftMinimumTime here,
// since leftExpirationTime <= leftMinimumTime is always true.
if (leftExpirationTime < leftQualifiedUpperBound) {
leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize)
@@ -261,6 +267,7 @@ abstract class TimeBoundedStreamInnerJoin(
timestamp: Long,
ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext,
out: Collector[CRow]): Unit = {
+
updateOperatorTime(ctx)
// In the future, we should separate the left and right watermarks. Otherwise, the
// registered timer of the faster stream will be delayed, even if the watermarks have
@@ -316,11 +323,11 @@ abstract class TimeBoundedStreamInnerJoin(
rowTime: Long,
leftRow: Boolean): Unit = {
if (leftRow) {
- val cleanupTime = rowTime + leftRelativeSize + cleanupDelay + allowedLateness + 1
+ val cleanupTime = rowTime + leftRelativeSize + minCleanUpInterval + allowedLateness + 1
registerTimer(ctx, cleanupTime)
rightTimerState.update(cleanupTime)
} else {
- val cleanupTime = rowTime + rightRelativeSize + cleanupDelay + allowedLateness + 1
+ val cleanupTime = rowTime + rightRelativeSize + minCleanUpInterval + allowedLateness + 1
registerTimer(ctx, cleanupTime)
leftTimerState.update(cleanupTime)
}
@@ -361,6 +368,7 @@ abstract class TimeBoundedStreamInnerJoin(
}
}
}
+
if (earliestTimestamp > 0) {
// There are rows left in the cache. Register a timer to expire them later.
registerCleanUpTimer(
@@ -385,6 +393,8 @@ abstract class TimeBoundedStreamInnerJoin(
/**
* Return the time for the target row from the left stream.
*
+ * Requires that [[updateOperatorTime()]] has been called before.
+ *
* @param context the runtime context
* @param row the target row
* @return time for the target row
@@ -394,6 +404,8 @@ abstract class TimeBoundedStreamInnerJoin(
/**
* Return the time for the target row from the right stream.
*
+ * Requires that [[updateOperatorTime()]] has been called before.
+ *
* @param context the runtime context
* @param row the target row
* @return time for the target row
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala
index 6f97f2a..863f342 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala
@@ -266,27 +266,6 @@ object WindowJoinUtil {
}
/**
- * Checks if an expression accesses a time attribute.
- *
- * @param expr The expression to check.
- * @param inputType The input type of the expression.
- * @return True, if the expression accesses a time attribute. False otherwise.
- */
- def accessesTimeAttribute(expr: RexNode, inputType: RelDataType): Boolean = {
- expr match {
- case i: RexInputRef =>
- val accessedType = inputType.getFieldList.get(i.getIndex).getType
- accessedType match {
- case _: TimeIndicatorRelDataType => true
- case _ => false
- }
- case c: RexCall =>
- c.operands.asScala.exists(accessesTimeAttribute(_, inputType))
- case _ => false
- }
- }
-
- /**
* Checks if an expression accesses a non-time attribute.
*
* @param expr The expression to check.
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala
index a4234c5..53aff82 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala
@@ -20,8 +20,9 @@ package org.apache.flink.table.api.stream.sql
import org.apache.calcite.rel.logical.LogicalJoin
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
+import org.apache.flink.table.plan.logical.TumblingGroupWindow
import org.apache.flink.table.runtime.join.WindowJoinUtil
-import org.apache.flink.table.utils.TableTestUtil._
+import org.apache.flink.table.utils.TableTestUtil.{term, _}
import org.apache.flink.table.utils.{StreamTableTestUtil, TableTestBase}
import org.junit.Assert._
import org.junit.Test
@@ -184,6 +185,96 @@ class JoinTest extends TableTestBase {
}
@Test
+ def testRowTimeInnerJoinAndWindowAggregationOnFirst(): Unit = {
+
+ val sqlQuery =
+ """
+ |SELECT t1.b, SUM(t2.a) AS aSum, COUNT(t2.b) AS bCnt
+ |FROM MyTable t1, MyTable2 t2
+ |WHERE t1.a = t2.a AND
+ | t1.c BETWEEN t2.c - INTERVAL '10' MINUTE AND t2.c + INTERVAL '1' HOUR
+ |GROUP BY TUMBLE(t1.c, INTERVAL '6' HOUR), t1.b
+ |""".stripMargin
+
+ val expected =
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamWindowJoin",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "b", "c")
+ ),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(1),
+ term("select", "a", "b", "c")
+ ),
+ term("where",
+ "AND(=(a, a0), >=(c, -(c0, 600000)), " +
+ "<=(c, DATETIME_PLUS(c0, 3600000)))"),
+ term("join", "a, b, c, a0, b0, c0"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "c", "b", "a0", "b0")
+ ),
+ term("groupBy", "b"),
+ term("window", TumblingGroupWindow('w$, 'c, 21600000.millis)),
+ term("select", "b", "SUM(a0) AS aSum", "COUNT(b0) AS bCnt")
+ )
+
+ streamUtil.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testRowTimeInnerJoinAndWindowAggregationOnSecond(): Unit = {
+
+ val sqlQuery =
+ """
+ |SELECT t2.b, SUM(t1.a) AS aSum, COUNT(t1.b) AS bCnt
+ |FROM MyTable t1, MyTable2 t2
+ |WHERE t1.a = t2.a AND
+ | t1.c BETWEEN t2.c - INTERVAL '10' MINUTE AND t2.c + INTERVAL '1' HOUR
+ |GROUP BY TUMBLE(t2.c, INTERVAL '6' HOUR), t2.b
+ |""".stripMargin
+
+ val expected =
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamWindowJoin",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "b", "c")
+ ),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(1),
+ term("select", "a", "b", "c")
+ ),
+ term("where",
+ "AND(=(a, a0), >=(c, -(c0, 600000)), " +
+ "<=(c, DATETIME_PLUS(c0, 3600000)))"),
+ term("join", "a, b, c, a0, b0, c0"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "c0", "b0", "a", "b")
+ ),
+ term("groupBy", "b0"),
+ term("window", TumblingGroupWindow('w$, 'c0, 21600000.millis)),
+ term("select", "b0", "SUM(a) AS aSum", "COUNT(b) AS bCnt")
+ )
+
+ streamUtil.verifySql(sqlQuery, expected)
+ }
+
+ @Test
def testJoinTimeBoundary(): Unit = {
verifyTimeBoundary(
"t1.proctime between t2.proctime - interval '1' hour " +
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index 67164b7..942846c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -355,6 +355,26 @@ object HarnessTestBase {
}
/**
+ * Return 0 for equal Rows and non zero for different rows
+ */
+ class RowResultSortComparatorWithWatermarks()
+ extends Comparator[Object] with Serializable {
+
+ override def compare(o1: Object, o2: Object): Int = {
+
+ (o1, o2) match {
+ case (w1: Watermark, w2: Watermark) =>
+ w1.getTimestamp.compareTo(w2.getTimestamp)
+ case (r1: StreamRecord[CRow], r2: StreamRecord[CRow]) =>
+ r1.getValue.toString.compareTo(r2.getValue.toString)
+ case (_: Watermark, _: StreamRecord[CRow]) => -1
+ case (_: StreamRecord[CRow], _: Watermark) => 1
+ case _ => -1
+ }
+ }
+ }
+
+ /**
* Tuple row key selector that returns a specified field as the selector function
*/
class TupleRowKeySelector[T](
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
index 192befd..43397ae 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
@@ -25,11 +25,12 @@ import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness
import org.apache.flink.table.api.Types
-import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, TupleRowKeySelector}
+import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks, TupleRowKeySelector}
import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin}
+import org.apache.flink.table.runtime.operators.KeyedCoProcessOperatorWithWatermarkDelay
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.types.Row
-import org.junit.Assert.{assertEquals}
+import org.junit.Assert.assertEquals
import org.junit.Test
class JoinHarnessTest extends HarnessTestBase {
@@ -75,7 +76,7 @@ class JoinHarnessTest extends HarnessTestBase {
def testProcTimeJoinWithCommonBounds() {
val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin(
- -10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode)
+ -10, 20, rowType, rowType, "TestJoinFunction", funcCode)
val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] =
new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc)
@@ -165,7 +166,7 @@ class JoinHarnessTest extends HarnessTestBase {
def testProcTimeJoinWithNegativeBounds() {
val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin(
- -10, -5, 0, rowType, rowType, "TestJoinFunction", funcCode)
+ -10, -5, rowType, rowType, "TestJoinFunction", funcCode)
val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] =
new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc)
@@ -250,7 +251,9 @@ class JoinHarnessTest extends HarnessTestBase {
-10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0)
val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] =
- new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc)
+ new KeyedCoProcessOperatorWithWatermarkDelay[String, CRow, CRow, CRow](
+ joinProcessFunc,
+ joinProcessFunc.getMaxOutputDelay)
val testHarness: KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow] =
new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow](
operator,
@@ -312,23 +315,31 @@ class JoinHarnessTest extends HarnessTestBase {
assertEquals(4, testHarness.numKeyedStateEntries())
val expectedOutput = new ConcurrentLinkedQueue[Object]()
+ expectedOutput.add(new Watermark(-19))
+ // This result is produced by the late row (1, "k1").
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(1L: JLong, "k1", 2L: JLong, "k1"), true), 0))
expectedOutput.add(new StreamRecord(
CRow(Row.of(2L: JLong, "k1", 2L: JLong, "k1"), true), 0))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(5L: JLong, "k1", 2L: JLong, "k1"), true), 0))
+ CRow(Row.of(5L: JLong, "k1", 2L: JLong, "k1"), true), 0))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(5L: JLong, "k1", 15L: JLong, "k1"), true), 0))
+ CRow(Row.of(5L: JLong, "k1", 15L: JLong, "k1"), true), 0))
+ expectedOutput.add(new Watermark(0))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(35L: JLong, "k1", 15L: JLong, "k1"), true), 0))
+ CRow(Row.of(35L: JLong, "k1", 15L: JLong, "k1"), true), 0))
+ expectedOutput.add(new Watermark(18))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(40L: JLong, "k2", 39L: JLong, "k2"), true), 0))
-
- // This result is produced by the late row (1, "k1").
- expectedOutput.add(new StreamRecord(
- CRow(Row.of(1L: JLong, "k1", 2L: JLong, "k1"), true), 0))
+ CRow(Row.of(40L: JLong, "k2", 39L: JLong, "k2"), true), 0))
+ expectedOutput.add(new Watermark(41))
val result = testHarness.getOutput
- verify(expectedOutput, result, new RowResultSortComparator())
+ println(result)
+ verify(
+ expectedOutput,
+ result,
+ new RowResultSortComparatorWithWatermarks(),
+ checkWaterMark = true)
testHarness.close()
}
@@ -340,7 +351,9 @@ class JoinHarnessTest extends HarnessTestBase {
-10, -7, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0)
val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] =
- new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc)
+ new KeyedCoProcessOperatorWithWatermarkDelay[String, CRow, CRow, CRow](
+ joinProcessFunc,
+ joinProcessFunc.getMaxOutputDelay)
val testHarness: KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow] =
new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow](
operator,
@@ -394,13 +407,21 @@ class JoinHarnessTest extends HarnessTestBase {
assertEquals(0, testHarness.numKeyedStateEntries())
val expectedOutput = new ConcurrentLinkedQueue[Object]()
+ expectedOutput.add(new Watermark(-9))
+ expectedOutput.add(new Watermark(-8))
expectedOutput.add(new StreamRecord(
CRow(Row.of(3L: JLong, "k1", 13L: JLong, "k1"), true), 0))
expectedOutput.add(new StreamRecord(
CRow(Row.of(6L: JLong, "k1", 13L: JLong, "k1"), true), 0))
+ expectedOutput.add(new Watermark(0))
+ expectedOutput.add(new Watermark(8))
val result = testHarness.getOutput
- verify(expectedOutput, result, new RowResultSortComparator())
+ verify(
+ expectedOutput,
+ result,
+ new RowResultSortComparatorWithWatermarks(),
+ checkWaterMark = true)
testHarness.close()
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1ea7f49a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
index 13bfbcd..015a5a2 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
@@ -132,51 +132,54 @@ class JoinITCase extends StreamingWithStateTestBase {
env.setStateBackend(getStateBackend)
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
StreamITCase.clear
- env.setParallelism(1)
val sqlQuery =
"""
- |SELECT t2.a, t2.c, t1.c
+ |SELECT t2.key, t2.id, t1.id
|FROM T1 as t1 join T2 as t2 ON
- | t1.a = t2.a AND
+ | t1.key = t2.key AND
| t1.rt BETWEEN t2.rt - INTERVAL '5' SECOND AND
| t2.rt + INTERVAL '6' SECOND
|""".stripMargin
- val data1 = new mutable.MutableList[(Int, Long, String, Long)]
+ val data1 = new mutable.MutableList[(String, String, Long)]
// for boundary test
- data1.+=((1, 999L, "LEFT0.999", 999L))
- data1.+=((1, 1000L, "LEFT1", 1000L))
- data1.+=((1, 2000L, "LEFT2", 2000L))
- data1.+=((1, 3000L, "LEFT3", 3000L))
- data1.+=((2, 4000L, "LEFT4", 4000L))
- data1.+=((1, 5000L, "LEFT5", 5000L))
- data1.+=((1, 6000L, "LEFT6", 6000L))
-
- val data2 = new mutable.MutableList[(Int, Long, String, Long)]
- data2.+=((1, 6000L, "RIGHT6", 6000L))
- data2.+=((2, 7000L, "RIGHT7", 7000L))
+ data1.+=(("A", "LEFT0.999", 999L))
+ data1.+=(("A", "LEFT1", 1000L))
+ data1.+=(("A", "LEFT2", 2000L))
+ data1.+=(("A", "LEFT3", 3000L))
+ data1.+=(("B", "LEFT4", 4000L))
+ data1.+=(("A", "LEFT5", 5000L))
+ data1.+=(("A", "LEFT6", 6000L))
+ // test null key
+ data1.+=((null.asInstanceOf[String], "LEFT8", 8000L))
+
+ val data2 = new mutable.MutableList[(String, String, Long)]
+ data2.+=(("A", "RIGHT6", 6000L))
+ data2.+=(("B", "RIGHT7", 7000L))
+ // test null key
+ data2.+=((null.asInstanceOf[String], "RIGHT10", 10000L))
val t1 = env.fromCollection(data1)
- .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor)
- .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime)
+ .assignTimestampsAndWatermarks(new Row3WatermarkExtractor2)
+ .toTable(tEnv, 'key, 'id, 'rt.rowtime)
val t2 = env.fromCollection(data2)
- .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor)
- .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime)
+ .assignTimestampsAndWatermarks(new Row3WatermarkExtractor2)
+ .toTable(tEnv, 'key, 'id, 'rt.rowtime)
tEnv.registerTable("T1", t1)
tEnv.registerTable("T2", t2)
- val result = tEnv.sql(sqlQuery).toAppendStream[Row]
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(new StreamITCase.StringSink[Row])
env.execute()
val expected = new java.util.ArrayList[String]
- expected.add("1,RIGHT6,LEFT1")
- expected.add("1,RIGHT6,LEFT2")
- expected.add("1,RIGHT6,LEFT3")
- expected.add("1,RIGHT6,LEFT5")
- expected.add("1,RIGHT6,LEFT6")
- expected.add("2,RIGHT7,LEFT4")
+ expected.add("A,RIGHT6,LEFT1")
+ expected.add("A,RIGHT6,LEFT2")
+ expected.add("A,RIGHT6,LEFT3")
+ expected.add("A,RIGHT6,LEFT5")
+ expected.add("A,RIGHT6,LEFT6")
+ expected.add("B,RIGHT7,LEFT4")
StreamITCase.compareWithList(expected)
}
@@ -189,9 +192,6 @@ class JoinITCase extends StreamingWithStateTestBase {
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
StreamITCase.clear
- // different parallelisms lead to different join results
- env.setParallelism(1)
-
val sqlQuery =
"""
|SELECT t2.a, t1.c, t2.c
@@ -215,8 +215,6 @@ class JoinITCase extends StreamingWithStateTestBase {
data1.+=((1, 4L, "LEFT4.9", 4999L))
data1.+=((1, 4L, "LEFT5", 5000L))
data1.+=((1, 10L, "LEFT6", 6000L))
- // a left late row
- data1.+=((1, 3L, "LEFT3.5", 3500L))
val data2 = new mutable.MutableList[(Int, Long, String, Long)]
// just for watermark
@@ -224,20 +222,18 @@ class JoinITCase extends StreamingWithStateTestBase {
data2.+=((1, 9L, "RIGHT6", 6000L))
data2.+=((2, 14L, "RIGHT7", 7000L))
data2.+=((1, 4L, "RIGHT8", 8000L))
- // a right late row
- data2.+=((1, 10L, "RIGHT5", 5000L))
val t1 = env.fromCollection(data1)
- .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor)
+ .assignTimestampsAndWatermarks(new Row4WatermarkExtractor)
.toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime)
val t2 = env.fromCollection(data2)
- .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor)
+ .assignTimestampsAndWatermarks(new Row4WatermarkExtractor)
.toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime)
tEnv.registerTable("T1", t1)
tEnv.registerTable("T2", t2)
- val result = tEnv.sql(sqlQuery).toAppendStream[Row]
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(new StreamITCase.StringSink[Row])
env.execute()
@@ -247,34 +243,131 @@ class JoinITCase extends StreamingWithStateTestBase {
expected1+= "1,LEFT1.1,RIGHT6"
expected1+= "2,LEFT4,RIGHT7"
expected1+= "1,LEFT4.9,RIGHT6"
- // produced by the left late rows
- expected1+= "1,LEFT3.5,RIGHT6"
- expected1+= "1,LEFT3.5,RIGHT8"
- // produced by the right late rows
- expected1+= "1,LEFT3,RIGHT5"
- expected1+= "1,LEFT3.5,RIGHT5"
val expected2 = new mutable.MutableList[String]
expected2+= "1,LEFT3,RIGHT6"
expected2+= "1,LEFT1.1,RIGHT6"
expected2+= "2,LEFT4,RIGHT7"
expected2+= "1,LEFT4.9,RIGHT6"
- // produced by the left late rows
- expected2+= "1,LEFT3.5,RIGHT6"
- expected2+= "1,LEFT3.5,RIGHT8"
- // produced by the right late rows
- expected2+= "1,LEFT3,RIGHT5"
- expected2+= "1,LEFT1,RIGHT5"
- expected2+= "1,LEFT1.1,RIGHT5"
Assert.assertThat(
StreamITCase.testResults.sorted,
CoreMatchers.either(CoreMatchers.is(expected1.sorted)).
or(CoreMatchers.is(expected2.sorted)))
}
+
+ /** test rowtime inner join with window aggregation **/
+ @Test
+ def testRowTimeInnerJoinWithWindowAggregateOnFirstTime(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setStateBackend(getStateBackend)
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+ StreamITCase.clear
+
+ val sqlQuery =
+ """
+ |SELECT t1.key, TUMBLE_END(t1.rt, INTERVAL '4' SECOND), COUNT(t2.key)
+ |FROM T1 AS t1 join T2 AS t2 ON
+ | t1.key = t2.key AND
+ | t1.rt BETWEEN t2.rt - INTERVAL '5' SECOND AND
+ | t2.rt + INTERVAL '5' SECOND
+ |GROUP BY TUMBLE(t1.rt, INTERVAL '4' SECOND), t1.key
+ |""".stripMargin
+
+ val data1 = new mutable.MutableList[(String, String, Long)]
+ data1.+=(("A", "L-1", 1000L)) // no joining record
+ data1.+=(("A", "L-2", 2000L)) // 1 joining record
+ data1.+=(("A", "L-3", 3000L)) // 2 joining records
+ data1.+=(("B", "L-4", 4000L)) // 1 joining record
+ data1.+=(("C", "L-5", 4000L)) // no joining record
+ data1.+=(("A", "L-6", 10000L)) // 2 joining records
+ data1.+=(("A", "L-7", 13000L)) // 1 joining record
+
+ val data2 = new mutable.MutableList[(String, String, Long)]
+ data2.+=(("A", "R-1", 7000L)) // 3 joining records
+ data2.+=(("B", "R-4", 7000L)) // 1 joining records
+ data2.+=(("A", "R-3", 8000L)) // 3 joining records
+ data2.+=(("D", "R-2", 8000L)) // no joining record
+
+ val t1 = env.fromCollection(data1)
+ .assignTimestampsAndWatermarks(new Row3WatermarkExtractor2)
+ .toTable(tEnv, 'key, 'id, 'rt.rowtime)
+ val t2 = env.fromCollection(data2)
+ .assignTimestampsAndWatermarks(new Row3WatermarkExtractor2)
+ .toTable(tEnv, 'key, 'id, 'rt.rowtime)
+
+ tEnv.registerTable("T1", t1)
+ tEnv.registerTable("T2", t2)
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+ val expected = new java.util.ArrayList[String]
+ expected.add("A,1970-01-01 00:00:04.0,3")
+ expected.add("A,1970-01-01 00:00:12.0,2")
+ expected.add("A,1970-01-01 00:00:16.0,1")
+ expected.add("B,1970-01-01 00:00:08.0,1")
+ StreamITCase.compareWithList(expected)
+ }
+
+ /** test rowtime inner join with window aggregation **/
+ @Test
+ def testRowTimeInnerJoinWithWindowAggregateOnSecondTime(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setStateBackend(getStateBackend)
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+ StreamITCase.clear
+
+ val sqlQuery =
+ """
+ |SELECT t2.key, TUMBLE_END(t2.rt, INTERVAL '4' SECOND), COUNT(t1.key)
+ |FROM T1 AS t1 join T2 AS t2 ON
+ | t1.key = t2.key AND
+ | t1.rt BETWEEN t2.rt - INTERVAL '5' SECOND AND
+ | t2.rt + INTERVAL '5' SECOND
+ |GROUP BY TUMBLE(t2.rt, INTERVAL '4' SECOND), t2.key
+ |""".stripMargin
+
+ val data1 = new mutable.MutableList[(String, String, Long)]
+ data1.+=(("A", "L-1", 1000L)) // no joining record
+ data1.+=(("A", "L-2", 2000L)) // 1 joining record
+ data1.+=(("A", "L-3", 3000L)) // 2 joining records
+ data1.+=(("B", "L-4", 4000L)) // 1 joining record
+ data1.+=(("C", "L-5", 4000L)) // no joining record
+ data1.+=(("A", "L-6", 10000L)) // 2 joining records
+ data1.+=(("A", "L-7", 13000L)) // 1 joining record
+
+ val data2 = new mutable.MutableList[(String, String, Long)]
+ data2.+=(("A", "R-1", 7000L)) // 3 joining records
+ data2.+=(("B", "R-4", 7000L)) // 1 joining records
+ data2.+=(("A", "R-3", 8000L)) // 3 joining records
+ data2.+=(("D", "R-2", 8000L)) // no joining record
+
+ val t1 = env.fromCollection(data1)
+ .assignTimestampsAndWatermarks(new Row3WatermarkExtractor2)
+ .toTable(tEnv, 'key, 'id, 'rt.rowtime)
+ val t2 = env.fromCollection(data2)
+ .assignTimestampsAndWatermarks(new Row3WatermarkExtractor2)
+ .toTable(tEnv, 'key, 'id, 'rt.rowtime)
+
+ tEnv.registerTable("T1", t1)
+ tEnv.registerTable("T2", t2)
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+ val expected = new java.util.ArrayList[String]
+ expected.add("A,1970-01-01 00:00:08.0,3")
+ expected.add("A,1970-01-01 00:00:12.0,3")
+ expected.add("B,1970-01-01 00:00:08.0,1")
+ StreamITCase.compareWithList(expected)
+ }
+
}
-private class Tuple2WatermarkExtractor
+private class Row4WatermarkExtractor
extends AssignerWithPunctuatedWatermarks[(Int, Long, String, Long)] {
override def checkAndGetNextWatermark(
@@ -289,3 +382,19 @@ private class Tuple2WatermarkExtractor
element._4
}
}
+
+private class Row3WatermarkExtractor2
+ extends AssignerWithPunctuatedWatermarks[(String, String, Long)] {
+
+ override def checkAndGetNextWatermark(
+ lastElement: (String, String, Long),
+ extractedTimestamp: Long): Watermark = {
+ new Watermark(extractedTimestamp - 1)
+ }
+
+ override def extractTimestamp(
+ element: (String, String, Long),
+ previousElementTimestamp: Long): Long = {
+ element._3
+ }
+}
[2/5] flink git commit: [FLINK-7776] [table] Prevent emission of
identical update records in group aggregation.
Posted by fh...@apache.org.
[FLINK-7776] [table] Prevent emission of identical update records in group aggregation.
This closes #4785.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/4047be49
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/4047be49
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/4047be49
Branch: refs/heads/master
Commit: 4047be49e10cacc5e4ce932a0b8433afffa82a58
Parents: 1ea7f49
Author: Xpray <le...@gmail.com>
Authored: Mon Oct 9 18:19:01 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Oct 10 23:09:07 2017 +0200
----------------------------------------------------------------------
.../aggregate/GroupAggProcessFunction.scala | 10 ++++----
.../runtime/stream/table/AggregateITCase.scala | 25 ++++++++++++++++++++
2 files changed, 31 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/4047be49/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
index df59460..91c379f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
@@ -129,17 +129,19 @@ class GroupAggProcessFunction(
state.update(accumulators)
cntState.update(inputCnt)
- // if this was not the first row and we have to emit retractions
- if (generateRetraction && !firstRow) {
+ // if this was not the first row
+ if (!firstRow) {
if (prevRow.row.equals(newRow.row) && !stateCleaningEnabled) {
// newRow is the same as before and state cleaning is not enabled.
- // We do not emit retraction and acc message.
+ // We emit nothing
// If state cleaning is enabled, we have to emit messages to prevent too early
// state eviction of downstream operators.
return
} else {
// retract previous result
- out.collect(prevRow)
+ if (generateRetraction) {
+ out.collect(prevRow)
+ }
}
}
// emit the new result
http://git-wip-us.apache.org/repos/asf/flink/blob/4047be49/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
index eb3d37f..e67c784 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
@@ -193,4 +193,29 @@ class AggregateITCase extends StreamingWithStateTestBase {
// verify agg close is called
assert(JavaUserDefinedAggFunctions.isCloseCalled)
}
+
+ @Test
+ def testRemoveDuplicateRecordsWithUpsertSink(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+
+ val data = new mutable.MutableList[(Int, Long, String)]
+ data.+=((1, 1L, "A"))
+ data.+=((2, 2L, "B"))
+ data.+=((3, 2L, "B"))
+ data.+=((4, 3L, "C"))
+ data.+=((5, 3L, "C"))
+
+ val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c)
+ .groupBy('c)
+ .select('c, 'b.max)
+
+ t.writeToSink(new TestUpsertSink(Array("c"), false))
+ env.execute()
+
+ val expected = List("(true,A,1)", "(true,B,2)", "(true,C,3)")
+ assertEquals(expected.sorted, RowCollector.getAndClearValues.map(_.toString).sorted)
+ }
}
[3/5] flink git commit: [FLINK-6233] [table] Add inner rowtime window
join between two streams for SQL.
Posted by fh...@apache.org.
[FLINK-6233] [table] Add inner rowtime window join between two streams for SQL.
This closes #4625.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/655d8b16
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/655d8b16
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/655d8b16
Branch: refs/heads/master
Commit: 655d8b16193ac7131fa1f58fb4ba7ff96e439438
Parents: 9829ca0
Author: Xingcan Cui <xi...@gmail.com>
Authored: Wed Aug 30 13:57:38 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Oct 10 23:09:07 2017 +0200
----------------------------------------------------------------------
docs/dev/table/sql.md | 2 +-
.../nodes/datastream/DataStreamWindowJoin.scala | 132 +++++-
.../datastream/DataStreamWindowJoinRule.scala | 11 +-
.../join/ProcTimeBoundedStreamInnerJoin.scala | 68 +++
.../runtime/join/ProcTimeWindowInnerJoin.scala | 346 ----------------
.../join/RowTimeBoundedStreamInnerJoin.scala | 82 ++++
.../join/TimeBoundedStreamInnerJoin.scala | 412 +++++++++++++++++++
.../table/runtime/join/WindowJoinUtil.scala | 40 +-
.../flink/table/api/stream/sql/JoinTest.scala | 94 ++++-
.../table/runtime/harness/JoinHarnessTest.scala | 305 +++++++++++---
.../table/runtime/stream/sql/JoinITCase.scala | 205 ++++++++-
11 files changed, 1230 insertions(+), 467 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/docs/dev/table/sql.md
----------------------------------------------------------------------
diff --git a/docs/dev/table/sql.md b/docs/dev/table/sql.md
index b9205ab..533aa6e 100644
--- a/docs/dev/table/sql.md
+++ b/docs/dev/table/sql.md
@@ -409,7 +409,7 @@ FROM Orders LEFT JOIN Product ON Orders.productId = Product.id
</ul>
</p>
- <p><b>Note:</b> Currently, only processing time window joins and <code>INNER</code> joins are supported.</p>
+ <p><b>Note:</b> Currently, only <code>INNER</code> joins are supported.</p>
{% highlight sql %}
SELECT *
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala
index f8015b3..9358aa3 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala
@@ -23,14 +23,20 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
import org.apache.calcite.rel.{BiRel, RelNode, RelWriter}
import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.NullByteKeySelector
+import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.streaming.api.datastream.DataStream
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction
import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException}
import org.apache.flink.table.plan.nodes.CommonJoin
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.plan.util.UpdatingPlanChecker
-import org.apache.flink.table.runtime.join.{ProcTimeWindowInnerJoin, WindowJoinUtil}
+import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin, WindowJoinUtil}
+import org.apache.flink.table.runtime.operators.KeyedCoProcessOperatorWithWatermarkDelay
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
+import org.apache.flink.table.util.Logging
+import org.apache.flink.util.Collector
/**
* RelNode for a time windowed stream join.
@@ -48,11 +54,14 @@ class DataStreamWindowJoin(
isRowTime: Boolean,
leftLowerBound: Long,
leftUpperBound: Long,
+ leftTimeIdx: Int,
+ rightTimeIdx: Int,
remainCondition: Option[RexNode],
ruleDescription: String)
extends BiRel(cluster, traitSet, leftNode, rightNode)
with CommonJoin
- with DataStreamRel {
+ with DataStreamRel
+ with Logging {
override def deriveRowType(): RelDataType = schema.relDataType
@@ -70,6 +79,8 @@ class DataStreamWindowJoin(
isRowTime,
leftLowerBound,
leftUpperBound,
+ leftTimeIdx,
+ rightTimeIdx,
remainCondition,
ruleDescription)
}
@@ -107,10 +118,12 @@ class DataStreamWindowJoin(
val leftDataStream = left.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
val rightDataStream = right.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
- // get the equality keys and other condition
+ // get the equi-keys and other conditions
val joinInfo = JoinInfo.of(leftNode, rightNode, joinCondition)
val leftKeys = joinInfo.leftKeys.toIntArray
val rightKeys = joinInfo.rightKeys.toIntArray
+ val relativeWindowSize = leftUpperBound - leftLowerBound
+ val returnTypeInfo = CRowTypeInfo(schema.typeInfo)
// generate join function
val joinFunction =
@@ -125,20 +138,32 @@ class DataStreamWindowJoin(
joinType match {
case JoinRelType.INNER =>
- if (isRowTime) {
- // RowTime JoinCoProcessFunction
- throw new TableException(
- "RowTime inner join between stream and stream is not supported yet.")
+ if (relativeWindowSize < 0) {
+ LOG.warn(s"The relative window size $relativeWindowSize is negative," +
+ " please check the join conditions.")
+ createEmptyInnerJoin(leftDataStream, rightDataStream, returnTypeInfo)
} else {
- // Proctime JoinCoProcessFunction
- createProcTimeInnerJoinFunction(
- leftDataStream,
- rightDataStream,
- joinFunction.name,
- joinFunction.code,
- leftKeys,
- rightKeys
- )
+ if (isRowTime) {
+ createRowTimeInnerJoin(
+ leftDataStream,
+ rightDataStream,
+ returnTypeInfo,
+ joinFunction.name,
+ joinFunction.code,
+ leftKeys,
+ rightKeys
+ )
+ } else {
+ createProcTimeInnerJoin(
+ leftDataStream,
+ rightDataStream,
+ returnTypeInfo,
+ joinFunction.name,
+ joinFunction.code,
+ leftKeys,
+ rightKeys
+ )
+ }
}
case JoinRelType.FULL =>
throw new TableException(
@@ -152,19 +177,40 @@ class DataStreamWindowJoin(
}
}
- def createProcTimeInnerJoinFunction(
+ def createEmptyInnerJoin(
+ leftDataStream: DataStream[CRow],
+ rightDataStream: DataStream[CRow],
+ returnTypeInfo: TypeInformation[CRow]): DataStream[CRow] = {
+ leftDataStream.connect(rightDataStream).process(
+ new CoProcessFunction[CRow, CRow, CRow] {
+ override def processElement1(
+ value: CRow,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ out: Collector[CRow]): Unit = {
+ //Do nothing.
+ }
+ override def processElement2(
+ value: CRow,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ out: Collector[CRow]): Unit = {
+ //Do nothing.
+ }
+ }).returns(returnTypeInfo)
+ }
+
+ def createProcTimeInnerJoin(
leftDataStream: DataStream[CRow],
rightDataStream: DataStream[CRow],
+ returnTypeInfo: TypeInformation[CRow],
joinFunctionName: String,
joinFunctionCode: String,
leftKeys: Array[Int],
rightKeys: Array[Int]): DataStream[CRow] = {
- val returnTypeInfo = CRowTypeInfo(schema.typeInfo)
-
- val procInnerJoinFunc = new ProcTimeWindowInnerJoin(
+ val procInnerJoinFunc = new ProcTimeBoundedStreamInnerJoin(
leftLowerBound,
leftUpperBound,
+ allowedLateness = 0L,
leftSchema.typeInfo,
rightSchema.typeInfo,
joinFunctionName,
@@ -184,4 +230,50 @@ class DataStreamWindowJoin(
.returns(returnTypeInfo)
}
}
+
+ def createRowTimeInnerJoin(
+ leftDataStream: DataStream[CRow],
+ rightDataStream: DataStream[CRow],
+ returnTypeInfo: TypeInformation[CRow],
+ joinFunctionName: String,
+ joinFunctionCode: String,
+ leftKeys: Array[Int],
+ rightKeys: Array[Int]): DataStream[CRow] = {
+
+ val rowTimeInnerJoinFunc = new RowTimeBoundedStreamInnerJoin(
+ leftLowerBound,
+ leftUpperBound,
+ allowedLateness = 0L,
+ leftSchema.typeInfo,
+ rightSchema.typeInfo,
+ joinFunctionName,
+ joinFunctionCode,
+ leftTimeIdx,
+ rightTimeIdx)
+
+ if (!leftKeys.isEmpty) {
+ leftDataStream
+ .connect(rightDataStream)
+ .keyBy(leftKeys, rightKeys)
+ .transform(
+ "InnerRowtimeWindowJoin",
+ returnTypeInfo,
+ new KeyedCoProcessOperatorWithWatermarkDelay[Tuple, CRow, CRow, CRow](
+ rowTimeInnerJoinFunc,
+ rowTimeInnerJoinFunc.getMaxOutputDelay)
+ )
+ } else {
+ leftDataStream.connect(rightDataStream)
+ .keyBy(new NullByteKeySelector[CRow](), new NullByteKeySelector[CRow])
+ .transform(
+ "InnerRowtimeWindowJoin",
+ returnTypeInfo,
+ new KeyedCoProcessOperatorWithWatermarkDelay[java.lang.Byte, CRow, CRow, CRow](
+ rowTimeInnerJoinFunc,
+ rowTimeInnerJoinFunc.getMaxOutputDelay)
+ )
+ .setParallelism(1)
+ .setMaxParallelism(1)
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
index 7dfcbc5..d208d2b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
@@ -40,10 +40,9 @@ class DataStreamWindowJoinRule
override def matches(call: RelOptRuleCall): Boolean = {
val join: FlinkLogicalJoin = call.rel(0).asInstanceOf[FlinkLogicalJoin]
- val joinInfo = join.analyzeCondition
val (windowBounds, remainingPreds) = WindowJoinUtil.extractWindowBoundsFromPredicate(
- joinInfo.getRemaining(join.getCluster.getRexBuilder),
+ join.getCondition,
join.getLeft.getRowType.getFieldCount,
join.getRowType,
join.getCluster.getRexBuilder,
@@ -55,8 +54,7 @@ class DataStreamWindowJoinRule
if (windowBounds.isDefined) {
if (windowBounds.get.isEventTime) {
- // we cannot handle event-time window joins yet
- false
+ !remainingPredsAccessTime
} else {
// Check that no event-time attributes are in the input.
// The proc-time join implementation does ensure that record timestamp are correctly set.
@@ -80,13 +78,12 @@ class DataStreamWindowJoinRule
val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM)
val convLeft: RelNode = RelOptRule.convert(join.getInput(0), FlinkConventions.DATASTREAM)
val convRight: RelNode = RelOptRule.convert(join.getInput(1), FlinkConventions.DATASTREAM)
- val joinInfo = join.analyzeCondition
val leftRowSchema = new RowSchema(convLeft.getRowType)
val rightRowSchema = new RowSchema(convRight.getRowType)
val (windowBounds, remainCondition) =
WindowJoinUtil.extractWindowBoundsFromPredicate(
- joinInfo.getRemaining(join.getCluster.getRexBuilder),
+ join.getCondition,
leftRowSchema.arity,
join.getRowType,
join.getCluster.getRexBuilder,
@@ -105,6 +102,8 @@ class DataStreamWindowJoinRule
windowBounds.get.isEventTime,
windowBounds.get.leftLowerBound,
windowBounds.get.leftUpperBound,
+ windowBounds.get.leftTimeIdx,
+ windowBounds.get.rightTimeIdx,
remainCondition,
description)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala
new file mode 100644
index 0000000..ab5a9c3
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.join
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.types.Row
+
+/**
+ * The function to execute processing time bounded stream inner-join.
+ */
+final class ProcTimeBoundedStreamInnerJoin(
+ leftLowerBound: Long,
+ leftUpperBound: Long,
+ allowedLateness: Long,
+ leftType: TypeInformation[Row],
+ rightType: TypeInformation[Row],
+ genJoinFuncName: String,
+ genJoinFuncCode: String)
+ extends TimeBoundedStreamInnerJoin(
+ leftLowerBound,
+ leftUpperBound,
+ allowedLateness,
+ leftType,
+ rightType,
+ genJoinFuncName,
+ genJoinFuncCode) {
+
+ override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = {
+ leftOperatorTime = ctx.timerService().currentProcessingTime()
+ rightOperatorTime = leftOperatorTime
+ }
+
+ override def getTimeForLeftStream(
+ context: CoProcessFunction[CRow, CRow, CRow]#Context,
+ row: Row): Long = {
+ leftOperatorTime
+ }
+
+ override def getTimeForRightStream(
+ context: CoProcessFunction[CRow, CRow, CRow]#Context,
+ row: Row): Long = {
+ rightOperatorTime
+ }
+
+ override def registerTimer(
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ cleanupTime: Long): Unit = {
+ ctx.timerService.registerProcessingTimeTimer(cleanupTime)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala
deleted file mode 100644
index 8240376..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala
+++ /dev/null
@@ -1,346 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.join
-
-import java.util
-import java.util.{List => JList}
-
-import org.apache.flink.api.common.functions.FlatJoinFunction
-import org.apache.flink.api.common.state._
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.ListTypeInfo
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.co.CoProcessFunction
-import org.apache.flink.table.codegen.Compiler
-import org.apache.flink.table.runtime.CRowWrappingCollector
-import org.apache.flink.table.runtime.types.CRow
-import org.apache.flink.table.util.Logging
-import org.apache.flink.types.Row
-import org.apache.flink.util.Collector
-
-/**
- * A CoProcessFunction to support stream join stream, currently just support inner-join
- *
- * @param leftLowerBound
- * the left stream lower bound, and -leftLowerBound is the right stream upper bound
- * @param leftUpperBound
- * the left stream upper bound, and -leftUpperBound is the right stream lower bound
- * @param element1Type the input type of left stream
- * @param element2Type the input type of right stream
- * @param genJoinFuncName the function code of other non-equi condition
- * @param genJoinFuncCode the function name of other non-equi condition
- *
- */
-class ProcTimeWindowInnerJoin(
- private val leftLowerBound: Long,
- private val leftUpperBound: Long,
- private val element1Type: TypeInformation[Row],
- private val element2Type: TypeInformation[Row],
- private val genJoinFuncName: String,
- private val genJoinFuncCode: String)
- extends CoProcessFunction[CRow, CRow, CRow]
- with Compiler[FlatJoinFunction[Row, Row, Row]]
- with Logging {
-
- private var cRowWrapper: CRowWrappingCollector = _
-
- // other condition function
- private var joinFunction: FlatJoinFunction[Row, Row, Row] = _
-
- // tmp list to store expired records
- private var removeList: JList[Long] = _
-
- // state to hold left stream element
- private var row1MapState: MapState[Long, JList[Row]] = _
- // state to hold right stream element
- private var row2MapState: MapState[Long, JList[Row]] = _
-
- // state to record last timer of left stream, 0 means no timer
- private var timerState1: ValueState[Long] = _
- // state to record last timer of right stream, 0 means no timer
- private var timerState2: ValueState[Long] = _
-
- // compute window sizes, i.e., how long to keep rows in state.
- // window size of -1 means rows do not need to be put into state.
- private val leftStreamWinSize: Long = if (leftLowerBound <= 0) -leftLowerBound else -1
- private val rightStreamWinSize: Long = if (leftUpperBound >= 0) leftUpperBound else -1
-
- override def open(config: Configuration) {
- LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " +
- s"Code:\n$genJoinFuncCode")
- val clazz = compile(
- getRuntimeContext.getUserCodeClassLoader,
- genJoinFuncName,
- genJoinFuncCode)
- LOG.debug("Instantiating JoinFunction.")
- joinFunction = clazz.newInstance()
-
- removeList = new util.ArrayList[Long]()
- cRowWrapper = new CRowWrappingCollector()
- cRowWrapper.setChange(true)
-
- // initialize row state
- val rowListTypeInfo1: TypeInformation[JList[Row]] = new ListTypeInfo[Row](element1Type)
- val mapStateDescriptor1: MapStateDescriptor[Long, JList[Row]] =
- new MapStateDescriptor[Long, JList[Row]]("row1mapstate",
- BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo1)
- row1MapState = getRuntimeContext.getMapState(mapStateDescriptor1)
-
- val rowListTypeInfo2: TypeInformation[JList[Row]] = new ListTypeInfo[Row](element2Type)
- val mapStateDescriptor2: MapStateDescriptor[Long, JList[Row]] =
- new MapStateDescriptor[Long, JList[Row]]("row2mapstate",
- BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo2)
- row2MapState = getRuntimeContext.getMapState(mapStateDescriptor2)
-
- // initialize timer state
- val valueStateDescriptor1: ValueStateDescriptor[Long] =
- new ValueStateDescriptor[Long]("timervaluestate1", classOf[Long])
- timerState1 = getRuntimeContext.getState(valueStateDescriptor1)
-
- val valueStateDescriptor2: ValueStateDescriptor[Long] =
- new ValueStateDescriptor[Long]("timervaluestate2", classOf[Long])
- timerState2 = getRuntimeContext.getState(valueStateDescriptor2)
- }
-
- /**
- * Process left stream records
- *
- * @param valueC The input value.
- * @param ctx The ctx to register timer or get current time
- * @param out The collector for returning result values.
- *
- */
- override def processElement1(
- valueC: CRow,
- ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
- out: Collector[CRow]): Unit = {
-
- processElement(
- valueC,
- ctx,
- out,
- leftStreamWinSize,
- timerState1,
- row1MapState,
- row2MapState,
- -leftUpperBound, // right stream lower
- -leftLowerBound, // right stream upper
- isLeft = true
- )
- }
-
- /**
- * Process right stream records
- *
- * @param valueC The input value.
- * @param ctx The ctx to register timer or get current time
- * @param out The collector for returning result values.
- *
- */
- override def processElement2(
- valueC: CRow,
- ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
- out: Collector[CRow]): Unit = {
-
- processElement(
- valueC,
- ctx,
- out,
- rightStreamWinSize,
- timerState2,
- row2MapState,
- row1MapState,
- leftLowerBound, // left stream lower
- leftUpperBound, // left stream upper
- isLeft = false
- )
- }
-
- /**
- * Called when a processing timer trigger.
- * Expire left/right records which earlier than current time - windowsize.
- *
- * @param timestamp The timestamp of the firing timer.
- * @param ctx The ctx to register timer or get current time
- * @param out The collector for returning result values.
- */
- override def onTimer(
- timestamp: Long,
- ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext,
- out: Collector[CRow]): Unit = {
-
- if (timerState1.value == timestamp) {
- expireOutTimeRow(
- timestamp,
- leftStreamWinSize,
- row1MapState,
- timerState1,
- ctx
- )
- }
-
- if (timerState2.value == timestamp) {
- expireOutTimeRow(
- timestamp,
- rightStreamWinSize,
- row2MapState,
- timerState2,
- ctx
- )
- }
- }
-
- /**
- * Puts an element from the input stream into state and search the other state to
- * output records meet the condition, and registers a timer for the current record
- * if there is no timer at present.
- */
- private def processElement(
- valueC: CRow,
- ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
- out: Collector[CRow],
- winSize: Long,
- timerState: ValueState[Long],
- rowMapState: MapState[Long, JList[Row]],
- otherRowMapState: MapState[Long, JList[Row]],
- otherLowerBound: Long,
- otherUpperBound: Long,
- isLeft: Boolean): Unit = {
-
- cRowWrapper.out = out
-
- val row = valueC.row
-
- val curProcessTime = ctx.timerService.currentProcessingTime
- val otherLowerTime = curProcessTime + otherLowerBound
- val otherUpperTime = curProcessTime + otherUpperBound
-
- if (winSize >= 0) {
- // put row into state for later joining.
- // (winSize == 0) joins rows received in the same millisecond.
- var rowList = rowMapState.get(curProcessTime)
- if (rowList == null) {
- rowList = new util.ArrayList[Row]()
- }
- rowList.add(row)
- rowMapState.put(curProcessTime, rowList)
-
- // register a timer to remove the row from state once it is expired
- if (timerState.value == 0) {
- val cleanupTime = curProcessTime + winSize + 1
- ctx.timerService.registerProcessingTimeTimer(cleanupTime)
- timerState.update(cleanupTime)
- }
- }
-
- // join row with rows received from the other input
- val otherTimeIter = otherRowMapState.keys().iterator()
- if (isLeft) {
- // go over all timestamps in the other input's state
- while (otherTimeIter.hasNext) {
- val otherTimestamp = otherTimeIter.next()
- if (otherTimestamp < otherLowerTime) {
- // other timestamp is expired. Remove it later.
- removeList.add(otherTimestamp)
- } else if (otherTimestamp <= otherUpperTime) {
- // join row with all rows from the other input for this timestamp
- val otherRows = otherRowMapState.get(otherTimestamp)
- var i = 0
- while (i < otherRows.size) {
- joinFunction.join(row, otherRows.get(i), cRowWrapper)
- i += 1
- }
- }
- }
- } else {
- // go over all timestamps in the other input's state
- while (otherTimeIter.hasNext) {
- val otherTimestamp = otherTimeIter.next()
- if (otherTimestamp < otherLowerTime) {
- // other timestamp is expired. Remove it later.
- removeList.add(otherTimestamp)
- } else if (otherTimestamp <= otherUpperTime) {
- // join row with all rows from the other input for this timestamp
- val otherRows = otherRowMapState.get(otherTimestamp)
- var i = 0
- while (i < otherRows.size) {
- joinFunction.join(otherRows.get(i), row, cRowWrapper)
- i += 1
- }
- }
- }
- }
-
- // remove rows for expired timestamps
- var i = removeList.size - 1
- while (i >= 0) {
- otherRowMapState.remove(removeList.get(i))
- i -= 1
- }
- removeList.clear()
- }
-
- /**
- * Removes records which are outside the join window from the state.
- * Registers a new timer if the state still holds records after the clean-up.
- */
- private def expireOutTimeRow(
- curTime: Long,
- winSize: Long,
- rowMapState: MapState[Long, JList[Row]],
- timerState: ValueState[Long],
- ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = {
-
- val expiredTime = curTime - winSize
- val keyIter = rowMapState.keys().iterator()
- var validTimestamp: Boolean = false
- // Search for expired timestamps.
- // If we find a non-expired timestamp, remember the timestamp and leave the loop.
- // This way we find all expired timestamps if they are sorted without doing a full pass.
- while (keyIter.hasNext && !validTimestamp) {
- val recordTime = keyIter.next
- if (recordTime < expiredTime) {
- removeList.add(recordTime)
- } else {
- // we found a timestamp that is still valid
- validTimestamp = true
- }
- }
-
- // If the state has non-expired timestamps, register a new timer.
- // Otherwise clean the complete state for this input.
- if (validTimestamp) {
-
- // Remove expired records from state
- var i = removeList.size - 1
- while (i >= 0) {
- rowMapState.remove(removeList.get(i))
- i -= 1
- }
- removeList.clear()
-
- val cleanupTime = curTime + winSize + 1
- ctx.timerService.registerProcessingTimeTimer(cleanupTime)
- timerState.update(cleanupTime)
- } else {
- timerState.clear()
- rowMapState.clear()
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala
new file mode 100644
index 0000000..5cf5a53
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.join
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.types.Row
+
+/**
+ * The function to execute row(event) time bounded stream inner-join.
+ */
+final class RowTimeBoundedStreamInnerJoin(
+ leftLowerBound: Long,
+ leftUpperBound: Long,
+ allowedLateness: Long,
+ leftType: TypeInformation[Row],
+ rightType: TypeInformation[Row],
+ genJoinFuncName: String,
+ genJoinFuncCode: String,
+ leftTimeIdx: Int,
+ rightTimeIdx: Int)
+ extends TimeBoundedStreamInnerJoin(
+ leftLowerBound,
+ leftUpperBound,
+ allowedLateness,
+ leftType,
+ rightType,
+ genJoinFuncName,
+ genJoinFuncCode) {
+
+ /**
+ * Get the maximum interval between receiving a row and emitting it (as part of a joined result).
+ * Only reasonable for row time join.
+ *
+ * @return the maximum delay for the outputs
+ */
+ def getMaxOutputDelay: Long = Math.max(leftRelativeSize, rightRelativeSize) + allowedLateness
+
+ override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = {
+ leftOperatorTime =
+ if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark()
+ else 0L
+ // We may set different operator times in the future.
+ rightOperatorTime = leftOperatorTime
+ }
+
+ override def getTimeForLeftStream(
+ context: CoProcessFunction[CRow, CRow, CRow]#Context,
+ row: Row): Long = {
+ row.getField(leftTimeIdx).asInstanceOf[Long]
+ }
+
+ override def getTimeForRightStream(
+ context: CoProcessFunction[CRow, CRow, CRow]#Context,
+ row: Row): Long = {
+ row.getField(rightTimeIdx).asInstanceOf[Long]
+ }
+
+ override def registerTimer(
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ cleanupTime: Long): Unit = {
+ // Maybe we can register timers for different streams in the future.
+ ctx.timerService.registerEventTimeTimer(cleanupTime)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala
new file mode 100644
index 0000000..7bf3d33
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.join
+
+import java.util
+import java.util.{List => JList}
+
+import org.apache.flink.api.common.functions.FlatJoinFunction
+import org.apache.flink.api.common.state._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.ListTypeInfo
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.codegen.Compiler
+import org.apache.flink.table.runtime.CRowWrappingCollector
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.util.Logging
+import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
+
+/**
+ * A CoProcessFunction to execute time-bounded stream inner-join.
+ * Two kinds of time criteria:
+ * "L.time between R.time + X and R.time + Y" or "R.time between L.time - Y and L.time - X".
+ *
+ * @param leftLowerBound the lower bound for the left stream (X in the criteria)
+ * @param leftUpperBound the upper bound for the left stream (Y in the criteria)
+ * @param allowedLateness the lateness allowed for the two streams
+ * @param leftType the input type of left stream
+ * @param rightType the input type of right stream
+ * @param genJoinFuncName the function code of other non-equi conditions
+ * @param genJoinFuncCode the function name of other non-equi conditions
+ *
+ */
+abstract class TimeBoundedStreamInnerJoin(
+ private val leftLowerBound: Long,
+ private val leftUpperBound: Long,
+ private val allowedLateness: Long,
+ private val leftType: TypeInformation[Row],
+ private val rightType: TypeInformation[Row],
+ private val genJoinFuncName: String,
+ private val genJoinFuncCode: String)
+ extends CoProcessFunction[CRow, CRow, CRow]
+ with Compiler[FlatJoinFunction[Row, Row, Row]]
+ with Logging {
+
+ private var cRowWrapper: CRowWrappingCollector = _
+
+ // the join function for other conditions
+ private var joinFunction: FlatJoinFunction[Row, Row, Row] = _
+
+ // cache to store rows from the left stream
+ private var leftCache: MapState[Long, JList[Row]] = _
+ // cache to store rows from the right stream
+ private var rightCache: MapState[Long, JList[Row]] = _
+
+ // state to record the timer on the left stream. 0 means no timer set
+ private var leftTimerState: ValueState[Long] = _
+ // state to record the timer on the right stream. 0 means no timer set
+ private var rightTimerState: ValueState[Long] = _
+
+ protected val leftRelativeSize: Long = -leftLowerBound
+ protected val rightRelativeSize: Long = leftUpperBound
+
+ private var leftExpirationTime: Long = 0L
+ private var rightExpirationTime: Long = 0L
+
+ protected var leftOperatorTime: Long = 0L
+ protected var rightOperatorTime: Long = 0L
+
+
+ // for delayed cleanup
+ private val cleanupDelay = (leftRelativeSize + rightRelativeSize) / 2
+
+ if (allowedLateness < 0) {
+ throw new IllegalArgumentException("The allowed lateness must be non-negative.")
+ }
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " +
+ s"Code:\n$genJoinFuncCode")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genJoinFuncName,
+ genJoinFuncCode)
+ LOG.debug("Instantiating JoinFunction.")
+ joinFunction = clazz.newInstance()
+
+ cRowWrapper = new CRowWrappingCollector()
+ cRowWrapper.setChange(true)
+
+ // Initialize the data caches.
+ val leftListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](leftType)
+ val leftStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+ new MapStateDescriptor[Long, JList[Row]](
+ "InnerJoinLeftCache",
+ Types.LONG.asInstanceOf[TypeInformation[Long]],
+ leftListTypeInfo)
+ leftCache = getRuntimeContext.getMapState(leftStateDescriptor)
+
+ val rightListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](rightType)
+ val rightStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+ new MapStateDescriptor[Long, JList[Row]](
+ "InnerJoinRightCache",
+ Types.LONG.asInstanceOf[TypeInformation[Long]],
+ rightListTypeInfo)
+ rightCache = getRuntimeContext.getMapState(rightStateDescriptor)
+
+ // Initialize the timer states.
+ val leftTimerStateDesc: ValueStateDescriptor[Long] =
+ new ValueStateDescriptor[Long]("InnerJoinLeftTimerState", classOf[Long])
+ leftTimerState = getRuntimeContext.getState(leftTimerStateDesc)
+
+ val rightTimerStateDesc: ValueStateDescriptor[Long] =
+ new ValueStateDescriptor[Long]("InnerJoinRightTimerState", classOf[Long])
+ rightTimerState = getRuntimeContext.getState(rightTimerStateDesc)
+ }
+
+ /**
+ * Process rows from the left stream.
+ */
+ override def processElement1(
+ cRowValue: CRow,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ out: Collector[CRow]): Unit = {
+ updateOperatorTime(ctx)
+ val leftRow = cRowValue.row
+ val timeForLeftRow: Long = getTimeForLeftStream(ctx, leftRow)
+ val rightQualifiedLowerBound: Long = timeForLeftRow - rightRelativeSize
+ val rightQualifiedUpperBound: Long = timeForLeftRow + leftRelativeSize
+ cRowWrapper.out = out
+ // Check if we need to cache the current row.
+ if (rightOperatorTime < rightQualifiedUpperBound) {
+ // Operator time of right stream has not exceeded the upper window bound of the current
+ // row. Put it into the left cache, since later coming records from the right stream are
+ // expected to be joined with it.
+ var leftRowList = leftCache.get(timeForLeftRow)
+ if (null == leftRowList) {
+ leftRowList = new util.ArrayList[Row](1)
+ }
+ leftRowList.add(leftRow)
+ leftCache.put(timeForLeftRow, leftRowList)
+ if (rightTimerState.value == 0) {
+ // Register a timer on the RIGHT stream to remove rows.
+ registerCleanUpTimer(ctx, timeForLeftRow, leftRow = true)
+ }
+ }
+ // Check if we need to join the current row against cached rows of the right input.
+ // The condition here should be rightMinimumTime < rightQualifiedUpperBound.
+ // I use rightExpirationTime as an approximation of the rightMinimumTime here,
+ // since rightExpirationTime <= rightMinimumTime is always true.
+ if (rightExpirationTime < rightQualifiedUpperBound) {
+ // Upper bound of current join window has not passed the cache expiration time yet.
+ // There might be qualifying rows in the cache that the current row needs to be joined with.
+ rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize)
+ // Join the leftRow with rows from the right cache.
+ val rightIterator = rightCache.iterator()
+ while (rightIterator.hasNext) {
+ val rightEntry = rightIterator.next
+ val rightTime = rightEntry.getKey
+ if (rightTime >= rightQualifiedLowerBound && rightTime <= rightQualifiedUpperBound) {
+ val rightRows = rightEntry.getValue
+ var i = 0
+ while (i < rightRows.size) {
+ joinFunction.join(leftRow, rightRows.get(i), cRowWrapper)
+ i += 1
+ }
+ }
+
+ if (rightTime <= rightExpirationTime) {
+ // eager remove
+ rightIterator.remove()
+ }// We could do the short-cutting optimization here once we get a state with ordered keys.
+ }
+ }
+ }
+
+ /**
+ * Process rows from the right stream.
+ */
+ override def processElement2(
+ cRowValue: CRow,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ out: Collector[CRow]): Unit = {
+ updateOperatorTime(ctx)
+ val rightRow = cRowValue.row
+ val timeForRightRow: Long = getTimeForRightStream(ctx, rightRow)
+ val leftQualifiedLowerBound: Long = timeForRightRow - leftRelativeSize
+ val leftQualifiedUpperBound: Long = timeForRightRow + rightRelativeSize
+ cRowWrapper.out = out
+ // Check if we need to cache the current row.
+ if (leftOperatorTime < leftQualifiedUpperBound) {
+ // Operator time of left stream has not exceeded the upper window bound of the current
+ // row. Put it into the right cache, since later coming records from the left stream are
+ // expected to be joined with it.
+ var rightRowList = rightCache.get(timeForRightRow)
+ if (null == rightRowList) {
+ rightRowList = new util.ArrayList[Row](1)
+ }
+ rightRowList.add(rightRow)
+ rightCache.put(timeForRightRow, rightRowList)
+ if (leftTimerState.value == 0) {
+ // Register a timer on the LEFT stream to remove rows.
+ registerCleanUpTimer(ctx, timeForRightRow, leftRow = false)
+ }
+ }
+ // Check if we need to join the current row against cached rows of the left input.
+ // The condition here should be leftMinimumTime < leftQualifiedUpperBound.
+ // I use leftExpirationTime as an approximation of the leftMinimumTime here,
+ // since leftExpirationTime <= leftMinimumTime is always true.
+ if (leftExpirationTime < leftQualifiedUpperBound) {
+ leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize)
+ // Join the rightRow with rows from the left cache.
+ val leftIterator = leftCache.iterator()
+ while (leftIterator.hasNext) {
+ val leftEntry = leftIterator.next
+ val leftTime = leftEntry.getKey
+ if (leftTime >= leftQualifiedLowerBound && leftTime <= leftQualifiedUpperBound) {
+ val leftRows = leftEntry.getValue
+ var i = 0
+ while (i < leftRows.size) {
+ joinFunction.join(leftRows.get(i), rightRow, cRowWrapper)
+ i += 1
+ }
+ }
+ if (leftTime <= leftExpirationTime) {
+ // eager remove
+ leftIterator.remove()
+ } // We could do the short-cutting optimization here once we get a state with ordered keys.
+ }
+ }
+ }
+
+ /**
+ * Called when a registered timer is fired.
+ * Remove rows whose timestamps are earlier than the expiration time,
+ * and register a new timer for the remaining rows.
+ *
+ * @param timestamp the timestamp of the timer
+ * @param ctx the context to register timer or get current time
+ * @param out the collector for returning result values
+ */
+ override def onTimer(
+ timestamp: Long,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext,
+ out: Collector[CRow]): Unit = {
+ updateOperatorTime(ctx)
+ // In the future, we should separate the left and right watermarks. Otherwise, the
+ // registered timer of the faster stream will be delayed, even if the watermarks have
+ // already been emitted by the source.
+ if (leftTimerState.value == timestamp) {
+ rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize)
+ removeExpiredRows(
+ rightExpirationTime,
+ rightCache,
+ leftTimerState,
+ ctx,
+ removeLeft = false
+ )
+ }
+
+ if (rightTimerState.value == timestamp) {
+ leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize)
+ removeExpiredRows(
+ leftExpirationTime,
+ leftCache,
+ rightTimerState,
+ ctx,
+ removeLeft = true
+ )
+ }
+ }
+
+ /**
+ * Calculate the expiration time with the given operator time and relative window size.
+ *
+ * @param operatorTime the operator time
+ * @param relativeSize the relative window size
+ * @return the expiration time for cached rows
+ */
+ private def calExpirationTime(operatorTime: Long, relativeSize: Long): Long = {
+ if (operatorTime < Long.MaxValue) {
+ operatorTime - relativeSize - allowedLateness - 1
+ } else {
+ // When operatorTime = Long.MaxValue, it means the stream has reached the end.
+ Long.MaxValue
+ }
+ }
+
+ /**
+ * Register a timer for cleaning up rows in a specified time.
+ *
+ * @param ctx the context to register timer
+ * @param rowTime time for the input row
+ * @param leftRow whether this row comes from the left stream
+ */
+ private def registerCleanUpTimer(
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ rowTime: Long,
+ leftRow: Boolean): Unit = {
+ if (leftRow) {
+ val cleanupTime = rowTime + leftRelativeSize + cleanupDelay + allowedLateness + 1
+ registerTimer(ctx, cleanupTime)
+ rightTimerState.update(cleanupTime)
+ } else {
+ val cleanupTime = rowTime + rightRelativeSize + cleanupDelay + allowedLateness + 1
+ registerTimer(ctx, cleanupTime)
+ leftTimerState.update(cleanupTime)
+ }
+ }
+
+ /**
+ * Remove the expired rows. Register a new timer if the cache still holds valid rows
+ * after the cleaning up.
+ *
+ * @param expirationTime the expiration time for this cache
+ * @param rowCache the row cache
+ * @param timerState timer state for the opposite stream
+ * @param ctx the context to register the cleanup timer
+ * @param removeLeft whether to remove the left rows
+ */
+ private def removeExpiredRows(
+ expirationTime: Long,
+ rowCache: MapState[Long, JList[Row]],
+ timerState: ValueState[Long],
+ ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext,
+ removeLeft: Boolean): Unit = {
+
+ val keysIterator = rowCache.keys().iterator()
+
+ var earliestTimestamp: Long = -1L
+ var rowTime: Long = 0L
+
+ // We remove all expired keys and do not leave the loop early.
+ // Hence, we do a full pass over the state.
+ while (keysIterator.hasNext) {
+ rowTime = keysIterator.next
+ if (rowTime <= expirationTime) {
+ keysIterator.remove()
+ } else {
+ // We find the earliest timestamp that is still valid.
+ if (rowTime < earliestTimestamp || earliestTimestamp < 0) {
+ earliestTimestamp = rowTime
+ }
+ }
+ }
+ if (earliestTimestamp > 0) {
+ // There are rows left in the cache. Register a timer to expire them later.
+ registerCleanUpTimer(
+ ctx,
+ earliestTimestamp,
+ removeLeft)
+ } else {
+ // No rows left in the cache. Clear the states and the timerState will be 0.
+ timerState.clear()
+ rowCache.clear()
+ }
+ }
+
+ /**
+ * Update the operator time of the two streams.
+ * Must be the first call in all processing methods (i.e., processElement(), onTimer()).
+ *
+ * @param ctx the context to acquire watermarks
+ */
+ def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit
+
+ /**
+ * Return the time for the target row from the left stream.
+ *
+ * @param context the runtime context
+ * @param row the target row
+ * @return time for the target row
+ */
+ def getTimeForLeftStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: Row): Long
+
+ /**
+ * Return the time for the target row from the right stream.
+ *
+ * @param context the runtime context
+ * @param row the target row
+ * @return time for the target row
+ */
+ def getTimeForRightStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: Row): Long
+
+ /**
+ * Register a proctime or rowtime timer.
+ *
+ * @param ctx the context to register the timer
+ * @param cleanupTime timestamp for the timer
+ */
+ def registerTimer(
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ cleanupTime: Long): Unit
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala
index b566113..6f97f2a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala
@@ -39,14 +39,23 @@ import scala.collection.JavaConverters._
*/
object WindowJoinUtil {
- case class WindowBounds(isEventTime: Boolean, leftLowerBound: Long, leftUpperBound: Long)
+ case class WindowBounds(
+ isEventTime: Boolean,
+ leftLowerBound: Long,
+ leftUpperBound: Long,
+ leftTimeIdx: Int,
+ rightTimeIdx: Int)
protected case class WindowBound(bound: Long, isLeftLower: Boolean)
+
protected case class TimePredicate(
isEventTime: Boolean,
leftInputOnLeftSide: Boolean,
+ leftTimeIdx: Int,
+ rightTimeIdx: Int,
pred: RexCall)
- protected case class TimeAttributeAccess(isEventTime: Boolean, isLeftInput: Boolean)
+
+ protected case class TimeAttributeAccess(isEventTime: Boolean, isLeftInput: Boolean, idx: Int)
/**
* Extracts the window bounds from a join predicate.
@@ -116,7 +125,21 @@ object WindowJoinUtil {
Some(otherPreds.reduceLeft((l, r) => RelOptUtil.andJoinFilters(rexBuilder, l, r)))
}
- val bounds = Some(WindowBounds(timePreds.head.isEventTime, leftLowerBound, leftUpperBound))
+ val bounds = if (timePreds.head.leftInputOnLeftSide) {
+ Some(WindowBounds(
+ timePreds.head.isEventTime,
+ leftLowerBound,
+ leftUpperBound,
+ timePreds.head.leftTimeIdx,
+ timePreds.head.rightTimeIdx))
+ } else {
+ Some(WindowBounds(
+ timePreds.head.isEventTime,
+ leftLowerBound,
+ leftUpperBound,
+ timePreds.head.rightTimeIdx,
+ timePreds.head.leftTimeIdx))
+ }
(bounds, remainCondition)
}
@@ -196,8 +219,8 @@ object WindowJoinUtil {
case (Some(left), Some(right)) if left.isLeftInput == right.isLeftInput =>
// Window join predicates must reference the time attribute of both inputs.
Right(pred)
- case (Some(left), Some(_)) =>
- Left(TimePredicate(left.isEventTime, left.isLeftInput, c))
+ case (Some(left), Some(right)) =>
+ Left(TimePredicate(left.isEventTime, left.isLeftInput, left.idx, right.idx, c))
}
// not a comparison predicate.
case _ => Right(pred)
@@ -224,8 +247,11 @@ object WindowJoinUtil {
inputType.getFieldList.get(idx).getType match {
case t: TimeIndicatorRelDataType =>
// time attribute access. Remember time type and side of input
- val isLeftInput = idx < leftFieldCount
- Seq(TimeAttributeAccess(t.isEventTime, isLeftInput))
+ if (idx < leftFieldCount) {
+ Seq(TimeAttributeAccess(t.isEventTime, true, idx))
+ } else {
+ Seq(TimeAttributeAccess(t.isEventTime, false, idx - leftFieldCount))
+ }
case _ =>
// not a time attribute access.
Seq()
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala
index e066fe4..a4234c5 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala
@@ -32,7 +32,7 @@ class JoinTest extends TableTestBase {
streamUtil.addTable[(Int, String, Long)]("MyTable2", 'a, 'b, 'c.rowtime, 'proctime.proctime)
@Test
- def testProcessingTimeInnerJoinWithOnClause() = {
+ def testProcessingTimeInnerJoinWithOnClause(): Unit = {
val sqlQuery =
"""
@@ -70,7 +70,45 @@ class JoinTest extends TableTestBase {
}
@Test
- def testProcessingTimeInnerJoinWithWhereClause() = {
+ def testRowTimeInnerJoinWithOnClause(): Unit = {
+
+ val sqlQuery =
+ """
+ |SELECT t1.a, t2.b
+ |FROM MyTable t1 JOIN MyTable2 t2 ON
+ | t1.a = t2.a AND
+ | t1.c BETWEEN t2.c - INTERVAL '10' SECOND AND t2.c + INTERVAL '1' HOUR
+ |""".stripMargin
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamWindowJoin",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "c")
+ ),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(1),
+ term("select", "a", "b", "c")
+ ),
+ term("where",
+ "AND(=(a, a0), >=(c, -(c0, 10000)), " +
+ "<=(c, DATETIME_PLUS(c0, 3600000)))"),
+ term("join", "a, c, a0, b, c0"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "a", "b")
+ )
+
+ streamUtil.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testProcessingTimeInnerJoinWithWhereClause(): Unit = {
val sqlQuery =
"""
@@ -108,6 +146,44 @@ class JoinTest extends TableTestBase {
}
@Test
+ def testRowTimeInnerJoinWithWhereClause(): Unit = {
+
+ val sqlQuery =
+ """
+ |SELECT t1.a, t2.b
+ |FROM MyTable t1, MyTable2 t2
+ |WHERE t1.a = t2.a AND
+ | t1.c BETWEEN t2.c - INTERVAL '10' MINUTE AND t2.c + INTERVAL '1' HOUR
+ |""".stripMargin
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamWindowJoin",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "a", "c")
+ ),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(1),
+ term("select", "a", "b", "c")
+ ),
+ term("where",
+ "AND(=(a, a0), >=(c, -(c0, 600000)), " +
+ "<=(c, DATETIME_PLUS(c0, 3600000)))"),
+ term("join", "a, c, a0, b, c0"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "a", "b0 AS b")
+ )
+
+ streamUtil.verifySql(sqlQuery, expected)
+ }
+
+ @Test
def testJoinTimeBoundary(): Unit = {
verifyTimeBoundary(
"t1.proctime between t2.proctime - interval '1' hour " +
@@ -175,16 +251,17 @@ class JoinTest extends TableTestBase {
"SELECT t1.a, t2.c FROM MyTable3 as t1 join MyTable4 as t2 on t1.a = t2.a and " +
"t1.b >= t2.b - interval '10' second and t1.b <= t2.b - interval '5' second and " +
"t1.c > t2.c"
+ // The equi-join predicate should also be included
verifyRemainConditionConvert(
query,
- ">($2, $6)")
+ "AND(=($0, $4), >($2, $6))")
val query1 =
"SELECT t1.a, t2.c FROM MyTable3 as t1 join MyTable4 as t2 on t1.a = t2.a and " +
"t1.b >= t2.b - interval '10' second and t1.b <= t2.b - interval '5' second "
verifyRemainConditionConvert(
query1,
- "")
+ "=($0, $4)")
streamUtil.addTable[(Int, Long, Int)]("MyTable5", 'a, 'b, 'c, 'proctime.proctime)
streamUtil.addTable[(Int, Long, Int)]("MyTable6", 'a, 'b, 'c, 'proctime.proctime)
@@ -195,7 +272,7 @@ class JoinTest extends TableTestBase {
"t1.c > t2.c"
verifyRemainConditionConvert(
query2,
- ">($2, $6)")
+ "AND(=($0, $4), >($2, $6))")
}
private def verifyTimeBoundary(
@@ -209,10 +286,9 @@ class JoinTest extends TableTestBase {
val resultTable = streamUtil.tableEnv.sqlQuery(query)
val relNode = resultTable.getRelNode
val joinNode = relNode.getInput(0).asInstanceOf[LogicalJoin]
- val rexNode = joinNode.getCondition
val (windowBounds, _) =
WindowJoinUtil.extractWindowBoundsFromPredicate(
- rexNode,
+ joinNode.getCondition,
4,
joinNode.getRowType,
joinNode.getCluster.getRexBuilder,
@@ -233,11 +309,9 @@ class JoinTest extends TableTestBase {
val resultTable = streamUtil.tableEnv.sqlQuery(query)
val relNode = resultTable.getRelNode
val joinNode = relNode.getInput(0).asInstanceOf[LogicalJoin]
- val joinInfo = joinNode.analyzeCondition
- val rexNode = joinInfo.getRemaining(joinNode.getCluster.getRexBuilder)
val (_, remainCondition) =
WindowJoinUtil.extractWindowBoundsFromPredicate(
- rexNode,
+ joinNode.getCondition,
4,
joinNode.getRowType,
joinNode.getCluster.getRexBuilder,
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
index 065b7bc..192befd 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
@@ -17,29 +17,26 @@
*/
package org.apache.flink.table.runtime.harness
+import java.lang.{Long => JLong}
import java.util.concurrent.ConcurrentLinkedQueue
-import java.lang.{Integer => JInt}
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator
+import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness
+import org.apache.flink.table.api.Types
import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, TupleRowKeySelector}
-import org.apache.flink.table.runtime.join.ProcTimeWindowInnerJoin
+import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin}
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.types.Row
+import org.junit.Assert.{assertEquals}
import org.junit.Test
-import org.junit.Assert.{assertEquals, assertTrue}
-class JoinHarnessTest extends HarnessTestBase{
-
- private val rT = new RowTypeInfo(Array[TypeInformation[_]](
- INT_TYPE_INFO,
- STRING_TYPE_INFO),
- Array("a", "b"))
+class JoinHarnessTest extends HarnessTestBase {
+ private val rowType = Types.ROW(
+ Types.LONG,
+ Types.STRING)
val funcCode: String =
"""
@@ -75,84 +72,88 @@ class JoinHarnessTest extends HarnessTestBase{
/** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime + 20 **/
@Test
- def testNormalProcTimeJoin() {
+ def testProcTimeJoinWithCommonBounds() {
- val joinProcessFunc = new ProcTimeWindowInnerJoin(-10, 20, rT, rT, "TestJoinFunction", funcCode)
+ val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin(
+ -10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode)
val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] =
new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc)
val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] =
new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow](
- operator,
- new TupleRowKeySelector[Integer](0),
- new TupleRowKeySelector[Integer](0),
- BasicTypeInfo.INT_TYPE_INFO,
- 1, 1, 0)
+ operator,
+ new TupleRowKeySelector[Integer](0),
+ new TupleRowKeySelector[Integer](0),
+ Types.INT,
+ 1, 1, 0)
testHarness.open()
- // left stream input
testHarness.setProcessingTime(1)
testHarness.processElement1(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa"), true), 1))
+ CRow(Row.of(1L: JLong, "1a1"), true), 1))
assertEquals(1, testHarness.numProcessingTimeTimers())
testHarness.setProcessingTime(2)
testHarness.processElement1(new StreamRecord(
- CRow(Row.of(2: JInt, "bbb"), true), 2))
+ CRow(Row.of(2L: JLong, "2a2"), true), 2))
+
+ // timers for key = 1 and key = 2
assertEquals(2, testHarness.numProcessingTimeTimers())
+
testHarness.setProcessingTime(3)
testHarness.processElement1(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa2"), true), 3))
+ CRow(Row.of(1L: JLong, "1a3"), true), 3))
assertEquals(4, testHarness.numKeyedStateEntries())
+
+ // The number of timers won't increase.
assertEquals(2, testHarness.numProcessingTimeTimers())
- // right stream input and output normally
testHarness.processElement2(new StreamRecord(
- CRow(Row.of(1: JInt, "Hi1"), true), 3))
+ CRow(Row.of(1L: JLong, "1b3"), true), 3))
testHarness.setProcessingTime(4)
testHarness.processElement2(new StreamRecord(
- CRow(Row.of(2: JInt, "Hello1"), true), 4))
- assertEquals(8, testHarness.numKeyedStateEntries())
- assertEquals(4, testHarness.numProcessingTimeTimers())
+ CRow(Row.of(2L: JLong, "2b4"), true), 4))
- // expired left stream record at timestamp 1
- testHarness.setProcessingTime(12)
+ // The number of states should be doubled.
assertEquals(8, testHarness.numKeyedStateEntries())
assertEquals(4, testHarness.numProcessingTimeTimers())
+
+ // Test for -10 boundary (13 - 10 = 3).
+ // The left row (key = 1) with timestamp = 1 will be eagerly removed here.
+ testHarness.setProcessingTime(13)
testHarness.processElement2(new StreamRecord(
- CRow(Row.of(1: JInt, "Hi2"), true), 12))
+ CRow(Row.of(1L: JLong, "1b13"), true), 13))
- // expired right stream record at timestamp 4 and all left stream
- testHarness.setProcessingTime(25)
- assertEquals(2, testHarness.numKeyedStateEntries())
- assertEquals(1, testHarness.numProcessingTimeTimers())
+ // Test for +20 boundary (13 + 20 = 33).
+ testHarness.setProcessingTime(33)
+ assertEquals(4, testHarness.numKeyedStateEntries())
+ assertEquals(2, testHarness.numProcessingTimeTimers())
testHarness.processElement1(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa3"), true), 25))
+ CRow(Row.of(1L: JLong, "1a33"), true), 33))
+
testHarness.processElement1(new StreamRecord(
- CRow(Row.of(2: JInt, "bbb2"), true), 25))
+ CRow(Row.of(2L: JLong, "2a33"), true), 33))
+
+ // The left row (key = 2) with timestamp = 2 will be eagerly removed here.
testHarness.processElement2(new StreamRecord(
- CRow(Row.of(2: JInt, "Hello2"), true), 25))
+ CRow(Row.of(2L: JLong, "2b33"), true), 33))
- testHarness.setProcessingTime(45)
- assertTrue(testHarness.numKeyedStateEntries() > 0)
- testHarness.setProcessingTime(46)
- assertEquals(0, testHarness.numKeyedStateEntries())
val result = testHarness.getOutput
val expectedOutput = new ConcurrentLinkedQueue[Object]()
expectedOutput.add(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true), 3))
+ CRow(Row.of(1L: JLong, "1a1", 1L: JLong, "1b3"), true), 3))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi1"), true), 3))
+ CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b3"), true), 3))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(2: JInt, "bbb", 2: JInt, "Hello1"), true), 4))
+ CRow(Row.of(2L: JLong, "2a2", 2L: JLong, "2b4"), true), 4))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi2"), true), 12))
+ CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b13"), true), 13))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa3", 1: JInt, "Hi2"), true), 25))
+ CRow(Row.of(1L: JLong, "1a33", 1L: JLong, "1b13"), true), 33))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(2: JInt, "bbb2", 2: JInt, "Hello2"), true), 25))
+ CRow(Row.of(2L: JLong, "2a33", 2L: JLong, "2b33"), true), 33))
verify(expectedOutput, result, new RowResultSortComparator())
@@ -161,9 +162,10 @@ class JoinHarnessTest extends HarnessTestBase{
/** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime - 5 **/
@Test
- def testProcTimeJoinSingleNeedStore() {
+ def testProcTimeJoinWithNegativeBounds() {
- val joinProcessFunc = new ProcTimeWindowInnerJoin(-10, -5, rT, rT, "TestJoinFunction", funcCode)
+ val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin(
+ -10, -5, 0, rowType, rowType, "TestJoinFunction", funcCode)
val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] =
new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc)
@@ -172,50 +174,58 @@ class JoinHarnessTest extends HarnessTestBase{
operator,
new TupleRowKeySelector[Integer](0),
new TupleRowKeySelector[Integer](0),
- BasicTypeInfo.INT_TYPE_INFO,
+ Types.INT,
1, 1, 0)
testHarness.open()
testHarness.setProcessingTime(1)
testHarness.processElement1(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa1"), true), 1))
+ CRow(Row.of(1L: JLong, "1a1"), true), 1))
testHarness.setProcessingTime(2)
testHarness.processElement1(new StreamRecord(
- CRow(Row.of(2: JInt, "aaa2"), true), 2))
+ CRow(Row.of(2L: JLong, "2a2"), true), 2))
testHarness.setProcessingTime(3)
testHarness.processElement1(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa3"), true), 3))
+ CRow(Row.of(1L: JLong, "1a3"), true), 3))
assertEquals(4, testHarness.numKeyedStateEntries())
assertEquals(2, testHarness.numProcessingTimeTimers())
- // Do not store b elements
- // not meet a.proctime <= b.proctime - 5
+ // All the right rows will not be cached.
testHarness.processElement2(new StreamRecord(
- CRow(Row.of(1: JInt, "bbb3"), true), 3))
+ CRow(Row.of(1L: JLong, "1b3"), true), 3))
assertEquals(4, testHarness.numKeyedStateEntries())
assertEquals(2, testHarness.numProcessingTimeTimers())
- // meet a.proctime <= b.proctime - 5
testHarness.setProcessingTime(7)
+
+ // Meets a.proctime <= b.proctime - 5.
+ // This row will only be joined without being cached (7 >= 7 - 5).
testHarness.processElement2(new StreamRecord(
- CRow(Row.of(2: JInt, "bbb7"), true), 7))
+ CRow(Row.of(2L: JLong, "2b7"), true), 7))
assertEquals(4, testHarness.numKeyedStateEntries())
assertEquals(2, testHarness.numProcessingTimeTimers())
- // expire record of stream a at timestamp 1
testHarness.setProcessingTime(12)
- assertEquals(4, testHarness.numKeyedStateEntries())
- assertEquals(2, testHarness.numProcessingTimeTimers())
+ // The left row (key = 1) with timestamp = 1 will be eagerly removed here.
testHarness.processElement2(new StreamRecord(
- CRow(Row.of(1: JInt, "bbb12"), true), 12))
+ CRow(Row.of(1L: JLong, "1b12"), true), 12))
+ // We add a delay (relativeWindowSize / 2) for cleaning up state.
+ // No timers will be triggered here.
testHarness.setProcessingTime(13)
+ assertEquals(4, testHarness.numKeyedStateEntries())
+ assertEquals(2, testHarness.numProcessingTimeTimers())
+
+ // Trigger the timer registered by the left row (key = 1) with timestamp = 1
+ // (1 + 10 + 2 + 0 + 1 = 14).
+ // The left row (key = 1) with timestamp = 3 will removed here.
+ testHarness.setProcessingTime(14)
assertEquals(2, testHarness.numKeyedStateEntries())
assertEquals(1, testHarness.numProcessingTimeTimers())
- // state must be cleaned after the window timer interval has passed without new rows.
- testHarness.setProcessingTime(23)
+ // Clean up the left row (key = 2) with timestamp = 2.
+ testHarness.setProcessingTime(16)
assertEquals(0, testHarness.numKeyedStateEntries())
assertEquals(0, testHarness.numProcessingTimeTimers())
val result = testHarness.getOutput
@@ -223,13 +233,174 @@ class JoinHarnessTest extends HarnessTestBase{
val expectedOutput = new ConcurrentLinkedQueue[Object]()
expectedOutput.add(new StreamRecord(
- CRow(Row.of(2: JInt, "aaa2", 2: JInt, "bbb7"), true), 7))
+ CRow(Row.of(2L: JLong, "2a2", 2L: JLong, "2b7"), true), 7))
expectedOutput.add(new StreamRecord(
- CRow(Row.of(1: JInt, "aaa3", 1: JInt, "bbb12"), true), 12))
+ CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b12"), true), 12))
verify(expectedOutput, result, new RowResultSortComparator())
testHarness.close()
}
+ /** a.c1 >= b.rowtime - 10 and a.rowtime <= b.rowtime + 20 **/
+ @Test
+ def testRowTimeJoinWithCommonBounds() {
+
+ val joinProcessFunc = new RowTimeBoundedStreamInnerJoin(
+ -10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0)
+
+ val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] =
+ new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc)
+ val testHarness: KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow] =
+ new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow](
+ operator,
+ new TupleRowKeySelector[String](1),
+ new TupleRowKeySelector[String](1),
+ Types.STRING,
+ 1, 1, 0)
+
+ testHarness.open()
+
+ testHarness.processWatermark1(new Watermark(1))
+ testHarness.processWatermark2(new Watermark(1))
+
+ // Test late data.
+ testHarness.processElement1(new StreamRecord[CRow](
+ CRow(Row.of(1L: JLong, "k1"), true), 0))
+
+ // Though (1L, "k1") is actually late, it will also be cached.
+ assertEquals(1, testHarness.numEventTimeTimers())
+
+ testHarness.processElement1(new StreamRecord[CRow](
+ CRow(Row.of(2L: JLong, "k1"), true), 0))
+ testHarness.processElement2(new StreamRecord[CRow](
+ CRow(Row.of(2L: JLong, "k1"), true), 0))
+
+ assertEquals(2, testHarness.numEventTimeTimers())
+ assertEquals(4, testHarness.numKeyedStateEntries())
+
+ testHarness.processElement1(new StreamRecord[CRow](
+ CRow(Row.of(5L: JLong, "k1"), true), 0))
+
+ testHarness.processElement2(new StreamRecord[CRow](
+ CRow(Row.of(15L: JLong, "k1"), true), 0))
+
+ testHarness.processWatermark1(new Watermark(20))
+ testHarness.processWatermark2(new Watermark(20))
+
+ assertEquals(4, testHarness.numKeyedStateEntries())
+
+ testHarness.processElement1(new StreamRecord[CRow](
+ CRow(Row.of(35L: JLong, "k1"), true), 0))
+
+ // The right rows with timestamp = 2 and 5 will be removed here.
+ // The left rows with timestamp = 2 and 15 will be removed here.
+ testHarness.processWatermark1(new Watermark(38))
+ testHarness.processWatermark2(new Watermark(38))
+
+ testHarness.processElement1(new StreamRecord[CRow](
+ CRow(Row.of(40L: JLong, "k2"), true), 0))
+ testHarness.processElement2(new StreamRecord[CRow](
+ CRow(Row.of(39L: JLong, "k2"), true), 0))
+
+ assertEquals(6, testHarness.numKeyedStateEntries())
+
+ // The right row with timestamp = 35 will be removed here.
+ testHarness.processWatermark1(new Watermark(61))
+ testHarness.processWatermark2(new Watermark(61))
+
+ assertEquals(4, testHarness.numKeyedStateEntries())
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(2L: JLong, "k1", 2L: JLong, "k1"), true), 0))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(5L: JLong, "k1", 2L: JLong, "k1"), true), 0))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(5L: JLong, "k1", 15L: JLong, "k1"), true), 0))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(35L: JLong, "k1", 15L: JLong, "k1"), true), 0))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(40L: JLong, "k2", 39L: JLong, "k2"), true), 0))
+
+ // This result is produced by the late row (1, "k1").
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(1L: JLong, "k1", 2L: JLong, "k1"), true), 0))
+
+ val result = testHarness.getOutput
+ verify(expectedOutput, result, new RowResultSortComparator())
+ testHarness.close()
+ }
+
+ /** a.rowtime >= b.rowtime - 10 and a.rowtime <= b.rowtime - 7 **/
+ @Test
+ def testRowTimeJoinWithNegativeBounds() {
+
+ val joinProcessFunc = new RowTimeBoundedStreamInnerJoin(
+ -10, -7, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0)
+
+ val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] =
+ new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc)
+ val testHarness: KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow] =
+ new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow](
+ operator,
+ new TupleRowKeySelector[String](1),
+ new TupleRowKeySelector[String](1),
+ Types.STRING,
+ 1, 1, 0)
+
+ testHarness.open()
+
+ testHarness.processWatermark1(new Watermark(1))
+ testHarness.processWatermark2(new Watermark(1))
+
+ // This row will not be cached.
+ testHarness.processElement2(new StreamRecord[CRow](
+ CRow(Row.of(2L: JLong, "k1"), true), 0))
+
+ assertEquals(0, testHarness.numKeyedStateEntries())
+
+ testHarness.processWatermark1(new Watermark(2))
+ testHarness.processWatermark2(new Watermark(2))
+
+ testHarness.processElement1(new StreamRecord[CRow](
+ CRow(Row.of(3L: JLong, "k1"), true), 0))
+ testHarness.processElement2(new StreamRecord[CRow](
+ CRow(Row.of(3L: JLong, "k1"), true), 0))
+
+ // Test for -10 boundary (13 - 10 = 3).
+ // This row from the right stream will be cached.
+ // The clean time for the left stream is 13 - 7 + 1 - 1 = 8
+ testHarness.processElement2(new StreamRecord[CRow](
+ CRow(Row.of(13L: JLong, "k1"), true), 0))
+
+ // Test for -7 boundary (13 - 7 = 6).
+ testHarness.processElement1(new StreamRecord[CRow](
+ CRow(Row.of(6L: JLong, "k1"), true), 0))
+
+ assertEquals(4, testHarness.numKeyedStateEntries())
+
+ // Trigger the left timer with timestamp 8.
+ // The row with timestamp = 13 will be removed here (13 < 10 + 7).
+ testHarness.processWatermark1(new Watermark(10))
+ testHarness.processWatermark2(new Watermark(10))
+
+ assertEquals(2, testHarness.numKeyedStateEntries())
+
+ // Clear the states.
+ testHarness.processWatermark1(new Watermark(18))
+ testHarness.processWatermark2(new Watermark(18))
+
+ assertEquals(0, testHarness.numKeyedStateEntries())
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(3L: JLong, "k1", 13L: JLong, "k1"), true), 0))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(6L: JLong, "k1", 13L: JLong, "k1"), true), 0))
+
+ val result = testHarness.getOutput
+ verify(expectedOutput, result, new RowResultSortComparator())
+ testHarness.close()
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/655d8b16/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
index e40da7a..13bfbcd 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
@@ -19,18 +19,22 @@
package org.apache.flink.table.runtime.stream.sql
import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.TimeCharacteristic
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.scala._
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase}
import org.apache.flink.types.Row
+import org.hamcrest.CoreMatchers
import org.junit._
import scala.collection.mutable
class JoinITCase extends StreamingWithStateTestBase {
- /** test process time inner join **/
+ /** test proctime inner join **/
@Test
def testProcessTimeInnerJoin(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
@@ -39,8 +43,14 @@ class JoinITCase extends StreamingWithStateTestBase {
StreamITCase.clear
env.setParallelism(1)
- val sqlQuery = "SELECT t2.a, t2.c, t1.c from T1 as t1 join T2 as t2 on t1.a = t2.a and " +
- "t1.proctime between t2.proctime - interval '5' second and t2.proctime + interval '5' second"
+ val sqlQuery =
+ """
+ |SELECT t2.a, t2.c, t1.c
+ |FROM T1 as t1 join T2 as t2 ON
+ | t1.a = t2.a AND
+ | t1.proctime BETWEEN t2.proctime - INTERVAL '5' SECOND AND
+ | t2.proctime + INTERVAL '5' SECOND
+ |""".stripMargin
val data1 = new mutable.MutableList[(Int, Long, String)]
data1.+=((1, 1L, "Hi1"))
@@ -65,19 +75,24 @@ class JoinITCase extends StreamingWithStateTestBase {
env.execute()
}
- /** test process time inner join with other condition **/
+ /** test proctime inner join with other condition **/
@Test
- def testProcessTimeInnerJoinWithOtherCondition(): Unit = {
+ def testProcessTimeInnerJoinWithOtherConditions(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
env.setStateBackend(getStateBackend)
StreamITCase.clear
- env.setParallelism(1)
+ env.setParallelism(2)
- val sqlQuery = "SELECT t2.a, t2.c, t1.c from T1 as t1 join T2 as t2 on t1.a = t2.a and " +
- "t1.proctime between t2.proctime - interval '5' second " +
- "and t2.proctime + interval '5' second " +
- "and t1.b > t2.b and t1.b + t2.b < 14"
+ val sqlQuery =
+ """
+ |SELECT t2.a, t2.c, t1.c
+ |FROM T1 as t1 JOIN T2 as t2 ON
+ | t1.a = t2.a AND
+ | t1.proctime BETWEEN t2.proctime - interval '5' SECOND AND
+ | t2.proctime + interval '5' second AND
+ | t1.b = t2.b
+ |""".stripMargin
val data1 = new mutable.MutableList[(String, Long, String)]
data1.+=(("1", 1L, "Hi1"))
@@ -91,6 +106,10 @@ class JoinITCase extends StreamingWithStateTestBase {
data2.+=(("1", 5L, "HiHi"))
data2.+=(("2", 2L, "HeHe"))
+ // For null key test
+ data1.+=((null.asInstanceOf[String], 20L, "leftNull"))
+ data2.+=((null.asInstanceOf[String], 20L, "rightNull"))
+
val t1 = env.fromCollection(data1).toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime)
val t2 = env.fromCollection(data2).toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime)
@@ -100,7 +119,173 @@ class JoinITCase extends StreamingWithStateTestBase {
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(new StreamITCase.StringSink[Row])
env.execute()
+
+ // Assert there is no result with null keys.
+ Assert.assertFalse(StreamITCase.testResults.toString().contains("null"))
+ }
+
+ /** test rowtime inner join **/
+ @Test
+ def testRowTimeInnerJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setStateBackend(getStateBackend)
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+ StreamITCase.clear
+ env.setParallelism(1)
+
+ val sqlQuery =
+ """
+ |SELECT t2.a, t2.c, t1.c
+ |FROM T1 as t1 join T2 as t2 ON
+ | t1.a = t2.a AND
+ | t1.rt BETWEEN t2.rt - INTERVAL '5' SECOND AND
+ | t2.rt + INTERVAL '6' SECOND
+ |""".stripMargin
+
+ val data1 = new mutable.MutableList[(Int, Long, String, Long)]
+ // for boundary test
+ data1.+=((1, 999L, "LEFT0.999", 999L))
+ data1.+=((1, 1000L, "LEFT1", 1000L))
+ data1.+=((1, 2000L, "LEFT2", 2000L))
+ data1.+=((1, 3000L, "LEFT3", 3000L))
+ data1.+=((2, 4000L, "LEFT4", 4000L))
+ data1.+=((1, 5000L, "LEFT5", 5000L))
+ data1.+=((1, 6000L, "LEFT6", 6000L))
+
+ val data2 = new mutable.MutableList[(Int, Long, String, Long)]
+ data2.+=((1, 6000L, "RIGHT6", 6000L))
+ data2.+=((2, 7000L, "RIGHT7", 7000L))
+
+ val t1 = env.fromCollection(data1)
+ .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor)
+ .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime)
+ val t2 = env.fromCollection(data2)
+ .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor)
+ .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime)
+
+ tEnv.registerTable("T1", t1)
+ tEnv.registerTable("T2", t2)
+
+ val result = tEnv.sql(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+ val expected = new java.util.ArrayList[String]
+ expected.add("1,RIGHT6,LEFT1")
+ expected.add("1,RIGHT6,LEFT2")
+ expected.add("1,RIGHT6,LEFT3")
+ expected.add("1,RIGHT6,LEFT5")
+ expected.add("1,RIGHT6,LEFT6")
+ expected.add("2,RIGHT7,LEFT4")
+ StreamITCase.compareWithList(expected)
}
+ /** test rowtime inner join with other conditions **/
+ @Test
+ def testRowTimeInnerJoinWithOtherConditions(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setStateBackend(getStateBackend)
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+ StreamITCase.clear
+
+ // different parallelisms lead to different join results
+ env.setParallelism(1)
+
+ val sqlQuery =
+ """
+ |SELECT t2.a, t1.c, t2.c
+ |FROM T1 as t1 JOIN T2 as t2 ON
+ | t1.a = t2.a AND
+ | t1.rt > t2.rt - INTERVAL '5' SECOND AND
+ | t1.rt < t2.rt - INTERVAL '1' SECOND AND
+ | t1.b < t2.b AND
+ | t1.b > 2
+ |""".stripMargin
+
+ val data1 = new mutable.MutableList[(Int, Long, String, Long)]
+ data1.+=((1, 4L, "LEFT1", 1000L))
+ // for boundary test
+ data1.+=((1, 8L, "LEFT1.1", 1001L))
+ // predicate (t1.b > 2) push down
+ data1.+=((1, 2L, "LEFT2", 2000L))
+ data1.+=((1, 7L, "LEFT3", 3000L))
+ data1.+=((2, 5L, "LEFT4", 4000L))
+ // for boundary test
+ data1.+=((1, 4L, "LEFT4.9", 4999L))
+ data1.+=((1, 4L, "LEFT5", 5000L))
+ data1.+=((1, 10L, "LEFT6", 6000L))
+ // a left late row
+ data1.+=((1, 3L, "LEFT3.5", 3500L))
+
+ val data2 = new mutable.MutableList[(Int, Long, String, Long)]
+ // just for watermark
+ data2.+=((1, 1L, "RIGHT1", 1000L))
+ data2.+=((1, 9L, "RIGHT6", 6000L))
+ data2.+=((2, 14L, "RIGHT7", 7000L))
+ data2.+=((1, 4L, "RIGHT8", 8000L))
+ // a right late row
+ data2.+=((1, 10L, "RIGHT5", 5000L))
+
+ val t1 = env.fromCollection(data1)
+ .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor)
+ .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime)
+ val t2 = env.fromCollection(data2)
+ .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor)
+ .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime)
+
+ tEnv.registerTable("T1", t1)
+ tEnv.registerTable("T2", t2)
+
+ val result = tEnv.sql(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ // There may be two expected results according to the process order.
+ val expected1 = new mutable.MutableList[String]
+ expected1+= "1,LEFT3,RIGHT6"
+ expected1+= "1,LEFT1.1,RIGHT6"
+ expected1+= "2,LEFT4,RIGHT7"
+ expected1+= "1,LEFT4.9,RIGHT6"
+ // produced by the left late rows
+ expected1+= "1,LEFT3.5,RIGHT6"
+ expected1+= "1,LEFT3.5,RIGHT8"
+ // produced by the right late rows
+ expected1+= "1,LEFT3,RIGHT5"
+ expected1+= "1,LEFT3.5,RIGHT5"
+
+ val expected2 = new mutable.MutableList[String]
+ expected2+= "1,LEFT3,RIGHT6"
+ expected2+= "1,LEFT1.1,RIGHT6"
+ expected2+= "2,LEFT4,RIGHT7"
+ expected2+= "1,LEFT4.9,RIGHT6"
+ // produced by the left late rows
+ expected2+= "1,LEFT3.5,RIGHT6"
+ expected2+= "1,LEFT3.5,RIGHT8"
+ // produced by the right late rows
+ expected2+= "1,LEFT3,RIGHT5"
+ expected2+= "1,LEFT1,RIGHT5"
+ expected2+= "1,LEFT1.1,RIGHT5"
+
+ Assert.assertThat(
+ StreamITCase.testResults.sorted,
+ CoreMatchers.either(CoreMatchers.is(expected1.sorted)).
+ or(CoreMatchers.is(expected2.sorted)))
+ }
}
+private class Tuple2WatermarkExtractor
+ extends AssignerWithPunctuatedWatermarks[(Int, Long, String, Long)] {
+
+ override def checkAndGetNextWatermark(
+ lastElement: (Int, Long, String, Long),
+ extractedTimestamp: Long): Watermark = {
+ new Watermark(extractedTimestamp - 1)
+ }
+
+ override def extractTimestamp(
+ element: (Int, Long, String, Long),
+ previousElementTimestamp: Long): Long = {
+ element._4
+ }
+}
[4/5] flink git commit: [FLINK-7491] [table] Add MultiSet type and
COLLECT aggregation function to SQL.
Posted by fh...@apache.org.
[FLINK-7491] [table] Add MultiSet type and COLLECT aggregation function to SQL.
This closes #4585.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/dccdba19
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/dccdba19
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/dccdba19
Branch: refs/heads/master
Commit: dccdba199a8fbb8b5186f0952410c1b1b3dff14f
Parents: 4047be4
Author: Shuyi Chen <sh...@uber.com>
Authored: Wed Aug 23 17:54:10 2017 -0700
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Oct 10 23:09:07 2017 +0200
----------------------------------------------------------------------
docs/dev/table/sql.md | 13 +-
.../flink/api/java/typeutils/MapTypeInfo.java | 2 +-
.../api/java/typeutils/MultisetTypeInfo.java | 91 ++++++++
.../java/typeutils/MultisetTypeInfoTest.java | 38 ++++
.../org/apache/flink/table/api/Types.scala | 11 +-
.../flink/table/calcite/FlinkTypeFactory.scala | 24 +-
.../flink/table/codegen/ExpressionReducer.scala | 9 +-
.../aggfunctions/CollectAggFunction.scala | 122 ++++++++++
.../flink/table/plan/nodes/FlinkRelNode.scala | 2 +-
.../table/plan/schema/MultisetRelDataType.scala | 50 ++++
.../table/runtime/aggregate/AggregateUtil.scala | 10 +-
.../flink/table/validate/FunctionCatalog.scala | 1 +
.../aggfunctions/CollectAggFunctionTest.scala | 226 +++++++++++++++++++
.../runtime/batch/sql/AggregateITCase.scala | 29 +++
.../table/runtime/stream/sql/SqlITCase.scala | 59 +++++
15 files changed, 677 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/docs/dev/table/sql.md
----------------------------------------------------------------------
diff --git a/docs/dev/table/sql.md b/docs/dev/table/sql.md
index 533aa6e..81dabee 100644
--- a/docs/dev/table/sql.md
+++ b/docs/dev/table/sql.md
@@ -803,6 +803,7 @@ The SQL runtime is built on top of Flink's DataSet and DataStream APIs. Internal
| `Types.PRIMITIVE_ARRAY`| `ARRAY` | e.g. `int[]` |
| `Types.OBJECT_ARRAY` | `ARRAY` | e.g. `java.lang.Byte[]`|
| `Types.MAP` | `MAP` | `java.util.HashMap` |
+| `Types.MULTISET` | `MULTISET` | e.g. `java.util.HashMap<String, Integer>` for a multiset of `String` |
Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and array types (object or primitive arrays) can be fields of a row.
@@ -2164,6 +2165,17 @@ VAR_SAMP(value)
<p>Returns the sample variance (square of the sample standard deviation) of the numeric field across all input values.</p>
</td>
</tr>
+
+ <tr>
+ <td>
+ {% highlight text %}
+ COLLECT(value)
+ {% endhighlight %}
+ </td>
+ <td>
+ <p>Returns a multiset of the <i>value</i>s. null input <i>value</i> will be ignored. Return an empty multiset if only null values are added. </p>
+ </td>
+ </tr>
</tbody>
</table>
@@ -2283,7 +2295,6 @@ The following functions are not supported yet:
- Binary string operators and functions
- System functions
-- Collection functions
- Distinct aggregate functions like COUNT DISTINCT
{% top %}
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java
index ca04e0c..e9cd09d 100644
--- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java
+++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java
@@ -93,7 +93,7 @@ public class MapTypeInfo<K, V> extends TypeInformation<Map<K, V>> {
@Override
public int getTotalFields() {
- return 2;
+ return 1;
}
@SuppressWarnings("unchecked")
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java
new file mode 100644
index 0000000..27fe709
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.java.typeutils;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * A {@link TypeInformation} for the Multiset types of the Java API.
+ *
+ * @param <T> The type of the elements in the Multiset.
+ */
+@PublicEvolving
+public final class MultisetTypeInfo<T> extends MapTypeInfo<T, Integer> {
+
+ private static final long serialVersionUID = 1L;
+
+ public MultisetTypeInfo(Class<T> elementTypeClass) {
+ super(elementTypeClass, Integer.class);
+ }
+
+ public MultisetTypeInfo(TypeInformation<T> elementTypeInfo) {
+ super(elementTypeInfo, BasicTypeInfo.INT_TYPE_INFO);
+ }
+
+ // ------------------------------------------------------------------------
+ // MultisetTypeInfo specific properties
+ // ------------------------------------------------------------------------
+
+ /**
+ * Gets the type information for the elements contained in the Multiset
+ */
+ public TypeInformation<T> getElementTypeInfo() {
+ return getKeyTypeInfo();
+ }
+
+ @Override
+ public String toString() {
+ return "Multiset<" + getKeyTypeInfo() + '>';
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == this) {
+ return true;
+ }
+ else if (obj instanceof MultisetTypeInfo) {
+ final MultisetTypeInfo<?> other = (MultisetTypeInfo<?>) obj;
+ return other.canEqual(this) && getKeyTypeInfo().equals(other.getKeyTypeInfo());
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * getKeyTypeInfo().hashCode() + 1;
+ }
+
+ @Override
+ public boolean canEqual(Object obj) {
+ return obj != null && obj.getClass() == getClass();
+ }
+
+ @SuppressWarnings("unchecked")
+ @PublicEvolving
+ public static <C> MultisetTypeInfo<C> getInfoFor(TypeInformation<C> componentInfo) {
+ checkNotNull(componentInfo);
+
+ return new MultisetTypeInfo<>(componentInfo);
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java
new file mode 100644
index 0000000..395f4ce
--- /dev/null
+++ b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.java.typeutils;
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.TypeInformationTestBase;
+
+/**
+ * Test for {@link MultisetTypeInfo}.
+ */
+public class MultisetTypeInfoTest extends TypeInformationTestBase<MultisetTypeInfo<?>> {
+
+ @Override
+ protected MultisetTypeInfo<?>[] getTestData() {
+ return new MultisetTypeInfo<?>[] {
+ new MultisetTypeInfo<>(BasicTypeInfo.STRING_TYPE_INFO),
+ new MultisetTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ new MultisetTypeInfo<>(Long.class)
+ };
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala
index 2152b72..100c22b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala
@@ -18,7 +18,7 @@
package org.apache.flink.table.api
import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, TypeInformation, Types => JTypes}
-import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo}
+import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo}
import org.apache.flink.table.typeutils.TimeIntervalTypeInfo
import org.apache.flink.types.Row
@@ -110,4 +110,13 @@ object Types {
def MAP(keyType: TypeInformation[_], valueType: TypeInformation[_]): TypeInformation[_] = {
new MapTypeInfo(keyType, valueType)
}
+
+ /**
+ * Generates type information for a Multiset.
+ *
+ * @param elementType type of the elements of the multiset e.g. Types.STRING
+ */
+ def MULTISET(elementType: TypeInformation[_]): TypeInformation[_] = {
+ new MultisetTypeInfo(elementType)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala
index 1cc9f6b..768d700 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala
@@ -31,7 +31,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.ValueTypeInfo._
-import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
+import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.FlinkTypeFactory.typeInfoToSqlTypeName
import org.apache.flink.table.plan.schema._
@@ -156,6 +156,13 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
createTypeFromTypeInfo(mp.getValueTypeInfo, isNullable = true),
isNullable)
+ case mts: MultisetTypeInfo[_] =>
+ new MultisetRelDataType(
+ mts,
+ createTypeFromTypeInfo(mts.getElementTypeInfo, isNullable = true),
+ isNullable
+ )
+
case ti: TypeInformation[_] =>
new GenericRelDataType(
ti,
@@ -213,6 +220,14 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
canonize(relType)
}
+ override def createMultisetType(elementType: RelDataType, maxCardinality: Long): RelDataType = {
+ val relType = new MultisetRelDataType(
+ MultisetTypeInfo.getInfoFor(FlinkTypeFactory.toTypeInfo(elementType)),
+ elementType,
+ isNullable = false)
+ canonize(relType)
+ }
+
override def createTypeWithNullability(
relDataType: RelDataType,
isNullable: Boolean): RelDataType = {
@@ -234,6 +249,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
case map: MapRelDataType =>
new MapRelDataType(map.typeInfo, map.keyType, map.valueType, isNullable)
+ case multiSet: MultisetRelDataType =>
+ new MultisetRelDataType(multiSet.typeInfo, multiSet.getComponentType, isNullable)
+
case generic: GenericRelDataType =>
new GenericRelDataType(generic.typeInfo, isNullable, typeSystem)
@@ -403,6 +421,10 @@ object FlinkTypeFactory {
val mapRelDataType = relDataType.asInstanceOf[MapRelDataType]
mapRelDataType.typeInfo
+ case MULTISET if relDataType.isInstanceOf[MultisetRelDataType] =>
+ val multisetRelDataType = relDataType.asInstanceOf[MultisetRelDataType]
+ multisetRelDataType.typeInfo
+
case _@t =>
throw TableException(s"Type is not supported: $t")
}
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
index 3e71c99..9696ced 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala
@@ -74,7 +74,8 @@ class ExpressionReducer(config: TableConfig)
case (SqlTypeName.ANY, _) |
(SqlTypeName.ROW, _) |
(SqlTypeName.ARRAY, _) |
- (SqlTypeName.MAP, _) => None
+ (SqlTypeName.MAP, _) |
+ (SqlTypeName.MULTISET, _) => None
case (_, e) => Some(e)
}
@@ -112,7 +113,11 @@ class ExpressionReducer(config: TableConfig)
val unreduced = constExprs.get(i)
unreduced.getType.getSqlTypeName match {
// we insert the original expression for object literals
- case SqlTypeName.ANY | SqlTypeName.ROW | SqlTypeName.ARRAY | SqlTypeName.MAP =>
+ case SqlTypeName.ANY |
+ SqlTypeName.ROW |
+ SqlTypeName.ARRAY |
+ SqlTypeName.MAP |
+ SqlTypeName.MULTISET =>
reducedValues.add(unreduced)
case _ =>
val reducedValue = reduced.getField(reducedIdx)
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
new file mode 100644
index 0000000..b10be61
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.aggfunctions
+
+import java.lang.{Iterable => JIterable}
+import java.util
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils._
+import org.apache.flink.table.api.dataview.MapView
+import org.apache.flink.table.dataview.MapViewTypeInfo
+import org.apache.flink.table.functions.AggregateFunction
+
+import scala.collection.JavaConverters._
+
+/** The initial accumulator for Collect aggregate function */
+class CollectAccumulator[E](var map: MapView[E, Integer]) {
+ def this() {
+ this(null)
+ }
+
+ def canEqual(a: Any): Boolean = a.isInstanceOf[CollectAccumulator[E]]
+
+ override def equals(that: Any): Boolean =
+ that match {
+ case that: CollectAccumulator[E] => that.canEqual(this) && this.map == that.map
+ case _ => false
+ }
+}
+
+class CollectAggFunction[E](valueTypeInfo: TypeInformation[_])
+ extends AggregateFunction[util.Map[E, Integer], CollectAccumulator[E]] {
+
+ override def createAccumulator(): CollectAccumulator[E] = {
+ new CollectAccumulator[E](
+ new MapView[E, Integer](
+ valueTypeInfo.asInstanceOf[TypeInformation[E]],
+ BasicTypeInfo.INT_TYPE_INFO))
+ }
+
+ def accumulate(accumulator: CollectAccumulator[E], value: E): Unit = {
+ if (value != null) {
+ val currVal = accumulator.map.get(value)
+ if (currVal != null) {
+ accumulator.map.put(value, currVal + 1)
+ } else {
+ accumulator.map.put(value, 1)
+ }
+ }
+ }
+
+ override def getValue(accumulator: CollectAccumulator[E]): util.Map[E, Integer] = {
+ val iterator = accumulator.map.iterator
+ if (iterator.hasNext) {
+ val map = new util.HashMap[E, Integer]()
+ while (iterator.hasNext) {
+ val entry = iterator.next()
+ map.put(entry.getKey, entry.getValue)
+ }
+ map
+ } else {
+ Map[E, Integer]().asJava
+ }
+ }
+
+ def resetAccumulator(acc: CollectAccumulator[E]): Unit = {
+ acc.map.clear()
+ }
+
+ override def getAccumulatorType: TypeInformation[CollectAccumulator[E]] = {
+ val clazz = classOf[CollectAccumulator[E]]
+ val pojoFields = new util.ArrayList[PojoField]
+ pojoFields.add(new PojoField(clazz.getDeclaredField("map"),
+ new MapViewTypeInfo[E, Integer](
+ valueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO)))
+ new PojoTypeInfo[CollectAccumulator[E]](clazz, pojoFields)
+ }
+
+ def merge(acc: CollectAccumulator[E], its: JIterable[CollectAccumulator[E]]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val mapViewIterator = iter.next().map.iterator
+ while (mapViewIterator.hasNext) {
+ val entry = mapViewIterator.next()
+ val k = entry.getKey
+ val oldValue = acc.map.get(k)
+ if (oldValue == null) {
+ acc.map.put(k, entry.getValue)
+ } else {
+ acc.map.put(k, entry.getValue + oldValue)
+ }
+ }
+ }
+ }
+
+ def retract(acc: CollectAccumulator[E], value: E): Unit = {
+ if (value != null) {
+ val count = acc.map.get(value)
+ if (count == 1) {
+ acc.map.remove(value)
+ } else {
+ acc.map.put(value, count - 1)
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala
index 8509a8e..f3e1a62 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala
@@ -94,7 +94,7 @@ trait FlinkRelNode extends RelNode {
case SqlTypeName.ARRAY =>
// 16 is an arbitrary estimate
estimateDataTypeSize(t.getComponentType) * 16
- case SqlTypeName.MAP =>
+ case SqlTypeName.MAP | SqlTypeName.MULTISET =>
// 16 is an arbitrary estimate
(estimateDataTypeSize(t.getKeyType) + estimateDataTypeSize(t.getValueType)) * 16
case SqlTypeName.ANY => 128 // 128 is an arbitrary estimate
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala
new file mode 100644
index 0000000..859fc41
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.schema
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.sql.`type`.MultisetSqlType
+import org.apache.flink.api.common.typeinfo.TypeInformation
+
+class MultisetRelDataType(
+ val typeInfo: TypeInformation[_],
+ elementType: RelDataType,
+ isNullable: Boolean)
+ extends MultisetSqlType(
+ elementType,
+ isNullable) {
+
+ override def toString = s"MULTISET($elementType)"
+
+ def canEqual(other: Any): Boolean = other.isInstanceOf[MultisetRelDataType]
+
+ override def equals(other: Any): Boolean = other match {
+ case that: MultisetRelDataType =>
+ super.equals(that) &&
+ (that canEqual this) &&
+ typeInfo == that.typeInfo &&
+ isNullable == that.isNullable
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ typeInfo.hashCode()
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 58940d0..c84b254 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -28,7 +28,7 @@ import org.apache.calcite.sql.{SqlAggFunction, SqlKind}
import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction, AggregateFunction => DataStreamAggFunction, _}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.Tuple
-import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo}
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
@@ -1200,8 +1200,8 @@ object AggregateUtil {
} else {
aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray
}
- val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
- .getSqlTypeName
+ val relDataType = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
+ val sqlTypeName = relDataType.getSqlTypeName
aggregateCall.getAggregation match {
case _: SqlSumAggFunction =>
@@ -1410,6 +1410,10 @@ object AggregateUtil {
case _: SqlCountAggFunction =>
aggregates(index) = new CountAggFunction
+ case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT =>
+ aggregates(index) = new CollectAggFunction(FlinkTypeFactory.toTypeInfo(relDataType))
+ accTypes(index) = aggregates(index).getAccumulatorType
+
case udagg: AggSqlFunction =>
aggregates(index) = udagg.getFunction
accTypes(index) = udagg.accType
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index 5254ceb..3398a93 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -319,6 +319,7 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
SqlStdOperatorTable.SUM,
SqlStdOperatorTable.SUM0,
SqlStdOperatorTable.COUNT,
+ SqlStdOperatorTable.COLLECT,
SqlStdOperatorTable.MIN,
SqlStdOperatorTable.MAX,
SqlStdOperatorTable.AVG,
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala
new file mode 100644
index 0000000..f85cb70
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggfunctions
+
+import java.util
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.api.java.typeutils.GenericTypeInfo
+import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.aggfunctions._
+
+import scala.collection.JavaConverters._
+
+/**
+ * Test case for built-in collect aggregate functions
+ */
+class StringCollectAggFunctionTest
+ extends AggFunctionTestBase[util.Map[String, Integer], CollectAccumulator[String]] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq("a", "a", "b", null, "c", null, "d", "e", null, "f"),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[util.Map[String, Integer]] = {
+ val map = new util.HashMap[String, Integer]()
+ map.put("a", 2)
+ map.put("b", 1)
+ map.put("c", 1)
+ map.put("d", 1)
+ map.put("e", 1)
+ map.put("f", 1)
+ Seq(map, Map[String, Integer]().asJava)
+ }
+
+ override def aggregator: AggregateFunction[
+ util.Map[String, Integer], CollectAccumulator[String]] =
+ new CollectAggFunction(BasicTypeInfo.STRING_TYPE_INFO)
+
+ override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
+}
+
+class IntCollectAggFunctionTest
+ extends AggFunctionTestBase[util.Map[Int, Integer], CollectAccumulator[Int]] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(1, 1, 2, null, 3, null, 4, 5, null),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[util.Map[Int, Integer]] = {
+ val map = new util.HashMap[Int, Integer]()
+ map.put(1, 2)
+ map.put(2, 1)
+ map.put(3, 1)
+ map.put(4, 1)
+ map.put(5, 1)
+ Seq(map, Map[Int, Integer]().asJava)
+ }
+
+ override def aggregator: AggregateFunction[util.Map[Int, Integer], CollectAccumulator[Int]] =
+ new CollectAggFunction(BasicTypeInfo.INT_TYPE_INFO)
+
+ override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
+}
+
+class ByteCollectAggFunctionTest
+ extends AggFunctionTestBase[util.Map[Byte, Integer], CollectAccumulator[Byte]] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(1.toByte, 1.toByte, 2.toByte, null, 3.toByte, null, 4.toByte, 5.toByte, null),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[util.Map[Byte, Integer]] = {
+ val map = new util.HashMap[Byte, Integer]()
+ map.put(1, 2)
+ map.put(2, 1)
+ map.put(3, 1)
+ map.put(4, 1)
+ map.put(5, 1)
+ Seq(map, Map[Byte, Integer]().asJava)
+ }
+
+ override def aggregator: AggregateFunction[util.Map[Byte, Integer], CollectAccumulator[Byte]] =
+ new CollectAggFunction(BasicTypeInfo.BYTE_TYPE_INFO)
+
+ override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
+}
+
+class ShortCollectAggFunctionTest
+ extends AggFunctionTestBase[util.Map[Short, Integer], CollectAccumulator[Short]] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(1.toShort, 1.toShort, 2.toShort, null,
+ 3.toShort, null, 4.toShort, 5.toShort, null),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[util.Map[Short, Integer]] = {
+ val map = new util.HashMap[Short, Integer]()
+ map.put(1, 2)
+ map.put(2, 1)
+ map.put(3, 1)
+ map.put(4, 1)
+ map.put(5, 1)
+ Seq(map, Map[Short, Integer]().asJava)
+ }
+
+ override def aggregator: AggregateFunction[util.Map[Short, Integer], CollectAccumulator[Short]] =
+ new CollectAggFunction(BasicTypeInfo.SHORT_TYPE_INFO)
+
+ override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
+}
+
+class LongCollectAggFunctionTest
+ extends AggFunctionTestBase[util.Map[Long, Integer], CollectAccumulator[Long]] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(1L, 1L, 2L, null, 3L, null, 4L, 5L, null),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[util.Map[Long, Integer]] = {
+ val map = new util.HashMap[Long, Integer]()
+ map.put(1, 2)
+ map.put(2, 1)
+ map.put(3, 1)
+ map.put(4, 1)
+ map.put(5, 1)
+ Seq(map, Map[Long, Integer]().asJava)
+ }
+
+ override def aggregator: AggregateFunction[util.Map[Long, Integer], CollectAccumulator[Long]] =
+ new CollectAggFunction(BasicTypeInfo.LONG_TYPE_INFO)
+
+ override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
+}
+
+class FloatAggFunctionTest
+ extends AggFunctionTestBase[util.Map[Float, Integer], CollectAccumulator[Float]] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(1f, 1f, 2f, null, 3.2f, null, 4f, 5f, null),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[util.Map[Float, Integer]] = {
+ val map = new util.HashMap[Float, Integer]()
+ map.put(1, 2)
+ map.put(2, 1)
+ map.put(3.2f, 1)
+ map.put(4, 1)
+ map.put(5, 1)
+ Seq(map, Map[Float, Integer]().asJava)
+ }
+
+ override def aggregator: AggregateFunction[util.Map[Float, Integer], CollectAccumulator[Float]] =
+ new CollectAggFunction(BasicTypeInfo.FLOAT_TYPE_INFO)
+
+ override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
+}
+
+class DoubleAggFunctionTest
+ extends AggFunctionTestBase[util.Map[Double, Integer], CollectAccumulator[Double]] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(1d, 1d, 2d, null, 3.2d, null, 4d, 5d),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[util.Map[Double, Integer]] = {
+ val map = new util.HashMap[Double, Integer]()
+ map.put(1, 2)
+ map.put(2, 1)
+ map.put(3.2d, 1)
+ map.put(4, 1)
+ map.put(5, 1)
+ Seq(map, Map[Double, Integer]().asJava)
+ }
+
+ override def aggregator: AggregateFunction[
+ util.Map[Double, Integer], CollectAccumulator[Double]] =
+ new CollectAggFunction(BasicTypeInfo.DOUBLE_TYPE_INFO)
+
+ override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
+}
+
+class ObjectCollectAggFunctionTest
+ extends AggFunctionTestBase[util.Map[Object, Integer], CollectAccumulator[Object]] {
+
+ override def inputValueSets: Seq[Seq[_]] = Seq(
+ Seq(Tuple2(1, "a"), Tuple2(1, "a"), null, Tuple2(2, "b")),
+ Seq(null, null, null, null, null, null)
+ )
+
+ override def expectedResults: Seq[util.Map[Object, Integer]] = {
+ val map = new util.HashMap[Object, Integer]()
+ map.put(Tuple2(1, "a"), 2)
+ map.put(Tuple2(2, "b"), 1)
+ Seq(map, Map[Object, Integer]().asJava)
+ }
+
+ override def aggregator: AggregateFunction[
+ util.Map[Object, Integer], CollectAccumulator[Object]] =
+ new CollectAggFunction(new GenericTypeInfo[Object](classOf[Object]))
+
+ override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
+}
+
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala
index 465a88c..aa934c6 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala
@@ -329,6 +329,35 @@ class AggregateITCase(
}
@Test
+ def testTumbleWindowAggregateWithCollect(): Unit = {
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+ val sqlQuery =
+ "SELECT b, COLLECT(b)" +
+ "FROM T " +
+ "GROUP BY b, TUMBLE(ts, INTERVAL '3' SECOND)"
+
+ val ds = CollectionDataSets.get3TupleDataSet(env)
+ // create timestamps
+ .map(x => (x._1, x._2, x._3, new Timestamp(x._1 * 1000)))
+ tEnv.registerDataSet("T", ds, 'a, 'b, 'c, 'ts)
+
+ val result = tEnv.sql(sqlQuery).toDataSet[Row].collect()
+ val expected = Seq(
+ "1,{1=1}",
+ "2,{2=1}", "2,{2=1}",
+ "3,{3=1}", "3,{3=2}",
+ "4,{4=2}", "4,{4=2}",
+ "5,{5=1}", "5,{5=1}", "5,{5=3}",
+ "6,{6=1}", "6,{6=2}", "6,{6=3}"
+ ).mkString("\n")
+
+ TestBaseUtils.compareResultAsText(result.asJava, expected)
+ }
+
+ @Test
def testHopWindowAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
http://git-wip-us.apache.org/repos/asf/flink/blob/dccdba19/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
index 2c82d9c..32e3724 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
@@ -92,6 +92,65 @@ class SqlITCase extends StreamingWithStateTestBase {
assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
}
+ @Test
+ def testUnboundedGroupByCollect(): Unit = {
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setStateBackend(getStateBackend)
+ StreamITCase.clear
+
+ val sqlQuery = "SELECT b, COLLECT(a) FROM MyTable GROUP BY b"
+
+ val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+ tEnv.registerTable("MyTable", t)
+
+ val result = tEnv.sql(sqlQuery).toRetractStream[Row]
+ result.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "1,{1=1}",
+ "2,{2=1, 3=1}",
+ "3,{4=1, 5=1, 6=1}",
+ "4,{7=1, 8=1, 9=1, 10=1}",
+ "5,{11=1, 12=1, 13=1, 14=1, 15=1}",
+ "6,{16=1, 17=1, 18=1, 19=1, 20=1, 21=1}")
+ assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
+ def testUnboundedGroupByCollectWithObject(): Unit = {
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setStateBackend(getStateBackend)
+ StreamITCase.clear
+
+ val sqlQuery = "SELECT b, COLLECT(c) FROM MyTable GROUP BY b"
+
+ val data = List(
+ (1, 1, (12, "45.6")),
+ (2, 2, (12, "45.612")),
+ (3, 2, (13, "41.6")),
+ (4, 3, (14, "45.2136")),
+ (5, 3, (18, "42.6"))
+ )
+
+ tEnv.registerTable("MyTable",
+ env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c))
+
+ val result = tEnv.sql(sqlQuery).toRetractStream[Row]
+ result.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "1,{(12,45.6)=1}",
+ "2,{(13,41.6)=1, (12,45.612)=1}",
+ "3,{(18,42.6)=1, (14,45.2136)=1}")
+ assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+ }
+
/** test selection **/
@Test
def testSelectExpressionFromTable(): Unit = {