You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2018/01/09 09:11:42 UTC
[1/2] flink git commit: [FLINK-6094] [table] Add checks for
hashCode/equals and little code cleanup
Repository: flink
Updated Branches:
refs/heads/master 11287fbf6 -> 49c6d10f1
[FLINK-6094] [table] Add checks for hashCode/equals and little code cleanup
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/49c6d10f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/49c6d10f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/49c6d10f
Branch: refs/heads/master
Commit: 49c6d10f186fb722d2a4003ce4d2219c01f55871
Parents: 9623b25
Author: twalthr <tw...@apache.org>
Authored: Mon Jan 8 14:27:34 2018 +0100
Committer: twalthr <tw...@apache.org>
Committed: Tue Jan 9 09:48:32 2018 +0100
----------------------------------------------------------------------
.../DataStreamGroupWindowAggregate.scala | 2 -
.../plan/nodes/datastream/DataStreamJoin.scala | 13 ++---
.../table/plan/util/UpdatingPlanChecker.scala | 50 ++++++++---------
.../flink/table/runtime/CRowKeySelector.scala | 4 ++
.../table/runtime/join/NonWindowInnerJoin.scala | 21 +++++---
.../flink/table/typeutils/TypeCheckUtils.scala | 26 +++++----
.../table/validation/JoinValidationTest.scala | 56 +++++++++++++++++---
.../table/plan/UpdatingPlanCheckerTest.scala | 7 ++-
.../table/typeutils/TypeCheckUtilsTest.scala | 22 ++++----
9 files changed, 128 insertions(+), 73 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
index 7a6b333..d527dc8 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
@@ -70,8 +70,6 @@ class DataStreamGroupWindowAggregate(
def getWindowProperties: Seq[NamedWindowProperty] = namedProperties
- def getWindowAlias: String = window.aliasAttribute.toString
-
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamGroupWindowAggregate(
window,
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
index 576c2bc..853006f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
@@ -53,8 +53,8 @@ class DataStreamJoin(
schema: RowSchema,
ruleDescription: String)
extends BiRel(cluster, traitSet, leftNode, rightNode)
- with CommonJoin
- with DataStreamRel {
+ with CommonJoin
+ with DataStreamRel {
override def deriveRowType(): RelDataType = schema.relDataType
@@ -123,8 +123,8 @@ class DataStreamJoin(
} else {
throw TableException(
"Equality join predicate on incompatible types.\n" +
- s"\tLeft: ${left},\n" +
- s"\tRight: ${right},\n" +
+ s"\tLeft: $left,\n" +
+ s"\tRight: $right,\n" +
s"\tCondition: (${joinConditionToString(schema.relDataType,
joinCondition, getExpressionString)})"
)
@@ -138,8 +138,9 @@ class DataStreamJoin(
val (connectOperator, nullCheck) = joinType match {
case JoinRelType.INNER => (leftDataStream.connect(rightDataStream), false)
- case _ => throw TableException(s"An Unsupported JoinType [ $joinType ]. Currently only " +
- s"non-window inner joins with at least one equality predicate are supported")
+ case _ =>
+ throw TableException(s"Unsupported join type '$joinType'. Currently only " +
+ s"non-window inner joins with at least one equality predicate are supported")
}
val generator = new FunctionCodeGenerator(
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
index 9ec097a..56465cc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
@@ -21,9 +21,9 @@ import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.{RelNode, RelVisitor}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode}
import org.apache.calcite.sql.SqlKind
+import org.apache.flink.table.expressions.ProctimeAttribute
import org.apache.flink.table.plan.nodes.datastream._
-import _root_.scala.collection.JavaConverters._
import _root_.scala.collection.JavaConversions._
import scala.collection.mutable
@@ -66,7 +66,7 @@ object UpdatingPlanChecker {
// belong to the same group, i.e., pk1. Here we use the lexicographic smallest attribute as
// the common group id. A node can have keys if it generates the keys by itself or it
// forwards keys from its input(s).
- def visit(node: RelNode): Option[List[(String, String)]] = {
+ def visit(node: RelNode): Option[Seq[(String, String)]] = {
node match {
case c: DataStreamCalc =>
val inputKeys = visit(node.getInput(0))
@@ -74,7 +74,7 @@ object UpdatingPlanChecker {
if (inputKeys.isDefined) {
// track keys forward
val inNames = c.getInput.getRowType.getFieldNames
- val inOutNames = c.getProgram.getNamedProjects.asScala
+ val inOutNames = c.getProgram.getNamedProjects
.map(p => {
c.getProgram.expandLocalRef(p.left) match {
// output field is forwarded input field
@@ -102,7 +102,8 @@ object UpdatingPlanChecker {
val inputKeysMap = inputKeys.get.toMap
val inOutGroups = inputKeysAndOutput
- .map(e => (inputKeysMap(e._1), e._2)).sorted.reverse.toMap
+ .map(e => (inputKeysMap(e._1), e._2))
+ .toMap
// get output keys
val outputKeys = inputKeysAndOutput
@@ -111,7 +112,7 @@ object UpdatingPlanChecker {
// check if all keys have been preserved
if (outputKeys.map(_._2).distinct.length == inputKeys.get.map(_._2).distinct.length) {
// all key have been preserved (but possibly renamed)
- Some(outputKeys.toList)
+ Some(outputKeys)
} else {
// some (or all) keys have been removed. Keys are no longer unique and removed
None
@@ -125,18 +126,19 @@ object UpdatingPlanChecker {
visit(node.getInput(0))
case a: DataStreamGroupAggregate =>
// get grouping keys
- val groupKeys = a.getRowType.getFieldNames.asScala.take(a.getGroupings.length)
- Some(groupKeys.map(e => (e, e)).toList)
+ val groupKeys = a.getRowType.getFieldNames.take(a.getGroupings.length)
+ Some(groupKeys.map(e => (e, e)))
case w: DataStreamGroupWindowAggregate =>
// get grouping keys
val groupKeys =
- w.getRowType.getFieldNames.asScala.take(w.getGroupings.length).toArray
- // get window start and end time
- val windowStartEnd = w.getWindowProperties.map(_.name)
+ w.getRowType.getFieldNames.take(w.getGroupings.length).toArray
+ // proctime is not a valid key
+ val windowProperties = w.getWindowProperties
+ .filter(!_.property.isInstanceOf[ProctimeAttribute])
+ .map(_.name)
// we have only a unique key if at least one window property is selected
- if (windowStartEnd.nonEmpty) {
- val smallestAttribute = windowStartEnd.min
- Some((groupKeys.map(e => (e, e)) ++ windowStartEnd.map((_, smallestAttribute))).toList)
+ if (windowProperties.nonEmpty) {
+ Some(groupKeys.map(e => (e, e)) ++ windowProperties.map(e => (e, e)))
} else {
None
}
@@ -144,7 +146,7 @@ object UpdatingPlanChecker {
case j: DataStreamJoin =>
val joinType = j.getJoinType
joinType match {
- case JoinRelType.INNER => {
+ case JoinRelType.INNER =>
// get key(s) for inner join
val lInKeys = visit(j.getLeft)
val rInKeys = visit(j.getRight)
@@ -170,18 +172,17 @@ object UpdatingPlanChecker {
.map(rInNames.get(_))
.map(rInNamesToJoinNamesMap(_))
- val inKeys: List[(String, String)] = lInKeys.get ++ rInKeys.get
+ val inKeys: Seq[(String, String)] = lInKeys.get ++ rInKeys.get
.map(e => (rInNamesToJoinNamesMap(e._1), rInNamesToJoinNamesMap(e._2)))
getOutputKeysForInnerJoin(
joinNames,
inKeys,
- lJoinKeys.zip(rJoinKeys).toList
+ lJoinKeys.zip(rJoinKeys)
)
}
- }
- case _ => throw new UnsupportedOperationException(
- s"An Unsupported JoinType [ $joinType ]")
+ case _ =>
+ throw new UnsupportedOperationException(s"Unsupported join type '$joinType'")
}
case _: DataStreamRel =>
// anything else does not forward keys, so we can stop
@@ -199,9 +200,9 @@ object UpdatingPlanChecker {
*/
def getOutputKeysForInnerJoin(
inNames: Seq[String],
- inKeys: List[(String, String)],
- joinKeys: List[(String, String)])
- : Option[List[(String, String)]] = {
+ inKeys: Seq[(String, String)],
+ joinKeys: Seq[(String, String)])
+ : Option[Seq[(String, String)]] = {
val nameToGroups = mutable.HashMap.empty[String,String]
@@ -210,7 +211,7 @@ object UpdatingPlanChecker {
val ga: String = findGroup(nameA)
val gb: String = findGroup(nameB)
if (!ga.equals(gb)) {
- if(ga.compare(gb) < 0) {
+ if (ga.compare(gb) < 0) {
nameToGroups += (gb -> ga)
} else {
nameToGroups += (ga -> gb)
@@ -242,14 +243,13 @@ object UpdatingPlanChecker {
// merge groups
joinKeys.foreach(e => merge(e._1, e._2))
// make sure all name point to the group name directly
- inNames.foreach(findGroup(_))
+ inNames.foreach(findGroup)
val outputGroups = inKeys.map(e => nameToGroups(e._1)).distinct
Some(
inNames
.filter(e => outputGroups.contains(nameToGroups(e)))
.map(e => (e, nameToGroups(e)))
- .toList
)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowKeySelector.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowKeySelector.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowKeySelector.scala
index 216a7f9..327476a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowKeySelector.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowKeySelector.scala
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.typeutils.TypeCheckUtils.validateEqualsHashCode
import org.apache.flink.types.Row
/**
@@ -33,6 +34,9 @@ class CRowKeySelector(
extends KeySelector[CRow, Row]
with ResultTypeQueryable[Row] {
+ // check if type implements proper equals/hashCode
+ validateEqualsHashCode("grouping", returnType)
+
override def getKey(value: CRow): Row = {
Row.project(value.row, keyFields)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
index 841cd15..6fef701 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
@@ -22,17 +22,19 @@ package org.apache.flink.table.runtime.join
import org.apache.flink.api.common.functions.FlatJoinFunction
import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.co.CoProcessFunction
import org.apache.flink.table.api.{StreamQueryConfig, Types}
+import org.apache.flink.table.codegen.Compiler
import org.apache.flink.table.runtime.CRowWrappingMultiOutputCollector
import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.typeutils.TypeCheckUtils
+import org.apache.flink.table.typeutils.TypeCheckUtils.validateEqualsHashCode
+import org.apache.flink.table.util.Logging
import org.apache.flink.types.Row
import org.apache.flink.util.Collector
-import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
-import org.apache.flink.table.codegen.Compiler
-import org.apache.flink.table.util.Logging
/**
@@ -53,8 +55,12 @@ class NonWindowInnerJoin(
genJoinFuncCode: String,
queryConfig: StreamQueryConfig)
extends CoProcessFunction[CRow, CRow, CRow]
- with Compiler[FlatJoinFunction[Row, Row, Row]]
- with Logging {
+ with Compiler[FlatJoinFunction[Row, Row, Row]]
+ with Logging {
+
+ // check if input types implement proper equals/hashCode
+ validateEqualsHashCode("join", leftType)
+ validateEqualsHashCode("join", rightType)
// state to hold left stream element
private var leftState: MapState[Row, JTuple2[Int, Long]] = _
@@ -116,7 +122,7 @@ class NonWindowInnerJoin(
ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
out: Collector[CRow]): Unit = {
- processElement(valueC, ctx, out, leftTimer, leftState, rightState, true)
+ processElement(valueC, ctx, out, leftTimer, leftState, rightState, isLeft = true)
}
/**
@@ -132,7 +138,7 @@ class NonWindowInnerJoin(
ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
out: Collector[CRow]): Unit = {
- processElement(valueC, ctx, out, rightTimer, rightState, leftState, false)
+ processElement(valueC, ctx, out, rightTimer, rightState, leftState, isLeft = false)
}
@@ -168,7 +174,6 @@ class NonWindowInnerJoin(
}
}
-
def getNewExpiredTime(
curProcessTime: Long,
oldExpiredTime: Long): Long = {
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala
index 278ae18..7df121f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala
@@ -109,52 +109,56 @@ object TypeCheckUtils {
/**
* Checks whether a type implements own hashCode() and equals() methods for storing an instance
- * in Flink's state.
+ * in Flink's state or performing a keyBy operation.
*
+ * @param name name of the operation
* @param t type information to be validated
*/
- def validateStateType(t: TypeInformation[_]): Unit = t match {
+ def validateEqualsHashCode(name: String, t: TypeInformation[_]): Unit = t match {
+
// make sure that a POJO class is a valid state type
case pt: PojoTypeInfo[_] =>
// we don't check the types recursively to give a chance of wrapping
// proper hashCode/equals methods around an immutable type
- validateStateType(pt.getClass)
+ validateEqualsHashCode(name, pt.getClass)
// recursively check composite types
case ct: CompositeType[_] =>
- validateStateType(t.getTypeClass)
+ validateEqualsHashCode(name, t.getTypeClass)
// we check recursively for entering Flink types such as tuples and rows
for (i <- 0 until ct.getArity) {
val subtype = ct.getTypeAt(i)
- validateStateType(subtype)
+ validateEqualsHashCode(name, subtype)
}
// check other type information only based on the type class
case _: TypeInformation[_] =>
- validateStateType(t.getTypeClass)
+ validateEqualsHashCode(name, t.getTypeClass)
}
/**
* Checks whether a class implements own hashCode() and equals() methods for storing an instance
- * in Flink's state.
+ * in Flink's state or performing a keyBy operation.
*
+ * @param name name of the operation
* @param c class to be validated
*/
- def validateStateType(c: Class[_]): Unit = {
+ def validateEqualsHashCode(name: String, c: Class[_]): Unit = {
+
// skip primitives
if (!c.isPrimitive) {
// check the component type of arrays
if (c.isArray) {
- validateStateType(c.getComponentType)
+ validateEqualsHashCode(name, c.getComponentType)
}
// check type for methods
else {
if (c.getMethod("hashCode").getDeclaringClass eq classOf[Object]) {
throw new ValidationException(
- s"Type '${c.getCanonicalName}' cannot be used in a stateful operation because it " +
+ s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " +
s"does not implement a proper hashCode() method.")
}
if (c.getMethod("equals", classOf[Object]).getDeclaringClass eq classOf[Object]) {
throw new ValidationException(
- s"Type '${c.getCanonicalName}' cannot be used in a stateful operation because it " +
+ s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " +
s"does not implement a proper equals() method.")
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala
index b354929..9cb3fbf 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala
@@ -20,8 +20,9 @@ package org.apache.flink.table.api.stream.table.validation
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
-import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.stream.table.validation.JoinValidationTest.WithoutEqualsHashCode
+import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.runtime.utils.StreamTestData
import org.apache.flink.table.utils.TableTestBase
import org.apache.flink.types.Row
@@ -30,6 +31,26 @@ import org.junit.Test
class JoinValidationTest extends TableTestBase {
/**
+ * Generic type cannot be used as key of map state.
+ */
+ @Test(expected = classOf[ValidationException])
+ def testInvalidStateTypes(): Unit = {
+ val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
+ val tenv = TableEnvironment.getTableEnvironment(env)
+ val ds = env.fromElements(new WithoutEqualsHashCode) // no equals/hashCode
+ val t = tenv.fromDataStream(ds)
+
+ val left = t.select('f0 as 'l)
+ val right = t.select('f0 as 'r)
+
+ val resultTable = left.join(right)
+ .where('l === 'r)
+ .select('l)
+
+ resultTable.toRetractStream[Row]
+ }
+
+ /**
* At least one equi-join predicate required.
*/
@Test(expected = classOf[TableException])
@@ -109,13 +130,12 @@ class JoinValidationTest extends TableTestBase {
util.verifyTable(resultTable, "")
}
-
- private val util = streamTestUtil()
- private val ds1 = util.addTable[(Int, Long, String)]("Table3",'a, 'b, 'c)
- private val ds2 = util.addTable[(Int, Long, Int, String, Long)]("Table5", 'd, 'e, 'f, 'g, 'h)
-
@Test(expected = classOf[ValidationException])
def testJoinNonExistingKey(): Unit = {
+ val util = streamTestUtil()
+ val ds1 = util.addTable[(Int, Long, String)]("Table3",'a, 'b, 'c)
+ val ds2 = util.addTable[(Int, Long, Int, String, Long)]("Table5", 'd, 'e, 'f, 'g, 'h)
+
ds1.join(ds2)
// must fail. Field 'foo does not exist
.where('foo === 'e)
@@ -124,6 +144,10 @@ class JoinValidationTest extends TableTestBase {
@Test(expected = classOf[ValidationException])
def testJoinWithNonMatchingKeyTypes(): Unit = {
+ val util = streamTestUtil()
+ val ds1 = util.addTable[(Int, Long, String)]("Table3",'a, 'b, 'c)
+ val ds2 = util.addTable[(Int, Long, Int, String, Long)]("Table5", 'd, 'e, 'f, 'g, 'h)
+
ds1.join(ds2)
// must fail. Field 'a is Int, and 'g is String
.where('a === 'g)
@@ -133,6 +157,10 @@ class JoinValidationTest extends TableTestBase {
@Test(expected = classOf[ValidationException])
def testJoinWithAmbiguousFields(): Unit = {
+ val util = streamTestUtil()
+ val ds1 = util.addTable[(Int, Long, String)]("Table3",'a, 'b, 'c)
+ val ds2 = util.addTable[(Int, Long, Int, String, Long)]("Table5", 'd, 'e, 'f, 'g, 'h)
+
ds1.join(ds2.select('d, 'e, 'f, 'g, 'h as 'c))
// must fail. Both inputs share the same field 'c
.where('a === 'd)
@@ -141,6 +169,10 @@ class JoinValidationTest extends TableTestBase {
@Test(expected = classOf[TableException])
def testNoEqualityJoinPredicate1(): Unit = {
+ val util = streamTestUtil()
+ val ds1 = util.addTable[(Int, Long, String)]("Table3",'a, 'b, 'c)
+ val ds2 = util.addTable[(Int, Long, Int, String, Long)]("Table5", 'd, 'e, 'f, 'g, 'h)
+
ds1.join(ds2)
// must fail. No equality join predicate
.where('d === 'f)
@@ -150,6 +182,10 @@ class JoinValidationTest extends TableTestBase {
@Test(expected = classOf[TableException])
def testNoEqualityJoinPredicate2(): Unit = {
+ val util = streamTestUtil()
+ val ds1 = util.addTable[(Int, Long, String)]("Table3",'a, 'b, 'c)
+ val ds2 = util.addTable[(Int, Long, Int, String, Long)]("Table5", 'd, 'e, 'f, 'g, 'h)
+
ds1.join(ds2)
// must fail. No equality join predicate
.where('a < 'd)
@@ -159,6 +195,10 @@ class JoinValidationTest extends TableTestBase {
@Test(expected = classOf[ValidationException])
def testNoEquiJoin(): Unit = {
+ val util = streamTestUtil()
+ val ds1 = util.addTable[(Int, Long, String)]("Table3",'a, 'b, 'c)
+ val ds2 = util.addTable[(Int, Long, Int, String, Long)]("Table5", 'd, 'e, 'f, 'g, 'h)
+
ds2.join(ds1, 'b < 'd).select('c, 'g)
}
@@ -189,3 +229,7 @@ class JoinValidationTest extends TableTestBase {
in1.join(in2).where("a === d").select("g.count")
}
}
+
+object JoinValidationTest {
+ class WithoutEqualsHashCode
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala
index 6fb19fc..a648724 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala
@@ -26,7 +26,6 @@ import org.apache.flink.table.api.scala._
import org.apache.flink.api.scala._
import org.junit.Test
-
class UpdatingPlanCheckerTest {
@Test
@@ -94,9 +93,9 @@ class UpdatingPlanCheckerTest {
val resultTable = table
.window(Tumble over 5.milli on 'proctime as 'w)
.groupBy('w, 'a)
- .select('a, 'b.count, 'w.start as 'start)
+ .select('a, 'b.count, 'w.proctime as 'p, 'w.start as 's, 'w.end as 'e)
- util.verifyTableUniqueKey(resultTable, Seq("a", "start"))
+ util.verifyTableUniqueKey(resultTable, Seq("a", "s", "e"))
}
@Test
@@ -217,7 +216,7 @@ class UpdatePlanCheckerUtil extends StreamTableTestUtil {
val actual = UpdatingPlanChecker.getUniqueKeyFields(optimized)
if (actual.isDefined) {
- assertEquals(expected.sorted, actual.get.toList.sorted)
+ assertEquals(expected.sorted, actual.get.toSeq.sorted)
} else {
assertEquals(expected.sorted, Nil)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/49c6d10f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/typeutils/TypeCheckUtilsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/typeutils/TypeCheckUtilsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/typeutils/TypeCheckUtilsTest.scala
index 65a7dbd..645e608 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/typeutils/TypeCheckUtilsTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/typeutils/TypeCheckUtilsTest.scala
@@ -21,34 +21,34 @@ package org.apache.flink.table.typeutils
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.typeutils.{Types => ScalaTypes}
import org.apache.flink.table.api.{Types, ValidationException}
-import org.apache.flink.table.typeutils.TypeCheckUtils.validateStateType
+import org.apache.flink.table.typeutils.TypeCheckUtils.validateEqualsHashCode
import org.junit.Test
class TypeCheckUtilsTest {
@Test
def testValidateStateType(): Unit = {
- validateStateType(Types.STRING)
- validateStateType(Types.LONG)
- validateStateType(Types.SQL_TIMESTAMP)
- validateStateType(Types.ROW(Types.LONG, Types.DECIMAL))
- validateStateType(ScalaTypes.CASE_CLASS[(Long, Int)])
- validateStateType(Types.OBJECT_ARRAY(Types.LONG))
- validateStateType(Types.PRIMITIVE_ARRAY(Types.LONG))
+ validateEqualsHashCode("", Types.STRING)
+ validateEqualsHashCode("", Types.LONG)
+ validateEqualsHashCode("", Types.SQL_TIMESTAMP)
+ validateEqualsHashCode("", Types.ROW(Types.LONG, Types.DECIMAL))
+ validateEqualsHashCode("", ScalaTypes.CASE_CLASS[(Long, Int)])
+ validateEqualsHashCode("", Types.OBJECT_ARRAY(Types.LONG))
+ validateEqualsHashCode("", Types.PRIMITIVE_ARRAY(Types.LONG))
}
@Test(expected = classOf[ValidationException])
def testInvalidType(): Unit = {
- validateStateType(ScalaTypes.NOTHING)
+ validateEqualsHashCode("", ScalaTypes.NOTHING)
}
@Test(expected = classOf[ValidationException])
def testInvalidType2(): Unit = {
- validateStateType(Types.ROW(ScalaTypes.NOTHING))
+ validateEqualsHashCode("", Types.ROW(ScalaTypes.NOTHING))
}
@Test(expected = classOf[ValidationException])
def testInvalidType3(): Unit = {
- validateStateType(Types.OBJECT_ARRAY[Nothing](ScalaTypes.NOTHING))
+ validateEqualsHashCode("", Types.OBJECT_ARRAY[Nothing](ScalaTypes.NOTHING))
}
}
[2/2] flink git commit: [FLINK-6094] [table] Implement stream-stream
proctime non-window inner join
Posted by tw...@apache.org.
[FLINK-6094] [table] Implement stream-stream proctime non-window inner join
This closes #4471.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/9623b252
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/9623b252
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/9623b252
Branch: refs/heads/master
Commit: 9623b252a97cc7a8a48a1e2ee18df3abe56bc9d9
Parents: 11287fb
Author: 军长 <he...@alibaba-inc.com>
Authored: Sun Jul 30 18:45:45 2017 +0800
Committer: twalthr <tw...@apache.org>
Committed: Tue Jan 9 09:48:32 2018 +0100
----------------------------------------------------------------------
.../DataStreamGroupWindowAggregate.scala | 2 +
.../plan/nodes/datastream/DataStreamJoin.scala | 196 +++++++++++++
.../flink/table/plan/rules/FlinkRuleSets.scala | 1 +
.../rules/datastream/DataStreamJoinRule.scala | 111 +++++++
.../datastream/DataStreamWindowJoinRule.scala | 2 +-
.../table/plan/util/UpdatingPlanChecker.scala | 165 +++++++++--
.../CRowWrappingMultiOutputCollector.scala | 50 ++++
.../table/runtime/join/NonWindowInnerJoin.scala | 286 +++++++++++++++++++
.../sql/validation/JoinValidationTest.scala | 35 ++-
.../table/validation/JoinValidationTest.scala | 104 ++++++-
.../flink/table/plan/RetractionRulesTest.scala | 34 +++
.../table/plan/UpdatingPlanCheckerTest.scala | 225 +++++++++++++++
.../table/runtime/harness/JoinHarnessTest.scala | 224 ++++++++++++++-
.../runtime/harness/NonWindowHarnessTest.scala | 4 +-
.../table/runtime/stream/sql/JoinITCase.scala | 55 ++++
.../table/runtime/stream/table/JoinITCase.scala | 243 ++++++++++++++++
.../runtime/stream/table/TableSinkITCase.scala | 87 +++---
17 files changed, 1739 insertions(+), 85 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
index d527dc8..7a6b333 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
@@ -70,6 +70,8 @@ class DataStreamGroupWindowAggregate(
def getWindowProperties: Seq[NamedWindowProperty] = namedProperties
+ def getWindowAlias: String = window.aliasAttribute.toString
+
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamGroupWindowAggregate(
window,
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
new file mode 100644
index 0000000..576c2bc
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes.datastream
+
+import org.apache.calcite.plan._
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
+import org.apache.calcite.rel.{BiRel, RelNode, RelWriter}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.common.functions.FlatJoinFunction
+import org.apache.flink.streaming.api.datastream.DataStream
+import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException}
+import org.apache.flink.table.codegen.FunctionCodeGenerator
+import org.apache.flink.table.plan.nodes.CommonJoin
+import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.runtime.CRowKeySelector
+import org.apache.flink.table.runtime.join.NonWindowInnerJoin
+import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
+import org.apache.flink.types.Row
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * RelNode for a non-windowed stream join.
+ */
+class DataStreamJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftNode: RelNode,
+ rightNode: RelNode,
+ joinCondition: RexNode,
+ joinInfo: JoinInfo,
+ joinType: JoinRelType,
+ leftSchema: RowSchema,
+ rightSchema: RowSchema,
+ schema: RowSchema,
+ ruleDescription: String)
+ extends BiRel(cluster, traitSet, leftNode, rightNode)
+ with CommonJoin
+ with DataStreamRel {
+
+ override def deriveRowType(): RelDataType = schema.relDataType
+
+ override def needsUpdatesAsRetraction: Boolean = true
+
+ override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+ new DataStreamJoin(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ inputs.get(1),
+ joinCondition,
+ joinInfo,
+ joinType,
+ leftSchema,
+ rightSchema,
+ schema,
+ ruleDescription)
+ }
+
+ def getJoinInfo: JoinInfo = joinInfo
+
+ def getJoinType: JoinRelType = joinType
+
+ override def toString: String = {
+ joinToString(
+ schema.relDataType,
+ joinCondition,
+ joinType,
+ getExpressionString)
+ }
+
+ override def explainTerms(pw: RelWriter): RelWriter = {
+ joinExplainTerms(
+ super.explainTerms(pw),
+ schema.relDataType,
+ joinCondition,
+ joinType,
+ getExpressionString)
+ }
+
+ override def translateToPlan(
+ tableEnv: StreamTableEnvironment,
+ queryConfig: StreamQueryConfig): DataStream[CRow] = {
+
+ val config = tableEnv.getConfig
+ val returnType = schema.typeInfo
+ val keyPairs = joinInfo.pairs().toList
+
+ // get the equality keys
+ val leftKeys = ArrayBuffer.empty[Int]
+ val rightKeys = ArrayBuffer.empty[Int]
+
+ // at least one equality expression
+ val leftFields = left.getRowType.getFieldList
+ val rightFields = right.getRowType.getFieldList
+
+ keyPairs.foreach(pair => {
+ val leftKeyType = leftFields.get(pair.source).getType.getSqlTypeName
+ val rightKeyType = rightFields.get(pair.target).getType.getSqlTypeName
+ // check if keys are compatible
+ if (leftKeyType == rightKeyType) {
+ // add key pair
+ leftKeys.add(pair.source)
+ rightKeys.add(pair.target)
+ } else {
+ throw TableException(
+ "Equality join predicate on incompatible types.\n" +
+ s"\tLeft: ${left},\n" +
+ s"\tRight: ${right},\n" +
+ s"\tCondition: (${joinConditionToString(schema.relDataType,
+ joinCondition, getExpressionString)})"
+ )
+ }
+ })
+
+ val leftDataStream =
+ left.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
+ val rightDataStream =
+ right.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
+
+ val (connectOperator, nullCheck) = joinType match {
+ case JoinRelType.INNER => (leftDataStream.connect(rightDataStream), false)
+ case _ => throw TableException(s"An Unsupported JoinType [ $joinType ]. Currently only " +
+ s"non-window inner joins with at least one equality predicate are supported")
+ }
+
+ val generator = new FunctionCodeGenerator(
+ config,
+ nullCheck,
+ leftSchema.typeInfo,
+ Some(rightSchema.typeInfo))
+ val conversion = generator.generateConverterResultExpression(
+ schema.typeInfo,
+ schema.fieldNames)
+
+ val body = if (joinInfo.isEqui) {
+ // only equality condition
+ s"""
+ |${conversion.code}
+ |${generator.collectorTerm}.collect(${conversion.resultTerm});
+ |""".stripMargin
+ } else {
+ val nonEquiPredicates = joinInfo.getRemaining(this.cluster.getRexBuilder)
+ val condition = generator.generateExpression(nonEquiPredicates)
+ s"""
+ |${condition.code}
+ |if (${condition.resultTerm}) {
+ | ${conversion.code}
+ | ${generator.collectorTerm}.collect(${conversion.resultTerm});
+ |}
+ |""".stripMargin
+ }
+
+ val genFunction = generator.generateFunction(
+ ruleDescription,
+ classOf[FlatJoinFunction[Row, Row, Row]],
+ body,
+ returnType)
+
+ val coMapFun =
+ new NonWindowInnerJoin(
+ leftSchema.typeInfo,
+ rightSchema.typeInfo,
+ CRowTypeInfo(returnType),
+ genFunction.name,
+ genFunction.code,
+ queryConfig)
+
+ val joinOpName = joinToString(getRowType, joinCondition, joinType, getExpressionString)
+ connectOperator
+ .keyBy(
+ new CRowKeySelector(leftKeys.toArray, leftSchema.projectedTypeInfo(leftKeys.toArray)),
+ new CRowKeySelector(rightKeys.toArray, rightSchema.projectedTypeInfo(rightKeys.toArray)))
+ .process(coMapFun)
+ .name(joinOpName)
+ .returns(CRowTypeInfo(returnType))
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index 10d6881..b8a96bf 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -210,6 +210,7 @@ object FlinkRuleSets {
DataStreamValuesRule.INSTANCE,
DataStreamCorrelateRule.INSTANCE,
DataStreamWindowJoinRule.INSTANCE,
+ DataStreamJoinRule.INSTANCE,
StreamTableSourceScanRule.INSTANCE
)
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamJoinRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamJoinRule.scala
new file mode 100644
index 0000000..072acb3
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamJoinRule.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.datastream
+
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode}
+import org.apache.flink.table.api.TableConfig
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.plan.nodes.FlinkConventions
+import org.apache.flink.table.plan.nodes.datastream.DataStreamJoin
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalJoin
+import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.runtime.join.WindowJoinUtil
+
+import scala.collection.JavaConverters._
+
+class DataStreamJoinRule
+ extends ConverterRule(
+ classOf[FlinkLogicalJoin],
+ FlinkConventions.LOGICAL,
+ FlinkConventions.DATASTREAM,
+ "DataStreamJoinRule") {
+
+ /**
+ * Checks if an expression accesses a time attribute.
+ *
+ * @param expr The expression to check.
+ * @param inputType The input type of the expression.
+ * @return True, if the expression accesses a time attribute. False otherwise.
+ */
+ def accessesTimeAttribute(expr: RexNode, inputType: RelDataType): Boolean = {
+ expr match {
+ case i: RexInputRef =>
+ val accessedType = inputType.getFieldList.get(i.getIndex).getType
+ FlinkTypeFactory.isTimeIndicatorType(accessedType)
+ case c: RexCall =>
+ c.operands.asScala.exists(accessesTimeAttribute(_, inputType))
+ case _ => false
+ }
+ }
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val join: FlinkLogicalJoin = call.rel(0).asInstanceOf[FlinkLogicalJoin]
+ val joinInfo = join.analyzeCondition
+
+ val (windowBounds, remainingPreds) = WindowJoinUtil.extractWindowBoundsFromPredicate(
+ joinInfo.getRemaining(join.getCluster.getRexBuilder),
+ join.getLeft.getRowType.getFieldCount,
+ join.getRowType,
+ join.getCluster.getRexBuilder,
+ TableConfig.DEFAULT)
+
+ // remaining predicate must not access time attributes
+ val remainingPredsAccessTime = remainingPreds.isDefined &&
+ accessesTimeAttribute(remainingPreds.get, join.getRowType)
+
+ // Check that no event-time attributes are in the input because non-window join is unbounded
+ // and we don't know how much to hold back watermarks.
+ val rowTimeAttrInOutput = join.getRowType.getFieldList.asScala
+ .exists(f => FlinkTypeFactory.isRowtimeIndicatorType(f.getType))
+
+ windowBounds.isEmpty && !remainingPredsAccessTime && !rowTimeAttrInOutput
+ }
+
+ override def convert(rel: RelNode): RelNode = {
+
+ val join: FlinkLogicalJoin = rel.asInstanceOf[FlinkLogicalJoin]
+ val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM)
+ val convLeft: RelNode = RelOptRule.convert(join.getInput(0), FlinkConventions.DATASTREAM)
+ val convRight: RelNode = RelOptRule.convert(join.getInput(1), FlinkConventions.DATASTREAM)
+ val joinInfo = join.analyzeCondition
+ val leftRowSchema = new RowSchema(convLeft.getRowType)
+ val rightRowSchema = new RowSchema(convRight.getRowType)
+
+ new DataStreamJoin(
+ rel.getCluster,
+ traitSet,
+ convLeft,
+ convRight,
+ join.getCondition,
+ joinInfo,
+ join.getJoinType,
+ leftRowSchema,
+ rightRowSchema,
+ new RowSchema(rel.getRowType),
+ description)
+ }
+}
+
+object DataStreamJoinRule {
+ val INSTANCE: RelOptRule = new DataStreamJoinRule
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
index a7358c7..3dfae99 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala
@@ -36,7 +36,7 @@ class DataStreamWindowJoinRule
classOf[FlinkLogicalJoin],
FlinkConventions.LOGICAL,
FlinkConventions.DATASTREAM,
- "DataStreamJoinRule") {
+ "DataStreamWindowJoinRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val join: FlinkLogicalJoin = call.rel(0).asInstanceOf[FlinkLogicalJoin]
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
index 6a160f6..9ec097a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/UpdatingPlanChecker.scala
@@ -17,12 +17,15 @@
*/
package org.apache.flink.table.plan.util
+import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.{RelNode, RelVisitor}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode}
import org.apache.calcite.sql.SqlKind
import org.apache.flink.table.plan.nodes.datastream._
import _root_.scala.collection.JavaConverters._
+import _root_.scala.collection.JavaConversions._
+import scala.collection.mutable
object UpdatingPlanChecker {
@@ -37,8 +40,7 @@ object UpdatingPlanChecker {
/** Extracts the unique keys of the table produced by the plan. */
def getUniqueKeyFields(plan: RelNode): Option[Array[String]] = {
val keyExtractor = new UniqueKeyExtractor
- keyExtractor.go(plan)
- keyExtractor.keys
+ keyExtractor.visit(plan).map(_.map(_._1).toArray)
}
private class AppendOnlyValidator extends RelVisitor {
@@ -56,16 +58,20 @@ object UpdatingPlanChecker {
}
/** Identifies unique key fields in the output of a RelNode. */
- private class UniqueKeyExtractor extends RelVisitor {
+ private class UniqueKeyExtractor {
- var keys: Option[Array[String]] = None
-
- override def visit(node: RelNode, ordinal: Int, parent: RelNode): Unit = {
+ // visit() function will return a tuple, the first element is the name of a key field, the
+ // second is a group name that is shared by all equivalent key fields. The group names are
+ // used to identify same keys, for example: select('pk as pk1, 'pk as pk2), both pk1 and pk2
+ // belong to the same group, i.e., pk1. Here we use the lexicographic smallest attribute as
+ // the common group id. A node can have keys if it generates the keys by itself or it
+ // forwards keys from its input(s).
+ def visit(node: RelNode): Option[List[(String, String)]] = {
node match {
case c: DataStreamCalc =>
- super.visit(node, ordinal, parent)
+ val inputKeys = visit(node.getInput(0))
// check if input has keys
- if (keys.isDefined) {
+ if (inputKeys.isDefined) {
// track keys forward
val inNames = c.getInput.getRowType.getFieldNames
val inOutNames = c.getProgram.getNamedProjects.asScala
@@ -91,23 +97,36 @@ object UpdatingPlanChecker {
.map(io => (inNames.get(io._1), io._2))
// filter by input keys
- val outKeys = inOutNames.filter(io => keys.get.contains(io._1)).map(_._2)
+ val inputKeysAndOutput = inOutNames
+ .filter(io => inputKeys.get.map(e => e._1).contains(io._1))
+
+ val inputKeysMap = inputKeys.get.toMap
+ val inOutGroups = inputKeysAndOutput
+ .map(e => (inputKeysMap(e._1), e._2)).sorted.reverse.toMap
+
+ // get output keys
+ val outputKeys = inputKeysAndOutput
+ .map(io => (io._2, inOutGroups(inputKeysMap(io._1))))
+
// check if all keys have been preserved
- if (outKeys.nonEmpty && outKeys.length == keys.get.length) {
+ if (outputKeys.map(_._2).distinct.length == inputKeys.get.map(_._2).distinct.length) {
// all key have been preserved (but possibly renamed)
- keys = Some(outKeys.toArray)
+ Some(outputKeys.toList)
} else {
// some (or all) keys have been removed. Keys are no longer unique and removed
- keys = None
+ None
}
+ } else {
+ None
}
+
case _: DataStreamOverAggregate =>
- super.visit(node, ordinal, parent)
- // keys are always forwarded by Over aggregate
+ // keys are always forwarded by Over aggregate
+ visit(node.getInput(0))
case a: DataStreamGroupAggregate =>
// get grouping keys
val groupKeys = a.getRowType.getFieldNames.asScala.take(a.getGroupings.length)
- keys = Some(groupKeys.toArray)
+ Some(groupKeys.map(e => (e, e)).toList)
case w: DataStreamGroupWindowAggregate =>
// get grouping keys
val groupKeys =
@@ -116,14 +135,122 @@ object UpdatingPlanChecker {
val windowStartEnd = w.getWindowProperties.map(_.name)
// we have only a unique key if at least one window property is selected
if (windowStartEnd.nonEmpty) {
- keys = Some(groupKeys ++ windowStartEnd)
+ val smallestAttribute = windowStartEnd.min
+ Some((groupKeys.map(e => (e, e)) ++ windowStartEnd.map((_, smallestAttribute))).toList)
+ } else {
+ None
+ }
+
+ case j: DataStreamJoin =>
+ val joinType = j.getJoinType
+ joinType match {
+ case JoinRelType.INNER => {
+ // get key(s) for inner join
+ val lInKeys = visit(j.getLeft)
+ val rInKeys = visit(j.getRight)
+ if (lInKeys.isEmpty || rInKeys.isEmpty) {
+ None
+ } else {
+ // Output of inner join must have keys if left and right both contain key(s).
+ // Key groups from both side will be merged by join equi-predicates
+ val lInNames: Seq[String] = j.getLeft.getRowType.getFieldNames
+ val rInNames: Seq[String] = j.getRight.getRowType.getFieldNames
+ val joinNames = j.getRowType.getFieldNames
+
+ // if right field names equal to left field names, calcite will rename right
+ // field names. For example, T1(pk, a) join T2(pk, b), calcite will rename T2(pk, b)
+ // to T2(pk0, b).
+ val rInNamesToJoinNamesMap = rInNames
+ .zip(joinNames.subList(lInNames.size, joinNames.length))
+ .toMap
+
+ val lJoinKeys: Seq[String] = j.getJoinInfo.leftKeys
+ .map(lInNames.get(_))
+ val rJoinKeys: Seq[String] = j.getJoinInfo.rightKeys
+ .map(rInNames.get(_))
+ .map(rInNamesToJoinNamesMap(_))
+
+ val inKeys: List[(String, String)] = lInKeys.get ++ rInKeys.get
+ .map(e => (rInNamesToJoinNamesMap(e._1), rInNamesToJoinNamesMap(e._2)))
+
+ getOutputKeysForInnerJoin(
+ joinNames,
+ inKeys,
+ lJoinKeys.zip(rJoinKeys).toList
+ )
+ }
+ }
+ case _ => throw new UnsupportedOperationException(
+ s"An Unsupported JoinType [ $joinType ]")
}
case _: DataStreamRel =>
- // anything else does not forward keys or might duplicate key, so we can stop
- keys = None
+ // anything else does not forward keys, so we can stop
+ None
}
}
- }
+ /**
+ * Get output keys for non-window inner join according to it's inputs.
+ *
+ * @param inNames Field names of join
+ * @param inKeys Input keys of join
+ * @param joinKeys JoinKeys of inner join
+ * @return Return output keys of inner join
+ */
+ def getOutputKeysForInnerJoin(
+ inNames: Seq[String],
+ inKeys: List[(String, String)],
+ joinKeys: List[(String, String)])
+ : Option[List[(String, String)]] = {
+
+ val nameToGroups = mutable.HashMap.empty[String,String]
+
+ // merge two groups
+ def merge(nameA: String, nameB: String): Unit = {
+ val ga: String = findGroup(nameA)
+ val gb: String = findGroup(nameB)
+ if (!ga.equals(gb)) {
+ if(ga.compare(gb) < 0) {
+ nameToGroups += (gb -> ga)
+ } else {
+ nameToGroups += (ga -> gb)
+ }
+ }
+ }
+ def findGroup(x: String): String = {
+ // find the group of x
+ var r: String = x
+ while (!nameToGroups(r).equals(r)) {
+ r = nameToGroups(r)
+ }
+
+ // point all name to the group name directly
+ var a: String = x
+ var b: String = null
+ while (!nameToGroups(a).equals(r)) {
+ b = nameToGroups(a)
+ nameToGroups += (a -> r)
+ a = b
+ }
+ r
+ }
+
+ // init groups
+ inNames.foreach(e => nameToGroups += (e -> e))
+ inKeys.foreach(e => nameToGroups += (e._1 -> e._2))
+ // merge groups
+ joinKeys.foreach(e => merge(e._1, e._2))
+ // make sure all name point to the group name directly
+ inNames.foreach(findGroup(_))
+
+ val outputGroups = inKeys.map(e => nameToGroups(e._1)).distinct
+ Some(
+ inNames
+ .filter(e => outputGroups.contains(nameToGroups(e)))
+ .map(e => (e, nameToGroups(e)))
+ .toList
+ )
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowWrappingMultiOutputCollector.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowWrappingMultiOutputCollector.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowWrappingMultiOutputCollector.scala
new file mode 100644
index 0000000..d551111
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowWrappingMultiOutputCollector.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime
+
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
+
+/**
+ * The collector to wrap a [[Row]] into a [[CRow]] and collect it multiple times.
+ */
+class CRowWrappingMultiOutputCollector() extends Collector[Row] {
+
+ private var out: Collector[CRow] = _
+ private val outCRow: CRow = new CRow()
+ private var times: Long = 0L
+
+ def setCollector(collector: Collector[CRow]): Unit = this.out = collector
+
+ def setChange(change: Boolean): Unit = this.outCRow.change = change
+
+ def setTimes(times: Long): Unit = this.times = times
+
+ override def collect(record: Row): Unit = {
+ outCRow.row = record
+ var i: Long = 0L
+ while (i < times) {
+ out.collect(outCRow)
+ i += 1
+ }
+ }
+
+ override def close(): Unit = out.close()
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
new file mode 100644
index 0000000..841cd15
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.flink.table.runtime.join
+
+import org.apache.flink.api.common.functions.FlatJoinFunction
+import org.apache.flink.api.common.state._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction
+import org.apache.flink.table.api.{StreamQueryConfig, Types}
+import org.apache.flink.table.runtime.CRowWrappingMultiOutputCollector
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
+import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.table.codegen.Compiler
+import org.apache.flink.table.util.Logging
+
+
+/**
+ * Connect data for left stream and right stream. Only use for innerJoin.
+ *
+ * @param leftType the input type of left stream
+ * @param rightType the input type of right stream
+ * @param resultType the output type of join
+ * @param genJoinFuncName the function code of other non-equi condition
+ * @param genJoinFuncCode the function name of other non-equi condition
+ * @param queryConfig the configuration for the query to generate
+ */
+class NonWindowInnerJoin(
+ leftType: TypeInformation[Row],
+ rightType: TypeInformation[Row],
+ resultType: TypeInformation[CRow],
+ genJoinFuncName: String,
+ genJoinFuncCode: String,
+ queryConfig: StreamQueryConfig)
+ extends CoProcessFunction[CRow, CRow, CRow]
+ with Compiler[FlatJoinFunction[Row, Row, Row]]
+ with Logging {
+
+ // state to hold left stream element
+ private var leftState: MapState[Row, JTuple2[Int, Long]] = _
+ // state to hold right stream element
+ private var rightState: MapState[Row, JTuple2[Int, Long]] = _
+ private var cRowWrapper: CRowWrappingMultiOutputCollector = _
+
+ private val minRetentionTime: Long = queryConfig.getMinIdleStateRetentionTime
+ private val maxRetentionTime: Long = queryConfig.getMaxIdleStateRetentionTime
+ private val stateCleaningEnabled: Boolean = minRetentionTime > 1
+
+ // state to record last timer of left stream, 0 means no timer
+ private var leftTimer: ValueState[Long] = _
+ // state to record last timer of right stream, 0 means no timer
+ private var rightTimer: ValueState[Long] = _
+
+ // other condition function
+ private var joinFunction: FlatJoinFunction[Row, Row, Row] = _
+
+ override def open(parameters: Configuration): Unit = {
+ LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " +
+ s"Code:\n$genJoinFuncCode")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genJoinFuncName,
+ genJoinFuncCode)
+ LOG.debug("Instantiating JoinFunction.")
+ joinFunction = clazz.newInstance()
+
+ // initialize left and right state, the first element of tuple2 indicates how many rows of
+ // this row, while the second element represents the expired time of this row.
+ val tupleTypeInfo = new TupleTypeInfo[JTuple2[Int, Long]](Types.INT, Types.LONG)
+ val leftStateDescriptor = new MapStateDescriptor[Row, JTuple2[Int, Long]](
+ "left", leftType, tupleTypeInfo)
+ val rightStateDescriptor = new MapStateDescriptor[Row, JTuple2[Int, Long]](
+ "right", rightType, tupleTypeInfo)
+ leftState = getRuntimeContext.getMapState(leftStateDescriptor)
+ rightState = getRuntimeContext.getMapState(rightStateDescriptor)
+
+ // initialize timer state
+ val valueStateDescriptor1 = new ValueStateDescriptor[Long]("timervaluestate1", classOf[Long])
+ leftTimer = getRuntimeContext.getState(valueStateDescriptor1)
+ val valueStateDescriptor2 = new ValueStateDescriptor[Long]("timervaluestate2", classOf[Long])
+ rightTimer = getRuntimeContext.getState(valueStateDescriptor2)
+
+ cRowWrapper = new CRowWrappingMultiOutputCollector()
+ }
+
+ /**
+ * Process left stream records
+ *
+ * @param valueC The input value.
+ * @param ctx The ctx to register timer or get current time
+ * @param out The collector for returning result values.
+ *
+ */
+ override def processElement1(
+ valueC: CRow,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ out: Collector[CRow]): Unit = {
+
+ processElement(valueC, ctx, out, leftTimer, leftState, rightState, true)
+ }
+
+ /**
+ * Process right stream records
+ *
+ * @param valueC The input value.
+ * @param ctx The ctx to register timer or get current time
+ * @param out The collector for returning result values.
+ *
+ */
+ override def processElement2(
+ valueC: CRow,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ out: Collector[CRow]): Unit = {
+
+ processElement(valueC, ctx, out, rightTimer, rightState, leftState, false)
+ }
+
+
+ /**
+ * Called when a processing timer trigger.
+ * Expire left/right records which are expired in left and right state.
+ *
+ * @param timestamp The timestamp of the firing timer.
+ * @param ctx The ctx to register timer or get current time
+ * @param out The collector for returning result values.
+ */
+ override def onTimer(
+ timestamp: Long,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext,
+ out: Collector[CRow]): Unit = {
+
+ if (stateCleaningEnabled && leftTimer.value == timestamp) {
+ expireOutTimeRow(
+ timestamp,
+ leftState,
+ leftTimer,
+ ctx
+ )
+ }
+
+ if (stateCleaningEnabled && rightTimer.value == timestamp) {
+ expireOutTimeRow(
+ timestamp,
+ rightState,
+ rightTimer,
+ ctx
+ )
+ }
+ }
+
+
+ def getNewExpiredTime(
+ curProcessTime: Long,
+ oldExpiredTime: Long): Long = {
+
+ if (stateCleaningEnabled && curProcessTime + minRetentionTime > oldExpiredTime) {
+ curProcessTime + maxRetentionTime
+ } else {
+ oldExpiredTime
+ }
+ }
+
+ /**
+ * Puts or Retract an element from the input stream into state and search the other state to
+ * output records meet the condition. Records will be expired in state if state retention time
+ * has been specified.
+ */
+ def processElement(
+ value: CRow,
+ ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
+ out: Collector[CRow],
+ timerState: ValueState[Long],
+ currentSideState: MapState[Row, JTuple2[Int, Long]],
+ otherSideState: MapState[Row, JTuple2[Int, Long]],
+ isLeft: Boolean): Unit = {
+
+ val inputRow = value.row
+ cRowWrapper.setCollector(out)
+ cRowWrapper.setChange(value.change)
+
+ val curProcessTime = ctx.timerService.currentProcessingTime
+ val oldCntAndExpiredTime = currentSideState.get(inputRow)
+ val cntAndExpiredTime = if (null == oldCntAndExpiredTime) {
+ JTuple2.of(0, -1L)
+ } else {
+ oldCntAndExpiredTime
+ }
+
+ cntAndExpiredTime.f1 = getNewExpiredTime(curProcessTime, cntAndExpiredTime.f1)
+ if (stateCleaningEnabled && timerState.value() == 0) {
+ timerState.update(cntAndExpiredTime.f1)
+ ctx.timerService().registerProcessingTimeTimer(cntAndExpiredTime.f1)
+ }
+
+ // update current side stream state
+ if (!value.change) {
+ cntAndExpiredTime.f0 = cntAndExpiredTime.f0 - 1
+ if (cntAndExpiredTime.f0 <= 0) {
+ currentSideState.remove(inputRow)
+ } else {
+ currentSideState.put(inputRow, cntAndExpiredTime)
+ }
+ } else {
+ cntAndExpiredTime.f0 = cntAndExpiredTime.f0 + 1
+ currentSideState.put(inputRow, cntAndExpiredTime)
+ }
+
+ val otherSideIterator = otherSideState.iterator()
+ // join other side data
+ while (otherSideIterator.hasNext) {
+ val otherSideEntry = otherSideIterator.next()
+ val otherSideRow = otherSideEntry.getKey
+ val cntAndExpiredTime = otherSideEntry.getValue
+ // join
+ cRowWrapper.setTimes(cntAndExpiredTime.f0)
+ if (isLeft) {
+ joinFunction.join(inputRow, otherSideRow, cRowWrapper)
+ } else {
+ joinFunction.join(otherSideRow, inputRow, cRowWrapper)
+ }
+ // clear expired data. Note: clear after join to keep closer to the original semantics
+ if (stateCleaningEnabled && curProcessTime >= cntAndExpiredTime.f1) {
+ otherSideIterator.remove()
+ }
+ }
+ }
+
+
+ /**
+ * Removes records which are expired from the state. Registers a new timer if the state still
+ * holds records after the clean-up.
+ */
+ private def expireOutTimeRow(
+ curTime: Long,
+ rowMapState: MapState[Row, JTuple2[Int, Long]],
+ timerState: ValueState[Long],
+ ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = {
+
+ val rowMapIter = rowMapState.iterator()
+ var validTimestamp: Boolean = false
+
+ while (rowMapIter.hasNext) {
+ val mapEntry = rowMapIter.next()
+ val recordExpiredTime = mapEntry.getValue.f1
+ if (recordExpiredTime <= curTime) {
+ rowMapIter.remove()
+ } else {
+ // we found a timestamp that is still valid
+ validTimestamp = true
+ }
+ }
+
+ // If the state has non-expired timestamps, register a new timer.
+ // Otherwise clean the complete state for this input.
+ if (validTimestamp) {
+ val cleanupTime = curTime + maxRetentionTime
+ ctx.timerService.registerProcessingTimeTimer(cleanupTime)
+ timerState.update(cleanupTime)
+ } else {
+ timerState.clear()
+ rowMapState.clear()
+ }
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/validation/JoinValidationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/validation/JoinValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/validation/JoinValidationTest.scala
index 9f7078c..141c817 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/validation/JoinValidationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/validation/JoinValidationTest.scala
@@ -30,16 +30,6 @@ class JoinValidationTest extends TableTestBase {
streamUtil.addTable[(Int, String, Long)]("MyTable", 'a, 'b, 'c.rowtime, 'proctime.proctime)
streamUtil.addTable[(Int, String, Long)]("MyTable2", 'a, 'b, 'c.rowtime, 'proctime.proctime)
- /** There should exist time conditions **/
- @Test(expected = classOf[TableException])
- def testWindowJoinUnExistTimeCondition() = {
- val sql =
- """
- |SELECT t2.a
- |FROM MyTable t1 JOIN MyTable2 t2 ON t1.a = t2.a""".stripMargin
- streamUtil.verifySql(sql, "n/a")
- }
-
/** There should exist exactly two time conditions **/
@Test(expected = classOf[TableException])
def testWindowJoinSingleTimeCondition() = {
@@ -121,4 +111,29 @@ class JoinValidationTest extends TableTestBase {
streamUtil.verifySql(sql, "n/a")
}
+ /** Validates that no rowtime attribute is in the output schema for non-window inner join **/
+ @Test(expected = classOf[TableException])
+ def testNoRowtimeAttributeInResultForNonWindowInnerJoin(): Unit = {
+ val sql =
+ """
+ |SELECT *
+ |FROM MyTable t1, MyTable2 t2
+ |WHERE t1.a = t2.a
+ | """.stripMargin
+
+ streamUtil.verifySql(sql, "n/a")
+ }
+
+ /** Validates that no proctime attribute is in remaining predicate for non-window inner join **/
+ @Test(expected = classOf[TableException])
+ def testNoProctimeAttributeInResultForNonWindowInnerJoin(): Unit = {
+ val sql =
+ """
+ |SELECT *
+ |FROM MyTable t1, MyTable2 t2
+ |WHERE t1.a = t2.a AND t1.proctime > t2.proctime
+ | """.stripMargin
+
+ streamUtil.verifySql(sql, "n/a")
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala
index e924e6e..b354929 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/JoinValidationTest.scala
@@ -19,14 +19,14 @@
package org.apache.flink.table.api.stream.table.validation
import org.apache.flink.api.scala._
-import org.apache.flink.table.api.TableException
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.StreamTestData
import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.types.Row
import org.junit.Test
-/**
- * Currently only time-windowed inner joins can be processed in a streaming fashion.
- */
class JoinValidationTest extends TableTestBase {
/**
@@ -47,6 +47,22 @@ class JoinValidationTest extends TableTestBase {
}
/**
+ * At least one equi-join predicate required for non-window inner join.
+ */
+ @Test(expected = classOf[TableException])
+ def testNonWindowInnerJoinWithoutEquiPredicate(): Unit = {
+ val util = streamTestUtil()
+ val left = util.addTable[(Long, Int, String)]('a, 'b, 'c)
+ val right = util.addTable[(Long, Int, String)]('d, 'e, 'f)
+
+ val resultTable = left.join(right)
+ .select('a, 'e)
+
+ val expected = ""
+ util.verifyTable(resultTable, expected)
+ }
+
+ /**
* There must be complete window-bounds.
*/
@Test(expected = classOf[TableException])
@@ -92,4 +108,84 @@ class JoinValidationTest extends TableTestBase {
util.verifyTable(resultTable, "")
}
+
+
+ private val util = streamTestUtil()
+ private val ds1 = util.addTable[(Int, Long, String)]("Table3",'a, 'b, 'c)
+ private val ds2 = util.addTable[(Int, Long, Int, String, Long)]("Table5", 'd, 'e, 'f, 'g, 'h)
+
+ @Test(expected = classOf[ValidationException])
+ def testJoinNonExistingKey(): Unit = {
+ ds1.join(ds2)
+ // must fail. Field 'foo does not exist
+ .where('foo === 'e)
+ .select('c, 'g)
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testJoinWithNonMatchingKeyTypes(): Unit = {
+ ds1.join(ds2)
+ // must fail. Field 'a is Int, and 'g is String
+ .where('a === 'g)
+ .select('c, 'g)
+ }
+
+
+ @Test(expected = classOf[ValidationException])
+ def testJoinWithAmbiguousFields(): Unit = {
+ ds1.join(ds2.select('d, 'e, 'f, 'g, 'h as 'c))
+ // must fail. Both inputs share the same field 'c
+ .where('a === 'd)
+ .select('c, 'g)
+ }
+
+ @Test(expected = classOf[TableException])
+ def testNoEqualityJoinPredicate1(): Unit = {
+ ds1.join(ds2)
+ // must fail. No equality join predicate
+ .where('d === 'f)
+ .select('c, 'g)
+ .toRetractStream[Row]
+ }
+
+ @Test(expected = classOf[TableException])
+ def testNoEqualityJoinPredicate2(): Unit = {
+ ds1.join(ds2)
+ // must fail. No equality join predicate
+ .where('a < 'd)
+ .select('c, 'g)
+ .toRetractStream[Row]
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testNoEquiJoin(): Unit = {
+ ds2.join(ds1, 'b < 'd).select('c, 'g)
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testJoinTablesFromDifferentEnvs(): Unit = {
+ val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv1 = TableEnvironment.getTableEnvironment(env)
+ val tEnv2 = TableEnvironment.getTableEnvironment(env)
+ val ds1 = StreamTestData.get3TupleDataStream(env)
+ val ds2 = StreamTestData.get5TupleDataStream(env)
+ val in1 = tEnv1.fromDataStream(ds1, 'a, 'b, 'c)
+ val in2 = tEnv2.fromDataStream(ds2, 'd, 'e, 'f, 'g, 'c)
+
+ // Must fail. Tables are bound to different TableEnvironments.
+ in1.join(in2).where('b === 'e).select('c, 'g)
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testJoinTablesFromDifferentEnvsJava() {
+ val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv1 = TableEnvironment.getTableEnvironment(env)
+ val tEnv2 = TableEnvironment.getTableEnvironment(env)
+ val ds1 = StreamTestData.get3TupleDataStream(env)
+ val ds2 = StreamTestData.get5TupleDataStream(env)
+ val in1 = tEnv1.fromDataStream(ds1, 'a, 'b, 'c)
+ val in2 = tEnv2.fromDataStream(ds2, 'd, 'e, 'f, 'g, 'c)
+ // Must fail. Tables are bound to different TableEnvironments.
+ in1.join(in2).where("a === d").select("g.count")
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RetractionRulesTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RetractionRulesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RetractionRulesTest.scala
index ba3c314..3541f9f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RetractionRulesTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RetractionRulesTest.scala
@@ -269,8 +269,42 @@ class RetractionRulesTest extends TableTestBase {
util.verifyTableTrait(resultTable, expected)
}
+
+ @Test
+ def testJoin(): Unit = {
+ val util = streamTestForRetractionUtil()
+ val lTable = util.addTable[(Int, Int)]('a, 'b)
+ val rTable = util.addTable[(Int, Int)]('bb, 'c)
+
+ val lTableWithPk = lTable
+ .groupBy('a)
+ .select('a, 'b.max as 'b)
+
+ val resultTable = lTableWithPk
+ .join(rTable)
+ .where('b === 'bb)
+ .select('a, 'b, 'c)
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamJoin",
+ unaryNode(
+ "DataStreamGroupAggregate",
+ "DataStreamScan(true, Acc)",
+ "true, AccRetract"
+ ),
+ "DataStreamScan(true, Acc)",
+ "false, AccRetract"
+ ),
+ "false, AccRetract"
+ )
+ util.verifyTableTrait(resultTable, expected)
+ }
}
+
class StreamTableTestForRetractionUtil extends StreamTableTestUtil {
def verifySqlTrait(query: String, expected: String): Unit = {
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala
new file mode 100644
index 0000000..6fb19fc
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/UpdatingPlanCheckerTest.scala
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import org.apache.flink.table.api.Table
+import org.apache.flink.table.plan.util.UpdatingPlanChecker
+import org.apache.flink.table.utils.StreamTableTestUtil
+import org.junit.Assert._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.api.scala._
+import org.junit.Test
+
+
+class UpdatingPlanCheckerTest {
+
+ @Test
+ def testSelect(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(String, Int)]('a, 'b)
+ val resultTable = table.select('a, 'b)
+
+ util.verifyTableUniqueKey(resultTable, Nil)
+ }
+
+ @Test
+ def testGroupByWithoutKey(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(String, Int)]('a, 'b)
+
+ val resultTable = table
+ .groupBy('a)
+ .select('b.count)
+
+ util.verifyTableUniqueKey(resultTable, Nil)
+ }
+
+ @Test
+ def testGroupByWithoutKey2(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(String, Int, Int)]('a, 'b, 'c)
+
+ val resultTable = table
+ .groupBy('a, 'b)
+ .select('a, 'c.count)
+
+ util.verifyTableUniqueKey(resultTable, Nil)
+ }
+
+ @Test
+ def testGroupBy(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(String, Int)]('a, 'b)
+
+ val resultTable = table
+ .groupBy('a)
+ .select('a, 'b.count)
+
+ util.verifyTableUniqueKey(resultTable, Seq("a"))
+ }
+
+ @Test
+ def testGroupByWithDuplicateKey(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(String, Int)]('a, 'b)
+
+ val resultTable = table
+ .groupBy('a)
+ .select('a as 'a1, 'a as 'a2, 'b.count)
+
+ util.verifyTableUniqueKey(resultTable, Seq("a1", "a2"))
+ }
+
+ @Test
+ def testGroupWindow(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(String, Int)]('a, 'b, 'proctime.proctime)
+
+ val resultTable = table
+ .window(Tumble over 5.milli on 'proctime as 'w)
+ .groupBy('w, 'a)
+ .select('a, 'b.count, 'w.start as 'start)
+
+ util.verifyTableUniqueKey(resultTable, Seq("a", "start"))
+ }
+
+ @Test
+ def testForwardBothKeysForJoin1(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(Int, Int)]('pk, 'a)
+
+ val lTableWithPk = table
+ .groupBy('pk)
+ .select('pk as 'l1, 'pk as 'l2, 'pk as 'l3, 'a.max as 'l4, 'a.min as 'l5)
+
+ val rTableWithPk = table
+ .groupBy('pk)
+ .select('pk as 'r2, 'pk as 'r3, 'a.max as 'r1, 'a.min as 'r4, 'a.count as 'r5)
+
+ val resultTable = lTableWithPk
+ .join(rTableWithPk)
+ .where('l2 === 'r2 && 'l4 === 'r3 && 'l4 === 'r5 && 'l5 === 'r4)
+ .select('l1, 'l2, 'l3, 'l4, 'l5, 'r1, 'r2, 'r3, 'r4, 'r5)
+
+ util.verifyTableUniqueKey(resultTable, Seq("l1", "l2", "l3", "l4", "r2", "r3", "r5"))
+ }
+
+ @Test
+ def testForwardBothKeysForJoin2(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(Int, Int)]('pk, 'a)
+
+ val lTableWithPk = table
+ .groupBy('pk)
+ .select('pk as 'l1, 'pk as 'l2, 'pk as 'l3, 'a.max as 'l4, 'a.min as 'l5)
+
+ val rTableWithPk = table
+ .groupBy('pk)
+ .select('pk as 'r2, 'pk as 'r3, 'a.max as 'r1, 'a.min as 'r4, 'a.count as 'r5)
+
+ val resultTable = lTableWithPk
+ .join(rTableWithPk)
+ .where('l5 === 'r4)
+ .select('l1, 'l2, 'l3, 'l4, 'l5, 'r1, 'r2, 'r3, 'r4, 'r5)
+
+ util.verifyTableUniqueKey(resultTable, Seq("l1", "l2", "l3", "r2", "r3"))
+ }
+
+ @Test
+ def testJoinKeysEqualsLeftKeys(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(Int, Int)]('pk, 'a)
+
+ val lTableWithPk = table
+ .groupBy('pk)
+ .select('pk as 'leftpk, 'a.max as 'lefta)
+
+ val rTableWithPk = table
+ .groupBy('pk)
+ .select('pk as 'rightpk, 'a.max as 'righta)
+
+ val resultTable = lTableWithPk
+ .join(rTableWithPk)
+ .where('leftpk === 'righta)
+ .select('rightpk, 'lefta, 'righta)
+
+ util.verifyTableUniqueKey(resultTable, Seq("rightpk", "righta"))
+ }
+
+ @Test
+ def testJoinKeysEqualsRightKeys(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(Int, Int)]('pk, 'a)
+
+ val lTableWithPk = table
+ .groupBy('pk)
+ .select('pk as 'leftpk, 'a.max as 'lefta)
+
+ val rTableWithPk = table
+ .groupBy('pk)
+ .select('pk as 'rightpk, 'a.max as 'righta)
+
+ val resultTable = lTableWithPk
+ .join(rTableWithPk)
+ .where('lefta === 'rightpk)
+ .select('leftpk, 'lefta, 'righta)
+
+ util.verifyTableUniqueKey(resultTable, Seq("leftpk", "lefta"))
+ }
+
+
+ @Test
+ def testNonKeysJoin(): Unit = {
+ val util = new UpdatePlanCheckerUtil()
+ val table = util.addTable[(Int, Int)]('a, 'b)
+
+ val lTable = table
+ .select('a as 'a, 'b as 'b)
+
+ val rTable = table
+ .select('a as 'aa, 'b as 'bb)
+
+ val resultTable = lTable
+ .join(rTable)
+ .where('a === 'aa)
+ .select('a, 'aa, 'b, 'bb)
+
+ util.verifyTableUniqueKey(resultTable, Nil)
+ }
+}
+
+
+class UpdatePlanCheckerUtil extends StreamTableTestUtil {
+
+ def verifySqlUniqueKey(query: String, expected: Seq[String]): Unit = {
+ verifyTableUniqueKey(tableEnv.sql(query), expected)
+ }
+
+ def verifyTableUniqueKey(resultTable: Table, expected: Seq[String]): Unit = {
+ val relNode = resultTable.getRelNode
+ val optimized = tableEnv.optimize(relNode, updatesAsRetraction = false)
+ val actual = UpdatingPlanChecker.getUniqueKeyFields(optimized)
+
+ if (actual.isDefined) {
+ assertEquals(expected.sorted, actual.get.toList.sorted)
+ } else {
+ assertEquals(expected.sorted, Nil)
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
index facdbd4..0407496 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
@@ -17,24 +17,34 @@
*/
package org.apache.flink.table.runtime.harness
-import java.lang.{Long => JLong}
+import java.lang.{Long => JLong, Integer => JInt}
import java.util.concurrent.ConcurrentLinkedQueue
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness
import org.apache.flink.table.api.Types
-import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks, TupleRowKeySelector}
+import org.apache.flink.table.runtime.harness.HarnessTestBase.RowResultSortComparatorWithWatermarks
import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin}
import org.apache.flink.table.runtime.operators.KeyedCoProcessOperatorWithWatermarkDelay
-import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.table.api.StreamQueryConfig
+import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, TupleRowKeySelector}
+import org.apache.flink.table.runtime.join.NonWindowInnerJoin
+import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.types.Row
-import org.junit.Assert.assertEquals
+import org.junit.Assert.{assertEquals, assertTrue}
import org.junit.Test
class JoinHarnessTest extends HarnessTestBase {
+ protected var queryConfig =
+ new StreamQueryConfig().withIdleStateRetentionTime(Time.milliseconds(2), Time.milliseconds(4))
+
private val rowType = Types.ROW(
Types.LONG,
Types.STRING)
@@ -243,7 +253,7 @@ class JoinHarnessTest extends HarnessTestBase {
testHarness.close()
}
- /** a.c1 >= b.rowtime - 10 and a.rowtime <= b.rowtime + 20 **/
+ /** a.rowtime >= b.rowtime - 10 and a.rowtime <= b.rowtime + 20 **/
@Test
def testRowTimeJoinWithCommonBounds() {
@@ -423,4 +433,208 @@ class JoinHarnessTest extends HarnessTestBase {
checkWaterMark = true)
testHarness.close()
}
+
+ @Test
+ def testNonWindowInnerJoin() {
+
+ val joinReturnType = CRowTypeInfo(new RowTypeInfo(
+ Array[TypeInformation[_]](
+ INT_TYPE_INFO,
+ STRING_TYPE_INFO,
+ INT_TYPE_INFO,
+ STRING_TYPE_INFO),
+ Array("a", "b", "c", "d")))
+
+ val joinProcessFunc = new NonWindowInnerJoin(
+ rowType,
+ rowType,
+ joinReturnType,
+ "TestJoinFunction",
+ funcCode,
+ queryConfig)
+
+ val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] =
+ new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc)
+ val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] =
+ new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow](
+ operator,
+ new TupleRowKeySelector[Integer](0),
+ new TupleRowKeySelector[Integer](0),
+ BasicTypeInfo.INT_TYPE_INFO,
+ 1, 1, 0)
+
+ testHarness.open()
+
+ // left stream input
+ testHarness.setProcessingTime(1)
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa"), true)))
+ assertEquals(1, testHarness.numProcessingTimeTimers())
+ assertEquals(2, testHarness.numKeyedStateEntries())
+ testHarness.setProcessingTime(2)
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa"), true)))
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(2: JInt, "bbb"), true)))
+ assertEquals(2, testHarness.numProcessingTimeTimers())
+ assertEquals(4, testHarness.numKeyedStateEntries())
+ testHarness.setProcessingTime(3)
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa"), true)))
+ assertEquals(4, testHarness.numKeyedStateEntries())
+ assertEquals(2, testHarness.numProcessingTimeTimers())
+
+ // right stream input and output normally
+ testHarness.processElement2(new StreamRecord(
+ CRow(Row.of(1: JInt, "Hi1"), true)))
+ assertEquals(6, testHarness.numKeyedStateEntries())
+ assertEquals(3, testHarness.numProcessingTimeTimers())
+ testHarness.setProcessingTime(4)
+ testHarness.processElement2(new StreamRecord(
+ CRow(Row.of(2: JInt, "Hello1"), true)))
+ assertEquals(8, testHarness.numKeyedStateEntries())
+ assertEquals(4, testHarness.numProcessingTimeTimers())
+
+ // expired left stream record with key value of 1
+ testHarness.setProcessingTime(5)
+ testHarness.processElement2(new StreamRecord(
+ CRow(Row.of(1: JInt, "Hi2"), true)))
+ assertEquals(6, testHarness.numKeyedStateEntries())
+ assertEquals(3, testHarness.numProcessingTimeTimers())
+
+ // expired all left stream record
+ testHarness.setProcessingTime(6)
+ assertEquals(4, testHarness.numKeyedStateEntries())
+ assertEquals(2, testHarness.numProcessingTimeTimers())
+
+ // expired right stream record with key value of 2
+ testHarness.setProcessingTime(8)
+ assertEquals(2, testHarness.numKeyedStateEntries())
+ assertEquals(1, testHarness.numProcessingTimeTimers())
+
+ testHarness.setProcessingTime(10)
+ assertTrue(testHarness.numKeyedStateEntries() > 0)
+ // expired all right stream record
+ testHarness.setProcessingTime(11)
+ assertEquals(0, testHarness.numKeyedStateEntries())
+ assertEquals(0, testHarness.numProcessingTimeTimers())
+
+ val result = testHarness.getOutput
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true)))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true)))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true)))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(2: JInt, "bbb", 2: JInt, "Hello1"), true)))
+
+ verify(expectedOutput, result, new RowResultSortComparator())
+
+ testHarness.close()
+ }
+
+
+ @Test
+ def testNonWindowInnerJoinWithRetract() {
+
+ val joinReturnType = CRowTypeInfo(new RowTypeInfo(
+ Array[TypeInformation[_]](
+ INT_TYPE_INFO,
+ STRING_TYPE_INFO,
+ INT_TYPE_INFO,
+ STRING_TYPE_INFO),
+ Array("a", "b", "c", "d")))
+
+ val joinProcessFunc = new NonWindowInnerJoin(
+ rowType,
+ rowType,
+ joinReturnType,
+ "TestJoinFunction",
+ funcCode,
+ queryConfig)
+
+ val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] =
+ new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc)
+ val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] =
+ new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow](
+ operator,
+ new TupleRowKeySelector[Integer](0),
+ new TupleRowKeySelector[Integer](0),
+ BasicTypeInfo.INT_TYPE_INFO,
+ 1, 1, 0)
+
+ testHarness.open()
+
+ // left stream input
+ testHarness.setProcessingTime(1)
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa"), true)))
+ assertEquals(1, testHarness.numProcessingTimeTimers())
+ assertEquals(2, testHarness.numKeyedStateEntries())
+ testHarness.setProcessingTime(2)
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa"), true)))
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(2: JInt, "bbb"), true)))
+ assertEquals(2, testHarness.numProcessingTimeTimers())
+ assertEquals(4, testHarness.numKeyedStateEntries())
+ testHarness.setProcessingTime(3)
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa"), false)))
+ assertEquals(4, testHarness.numKeyedStateEntries())
+ assertEquals(2, testHarness.numProcessingTimeTimers())
+
+ // right stream input and output normally
+ testHarness.processElement2(new StreamRecord(
+ CRow(Row.of(1: JInt, "Hi1"), true)))
+ testHarness.processElement2(new StreamRecord(
+ CRow(Row.of(1: JInt, "Hi1"), false)))
+ assertEquals(5, testHarness.numKeyedStateEntries())
+ assertEquals(3, testHarness.numProcessingTimeTimers())
+ testHarness.setProcessingTime(4)
+ testHarness.processElement2(new StreamRecord(
+ CRow(Row.of(2: JInt, "Hello1"), true)))
+ assertEquals(7, testHarness.numKeyedStateEntries())
+ assertEquals(4, testHarness.numProcessingTimeTimers())
+
+ testHarness.processElement1(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa"), false)))
+ // expired left stream record with key value of 1
+ testHarness.setProcessingTime(5)
+ testHarness.processElement2(new StreamRecord(
+ CRow(Row.of(1: JInt, "Hi2"), true)))
+ testHarness.processElement2(new StreamRecord(
+ CRow(Row.of(1: JInt, "Hi2"), false)))
+ assertEquals(5, testHarness.numKeyedStateEntries())
+ assertEquals(3, testHarness.numProcessingTimeTimers())
+
+ // expired all left stream record
+ testHarness.setProcessingTime(6)
+ assertEquals(3, testHarness.numKeyedStateEntries())
+ assertEquals(2, testHarness.numProcessingTimeTimers())
+
+ // expired right stream record with key value of 2
+ testHarness.setProcessingTime(8)
+ assertEquals(0, testHarness.numKeyedStateEntries())
+ assertEquals(0, testHarness.numProcessingTimeTimers())
+
+ val result = testHarness.getOutput
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true)))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), false)))
+ expectedOutput.add(new StreamRecord(
+ CRow(Row.of(2: JInt, "bbb", 2: JInt, "Hello1"), true)))
+
+ verify(expectedOutput, result, new RowResultSortComparator())
+
+ testHarness.close()
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala
index dd14d7e..ad50761 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala
@@ -37,7 +37,7 @@ class NonWindowHarnessTest extends HarnessTestBase {
new StreamQueryConfig().withIdleStateRetentionTime(Time.seconds(2), Time.seconds(3))
@Test
- def testProcTimeNonWindow(): Unit = {
+ def testNonWindow(): Unit = {
val processFunction = new KeyedProcessOperator[String, CRow, CRow](
new GroupAggProcessFunction(
@@ -97,7 +97,7 @@ class NonWindowHarnessTest extends HarnessTestBase {
}
@Test
- def testProcTimeNonWindowWithRetract(): Unit = {
+ def testNonWindowWithRetract(): Unit = {
val processFunction = new KeyedProcessOperator[String, CRow, CRow](
new GroupAggProcessFunction(
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
index 85929e8..1c00521 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala
@@ -30,6 +30,7 @@ import org.apache.flink.table.api.{TableEnvironment, Types}
import org.apache.flink.table.expressions.Null
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase}
import org.apache.flink.types.Row
+import org.junit.Assert._
import org.junit._
import scala.collection.mutable
@@ -461,6 +462,60 @@ class JoinITCase extends StreamingWithStateTestBase {
StreamITCase.compareWithList(expected)
}
+ /** test non-window inner join **/
+ @Test
+ def testNonWindowInnerJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ env.setStateBackend(getStateBackend)
+ StreamITCase.clear
+
+ val data1 = new mutable.MutableList[(Int, Long, String)]
+ data1.+=((1, 1L, "Hi1"))
+ data1.+=((1, 2L, "Hi2"))
+ data1.+=((1, 2L, "Hi2"))
+ data1.+=((1, 5L, "Hi3"))
+ data1.+=((2, 7L, "Hi5"))
+ data1.+=((1, 9L, "Hi6"))
+ data1.+=((1, 8L, "Hi8"))
+ data1.+=((3, 8L, "Hi9"))
+
+ val data2 = new mutable.MutableList[(Int, Long, String)]
+ data2.+=((1, 1L, "HiHi"))
+ data2.+=((2, 2L, "HeHe"))
+ data2.+=((3, 2L, "HeHe"))
+
+ val t1 = env.fromCollection(data1).toTable(tEnv, 'a, 'b, 'c)
+ .select(('a === 3) ? (Null(Types.INT), 'a) as 'a, 'b, 'c)
+ val t2 = env.fromCollection(data2).toTable(tEnv, 'a, 'b, 'c)
+ .select(('a === 3) ? (Null(Types.INT), 'a) as 'a, 'b, 'c)
+
+ tEnv.registerTable("T1", t1)
+ tEnv.registerTable("T2", t2)
+
+ val sqlQuery =
+ """
+ |SELECT t2.a, t2.c, t1.c
+ |FROM T1 as t1 JOIN T2 as t2 ON
+ | t1.a = t2.a AND
+ | t1.b > t2.b
+ |""".stripMargin
+
+ val result = tEnv.sql(sqlQuery).toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = mutable.MutableList(
+ "1,HiHi,Hi2",
+ "1,HiHi,Hi2",
+ "1,HiHi,Hi3",
+ "1,HiHi,Hi6",
+ "1,HiHi,Hi8",
+ "2,HeHe,Hi5",
+ "null,HeHe,Hi9")
+
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
}
private class Row4WatermarkExtractor
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/JoinITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/JoinITCase.scala
new file mode 100644
index 0000000..8916c82
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/JoinITCase.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.stream.table
+
+import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment, TableException}
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase}
+import org.junit.Assert._
+import org.junit.Test
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.table.functions.aggfunctions.CountAggFunction
+import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, WeightedAvg}
+import org.apache.flink.types.Row
+
+import scala.collection.mutable
+
+class JoinITCase extends StreamingWithStateTestBase {
+
+ private val queryConfig = new StreamQueryConfig()
+ queryConfig.withIdleStateRetentionTime(Time.hours(1), Time.hours(2))
+
+ @Test
+ def testOutputWithPk(): Unit = {
+ // data input
+ val data1 = List(
+ (0, 0),
+ (1, 0),
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ (4, 4),
+ (5, 4),
+ (5, 5),
+ (5, null),
+ (6, null)
+ )
+
+ val data2 = List(
+ (0L, 0),
+ (1L, 1),
+ (2L, 0),
+ (2L, 1),
+ (2L, 2),
+ (3L, 3),
+ (4L, 4),
+ (5L, 4),
+ (5L, 5),
+ (6L, 6),
+ (7L, null),
+ (8L, null)
+ )
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+ env.setStateBackend(getStateBackend)
+
+ val leftTable = env.fromCollection(data1).toTable(tEnv, 'a, 'b)
+ val rightTable = env.fromCollection(data2).toTable(tEnv, 'bb, 'c)
+
+ val leftTableWithPk = leftTable
+ .groupBy('a)
+ .select('a, 'b.count as 'b)
+
+ val rightTableWithPk = rightTable
+ .groupBy('bb)
+ .select('bb, 'c.count as 'c)
+
+ leftTableWithPk
+ .join(rightTableWithPk, 'b === 'bb)
+ .select('a, 'b, 'c)
+ .writeToSink(new TestUpsertSink(Array("a,b"), false), queryConfig)
+
+ env.execute()
+ val results = RowCollector.getAndClearValues
+ val retracted = RowCollector.upsertResults(results, Array(0, 1))
+
+ val expected = Seq("0,1,1", "1,2,3", "2,1,1", "3,1,1", "4,1,1", "5,2,3", "6,0,1")
+ assertEquals(expected.sorted, retracted.sorted)
+ }
+
+
+ @Test
+ def testOutputWithoutPk(): Unit = {
+ // data input
+
+ val data1 = List(
+ (0, 0),
+ (1, 0),
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ (4, 4),
+ (5, 4),
+ (5, 5)
+ )
+
+ val data2 = List(
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ (2, 2, 2),
+ (3, 3, 3),
+ (4, 4, 4),
+ (5, 5, 5),
+ (5, 5, 5),
+ (6, 6, 6)
+ )
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+ env.setStateBackend(getStateBackend)
+
+ val leftTable = env.fromCollection(data1).toTable(tEnv, 'a, 'b)
+ val rightTable = env.fromCollection(data2).toTable(tEnv, 'bb, 'c, 'd)
+
+ val leftTableWithPk = leftTable
+ .groupBy('a)
+ .select('a, 'b.max as 'b)
+
+ leftTableWithPk
+ .join(rightTable, 'a === 'bb && ('a < 4 || 'a > 4))
+ .select('a, 'b, 'c, 'd)
+ .writeToSink(new TestRetractSink, queryConfig)
+
+ env.execute()
+ val results = RowCollector.getAndClearValues
+ val retracted = RowCollector.retractResults(results)
+ val expected = Seq("1,1,1,1", "1,1,1,1", "1,1,1,1", "1,1,1,1", "2,2,2,2", "3,3,3,3",
+ "5,5,5,5", "5,5,5,5")
+ assertEquals(expected.sorted, retracted.sorted)
+ }
+
+ @Test
+ def testJoinWithProcTimeAttributeOutput() {
+
+ val data1 = List(
+ (1L, 1, "LEFT:Hi"),
+ (2L, 2, "LEFT:Hello"),
+ (4L, 2, "LEFT:Hello"),
+ (8L, 3, "LEFT:Hello world"),
+ (16L, 3, "LEFT:Hello world"))
+
+ val data2 = List(
+ (1L, 1, "RIGHT:Hi"),
+ (2L, 2, "RIGHT:Hello"),
+ (4L, 2, "RIGHT:Hello"),
+ (8L, 3, "RIGHT:Hello world"),
+ (16L, 3, "RIGHT:Hello world"))
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.testResults = mutable.MutableList()
+
+ val stream1 = env
+ .fromCollection(data1)
+ val stream2 = env
+ .fromCollection(data2)
+
+ val table1 = stream1.toTable(tEnv, 'long_l, 'int_l, 'string_l, 'proctime_l.proctime)
+ val table2 = stream2.toTable(tEnv, 'long_r, 'int_r, 'string_r)
+ val countFun = new CountAggFunction
+ val weightAvgFun = new WeightedAvg
+ val countDistinct = new CountDistinct
+
+ val table = table1
+ .join(table2, 'long_l === 'long_r)
+ .select('long_l as 'long, 'int_r as 'int, 'string_r as 'string, 'proctime_l as 'proctime)
+
+ val windowedTable = table
+ .window(Tumble over 5.milli on 'proctime as 'w)
+ .groupBy('w, 'string)
+ .select('string, countFun('string), 'int.avg, weightAvgFun('long, 'int),
+ weightAvgFun('int, 'int), 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end,
+ countDistinct('long))
+
+ val results = windowedTable.toAppendStream[Row]
+ results.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ // Proctime window output uncertain results, so assert has been ignored here.
+ }
+
+
+ @Test(expected = classOf[TableException])
+ def testLeftOuterJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+ env.setStateBackend(getStateBackend)
+
+ val leftTable = env.fromCollection(List((1, 2))).toTable(tEnv, 'a, 'b)
+ val rightTable = env.fromCollection(List((1, 2))).toTable(tEnv, 'bb, 'c)
+
+ leftTable.leftOuterJoin(rightTable, 'a ==='bb).toAppendStream[Row]
+ }
+
+ @Test(expected = classOf[TableException])
+ def testRightOuterJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+ env.setStateBackend(getStateBackend)
+
+ val leftTable = env.fromCollection(List((1, 2))).toTable(tEnv, 'a, 'b)
+ val rightTable = env.fromCollection(List((1, 2))).toTable(tEnv, 'bb, 'c)
+
+ leftTable.rightOuterJoin(rightTable, 'a ==='bb).toAppendStream[Row]
+ }
+
+ @Test(expected = classOf[TableException])
+ def testFullOuterJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env)
+ StreamITCase.clear
+ env.setStateBackend(getStateBackend)
+
+ val leftTable = env.fromCollection(List((1, 2))).toTable(tEnv, 'a, 'b)
+ val rightTable = env.fromCollection(List((1, 2))).toTable(tEnv, 'bb, 'c)
+
+ leftTable.fullOuterJoin(rightTable, 'a ==='bb).toAppendStream[Row]
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9623b252/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
index f1badee..bda823e 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala
@@ -164,7 +164,7 @@ class TableSinkITCase extends AbstractTestBase {
env.execute()
val results = RowCollector.getAndClearValues
- val retracted = restractResults(results).sorted
+ val retracted = RowCollector.retractResults(results).sorted
val expected = List(
"2,1,1",
"5,1,2",
@@ -200,7 +200,7 @@ class TableSinkITCase extends AbstractTestBase {
"Received retraction messages for append only table",
results.exists(!_.f0))
- val retracted = restractResults(results).sorted
+ val retracted = RowCollector.retractResults(results).sorted
val expected = List(
"1970-01-01 00:00:00.005,4,8",
"1970-01-01 00:00:00.01,5,18",
@@ -238,7 +238,7 @@ class TableSinkITCase extends AbstractTestBase {
results.exists(_.f0 == false)
)
- val retracted = upsertResults(results, Array(0, 2)).sorted
+ val retracted = RowCollector.upsertResults(results, Array(0, 2)).sorted
val expected = List(
"1,5,true",
"7,1,true",
@@ -270,7 +270,7 @@ class TableSinkITCase extends AbstractTestBase {
"Received retraction messages for append only table",
results.exists(!_.f0))
- val retracted = upsertResults(results, Array(0, 1, 2)).sorted
+ val retracted = RowCollector.upsertResults(results, Array(0, 1, 2)).sorted
val expected = List(
"1,1970-01-01 00:00:00.005,1",
"2,1970-01-01 00:00:00.005,2",
@@ -308,7 +308,7 @@ class TableSinkITCase extends AbstractTestBase {
"Received retraction messages for append only table",
results.exists(!_.f0))
- val retracted = upsertResults(results, Array(0, 1, 2)).sorted
+ val retracted = RowCollector.upsertResults(results, Array(0, 1, 2)).sorted
val expected = List(
"1970-01-01 00:00:00.0,1970-01-01 00:00:00.005,1,1",
"1970-01-01 00:00:00.0,1970-01-01 00:00:00.005,2,2",
@@ -531,45 +531,6 @@ class TableSinkITCase extends AbstractTestBase {
r.toRetractStream[Row]
}
-
- /** Converts a list of retraction messages into a list of final results. */
- private def restractResults(results: List[JTuple2[JBool, Row]]): List[String] = {
-
- val retracted = results
- .foldLeft(Map[String, Int]()){ (m: Map[String, Int], v: JTuple2[JBool, Row]) =>
- val cnt = m.getOrElse(v.f1.toString, 0)
- if (v.f0) {
- m + (v.f1.toString -> (cnt + 1))
- } else {
- m + (v.f1.toString -> (cnt - 1))
- }
- }.filter{ case (_, c: Int) => c != 0 }
-
- assertFalse(
- "Received retracted rows which have not been accumulated.",
- retracted.exists{ case (_, c: Int) => c < 0})
-
- retracted.flatMap { case (r: String, c: Int) => (0 until c).map(_ => r) }.toList
- }
-
- /** Converts a list of upsert messages into a list of final results. */
- private def upsertResults(results: List[JTuple2[JBool, Row]], keys: Array[Int]): List[String] = {
-
- def getKeys(r: Row): List[String] =
- keys.foldLeft(List[String]())((k, i) => r.getField(i).toString :: k)
-
- val upserted = results.foldLeft(Map[String, String]()){ (o: Map[String, String], r) =>
- val key = getKeys(r.f1).mkString("")
- if (r.f0) {
- o + (key -> r.f1.toString)
- } else {
- o - key
- }
- }
-
- upserted.values.toList
- }
-
}
private[flink] class TestAppendSink extends AppendStreamTableSink[Row] {
@@ -692,4 +653,42 @@ object RowCollector {
sink.clear()
out
}
+
+ /** Converts a list of retraction messages into a list of final results. */
+ def retractResults(results: List[JTuple2[JBool, Row]]): List[String] = {
+
+ val retracted = results
+ .foldLeft(Map[String, Int]()){ (m: Map[String, Int], v: JTuple2[JBool, Row]) =>
+ val cnt = m.getOrElse(v.f1.toString, 0)
+ if (v.f0) {
+ m + (v.f1.toString -> (cnt + 1))
+ } else {
+ m + (v.f1.toString -> (cnt - 1))
+ }
+ }.filter{ case (_, c: Int) => c != 0 }
+
+ assertFalse(
+ "Received retracted rows which have not been accumulated.",
+ retracted.exists{ case (_, c: Int) => c < 0})
+
+ retracted.flatMap { case (r: String, c: Int) => (0 until c).map(_ => r) }.toList
+ }
+
+ /** Converts a list of upsert messages into a list of final results. */
+ def upsertResults(results: List[JTuple2[JBool, Row]], keys: Array[Int]): List[String] = {
+
+ def getKeys(r: Row): List[String] =
+ keys.foldLeft(List[String]())((k, i) => r.getField(i).toString :: k)
+
+ val upserted = results.foldLeft(Map[String, String]()){ (o: Map[String, String], r) =>
+ val key = getKeys(r.f1).mkString("")
+ if (r.f0) {
+ o + (key -> r.f1.toString)
+ } else {
+ o - key
+ }
+ }
+
+ upserted.values.toList
+ }
}