You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2022/08/10 02:23:13 UTC
[flink] branch master updated: [FLINK-26929][table-runtime] Introduce adaptive hash join strategy for batch hash join (#20365)
This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 10acb641a03 [FLINK-26929][table-runtime] Introduce adaptive hash join strategy for batch hash join (#20365)
10acb641a03 is described below
commit 10acb641a036ae9d171452695b660c46d97c865a
Author: Ron <ld...@163.com>
AuthorDate: Wed Aug 10 10:23:05 2022 +0800
[FLINK-26929][table-runtime] Introduce adaptive hash join strategy for batch hash join (#20365)
---
.../planner/plan/nodes/exec/ExecNodeConfig.java | 4 +-
.../plan/nodes/exec/batch/BatchExecHashJoin.java | 60 ++-
.../nodes/exec/batch/BatchExecSortMergeJoin.java | 61 +--
.../plan/utils/SorMergeJoinOperatorUtil.java | 83 ++++
.../planner/codegen/LongHashJoinGenerator.scala | 110 ++++-
.../flink/table/planner/plan/utils/SortUtil.scala | 12 +
.../codegen/LongAdaptiveHashJoinGeneratorTest.java | 134 ++++++
.../planner/codegen/LongHashJoinGeneratorTest.java | 58 +--
.../runtime/hashtable/BaseHybridHashTable.java | 11 +
.../runtime/hashtable/BinaryHashPartition.java | 10 +-
.../table/runtime/hashtable/BinaryHashTable.java | 122 ++++-
.../table/runtime/hashtable/LongHashPartition.java | 4 +-
.../runtime/hashtable/LongHybridHashTable.java | 123 ++++-
.../io/BinaryRowChannelInputViewIterator.java | 6 +-
...shPartitionChannelReaderInputViewIterator.java} | 35 +-
.../runtime/operators/join/HashJoinOperator.java | 112 ++++-
...oinOperator.java => SortMergeJoinFunction.java} | 120 +++--
.../operators/join/SortMergeJoinOperator.java | 500 +--------------------
.../runtime/hashtable/BinaryHashTableTest.java | 61 ++-
.../table/runtime/hashtable/LongHashTableTest.java | 61 ++-
.../join/Int2AdaptiveHashJoinOperatorTest.java | 450 +++++++++++++++++++
.../operators/join/Int2HashJoinOperatorTest.java | 287 +-----------
...Test.java => Int2HashJoinOperatorTestBase.java} | 415 ++++++-----------
.../join/Int2SortMergeJoinOperatorTest.java | 8 +-
.../operators/join/SortMergeJoinIteratorTest.java | 2 +-
.../join/String2HashJoinOperatorTest.java | 98 +++-
.../join/String2SortMergeJoinOperatorTest.java | 108 ++---
.../apache/flink/table/runtime/util/JoinUtil.java | 43 ++
.../runtime/util/UniformBinaryRowGenerator.java | 2 +-
29 files changed, 1739 insertions(+), 1361 deletions(-)
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/ExecNodeConfig.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/ExecNodeConfig.java
index b33396fbec0..f56165f66c0 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/ExecNodeConfig.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/ExecNodeConfig.java
@@ -19,6 +19,7 @@
package org.apache.flink.table.planner.plan.nodes.exec;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.api.TableConfig;
@@ -40,7 +41,8 @@ public final class ExecNodeConfig implements ReadableConfig {
private final ReadableConfig nodeConfig;
- ExecNodeConfig(TableConfig tableConfig, ReadableConfig nodeConfig) {
+ @VisibleForTesting
+ public ExecNodeConfig(TableConfig tableConfig, ReadableConfig nodeConfig) {
this.nodeConfig = nodeConfig;
this.tableConfig = tableConfig;
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java
index a4f989264fd..85b3e017d50 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java
@@ -37,11 +37,13 @@ import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTransl
import org.apache.flink.table.planner.plan.nodes.exec.spec.JoinSpec;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.JoinUtil;
+import org.apache.flink.table.planner.plan.utils.SorMergeJoinOperatorUtil;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.HashJoinOperator;
import org.apache.flink.table.runtime.operators.join.HashJoinType;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
@@ -162,7 +164,7 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
probeTransform = rightInputTransform;
probeProj = rightProj;
probeType = rightType;
- probeRowCount = estimatedLeftRowCount;
+ probeRowCount = estimatedRightRowCount;
probeKeys = rightKeys;
} else {
buildTransform = rightInputTransform;
@@ -189,6 +191,28 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
joinType.isRightOuter(),
joinType == FlinkJoinType.SEMI,
joinType == FlinkJoinType.ANTI);
+
+ long externalBufferMemory =
+ config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY)
+ .getBytes();
+ long managedMemory = getLargeManagedMemory(joinType, config);
+
+ // sort merge join function
+ SortMergeJoinFunction sortMergeJoinFunction =
+ SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+ planner.getFlinkContext().getClassLoader(),
+ config,
+ joinType,
+ leftType,
+ rightType,
+ leftKeys,
+ rightKeys,
+ keyType,
+ leftIsBuild,
+ joinSpec.getFilterNulls(),
+ condFunc,
+ 1.0 * externalBufferMemory / managedMemory);
+
if (LongHashJoinGenerator.support(hashJoinType, keyType, joinSpec.getFilterNulls())) {
operator =
LongHashJoinGenerator.gen(
@@ -203,12 +227,15 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
buildRowSize,
buildRowCount,
reverseJoin,
- condFunc);
+ condFunc,
+ leftIsBuild,
+ sortMergeJoinFunction);
} else {
operator =
SimpleOperatorFactory.of(
HashJoinOperator.newHashJoinOperator(
hashJoinType,
+ leftIsBuild,
condFunc,
reverseJoin,
joinSpec.getFilterNulls(),
@@ -218,12 +245,10 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
buildRowSize,
buildRowCount,
probeRowCount,
- keyType));
+ keyType,
+ sortMergeJoinFunction));
}
- long managedMemory =
- config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY).getBytes();
-
return ExecNodeUtil.createTwoInputTransformation(
buildTransform,
probeTransform,
@@ -234,4 +259,27 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
probeTransform.getParallelism(),
managedMemory);
}
+
+ private long getLargeManagedMemory(FlinkJoinType joinType, ExecNodeConfig config) {
+ long hashJoinManagedMemory =
+ config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY).getBytes();
+
+ // The memory used by SortMergeJoinIterator that buffer the matched rows, each side needs
+ // this memory if it is full outer join
+ long externalBufferMemory =
+ config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY)
+ .getBytes();
+ // The memory used by BinaryExternalSorter for sort, the left and right side both need it
+ long sortMemory =
+ config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_SORT_MEMORY).getBytes();
+ int externalBufferNum = 1;
+ if (joinType == FlinkJoinType.FULL) {
+ externalBufferNum = 2;
+ }
+ long sortMergeJoinManagedMemory = externalBufferMemory * externalBufferNum + sortMemory * 2;
+
+ // Due to hash join maybe fallback to sort merge join, so here managed memory choose the
+ // large one
+ return Math.max(hashJoinManagedMemory, sortMergeJoinManagedMemory);
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortMergeJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortMergeJoin.java
index 99b7ffa89c7..a04b90bd9d8 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortMergeJoin.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortMergeJoin.java
@@ -23,9 +23,6 @@ import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
-import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
-import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
@@ -33,12 +30,12 @@ import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
-import org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.JoinUtil;
-import org.apache.flink.table.planner.plan.utils.SortUtil;
+import org.apache.flink.table.planner.plan.utils.SorMergeJoinOperatorUtil;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
@@ -130,44 +127,20 @@ public class BatchExecSortMergeJoin extends ExecNodeBase<RowData>
long managedMemory = externalBufferMemory * externalBufferNum + sortMemory * 2;
- SortCodeGenerator leftSortGen =
- newSortGen(config, planner.getFlinkContext().getClassLoader(), leftKeys, leftType);
- SortCodeGenerator rightSortGen =
- newSortGen(
- config, planner.getFlinkContext().getClassLoader(), rightKeys, rightType);
-
- int[] keyPositions = IntStream.range(0, leftKeys.length).toArray();
- SortMergeJoinOperator operator =
- new SortMergeJoinOperator(
- 1.0 * externalBufferMemory / managedMemory,
+ SortMergeJoinFunction sortMergeJoinFunction =
+ SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+ planner.getFlinkContext().getClassLoader(),
+ config,
joinType,
+ leftType,
+ rightType,
+ leftKeys,
+ rightKeys,
+ keyType,
leftIsSmaller,
+ filterNulls,
condFunc,
- ProjectionCodeGenerator.generateProjection(
- new CodeGeneratorContext(
- config, planner.getFlinkContext().getClassLoader()),
- "SMJProjection",
- leftType,
- keyType,
- leftKeys),
- ProjectionCodeGenerator.generateProjection(
- new CodeGeneratorContext(
- config, planner.getFlinkContext().getClassLoader()),
- "SMJProjection",
- rightType,
- keyType,
- rightKeys),
- leftSortGen.generateNormalizedKeyComputer("LeftComputer"),
- leftSortGen.generateRecordComparator("LeftComparator"),
- rightSortGen.generateNormalizedKeyComputer("RightComputer"),
- rightSortGen.generateRecordComparator("RightComparator"),
- newSortGen(
- config,
- planner.getFlinkContext().getClassLoader(),
- keyPositions,
- keyType)
- .generateRecordComparator("KeyComparator"),
- filterNulls);
+ 1.0 * externalBufferMemory / managedMemory);
Transformation<RowData> leftInputTransform =
(Transformation<RowData>) leftInputEdge.translateToPlan(planner);
@@ -178,15 +151,9 @@ public class BatchExecSortMergeJoin extends ExecNodeBase<RowData>
rightInputTransform,
createTransformationName(config),
createTransformationDescription(config),
- SimpleOperatorFactory.of(operator),
+ SimpleOperatorFactory.of(new SortMergeJoinOperator(sortMergeJoinFunction)),
InternalTypeInfo.of(getOutputType()),
rightInputTransform.getParallelism(),
managedMemory);
}
-
- private SortCodeGenerator newSortGen(
- ExecNodeConfig config, ClassLoader classLoader, int[] originalKeys, RowType inputType) {
- SortSpec sortSpec = SortUtil.getAscendingSortSpec(originalKeys);
- return new SortCodeGenerator(config, classLoader, inputType, sortSpec);
- }
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java
new file mode 100644
index 00000000000..634ceb345dd
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.utils;
+
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
+import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
+import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator;
+import org.apache.flink.table.types.logical.RowType;
+
+import java.util.stream.IntStream;
+
+/** Utility for {@link SortMergeJoinOperator}. */
+public class SorMergeJoinOperatorUtil {
+
+ public static SortMergeJoinFunction getSortMergeJoinFunction(
+ ClassLoader classLoader,
+ ExecNodeConfig config,
+ FlinkJoinType joinType,
+ RowType leftType,
+ RowType rightType,
+ int[] leftKeys,
+ int[] rightKeys,
+ RowType keyType,
+ boolean leftIsSmaller,
+ boolean[] filterNulls,
+ GeneratedJoinCondition condFunc,
+ double externalBufferMemRatio) {
+ int[] keyPositions = IntStream.range(0, leftKeys.length).toArray();
+
+ SortCodeGenerator leftSortGen =
+ SortUtil.newSortGen(config, classLoader, leftKeys, leftType);
+ SortCodeGenerator rightSortGen =
+ SortUtil.newSortGen(config, classLoader, rightKeys, rightType);
+
+ return new SortMergeJoinFunction(
+ externalBufferMemRatio,
+ joinType,
+ leftIsSmaller,
+ condFunc,
+ ProjectionCodeGenerator.generateProjection(
+ new CodeGeneratorContext(config, classLoader),
+ "SMJProjection",
+ leftType,
+ keyType,
+ leftKeys),
+ ProjectionCodeGenerator.generateProjection(
+ new CodeGeneratorContext(config, classLoader),
+ "SMJProjection",
+ rightType,
+ keyType,
+ rightKeys),
+ leftSortGen.generateNormalizedKeyComputer("LeftComputer"),
+ leftSortGen.generateRecordComparator("LeftComparator"),
+ rightSortGen.generateNormalizedKeyComputer("RightComputer"),
+ rightSortGen.generateRecordComparator("RightComparator"),
+ SortUtil.newSortGen(config, classLoader, keyPositions, keyType)
+ .generateRecordComparator("KeyComparator"),
+ filterNulls);
+ }
+
+ private SorMergeJoinOperatorUtil() {}
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
index 7d9c3ac9634..abea1abce62 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
@@ -24,10 +24,11 @@ import org.apache.flink.table.data.utils.JoinedRowData
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.{generateCollect, INPUT_SELECTION}
import org.apache.flink.table.runtime.generated.{GeneratedJoinCondition, GeneratedProjection}
-import org.apache.flink.table.runtime.hashtable.{LongHashPartition, LongHybridHashTable}
+import org.apache.flink.table.runtime.hashtable.{LongHashPartition, LongHybridHashTable, ProbeIterator}
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
-import org.apache.flink.table.runtime.operators.join.HashJoinType
+import org.apache.flink.table.runtime.operators.join.{HashJoinType, SortMergeJoinFunction}
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer
+import org.apache.flink.table.runtime.util.{RowIterator, StreamRecordCollector}
import org.apache.flink.table.types.logical._
import org.apache.flink.table.types.logical.LogicalTypeRoot._
@@ -108,7 +109,9 @@ object LongHashJoinGenerator {
buildRowSize: Int,
buildRowCount: Long,
reverseJoinFunction: Boolean,
- condFunc: GeneratedJoinCondition): CodeGenOperatorFactory[RowData] = {
+ condFunc: GeneratedJoinCondition,
+ leftIsBuild: Boolean,
+ sortMergeJoinFunction: SortMergeJoinFunction): CodeGenOperatorFactory[RowData] = {
val buildSer = new BinaryRowDataSerializer(buildType.getFieldCount)
val probeSer = new BinaryRowDataSerializer(probeType.getFieldCount)
@@ -143,6 +146,17 @@ object LongHashJoinGenerator {
ctx.addReusableOpenStatement(s"condFunc.open(new ${className[Configuration]}());")
ctx.addReusableCloseStatement(s"condFunc.close();")
+ val leftIsBuildTerm = newName("leftIsBuild")
+ ctx.addReusableMember(s"private final boolean $leftIsBuildTerm = $leftIsBuild;")
+
+ val smjFunctionTerm = className[SortMergeJoinFunction]
+ ctx.addReusableMember(s"private final $smjFunctionTerm sortMergeJoinFunction;")
+ val smjFunctionRefs = ctx.addReusableObject(Array(sortMergeJoinFunction), "smjFunctionRefs")
+ ctx.addReusableInitStatement(s"sortMergeJoinFunction = $smjFunctionRefs[0];")
+
+ val fallbackSMJ = newName("fallbackSMJ")
+ ctx.addReusableMember(s"private transient boolean $fallbackSMJ = false;")
+
val gauge = classOf[Gauge[_]].getCanonicalName
ctx.addReusableOpenStatement(
s"""
@@ -300,11 +314,91 @@ object LongHashJoinGenerator {
|}
""".stripMargin)
+ // fallback to sort merge join in probe phase
+ val rowIter = classOf[RowIterator[_]].getCanonicalName
+ ctx.addReusableMember(s"""
+ |private void fallbackSMJProcessPartition() throws Exception {
+ | if(!table.getPartitionsPendingForSMJ().isEmpty()) {
+ | LOG.info(
+ | "Fallback to sort merge join to process spilled partitions.");
+ | initialSortMergeJoinFunction();
+ | $fallbackSMJ = true;
+ |
+ | for(${classOf[LongHashPartition].getCanonicalName} p :
+ | table.getPartitionsPendingForSMJ()) {
+ | $rowIter<$BINARY_ROW> buildSideIter =
+ | table.getSpilledPartitionBuildSideIter(p);
+ | while (buildSideIter.advanceNext()) {
+ | processSortMergeJoinElement1(buildSideIter.getRow());
+ | }
+ |
+ | ${classOf[ProbeIterator].getCanonicalName} probeIter =
+ | table.getSpilledPartitionProbeSideIter(p);
+ | $BINARY_ROW probeNext;
+ | while ((probeNext = probeIter.next()) != null) {
+ | processSortMergeJoinElement2(probeNext);
+ | }
+ | }
+ |
+ | closeHashTable();
+ |
+ | sortMergeJoinFunction.endInput(1);
+ | sortMergeJoinFunction.endInput(2);
+ | LOG.info("Finish sort merge join for spilled partitions.");
+ | }
+ |}
+ """.stripMargin)
+
+ val collector = classOf[StreamRecordCollector[_]].getCanonicalName
+ ctx.addReusableMember(s"""
+ |private void initialSortMergeJoinFunction() throws Exception {
+ | sortMergeJoinFunction.open(
+ | this.getContainingTask(),
+ | this.getOperatorConfig(),
+ | new $collector<$ROW_DATA>(output),
+ | this.computeMemorySize(),
+ | this.getRuntimeContext(),
+ | this.getMetricGroup());
+ |}
+ """.stripMargin)
+
+ ctx.addReusableMember(
+ s"""
+ |private void processSortMergeJoinElement1($ROW_DATA rowData) throws Exception {
+ | if($leftIsBuild) {
+ | sortMergeJoinFunction.processElement1(rowData);
+ | } else {
+ | sortMergeJoinFunction.processElement2(rowData);
+ | }
+ |}
+ """.stripMargin)
+
+ ctx.addReusableMember(
+ s"""
+ |private void processSortMergeJoinElement2($ROW_DATA rowData) throws Exception {
+ | if($leftIsBuild) {
+ | sortMergeJoinFunction.processElement2(rowData);
+ | } else {
+ | sortMergeJoinFunction.processElement1(rowData);
+ | }
+ |}
+ """.stripMargin)
+
+ ctx.addReusableMember(s"""
+ |private void closeHashTable() {
+ | if (this.table != null) {
+ | this.table.close();
+ | this.table.free();
+ | this.table = null;
+ | }
+ |}
+ """.stripMargin)
+
ctx.addReusableCloseStatement(s"""
- |if (this.table != null) {
- | this.table.close();
- | this.table.free();
- | this.table = null;
+ |closeHashTable();
+ |
+ |if ($fallbackSMJ) {
+ | sortMergeJoinFunction.close();
|}
""".stripMargin)
@@ -352,6 +446,8 @@ object LongHashJoinGenerator {
| joinWithNextKey();
|}
|LOG.info("Finish rebuild phase.");
+ |
+ |fallbackSMJProcessPartition();
""".stripMargin)
)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
index 74272a4717e..e3c7cfc5e86 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
@@ -20,7 +20,10 @@ package org.apache.flink.table.planner.plan.utils
import org.apache.flink.api.common.operators.Order
import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkPlannerImpl
+import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig
import org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec
+import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.rel.`type`._
import org.apache.calcite.rel.{RelCollation, RelFieldCollation}
@@ -119,6 +122,15 @@ object SortUtil {
builder.build()
}
+ def newSortGen(
+ config: ExecNodeConfig,
+ classLoader: ClassLoader,
+ originalKeys: Array[Int],
+ inputType: RowType): SortCodeGenerator = {
+ val sortSpec = SortUtil.getAscendingSortSpec(originalKeys)
+ new SortCodeGenerator(config, classLoader, inputType, sortSpec)
+ }
+
def directionToOrder(direction: Direction): Order = {
direction match {
case Direction.ASCENDING | Direction.STRICTLY_ASCENDING => Order.ASCENDING
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongAdaptiveHashJoinGeneratorTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongAdaptiveHashJoinGeneratorTest.java
new file mode 100644
index 00000000000..6dfe26b54a5
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongAdaptiveHashJoinGeneratorTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.codegen;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
+import org.apache.flink.table.planner.plan.utils.SorMergeJoinOperatorUtil;
+import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.HashJoinType;
+import org.apache.flink.table.runtime.operators.join.Int2AdaptiveHashJoinOperatorTest;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.RowType;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for adaptive {@link LongHashJoinGenerator}. */
+public class LongAdaptiveHashJoinGeneratorTest extends Int2AdaptiveHashJoinOperatorTest {
+
+ @Override
+ public Object newOperator(
+ long memorySize,
+ FlinkJoinType flinkJoinType,
+ HashJoinType hashJoinType,
+ boolean buildLeft,
+ boolean reverseJoinFunction) {
+ return getLongHashJoinOperator(flinkJoinType, hashJoinType, buildLeft, reverseJoinFunction);
+ }
+
+ @Override
+ public void testBuildLeftAntiJoinFallbackToSMJ() {}
+
+ @Override
+ public void testBuildLeftSemiJoinFallbackToSMJ() {}
+
+ @Override
+ public void testBuildFirstHashLeftOutJoinFallbackToSMJ() {}
+
+ @Override
+ public void testBuildSecondHashRightOutJoinFallbackToSMJ() {}
+
+ @Override
+ public void testBuildFirstHashFullOutJoinFallbackToSMJ() {}
+
+ static Object getLongHashJoinOperator(
+ FlinkJoinType flinkJoinType,
+ HashJoinType hashJoinType,
+ boolean buildLeft,
+ boolean reverseJoinFunction) {
+ RowType keyType = RowType.of(new IntType());
+ boolean[] filterNulls = new boolean[] {true};
+ assertThat(LongHashJoinGenerator.support(hashJoinType, keyType, filterNulls)).isTrue();
+
+ RowType buildType = RowType.of(new IntType(), new IntType());
+ RowType probeType = RowType.of(new IntType(), new IntType());
+ int[] buildKeyMapping = new int[] {0};
+ int[] probeKeyMapping = new int[] {0};
+ GeneratedJoinCondition condFunc =
+ new GeneratedJoinCondition(
+ MyJoinCondition.class.getCanonicalName(), "", new Object[0]) {
+ @Override
+ public JoinCondition newInstance(ClassLoader classLoader) {
+ return new MyJoinCondition(new Object[0]);
+ }
+ };
+
+ SortMergeJoinFunction sortMergeJoinFunction;
+ if (buildLeft) {
+ sortMergeJoinFunction =
+ SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+ Thread.currentThread().getContextClassLoader(),
+ new ExecNodeConfig(TableConfig.getDefault(), new Configuration()),
+ flinkJoinType,
+ buildType,
+ probeType,
+ buildKeyMapping,
+ probeKeyMapping,
+ keyType,
+ buildLeft,
+ filterNulls,
+ condFunc,
+ 0);
+ } else {
+ sortMergeJoinFunction =
+ SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+ Thread.currentThread().getContextClassLoader(),
+ new ExecNodeConfig(TableConfig.getDefault(), new Configuration()),
+ flinkJoinType,
+ probeType,
+ buildType,
+ probeKeyMapping,
+ buildKeyMapping,
+ keyType,
+ buildLeft,
+ filterNulls,
+ condFunc,
+ 0);
+ }
+ return LongHashJoinGenerator.gen(
+ new Configuration(),
+ Thread.currentThread().getContextClassLoader(),
+ hashJoinType,
+ keyType,
+ buildType,
+ probeType,
+ buildKeyMapping,
+ probeKeyMapping,
+ 20,
+ 10000,
+ reverseJoinFunction,
+ condFunc,
+ buildLeft,
+ sortMergeJoinFunction);
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongHashJoinGeneratorTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongHashJoinGeneratorTest.java
index d1ce2c9e77f..0aaa40e5403 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongHashJoinGeneratorTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongHashJoinGeneratorTest.java
@@ -18,75 +18,47 @@
package org.apache.flink.table.planner.codegen;
-import org.apache.flink.api.common.functions.AbstractRichFunction;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
-import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.HashJoinType;
import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTest;
-import org.apache.flink.table.types.logical.IntType;
-import org.apache.flink.table.types.logical.RowType;
import org.junit.Test;
-import static org.assertj.core.api.Assertions.assertThat;
-
/** Test for {@link LongHashJoinGenerator}. */
public class LongHashJoinGeneratorTest extends Int2HashJoinOperatorTest {
@Override
- public Object newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) {
- RowType keyType = RowType.of(new IntType());
- assertThat(LongHashJoinGenerator.support(type, keyType, new boolean[] {true})).isTrue();
- return LongHashJoinGenerator.gen(
- new Configuration(),
- Thread.currentThread().getContextClassLoader(),
- type,
- keyType,
- RowType.of(new IntType(), new IntType()),
- RowType.of(new IntType(), new IntType()),
- new int[] {0},
- new int[] {0},
- 20,
- 10000,
- reverseJoinFunction,
- new GeneratedJoinCondition(
- MyJoinCondition.class.getCanonicalName(), "", new Object[0]));
+ public Object newOperator(
+ long memorySize,
+ FlinkJoinType flinkJoinType,
+ HashJoinType hashJoinType,
+ boolean buildLeft,
+ boolean reverseJoinFunction) {
+ return LongAdaptiveHashJoinGeneratorTest.getLongHashJoinOperator(
+ flinkJoinType, hashJoinType, buildLeft, reverseJoinFunction);
}
@Test
@Override
- public void testBuildLeftSemiJoin() throws Exception {}
+ public void testBuildLeftSemiJoin() {}
@Test
@Override
- public void testBuildSecondHashFullOutJoin() throws Exception {}
+ public void testBuildSecondHashFullOutJoin() {}
@Test
@Override
- public void testBuildSecondHashRightOutJoin() throws Exception {}
+ public void testBuildSecondHashRightOutJoin() {}
@Test
@Override
- public void testBuildLeftAntiJoin() throws Exception {}
+ public void testBuildLeftAntiJoin() {}
@Test
@Override
- public void testBuildFirstHashLeftOutJoin() throws Exception {}
+ public void testBuildFirstHashLeftOutJoin() {}
@Test
@Override
- public void testBuildFirstHashFullOutJoin() throws Exception {}
-
- /** Test cond. */
- public static class MyJoinCondition extends AbstractRichFunction implements JoinCondition {
-
- public MyJoinCondition(Object[] reference) {}
-
- @Override
- public boolean apply(RowData in1, RowData in2) {
- return true;
- }
- }
+ public void testBuildFirstHashFullOutJoin() {}
}
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
index 5c5a3ddf518..f7e691057c7 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
@@ -387,6 +387,17 @@ public abstract class BaseHybridHashTable implements MemorySegmentPool {
return;
}
+ // clear the current build side channel, if there exist one
+ if (this.currentSpilledBuildSide != null) {
+ try {
+ this.currentSpilledBuildSide.getChannel().closeAndDelete();
+ } catch (Throwable t) {
+ LOG.warn(
+ "Could not close and delete the temp file for the current spilled partition build side.",
+ t);
+ }
+ }
+
// clear the current probe side channel, if there is one
if (this.currentSpilledProbeSide != null) {
try {
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashPartition.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashPartition.java
index ab89c9aa01e..eaea42361f1 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashPartition.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashPartition.java
@@ -60,7 +60,15 @@ public class BinaryHashPartition extends AbstractPagedInputView implements Seeka
final int partitionNumber; // the number of the partition
long probeSideRecordCounter; // number of probe-side records in this partition
+
+ /**
+ * These buffers are null when create the BinaryHashPartition first, it will be assigned value
+ * in two case: 1) when build stage end, if this partition in memory, all the segments occupied
+ * by this partition will be returned to it; 2) when build stage end, if this partition has
+ * spilled to disk, the data read from disk next time will be assigned to it.
+ */
private MemorySegment[] partitionBuffers;
+
private int currentBufferNum;
private int finalBufferLimit;
@@ -621,7 +629,7 @@ public class BinaryHashPartition extends AbstractPagedInputView implements Seeka
}
}
- final long getPointer() {
+ long getPointer() {
return this.currentPointer;
}
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
index fe5f90c5dc8..26cfb0bcd9a 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
@@ -95,6 +95,12 @@ public class BinaryHashTable extends BaseHybridHashTable {
/** The partitions that are built by processing the current partition. */
final ArrayList<BinaryHashPartition> partitionsBeingBuilt;
+ /**
+ * The partitions that have been spilled previously and are pending to be processed by sort
+ * merge join operator.
+ */
+ private final List<BinaryHashPartition> partitionsPendingForSMJ;
+
/**
* BitSet which used to mark whether the element(int build side) has successfully matched during
* probe phase. As there are 9 elements in each bucket, we assign 2 bytes to BitSet.
@@ -193,6 +199,7 @@ public class BinaryHashTable extends BaseHybridHashTable {
this.partitionsBeingBuilt = new ArrayList<>();
this.partitionsPending = new ArrayList<>();
+ this.partitionsPendingForSMJ = new ArrayList<>();
createPartitions(initPartitionFanOut, 0);
}
@@ -399,10 +406,37 @@ public class BinaryHashTable extends BaseHybridHashTable {
this.probeMatchedPhase = true;
this.buildIterVisited = false;
+ final int nextRecursionLevel = p.getRecursionLevel() + 1;
+ if (nextRecursionLevel == 2) {
+ LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
+ } else if (nextRecursionLevel > MAX_RECURSION_DEPTH) {
+ LOG.info(
+ "Partition number [{}] recursive level more than {}, process the partition using SortMergeJoin later.",
+ p.getPartitionNumber(),
+ MAX_RECURSION_DEPTH);
+ // if the partition has spilled to disk more than three times, process it by sort merge
+ // join later
+ this.partitionsPendingForSMJ.add(p);
+ // also need to remove it from pending list
+ this.partitionsPending.remove(0);
+ // recursively get the next partition
+ return prepareNextPartition();
+ }
// build the next table; memory must be allocated after this call
- buildTableFromSpilledPartition(p);
+ buildTableFromSpilledPartition(p, nextRecursionLevel);
// set the probe side
+ setPartitionProbeReader(p);
+
+ // unregister the pending partition
+ this.partitionsPending.remove(0);
+ this.currentRecursionDepth = p.getRecursionLevel() + 1;
+
+ // recursively get the next
+ return nextMatching();
+ }
+
+ private void setPartitionProbeReader(BinaryHashPartition p) throws IOException {
ChannelWithMeta channelWithMeta =
new ChannelWithMeta(
p.probeSideBuffer.getChannel().getChannelID(),
@@ -424,27 +458,11 @@ public class BinaryHashTable extends BaseHybridHashTable {
new ArrayList<>(),
this.binaryProbeSideSerializer);
this.probeIterator.set(probeReader);
- this.probeIterator.setReuse(binaryProbeSideSerializer.createInstance());
-
- // unregister the pending partition
- this.partitionsPending.remove(0);
- this.currentRecursionDepth = p.getRecursionLevel() + 1;
-
- // recursively get the next
- return nextMatching();
+ this.probeIterator.setReuse(this.binaryProbeSideSerializer.createInstance());
}
- private void buildTableFromSpilledPartition(final BinaryHashPartition p) throws IOException {
-
- final int nextRecursionLevel = p.getRecursionLevel() + 1;
- if (nextRecursionLevel == 2) {
- LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
- } else if (nextRecursionLevel > MAX_RECURSION_DEPTH) {
- throw new RuntimeException(
- "Hash join exceeded maximum number of recursions, without reducing "
- + "partitions enough to be memory resident. Probably cause: Too many duplicate keys.");
- }
-
+ private void buildTableFromSpilledPartition(
+ final BinaryHashPartition p, final int nextRecursionLevel) throws IOException {
if (p.getBuildSideBlockCount() > p.getProbeSideBlockCount()) {
LOG.info(
String.format(
@@ -630,6 +648,16 @@ public class BinaryHashTable extends BaseHybridHashTable {
for (final BinaryHashPartition p : this.partitionsPending) {
p.clearAllMemory(this.internalPool);
}
+
+ // clear the partitions that processed by sort merge join operator
+ for (final BinaryHashPartition p : this.partitionsPendingForSMJ) {
+ try {
+ p.clearAllMemory(this.internalPool);
+ } catch (Exception e) {
+ LOG.error("Error during partition cleanup.", e);
+ }
+ }
+ this.partitionsPendingForSMJ.clear();
}
/**
@@ -659,7 +687,6 @@ public class BinaryHashTable extends BaseHybridHashTable {
this.currentEnumerator.next(),
this.buildSpillReturnBuffers);
this.buildSpillRetBufferNumbers += numBuffersFreed;
-
LOG.info(
String.format(
"Grace hash join: Ran out memory, choosing partition "
@@ -675,11 +702,62 @@ public class BinaryHashTable extends BaseHybridHashTable {
}
numSpillFiles++;
spillInBytes += numBuffersFreed * segmentSize;
- // The bloomFilter is built after the data is spilled, so that we can use enough memory.
+ // The bloomFilter is built by bucket area after the data is spilled, so that we can use
+ // enough memory.
p.buildBloomFilterAndFreeBucket();
return largestPartNum;
}
+ public List<BinaryHashPartition> getPartitionsPendingForSMJ() {
+ return this.partitionsPendingForSMJ;
+ }
+
+ public RowIterator getSpilledPartitionBuildSideIter(BinaryHashPartition p) throws IOException {
+ // close build side channel of last processed partition
+ if (this.currentSpilledBuildSide != null) {
+ try {
+ this.currentSpilledBuildSide.getChannel().closeAndDelete();
+ } catch (Throwable t) {
+ LOG.warn(
+ "Could not close and delete the temp file for the current spilled partition build side.",
+ t);
+ }
+ this.currentSpilledBuildSide = null;
+ }
+
+ this.currentSpilledBuildSide =
+ createInputView(
+ p.getBuildSideChannel().getChannelID(),
+ p.getBuildSideBlockCount(),
+ p.getLastSegmentLimit());
+ this.buildIterator =
+ new WrappedRowIterator<>(
+ new BinaryRowChannelInputViewIterator(
+ this.currentSpilledBuildSide, this.binaryBuildSideSerializer),
+ this.binaryBuildSideSerializer.createInstance());
+ return this.buildIterator;
+ }
+
+ public ProbeIterator getSpilledPartitionProbeSideIter(BinaryHashPartition p)
+ throws IOException {
+ // close probe side channel of last processed partition
+ if (this.currentSpilledProbeSide != null) {
+ try {
+ this.currentSpilledProbeSide.getChannel().closeAndDelete();
+ } catch (Throwable t) {
+ LOG.warn(
+ "Could not close and delete the temp file for the current spilled partition probe side.",
+ t);
+ }
+ this.currentSpilledProbeSide = null;
+ }
+
+ // get the probe side iterator
+ this.probeIterator = new ProbeIterator(this.binaryProbeSideSerializer.createInstance());
+ setPartitionProbeReader(p);
+ return this.probeIterator;
+ }
+
boolean applyCondition(BinaryRowData candidate) {
BinaryRowData buildKey = buildSideProjection.apply(candidate);
// They come from Projection, so we can make sure it is in byte[].
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
index 951562687cb..306420272bf 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
@@ -849,7 +849,7 @@ public class LongHashPartition extends AbstractPagedInputView implements Seekabl
}
}
- final long getPointer() {
+ long getPointer() {
return this.currentPointer;
}
@@ -879,7 +879,7 @@ public class LongHashPartition extends AbstractPagedInputView implements Seekabl
return available < 8 + serializer.getFixedLengthPartSize();
}
- static void deserializeFromPages(
+ public static void deserializeFromPages(
BinaryRowData reuse,
ChannelReaderInputView inView,
BinaryRowDataSerializer buildSideSerializer)
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
index 600953a7e15..75081ea81b7 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
@@ -28,8 +28,10 @@ import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.runtime.io.ChannelWithMeta;
+import org.apache.flink.table.runtime.io.LongHashPartitionChannelReaderInputViewIterator;
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
import org.apache.flink.table.runtime.util.FileChannelUtil;
+import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.util.MathUtils;
import java.io.EOFException;
@@ -52,6 +54,12 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
private final BinaryRowDataSerializer probeSideSerializer;
private final ArrayList<LongHashPartition> partitionsBeingBuilt;
private final ArrayList<LongHashPartition> partitionsPending;
+ /**
+ * The partitions that have been spilled previously and are pending to be processed by sort
+ * merge join operator.
+ */
+ private final List<LongHashPartition> partitionsPendingForSMJ;
+
private ProbeIterator probeIterator;
private LongHashPartition.MatchIterator matchIterator;
@@ -85,6 +93,7 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
this.partitionsBeingBuilt = new ArrayList<>();
this.partitionsPending = new ArrayList<>();
+ this.partitionsPendingForSMJ = new ArrayList<>();
createPartitions(initPartitionFanOut, 0);
}
@@ -184,7 +193,7 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
/** After build end, try to use dense mode. */
private void tryDenseMode() {
-
+ // if some partitions have spilled to disk, always use hash mode
if (numSpillFiles != 0) {
return;
}
@@ -212,7 +221,7 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
long range = maxKey - minKey + 1;
- // 1.range is negative mean: range is to big to overflow
+ // 1.range is negative mean: range is too big to overflow
// 2.range is zero, maybe the max is Long.Max, and the min is Long.Min,
// so we should not use dense mode too.
if (range > 0 && (range <= recordCount * 4 || range <= segmentSize / 8)) {
@@ -221,8 +230,7 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
int buffers = (int) Math.ceil(((double) (range * 8)) / segmentSize);
// TODO MemoryManager needs to support flexible larger segment, so that the index area
- // of the
- // build side is placed on a segment to avoid the overhead of addressing.
+ // of the build side is placed on a segment to avoid the overhead of addressing.
MemorySegment[] denseBuckets = new MemorySegment[buffers];
for (int i = 0; i < buffers; i++) {
MemorySegment seg = getNextBuffer();
@@ -362,10 +370,38 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
return prepareNextPartition();
}
+ final int nextRecursionLevel = p.getRecursionLevel() + 1;
+ if (nextRecursionLevel == 2) {
+ LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
+ } else if (nextRecursionLevel > MAX_RECURSION_DEPTH) {
+ LOG.info(
+ "Partition number [{}] recursive level more than {}, process the partition using SortMergeJoin later.",
+ p.getPartitionNumber(),
+ MAX_RECURSION_DEPTH);
+ // if the partition has spilled to disk more than three times, process it by sort merge
+ // join later
+ this.partitionsPendingForSMJ.add(p);
+ // also need to remove it from pending list
+ this.partitionsPending.remove(0);
+ // recursively get the next partition
+ return prepareNextPartition();
+ }
+
// build the next table; memory must be allocated after this call
- buildTableFromSpilledPartition(p);
+ buildTableFromSpilledPartition(p, nextRecursionLevel);
// set the probe side
+ setPartitionProbeReader(p);
+
+ // unregister the pending partition
+ this.partitionsPending.remove(0);
+ this.currentRecursionDepth = p.getRecursionLevel() + 1;
+
+ // recursively get the next
+ return nextMatching();
+ }
+
+ private void setPartitionProbeReader(LongHashPartition p) throws IOException {
ChannelWithMeta channelWithMeta =
new ChannelWithMeta(
p.probeSideBuffer.getChannel().getChannelID(),
@@ -386,26 +422,10 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
this.currentSpilledProbeSide, new ArrayList<>(), this.probeSideSerializer);
this.probeIterator.set(probeReader);
this.probeIterator.setReuse(probeSideSerializer.createInstance());
-
- // unregister the pending partition
- this.partitionsPending.remove(0);
- this.currentRecursionDepth = p.getRecursionLevel() + 1;
-
- // recursively get the next
- return nextMatching();
}
- private void buildTableFromSpilledPartition(final LongHashPartition p) throws IOException {
-
- final int nextRecursionLevel = p.getRecursionLevel() + 1;
- if (nextRecursionLevel == 2) {
- LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
- } else if (nextRecursionLevel > MAX_RECURSION_DEPTH) {
- throw new RuntimeException(
- "Hash join exceeded maximum number of recursions, without reducing "
- + "partitions enough to be memory resident. Probably cause: Too many duplicate keys.");
- }
-
+ private void buildTableFromSpilledPartition(
+ final LongHashPartition p, final int nextRecursionLevel) throws IOException {
if (p.getBuildSideBlockCount() > p.getProbeSideBlockCount()) {
LOG.info(
String.format(
@@ -561,6 +581,53 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
return largestPartNum;
}
+ public List<LongHashPartition> getPartitionsPendingForSMJ() {
+ return this.partitionsPendingForSMJ;
+ }
+
+ public RowIterator getSpilledPartitionBuildSideIter(LongHashPartition p) throws IOException {
+ // close build side channel of last processed partition
+ if (this.currentSpilledBuildSide != null) {
+ try {
+ this.currentSpilledBuildSide.getChannel().closeAndDelete();
+ } catch (Throwable t) {
+ LOG.warn(
+ "Could not close and delete the temp file for the current spilled partition build side.",
+ t);
+ }
+ this.currentSpilledBuildSide = null;
+ }
+
+ this.currentSpilledBuildSide =
+ createInputView(
+ p.getBuildSideChannel().getChannelID(),
+ p.getBuildSideBlockCount(),
+ p.getLastSegmentLimit());
+ return new WrappedRowIterator<>(
+ new LongHashPartitionChannelReaderInputViewIterator(
+ this.currentSpilledBuildSide, this.buildSideSerializer),
+ this.buildSideSerializer.createInstance());
+ }
+
+ public ProbeIterator getSpilledPartitionProbeSideIter(LongHashPartition p) throws IOException {
+ // close probe side channel of last processed partition
+ if (this.currentSpilledProbeSide != null) {
+ try {
+ this.currentSpilledProbeSide.getChannel().closeAndDelete();
+ } catch (Throwable t) {
+ LOG.warn(
+ "Could not close and delete the temp file for the current spilled partition probe side.",
+ t);
+ }
+ this.currentSpilledProbeSide = null;
+ }
+
+ // get the probe side iterator
+ this.probeIterator = new ProbeIterator(this.probeSideSerializer.createInstance());
+ setPartitionProbeReader(p);
+ return this.probeIterator;
+ }
+
@Override
protected void clearPartitions() {
this.probeIterator = null;
@@ -579,6 +646,16 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
for (final LongHashPartition p : this.partitionsPending) {
p.clearAllMemory(this.internalPool);
}
+
+ // clear the partitions that processed by sort merge join operator
+ for (final LongHashPartition p : this.partitionsPendingForSMJ) {
+ try {
+ p.clearAllMemory(this.internalPool);
+ } catch (Exception e) {
+ LOG.error("Error during partition cleanup.", e);
+ }
+ }
+ this.partitionsPendingForSMJ.clear();
}
public boolean compressionEnable() {
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
index 1779ab27241..30456442622 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
@@ -34,11 +34,11 @@ import java.util.List;
* BinaryRowDataSerializer#deserializeFromPages}.
*/
public class BinaryRowChannelInputViewIterator implements MutableObjectIterator<BinaryRowData> {
- private final ChannelReaderInputView inView;
+ protected final ChannelReaderInputView inView;
- private final BinaryRowDataSerializer serializer;
+ protected final BinaryRowDataSerializer serializer;
- private final List<MemorySegment> freeMemTarget;
+ protected final List<MemorySegment> freeMemTarget;
public BinaryRowChannelInputViewIterator(
ChannelReaderInputView inView, BinaryRowDataSerializer serializer) {
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/LongHashPartitionChannelReaderInputViewIterator.java
similarity index 61%
copy from flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
copy to flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/LongHashPartitionChannelReaderInputViewIterator.java
index 1779ab27241..b7e2cf62212 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/LongHashPartitionChannelReaderInputViewIterator.java
@@ -21,43 +21,30 @@ package org.apache.flink.table.runtime.io;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.disk.iomanager.ChannelReaderInputView;
import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.hashtable.LongHashPartition;
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
-import org.apache.flink.util.MutableObjectIterator;
import java.io.EOFException;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.List;
/**
* A simple iterator over the input read though an I/O channel. Use {@link
- * BinaryRowDataSerializer#deserializeFromPages}.
+ * LongHashPartition#deserializeFromPages}
*/
-public class BinaryRowChannelInputViewIterator implements MutableObjectIterator<BinaryRowData> {
- private final ChannelReaderInputView inView;
+public class LongHashPartitionChannelReaderInputViewIterator
+ extends BinaryRowChannelInputViewIterator {
- private final BinaryRowDataSerializer serializer;
-
- private final List<MemorySegment> freeMemTarget;
-
- public BinaryRowChannelInputViewIterator(
+ public LongHashPartitionChannelReaderInputViewIterator(
ChannelReaderInputView inView, BinaryRowDataSerializer serializer) {
- this(inView, new ArrayList<>(), serializer);
- }
-
- public BinaryRowChannelInputViewIterator(
- ChannelReaderInputView inView,
- List<MemorySegment> freeMemTarget,
- BinaryRowDataSerializer serializer) {
- this.inView = inView;
- this.freeMemTarget = freeMemTarget;
- this.serializer = serializer;
+ super(inView, serializer);
}
@Override
public BinaryRowData next(BinaryRowData reuse) throws IOException {
try {
- return this.serializer.deserializeFromPages(reuse, this.inView);
+ LongHashPartition.deserializeFromPages(reuse, inView, serializer);
+ return reuse;
} catch (EOFException eofex) {
final List<MemorySegment> freeMem = this.inView.close();
if (this.freeMemTarget != null) {
@@ -66,10 +53,4 @@ public class BinaryRowChannelInputViewIterator implements MutableObjectIterator<
return null;
}
}
-
- @Override
- public BinaryRowData next() throws IOException {
- throw new UnsupportedOperationException(
- "This method is disabled due to performance issue!");
- }
}
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/HashJoinOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/HashJoinOperator.java
index 08312d34d7e..59b522a4521 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/HashJoinOperator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/HashJoinOperator.java
@@ -32,7 +32,9 @@ import org.apache.flink.table.data.utils.JoinedRowData;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.hashtable.BinaryHashPartition;
import org.apache.flink.table.runtime.hashtable.BinaryHashTable;
+import org.apache.flink.table.runtime.hashtable.ProbeIterator;
import org.apache.flink.table.runtime.operators.TableStreamOperator;
import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
import org.apache.flink.table.runtime.util.RowIterator;
@@ -54,6 +56,13 @@ import static org.apache.flink.util.Preconditions.checkState;
* <p>The join operator implements the logic of a join operator at runtime. It uses a
* hybrid-hash-join internally to match the records with equal key. The build side of the hash is
* the first input of the match. It support all join type in {@link HashJoinType}.
+ *
+ * <p>Note: In order to solve the problem of data skew, or too much data in the hash table, the
+ * fallback to sort merge join mechanism is introduced here. If some partitions are spilled to disk
+ * more than three times in the process of hash join, it will fallback to sort merge join by default
+ * to improve stability. In the future, we will support more flexible adaptive hash join strategy,
+ * for example, in the process of building a hash table, if the size of data written to disk reaches
+ * a certain threshold, fallback to sort merge join in advance.
*/
public abstract class HashJoinOperator extends TableStreamOperator<RowData>
implements TwoInputStreamOperator<RowData, RowData, RowData>,
@@ -65,6 +74,8 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
private final HashJoinParameter parameter;
private final boolean reverseJoinFunction;
private final HashJoinType type;
+ private final boolean leftIsBuild;
+ private final SortMergeJoinFunction sortMergeJoinFunction;
private transient BinaryHashTable table;
transient Collector<RowData> collector;
@@ -75,10 +86,15 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
private transient boolean buildEnd;
private transient JoinCondition condition;
+ // Flag indicates whether fallback to sort merge join in probe phase
+ private transient boolean fallbackSMJ;
+
HashJoinOperator(HashJoinParameter parameter) {
this.parameter = parameter;
this.type = parameter.type;
+ this.leftIsBuild = parameter.leftIsBuild;
this.reverseJoinFunction = parameter.reverseJoinFunction;
+ this.sortMergeJoinFunction = parameter.sortMergeJoinFunction;
}
@Override
@@ -132,6 +148,7 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
this.probeSideNullRow = new GenericRowData(probeSerializer.getArity());
this.joinedRow = new JoinedRowData();
this.buildEnd = false;
+ this.fallbackSMJ = false;
getMetricGroup().gauge("memoryUsedSizeInBytes", table::getUsedMemoryInBytes);
getMetricGroup().gauge("numSpillFiles", table::getNumSpillFiles);
@@ -177,6 +194,10 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
joinWithNextKey();
}
LOG.info("Finish rebuild phase.");
+
+ // switch to sort merge join process the remaining partition which recursive
+ // level > 3
+ fallbackSMJProcessPartition();
break;
}
}
@@ -214,16 +235,90 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
@Override
public void close() throws Exception {
super.close();
+ closeHashTable();
+ condition.close();
+
+ // If fallback to sort merge join during hash join, also need to close the operator
+ if (fallbackSMJ) {
+ sortMergeJoinFunction.close();
+ }
+ }
+
+ private void closeHashTable() {
if (this.table != null) {
this.table.close();
this.table.free();
this.table = null;
}
- condition.close();
+ }
+
+ /**
+ * If here also exists partitions which spilled to disk more than three time when hash join end,
+ * means that the key in these partitions is very skewed, so fallback to sort merge join
+ * algorithm to process it.
+ */
+ private void fallbackSMJProcessPartition() throws Exception {
+ if (!table.getPartitionsPendingForSMJ().isEmpty()) {
+ // initialize sort merge join operator
+ LOG.info("Fallback to sort merge join to process spilled partitions.");
+ initialSortMergeJoinFunction();
+ fallbackSMJ = true;
+
+ for (BinaryHashPartition p : table.getPartitionsPendingForSMJ()) {
+ // process build side
+ RowIterator<BinaryRowData> buildSideIter =
+ table.getSpilledPartitionBuildSideIter(p);
+ while (buildSideIter.advanceNext()) {
+ processSortMergeJoinElement1(buildSideIter.getRow());
+ }
+
+ // process probe side
+ ProbeIterator probeIter = table.getSpilledPartitionProbeSideIter(p);
+ BinaryRowData probeNext;
+ while ((probeNext = probeIter.next()) != null) {
+ processSortMergeJoinElement2(probeNext);
+ }
+ }
+
+ // close the HashTable
+ closeHashTable();
+
+ // finish build and probe
+ sortMergeJoinFunction.endInput(1);
+ sortMergeJoinFunction.endInput(2);
+ LOG.info("Finish sort merge join for spilled partitions.");
+ }
+ }
+
+ private void initialSortMergeJoinFunction() throws Exception {
+ sortMergeJoinFunction.open(
+ this.getContainingTask(),
+ this.getOperatorConfig(),
+ (StreamRecordCollector) this.collector,
+ this.computeMemorySize(),
+ this.getRuntimeContext(),
+ this.getMetricGroup());
+ }
+
+ private void processSortMergeJoinElement1(RowData rowData) throws Exception {
+ if (leftIsBuild) {
+ sortMergeJoinFunction.processElement1(rowData);
+ } else {
+ sortMergeJoinFunction.processElement2(rowData);
+ }
+ }
+
+ private void processSortMergeJoinElement2(RowData rowData) throws Exception {
+ if (leftIsBuild) {
+ sortMergeJoinFunction.processElement2(rowData);
+ } else {
+ sortMergeJoinFunction.processElement1(rowData);
+ }
}
public static HashJoinOperator newHashJoinOperator(
HashJoinType type,
+ boolean leftIsBuild,
GeneratedJoinCondition condFuncCode,
boolean reverseJoinFunction,
boolean[] filterNullKeys,
@@ -233,10 +328,12 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
int buildRowSize,
long buildRowCount,
long probeRowCount,
- RowType keyType) {
+ RowType keyType,
+ SortMergeJoinFunction sortMergeJoinFunction) {
HashJoinParameter parameter =
new HashJoinParameter(
type,
+ leftIsBuild,
condFuncCode,
reverseJoinFunction,
filterNullKeys,
@@ -246,7 +343,8 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
buildRowSize,
buildRowCount,
probeRowCount,
- keyType);
+ keyType,
+ sortMergeJoinFunction);
switch (type) {
case INNER:
return new InnerHashJoinOperator(parameter);
@@ -270,6 +368,7 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
static class HashJoinParameter implements Serializable {
HashJoinType type;
+ boolean leftIsBuild;
GeneratedJoinCondition condFuncCode;
boolean reverseJoinFunction;
boolean[] filterNullKeys;
@@ -280,9 +379,11 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
long buildRowCount;
long probeRowCount;
RowType keyType;
+ SortMergeJoinFunction sortMergeJoinFunction;
HashJoinParameter(
HashJoinType type,
+ boolean leftIsBuild,
GeneratedJoinCondition condFuncCode,
boolean reverseJoinFunction,
boolean[] filterNullKeys,
@@ -292,8 +393,10 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
int buildRowSize,
long buildRowCount,
long probeRowCount,
- RowType keyType) {
+ RowType keyType,
+ SortMergeJoinFunction sortMergeJoinFunction) {
this.type = type;
+ this.leftIsBuild = leftIsBuild;
this.condFuncCode = condFuncCode;
this.reverseJoinFunction = reverseJoinFunction;
this.filterNullKeys = filterNullKeys;
@@ -304,6 +407,7 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
this.buildRowCount = buildRowCount;
this.probeRowCount = probeRowCount;
this.keyType = keyType;
+ this.sortMergeJoinFunction = sortMergeJoinFunction;
}
}
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinFunction.java
similarity index 85%
copy from flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
copy to flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinFunction.java
index 18caea96374..79a8fc19c56 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinFunction.java
@@ -1,12 +1,13 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -17,13 +18,14 @@
package org.apache.flink.table.runtime.operators.join;
+import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.metrics.Gauge;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.memory.MemoryManager;
-import org.apache.flink.streaming.api.operators.BoundedMultiInput;
-import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
@@ -36,7 +38,6 @@ import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
import org.apache.flink.table.runtime.generated.JoinCondition;
import org.apache.flink.table.runtime.generated.Projection;
import org.apache.flink.table.runtime.generated.RecordComparator;
-import org.apache.flink.table.runtime.operators.TableStreamOperator;
import org.apache.flink.table.runtime.operators.sort.BinaryExternalSorter;
import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
@@ -46,21 +47,13 @@ import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
+import java.io.Serializable;
import java.util.BitSet;
import static org.apache.flink.util.Preconditions.checkNotNull;
-/**
- * An implementation that realizes the joining through a sort-merge join strategy. 1.In most cases,
- * its performance is weaker than HashJoin. 2.It is more stable than HashJoin, and most of the data
- * can be sorted stably. 3.SortMergeJoin should be the best choice if sort can be omitted in the
- * case of multi-level join cascade with the same key.
- *
- * <p>NOTE: SEMI and ANTI join output input1 instead of input2. (Contrary to {@link
- * HashJoinOperator}).
- */
-public class SortMergeJoinOperator extends TableStreamOperator<RowData>
- implements TwoInputStreamOperator<RowData, RowData, RowData>, BoundedMultiInput {
+/** This function is used to process the main logic of sort merge join. */
+public class SortMergeJoinFunction implements Serializable {
private final double externalBufferMemRatio;
private final FlinkJoinType type;
@@ -77,6 +70,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
private GeneratedRecordComparator comparator2;
private GeneratedRecordComparator genKeyComparator;
+ private transient StreamTask<?, ?> taskContainer;
private transient long externalBufferMemory;
private transient MemoryManager memManager;
private transient IOManager ioManager;
@@ -95,7 +89,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
private transient RowData rightNullRow;
private transient JoinedRowData joinedRow;
- public SortMergeJoinOperator(
+ public SortMergeJoinFunction(
double externalBufferMemRatio,
FlinkJoinType type,
boolean leftIsSmaller,
@@ -122,29 +116,31 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
this.filterNulls = filterNulls;
}
- @Override
- public void open() throws Exception {
- super.open();
-
- Configuration conf = getContainingTask().getJobConfiguration();
+ public void open(
+ StreamTask<?, ?> taskContainer,
+ StreamConfig operatorConfig,
+ StreamRecordCollector collector,
+ long totalMemory,
+ RuntimeContext runtimeContext,
+ OperatorMetricGroup operatorMetricGroup)
+ throws Exception {
+ this.taskContainer = taskContainer;
isFinished = new boolean[] {false, false};
- collector = new StreamRecordCollector<>(output);
+ this.collector = collector;
- ClassLoader cl = getUserCodeClassloader();
+ ClassLoader cl = taskContainer.getUserCodeClassLoader();
AbstractRowDataSerializer inputSerializer1 =
- (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn1(cl);
+ (AbstractRowDataSerializer) operatorConfig.getTypeSerializerIn1(cl);
this.serializer1 = new BinaryRowDataSerializer(inputSerializer1.getArity());
AbstractRowDataSerializer inputSerializer2 =
- (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn2(cl);
+ (AbstractRowDataSerializer) operatorConfig.getTypeSerializerIn2(cl);
this.serializer2 = new BinaryRowDataSerializer(inputSerializer2.getArity());
- this.memManager = this.getContainingTask().getEnvironment().getMemoryManager();
- this.ioManager = this.getContainingTask().getEnvironment().getIOManager();
-
- long totalMemory = computeMemorySize();
+ this.memManager = taskContainer.getEnvironment().getMemoryManager();
+ this.ioManager = taskContainer.getEnvironment().getIOManager();
externalBufferMemory = (long) (totalMemory * externalBufferMemRatio);
externalBufferMemory =
@@ -162,10 +158,11 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
+ ", please increase manage memory of task manager.");
}
+ Configuration conf = taskContainer.getJobConfiguration();
// sorter1
this.sorter1 =
new BinaryExternalSorter(
- this.getContainingTask(),
+ taskContainer,
memManager,
totalSortMem / 2,
ioManager,
@@ -179,7 +176,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
// sorter2
this.sorter2 =
new BinaryExternalSorter(
- this.getContainingTask(),
+ taskContainer,
memManager,
totalSortMem / 2,
ioManager,
@@ -192,7 +189,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
keyComparator = genKeyComparator.newInstance(cl);
this.condFunc = condFuncCode.newInstance(cl);
- condFunc.setRuntimeContext(getRuntimeContext());
+ condFunc.setRuntimeContext(runtimeContext);
condFunc.open(new Configuration());
projection1 = projectionCode1.newInstance(cl);
@@ -211,37 +208,28 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
projectionCode2 = null;
genKeyComparator = null;
- getMetricGroup()
- .gauge(
- "memoryUsedSizeInBytes",
- (Gauge<Long>)
- () ->
- sorter1.getUsedMemoryInBytes()
- + sorter2.getUsedMemoryInBytes());
-
- getMetricGroup()
- .gauge(
- "numSpillFiles",
- (Gauge<Long>)
- () -> sorter1.getNumSpillFiles() + sorter2.getNumSpillFiles());
-
- getMetricGroup()
- .gauge(
- "spillInBytes",
- (Gauge<Long>) () -> sorter1.getSpillInBytes() + sorter2.getSpillInBytes());
+ operatorMetricGroup.gauge(
+ "memoryUsedSizeInBytes",
+ (Gauge<Long>)
+ () -> sorter1.getUsedMemoryInBytes() + sorter2.getUsedMemoryInBytes());
+
+ operatorMetricGroup.gauge(
+ "numSpillFiles",
+ (Gauge<Long>) () -> sorter1.getNumSpillFiles() + sorter2.getNumSpillFiles());
+
+ operatorMetricGroup.gauge(
+ "spillInBytes",
+ (Gauge<Long>) () -> sorter1.getSpillInBytes() + sorter2.getSpillInBytes());
}
- @Override
- public void processElement1(StreamRecord<RowData> element) throws Exception {
- this.sorter1.write(element.getValue());
+ public void processElement1(RowData element) throws Exception {
+ this.sorter1.write(element);
}
- @Override
- public void processElement2(StreamRecord<RowData> element) throws Exception {
- this.sorter2.write(element.getValue());
+ public void processElement2(RowData element) throws Exception {
+ this.sorter2.write(element);
}
- @Override
public void endInput(int inputId) throws Exception {
isFinished[inputId - 1] = true;
if (isAllFinished()) {
@@ -521,7 +509,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
private ResettableExternalBuffer newBuffer(BinaryRowDataSerializer serializer) {
LazyMemorySegmentPool pool =
new LazyMemorySegmentPool(
- this.getContainingTask(),
+ taskContainer,
memManager,
(int) (externalBufferMemory / memManager.getPageSize()));
return new ResettableExternalBuffer(
@@ -535,9 +523,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
return isFinished[0] && isFinished[1];
}
- @Override
public void close() throws Exception {
- super.close();
if (this.sorter1 != null) {
this.sorter1.close();
}
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
index 18caea96374..61c452a0f79 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
@@ -17,38 +17,12 @@
package org.apache.flink.table.runtime.operators.join;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.metrics.Gauge;
-import org.apache.flink.runtime.io.disk.iomanager.IOManager;
-import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.streaming.api.operators.BoundedMultiInput;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.table.api.TableException;
-import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.binary.BinaryRowData;
-import org.apache.flink.table.data.utils.JoinedRowData;
-import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
-import org.apache.flink.table.runtime.generated.GeneratedNormalizedKeyComputer;
-import org.apache.flink.table.runtime.generated.GeneratedProjection;
-import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
-import org.apache.flink.table.runtime.generated.JoinCondition;
-import org.apache.flink.table.runtime.generated.Projection;
-import org.apache.flink.table.runtime.generated.RecordComparator;
import org.apache.flink.table.runtime.operators.TableStreamOperator;
-import org.apache.flink.table.runtime.operators.sort.BinaryExternalSorter;
-import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
-import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
-import org.apache.flink.table.runtime.util.LazyMemorySegmentPool;
-import org.apache.flink.table.runtime.util.ResettableExternalBuffer;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
-import org.apache.flink.util.Collector;
-import org.apache.flink.util.MutableObjectIterator;
-
-import java.util.BitSet;
-
-import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* An implementation that realizes the joining through a sort-merge join strategy. 1.In most cases,
@@ -62,488 +36,44 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
public class SortMergeJoinOperator extends TableStreamOperator<RowData>
implements TwoInputStreamOperator<RowData, RowData, RowData>, BoundedMultiInput {
- private final double externalBufferMemRatio;
- private final FlinkJoinType type;
- private final boolean leftIsSmaller;
- private final boolean[] filterNulls;
-
- // generated code to cook
- private GeneratedJoinCondition condFuncCode;
- private GeneratedProjection projectionCode1;
- private GeneratedProjection projectionCode2;
- private GeneratedNormalizedKeyComputer computer1;
- private GeneratedRecordComparator comparator1;
- private GeneratedNormalizedKeyComputer computer2;
- private GeneratedRecordComparator comparator2;
- private GeneratedRecordComparator genKeyComparator;
+ private final SortMergeJoinFunction sortMergeJoinFunction;
- private transient long externalBufferMemory;
- private transient MemoryManager memManager;
- private transient IOManager ioManager;
- private transient BinaryRowDataSerializer serializer1;
- private transient BinaryRowDataSerializer serializer2;
- private transient BinaryExternalSorter sorter1;
- private transient BinaryExternalSorter sorter2;
- private transient Collector<RowData> collector;
- private transient boolean[] isFinished;
- private transient JoinCondition condFunc;
- private transient RecordComparator keyComparator;
- private transient Projection<RowData, BinaryRowData> projection1;
- private transient Projection<RowData, BinaryRowData> projection2;
-
- private transient RowData leftNullRow;
- private transient RowData rightNullRow;
- private transient JoinedRowData joinedRow;
-
- public SortMergeJoinOperator(
- double externalBufferMemRatio,
- FlinkJoinType type,
- boolean leftIsSmaller,
- GeneratedJoinCondition condFuncCode,
- GeneratedProjection projectionCode1,
- GeneratedProjection projectionCode2,
- GeneratedNormalizedKeyComputer computer1,
- GeneratedRecordComparator comparator1,
- GeneratedNormalizedKeyComputer computer2,
- GeneratedRecordComparator comparator2,
- GeneratedRecordComparator genKeyComparator,
- boolean[] filterNulls) {
- this.externalBufferMemRatio = externalBufferMemRatio;
- this.type = type;
- this.leftIsSmaller = leftIsSmaller;
- this.condFuncCode = condFuncCode;
- this.projectionCode1 = projectionCode1;
- this.projectionCode2 = projectionCode2;
- this.computer1 = checkNotNull(computer1);
- this.comparator1 = checkNotNull(comparator1);
- this.computer2 = checkNotNull(computer2);
- this.comparator2 = checkNotNull(comparator2);
- this.genKeyComparator = checkNotNull(genKeyComparator);
- this.filterNulls = filterNulls;
+ public SortMergeJoinOperator(SortMergeJoinFunction sortMergeJoinFunction) {
+ this.sortMergeJoinFunction = sortMergeJoinFunction;
}
@Override
public void open() throws Exception {
super.open();
- Configuration conf = getContainingTask().getJobConfiguration();
-
- isFinished = new boolean[] {false, false};
-
- collector = new StreamRecordCollector<>(output);
-
- ClassLoader cl = getUserCodeClassloader();
- AbstractRowDataSerializer inputSerializer1 =
- (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn1(cl);
- this.serializer1 = new BinaryRowDataSerializer(inputSerializer1.getArity());
-
- AbstractRowDataSerializer inputSerializer2 =
- (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn2(cl);
- this.serializer2 = new BinaryRowDataSerializer(inputSerializer2.getArity());
-
- this.memManager = this.getContainingTask().getEnvironment().getMemoryManager();
- this.ioManager = this.getContainingTask().getEnvironment().getIOManager();
-
- long totalMemory = computeMemorySize();
-
- externalBufferMemory = (long) (totalMemory * externalBufferMemRatio);
- externalBufferMemory =
- Math.max(externalBufferMemory, ResettableExternalBuffer.MIN_NUM_MEMORY);
-
- long totalSortMem =
- totalMemory
- - (type.equals(FlinkJoinType.FULL)
- ? externalBufferMemory * 2
- : externalBufferMemory);
- if (totalSortMem < 0) {
- throw new TableException(
- "Memory size is too small: "
- + totalMemory
- + ", please increase manage memory of task manager.");
- }
-
- // sorter1
- this.sorter1 =
- new BinaryExternalSorter(
- this.getContainingTask(),
- memManager,
- totalSortMem / 2,
- ioManager,
- inputSerializer1,
- serializer1,
- computer1.newInstance(cl),
- comparator1.newInstance(cl),
- conf);
- this.sorter1.startThreads();
-
- // sorter2
- this.sorter2 =
- new BinaryExternalSorter(
- this.getContainingTask(),
- memManager,
- totalSortMem / 2,
- ioManager,
- inputSerializer2,
- serializer2,
- computer2.newInstance(cl),
- comparator2.newInstance(cl),
- conf);
- this.sorter2.startThreads();
-
- keyComparator = genKeyComparator.newInstance(cl);
- this.condFunc = condFuncCode.newInstance(cl);
- condFunc.setRuntimeContext(getRuntimeContext());
- condFunc.open(new Configuration());
-
- projection1 = projectionCode1.newInstance(cl);
- projection2 = projectionCode2.newInstance(cl);
-
- this.leftNullRow = new GenericRowData(serializer1.getArity());
- this.rightNullRow = new GenericRowData(serializer2.getArity());
- this.joinedRow = new JoinedRowData();
-
- condFuncCode = null;
- computer1 = null;
- comparator1 = null;
- computer2 = null;
- comparator2 = null;
- projectionCode1 = null;
- projectionCode2 = null;
- genKeyComparator = null;
-
- getMetricGroup()
- .gauge(
- "memoryUsedSizeInBytes",
- (Gauge<Long>)
- () ->
- sorter1.getUsedMemoryInBytes()
- + sorter2.getUsedMemoryInBytes());
-
- getMetricGroup()
- .gauge(
- "numSpillFiles",
- (Gauge<Long>)
- () -> sorter1.getNumSpillFiles() + sorter2.getNumSpillFiles());
-
- getMetricGroup()
- .gauge(
- "spillInBytes",
- (Gauge<Long>) () -> sorter1.getSpillInBytes() + sorter2.getSpillInBytes());
+ // initialize sort merge join function
+ this.sortMergeJoinFunction.open(
+ this.getContainingTask(),
+ this.getOperatorConfig(),
+ new StreamRecordCollector(output),
+ this.computeMemorySize(),
+ this.getRuntimeContext(),
+ this.getMetricGroup());
}
@Override
public void processElement1(StreamRecord<RowData> element) throws Exception {
- this.sorter1.write(element.getValue());
+ this.sortMergeJoinFunction.processElement1(element.getValue());
}
@Override
public void processElement2(StreamRecord<RowData> element) throws Exception {
- this.sorter2.write(element.getValue());
+ this.sortMergeJoinFunction.processElement2(element.getValue());
}
@Override
public void endInput(int inputId) throws Exception {
- isFinished[inputId - 1] = true;
- if (isAllFinished()) {
- doSortMergeJoin();
- }
- }
-
- private void doSortMergeJoin() throws Exception {
- MutableObjectIterator iterator1 = sorter1.getIterator();
- MutableObjectIterator iterator2 = sorter2.getIterator();
-
- if (type.equals(FlinkJoinType.INNER)) {
- if (!leftIsSmaller) {
- try (SortMergeInnerJoinIterator joinIterator =
- new SortMergeInnerJoinIterator(
- serializer1,
- serializer2,
- projection1,
- projection2,
- keyComparator,
- iterator1,
- iterator2,
- newBuffer(serializer2),
- filterNulls)) {
- innerJoin(joinIterator, false);
- }
- } else {
- try (SortMergeInnerJoinIterator joinIterator =
- new SortMergeInnerJoinIterator(
- serializer2,
- serializer1,
- projection2,
- projection1,
- keyComparator,
- iterator2,
- iterator1,
- newBuffer(serializer1),
- filterNulls)) {
- innerJoin(joinIterator, true);
- }
- }
- } else if (type.equals(FlinkJoinType.LEFT)) {
- try (SortMergeOneSideOuterJoinIterator joinIterator =
- new SortMergeOneSideOuterJoinIterator(
- serializer1,
- serializer2,
- projection1,
- projection2,
- keyComparator,
- iterator1,
- iterator2,
- newBuffer(serializer2),
- filterNulls)) {
- oneSideOuterJoin(joinIterator, false, rightNullRow);
- }
- } else if (type.equals(FlinkJoinType.RIGHT)) {
- try (SortMergeOneSideOuterJoinIterator joinIterator =
- new SortMergeOneSideOuterJoinIterator(
- serializer2,
- serializer1,
- projection2,
- projection1,
- keyComparator,
- iterator2,
- iterator1,
- newBuffer(serializer1),
- filterNulls)) {
- oneSideOuterJoin(joinIterator, true, leftNullRow);
- }
- } else if (type.equals(FlinkJoinType.FULL)) {
- try (SortMergeFullOuterJoinIterator fullOuterJoinIterator =
- new SortMergeFullOuterJoinIterator(
- serializer1,
- serializer2,
- projection1,
- projection2,
- keyComparator,
- iterator1,
- iterator2,
- newBuffer(serializer1),
- newBuffer(serializer2),
- filterNulls)) {
- fullOuterJoin(fullOuterJoinIterator);
- }
- } else if (type.equals(FlinkJoinType.SEMI)) {
- try (SortMergeInnerJoinIterator joinIterator =
- new SortMergeInnerJoinIterator(
- serializer1,
- serializer2,
- projection1,
- projection2,
- keyComparator,
- iterator1,
- iterator2,
- newBuffer(serializer2),
- filterNulls)) {
- while (joinIterator.nextInnerJoin()) {
- RowData probeRow = joinIterator.getProbeRow();
- boolean matched = false;
- try (ResettableExternalBuffer.BufferIterator iter =
- joinIterator.getMatchBuffer().newIterator()) {
- while (iter.advanceNext()) {
- RowData row = iter.getRow();
- if (condFunc.apply(probeRow, row)) {
- matched = true;
- break;
- }
- }
- }
- if (matched) {
- collector.collect(probeRow);
- }
- }
- }
- } else if (type.equals(FlinkJoinType.ANTI)) {
- try (SortMergeOneSideOuterJoinIterator joinIterator =
- new SortMergeOneSideOuterJoinIterator(
- serializer1,
- serializer2,
- projection1,
- projection2,
- keyComparator,
- iterator1,
- iterator2,
- newBuffer(serializer2),
- filterNulls)) {
- while (joinIterator.nextOuterJoin()) {
- RowData probeRow = joinIterator.getProbeRow();
- ResettableExternalBuffer matchBuffer = joinIterator.getMatchBuffer();
- boolean matched = false;
- if (matchBuffer != null) {
- try (ResettableExternalBuffer.BufferIterator iter =
- matchBuffer.newIterator()) {
- while (iter.advanceNext()) {
- RowData row = iter.getRow();
- if (condFunc.apply(probeRow, row)) {
- matched = true;
- break;
- }
- }
- }
- }
- if (!matched) {
- collector.collect(probeRow);
- }
- }
- }
- } else {
- throw new RuntimeException("Not support type: " + type);
- }
- }
-
- private void innerJoin(SortMergeInnerJoinIterator iterator, boolean reverseInvoke)
- throws Exception {
- while (iterator.nextInnerJoin()) {
- RowData probeRow = iterator.getProbeRow();
- ResettableExternalBuffer.BufferIterator iter = iterator.getMatchBuffer().newIterator();
- while (iter.advanceNext()) {
- RowData row = iter.getRow();
- joinWithCondition(probeRow, row, reverseInvoke);
- }
- iter.close();
- }
- }
-
- private void oneSideOuterJoin(
- SortMergeOneSideOuterJoinIterator iterator, boolean reverseInvoke, RowData buildNullRow)
- throws Exception {
- while (iterator.nextOuterJoin()) {
- RowData probeRow = iterator.getProbeRow();
- boolean found = false;
-
- if (iterator.getMatchKey() != null) {
- ResettableExternalBuffer.BufferIterator iter =
- iterator.getMatchBuffer().newIterator();
- while (iter.advanceNext()) {
- RowData row = iter.getRow();
- found |= joinWithCondition(probeRow, row, reverseInvoke);
- }
- iter.close();
- }
-
- if (!found) {
- collect(probeRow, buildNullRow, reverseInvoke);
- }
- }
- }
-
- private void fullOuterJoin(SortMergeFullOuterJoinIterator iterator) throws Exception {
- BitSet bitSet = new BitSet();
-
- while (iterator.nextOuterJoin()) {
-
- bitSet.clear();
- BinaryRowData matchKey = iterator.getMatchKey();
- ResettableExternalBuffer buffer1 = iterator.getBuffer1();
- ResettableExternalBuffer buffer2 = iterator.getBuffer2();
-
- if (matchKey == null && buffer1.size() > 0) { // left outer join.
- ResettableExternalBuffer.BufferIterator iter = buffer1.newIterator();
- while (iter.advanceNext()) {
- RowData row1 = iter.getRow();
- collector.collect(joinedRow.replace(row1, rightNullRow));
- }
- iter.close();
- } else if (matchKey == null && buffer2.size() > 0) { // right outer join.
- ResettableExternalBuffer.BufferIterator iter = buffer2.newIterator();
- while (iter.advanceNext()) {
- RowData row2 = iter.getRow();
- collector.collect(joinedRow.replace(leftNullRow, row2));
- }
- iter.close();
- } else if (matchKey != null) { // match join.
- ResettableExternalBuffer.BufferIterator iter1 = buffer1.newIterator();
- while (iter1.advanceNext()) {
- RowData row1 = iter1.getRow();
- boolean found = false;
- int index = 0;
- ResettableExternalBuffer.BufferIterator iter2 = buffer2.newIterator();
- while (iter2.advanceNext()) {
- RowData row2 = iter2.getRow();
- if (condFunc.apply(row1, row2)) {
- collector.collect(joinedRow.replace(row1, row2));
- found = true;
- bitSet.set(index);
- }
- index++;
- }
- iter2.close();
- if (!found) {
- collector.collect(joinedRow.replace(row1, rightNullRow));
- }
- }
- iter1.close();
-
- // row2 outer
- int index = 0;
- ResettableExternalBuffer.BufferIterator iter2 = buffer2.newIterator();
- while (iter2.advanceNext()) {
- RowData row2 = iter2.getRow();
- if (!bitSet.get(index)) {
- collector.collect(joinedRow.replace(leftNullRow, row2));
- }
- index++;
- }
- iter2.close();
- } else { // bug...
- throw new RuntimeException("There is a bug.");
- }
- }
- }
-
- private boolean joinWithCondition(RowData row1, RowData row2, boolean reverseInvoke)
- throws Exception {
- if (reverseInvoke) {
- if (condFunc.apply(row2, row1)) {
- collector.collect(joinedRow.replace(row2, row1));
- return true;
- }
- } else {
- if (condFunc.apply(row1, row2)) {
- collector.collect(joinedRow.replace(row1, row2));
- return true;
- }
- }
- return false;
- }
-
- private void collect(RowData row1, RowData row2, boolean reverseInvoke) {
- if (reverseInvoke) {
- collector.collect(joinedRow.replace(row2, row1));
- } else {
- collector.collect(joinedRow.replace(row1, row2));
- }
- }
-
- private ResettableExternalBuffer newBuffer(BinaryRowDataSerializer serializer) {
- LazyMemorySegmentPool pool =
- new LazyMemorySegmentPool(
- this.getContainingTask(),
- memManager,
- (int) (externalBufferMemory / memManager.getPageSize()));
- return new ResettableExternalBuffer(
- ioManager,
- pool,
- serializer,
- false /* we don't use newIterator(int beginRow), so don't need use this optimization*/);
- }
-
- private boolean isAllFinished() {
- return isFinished[0] && isFinished[1];
+ this.sortMergeJoinFunction.endInput(inputId);
}
@Override
public void close() throws Exception {
super.close();
- if (this.sorter1 != null) {
- this.sorter1.close();
- }
- if (this.sorter2 != null) {
- this.sorter2.close();
- }
- condFunc.close();
+ this.sortMergeJoinFunction.close();
}
}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
index 3686a62ab0a..b9868059497 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
@@ -562,7 +562,7 @@ public class BinaryHashTableTest {
* fits into memory by itself and needs to be repartitioned in the recursion again.
*/
@Test
- public void testFailingHashJoinTooManyRecursions() throws IOException {
+ public void testSpillingHashJoinWithTooManyRecursions() throws IOException {
// the following two values are known to have a hash-code collision on the first recursion
// level.
// we use them to make sure one partition grows over-proportionally large
@@ -613,12 +613,61 @@ public class BinaryHashTableTest {
896 * PAGE_SIZE,
ioManager);
- try {
- join(table, buildInput, probeInput);
- fail("Hash Join must have failed due to too many recursions.");
- } catch (Exception ex) {
- // expected
+ // create the map for validating the results
+ HashMap<Integer, Long> map = new HashMap<>(numKeys);
+
+ BinaryRowData buildRow = buildSideSerializer.createInstance();
+ while ((buildRow = buildInput.next(buildRow)) != null) {
+ table.putBuildRow(buildRow);
}
+ table.endBuild();
+
+ BinaryRowData probeRow = probeSideSerializer.createInstance();
+ while ((probeRow = probeInput.next(probeRow)) != null) {
+ if (table.tryProbe(probeRow)) {
+ testJoin(table, map);
+ }
+ }
+
+ while (table.nextMatching()) {
+ testJoin(table, map);
+ }
+
+ // The partition which spill to disk more than 3 can't be joined
+ assertThat(map.size()).as("Wrong number of records in join result.").isLessThan(numKeys);
+
+ // Here exists two partition which spill to disk more than 3
+ assertThat(table.getPartitionsPendingForSMJ().size())
+ .as("Wrong number of spilled partition.")
+ .isEqualTo(2);
+
+ Map<Integer, Integer> spilledPartitionBuildSideKeys = new HashMap<>();
+ Map<Integer, Integer> spilledPartitionProbeSideKeys = new HashMap<>();
+ for (BinaryHashPartition p : table.getPartitionsPendingForSMJ()) {
+ RowIterator<BinaryRowData> buildIter = table.getSpilledPartitionBuildSideIter(p);
+ while (buildIter.advanceNext()) {
+ Integer key = buildIter.getRow().getInt(0);
+ spilledPartitionBuildSideKeys.put(
+ key, spilledPartitionBuildSideKeys.getOrDefault(key, 0) + 1);
+ }
+
+ ProbeIterator probeIter = table.getSpilledPartitionProbeSideIter(p);
+ BinaryRowData rowData;
+ while ((rowData = probeIter.next()) != null) {
+ Integer key = rowData.getInt(0);
+ spilledPartitionProbeSideKeys.put(
+ key, spilledPartitionProbeSideKeys.getOrDefault(key, 0) + 1);
+ }
+ }
+
+ // assert spilled partition contains key repeatedValue1 and repeatedValue2
+ Integer buildKeyCnt = repeatedValueCount + buildValsPerKey;
+ assertThat(spilledPartitionBuildSideKeys).containsEntry(repeatedValue1, buildKeyCnt);
+ assertThat(spilledPartitionBuildSideKeys).containsEntry(repeatedValue2, buildKeyCnt);
+
+ Integer probeKeyCnt = repeatedValueCount + probeValsPerKey;
+ assertThat(spilledPartitionProbeSideKeys).containsEntry(repeatedValue1, probeKeyCnt);
+ assertThat(spilledPartitionProbeSideKeys).containsEntry(repeatedValue2, probeKeyCnt);
table.close();
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/LongHashTableTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/LongHashTableTest.java
index 4ae6506ec50..8ec312ee392 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/LongHashTableTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/LongHashTableTest.java
@@ -433,7 +433,7 @@ public class LongHashTableTest {
* fits into memory by itself and needs to be repartitioned in the recursion again.
*/
@TestTemplate
- void testFailingHashJoinTooManyRecursions() throws IOException {
+ void testSpillingHashJoinWithTooManyRecursions() throws IOException {
// the following two values are known to have a hash-code collision on the first recursion
// level.
// we use them to make sure one partition grows over-proportionally large
@@ -477,12 +477,61 @@ public class LongHashTableTest {
MutableObjectIterator<BinaryRowData> probeInput = new UnionIterator<>(probes);
final MyHashTable table = new MyHashTable(896 * PAGE_SIZE);
- try {
- join(table, buildInput, probeInput);
- fail("Hash Join must have failed due to too many recursions.");
- } catch (Exception ex) {
- // expected
+ // create the map for validating the results
+ HashMap<Integer, Long> map = new HashMap<>(numKeys);
+
+ BinaryRowData buildRow = buildSideSerializer.createInstance();
+ while ((buildRow = buildInput.next(buildRow)) != null) {
+ table.putBuildRow(buildRow);
}
+ table.endBuild();
+
+ BinaryRowData probeRow = probeSideSerializer.createInstance();
+ while ((probeRow = probeInput.next(probeRow)) != null) {
+ if (table.tryProbe(probeRow)) {
+ testJoin(table, map);
+ }
+ }
+
+ while (table.nextMatching()) {
+ testJoin(table, map);
+ }
+
+ // The partition which spill to disk more than 3 can't be joined
+ assertThat(map.size()).as("Wrong number of records in join result.").isLessThan(numKeys);
+
+ // Here exists two partition which spill to disk more than 3
+ assertThat(table.getPartitionsPendingForSMJ().size())
+ .as("Wrong number of spilled partition.")
+ .isEqualTo(2);
+
+ Map<Integer, Integer> spilledPartitionBuildSideKeys = new HashMap<>();
+ Map<Integer, Integer> spilledPartitionProbeSideKeys = new HashMap<>();
+ for (LongHashPartition p : table.getPartitionsPendingForSMJ()) {
+ RowIterator<BinaryRowData> buildIter = table.getSpilledPartitionBuildSideIter(p);
+ while (buildIter.advanceNext()) {
+ Integer key = buildIter.getRow().getInt(0);
+ spilledPartitionBuildSideKeys.put(
+ key, spilledPartitionBuildSideKeys.getOrDefault(key, 0) + 1);
+ }
+
+ ProbeIterator probeIter = table.getSpilledPartitionProbeSideIter(p);
+ BinaryRowData rowData;
+ while ((rowData = probeIter.next()) != null) {
+ Integer key = rowData.getInt(0);
+ spilledPartitionProbeSideKeys.put(
+ key, spilledPartitionProbeSideKeys.getOrDefault(key, 0) + 1);
+ }
+ }
+
+ // assert spilled partition contains key repeatedValue1 and repeatedValue2
+ Integer buildKeyCnt = repeatedValueCount + buildValsPerKey;
+ assertThat(spilledPartitionBuildSideKeys).containsEntry(repeatedValue1, buildKeyCnt);
+ assertThat(spilledPartitionBuildSideKeys).containsEntry(repeatedValue2, buildKeyCnt);
+
+ Integer probeKeyCnt = repeatedValueCount + probeValsPerKey;
+ assertThat(spilledPartitionProbeSideKeys).containsEntry(repeatedValue1, probeKeyCnt);
+ assertThat(spilledPartitionProbeSideKeys).containsEntry(repeatedValue2, probeKeyCnt);
table.close();
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2AdaptiveHashJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2AdaptiveHashJoinOperatorTest.java
new file mode 100644
index 00000000000..5ca980c7ca9
--- /dev/null
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2AdaptiveHashJoinOperatorTest.java
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.join;
+
+import org.apache.flink.runtime.operators.testutils.UnionIterator;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.hashtable.BinaryHashTableTest;
+import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
+import org.apache.flink.util.MutableObjectIterator;
+
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/** Random test for adaptive {@link HashJoinOperator}. */
+public class Int2AdaptiveHashJoinOperatorTest extends Int2HashJoinOperatorTestBase {
+
+ // ---------------------- build first inner join -----------------------------------------
+ // ------------------- fallback to sort merge join in build or probe phase ---------------
+ @Test
+ public void testBuildFirstHashInnerJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 405590;
+ final int repeatedValue2 = 928820;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 160000;
+ final int buildValsPerKey = 3;
+ final int probeValsPerKey = 2;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ // output build side which matched the key
+ int expectOutSize = numKeys1 * buildValsPerKey * probeValsPerKey;
+ buildJoin(buildInput, probeInput, false, false, true, expectOutSize, numKeys1, -1);
+ }
+
+ // ---------------------- build first left out join -----------------------------------------
+ // -------------------- fallback to sort merge join in build or probe phase -----------------
+ @Test
+ public void testBuildFirstHashLeftOutJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 40559;
+ final int repeatedValue2 = 92882;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 50000;
+ final int buildValsPerKey = 3;
+ final int probeValsPerKey = 2;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ // output build side that is in left
+ int expectOutSize =
+ (numKeys1 - numKeys2) * buildValsPerKey * probeValsPerKey
+ + (numKeys1 - numKeys2) * buildValsPerKey
+ + repeatedValueCountBuild * probeValsPerKey
+ + repeatedValueCountBuild;
+ buildJoin(buildInput, probeInput, true, false, true, expectOutSize, numKeys1, -1);
+ }
+
+ // ---------------------- build first right out join -----------------------------------------
+ // --------------------- fallback to sort merge join in build or probe phase -----------------
+ @Test
+ public void testBuildFirstHashRightOutJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 40559;
+ final int repeatedValue2 = 92882;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 50000;
+ final int buildValsPerKey = 3;
+ final int probeValsPerKey = 1;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ // output probe side that is in right
+ int expectOutSize = numKeys2 * probeValsPerKey * buildValsPerKey + repeatedValueCountBuild;
+ buildJoin(buildInput, probeInput, false, true, true, expectOutSize, numKeys2, -1);
+ }
+
+ // ---------------------- build first full out join -----------------------------------------
+ // --------------------- fallback to sort merge join in build or probe phase ----------------
+ @Test
+ public void testBuildFirstHashFullOutJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 40559;
+ final int repeatedValue2 = 92882;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 150000;
+ final int buildValsPerKey = 3;
+ final int probeValsPerKey = 1;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ // output build side and probe side simultaneously
+ int expectOutSize =
+ numKeys1 * buildValsPerKey * probeValsPerKey
+ + repeatedValueCountBuild * 2
+ + numKeys2
+ - numKeys1;
+ buildJoin(buildInput, probeInput, true, true, true, expectOutSize, numKeys2, -1);
+ }
+
+ // ---------------------- build second left out join -----------------------------------------
+ // ---------------------- switch to sort merge join in build or probe phase ------------------
+ @Test
+ public void testBuildSecondHashLeftOutJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 40559;
+ final int repeatedValue2 = 92882;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 50000;
+ final int buildValsPerKey = 3;
+ final int probeValsPerKey = 1;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ // output probe side that is in left
+ int expectOutSize = numKeys2 * probeValsPerKey * buildValsPerKey + repeatedValueCountBuild;
+ buildJoin(buildInput, probeInput, true, false, false, expectOutSize, numKeys2, -1);
+ }
+
+ // ---------------------- build second right out join -----------------------------------------
+ // ---------------------- switch to sort merge join in build or probe phase -------------------
+ @Test
+ public void testBuildSecondHashRightOutJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 40559;
+ final int repeatedValue2 = 92882;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 50000;
+ final int buildValsPerKey = 3;
+ final int probeValsPerKey = 1;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ // output build side that is in right
+ int expectOutSize = numKeys1 * buildValsPerKey + repeatedValueCountBuild * 2;
+ buildJoin(buildInput, probeInput, false, true, false, expectOutSize, numKeys1, -1);
+ }
+
+ // ---------------------- switch to sort merge join in build or probe phase ------------------
+ @Test
+ public void testSemiJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 40559;
+ final int repeatedValue2 = 92882;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 100000;
+ final int buildValsPerKey = 3;
+ final int probeValsPerKey = 1;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ Object operator =
+ newOperator(33 * 32 * 1024, FlinkJoinType.SEMI, HashJoinType.SEMI, false, false);
+
+ // output probe side that is in left
+ joinAndAssert(operator, buildInput, probeInput, numKeys2, numKeys2, 0, true);
+ }
+
+ // ---------------------- fallback to sort merge join in build or probe phase ------------------
+ @Test
+ public void testAntiJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 40559;
+ final int repeatedValue2 = 92882;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 160000;
+ final int buildValsPerKey = 3;
+ final int probeValsPerKey = 1;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ Object operator =
+ newOperator(33 * 32 * 1024, FlinkJoinType.ANTI, HashJoinType.ANTI, false, false);
+
+ // output probe side that is in left
+ int expectOutSize = numKeys2 - numKeys1;
+ joinAndAssert(operator, buildInput, probeInput, expectOutSize, expectOutSize, 0, true);
+ }
+
+ // ---------------------- fallback to sort merge join in build or probe phase ------------------
+ @Test
+ public void testBuildLeftSemiJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 405;
+ final int repeatedValue2 = 928;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 100000;
+ final int numKeys2 = 1000;
+ final int buildValsPerKey = 1;
+ final int probeValsPerKey = 3;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ Object operator =
+ newOperator(
+ 33 * 32 * 1024,
+ FlinkJoinType.SEMI,
+ HashJoinType.BUILD_LEFT_SEMI,
+ true,
+ false);
+
+ // output build side in left that is matched with probe side
+ int expectOutSize = numKeys2 + repeatedValueCountBuild * 2;
+ joinAndAssert(operator, buildInput, probeInput, expectOutSize, numKeys1, -1, true);
+ }
+
+ // ---------------------- fallback to sort merge join in build or probe phase ------------------
+ @Test
+ public void testBuildLeftAntiJoinFallbackToSMJ() throws Exception {
+ // the following two values are known to have a hash-code collision on the first recursion
+ // level. we use them to make sure one partition grows over-proportionally large
+ final int repeatedValue1 = 40559;
+ final int repeatedValue2 = 92882;
+ final int repeatedValueCountBuild = 100000;
+
+ final int numKeys1 = 500000;
+ final int numKeys2 = 100000;
+ final int buildValsPerKey = 1;
+ final int probeValsPerKey = 3;
+
+ // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+ // 1 million pairs with two colliding keys
+ MutableObjectIterator<BinaryRowData> build1 =
+ new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+ MutableObjectIterator<BinaryRowData> build2 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue1, 17, repeatedValueCountBuild);
+ MutableObjectIterator<BinaryRowData> build3 =
+ new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+ repeatedValue2, 23, repeatedValueCountBuild);
+ List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+ builds.add(build1);
+ builds.add(build2);
+ builds.add(build3);
+ MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+ MutableObjectIterator<BinaryRowData> probeInput =
+ new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+ Object operator =
+ newOperator(
+ 33 * 32 * 1024,
+ FlinkJoinType.ANTI,
+ HashJoinType.BUILD_LEFT_ANTI,
+ true,
+ false);
+
+ // output build side in left that not matched with probe side
+ int expectOutSize = numKeys1 - numKeys2;
+ joinAndAssert(
+ operator, buildInput, probeInput, expectOutSize, numKeys1 - numKeys2, -1, true);
+ }
+}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
index af589b95954..dd4da788526 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
@@ -18,42 +18,14 @@
package org.apache.flink.table.runtime.operators.join;
-import org.apache.flink.api.common.functions.AbstractRichFunction;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.core.memory.ManagedMemoryUseCase;
-import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.streaming.api.operators.StreamOperator;
-import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
-import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTaskTestHarness;
-import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
-import org.apache.flink.table.data.utils.JoinedRowData;
-import org.apache.flink.table.data.writer.BinaryRowWriter;
-import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
-import org.apache.flink.table.runtime.generated.GeneratedProjection;
-import org.apache.flink.table.runtime.generated.JoinCondition;
-import org.apache.flink.table.runtime.generated.Projection;
-import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
-import org.apache.flink.table.types.logical.IntType;
-import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.MutableObjectIterator;
import org.junit.Test;
-import java.io.Serializable;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Queue;
-import java.util.Random;
-
-import static java.lang.Long.valueOf;
-import static org.assertj.core.api.Assertions.assertThat;
-
/** Random test for {@link HashJoinOperator}. */
-public class Int2HashJoinOperatorTest implements Serializable {
+public class Int2HashJoinOperatorTest extends Int2HashJoinOperatorTestBase {
// ---------------------- build first inner join -----------------------------------------
@Test
@@ -233,8 +205,8 @@ public class Int2HashJoinOperatorTest implements Serializable {
MutableObjectIterator<BinaryRowData> probeInput =
new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
- HashJoinType type = HashJoinType.SEMI;
- Object operator = newOperator(33 * 32 * 1024, type, false);
+ Object operator =
+ newOperator(33 * 32 * 1024, FlinkJoinType.SEMI, HashJoinType.SEMI, false, false);
joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
}
@@ -250,14 +222,13 @@ public class Int2HashJoinOperatorTest implements Serializable {
MutableObjectIterator<BinaryRowData> probeInput =
new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
- HashJoinType type = HashJoinType.ANTI;
- Object operator = newOperator(33 * 32 * 1024, type, false);
+ Object operator =
+ newOperator(33 * 32 * 1024, FlinkJoinType.ANTI, HashJoinType.ANTI, false, false);
joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
}
@Test
public void testBuildLeftSemiJoin() throws Exception {
-
int numKeys1 = 10;
int numKeys2 = 9;
int buildValsPerKey = 10;
@@ -267,14 +238,18 @@ public class Int2HashJoinOperatorTest implements Serializable {
MutableObjectIterator<BinaryRowData> probeInput =
new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
- HashJoinType type = HashJoinType.BUILD_LEFT_SEMI;
- Object operator = newOperator(33 * 32 * 1024, type, false);
+ Object operator =
+ newOperator(
+ 33 * 32 * 1024,
+ FlinkJoinType.SEMI,
+ HashJoinType.BUILD_LEFT_SEMI,
+ true,
+ false);
joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
}
@Test
public void testBuildLeftAntiJoin() throws Exception {
-
int numKeys1 = 10;
int numKeys2 = 9;
int buildValsPerKey = 10;
@@ -284,237 +259,13 @@ public class Int2HashJoinOperatorTest implements Serializable {
MutableObjectIterator<BinaryRowData> probeInput =
new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
- HashJoinType type = HashJoinType.BUILD_LEFT_ANTI;
- Object operator = newOperator(33 * 32 * 1024, type, false);
+ Object operator =
+ newOperator(
+ 33 * 32 * 1024,
+ FlinkJoinType.ANTI,
+ HashJoinType.BUILD_LEFT_ANTI,
+ true,
+ false);
joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
}
-
- private void buildJoin(
- MutableObjectIterator<BinaryRowData> buildInput,
- MutableObjectIterator<BinaryRowData> probeInput,
- boolean leftOut,
- boolean rightOut,
- boolean buildLeft,
- int expectOutSize,
- int expectOutKeySize,
- int expectOutVal)
- throws Exception {
- HashJoinType type = HashJoinType.of(buildLeft, leftOut, rightOut);
- Object operator = newOperator(33 * 32 * 1024, type, !buildLeft);
- joinAndAssert(
- operator,
- buildInput,
- probeInput,
- expectOutSize,
- expectOutKeySize,
- expectOutVal,
- false);
- }
-
- @SuppressWarnings("unchecked")
- static void joinAndAssert(
- Object operator,
- MutableObjectIterator<BinaryRowData> input1,
- MutableObjectIterator<BinaryRowData> input2,
- int expectOutSize,
- int expectOutKeySize,
- int expectOutVal,
- boolean semiJoin)
- throws Exception {
- InternalTypeInfo<RowData> typeInfo =
- InternalTypeInfo.ofFields(new IntType(), new IntType());
- InternalTypeInfo<RowData> rowDataTypeInfo =
- InternalTypeInfo.ofFields(
- new IntType(), new IntType(), new IntType(), new IntType());
- TwoInputStreamTaskTestHarness<BinaryRowData, BinaryRowData, JoinedRowData> testHarness =
- new TwoInputStreamTaskTestHarness<>(
- TwoInputStreamTask::new,
- 2,
- 1,
- new int[] {1, 2},
- typeInfo,
- (TypeInformation) typeInfo,
- rowDataTypeInfo);
- testHarness.memorySize = 36 * 1024 * 1024;
- testHarness.getExecutionConfig().enableObjectReuse();
- testHarness.setupOutputForSingletonOperatorChain();
- if (operator instanceof StreamOperator) {
- testHarness.getStreamConfig().setStreamOperator((StreamOperator<?>) operator);
- } else {
- testHarness
- .getStreamConfig()
- .setStreamOperatorFactory((StreamOperatorFactory<?>) operator);
- }
- testHarness.getStreamConfig().setOperatorID(new OperatorID());
- testHarness
- .getStreamConfig()
- .setManagedMemoryFractionOperatorOfUseCase(ManagedMemoryUseCase.OPERATOR, 0.99);
-
- testHarness.invoke();
- testHarness.waitForTaskRunning();
-
- Random random = new Random();
- do {
- BinaryRowData row1 = null;
- BinaryRowData row2 = null;
-
- if (random.nextInt(2) == 0) {
- row1 = input1.next();
- if (row1 == null) {
- row2 = input2.next();
- }
- } else {
- row2 = input2.next();
- if (row2 == null) {
- row1 = input1.next();
- }
- }
-
- if (row1 == null && row2 == null) {
- break;
- }
-
- if (row1 != null) {
- testHarness.processElement(new StreamRecord<>(row1), 0, 0);
- } else {
- testHarness.processElement(new StreamRecord<>(row2), 1, 0);
- }
- } while (true);
-
- testHarness.endInput(0, 0);
- testHarness.endInput(1, 0);
-
- testHarness.waitForInputProcessing();
- testHarness.waitForTaskCompletion();
-
- Queue<Object> actual = testHarness.getOutput();
-
- assertThat(actual).as("Output was not correct.").hasSize(expectOutSize);
-
- // Don't verify the output value when experOutVal is -1
- if (expectOutVal != -1) {
- if (semiJoin) {
- HashMap<Integer, Long> map = new HashMap<>(expectOutKeySize);
-
- for (Object o : actual) {
- StreamRecord<RowData> record = (StreamRecord<RowData>) o;
- RowData row = record.getValue();
- int key = row.getInt(0);
- int val = row.getInt(1);
- Long contained = map.get(key);
- if (contained == null) {
- contained = (long) val;
- } else {
- contained = valueOf(contained + val);
- }
- map.put(key, contained);
- }
-
- assertThat(map).as("Wrong number of keys").hasSize(expectOutKeySize);
- for (Map.Entry<Integer, Long> entry : map.entrySet()) {
- long val = entry.getValue();
- int key = entry.getKey();
-
- assertThat(val)
- .as("Wrong number of values in per-key cross product for key " + key)
- .isEqualTo(expectOutVal);
- }
- } else {
- // create the map for validating the results
- HashMap<Integer, Long> map = new HashMap<>(expectOutKeySize);
-
- for (Object o : actual) {
- StreamRecord<RowData> record = (StreamRecord<RowData>) o;
- RowData row = record.getValue();
- int key = row.isNullAt(0) ? row.getInt(2) : row.getInt(0);
-
- int val1 = 0;
- int val2 = 0;
- if (!row.isNullAt(1)) {
- val1 = row.getInt(1);
- }
- if (!row.isNullAt(3)) {
- val2 = row.getInt(3);
- }
- int val = val1 + val2;
-
- Long contained = map.get(key);
- if (contained == null) {
- contained = (long) val;
- } else {
- contained = valueOf(contained + val);
- }
- map.put(key, contained);
- }
-
- assertThat(map).as("Wrong number of keys").hasSize(expectOutKeySize);
- for (Map.Entry<Integer, Long> entry : map.entrySet()) {
- long val = entry.getValue();
- int key = entry.getKey();
-
- assertThat(val)
- .as("Wrong number of values in per-key cross product for key " + key)
- .isEqualTo(expectOutVal);
- }
- }
- }
- }
-
- /** my projection. */
- public static final class MyProjection implements Projection<RowData, BinaryRowData> {
-
- BinaryRowData innerRow = new BinaryRowData(1);
- BinaryRowWriter writer = new BinaryRowWriter(innerRow);
-
- @Override
- public BinaryRowData apply(RowData row) {
- writer.reset();
- if (row.isNullAt(0)) {
- writer.setNullAt(0);
- } else {
- writer.writeInt(0, row.getInt(0));
- }
- writer.complete();
- return innerRow;
- }
- }
-
- public Object newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) {
- return HashJoinOperator.newHashJoinOperator(
- type,
- new GeneratedJoinCondition("", "", new Object[0]) {
- @Override
- public JoinCondition newInstance(ClassLoader classLoader) {
- return new TrueCondition();
- }
- },
- reverseJoinFunction,
- new boolean[] {true},
- new GeneratedProjection("", "", new Object[0]) {
- @Override
- public Projection newInstance(ClassLoader classLoader) {
- return new MyProjection();
- }
- },
- new GeneratedProjection("", "", new Object[0]) {
- @Override
- public Projection newInstance(ClassLoader classLoader) {
- return new MyProjection();
- }
- },
- false,
- 20,
- 10000,
- 10000,
- RowType.of(new IntType()));
- }
-
- /** Test util. */
- public static class TrueCondition extends AbstractRichFunction implements JoinCondition {
-
- @Override
- public boolean apply(RowData in1, RowData in2) {
- return true;
- }
- }
}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
similarity index 52%
copy from flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
copy to flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
index af589b95954..0988d272dfc 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
@@ -32,17 +32,20 @@ import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.data.utils.JoinedRowData;
import org.apache.flink.table.data.writer.BinaryRowWriter;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.generated.GeneratedNormalizedKeyComputer;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.generated.NormalizedKeyComputer;
import org.apache.flink.table.runtime.generated.Projection;
+import org.apache.flink.table.runtime.generated.RecordComparator;
+import org.apache.flink.table.runtime.operators.sort.IntNormalizedKeyComputer;
+import org.apache.flink.table.runtime.operators.sort.IntRecordComparator;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
-import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.MutableObjectIterator;
-import org.junit.Test;
-
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@@ -50,246 +53,13 @@ import java.util.Queue;
import java.util.Random;
import static java.lang.Long.valueOf;
+import static org.apache.flink.table.runtime.util.JoinUtil.getJoinType;
import static org.assertj.core.api.Assertions.assertThat;
-/** Random test for {@link HashJoinOperator}. */
-public class Int2HashJoinOperatorTest implements Serializable {
-
- // ---------------------- build first inner join -----------------------------------------
- @Test
- public void testBuildFirstHashInnerJoin() throws Exception {
-
- int numKeys = 100;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true);
-
- buildJoin(
- buildInput,
- probeInput,
- false,
- false,
- true,
- numKeys * buildValsPerKey * probeValsPerKey,
- numKeys,
- 165);
- }
-
- // ---------------------- build first left out join -----------------------------------------
- @Test
- public void testBuildFirstHashLeftOutJoin() throws Exception {
-
- int numKeys1 = 9;
- int numKeys2 = 10;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- buildJoin(
- buildInput,
- probeInput,
- true,
- false,
- true,
- numKeys1 * buildValsPerKey * probeValsPerKey,
- numKeys1,
- 165);
- }
-
- // ---------------------- build first right out join -----------------------------------------
- @Test
- public void testBuildFirstHashRightOutJoin() throws Exception {
-
- int numKeys1 = 9;
- int numKeys2 = 10;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- buildJoin(buildInput, probeInput, false, true, true, 280, numKeys2, -1);
- }
-
- // ---------------------- build first full out join -----------------------------------------
- @Test
- public void testBuildFirstHashFullOutJoin() throws Exception {
-
- int numKeys1 = 9;
- int numKeys2 = 10;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- buildJoin(buildInput, probeInput, true, true, true, 280, numKeys2, -1);
- }
-
- // ---------------------- build second inner join -----------------------------------------
- @Test
- public void testBuildSecondHashInnerJoin() throws Exception {
-
- int numKeys = 100;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true);
-
- buildJoin(
- buildInput,
- probeInput,
- false,
- false,
- false,
- numKeys * buildValsPerKey * probeValsPerKey,
- numKeys,
- 165);
- }
+/** Base test class for {@link HashJoinOperator}. */
+public abstract class Int2HashJoinOperatorTestBase implements Serializable {
- // ---------------------- build second left out join -----------------------------------------
- @Test
- public void testBuildSecondHashLeftOutJoin() throws Exception {
-
- int numKeys1 = 10;
- int numKeys2 = 9;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- buildJoin(
- buildInput,
- probeInput,
- true,
- false,
- false,
- numKeys2 * buildValsPerKey * probeValsPerKey,
- numKeys2,
- 165);
- }
-
- // ---------------------- build second right out join -----------------------------------------
- @Test
- public void testBuildSecondHashRightOutJoin() throws Exception {
-
- int numKeys1 = 9;
- int numKeys2 = 10;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- buildJoin(
- buildInput,
- probeInput,
- false,
- true,
- false,
- numKeys1 * buildValsPerKey * probeValsPerKey,
- numKeys2,
- -1);
- }
-
- // ---------------------- build second full out join -----------------------------------------
- @Test
- public void testBuildSecondHashFullOutJoin() throws Exception {
-
- int numKeys1 = 9;
- int numKeys2 = 10;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- buildJoin(buildInput, probeInput, true, true, false, 280, numKeys2, -1);
- }
-
- @Test
- public void testSemiJoin() throws Exception {
-
- int numKeys1 = 9;
- int numKeys2 = 10;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- HashJoinType type = HashJoinType.SEMI;
- Object operator = newOperator(33 * 32 * 1024, type, false);
- joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
- }
-
- @Test
- public void testAntiJoin() throws Exception {
-
- int numKeys1 = 9;
- int numKeys2 = 10;
- int buildValsPerKey = 3;
- int probeValsPerKey = 10;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- HashJoinType type = HashJoinType.ANTI;
- Object operator = newOperator(33 * 32 * 1024, type, false);
- joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
- }
-
- @Test
- public void testBuildLeftSemiJoin() throws Exception {
-
- int numKeys1 = 10;
- int numKeys2 = 9;
- int buildValsPerKey = 10;
- int probeValsPerKey = 3;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- HashJoinType type = HashJoinType.BUILD_LEFT_SEMI;
- Object operator = newOperator(33 * 32 * 1024, type, false);
- joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
- }
-
- @Test
- public void testBuildLeftAntiJoin() throws Exception {
-
- int numKeys1 = 10;
- int numKeys2 = 9;
- int buildValsPerKey = 10;
- int probeValsPerKey = 3;
- MutableObjectIterator<BinaryRowData> buildInput =
- new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
- MutableObjectIterator<BinaryRowData> probeInput =
- new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
- HashJoinType type = HashJoinType.BUILD_LEFT_ANTI;
- Object operator = newOperator(33 * 32 * 1024, type, false);
- joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
- }
-
- private void buildJoin(
+ public void buildJoin(
MutableObjectIterator<BinaryRowData> buildInput,
MutableObjectIterator<BinaryRowData> probeInput,
boolean leftOut,
@@ -299,8 +69,10 @@ public class Int2HashJoinOperatorTest implements Serializable {
int expectOutKeySize,
int expectOutVal)
throws Exception {
- HashJoinType type = HashJoinType.of(buildLeft, leftOut, rightOut);
- Object operator = newOperator(33 * 32 * 1024, type, !buildLeft);
+ FlinkJoinType flinkJoinType = getJoinType(leftOut, rightOut);
+ HashJoinType hashJoinType = HashJoinType.of(buildLeft, leftOut, rightOut);
+ Object operator =
+ newOperator(33 * 32 * 1024, flinkJoinType, hashJoinType, buildLeft, !buildLeft);
joinAndAssert(
operator,
buildInput,
@@ -311,8 +83,122 @@ public class Int2HashJoinOperatorTest implements Serializable {
false);
}
+ public Object newOperator(
+ long memorySize,
+ FlinkJoinType flinkJoinType,
+ HashJoinType hashJoinType,
+ boolean buildLeft,
+ boolean reverseJoinFunction) {
+ GeneratedJoinCondition condFuncCode =
+ new GeneratedJoinCondition("", "", new Object[0]) {
+ @Override
+ public JoinCondition newInstance(ClassLoader classLoader) {
+ return new TrueCondition();
+ }
+ };
+ GeneratedProjection buildProjectionCode =
+ new GeneratedProjection("", "", new Object[0]) {
+ @Override
+ public Projection newInstance(ClassLoader classLoader) {
+ return new MyProjection();
+ }
+ };
+ GeneratedProjection probeProjectionCode =
+ new GeneratedProjection("", "", new Object[0]) {
+ @Override
+ public Projection newInstance(ClassLoader classLoader) {
+ return new MyProjection();
+ }
+ };
+ GeneratedNormalizedKeyComputer computer1 =
+ new GeneratedNormalizedKeyComputer("", "") {
+ @Override
+ public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+ return new IntNormalizedKeyComputer();
+ }
+ };
+ GeneratedRecordComparator comparator1 =
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new IntRecordComparator();
+ }
+ };
+
+ GeneratedNormalizedKeyComputer computer2 =
+ new GeneratedNormalizedKeyComputer("", "") {
+ @Override
+ public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+ return new IntNormalizedKeyComputer();
+ }
+ };
+ GeneratedRecordComparator comparator2 =
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new IntRecordComparator();
+ }
+ };
+ GeneratedRecordComparator genKeyComparator =
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new IntRecordComparator();
+ }
+ };
+ boolean[] filterNulls = new boolean[] {true};
+
+ SortMergeJoinFunction sortMergeJoinFunction;
+ if (buildLeft) {
+ sortMergeJoinFunction =
+ new SortMergeJoinFunction(
+ 0,
+ flinkJoinType,
+ buildLeft,
+ condFuncCode,
+ buildProjectionCode,
+ probeProjectionCode,
+ computer1,
+ comparator1,
+ computer2,
+ comparator2,
+ genKeyComparator,
+ filterNulls);
+ } else {
+ sortMergeJoinFunction =
+ new SortMergeJoinFunction(
+ 0,
+ flinkJoinType,
+ buildLeft,
+ condFuncCode,
+ probeProjectionCode,
+ buildProjectionCode,
+ computer2,
+ comparator2,
+ computer1,
+ comparator1,
+ genKeyComparator,
+ filterNulls);
+ }
+
+ return HashJoinOperator.newHashJoinOperator(
+ hashJoinType,
+ buildLeft,
+ condFuncCode,
+ reverseJoinFunction,
+ filterNulls,
+ buildProjectionCode,
+ probeProjectionCode,
+ false,
+ 20,
+ 10000,
+ 10000,
+ RowType.of(new IntType()),
+ sortMergeJoinFunction);
+ }
+
@SuppressWarnings("unchecked")
- static void joinAndAssert(
+ public static void joinAndAssert(
Object operator,
MutableObjectIterator<BinaryRowData> input1,
MutableObjectIterator<BinaryRowData> input2,
@@ -335,7 +221,7 @@ public class Int2HashJoinOperatorTest implements Serializable {
typeInfo,
(TypeInformation) typeInfo,
rowDataTypeInfo);
- testHarness.memorySize = 36 * 1024 * 1024;
+ testHarness.memorySize = 3 * 1024 * 1024;
testHarness.getExecutionConfig().enableObjectReuse();
testHarness.setupOutputForSingletonOperatorChain();
if (operator instanceof StreamOperator) {
@@ -479,36 +365,6 @@ public class Int2HashJoinOperatorTest implements Serializable {
}
}
- public Object newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) {
- return HashJoinOperator.newHashJoinOperator(
- type,
- new GeneratedJoinCondition("", "", new Object[0]) {
- @Override
- public JoinCondition newInstance(ClassLoader classLoader) {
- return new TrueCondition();
- }
- },
- reverseJoinFunction,
- new boolean[] {true},
- new GeneratedProjection("", "", new Object[0]) {
- @Override
- public Projection newInstance(ClassLoader classLoader) {
- return new MyProjection();
- }
- },
- new GeneratedProjection("", "", new Object[0]) {
- @Override
- public Projection newInstance(ClassLoader classLoader) {
- return new MyProjection();
- }
- },
- false,
- 20,
- 10000,
- 10000,
- RowType.of(new IntType()));
- }
-
/** Test util. */
public static class TrueCondition extends AbstractRichFunction implements JoinCondition {
@@ -517,4 +373,15 @@ public class Int2HashJoinOperatorTest implements Serializable {
return true;
}
}
+
+ /** Test cond. */
+ public static class MyJoinCondition extends AbstractRichFunction implements JoinCondition {
+
+ public MyJoinCondition(Object[] reference) {}
+
+ @Override
+ public boolean apply(RowData in1, RowData in2) {
+ return true;
+ }
+ }
}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2SortMergeJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2SortMergeJoinOperatorTest.java
index 01f6d0288dd..9deb124cb5c 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2SortMergeJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2SortMergeJoinOperatorTest.java
@@ -32,7 +32,7 @@ import org.apache.flink.table.runtime.generated.JoinCondition;
import org.apache.flink.table.runtime.generated.NormalizedKeyComputer;
import org.apache.flink.table.runtime.generated.Projection;
import org.apache.flink.table.runtime.generated.RecordComparator;
-import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTest.MyProjection;
+import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTestBase.MyProjection;
import org.apache.flink.table.runtime.operators.sort.IntNormalizedKeyComputer;
import org.apache.flink.table.runtime.operators.sort.IntRecordComparator;
import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
@@ -211,7 +211,11 @@ public class Int2SortMergeJoinOperatorTest {
}
static StreamOperator newOperator(FlinkJoinType type, boolean leftIsSmaller) {
- return new SortMergeJoinOperator(
+ return new SortMergeJoinOperator(getJoinFunction(type, leftIsSmaller));
+ }
+
+ public static SortMergeJoinFunction getJoinFunction(FlinkJoinType type, boolean leftIsSmaller) {
+ return new SortMergeJoinFunction(
0,
type,
leftIsSmaller,
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinIteratorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinIteratorTest.java
index a938ea882df..60d01c9466c 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinIteratorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinIteratorTest.java
@@ -27,7 +27,7 @@ import org.apache.flink.runtime.memory.MemoryManagerBuilder;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.data.writer.BinaryRowWriter;
-import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTest.MyProjection;
+import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTestBase.MyProjection;
import org.apache.flink.table.runtime.operators.sort.IntRecordComparator;
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
import org.apache.flink.table.runtime.util.LazyMemorySegmentPool;
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2HashJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2HashJoinOperatorTest.java
index 31eb491376d..ed6ff982eba 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2HashJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2HashJoinOperatorTest.java
@@ -31,10 +31,17 @@ import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.data.utils.JoinedRowData;
import org.apache.flink.table.data.writer.BinaryRowWriter;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.generated.GeneratedNormalizedKeyComputer;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.generated.NormalizedKeyComputer;
import org.apache.flink.table.runtime.generated.Projection;
+import org.apache.flink.table.runtime.generated.RecordComparator;
+import org.apache.flink.table.runtime.operators.sort.StringNormalizedKeyComputer;
+import org.apache.flink.table.runtime.operators.sort.StringRecordComparator;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.runtime.util.JoinUtil;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.VarCharType;
@@ -81,8 +88,10 @@ public class String2HashJoinOperatorTest implements Serializable {
}
private void init(boolean leftOut, boolean rightOut, boolean buildLeft) throws Exception {
- HashJoinType type = HashJoinType.of(buildLeft, leftOut, rightOut);
- HashJoinOperator operator = newOperator(33 * 32 * 1024, type, !buildLeft);
+ FlinkJoinType flinkJoinType = JoinUtil.getJoinType(leftOut, rightOut);
+ HashJoinType hashJoinType = HashJoinType.of(buildLeft, leftOut, rightOut);
+ HashJoinOperator operator =
+ newOperator(33 * 32 * 1024, flinkJoinType, hashJoinType, !buildLeft);
testHarness =
new TwoInputStreamTaskTestHarness<>(
TwoInputStreamTask::new,
@@ -325,33 +334,98 @@ public class String2HashJoinOperatorTest implements Serializable {
}
private HashJoinOperator newOperator(
- long memorySize, HashJoinType type, boolean reverseJoinFunction) {
- return HashJoinOperator.newHashJoinOperator(
- type,
+ long memorySize,
+ FlinkJoinType flinkJoinType,
+ HashJoinType hashJoinType,
+ boolean reverseJoinFunction) {
+ boolean buildLeft = false;
+ GeneratedJoinCondition condFuncCode =
new GeneratedJoinCondition("", "", new Object[0]) {
@Override
public JoinCondition newInstance(ClassLoader classLoader) {
- return new Int2HashJoinOperatorTest.TrueCondition();
+ return new Int2HashJoinOperatorTestBase.TrueCondition();
}
- },
- reverseJoinFunction,
- new boolean[] {true},
+ };
+ GeneratedProjection buildProjectionCode =
new GeneratedProjection("", "", new Object[0]) {
@Override
public Projection newInstance(ClassLoader classLoader) {
return new MyProjection();
}
- },
+ };
+ GeneratedProjection probeProjectionCode =
new GeneratedProjection("", "", new Object[0]) {
@Override
public Projection newInstance(ClassLoader classLoader) {
return new MyProjection();
}
- },
+ };
+ GeneratedNormalizedKeyComputer computer1 =
+ new GeneratedNormalizedKeyComputer("", "") {
+ @Override
+ public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+ return new StringNormalizedKeyComputer();
+ }
+ };
+ GeneratedRecordComparator comparator1 =
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new StringRecordComparator();
+ }
+ };
+
+ GeneratedNormalizedKeyComputer computer2 =
+ new GeneratedNormalizedKeyComputer("", "") {
+ @Override
+ public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+ return new StringNormalizedKeyComputer();
+ }
+ };
+ GeneratedRecordComparator comparator2 =
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new StringRecordComparator();
+ }
+ };
+ GeneratedRecordComparator genKeyComparator =
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new StringRecordComparator();
+ }
+ };
+ boolean[] filterNulls = new boolean[] {true};
+
+ SortMergeJoinFunction sortMergeJoinFunction =
+ new SortMergeJoinFunction(
+ 0,
+ flinkJoinType,
+ buildLeft,
+ condFuncCode,
+ probeProjectionCode,
+ buildProjectionCode,
+ computer2,
+ comparator2,
+ computer1,
+ comparator1,
+ genKeyComparator,
+ filterNulls);
+
+ return HashJoinOperator.newHashJoinOperator(
+ hashJoinType,
+ buildLeft,
+ condFuncCode,
+ reverseJoinFunction,
+ filterNulls,
+ buildProjectionCode,
+ probeProjectionCode,
false,
20,
10000,
10000,
- RowType.of(VarCharType.STRING_TYPE));
+ RowType.of(VarCharType.STRING_TYPE),
+ sortMergeJoinFunction);
}
}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2SortMergeJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2SortMergeJoinOperatorTest.java
index 719cecfba8a..cb1ea8357ca 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2SortMergeJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2SortMergeJoinOperatorTest.java
@@ -185,58 +185,60 @@ public class String2SortMergeJoinOperatorTest {
}
static StreamOperator newOperator(FlinkJoinType type, boolean leftIsSmaller) {
- return new SortMergeJoinOperator(
- 0,
- type,
- leftIsSmaller,
- new GeneratedJoinCondition("", "", new Object[0]) {
- @Override
- public JoinCondition newInstance(ClassLoader classLoader) {
- return new Int2HashJoinOperatorTest.TrueCondition();
- }
- },
- new GeneratedProjection("", "", new Object[0]) {
- @Override
- public Projection newInstance(ClassLoader classLoader) {
- return new MyProjection();
- }
- },
- new GeneratedProjection("", "", new Object[0]) {
- @Override
- public Projection newInstance(ClassLoader classLoader) {
- return new MyProjection();
- }
- },
- new GeneratedNormalizedKeyComputer("", "") {
- @Override
- public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
- return new StringNormalizedKeyComputer();
- }
- },
- new GeneratedRecordComparator("", "", new Object[0]) {
- @Override
- public RecordComparator newInstance(ClassLoader classLoader) {
- return new StringRecordComparator();
- }
- },
- new GeneratedNormalizedKeyComputer("", "") {
- @Override
- public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
- return new StringNormalizedKeyComputer();
- }
- },
- new GeneratedRecordComparator("", "", new Object[0]) {
- @Override
- public RecordComparator newInstance(ClassLoader classLoader) {
- return new StringRecordComparator();
- }
- },
- new GeneratedRecordComparator("", "", new Object[0]) {
- @Override
- public RecordComparator newInstance(ClassLoader classLoader) {
- return new StringRecordComparator();
- }
- },
- new boolean[] {true});
+ SortMergeJoinFunction sortMergeJoinFunction =
+ new SortMergeJoinFunction(
+ 0,
+ type,
+ leftIsSmaller,
+ new GeneratedJoinCondition("", "", new Object[0]) {
+ @Override
+ public JoinCondition newInstance(ClassLoader classLoader) {
+ return new Int2HashJoinOperatorTest.TrueCondition();
+ }
+ },
+ new GeneratedProjection("", "", new Object[0]) {
+ @Override
+ public Projection newInstance(ClassLoader classLoader) {
+ return new MyProjection();
+ }
+ },
+ new GeneratedProjection("", "", new Object[0]) {
+ @Override
+ public Projection newInstance(ClassLoader classLoader) {
+ return new MyProjection();
+ }
+ },
+ new GeneratedNormalizedKeyComputer("", "") {
+ @Override
+ public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+ return new StringNormalizedKeyComputer();
+ }
+ },
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new StringRecordComparator();
+ }
+ },
+ new GeneratedNormalizedKeyComputer("", "") {
+ @Override
+ public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+ return new StringNormalizedKeyComputer();
+ }
+ },
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new StringRecordComparator();
+ }
+ },
+ new GeneratedRecordComparator("", "", new Object[0]) {
+ @Override
+ public RecordComparator newInstance(ClassLoader classLoader) {
+ return new StringRecordComparator();
+ }
+ },
+ new boolean[] {true});
+ return new SortMergeJoinOperator(sortMergeJoinFunction);
}
}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/JoinUtil.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/JoinUtil.java
new file mode 100644
index 00000000000..cef98c84113
--- /dev/null
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/JoinUtil.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.util;
+
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+
+import static org.apache.flink.table.runtime.operators.join.FlinkJoinType.FULL;
+import static org.apache.flink.table.runtime.operators.join.FlinkJoinType.INNER;
+import static org.apache.flink.table.runtime.operators.join.FlinkJoinType.RIGHT;
+
+/** Utility for join. */
+public class JoinUtil {
+
+ public static FlinkJoinType getJoinType(boolean leftOuter, boolean rightOuter) {
+ if (leftOuter && rightOuter) {
+ return FULL;
+ } else if (leftOuter) {
+ return FlinkJoinType.LEFT;
+ } else if (rightOuter) {
+ return RIGHT;
+ } else {
+ return INNER;
+ }
+ }
+
+ private JoinUtil() {}
+}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/UniformBinaryRowGenerator.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/UniformBinaryRowGenerator.java
index 5113475fc38..74b26e9b085 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/UniformBinaryRowGenerator.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/UniformBinaryRowGenerator.java
@@ -23,7 +23,7 @@ import org.apache.flink.table.data.writer.BinaryRowWriter;
import org.apache.flink.types.IntValue;
import org.apache.flink.util.MutableObjectIterator;
-/** Uniform genarator for binary row. */
+/** Uniform generator for binary row. */
public class UniformBinaryRowGenerator implements MutableObjectIterator<BinaryRowData> {
int numKeys;