You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2022/08/10 02:23:13 UTC

[flink] branch master updated: [FLINK-26929][table-runtime] Introduce adaptive hash join strategy for batch hash join (#20365)

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 10acb641a03 [FLINK-26929][table-runtime] Introduce adaptive hash join strategy for batch hash join (#20365)
10acb641a03 is described below

commit 10acb641a036ae9d171452695b660c46d97c865a
Author: Ron <ld...@163.com>
AuthorDate: Wed Aug 10 10:23:05 2022 +0800

    [FLINK-26929][table-runtime] Introduce adaptive hash join strategy for batch hash join (#20365)
---
 .../planner/plan/nodes/exec/ExecNodeConfig.java    |   4 +-
 .../plan/nodes/exec/batch/BatchExecHashJoin.java   |  60 ++-
 .../nodes/exec/batch/BatchExecSortMergeJoin.java   |  61 +--
 .../plan/utils/SorMergeJoinOperatorUtil.java       |  83 ++++
 .../planner/codegen/LongHashJoinGenerator.scala    | 110 ++++-
 .../flink/table/planner/plan/utils/SortUtil.scala  |  12 +
 .../codegen/LongAdaptiveHashJoinGeneratorTest.java | 134 ++++++
 .../planner/codegen/LongHashJoinGeneratorTest.java |  58 +--
 .../runtime/hashtable/BaseHybridHashTable.java     |  11 +
 .../runtime/hashtable/BinaryHashPartition.java     |  10 +-
 .../table/runtime/hashtable/BinaryHashTable.java   | 122 ++++-
 .../table/runtime/hashtable/LongHashPartition.java |   4 +-
 .../runtime/hashtable/LongHybridHashTable.java     | 123 ++++-
 .../io/BinaryRowChannelInputViewIterator.java      |   6 +-
 ...shPartitionChannelReaderInputViewIterator.java} |  35 +-
 .../runtime/operators/join/HashJoinOperator.java   | 112 ++++-
 ...oinOperator.java => SortMergeJoinFunction.java} | 120 +++--
 .../operators/join/SortMergeJoinOperator.java      | 500 +--------------------
 .../runtime/hashtable/BinaryHashTableTest.java     |  61 ++-
 .../table/runtime/hashtable/LongHashTableTest.java |  61 ++-
 .../join/Int2AdaptiveHashJoinOperatorTest.java     | 450 +++++++++++++++++++
 .../operators/join/Int2HashJoinOperatorTest.java   | 287 +-----------
 ...Test.java => Int2HashJoinOperatorTestBase.java} | 415 ++++++-----------
 .../join/Int2SortMergeJoinOperatorTest.java        |   8 +-
 .../operators/join/SortMergeJoinIteratorTest.java  |   2 +-
 .../join/String2HashJoinOperatorTest.java          |  98 +++-
 .../join/String2SortMergeJoinOperatorTest.java     | 108 ++---
 .../apache/flink/table/runtime/util/JoinUtil.java  |  43 ++
 .../runtime/util/UniformBinaryRowGenerator.java    |   2 +-
 29 files changed, 1739 insertions(+), 1361 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/ExecNodeConfig.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/ExecNodeConfig.java
index b33396fbec0..f56165f66c0 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/ExecNodeConfig.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/ExecNodeConfig.java
@@ -19,6 +19,7 @@
 package org.apache.flink.table.planner.plan.nodes.exec;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.configuration.ConfigOption;
 import org.apache.flink.configuration.ReadableConfig;
 import org.apache.flink.table.api.TableConfig;
@@ -40,7 +41,8 @@ public final class ExecNodeConfig implements ReadableConfig {
 
     private final ReadableConfig nodeConfig;
 
-    ExecNodeConfig(TableConfig tableConfig, ReadableConfig nodeConfig) {
+    @VisibleForTesting
+    public ExecNodeConfig(TableConfig tableConfig, ReadableConfig nodeConfig) {
         this.nodeConfig = nodeConfig;
         this.tableConfig = tableConfig;
     }
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java
index a4f989264fd..85b3e017d50 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecHashJoin.java
@@ -37,11 +37,13 @@ import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTransl
 import org.apache.flink.table.planner.plan.nodes.exec.spec.JoinSpec;
 import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
 import org.apache.flink.table.planner.plan.utils.JoinUtil;
+import org.apache.flink.table.planner.plan.utils.SorMergeJoinOperatorUtil;
 import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
 import org.apache.flink.table.runtime.generated.GeneratedProjection;
 import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
 import org.apache.flink.table.runtime.operators.join.HashJoinOperator;
 import org.apache.flink.table.runtime.operators.join.HashJoinType;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
 import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
 import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.RowType;
@@ -162,7 +164,7 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
             probeTransform = rightInputTransform;
             probeProj = rightProj;
             probeType = rightType;
-            probeRowCount = estimatedLeftRowCount;
+            probeRowCount = estimatedRightRowCount;
             probeKeys = rightKeys;
         } else {
             buildTransform = rightInputTransform;
@@ -189,6 +191,28 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
                         joinType.isRightOuter(),
                         joinType == FlinkJoinType.SEMI,
                         joinType == FlinkJoinType.ANTI);
+
+        long externalBufferMemory =
+                config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY)
+                        .getBytes();
+        long managedMemory = getLargeManagedMemory(joinType, config);
+
+        // sort merge join function
+        SortMergeJoinFunction sortMergeJoinFunction =
+                SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+                        planner.getFlinkContext().getClassLoader(),
+                        config,
+                        joinType,
+                        leftType,
+                        rightType,
+                        leftKeys,
+                        rightKeys,
+                        keyType,
+                        leftIsBuild,
+                        joinSpec.getFilterNulls(),
+                        condFunc,
+                        1.0 * externalBufferMemory / managedMemory);
+
         if (LongHashJoinGenerator.support(hashJoinType, keyType, joinSpec.getFilterNulls())) {
             operator =
                     LongHashJoinGenerator.gen(
@@ -203,12 +227,15 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
                             buildRowSize,
                             buildRowCount,
                             reverseJoin,
-                            condFunc);
+                            condFunc,
+                            leftIsBuild,
+                            sortMergeJoinFunction);
         } else {
             operator =
                     SimpleOperatorFactory.of(
                             HashJoinOperator.newHashJoinOperator(
                                     hashJoinType,
+                                    leftIsBuild,
                                     condFunc,
                                     reverseJoin,
                                     joinSpec.getFilterNulls(),
@@ -218,12 +245,10 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
                                     buildRowSize,
                                     buildRowCount,
                                     probeRowCount,
-                                    keyType));
+                                    keyType,
+                                    sortMergeJoinFunction));
         }
 
-        long managedMemory =
-                config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY).getBytes();
-
         return ExecNodeUtil.createTwoInputTransformation(
                 buildTransform,
                 probeTransform,
@@ -234,4 +259,27 @@ public class BatchExecHashJoin extends ExecNodeBase<RowData>
                 probeTransform.getParallelism(),
                 managedMemory);
     }
+
+    private long getLargeManagedMemory(FlinkJoinType joinType, ExecNodeConfig config) {
+        long hashJoinManagedMemory =
+                config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY).getBytes();
+
+        // The memory used by SortMergeJoinIterator that buffer the matched rows, each side needs
+        // this memory if it is full outer join
+        long externalBufferMemory =
+                config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY)
+                        .getBytes();
+        // The memory used by BinaryExternalSorter for sort, the left and right side both need it
+        long sortMemory =
+                config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_SORT_MEMORY).getBytes();
+        int externalBufferNum = 1;
+        if (joinType == FlinkJoinType.FULL) {
+            externalBufferNum = 2;
+        }
+        long sortMergeJoinManagedMemory = externalBufferMemory * externalBufferNum + sortMemory * 2;
+
+        // Due to hash join maybe fallback to sort merge join, so here managed memory choose the
+        // large one
+        return Math.max(hashJoinManagedMemory, sortMergeJoinManagedMemory);
+    }
 }
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortMergeJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortMergeJoin.java
index 99b7ffa89c7..a04b90bd9d8 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortMergeJoin.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortMergeJoin.java
@@ -23,9 +23,6 @@ import org.apache.flink.configuration.ReadableConfig;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
 import org.apache.flink.table.api.config.ExecutionConfigOptions;
 import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
-import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
-import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator;
 import org.apache.flink.table.planner.delegation.PlannerBase;
 import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
 import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
@@ -33,12 +30,12 @@ import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
 import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext;
 import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
 import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
-import org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec;
 import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
 import org.apache.flink.table.planner.plan.utils.JoinUtil;
-import org.apache.flink.table.planner.plan.utils.SortUtil;
+import org.apache.flink.table.planner.plan.utils.SorMergeJoinOperatorUtil;
 import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
 import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
 import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator;
 import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
 import org.apache.flink.table.types.logical.LogicalType;
@@ -130,44 +127,20 @@ public class BatchExecSortMergeJoin extends ExecNodeBase<RowData>
 
         long managedMemory = externalBufferMemory * externalBufferNum + sortMemory * 2;
 
-        SortCodeGenerator leftSortGen =
-                newSortGen(config, planner.getFlinkContext().getClassLoader(), leftKeys, leftType);
-        SortCodeGenerator rightSortGen =
-                newSortGen(
-                        config, planner.getFlinkContext().getClassLoader(), rightKeys, rightType);
-
-        int[] keyPositions = IntStream.range(0, leftKeys.length).toArray();
-        SortMergeJoinOperator operator =
-                new SortMergeJoinOperator(
-                        1.0 * externalBufferMemory / managedMemory,
+        SortMergeJoinFunction sortMergeJoinFunction =
+                SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+                        planner.getFlinkContext().getClassLoader(),
+                        config,
                         joinType,
+                        leftType,
+                        rightType,
+                        leftKeys,
+                        rightKeys,
+                        keyType,
                         leftIsSmaller,
+                        filterNulls,
                         condFunc,
-                        ProjectionCodeGenerator.generateProjection(
-                                new CodeGeneratorContext(
-                                        config, planner.getFlinkContext().getClassLoader()),
-                                "SMJProjection",
-                                leftType,
-                                keyType,
-                                leftKeys),
-                        ProjectionCodeGenerator.generateProjection(
-                                new CodeGeneratorContext(
-                                        config, planner.getFlinkContext().getClassLoader()),
-                                "SMJProjection",
-                                rightType,
-                                keyType,
-                                rightKeys),
-                        leftSortGen.generateNormalizedKeyComputer("LeftComputer"),
-                        leftSortGen.generateRecordComparator("LeftComparator"),
-                        rightSortGen.generateNormalizedKeyComputer("RightComputer"),
-                        rightSortGen.generateRecordComparator("RightComparator"),
-                        newSortGen(
-                                        config,
-                                        planner.getFlinkContext().getClassLoader(),
-                                        keyPositions,
-                                        keyType)
-                                .generateRecordComparator("KeyComparator"),
-                        filterNulls);
+                        1.0 * externalBufferMemory / managedMemory);
 
         Transformation<RowData> leftInputTransform =
                 (Transformation<RowData>) leftInputEdge.translateToPlan(planner);
@@ -178,15 +151,9 @@ public class BatchExecSortMergeJoin extends ExecNodeBase<RowData>
                 rightInputTransform,
                 createTransformationName(config),
                 createTransformationDescription(config),
-                SimpleOperatorFactory.of(operator),
+                SimpleOperatorFactory.of(new SortMergeJoinOperator(sortMergeJoinFunction)),
                 InternalTypeInfo.of(getOutputType()),
                 rightInputTransform.getParallelism(),
                 managedMemory);
     }
-
-    private SortCodeGenerator newSortGen(
-            ExecNodeConfig config, ClassLoader classLoader, int[] originalKeys, RowType inputType) {
-        SortSpec sortSpec = SortUtil.getAscendingSortSpec(originalKeys);
-        return new SortCodeGenerator(config, classLoader, inputType, sortSpec);
-    }
 }
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java
new file mode 100644
index 00000000000..634ceb345dd
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/SorMergeJoinOperatorUtil.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.utils;
+
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
+import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
+import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator;
+import org.apache.flink.table.types.logical.RowType;
+
+import java.util.stream.IntStream;
+
+/** Utility for {@link SortMergeJoinOperator}. */
+public class SorMergeJoinOperatorUtil {
+
+    public static SortMergeJoinFunction getSortMergeJoinFunction(
+            ClassLoader classLoader,
+            ExecNodeConfig config,
+            FlinkJoinType joinType,
+            RowType leftType,
+            RowType rightType,
+            int[] leftKeys,
+            int[] rightKeys,
+            RowType keyType,
+            boolean leftIsSmaller,
+            boolean[] filterNulls,
+            GeneratedJoinCondition condFunc,
+            double externalBufferMemRatio) {
+        int[] keyPositions = IntStream.range(0, leftKeys.length).toArray();
+
+        SortCodeGenerator leftSortGen =
+                SortUtil.newSortGen(config, classLoader, leftKeys, leftType);
+        SortCodeGenerator rightSortGen =
+                SortUtil.newSortGen(config, classLoader, rightKeys, rightType);
+
+        return new SortMergeJoinFunction(
+                externalBufferMemRatio,
+                joinType,
+                leftIsSmaller,
+                condFunc,
+                ProjectionCodeGenerator.generateProjection(
+                        new CodeGeneratorContext(config, classLoader),
+                        "SMJProjection",
+                        leftType,
+                        keyType,
+                        leftKeys),
+                ProjectionCodeGenerator.generateProjection(
+                        new CodeGeneratorContext(config, classLoader),
+                        "SMJProjection",
+                        rightType,
+                        keyType,
+                        rightKeys),
+                leftSortGen.generateNormalizedKeyComputer("LeftComputer"),
+                leftSortGen.generateRecordComparator("LeftComparator"),
+                rightSortGen.generateNormalizedKeyComputer("RightComputer"),
+                rightSortGen.generateRecordComparator("RightComparator"),
+                SortUtil.newSortGen(config, classLoader, keyPositions, keyType)
+                        .generateRecordComparator("KeyComparator"),
+                filterNulls);
+    }
+
+    private SorMergeJoinOperatorUtil() {}
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
index 7d9c3ac9634..abea1abce62 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
@@ -24,10 +24,11 @@ import org.apache.flink.table.data.utils.JoinedRowData
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
 import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.{generateCollect, INPUT_SELECTION}
 import org.apache.flink.table.runtime.generated.{GeneratedJoinCondition, GeneratedProjection}
-import org.apache.flink.table.runtime.hashtable.{LongHashPartition, LongHybridHashTable}
+import org.apache.flink.table.runtime.hashtable.{LongHashPartition, LongHybridHashTable, ProbeIterator}
 import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
-import org.apache.flink.table.runtime.operators.join.HashJoinType
+import org.apache.flink.table.runtime.operators.join.{HashJoinType, SortMergeJoinFunction}
 import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer
+import org.apache.flink.table.runtime.util.{RowIterator, StreamRecordCollector}
 import org.apache.flink.table.types.logical._
 import org.apache.flink.table.types.logical.LogicalTypeRoot._
 
@@ -108,7 +109,9 @@ object LongHashJoinGenerator {
       buildRowSize: Int,
       buildRowCount: Long,
       reverseJoinFunction: Boolean,
-      condFunc: GeneratedJoinCondition): CodeGenOperatorFactory[RowData] = {
+      condFunc: GeneratedJoinCondition,
+      leftIsBuild: Boolean,
+      sortMergeJoinFunction: SortMergeJoinFunction): CodeGenOperatorFactory[RowData] = {
 
     val buildSer = new BinaryRowDataSerializer(buildType.getFieldCount)
     val probeSer = new BinaryRowDataSerializer(probeType.getFieldCount)
@@ -143,6 +146,17 @@ object LongHashJoinGenerator {
     ctx.addReusableOpenStatement(s"condFunc.open(new ${className[Configuration]}());")
     ctx.addReusableCloseStatement(s"condFunc.close();")
 
+    val leftIsBuildTerm = newName("leftIsBuild")
+    ctx.addReusableMember(s"private final boolean $leftIsBuildTerm = $leftIsBuild;")
+
+    val smjFunctionTerm = className[SortMergeJoinFunction]
+    ctx.addReusableMember(s"private final $smjFunctionTerm sortMergeJoinFunction;")
+    val smjFunctionRefs = ctx.addReusableObject(Array(sortMergeJoinFunction), "smjFunctionRefs")
+    ctx.addReusableInitStatement(s"sortMergeJoinFunction = $smjFunctionRefs[0];")
+
+    val fallbackSMJ = newName("fallbackSMJ")
+    ctx.addReusableMember(s"private transient boolean $fallbackSMJ = false;")
+
     val gauge = classOf[Gauge[_]].getCanonicalName
     ctx.addReusableOpenStatement(
       s"""
@@ -300,11 +314,91 @@ object LongHashJoinGenerator {
          |}
        """.stripMargin)
 
+    // fallback to sort merge join in probe phase
+    val rowIter = classOf[RowIterator[_]].getCanonicalName
+    ctx.addReusableMember(s"""
+                             |private void fallbackSMJProcessPartition() throws Exception {
+                             |  if(!table.getPartitionsPendingForSMJ().isEmpty()) {
+                             |    LOG.info(
+                             |    "Fallback to sort merge join to process spilled partitions.");
+                             |    initialSortMergeJoinFunction();
+                             |    $fallbackSMJ = true;
+                             |
+                             |    for(${classOf[LongHashPartition].getCanonicalName} p : 
+                             |      table.getPartitionsPendingForSMJ()) {
+                             |      $rowIter<$BINARY_ROW> buildSideIter = 
+                             |      table.getSpilledPartitionBuildSideIter(p);
+                             |      while (buildSideIter.advanceNext()) {
+                             |        processSortMergeJoinElement1(buildSideIter.getRow());
+                             |      }
+                             |
+                             |      ${classOf[ProbeIterator].getCanonicalName} probeIter =
+                             |      table.getSpilledPartitionProbeSideIter(p);
+                             |      $BINARY_ROW probeNext;
+                             |      while ((probeNext = probeIter.next()) != null) {
+                             |        processSortMergeJoinElement2(probeNext);
+                             |      }
+                             |    }
+                             |
+                             |    closeHashTable();
+                             |
+                             |    sortMergeJoinFunction.endInput(1);
+                             |    sortMergeJoinFunction.endInput(2);
+                             |    LOG.info("Finish sort merge join for spilled partitions.");
+                             |  }
+                             |}
+       """.stripMargin)
+
+    val collector = classOf[StreamRecordCollector[_]].getCanonicalName
+    ctx.addReusableMember(s"""
+                             |private void initialSortMergeJoinFunction() throws Exception {
+                             |  sortMergeJoinFunction.open(
+                             |    this.getContainingTask(),
+                             |    this.getOperatorConfig(),
+                             |    new $collector<$ROW_DATA>(output),
+                             |    this.computeMemorySize(),
+                             |    this.getRuntimeContext(),
+                             |    this.getMetricGroup());
+                             |}
+       """.stripMargin)
+
+    ctx.addReusableMember(
+      s"""
+         |private void processSortMergeJoinElement1($ROW_DATA rowData) throws Exception {
+         |  if($leftIsBuild) {
+         |    sortMergeJoinFunction.processElement1(rowData);
+         |  } else {
+         |    sortMergeJoinFunction.processElement2(rowData);
+         |  }
+         |}
+       """.stripMargin)
+
+    ctx.addReusableMember(
+      s"""
+         |private void processSortMergeJoinElement2($ROW_DATA rowData) throws Exception {
+         |  if($leftIsBuild) {
+         |    sortMergeJoinFunction.processElement2(rowData);
+         |  } else {
+         |    sortMergeJoinFunction.processElement1(rowData);
+         |  }
+         |}
+       """.stripMargin)
+
+    ctx.addReusableMember(s"""
+                             |private void closeHashTable() {
+                             |  if (this.table != null) {
+                             |    this.table.close();
+                             |    this.table.free();
+                             |    this.table = null;
+                             |  }
+                             |}
+       """.stripMargin)
+
     ctx.addReusableCloseStatement(s"""
-                                     |if (this.table != null) {
-                                     |  this.table.close();
-                                     |  this.table.free();
-                                     |  this.table = null;
+                                     |closeHashTable();
+                                     |
+                                     |if ($fallbackSMJ) {
+                                     |  sortMergeJoinFunction.close();
                                      |}
        """.stripMargin)
 
@@ -352,6 +446,8 @@ object LongHashJoinGenerator {
                               |  joinWithNextKey();
                               |}
                               |LOG.info("Finish rebuild phase.");
+                              |
+                              |fallbackSMJProcessPartition();
          """.stripMargin)
     )
 
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
index 74272a4717e..e3c7cfc5e86 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SortUtil.scala
@@ -20,7 +20,10 @@ package org.apache.flink.table.planner.plan.utils
 import org.apache.flink.api.common.operators.Order
 import org.apache.flink.table.api.TableException
 import org.apache.flink.table.planner.calcite.FlinkPlannerImpl
+import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig
 import org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec
+import org.apache.flink.table.types.logical.RowType
 
 import org.apache.calcite.rel.`type`._
 import org.apache.calcite.rel.{RelCollation, RelFieldCollation}
@@ -119,6 +122,15 @@ object SortUtil {
     builder.build()
   }
 
+  def newSortGen(
+      config: ExecNodeConfig,
+      classLoader: ClassLoader,
+      originalKeys: Array[Int],
+      inputType: RowType): SortCodeGenerator = {
+    val sortSpec = SortUtil.getAscendingSortSpec(originalKeys)
+    new SortCodeGenerator(config, classLoader, inputType, sortSpec)
+  }
+
   def directionToOrder(direction: Direction): Order = {
     direction match {
       case Direction.ASCENDING | Direction.STRICTLY_ASCENDING => Order.ASCENDING
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongAdaptiveHashJoinGeneratorTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongAdaptiveHashJoinGeneratorTest.java
new file mode 100644
index 00000000000..6dfe26b54a5
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongAdaptiveHashJoinGeneratorTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.codegen;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
+import org.apache.flink.table.planner.plan.utils.SorMergeJoinOperatorUtil;
+import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+import org.apache.flink.table.runtime.operators.join.HashJoinType;
+import org.apache.flink.table.runtime.operators.join.Int2AdaptiveHashJoinOperatorTest;
+import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.RowType;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for adaptive {@link LongHashJoinGenerator}. */
+public class LongAdaptiveHashJoinGeneratorTest extends Int2AdaptiveHashJoinOperatorTest {
+
+    @Override
+    public Object newOperator(
+            long memorySize,
+            FlinkJoinType flinkJoinType,
+            HashJoinType hashJoinType,
+            boolean buildLeft,
+            boolean reverseJoinFunction) {
+        return getLongHashJoinOperator(flinkJoinType, hashJoinType, buildLeft, reverseJoinFunction);
+    }
+
+    @Override
+    public void testBuildLeftAntiJoinFallbackToSMJ() {}
+
+    @Override
+    public void testBuildLeftSemiJoinFallbackToSMJ() {}
+
+    @Override
+    public void testBuildFirstHashLeftOutJoinFallbackToSMJ() {}
+
+    @Override
+    public void testBuildSecondHashRightOutJoinFallbackToSMJ() {}
+
+    @Override
+    public void testBuildFirstHashFullOutJoinFallbackToSMJ() {}
+
+    static Object getLongHashJoinOperator(
+            FlinkJoinType flinkJoinType,
+            HashJoinType hashJoinType,
+            boolean buildLeft,
+            boolean reverseJoinFunction) {
+        RowType keyType = RowType.of(new IntType());
+        boolean[] filterNulls = new boolean[] {true};
+        assertThat(LongHashJoinGenerator.support(hashJoinType, keyType, filterNulls)).isTrue();
+
+        RowType buildType = RowType.of(new IntType(), new IntType());
+        RowType probeType = RowType.of(new IntType(), new IntType());
+        int[] buildKeyMapping = new int[] {0};
+        int[] probeKeyMapping = new int[] {0};
+        GeneratedJoinCondition condFunc =
+                new GeneratedJoinCondition(
+                        MyJoinCondition.class.getCanonicalName(), "", new Object[0]) {
+                    @Override
+                    public JoinCondition newInstance(ClassLoader classLoader) {
+                        return new MyJoinCondition(new Object[0]);
+                    }
+                };
+
+        SortMergeJoinFunction sortMergeJoinFunction;
+        if (buildLeft) {
+            sortMergeJoinFunction =
+                    SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+                            Thread.currentThread().getContextClassLoader(),
+                            new ExecNodeConfig(TableConfig.getDefault(), new Configuration()),
+                            flinkJoinType,
+                            buildType,
+                            probeType,
+                            buildKeyMapping,
+                            probeKeyMapping,
+                            keyType,
+                            buildLeft,
+                            filterNulls,
+                            condFunc,
+                            0);
+        } else {
+            sortMergeJoinFunction =
+                    SorMergeJoinOperatorUtil.getSortMergeJoinFunction(
+                            Thread.currentThread().getContextClassLoader(),
+                            new ExecNodeConfig(TableConfig.getDefault(), new Configuration()),
+                            flinkJoinType,
+                            probeType,
+                            buildType,
+                            probeKeyMapping,
+                            buildKeyMapping,
+                            keyType,
+                            buildLeft,
+                            filterNulls,
+                            condFunc,
+                            0);
+        }
+        return LongHashJoinGenerator.gen(
+                new Configuration(),
+                Thread.currentThread().getContextClassLoader(),
+                hashJoinType,
+                keyType,
+                buildType,
+                probeType,
+                buildKeyMapping,
+                probeKeyMapping,
+                20,
+                10000,
+                reverseJoinFunction,
+                condFunc,
+                buildLeft,
+                sortMergeJoinFunction);
+    }
+}
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongHashJoinGeneratorTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongHashJoinGeneratorTest.java
index d1ce2c9e77f..0aaa40e5403 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongHashJoinGeneratorTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/LongHashJoinGeneratorTest.java
@@ -18,75 +18,47 @@
 
 package org.apache.flink.table.planner.codegen;
 
-import org.apache.flink.api.common.functions.AbstractRichFunction;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
-import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
 import org.apache.flink.table.runtime.operators.join.HashJoinType;
 import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTest;
-import org.apache.flink.table.types.logical.IntType;
-import org.apache.flink.table.types.logical.RowType;
 
 import org.junit.Test;
 
-import static org.assertj.core.api.Assertions.assertThat;
-
 /** Test for {@link LongHashJoinGenerator}. */
 public class LongHashJoinGeneratorTest extends Int2HashJoinOperatorTest {
 
     @Override
-    public Object newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) {
-        RowType keyType = RowType.of(new IntType());
-        assertThat(LongHashJoinGenerator.support(type, keyType, new boolean[] {true})).isTrue();
-        return LongHashJoinGenerator.gen(
-                new Configuration(),
-                Thread.currentThread().getContextClassLoader(),
-                type,
-                keyType,
-                RowType.of(new IntType(), new IntType()),
-                RowType.of(new IntType(), new IntType()),
-                new int[] {0},
-                new int[] {0},
-                20,
-                10000,
-                reverseJoinFunction,
-                new GeneratedJoinCondition(
-                        MyJoinCondition.class.getCanonicalName(), "", new Object[0]));
+    public Object newOperator(
+            long memorySize,
+            FlinkJoinType flinkJoinType,
+            HashJoinType hashJoinType,
+            boolean buildLeft,
+            boolean reverseJoinFunction) {
+        return LongAdaptiveHashJoinGeneratorTest.getLongHashJoinOperator(
+                flinkJoinType, hashJoinType, buildLeft, reverseJoinFunction);
     }
 
     @Test
     @Override
-    public void testBuildLeftSemiJoin() throws Exception {}
+    public void testBuildLeftSemiJoin() {}
 
     @Test
     @Override
-    public void testBuildSecondHashFullOutJoin() throws Exception {}
+    public void testBuildSecondHashFullOutJoin() {}
 
     @Test
     @Override
-    public void testBuildSecondHashRightOutJoin() throws Exception {}
+    public void testBuildSecondHashRightOutJoin() {}
 
     @Test
     @Override
-    public void testBuildLeftAntiJoin() throws Exception {}
+    public void testBuildLeftAntiJoin() {}
 
     @Test
     @Override
-    public void testBuildFirstHashLeftOutJoin() throws Exception {}
+    public void testBuildFirstHashLeftOutJoin() {}
 
     @Test
     @Override
-    public void testBuildFirstHashFullOutJoin() throws Exception {}
-
-    /** Test cond. */
-    public static class MyJoinCondition extends AbstractRichFunction implements JoinCondition {
-
-        public MyJoinCondition(Object[] reference) {}
-
-        @Override
-        public boolean apply(RowData in1, RowData in2) {
-            return true;
-        }
-    }
+    public void testBuildFirstHashFullOutJoin() {}
 }
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
index 5c5a3ddf518..f7e691057c7 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
@@ -387,6 +387,17 @@ public abstract class BaseHybridHashTable implements MemorySegmentPool {
             return;
         }
 
+        // clear the current build side channel, if there exist one
+        if (this.currentSpilledBuildSide != null) {
+            try {
+                this.currentSpilledBuildSide.getChannel().closeAndDelete();
+            } catch (Throwable t) {
+                LOG.warn(
+                        "Could not close and delete the temp file for the current spilled partition build side.",
+                        t);
+            }
+        }
+
         // clear the current probe side channel, if there is one
         if (this.currentSpilledProbeSide != null) {
             try {
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashPartition.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashPartition.java
index ab89c9aa01e..eaea42361f1 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashPartition.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashPartition.java
@@ -60,7 +60,15 @@ public class BinaryHashPartition extends AbstractPagedInputView implements Seeka
     final int partitionNumber; // the number of the partition
 
     long probeSideRecordCounter; // number of probe-side records in this partition
+
+    /**
+     * These buffers are null when create the BinaryHashPartition first, it will be assigned value
+     * in two case: 1) when build stage end, if this partition in memory, all the segments occupied
+     * by this partition will be returned to it; 2) when build stage end, if this partition has
+     * spilled to disk, the data read from disk next time will be assigned to it.
+     */
     private MemorySegment[] partitionBuffers;
+
     private int currentBufferNum;
     private int finalBufferLimit;
 
@@ -621,7 +629,7 @@ public class BinaryHashPartition extends AbstractPagedInputView implements Seeka
             }
         }
 
-        final long getPointer() {
+        long getPointer() {
             return this.currentPointer;
         }
 
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
index fe5f90c5dc8..26cfb0bcd9a 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
@@ -95,6 +95,12 @@ public class BinaryHashTable extends BaseHybridHashTable {
     /** The partitions that are built by processing the current partition. */
     final ArrayList<BinaryHashPartition> partitionsBeingBuilt;
 
+    /**
+     * The partitions that have been spilled previously and are pending to be processed by sort
+     * merge join operator.
+     */
+    private final List<BinaryHashPartition> partitionsPendingForSMJ;
+
     /**
      * BitSet which used to mark whether the element(int build side) has successfully matched during
      * probe phase. As there are 9 elements in each bucket, we assign 2 bytes to BitSet.
@@ -193,6 +199,7 @@ public class BinaryHashTable extends BaseHybridHashTable {
 
         this.partitionsBeingBuilt = new ArrayList<>();
         this.partitionsPending = new ArrayList<>();
+        this.partitionsPendingForSMJ = new ArrayList<>();
 
         createPartitions(initPartitionFanOut, 0);
     }
@@ -399,10 +406,37 @@ public class BinaryHashTable extends BaseHybridHashTable {
         this.probeMatchedPhase = true;
         this.buildIterVisited = false;
 
+        final int nextRecursionLevel = p.getRecursionLevel() + 1;
+        if (nextRecursionLevel == 2) {
+            LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
+        } else if (nextRecursionLevel > MAX_RECURSION_DEPTH) {
+            LOG.info(
+                    "Partition number [{}] recursive level more than {}, process the partition using SortMergeJoin later.",
+                    p.getPartitionNumber(),
+                    MAX_RECURSION_DEPTH);
+            // if the partition has spilled to disk more than three times, process it by sort merge
+            // join later
+            this.partitionsPendingForSMJ.add(p);
+            // also need to remove it from pending list
+            this.partitionsPending.remove(0);
+            // recursively get the next partition
+            return prepareNextPartition();
+        }
         // build the next table; memory must be allocated after this call
-        buildTableFromSpilledPartition(p);
+        buildTableFromSpilledPartition(p, nextRecursionLevel);
 
         // set the probe side
+        setPartitionProbeReader(p);
+
+        // unregister the pending partition
+        this.partitionsPending.remove(0);
+        this.currentRecursionDepth = p.getRecursionLevel() + 1;
+
+        // recursively get the next
+        return nextMatching();
+    }
+
+    private void setPartitionProbeReader(BinaryHashPartition p) throws IOException {
         ChannelWithMeta channelWithMeta =
                 new ChannelWithMeta(
                         p.probeSideBuffer.getChannel().getChannelID(),
@@ -424,27 +458,11 @@ public class BinaryHashTable extends BaseHybridHashTable {
                         new ArrayList<>(),
                         this.binaryProbeSideSerializer);
         this.probeIterator.set(probeReader);
-        this.probeIterator.setReuse(binaryProbeSideSerializer.createInstance());
-
-        // unregister the pending partition
-        this.partitionsPending.remove(0);
-        this.currentRecursionDepth = p.getRecursionLevel() + 1;
-
-        // recursively get the next
-        return nextMatching();
+        this.probeIterator.setReuse(this.binaryProbeSideSerializer.createInstance());
     }
 
-    private void buildTableFromSpilledPartition(final BinaryHashPartition p) throws IOException {
-
-        final int nextRecursionLevel = p.getRecursionLevel() + 1;
-        if (nextRecursionLevel == 2) {
-            LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
-        } else if (nextRecursionLevel > MAX_RECURSION_DEPTH) {
-            throw new RuntimeException(
-                    "Hash join exceeded maximum number of recursions, without reducing "
-                            + "partitions enough to be memory resident. Probably cause: Too many duplicate keys.");
-        }
-
+    private void buildTableFromSpilledPartition(
+            final BinaryHashPartition p, final int nextRecursionLevel) throws IOException {
         if (p.getBuildSideBlockCount() > p.getProbeSideBlockCount()) {
             LOG.info(
                     String.format(
@@ -630,6 +648,16 @@ public class BinaryHashTable extends BaseHybridHashTable {
         for (final BinaryHashPartition p : this.partitionsPending) {
             p.clearAllMemory(this.internalPool);
         }
+
+        // clear the partitions that processed by sort merge join operator
+        for (final BinaryHashPartition p : this.partitionsPendingForSMJ) {
+            try {
+                p.clearAllMemory(this.internalPool);
+            } catch (Exception e) {
+                LOG.error("Error during partition cleanup.", e);
+            }
+        }
+        this.partitionsPendingForSMJ.clear();
     }
 
     /**
@@ -659,7 +687,6 @@ public class BinaryHashTable extends BaseHybridHashTable {
                         this.currentEnumerator.next(),
                         this.buildSpillReturnBuffers);
         this.buildSpillRetBufferNumbers += numBuffersFreed;
-
         LOG.info(
                 String.format(
                         "Grace hash join: Ran out memory, choosing partition "
@@ -675,11 +702,62 @@ public class BinaryHashTable extends BaseHybridHashTable {
         }
         numSpillFiles++;
         spillInBytes += numBuffersFreed * segmentSize;
-        // The bloomFilter is built after the data is spilled, so that we can use enough memory.
+        // The bloomFilter is built by bucket area after the data is spilled, so that we can use
+        // enough memory.
         p.buildBloomFilterAndFreeBucket();
         return largestPartNum;
     }
 
+    public List<BinaryHashPartition> getPartitionsPendingForSMJ() {
+        return this.partitionsPendingForSMJ;
+    }
+
+    public RowIterator getSpilledPartitionBuildSideIter(BinaryHashPartition p) throws IOException {
+        // close build side channel of last processed partition
+        if (this.currentSpilledBuildSide != null) {
+            try {
+                this.currentSpilledBuildSide.getChannel().closeAndDelete();
+            } catch (Throwable t) {
+                LOG.warn(
+                        "Could not close and delete the temp file for the current spilled partition build side.",
+                        t);
+            }
+            this.currentSpilledBuildSide = null;
+        }
+
+        this.currentSpilledBuildSide =
+                createInputView(
+                        p.getBuildSideChannel().getChannelID(),
+                        p.getBuildSideBlockCount(),
+                        p.getLastSegmentLimit());
+        this.buildIterator =
+                new WrappedRowIterator<>(
+                        new BinaryRowChannelInputViewIterator(
+                                this.currentSpilledBuildSide, this.binaryBuildSideSerializer),
+                        this.binaryBuildSideSerializer.createInstance());
+        return this.buildIterator;
+    }
+
+    public ProbeIterator getSpilledPartitionProbeSideIter(BinaryHashPartition p)
+            throws IOException {
+        // close probe side channel of last processed partition
+        if (this.currentSpilledProbeSide != null) {
+            try {
+                this.currentSpilledProbeSide.getChannel().closeAndDelete();
+            } catch (Throwable t) {
+                LOG.warn(
+                        "Could not close and delete the temp file for the current spilled partition probe side.",
+                        t);
+            }
+            this.currentSpilledProbeSide = null;
+        }
+
+        // get the probe side iterator
+        this.probeIterator = new ProbeIterator(this.binaryProbeSideSerializer.createInstance());
+        setPartitionProbeReader(p);
+        return this.probeIterator;
+    }
+
     boolean applyCondition(BinaryRowData candidate) {
         BinaryRowData buildKey = buildSideProjection.apply(candidate);
         // They come from Projection, so we can make sure it is in byte[].
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
index 951562687cb..306420272bf 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
@@ -849,7 +849,7 @@ public class LongHashPartition extends AbstractPagedInputView implements Seekabl
             }
         }
 
-        final long getPointer() {
+        long getPointer() {
             return this.currentPointer;
         }
 
@@ -879,7 +879,7 @@ public class LongHashPartition extends AbstractPagedInputView implements Seekabl
         return available < 8 + serializer.getFixedLengthPartSize();
     }
 
-    static void deserializeFromPages(
+    public static void deserializeFromPages(
             BinaryRowData reuse,
             ChannelReaderInputView inView,
             BinaryRowDataSerializer buildSideSerializer)
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
index 600953a7e15..75081ea81b7 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
@@ -28,8 +28,10 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.binary.BinaryRowData;
 import org.apache.flink.table.runtime.io.ChannelWithMeta;
+import org.apache.flink.table.runtime.io.LongHashPartitionChannelReaderInputViewIterator;
 import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
 import org.apache.flink.table.runtime.util.FileChannelUtil;
+import org.apache.flink.table.runtime.util.RowIterator;
 import org.apache.flink.util.MathUtils;
 
 import java.io.EOFException;
@@ -52,6 +54,12 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
     private final BinaryRowDataSerializer probeSideSerializer;
     private final ArrayList<LongHashPartition> partitionsBeingBuilt;
     private final ArrayList<LongHashPartition> partitionsPending;
+    /**
+     * The partitions that have been spilled previously and are pending to be processed by sort
+     * merge join operator.
+     */
+    private final List<LongHashPartition> partitionsPendingForSMJ;
+
     private ProbeIterator probeIterator;
     private LongHashPartition.MatchIterator matchIterator;
 
@@ -85,6 +93,7 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
 
         this.partitionsBeingBuilt = new ArrayList<>();
         this.partitionsPending = new ArrayList<>();
+        this.partitionsPendingForSMJ = new ArrayList<>();
 
         createPartitions(initPartitionFanOut, 0);
     }
@@ -184,7 +193,7 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
 
     /** After build end, try to use dense mode. */
     private void tryDenseMode() {
-
+        // if some partitions have spilled to disk, always use hash mode
         if (numSpillFiles != 0) {
             return;
         }
@@ -212,7 +221,7 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
 
         long range = maxKey - minKey + 1;
 
-        // 1.range is negative mean: range is to big to overflow
+        // 1.range is negative mean: range is too big to overflow
         // 2.range is zero, maybe the max is Long.Max, and the min is Long.Min,
         // so we should not use dense mode too.
         if (range > 0 && (range <= recordCount * 4 || range <= segmentSize / 8)) {
@@ -221,8 +230,7 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
             int buffers = (int) Math.ceil(((double) (range * 8)) / segmentSize);
 
             // TODO MemoryManager needs to support flexible larger segment, so that the index area
-            // of the
-            // build side is placed on a segment to avoid the overhead of addressing.
+            // of the build side is placed on a segment to avoid the overhead of addressing.
             MemorySegment[] denseBuckets = new MemorySegment[buffers];
             for (int i = 0; i < buffers; i++) {
                 MemorySegment seg = getNextBuffer();
@@ -362,10 +370,38 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
             return prepareNextPartition();
         }
 
+        final int nextRecursionLevel = p.getRecursionLevel() + 1;
+        if (nextRecursionLevel == 2) {
+            LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
+        } else if (nextRecursionLevel > MAX_RECURSION_DEPTH) {
+            LOG.info(
+                    "Partition number [{}] recursive level more than {}, process the partition using SortMergeJoin later.",
+                    p.getPartitionNumber(),
+                    MAX_RECURSION_DEPTH);
+            // if the partition has spilled to disk more than three times, process it by sort merge
+            // join later
+            this.partitionsPendingForSMJ.add(p);
+            // also need to remove it from pending list
+            this.partitionsPending.remove(0);
+            // recursively get the next partition
+            return prepareNextPartition();
+        }
+
         // build the next table; memory must be allocated after this call
-        buildTableFromSpilledPartition(p);
+        buildTableFromSpilledPartition(p, nextRecursionLevel);
 
         // set the probe side
+        setPartitionProbeReader(p);
+
+        // unregister the pending partition
+        this.partitionsPending.remove(0);
+        this.currentRecursionDepth = p.getRecursionLevel() + 1;
+
+        // recursively get the next
+        return nextMatching();
+    }
+
+    private void setPartitionProbeReader(LongHashPartition p) throws IOException {
         ChannelWithMeta channelWithMeta =
                 new ChannelWithMeta(
                         p.probeSideBuffer.getChannel().getChannelID(),
@@ -386,26 +422,10 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
                         this.currentSpilledProbeSide, new ArrayList<>(), this.probeSideSerializer);
         this.probeIterator.set(probeReader);
         this.probeIterator.setReuse(probeSideSerializer.createInstance());
-
-        // unregister the pending partition
-        this.partitionsPending.remove(0);
-        this.currentRecursionDepth = p.getRecursionLevel() + 1;
-
-        // recursively get the next
-        return nextMatching();
     }
 
-    private void buildTableFromSpilledPartition(final LongHashPartition p) throws IOException {
-
-        final int nextRecursionLevel = p.getRecursionLevel() + 1;
-        if (nextRecursionLevel == 2) {
-            LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
-        } else if (nextRecursionLevel > MAX_RECURSION_DEPTH) {
-            throw new RuntimeException(
-                    "Hash join exceeded maximum number of recursions, without reducing "
-                            + "partitions enough to be memory resident. Probably cause: Too many duplicate keys.");
-        }
-
+    private void buildTableFromSpilledPartition(
+            final LongHashPartition p, final int nextRecursionLevel) throws IOException {
         if (p.getBuildSideBlockCount() > p.getProbeSideBlockCount()) {
             LOG.info(
                     String.format(
@@ -561,6 +581,53 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
         return largestPartNum;
     }
 
+    public List<LongHashPartition> getPartitionsPendingForSMJ() {
+        return this.partitionsPendingForSMJ;
+    }
+
+    public RowIterator getSpilledPartitionBuildSideIter(LongHashPartition p) throws IOException {
+        // close build side channel of last processed partition
+        if (this.currentSpilledBuildSide != null) {
+            try {
+                this.currentSpilledBuildSide.getChannel().closeAndDelete();
+            } catch (Throwable t) {
+                LOG.warn(
+                        "Could not close and delete the temp file for the current spilled partition build side.",
+                        t);
+            }
+            this.currentSpilledBuildSide = null;
+        }
+
+        this.currentSpilledBuildSide =
+                createInputView(
+                        p.getBuildSideChannel().getChannelID(),
+                        p.getBuildSideBlockCount(),
+                        p.getLastSegmentLimit());
+        return new WrappedRowIterator<>(
+                new LongHashPartitionChannelReaderInputViewIterator(
+                        this.currentSpilledBuildSide, this.buildSideSerializer),
+                this.buildSideSerializer.createInstance());
+    }
+
+    public ProbeIterator getSpilledPartitionProbeSideIter(LongHashPartition p) throws IOException {
+        // close probe side channel of last processed partition
+        if (this.currentSpilledProbeSide != null) {
+            try {
+                this.currentSpilledProbeSide.getChannel().closeAndDelete();
+            } catch (Throwable t) {
+                LOG.warn(
+                        "Could not close and delete the temp file for the current spilled partition probe side.",
+                        t);
+            }
+            this.currentSpilledProbeSide = null;
+        }
+
+        // get the probe side iterator
+        this.probeIterator = new ProbeIterator(this.probeSideSerializer.createInstance());
+        setPartitionProbeReader(p);
+        return this.probeIterator;
+    }
+
     @Override
     protected void clearPartitions() {
         this.probeIterator = null;
@@ -579,6 +646,16 @@ public abstract class LongHybridHashTable extends BaseHybridHashTable {
         for (final LongHashPartition p : this.partitionsPending) {
             p.clearAllMemory(this.internalPool);
         }
+
+        // clear the partitions that processed by sort merge join operator
+        for (final LongHashPartition p : this.partitionsPendingForSMJ) {
+            try {
+                p.clearAllMemory(this.internalPool);
+            } catch (Exception e) {
+                LOG.error("Error during partition cleanup.", e);
+            }
+        }
+        this.partitionsPendingForSMJ.clear();
     }
 
     public boolean compressionEnable() {
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
index 1779ab27241..30456442622 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
@@ -34,11 +34,11 @@ import java.util.List;
  * BinaryRowDataSerializer#deserializeFromPages}.
  */
 public class BinaryRowChannelInputViewIterator implements MutableObjectIterator<BinaryRowData> {
-    private final ChannelReaderInputView inView;
+    protected final ChannelReaderInputView inView;
 
-    private final BinaryRowDataSerializer serializer;
+    protected final BinaryRowDataSerializer serializer;
 
-    private final List<MemorySegment> freeMemTarget;
+    protected final List<MemorySegment> freeMemTarget;
 
     public BinaryRowChannelInputViewIterator(
             ChannelReaderInputView inView, BinaryRowDataSerializer serializer) {
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/LongHashPartitionChannelReaderInputViewIterator.java
similarity index 61%
copy from flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
copy to flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/LongHashPartitionChannelReaderInputViewIterator.java
index 1779ab27241..b7e2cf62212 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/BinaryRowChannelInputViewIterator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/io/LongHashPartitionChannelReaderInputViewIterator.java
@@ -21,43 +21,30 @@ package org.apache.flink.table.runtime.io;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.io.disk.iomanager.ChannelReaderInputView;
 import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.hashtable.LongHashPartition;
 import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
-import org.apache.flink.util.MutableObjectIterator;
 
 import java.io.EOFException;
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.List;
 
 /**
  * A simple iterator over the input read though an I/O channel. Use {@link
- * BinaryRowDataSerializer#deserializeFromPages}.
+ * LongHashPartition#deserializeFromPages}
  */
-public class BinaryRowChannelInputViewIterator implements MutableObjectIterator<BinaryRowData> {
-    private final ChannelReaderInputView inView;
+public class LongHashPartitionChannelReaderInputViewIterator
+        extends BinaryRowChannelInputViewIterator {
 
-    private final BinaryRowDataSerializer serializer;
-
-    private final List<MemorySegment> freeMemTarget;
-
-    public BinaryRowChannelInputViewIterator(
+    public LongHashPartitionChannelReaderInputViewIterator(
             ChannelReaderInputView inView, BinaryRowDataSerializer serializer) {
-        this(inView, new ArrayList<>(), serializer);
-    }
-
-    public BinaryRowChannelInputViewIterator(
-            ChannelReaderInputView inView,
-            List<MemorySegment> freeMemTarget,
-            BinaryRowDataSerializer serializer) {
-        this.inView = inView;
-        this.freeMemTarget = freeMemTarget;
-        this.serializer = serializer;
+        super(inView, serializer);
     }
 
     @Override
     public BinaryRowData next(BinaryRowData reuse) throws IOException {
         try {
-            return this.serializer.deserializeFromPages(reuse, this.inView);
+            LongHashPartition.deserializeFromPages(reuse, inView, serializer);
+            return reuse;
         } catch (EOFException eofex) {
             final List<MemorySegment> freeMem = this.inView.close();
             if (this.freeMemTarget != null) {
@@ -66,10 +53,4 @@ public class BinaryRowChannelInputViewIterator implements MutableObjectIterator<
             return null;
         }
     }
-
-    @Override
-    public BinaryRowData next() throws IOException {
-        throw new UnsupportedOperationException(
-                "This method is disabled due to performance issue!");
-    }
 }
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/HashJoinOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/HashJoinOperator.java
index 08312d34d7e..59b522a4521 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/HashJoinOperator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/HashJoinOperator.java
@@ -32,7 +32,9 @@ import org.apache.flink.table.data.utils.JoinedRowData;
 import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
 import org.apache.flink.table.runtime.generated.GeneratedProjection;
 import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.hashtable.BinaryHashPartition;
 import org.apache.flink.table.runtime.hashtable.BinaryHashTable;
+import org.apache.flink.table.runtime.hashtable.ProbeIterator;
 import org.apache.flink.table.runtime.operators.TableStreamOperator;
 import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
 import org.apache.flink.table.runtime.util.RowIterator;
@@ -54,6 +56,13 @@ import static org.apache.flink.util.Preconditions.checkState;
  * <p>The join operator implements the logic of a join operator at runtime. It uses a
  * hybrid-hash-join internally to match the records with equal key. The build side of the hash is
  * the first input of the match. It support all join type in {@link HashJoinType}.
+ *
+ * <p>Note: In order to solve the problem of data skew, or too much data in the hash table, the
+ * fallback to sort merge join mechanism is introduced here. If some partitions are spilled to disk
+ * more than three times in the process of hash join, it will fallback to sort merge join by default
+ * to improve stability. In the future, we will support more flexible adaptive hash join strategy,
+ * for example, in the process of building a hash table, if the size of data written to disk reaches
+ * a certain threshold, fallback to sort merge join in advance.
  */
 public abstract class HashJoinOperator extends TableStreamOperator<RowData>
         implements TwoInputStreamOperator<RowData, RowData, RowData>,
@@ -65,6 +74,8 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
     private final HashJoinParameter parameter;
     private final boolean reverseJoinFunction;
     private final HashJoinType type;
+    private final boolean leftIsBuild;
+    private final SortMergeJoinFunction sortMergeJoinFunction;
 
     private transient BinaryHashTable table;
     transient Collector<RowData> collector;
@@ -75,10 +86,15 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
     private transient boolean buildEnd;
     private transient JoinCondition condition;
 
+    // Flag indicates whether fallback to sort merge join in probe phase
+    private transient boolean fallbackSMJ;
+
     HashJoinOperator(HashJoinParameter parameter) {
         this.parameter = parameter;
         this.type = parameter.type;
+        this.leftIsBuild = parameter.leftIsBuild;
         this.reverseJoinFunction = parameter.reverseJoinFunction;
+        this.sortMergeJoinFunction = parameter.sortMergeJoinFunction;
     }
 
     @Override
@@ -132,6 +148,7 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
         this.probeSideNullRow = new GenericRowData(probeSerializer.getArity());
         this.joinedRow = new JoinedRowData();
         this.buildEnd = false;
+        this.fallbackSMJ = false;
 
         getMetricGroup().gauge("memoryUsedSizeInBytes", table::getUsedMemoryInBytes);
         getMetricGroup().gauge("numSpillFiles", table::getNumSpillFiles);
@@ -177,6 +194,10 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
                     joinWithNextKey();
                 }
                 LOG.info("Finish rebuild phase.");
+
+                // switch to sort merge join process the remaining partition which recursive
+                // level > 3
+                fallbackSMJProcessPartition();
                 break;
         }
     }
@@ -214,16 +235,90 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
     @Override
     public void close() throws Exception {
         super.close();
+        closeHashTable();
+        condition.close();
+
+        // If fallback to sort merge join during hash join, also need to close the operator
+        if (fallbackSMJ) {
+            sortMergeJoinFunction.close();
+        }
+    }
+
+    private void closeHashTable() {
         if (this.table != null) {
             this.table.close();
             this.table.free();
             this.table = null;
         }
-        condition.close();
+    }
+
+    /**
+     * If here also exists partitions which spilled to disk more than three time when hash join end,
+     * means that the key in these partitions is very skewed, so fallback to sort merge join
+     * algorithm to process it.
+     */
+    private void fallbackSMJProcessPartition() throws Exception {
+        if (!table.getPartitionsPendingForSMJ().isEmpty()) {
+            // initialize sort merge join operator
+            LOG.info("Fallback to sort merge join to process spilled partitions.");
+            initialSortMergeJoinFunction();
+            fallbackSMJ = true;
+
+            for (BinaryHashPartition p : table.getPartitionsPendingForSMJ()) {
+                // process build side
+                RowIterator<BinaryRowData> buildSideIter =
+                        table.getSpilledPartitionBuildSideIter(p);
+                while (buildSideIter.advanceNext()) {
+                    processSortMergeJoinElement1(buildSideIter.getRow());
+                }
+
+                // process probe side
+                ProbeIterator probeIter = table.getSpilledPartitionProbeSideIter(p);
+                BinaryRowData probeNext;
+                while ((probeNext = probeIter.next()) != null) {
+                    processSortMergeJoinElement2(probeNext);
+                }
+            }
+
+            // close the HashTable
+            closeHashTable();
+
+            // finish build and probe
+            sortMergeJoinFunction.endInput(1);
+            sortMergeJoinFunction.endInput(2);
+            LOG.info("Finish sort merge join for spilled partitions.");
+        }
+    }
+
+    private void initialSortMergeJoinFunction() throws Exception {
+        sortMergeJoinFunction.open(
+                this.getContainingTask(),
+                this.getOperatorConfig(),
+                (StreamRecordCollector) this.collector,
+                this.computeMemorySize(),
+                this.getRuntimeContext(),
+                this.getMetricGroup());
+    }
+
+    private void processSortMergeJoinElement1(RowData rowData) throws Exception {
+        if (leftIsBuild) {
+            sortMergeJoinFunction.processElement1(rowData);
+        } else {
+            sortMergeJoinFunction.processElement2(rowData);
+        }
+    }
+
+    private void processSortMergeJoinElement2(RowData rowData) throws Exception {
+        if (leftIsBuild) {
+            sortMergeJoinFunction.processElement2(rowData);
+        } else {
+            sortMergeJoinFunction.processElement1(rowData);
+        }
     }
 
     public static HashJoinOperator newHashJoinOperator(
             HashJoinType type,
+            boolean leftIsBuild,
             GeneratedJoinCondition condFuncCode,
             boolean reverseJoinFunction,
             boolean[] filterNullKeys,
@@ -233,10 +328,12 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
             int buildRowSize,
             long buildRowCount,
             long probeRowCount,
-            RowType keyType) {
+            RowType keyType,
+            SortMergeJoinFunction sortMergeJoinFunction) {
         HashJoinParameter parameter =
                 new HashJoinParameter(
                         type,
+                        leftIsBuild,
                         condFuncCode,
                         reverseJoinFunction,
                         filterNullKeys,
@@ -246,7 +343,8 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
                         buildRowSize,
                         buildRowCount,
                         probeRowCount,
-                        keyType);
+                        keyType,
+                        sortMergeJoinFunction);
         switch (type) {
             case INNER:
                 return new InnerHashJoinOperator(parameter);
@@ -270,6 +368,7 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
 
     static class HashJoinParameter implements Serializable {
         HashJoinType type;
+        boolean leftIsBuild;
         GeneratedJoinCondition condFuncCode;
         boolean reverseJoinFunction;
         boolean[] filterNullKeys;
@@ -280,9 +379,11 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
         long buildRowCount;
         long probeRowCount;
         RowType keyType;
+        SortMergeJoinFunction sortMergeJoinFunction;
 
         HashJoinParameter(
                 HashJoinType type,
+                boolean leftIsBuild,
                 GeneratedJoinCondition condFuncCode,
                 boolean reverseJoinFunction,
                 boolean[] filterNullKeys,
@@ -292,8 +393,10 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
                 int buildRowSize,
                 long buildRowCount,
                 long probeRowCount,
-                RowType keyType) {
+                RowType keyType,
+                SortMergeJoinFunction sortMergeJoinFunction) {
             this.type = type;
+            this.leftIsBuild = leftIsBuild;
             this.condFuncCode = condFuncCode;
             this.reverseJoinFunction = reverseJoinFunction;
             this.filterNullKeys = filterNullKeys;
@@ -304,6 +407,7 @@ public abstract class HashJoinOperator extends TableStreamOperator<RowData>
             this.buildRowCount = buildRowCount;
             this.probeRowCount = probeRowCount;
             this.keyType = keyType;
+            this.sortMergeJoinFunction = sortMergeJoinFunction;
         }
     }
 
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinFunction.java
similarity index 85%
copy from flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
copy to flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinFunction.java
index 18caea96374..79a8fc19c56 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinFunction.java
@@ -1,12 +1,13 @@
 /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.	See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.	You may obtain a copy of the License at
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
  *
- *		http://www.apache.org/licenses/LICENSE-2.0
+ *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -17,13 +18,14 @@
 
 package org.apache.flink.table.runtime.operators.join;
 
+import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.metrics.Gauge;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.memory.MemoryManager;
-import org.apache.flink.streaming.api.operators.BoundedMultiInput;
-import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.table.api.TableException;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
@@ -36,7 +38,6 @@ import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
 import org.apache.flink.table.runtime.generated.JoinCondition;
 import org.apache.flink.table.runtime.generated.Projection;
 import org.apache.flink.table.runtime.generated.RecordComparator;
-import org.apache.flink.table.runtime.operators.TableStreamOperator;
 import org.apache.flink.table.runtime.operators.sort.BinaryExternalSorter;
 import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
 import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
@@ -46,21 +47,13 @@ import org.apache.flink.table.runtime.util.StreamRecordCollector;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.MutableObjectIterator;
 
+import java.io.Serializable;
 import java.util.BitSet;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
-/**
- * An implementation that realizes the joining through a sort-merge join strategy. 1.In most cases,
- * its performance is weaker than HashJoin. 2.It is more stable than HashJoin, and most of the data
- * can be sorted stably. 3.SortMergeJoin should be the best choice if sort can be omitted in the
- * case of multi-level join cascade with the same key.
- *
- * <p>NOTE: SEMI and ANTI join output input1 instead of input2. (Contrary to {@link
- * HashJoinOperator}).
- */
-public class SortMergeJoinOperator extends TableStreamOperator<RowData>
-        implements TwoInputStreamOperator<RowData, RowData, RowData>, BoundedMultiInput {
+/** This function is used to process the main logic of sort merge join. */
+public class SortMergeJoinFunction implements Serializable {
 
     private final double externalBufferMemRatio;
     private final FlinkJoinType type;
@@ -77,6 +70,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
     private GeneratedRecordComparator comparator2;
     private GeneratedRecordComparator genKeyComparator;
 
+    private transient StreamTask<?, ?> taskContainer;
     private transient long externalBufferMemory;
     private transient MemoryManager memManager;
     private transient IOManager ioManager;
@@ -95,7 +89,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
     private transient RowData rightNullRow;
     private transient JoinedRowData joinedRow;
 
-    public SortMergeJoinOperator(
+    public SortMergeJoinFunction(
             double externalBufferMemRatio,
             FlinkJoinType type,
             boolean leftIsSmaller,
@@ -122,29 +116,31 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
         this.filterNulls = filterNulls;
     }
 
-    @Override
-    public void open() throws Exception {
-        super.open();
-
-        Configuration conf = getContainingTask().getJobConfiguration();
+    public void open(
+            StreamTask<?, ?> taskContainer,
+            StreamConfig operatorConfig,
+            StreamRecordCollector collector,
+            long totalMemory,
+            RuntimeContext runtimeContext,
+            OperatorMetricGroup operatorMetricGroup)
+            throws Exception {
+        this.taskContainer = taskContainer;
 
         isFinished = new boolean[] {false, false};
 
-        collector = new StreamRecordCollector<>(output);
+        this.collector = collector;
 
-        ClassLoader cl = getUserCodeClassloader();
+        ClassLoader cl = taskContainer.getUserCodeClassLoader();
         AbstractRowDataSerializer inputSerializer1 =
-                (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn1(cl);
+                (AbstractRowDataSerializer) operatorConfig.getTypeSerializerIn1(cl);
         this.serializer1 = new BinaryRowDataSerializer(inputSerializer1.getArity());
 
         AbstractRowDataSerializer inputSerializer2 =
-                (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn2(cl);
+                (AbstractRowDataSerializer) operatorConfig.getTypeSerializerIn2(cl);
         this.serializer2 = new BinaryRowDataSerializer(inputSerializer2.getArity());
 
-        this.memManager = this.getContainingTask().getEnvironment().getMemoryManager();
-        this.ioManager = this.getContainingTask().getEnvironment().getIOManager();
-
-        long totalMemory = computeMemorySize();
+        this.memManager = taskContainer.getEnvironment().getMemoryManager();
+        this.ioManager = taskContainer.getEnvironment().getIOManager();
 
         externalBufferMemory = (long) (totalMemory * externalBufferMemRatio);
         externalBufferMemory =
@@ -162,10 +158,11 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
                             + ", please increase manage memory of task manager.");
         }
 
+        Configuration conf = taskContainer.getJobConfiguration();
         // sorter1
         this.sorter1 =
                 new BinaryExternalSorter(
-                        this.getContainingTask(),
+                        taskContainer,
                         memManager,
                         totalSortMem / 2,
                         ioManager,
@@ -179,7 +176,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
         // sorter2
         this.sorter2 =
                 new BinaryExternalSorter(
-                        this.getContainingTask(),
+                        taskContainer,
                         memManager,
                         totalSortMem / 2,
                         ioManager,
@@ -192,7 +189,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
 
         keyComparator = genKeyComparator.newInstance(cl);
         this.condFunc = condFuncCode.newInstance(cl);
-        condFunc.setRuntimeContext(getRuntimeContext());
+        condFunc.setRuntimeContext(runtimeContext);
         condFunc.open(new Configuration());
 
         projection1 = projectionCode1.newInstance(cl);
@@ -211,37 +208,28 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
         projectionCode2 = null;
         genKeyComparator = null;
 
-        getMetricGroup()
-                .gauge(
-                        "memoryUsedSizeInBytes",
-                        (Gauge<Long>)
-                                () ->
-                                        sorter1.getUsedMemoryInBytes()
-                                                + sorter2.getUsedMemoryInBytes());
-
-        getMetricGroup()
-                .gauge(
-                        "numSpillFiles",
-                        (Gauge<Long>)
-                                () -> sorter1.getNumSpillFiles() + sorter2.getNumSpillFiles());
-
-        getMetricGroup()
-                .gauge(
-                        "spillInBytes",
-                        (Gauge<Long>) () -> sorter1.getSpillInBytes() + sorter2.getSpillInBytes());
+        operatorMetricGroup.gauge(
+                "memoryUsedSizeInBytes",
+                (Gauge<Long>)
+                        () -> sorter1.getUsedMemoryInBytes() + sorter2.getUsedMemoryInBytes());
+
+        operatorMetricGroup.gauge(
+                "numSpillFiles",
+                (Gauge<Long>) () -> sorter1.getNumSpillFiles() + sorter2.getNumSpillFiles());
+
+        operatorMetricGroup.gauge(
+                "spillInBytes",
+                (Gauge<Long>) () -> sorter1.getSpillInBytes() + sorter2.getSpillInBytes());
     }
 
-    @Override
-    public void processElement1(StreamRecord<RowData> element) throws Exception {
-        this.sorter1.write(element.getValue());
+    public void processElement1(RowData element) throws Exception {
+        this.sorter1.write(element);
     }
 
-    @Override
-    public void processElement2(StreamRecord<RowData> element) throws Exception {
-        this.sorter2.write(element.getValue());
+    public void processElement2(RowData element) throws Exception {
+        this.sorter2.write(element);
     }
 
-    @Override
     public void endInput(int inputId) throws Exception {
         isFinished[inputId - 1] = true;
         if (isAllFinished()) {
@@ -521,7 +509,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
     private ResettableExternalBuffer newBuffer(BinaryRowDataSerializer serializer) {
         LazyMemorySegmentPool pool =
                 new LazyMemorySegmentPool(
-                        this.getContainingTask(),
+                        taskContainer,
                         memManager,
                         (int) (externalBufferMemory / memManager.getPageSize()));
         return new ResettableExternalBuffer(
@@ -535,9 +523,7 @@ public class SortMergeJoinOperator extends TableStreamOperator<RowData>
         return isFinished[0] && isFinished[1];
     }
 
-    @Override
     public void close() throws Exception {
-        super.close();
         if (this.sorter1 != null) {
             this.sorter1.close();
         }
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
index 18caea96374..61c452a0f79 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinOperator.java
@@ -17,38 +17,12 @@
 
 package org.apache.flink.table.runtime.operators.join;
 
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.metrics.Gauge;
-import org.apache.flink.runtime.io.disk.iomanager.IOManager;
-import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.table.api.TableException;
-import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.binary.BinaryRowData;
-import org.apache.flink.table.data.utils.JoinedRowData;
-import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
-import org.apache.flink.table.runtime.generated.GeneratedNormalizedKeyComputer;
-import org.apache.flink.table.runtime.generated.GeneratedProjection;
-import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
-import org.apache.flink.table.runtime.generated.JoinCondition;
-import org.apache.flink.table.runtime.generated.Projection;
-import org.apache.flink.table.runtime.generated.RecordComparator;
 import org.apache.flink.table.runtime.operators.TableStreamOperator;
-import org.apache.flink.table.runtime.operators.sort.BinaryExternalSorter;
-import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
-import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
-import org.apache.flink.table.runtime.util.LazyMemorySegmentPool;
-import org.apache.flink.table.runtime.util.ResettableExternalBuffer;
 import org.apache.flink.table.runtime.util.StreamRecordCollector;
-import org.apache.flink.util.Collector;
-import org.apache.flink.util.MutableObjectIterator;
-
-import java.util.BitSet;
-
-import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
  * An implementation that realizes the joining through a sort-merge join strategy. 1.In most cases,
@@ -62,488 +36,44 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
 public class SortMergeJoinOperator extends TableStreamOperator<RowData>
         implements TwoInputStreamOperator<RowData, RowData, RowData>, BoundedMultiInput {
 
-    private final double externalBufferMemRatio;
-    private final FlinkJoinType type;
-    private final boolean leftIsSmaller;
-    private final boolean[] filterNulls;
-
-    // generated code to cook
-    private GeneratedJoinCondition condFuncCode;
-    private GeneratedProjection projectionCode1;
-    private GeneratedProjection projectionCode2;
-    private GeneratedNormalizedKeyComputer computer1;
-    private GeneratedRecordComparator comparator1;
-    private GeneratedNormalizedKeyComputer computer2;
-    private GeneratedRecordComparator comparator2;
-    private GeneratedRecordComparator genKeyComparator;
+    private final SortMergeJoinFunction sortMergeJoinFunction;
 
-    private transient long externalBufferMemory;
-    private transient MemoryManager memManager;
-    private transient IOManager ioManager;
-    private transient BinaryRowDataSerializer serializer1;
-    private transient BinaryRowDataSerializer serializer2;
-    private transient BinaryExternalSorter sorter1;
-    private transient BinaryExternalSorter sorter2;
-    private transient Collector<RowData> collector;
-    private transient boolean[] isFinished;
-    private transient JoinCondition condFunc;
-    private transient RecordComparator keyComparator;
-    private transient Projection<RowData, BinaryRowData> projection1;
-    private transient Projection<RowData, BinaryRowData> projection2;
-
-    private transient RowData leftNullRow;
-    private transient RowData rightNullRow;
-    private transient JoinedRowData joinedRow;
-
-    public SortMergeJoinOperator(
-            double externalBufferMemRatio,
-            FlinkJoinType type,
-            boolean leftIsSmaller,
-            GeneratedJoinCondition condFuncCode,
-            GeneratedProjection projectionCode1,
-            GeneratedProjection projectionCode2,
-            GeneratedNormalizedKeyComputer computer1,
-            GeneratedRecordComparator comparator1,
-            GeneratedNormalizedKeyComputer computer2,
-            GeneratedRecordComparator comparator2,
-            GeneratedRecordComparator genKeyComparator,
-            boolean[] filterNulls) {
-        this.externalBufferMemRatio = externalBufferMemRatio;
-        this.type = type;
-        this.leftIsSmaller = leftIsSmaller;
-        this.condFuncCode = condFuncCode;
-        this.projectionCode1 = projectionCode1;
-        this.projectionCode2 = projectionCode2;
-        this.computer1 = checkNotNull(computer1);
-        this.comparator1 = checkNotNull(comparator1);
-        this.computer2 = checkNotNull(computer2);
-        this.comparator2 = checkNotNull(comparator2);
-        this.genKeyComparator = checkNotNull(genKeyComparator);
-        this.filterNulls = filterNulls;
+    public SortMergeJoinOperator(SortMergeJoinFunction sortMergeJoinFunction) {
+        this.sortMergeJoinFunction = sortMergeJoinFunction;
     }
 
     @Override
     public void open() throws Exception {
         super.open();
 
-        Configuration conf = getContainingTask().getJobConfiguration();
-
-        isFinished = new boolean[] {false, false};
-
-        collector = new StreamRecordCollector<>(output);
-
-        ClassLoader cl = getUserCodeClassloader();
-        AbstractRowDataSerializer inputSerializer1 =
-                (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn1(cl);
-        this.serializer1 = new BinaryRowDataSerializer(inputSerializer1.getArity());
-
-        AbstractRowDataSerializer inputSerializer2 =
-                (AbstractRowDataSerializer) getOperatorConfig().getTypeSerializerIn2(cl);
-        this.serializer2 = new BinaryRowDataSerializer(inputSerializer2.getArity());
-
-        this.memManager = this.getContainingTask().getEnvironment().getMemoryManager();
-        this.ioManager = this.getContainingTask().getEnvironment().getIOManager();
-
-        long totalMemory = computeMemorySize();
-
-        externalBufferMemory = (long) (totalMemory * externalBufferMemRatio);
-        externalBufferMemory =
-                Math.max(externalBufferMemory, ResettableExternalBuffer.MIN_NUM_MEMORY);
-
-        long totalSortMem =
-                totalMemory
-                        - (type.equals(FlinkJoinType.FULL)
-                                ? externalBufferMemory * 2
-                                : externalBufferMemory);
-        if (totalSortMem < 0) {
-            throw new TableException(
-                    "Memory size is too small: "
-                            + totalMemory
-                            + ", please increase manage memory of task manager.");
-        }
-
-        // sorter1
-        this.sorter1 =
-                new BinaryExternalSorter(
-                        this.getContainingTask(),
-                        memManager,
-                        totalSortMem / 2,
-                        ioManager,
-                        inputSerializer1,
-                        serializer1,
-                        computer1.newInstance(cl),
-                        comparator1.newInstance(cl),
-                        conf);
-        this.sorter1.startThreads();
-
-        // sorter2
-        this.sorter2 =
-                new BinaryExternalSorter(
-                        this.getContainingTask(),
-                        memManager,
-                        totalSortMem / 2,
-                        ioManager,
-                        inputSerializer2,
-                        serializer2,
-                        computer2.newInstance(cl),
-                        comparator2.newInstance(cl),
-                        conf);
-        this.sorter2.startThreads();
-
-        keyComparator = genKeyComparator.newInstance(cl);
-        this.condFunc = condFuncCode.newInstance(cl);
-        condFunc.setRuntimeContext(getRuntimeContext());
-        condFunc.open(new Configuration());
-
-        projection1 = projectionCode1.newInstance(cl);
-        projection2 = projectionCode2.newInstance(cl);
-
-        this.leftNullRow = new GenericRowData(serializer1.getArity());
-        this.rightNullRow = new GenericRowData(serializer2.getArity());
-        this.joinedRow = new JoinedRowData();
-
-        condFuncCode = null;
-        computer1 = null;
-        comparator1 = null;
-        computer2 = null;
-        comparator2 = null;
-        projectionCode1 = null;
-        projectionCode2 = null;
-        genKeyComparator = null;
-
-        getMetricGroup()
-                .gauge(
-                        "memoryUsedSizeInBytes",
-                        (Gauge<Long>)
-                                () ->
-                                        sorter1.getUsedMemoryInBytes()
-                                                + sorter2.getUsedMemoryInBytes());
-
-        getMetricGroup()
-                .gauge(
-                        "numSpillFiles",
-                        (Gauge<Long>)
-                                () -> sorter1.getNumSpillFiles() + sorter2.getNumSpillFiles());
-
-        getMetricGroup()
-                .gauge(
-                        "spillInBytes",
-                        (Gauge<Long>) () -> sorter1.getSpillInBytes() + sorter2.getSpillInBytes());
+        // initialize sort merge join function
+        this.sortMergeJoinFunction.open(
+                this.getContainingTask(),
+                this.getOperatorConfig(),
+                new StreamRecordCollector(output),
+                this.computeMemorySize(),
+                this.getRuntimeContext(),
+                this.getMetricGroup());
     }
 
     @Override
     public void processElement1(StreamRecord<RowData> element) throws Exception {
-        this.sorter1.write(element.getValue());
+        this.sortMergeJoinFunction.processElement1(element.getValue());
     }
 
     @Override
     public void processElement2(StreamRecord<RowData> element) throws Exception {
-        this.sorter2.write(element.getValue());
+        this.sortMergeJoinFunction.processElement2(element.getValue());
     }
 
     @Override
     public void endInput(int inputId) throws Exception {
-        isFinished[inputId - 1] = true;
-        if (isAllFinished()) {
-            doSortMergeJoin();
-        }
-    }
-
-    private void doSortMergeJoin() throws Exception {
-        MutableObjectIterator iterator1 = sorter1.getIterator();
-        MutableObjectIterator iterator2 = sorter2.getIterator();
-
-        if (type.equals(FlinkJoinType.INNER)) {
-            if (!leftIsSmaller) {
-                try (SortMergeInnerJoinIterator joinIterator =
-                        new SortMergeInnerJoinIterator(
-                                serializer1,
-                                serializer2,
-                                projection1,
-                                projection2,
-                                keyComparator,
-                                iterator1,
-                                iterator2,
-                                newBuffer(serializer2),
-                                filterNulls)) {
-                    innerJoin(joinIterator, false);
-                }
-            } else {
-                try (SortMergeInnerJoinIterator joinIterator =
-                        new SortMergeInnerJoinIterator(
-                                serializer2,
-                                serializer1,
-                                projection2,
-                                projection1,
-                                keyComparator,
-                                iterator2,
-                                iterator1,
-                                newBuffer(serializer1),
-                                filterNulls)) {
-                    innerJoin(joinIterator, true);
-                }
-            }
-        } else if (type.equals(FlinkJoinType.LEFT)) {
-            try (SortMergeOneSideOuterJoinIterator joinIterator =
-                    new SortMergeOneSideOuterJoinIterator(
-                            serializer1,
-                            serializer2,
-                            projection1,
-                            projection2,
-                            keyComparator,
-                            iterator1,
-                            iterator2,
-                            newBuffer(serializer2),
-                            filterNulls)) {
-                oneSideOuterJoin(joinIterator, false, rightNullRow);
-            }
-        } else if (type.equals(FlinkJoinType.RIGHT)) {
-            try (SortMergeOneSideOuterJoinIterator joinIterator =
-                    new SortMergeOneSideOuterJoinIterator(
-                            serializer2,
-                            serializer1,
-                            projection2,
-                            projection1,
-                            keyComparator,
-                            iterator2,
-                            iterator1,
-                            newBuffer(serializer1),
-                            filterNulls)) {
-                oneSideOuterJoin(joinIterator, true, leftNullRow);
-            }
-        } else if (type.equals(FlinkJoinType.FULL)) {
-            try (SortMergeFullOuterJoinIterator fullOuterJoinIterator =
-                    new SortMergeFullOuterJoinIterator(
-                            serializer1,
-                            serializer2,
-                            projection1,
-                            projection2,
-                            keyComparator,
-                            iterator1,
-                            iterator2,
-                            newBuffer(serializer1),
-                            newBuffer(serializer2),
-                            filterNulls)) {
-                fullOuterJoin(fullOuterJoinIterator);
-            }
-        } else if (type.equals(FlinkJoinType.SEMI)) {
-            try (SortMergeInnerJoinIterator joinIterator =
-                    new SortMergeInnerJoinIterator(
-                            serializer1,
-                            serializer2,
-                            projection1,
-                            projection2,
-                            keyComparator,
-                            iterator1,
-                            iterator2,
-                            newBuffer(serializer2),
-                            filterNulls)) {
-                while (joinIterator.nextInnerJoin()) {
-                    RowData probeRow = joinIterator.getProbeRow();
-                    boolean matched = false;
-                    try (ResettableExternalBuffer.BufferIterator iter =
-                            joinIterator.getMatchBuffer().newIterator()) {
-                        while (iter.advanceNext()) {
-                            RowData row = iter.getRow();
-                            if (condFunc.apply(probeRow, row)) {
-                                matched = true;
-                                break;
-                            }
-                        }
-                    }
-                    if (matched) {
-                        collector.collect(probeRow);
-                    }
-                }
-            }
-        } else if (type.equals(FlinkJoinType.ANTI)) {
-            try (SortMergeOneSideOuterJoinIterator joinIterator =
-                    new SortMergeOneSideOuterJoinIterator(
-                            serializer1,
-                            serializer2,
-                            projection1,
-                            projection2,
-                            keyComparator,
-                            iterator1,
-                            iterator2,
-                            newBuffer(serializer2),
-                            filterNulls)) {
-                while (joinIterator.nextOuterJoin()) {
-                    RowData probeRow = joinIterator.getProbeRow();
-                    ResettableExternalBuffer matchBuffer = joinIterator.getMatchBuffer();
-                    boolean matched = false;
-                    if (matchBuffer != null) {
-                        try (ResettableExternalBuffer.BufferIterator iter =
-                                matchBuffer.newIterator()) {
-                            while (iter.advanceNext()) {
-                                RowData row = iter.getRow();
-                                if (condFunc.apply(probeRow, row)) {
-                                    matched = true;
-                                    break;
-                                }
-                            }
-                        }
-                    }
-                    if (!matched) {
-                        collector.collect(probeRow);
-                    }
-                }
-            }
-        } else {
-            throw new RuntimeException("Not support type: " + type);
-        }
-    }
-
-    private void innerJoin(SortMergeInnerJoinIterator iterator, boolean reverseInvoke)
-            throws Exception {
-        while (iterator.nextInnerJoin()) {
-            RowData probeRow = iterator.getProbeRow();
-            ResettableExternalBuffer.BufferIterator iter = iterator.getMatchBuffer().newIterator();
-            while (iter.advanceNext()) {
-                RowData row = iter.getRow();
-                joinWithCondition(probeRow, row, reverseInvoke);
-            }
-            iter.close();
-        }
-    }
-
-    private void oneSideOuterJoin(
-            SortMergeOneSideOuterJoinIterator iterator, boolean reverseInvoke, RowData buildNullRow)
-            throws Exception {
-        while (iterator.nextOuterJoin()) {
-            RowData probeRow = iterator.getProbeRow();
-            boolean found = false;
-
-            if (iterator.getMatchKey() != null) {
-                ResettableExternalBuffer.BufferIterator iter =
-                        iterator.getMatchBuffer().newIterator();
-                while (iter.advanceNext()) {
-                    RowData row = iter.getRow();
-                    found |= joinWithCondition(probeRow, row, reverseInvoke);
-                }
-                iter.close();
-            }
-
-            if (!found) {
-                collect(probeRow, buildNullRow, reverseInvoke);
-            }
-        }
-    }
-
-    private void fullOuterJoin(SortMergeFullOuterJoinIterator iterator) throws Exception {
-        BitSet bitSet = new BitSet();
-
-        while (iterator.nextOuterJoin()) {
-
-            bitSet.clear();
-            BinaryRowData matchKey = iterator.getMatchKey();
-            ResettableExternalBuffer buffer1 = iterator.getBuffer1();
-            ResettableExternalBuffer buffer2 = iterator.getBuffer2();
-
-            if (matchKey == null && buffer1.size() > 0) { // left outer join.
-                ResettableExternalBuffer.BufferIterator iter = buffer1.newIterator();
-                while (iter.advanceNext()) {
-                    RowData row1 = iter.getRow();
-                    collector.collect(joinedRow.replace(row1, rightNullRow));
-                }
-                iter.close();
-            } else if (matchKey == null && buffer2.size() > 0) { // right outer join.
-                ResettableExternalBuffer.BufferIterator iter = buffer2.newIterator();
-                while (iter.advanceNext()) {
-                    RowData row2 = iter.getRow();
-                    collector.collect(joinedRow.replace(leftNullRow, row2));
-                }
-                iter.close();
-            } else if (matchKey != null) { // match join.
-                ResettableExternalBuffer.BufferIterator iter1 = buffer1.newIterator();
-                while (iter1.advanceNext()) {
-                    RowData row1 = iter1.getRow();
-                    boolean found = false;
-                    int index = 0;
-                    ResettableExternalBuffer.BufferIterator iter2 = buffer2.newIterator();
-                    while (iter2.advanceNext()) {
-                        RowData row2 = iter2.getRow();
-                        if (condFunc.apply(row1, row2)) {
-                            collector.collect(joinedRow.replace(row1, row2));
-                            found = true;
-                            bitSet.set(index);
-                        }
-                        index++;
-                    }
-                    iter2.close();
-                    if (!found) {
-                        collector.collect(joinedRow.replace(row1, rightNullRow));
-                    }
-                }
-                iter1.close();
-
-                // row2 outer
-                int index = 0;
-                ResettableExternalBuffer.BufferIterator iter2 = buffer2.newIterator();
-                while (iter2.advanceNext()) {
-                    RowData row2 = iter2.getRow();
-                    if (!bitSet.get(index)) {
-                        collector.collect(joinedRow.replace(leftNullRow, row2));
-                    }
-                    index++;
-                }
-                iter2.close();
-            } else { // bug...
-                throw new RuntimeException("There is a bug.");
-            }
-        }
-    }
-
-    private boolean joinWithCondition(RowData row1, RowData row2, boolean reverseInvoke)
-            throws Exception {
-        if (reverseInvoke) {
-            if (condFunc.apply(row2, row1)) {
-                collector.collect(joinedRow.replace(row2, row1));
-                return true;
-            }
-        } else {
-            if (condFunc.apply(row1, row2)) {
-                collector.collect(joinedRow.replace(row1, row2));
-                return true;
-            }
-        }
-        return false;
-    }
-
-    private void collect(RowData row1, RowData row2, boolean reverseInvoke) {
-        if (reverseInvoke) {
-            collector.collect(joinedRow.replace(row2, row1));
-        } else {
-            collector.collect(joinedRow.replace(row1, row2));
-        }
-    }
-
-    private ResettableExternalBuffer newBuffer(BinaryRowDataSerializer serializer) {
-        LazyMemorySegmentPool pool =
-                new LazyMemorySegmentPool(
-                        this.getContainingTask(),
-                        memManager,
-                        (int) (externalBufferMemory / memManager.getPageSize()));
-        return new ResettableExternalBuffer(
-                ioManager,
-                pool,
-                serializer,
-                false /* we don't use newIterator(int beginRow), so don't need use this optimization*/);
-    }
-
-    private boolean isAllFinished() {
-        return isFinished[0] && isFinished[1];
+        this.sortMergeJoinFunction.endInput(inputId);
     }
 
     @Override
     public void close() throws Exception {
         super.close();
-        if (this.sorter1 != null) {
-            this.sorter1.close();
-        }
-        if (this.sorter2 != null) {
-            this.sorter2.close();
-        }
-        condFunc.close();
+        this.sortMergeJoinFunction.close();
     }
 }
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
index 3686a62ab0a..b9868059497 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
@@ -562,7 +562,7 @@ public class BinaryHashTableTest {
      * fits into memory by itself and needs to be repartitioned in the recursion again.
      */
     @Test
-    public void testFailingHashJoinTooManyRecursions() throws IOException {
+    public void testSpillingHashJoinWithTooManyRecursions() throws IOException {
         // the following two values are known to have a hash-code collision on the first recursion
         // level.
         // we use them to make sure one partition grows over-proportionally large
@@ -613,12 +613,61 @@ public class BinaryHashTableTest {
                         896 * PAGE_SIZE,
                         ioManager);
 
-        try {
-            join(table, buildInput, probeInput);
-            fail("Hash Join must have failed due to too many recursions.");
-        } catch (Exception ex) {
-            // expected
+        // create the map for validating the results
+        HashMap<Integer, Long> map = new HashMap<>(numKeys);
+
+        BinaryRowData buildRow = buildSideSerializer.createInstance();
+        while ((buildRow = buildInput.next(buildRow)) != null) {
+            table.putBuildRow(buildRow);
         }
+        table.endBuild();
+
+        BinaryRowData probeRow = probeSideSerializer.createInstance();
+        while ((probeRow = probeInput.next(probeRow)) != null) {
+            if (table.tryProbe(probeRow)) {
+                testJoin(table, map);
+            }
+        }
+
+        while (table.nextMatching()) {
+            testJoin(table, map);
+        }
+
+        // The partition which spill to disk more than 3 can't be joined
+        assertThat(map.size()).as("Wrong number of records in join result.").isLessThan(numKeys);
+
+        // Here exists two partition which spill to disk more than 3
+        assertThat(table.getPartitionsPendingForSMJ().size())
+                .as("Wrong number of spilled partition.")
+                .isEqualTo(2);
+
+        Map<Integer, Integer> spilledPartitionBuildSideKeys = new HashMap<>();
+        Map<Integer, Integer> spilledPartitionProbeSideKeys = new HashMap<>();
+        for (BinaryHashPartition p : table.getPartitionsPendingForSMJ()) {
+            RowIterator<BinaryRowData> buildIter = table.getSpilledPartitionBuildSideIter(p);
+            while (buildIter.advanceNext()) {
+                Integer key = buildIter.getRow().getInt(0);
+                spilledPartitionBuildSideKeys.put(
+                        key, spilledPartitionBuildSideKeys.getOrDefault(key, 0) + 1);
+            }
+
+            ProbeIterator probeIter = table.getSpilledPartitionProbeSideIter(p);
+            BinaryRowData rowData;
+            while ((rowData = probeIter.next()) != null) {
+                Integer key = rowData.getInt(0);
+                spilledPartitionProbeSideKeys.put(
+                        key, spilledPartitionProbeSideKeys.getOrDefault(key, 0) + 1);
+            }
+        }
+
+        // assert spilled partition contains key repeatedValue1 and repeatedValue2
+        Integer buildKeyCnt = repeatedValueCount + buildValsPerKey;
+        assertThat(spilledPartitionBuildSideKeys).containsEntry(repeatedValue1, buildKeyCnt);
+        assertThat(spilledPartitionBuildSideKeys).containsEntry(repeatedValue2, buildKeyCnt);
+
+        Integer probeKeyCnt = repeatedValueCount + probeValsPerKey;
+        assertThat(spilledPartitionProbeSideKeys).containsEntry(repeatedValue1, probeKeyCnt);
+        assertThat(spilledPartitionProbeSideKeys).containsEntry(repeatedValue2, probeKeyCnt);
 
         table.close();
 
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/LongHashTableTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/LongHashTableTest.java
index 4ae6506ec50..8ec312ee392 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/LongHashTableTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/hashtable/LongHashTableTest.java
@@ -433,7 +433,7 @@ public class LongHashTableTest {
      * fits into memory by itself and needs to be repartitioned in the recursion again.
      */
     @TestTemplate
-    void testFailingHashJoinTooManyRecursions() throws IOException {
+    void testSpillingHashJoinWithTooManyRecursions() throws IOException {
         // the following two values are known to have a hash-code collision on the first recursion
         // level.
         // we use them to make sure one partition grows over-proportionally large
@@ -477,12 +477,61 @@ public class LongHashTableTest {
         MutableObjectIterator<BinaryRowData> probeInput = new UnionIterator<>(probes);
         final MyHashTable table = new MyHashTable(896 * PAGE_SIZE);
 
-        try {
-            join(table, buildInput, probeInput);
-            fail("Hash Join must have failed due to too many recursions.");
-        } catch (Exception ex) {
-            // expected
+        // create the map for validating the results
+        HashMap<Integer, Long> map = new HashMap<>(numKeys);
+
+        BinaryRowData buildRow = buildSideSerializer.createInstance();
+        while ((buildRow = buildInput.next(buildRow)) != null) {
+            table.putBuildRow(buildRow);
         }
+        table.endBuild();
+
+        BinaryRowData probeRow = probeSideSerializer.createInstance();
+        while ((probeRow = probeInput.next(probeRow)) != null) {
+            if (table.tryProbe(probeRow)) {
+                testJoin(table, map);
+            }
+        }
+
+        while (table.nextMatching()) {
+            testJoin(table, map);
+        }
+
+        // The partition which spill to disk more than 3 can't be joined
+        assertThat(map.size()).as("Wrong number of records in join result.").isLessThan(numKeys);
+
+        // Here exists two partition which spill to disk more than 3
+        assertThat(table.getPartitionsPendingForSMJ().size())
+                .as("Wrong number of spilled partition.")
+                .isEqualTo(2);
+
+        Map<Integer, Integer> spilledPartitionBuildSideKeys = new HashMap<>();
+        Map<Integer, Integer> spilledPartitionProbeSideKeys = new HashMap<>();
+        for (LongHashPartition p : table.getPartitionsPendingForSMJ()) {
+            RowIterator<BinaryRowData> buildIter = table.getSpilledPartitionBuildSideIter(p);
+            while (buildIter.advanceNext()) {
+                Integer key = buildIter.getRow().getInt(0);
+                spilledPartitionBuildSideKeys.put(
+                        key, spilledPartitionBuildSideKeys.getOrDefault(key, 0) + 1);
+            }
+
+            ProbeIterator probeIter = table.getSpilledPartitionProbeSideIter(p);
+            BinaryRowData rowData;
+            while ((rowData = probeIter.next()) != null) {
+                Integer key = rowData.getInt(0);
+                spilledPartitionProbeSideKeys.put(
+                        key, spilledPartitionProbeSideKeys.getOrDefault(key, 0) + 1);
+            }
+        }
+
+        // assert spilled partition contains key repeatedValue1 and repeatedValue2
+        Integer buildKeyCnt = repeatedValueCount + buildValsPerKey;
+        assertThat(spilledPartitionBuildSideKeys).containsEntry(repeatedValue1, buildKeyCnt);
+        assertThat(spilledPartitionBuildSideKeys).containsEntry(repeatedValue2, buildKeyCnt);
+
+        Integer probeKeyCnt = repeatedValueCount + probeValsPerKey;
+        assertThat(spilledPartitionProbeSideKeys).containsEntry(repeatedValue1, probeKeyCnt);
+        assertThat(spilledPartitionProbeSideKeys).containsEntry(repeatedValue2, probeKeyCnt);
 
         table.close();
 
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2AdaptiveHashJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2AdaptiveHashJoinOperatorTest.java
new file mode 100644
index 00000000000..5ca980c7ca9
--- /dev/null
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2AdaptiveHashJoinOperatorTest.java
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.join;
+
+import org.apache.flink.runtime.operators.testutils.UnionIterator;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.runtime.hashtable.BinaryHashTableTest;
+import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
+import org.apache.flink.util.MutableObjectIterator;
+
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/** Random test for adaptive {@link HashJoinOperator}. */
+public class Int2AdaptiveHashJoinOperatorTest extends Int2HashJoinOperatorTestBase {
+
+    // ---------------------- build first inner join -----------------------------------------
+    // ------------------- fallback to sort merge join in build or probe phase ---------------
+    @Test
+    public void testBuildFirstHashInnerJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 405590;
+        final int repeatedValue2 = 928820;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 160000;
+        final int buildValsPerKey = 3;
+        final int probeValsPerKey = 2;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        // output build side which matched the key
+        int expectOutSize = numKeys1 * buildValsPerKey * probeValsPerKey;
+        buildJoin(buildInput, probeInput, false, false, true, expectOutSize, numKeys1, -1);
+    }
+
+    // ---------------------- build first left out join -----------------------------------------
+    // -------------------- fallback to sort merge join in build or probe phase -----------------
+    @Test
+    public void testBuildFirstHashLeftOutJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 40559;
+        final int repeatedValue2 = 92882;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 50000;
+        final int buildValsPerKey = 3;
+        final int probeValsPerKey = 2;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        // output build side that is in left
+        int expectOutSize =
+                (numKeys1 - numKeys2) * buildValsPerKey * probeValsPerKey
+                        + (numKeys1 - numKeys2) * buildValsPerKey
+                        + repeatedValueCountBuild * probeValsPerKey
+                        + repeatedValueCountBuild;
+        buildJoin(buildInput, probeInput, true, false, true, expectOutSize, numKeys1, -1);
+    }
+
+    // ---------------------- build first right out join -----------------------------------------
+    // --------------------- fallback to sort merge join in build or probe phase -----------------
+    @Test
+    public void testBuildFirstHashRightOutJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 40559;
+        final int repeatedValue2 = 92882;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 50000;
+        final int buildValsPerKey = 3;
+        final int probeValsPerKey = 1;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        // output probe side that is in right
+        int expectOutSize = numKeys2 * probeValsPerKey * buildValsPerKey + repeatedValueCountBuild;
+        buildJoin(buildInput, probeInput, false, true, true, expectOutSize, numKeys2, -1);
+    }
+
+    // ---------------------- build first full out join -----------------------------------------
+    // --------------------- fallback to sort merge join in build or probe phase ----------------
+    @Test
+    public void testBuildFirstHashFullOutJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 40559;
+        final int repeatedValue2 = 92882;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 150000;
+        final int buildValsPerKey = 3;
+        final int probeValsPerKey = 1;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        // output build side and probe side simultaneously
+        int expectOutSize =
+                numKeys1 * buildValsPerKey * probeValsPerKey
+                        + repeatedValueCountBuild * 2
+                        + numKeys2
+                        - numKeys1;
+        buildJoin(buildInput, probeInput, true, true, true, expectOutSize, numKeys2, -1);
+    }
+
+    // ---------------------- build second left out join -----------------------------------------
+    // ---------------------- switch to sort merge join in build or probe phase ------------------
+    @Test
+    public void testBuildSecondHashLeftOutJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 40559;
+        final int repeatedValue2 = 92882;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 50000;
+        final int buildValsPerKey = 3;
+        final int probeValsPerKey = 1;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        // output probe side that is in left
+        int expectOutSize = numKeys2 * probeValsPerKey * buildValsPerKey + repeatedValueCountBuild;
+        buildJoin(buildInput, probeInput, true, false, false, expectOutSize, numKeys2, -1);
+    }
+
+    // ---------------------- build second right out join -----------------------------------------
+    // ---------------------- switch to sort merge join in build or probe phase -------------------
+    @Test
+    public void testBuildSecondHashRightOutJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 40559;
+        final int repeatedValue2 = 92882;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 50000;
+        final int buildValsPerKey = 3;
+        final int probeValsPerKey = 1;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        // output build side that is in right
+        int expectOutSize = numKeys1 * buildValsPerKey + repeatedValueCountBuild * 2;
+        buildJoin(buildInput, probeInput, false, true, false, expectOutSize, numKeys1, -1);
+    }
+
+    // ---------------------- switch to sort merge join in build or probe phase ------------------
+    @Test
+    public void testSemiJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 40559;
+        final int repeatedValue2 = 92882;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 100000;
+        final int buildValsPerKey = 3;
+        final int probeValsPerKey = 1;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        Object operator =
+                newOperator(33 * 32 * 1024, FlinkJoinType.SEMI, HashJoinType.SEMI, false, false);
+
+        // output probe side that is in left
+        joinAndAssert(operator, buildInput, probeInput, numKeys2, numKeys2, 0, true);
+    }
+
+    // ---------------------- fallback to sort merge join in build or probe phase ------------------
+    @Test
+    public void testAntiJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 40559;
+        final int repeatedValue2 = 92882;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 160000;
+        final int buildValsPerKey = 3;
+        final int probeValsPerKey = 1;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        Object operator =
+                newOperator(33 * 32 * 1024, FlinkJoinType.ANTI, HashJoinType.ANTI, false, false);
+
+        // output probe side that is in left
+        int expectOutSize = numKeys2 - numKeys1;
+        joinAndAssert(operator, buildInput, probeInput, expectOutSize, expectOutSize, 0, true);
+    }
+
+    // ---------------------- fallback to sort merge join in build or probe phase ------------------
+    @Test
+    public void testBuildLeftSemiJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 405;
+        final int repeatedValue2 = 928;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 100000;
+        final int numKeys2 = 1000;
+        final int buildValsPerKey = 1;
+        final int probeValsPerKey = 3;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        Object operator =
+                newOperator(
+                        33 * 32 * 1024,
+                        FlinkJoinType.SEMI,
+                        HashJoinType.BUILD_LEFT_SEMI,
+                        true,
+                        false);
+
+        // output build side in left that is matched with probe side
+        int expectOutSize = numKeys2 + repeatedValueCountBuild * 2;
+        joinAndAssert(operator, buildInput, probeInput, expectOutSize, numKeys1, -1, true);
+    }
+
+    // ---------------------- fallback to sort merge join in build or probe phase ------------------
+    @Test
+    public void testBuildLeftAntiJoinFallbackToSMJ() throws Exception {
+        // the following two values are known to have a hash-code collision on the first recursion
+        // level. we use them to make sure one partition grows over-proportionally large
+        final int repeatedValue1 = 40559;
+        final int repeatedValue2 = 92882;
+        final int repeatedValueCountBuild = 100000;
+
+        final int numKeys1 = 500000;
+        final int numKeys2 = 100000;
+        final int buildValsPerKey = 1;
+        final int probeValsPerKey = 3;
+
+        // create a build input that gives 100k pairs with 3 values sharing the same key, plus
+        // 1 million pairs with two colliding keys
+        MutableObjectIterator<BinaryRowData> build1 =
+                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, false);
+        MutableObjectIterator<BinaryRowData> build2 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue1, 17, repeatedValueCountBuild);
+        MutableObjectIterator<BinaryRowData> build3 =
+                new BinaryHashTableTest.ConstantsKeyValuePairsIterator(
+                        repeatedValue2, 23, repeatedValueCountBuild);
+        List<MutableObjectIterator<BinaryRowData>> builds = new ArrayList<>();
+        builds.add(build1);
+        builds.add(build2);
+        builds.add(build3);
+        MutableObjectIterator<BinaryRowData> buildInput = new UnionIterator<>(builds);
+
+        MutableObjectIterator<BinaryRowData> probeInput =
+                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
+
+        Object operator =
+                newOperator(
+                        33 * 32 * 1024,
+                        FlinkJoinType.ANTI,
+                        HashJoinType.BUILD_LEFT_ANTI,
+                        true,
+                        false);
+
+        // output build side in left that not matched with probe side
+        int expectOutSize = numKeys1 - numKeys2;
+        joinAndAssert(
+                operator, buildInput, probeInput, expectOutSize, numKeys1 - numKeys2, -1, true);
+    }
+}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
index af589b95954..dd4da788526 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
@@ -18,42 +18,14 @@
 
 package org.apache.flink.table.runtime.operators.join;
 
-import org.apache.flink.api.common.functions.AbstractRichFunction;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.core.memory.ManagedMemoryUseCase;
-import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.streaming.api.operators.StreamOperator;
-import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
-import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTaskTestHarness;
-import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.binary.BinaryRowData;
-import org.apache.flink.table.data.utils.JoinedRowData;
-import org.apache.flink.table.data.writer.BinaryRowWriter;
-import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
-import org.apache.flink.table.runtime.generated.GeneratedProjection;
-import org.apache.flink.table.runtime.generated.JoinCondition;
-import org.apache.flink.table.runtime.generated.Projection;
-import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
 import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
-import org.apache.flink.table.types.logical.IntType;
-import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.util.MutableObjectIterator;
 
 import org.junit.Test;
 
-import java.io.Serializable;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Queue;
-import java.util.Random;
-
-import static java.lang.Long.valueOf;
-import static org.assertj.core.api.Assertions.assertThat;
-
 /** Random test for {@link HashJoinOperator}. */
-public class Int2HashJoinOperatorTest implements Serializable {
+public class Int2HashJoinOperatorTest extends Int2HashJoinOperatorTestBase {
 
     // ---------------------- build first inner join -----------------------------------------
     @Test
@@ -233,8 +205,8 @@ public class Int2HashJoinOperatorTest implements Serializable {
         MutableObjectIterator<BinaryRowData> probeInput =
                 new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
 
-        HashJoinType type = HashJoinType.SEMI;
-        Object operator = newOperator(33 * 32 * 1024, type, false);
+        Object operator =
+                newOperator(33 * 32 * 1024, FlinkJoinType.SEMI, HashJoinType.SEMI, false, false);
         joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
     }
 
@@ -250,14 +222,13 @@ public class Int2HashJoinOperatorTest implements Serializable {
         MutableObjectIterator<BinaryRowData> probeInput =
                 new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
 
-        HashJoinType type = HashJoinType.ANTI;
-        Object operator = newOperator(33 * 32 * 1024, type, false);
+        Object operator =
+                newOperator(33 * 32 * 1024, FlinkJoinType.ANTI, HashJoinType.ANTI, false, false);
         joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
     }
 
     @Test
     public void testBuildLeftSemiJoin() throws Exception {
-
         int numKeys1 = 10;
         int numKeys2 = 9;
         int buildValsPerKey = 10;
@@ -267,14 +238,18 @@ public class Int2HashJoinOperatorTest implements Serializable {
         MutableObjectIterator<BinaryRowData> probeInput =
                 new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
 
-        HashJoinType type = HashJoinType.BUILD_LEFT_SEMI;
-        Object operator = newOperator(33 * 32 * 1024, type, false);
+        Object operator =
+                newOperator(
+                        33 * 32 * 1024,
+                        FlinkJoinType.SEMI,
+                        HashJoinType.BUILD_LEFT_SEMI,
+                        true,
+                        false);
         joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
     }
 
     @Test
     public void testBuildLeftAntiJoin() throws Exception {
-
         int numKeys1 = 10;
         int numKeys2 = 9;
         int buildValsPerKey = 10;
@@ -284,237 +259,13 @@ public class Int2HashJoinOperatorTest implements Serializable {
         MutableObjectIterator<BinaryRowData> probeInput =
                 new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
 
-        HashJoinType type = HashJoinType.BUILD_LEFT_ANTI;
-        Object operator = newOperator(33 * 32 * 1024, type, false);
+        Object operator =
+                newOperator(
+                        33 * 32 * 1024,
+                        FlinkJoinType.ANTI,
+                        HashJoinType.BUILD_LEFT_ANTI,
+                        true,
+                        false);
         joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
     }
-
-    private void buildJoin(
-            MutableObjectIterator<BinaryRowData> buildInput,
-            MutableObjectIterator<BinaryRowData> probeInput,
-            boolean leftOut,
-            boolean rightOut,
-            boolean buildLeft,
-            int expectOutSize,
-            int expectOutKeySize,
-            int expectOutVal)
-            throws Exception {
-        HashJoinType type = HashJoinType.of(buildLeft, leftOut, rightOut);
-        Object operator = newOperator(33 * 32 * 1024, type, !buildLeft);
-        joinAndAssert(
-                operator,
-                buildInput,
-                probeInput,
-                expectOutSize,
-                expectOutKeySize,
-                expectOutVal,
-                false);
-    }
-
-    @SuppressWarnings("unchecked")
-    static void joinAndAssert(
-            Object operator,
-            MutableObjectIterator<BinaryRowData> input1,
-            MutableObjectIterator<BinaryRowData> input2,
-            int expectOutSize,
-            int expectOutKeySize,
-            int expectOutVal,
-            boolean semiJoin)
-            throws Exception {
-        InternalTypeInfo<RowData> typeInfo =
-                InternalTypeInfo.ofFields(new IntType(), new IntType());
-        InternalTypeInfo<RowData> rowDataTypeInfo =
-                InternalTypeInfo.ofFields(
-                        new IntType(), new IntType(), new IntType(), new IntType());
-        TwoInputStreamTaskTestHarness<BinaryRowData, BinaryRowData, JoinedRowData> testHarness =
-                new TwoInputStreamTaskTestHarness<>(
-                        TwoInputStreamTask::new,
-                        2,
-                        1,
-                        new int[] {1, 2},
-                        typeInfo,
-                        (TypeInformation) typeInfo,
-                        rowDataTypeInfo);
-        testHarness.memorySize = 36 * 1024 * 1024;
-        testHarness.getExecutionConfig().enableObjectReuse();
-        testHarness.setupOutputForSingletonOperatorChain();
-        if (operator instanceof StreamOperator) {
-            testHarness.getStreamConfig().setStreamOperator((StreamOperator<?>) operator);
-        } else {
-            testHarness
-                    .getStreamConfig()
-                    .setStreamOperatorFactory((StreamOperatorFactory<?>) operator);
-        }
-        testHarness.getStreamConfig().setOperatorID(new OperatorID());
-        testHarness
-                .getStreamConfig()
-                .setManagedMemoryFractionOperatorOfUseCase(ManagedMemoryUseCase.OPERATOR, 0.99);
-
-        testHarness.invoke();
-        testHarness.waitForTaskRunning();
-
-        Random random = new Random();
-        do {
-            BinaryRowData row1 = null;
-            BinaryRowData row2 = null;
-
-            if (random.nextInt(2) == 0) {
-                row1 = input1.next();
-                if (row1 == null) {
-                    row2 = input2.next();
-                }
-            } else {
-                row2 = input2.next();
-                if (row2 == null) {
-                    row1 = input1.next();
-                }
-            }
-
-            if (row1 == null && row2 == null) {
-                break;
-            }
-
-            if (row1 != null) {
-                testHarness.processElement(new StreamRecord<>(row1), 0, 0);
-            } else {
-                testHarness.processElement(new StreamRecord<>(row2), 1, 0);
-            }
-        } while (true);
-
-        testHarness.endInput(0, 0);
-        testHarness.endInput(1, 0);
-
-        testHarness.waitForInputProcessing();
-        testHarness.waitForTaskCompletion();
-
-        Queue<Object> actual = testHarness.getOutput();
-
-        assertThat(actual).as("Output was not correct.").hasSize(expectOutSize);
-
-        // Don't verify the output value when experOutVal is -1
-        if (expectOutVal != -1) {
-            if (semiJoin) {
-                HashMap<Integer, Long> map = new HashMap<>(expectOutKeySize);
-
-                for (Object o : actual) {
-                    StreamRecord<RowData> record = (StreamRecord<RowData>) o;
-                    RowData row = record.getValue();
-                    int key = row.getInt(0);
-                    int val = row.getInt(1);
-                    Long contained = map.get(key);
-                    if (contained == null) {
-                        contained = (long) val;
-                    } else {
-                        contained = valueOf(contained + val);
-                    }
-                    map.put(key, contained);
-                }
-
-                assertThat(map).as("Wrong number of keys").hasSize(expectOutKeySize);
-                for (Map.Entry<Integer, Long> entry : map.entrySet()) {
-                    long val = entry.getValue();
-                    int key = entry.getKey();
-
-                    assertThat(val)
-                            .as("Wrong number of values in per-key cross product for key " + key)
-                            .isEqualTo(expectOutVal);
-                }
-            } else {
-                // create the map for validating the results
-                HashMap<Integer, Long> map = new HashMap<>(expectOutKeySize);
-
-                for (Object o : actual) {
-                    StreamRecord<RowData> record = (StreamRecord<RowData>) o;
-                    RowData row = record.getValue();
-                    int key = row.isNullAt(0) ? row.getInt(2) : row.getInt(0);
-
-                    int val1 = 0;
-                    int val2 = 0;
-                    if (!row.isNullAt(1)) {
-                        val1 = row.getInt(1);
-                    }
-                    if (!row.isNullAt(3)) {
-                        val2 = row.getInt(3);
-                    }
-                    int val = val1 + val2;
-
-                    Long contained = map.get(key);
-                    if (contained == null) {
-                        contained = (long) val;
-                    } else {
-                        contained = valueOf(contained + val);
-                    }
-                    map.put(key, contained);
-                }
-
-                assertThat(map).as("Wrong number of keys").hasSize(expectOutKeySize);
-                for (Map.Entry<Integer, Long> entry : map.entrySet()) {
-                    long val = entry.getValue();
-                    int key = entry.getKey();
-
-                    assertThat(val)
-                            .as("Wrong number of values in per-key cross product for key " + key)
-                            .isEqualTo(expectOutVal);
-                }
-            }
-        }
-    }
-
-    /** my projection. */
-    public static final class MyProjection implements Projection<RowData, BinaryRowData> {
-
-        BinaryRowData innerRow = new BinaryRowData(1);
-        BinaryRowWriter writer = new BinaryRowWriter(innerRow);
-
-        @Override
-        public BinaryRowData apply(RowData row) {
-            writer.reset();
-            if (row.isNullAt(0)) {
-                writer.setNullAt(0);
-            } else {
-                writer.writeInt(0, row.getInt(0));
-            }
-            writer.complete();
-            return innerRow;
-        }
-    }
-
-    public Object newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) {
-        return HashJoinOperator.newHashJoinOperator(
-                type,
-                new GeneratedJoinCondition("", "", new Object[0]) {
-                    @Override
-                    public JoinCondition newInstance(ClassLoader classLoader) {
-                        return new TrueCondition();
-                    }
-                },
-                reverseJoinFunction,
-                new boolean[] {true},
-                new GeneratedProjection("", "", new Object[0]) {
-                    @Override
-                    public Projection newInstance(ClassLoader classLoader) {
-                        return new MyProjection();
-                    }
-                },
-                new GeneratedProjection("", "", new Object[0]) {
-                    @Override
-                    public Projection newInstance(ClassLoader classLoader) {
-                        return new MyProjection();
-                    }
-                },
-                false,
-                20,
-                10000,
-                10000,
-                RowType.of(new IntType()));
-    }
-
-    /** Test util. */
-    public static class TrueCondition extends AbstractRichFunction implements JoinCondition {
-
-        @Override
-        public boolean apply(RowData in1, RowData in2) {
-            return true;
-        }
-    }
 }
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
similarity index 52%
copy from flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
copy to flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
index af589b95954..0988d272dfc 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2HashJoinOperatorTestBase.java
@@ -32,17 +32,20 @@ import org.apache.flink.table.data.binary.BinaryRowData;
 import org.apache.flink.table.data.utils.JoinedRowData;
 import org.apache.flink.table.data.writer.BinaryRowWriter;
 import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.generated.GeneratedNormalizedKeyComputer;
 import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
 import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.generated.NormalizedKeyComputer;
 import org.apache.flink.table.runtime.generated.Projection;
+import org.apache.flink.table.runtime.generated.RecordComparator;
+import org.apache.flink.table.runtime.operators.sort.IntNormalizedKeyComputer;
+import org.apache.flink.table.runtime.operators.sort.IntRecordComparator;
 import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
-import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
 import org.apache.flink.table.types.logical.IntType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.util.MutableObjectIterator;
 
-import org.junit.Test;
-
 import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
@@ -50,246 +53,13 @@ import java.util.Queue;
 import java.util.Random;
 
 import static java.lang.Long.valueOf;
+import static org.apache.flink.table.runtime.util.JoinUtil.getJoinType;
 import static org.assertj.core.api.Assertions.assertThat;
 
-/** Random test for {@link HashJoinOperator}. */
-public class Int2HashJoinOperatorTest implements Serializable {
-
-    // ---------------------- build first inner join -----------------------------------------
-    @Test
-    public void testBuildFirstHashInnerJoin() throws Exception {
-
-        int numKeys = 100;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true);
-
-        buildJoin(
-                buildInput,
-                probeInput,
-                false,
-                false,
-                true,
-                numKeys * buildValsPerKey * probeValsPerKey,
-                numKeys,
-                165);
-    }
-
-    // ---------------------- build first left out join -----------------------------------------
-    @Test
-    public void testBuildFirstHashLeftOutJoin() throws Exception {
-
-        int numKeys1 = 9;
-        int numKeys2 = 10;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        buildJoin(
-                buildInput,
-                probeInput,
-                true,
-                false,
-                true,
-                numKeys1 * buildValsPerKey * probeValsPerKey,
-                numKeys1,
-                165);
-    }
-
-    // ---------------------- build first right out join -----------------------------------------
-    @Test
-    public void testBuildFirstHashRightOutJoin() throws Exception {
-
-        int numKeys1 = 9;
-        int numKeys2 = 10;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        buildJoin(buildInput, probeInput, false, true, true, 280, numKeys2, -1);
-    }
-
-    // ---------------------- build first full out join -----------------------------------------
-    @Test
-    public void testBuildFirstHashFullOutJoin() throws Exception {
-
-        int numKeys1 = 9;
-        int numKeys2 = 10;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        buildJoin(buildInput, probeInput, true, true, true, 280, numKeys2, -1);
-    }
-
-    // ---------------------- build second inner join -----------------------------------------
-    @Test
-    public void testBuildSecondHashInnerJoin() throws Exception {
-
-        int numKeys = 100;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true);
-
-        buildJoin(
-                buildInput,
-                probeInput,
-                false,
-                false,
-                false,
-                numKeys * buildValsPerKey * probeValsPerKey,
-                numKeys,
-                165);
-    }
+/** Base test class for {@link HashJoinOperator}. */
+public abstract class Int2HashJoinOperatorTestBase implements Serializable {
 
-    // ---------------------- build second left out join -----------------------------------------
-    @Test
-    public void testBuildSecondHashLeftOutJoin() throws Exception {
-
-        int numKeys1 = 10;
-        int numKeys2 = 9;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        buildJoin(
-                buildInput,
-                probeInput,
-                true,
-                false,
-                false,
-                numKeys2 * buildValsPerKey * probeValsPerKey,
-                numKeys2,
-                165);
-    }
-
-    // ---------------------- build second right out join -----------------------------------------
-    @Test
-    public void testBuildSecondHashRightOutJoin() throws Exception {
-
-        int numKeys1 = 9;
-        int numKeys2 = 10;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        buildJoin(
-                buildInput,
-                probeInput,
-                false,
-                true,
-                false,
-                numKeys1 * buildValsPerKey * probeValsPerKey,
-                numKeys2,
-                -1);
-    }
-
-    // ---------------------- build second full out join -----------------------------------------
-    @Test
-    public void testBuildSecondHashFullOutJoin() throws Exception {
-
-        int numKeys1 = 9;
-        int numKeys2 = 10;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        buildJoin(buildInput, probeInput, true, true, false, 280, numKeys2, -1);
-    }
-
-    @Test
-    public void testSemiJoin() throws Exception {
-
-        int numKeys1 = 9;
-        int numKeys2 = 10;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        HashJoinType type = HashJoinType.SEMI;
-        Object operator = newOperator(33 * 32 * 1024, type, false);
-        joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
-    }
-
-    @Test
-    public void testAntiJoin() throws Exception {
-
-        int numKeys1 = 9;
-        int numKeys2 = 10;
-        int buildValsPerKey = 3;
-        int probeValsPerKey = 10;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        HashJoinType type = HashJoinType.ANTI;
-        Object operator = newOperator(33 * 32 * 1024, type, false);
-        joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
-    }
-
-    @Test
-    public void testBuildLeftSemiJoin() throws Exception {
-
-        int numKeys1 = 10;
-        int numKeys2 = 9;
-        int buildValsPerKey = 10;
-        int probeValsPerKey = 3;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        HashJoinType type = HashJoinType.BUILD_LEFT_SEMI;
-        Object operator = newOperator(33 * 32 * 1024, type, false);
-        joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true);
-    }
-
-    @Test
-    public void testBuildLeftAntiJoin() throws Exception {
-
-        int numKeys1 = 10;
-        int numKeys2 = 9;
-        int buildValsPerKey = 10;
-        int probeValsPerKey = 3;
-        MutableObjectIterator<BinaryRowData> buildInput =
-                new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
-        MutableObjectIterator<BinaryRowData> probeInput =
-                new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
-
-        HashJoinType type = HashJoinType.BUILD_LEFT_ANTI;
-        Object operator = newOperator(33 * 32 * 1024, type, false);
-        joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true);
-    }
-
-    private void buildJoin(
+    public void buildJoin(
             MutableObjectIterator<BinaryRowData> buildInput,
             MutableObjectIterator<BinaryRowData> probeInput,
             boolean leftOut,
@@ -299,8 +69,10 @@ public class Int2HashJoinOperatorTest implements Serializable {
             int expectOutKeySize,
             int expectOutVal)
             throws Exception {
-        HashJoinType type = HashJoinType.of(buildLeft, leftOut, rightOut);
-        Object operator = newOperator(33 * 32 * 1024, type, !buildLeft);
+        FlinkJoinType flinkJoinType = getJoinType(leftOut, rightOut);
+        HashJoinType hashJoinType = HashJoinType.of(buildLeft, leftOut, rightOut);
+        Object operator =
+                newOperator(33 * 32 * 1024, flinkJoinType, hashJoinType, buildLeft, !buildLeft);
         joinAndAssert(
                 operator,
                 buildInput,
@@ -311,8 +83,122 @@ public class Int2HashJoinOperatorTest implements Serializable {
                 false);
     }
 
+    public Object newOperator(
+            long memorySize,
+            FlinkJoinType flinkJoinType,
+            HashJoinType hashJoinType,
+            boolean buildLeft,
+            boolean reverseJoinFunction) {
+        GeneratedJoinCondition condFuncCode =
+                new GeneratedJoinCondition("", "", new Object[0]) {
+                    @Override
+                    public JoinCondition newInstance(ClassLoader classLoader) {
+                        return new TrueCondition();
+                    }
+                };
+        GeneratedProjection buildProjectionCode =
+                new GeneratedProjection("", "", new Object[0]) {
+                    @Override
+                    public Projection newInstance(ClassLoader classLoader) {
+                        return new MyProjection();
+                    }
+                };
+        GeneratedProjection probeProjectionCode =
+                new GeneratedProjection("", "", new Object[0]) {
+                    @Override
+                    public Projection newInstance(ClassLoader classLoader) {
+                        return new MyProjection();
+                    }
+                };
+        GeneratedNormalizedKeyComputer computer1 =
+                new GeneratedNormalizedKeyComputer("", "") {
+                    @Override
+                    public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+                        return new IntNormalizedKeyComputer();
+                    }
+                };
+        GeneratedRecordComparator comparator1 =
+                new GeneratedRecordComparator("", "", new Object[0]) {
+                    @Override
+                    public RecordComparator newInstance(ClassLoader classLoader) {
+                        return new IntRecordComparator();
+                    }
+                };
+
+        GeneratedNormalizedKeyComputer computer2 =
+                new GeneratedNormalizedKeyComputer("", "") {
+                    @Override
+                    public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+                        return new IntNormalizedKeyComputer();
+                    }
+                };
+        GeneratedRecordComparator comparator2 =
+                new GeneratedRecordComparator("", "", new Object[0]) {
+                    @Override
+                    public RecordComparator newInstance(ClassLoader classLoader) {
+                        return new IntRecordComparator();
+                    }
+                };
+        GeneratedRecordComparator genKeyComparator =
+                new GeneratedRecordComparator("", "", new Object[0]) {
+                    @Override
+                    public RecordComparator newInstance(ClassLoader classLoader) {
+                        return new IntRecordComparator();
+                    }
+                };
+        boolean[] filterNulls = new boolean[] {true};
+
+        SortMergeJoinFunction sortMergeJoinFunction;
+        if (buildLeft) {
+            sortMergeJoinFunction =
+                    new SortMergeJoinFunction(
+                            0,
+                            flinkJoinType,
+                            buildLeft,
+                            condFuncCode,
+                            buildProjectionCode,
+                            probeProjectionCode,
+                            computer1,
+                            comparator1,
+                            computer2,
+                            comparator2,
+                            genKeyComparator,
+                            filterNulls);
+        } else {
+            sortMergeJoinFunction =
+                    new SortMergeJoinFunction(
+                            0,
+                            flinkJoinType,
+                            buildLeft,
+                            condFuncCode,
+                            probeProjectionCode,
+                            buildProjectionCode,
+                            computer2,
+                            comparator2,
+                            computer1,
+                            comparator1,
+                            genKeyComparator,
+                            filterNulls);
+        }
+
+        return HashJoinOperator.newHashJoinOperator(
+                hashJoinType,
+                buildLeft,
+                condFuncCode,
+                reverseJoinFunction,
+                filterNulls,
+                buildProjectionCode,
+                probeProjectionCode,
+                false,
+                20,
+                10000,
+                10000,
+                RowType.of(new IntType()),
+                sortMergeJoinFunction);
+    }
+
     @SuppressWarnings("unchecked")
-    static void joinAndAssert(
+    public static void joinAndAssert(
             Object operator,
             MutableObjectIterator<BinaryRowData> input1,
             MutableObjectIterator<BinaryRowData> input2,
@@ -335,7 +221,7 @@ public class Int2HashJoinOperatorTest implements Serializable {
                         typeInfo,
                         (TypeInformation) typeInfo,
                         rowDataTypeInfo);
-        testHarness.memorySize = 36 * 1024 * 1024;
+        testHarness.memorySize = 3 * 1024 * 1024;
         testHarness.getExecutionConfig().enableObjectReuse();
         testHarness.setupOutputForSingletonOperatorChain();
         if (operator instanceof StreamOperator) {
@@ -479,36 +365,6 @@ public class Int2HashJoinOperatorTest implements Serializable {
         }
     }
 
-    public Object newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) {
-        return HashJoinOperator.newHashJoinOperator(
-                type,
-                new GeneratedJoinCondition("", "", new Object[0]) {
-                    @Override
-                    public JoinCondition newInstance(ClassLoader classLoader) {
-                        return new TrueCondition();
-                    }
-                },
-                reverseJoinFunction,
-                new boolean[] {true},
-                new GeneratedProjection("", "", new Object[0]) {
-                    @Override
-                    public Projection newInstance(ClassLoader classLoader) {
-                        return new MyProjection();
-                    }
-                },
-                new GeneratedProjection("", "", new Object[0]) {
-                    @Override
-                    public Projection newInstance(ClassLoader classLoader) {
-                        return new MyProjection();
-                    }
-                },
-                false,
-                20,
-                10000,
-                10000,
-                RowType.of(new IntType()));
-    }
-
     /** Test util. */
     public static class TrueCondition extends AbstractRichFunction implements JoinCondition {
 
@@ -517,4 +373,15 @@ public class Int2HashJoinOperatorTest implements Serializable {
             return true;
         }
     }
+
+    /** Test cond. */
+    public static class MyJoinCondition extends AbstractRichFunction implements JoinCondition {
+
+        public MyJoinCondition(Object[] reference) {}
+
+        @Override
+        public boolean apply(RowData in1, RowData in2) {
+            return true;
+        }
+    }
 }
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2SortMergeJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2SortMergeJoinOperatorTest.java
index 01f6d0288dd..9deb124cb5c 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2SortMergeJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/Int2SortMergeJoinOperatorTest.java
@@ -32,7 +32,7 @@ import org.apache.flink.table.runtime.generated.JoinCondition;
 import org.apache.flink.table.runtime.generated.NormalizedKeyComputer;
 import org.apache.flink.table.runtime.generated.Projection;
 import org.apache.flink.table.runtime.generated.RecordComparator;
-import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTest.MyProjection;
+import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTestBase.MyProjection;
 import org.apache.flink.table.runtime.operators.sort.IntNormalizedKeyComputer;
 import org.apache.flink.table.runtime.operators.sort.IntRecordComparator;
 import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
@@ -211,7 +211,11 @@ public class Int2SortMergeJoinOperatorTest {
     }
 
     static StreamOperator newOperator(FlinkJoinType type, boolean leftIsSmaller) {
-        return new SortMergeJoinOperator(
+        return new SortMergeJoinOperator(getJoinFunction(type, leftIsSmaller));
+    }
+
+    public static SortMergeJoinFunction getJoinFunction(FlinkJoinType type, boolean leftIsSmaller) {
+        return new SortMergeJoinFunction(
                 0,
                 type,
                 leftIsSmaller,
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinIteratorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinIteratorTest.java
index a938ea882df..60d01c9466c 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinIteratorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/SortMergeJoinIteratorTest.java
@@ -27,7 +27,7 @@ import org.apache.flink.runtime.memory.MemoryManagerBuilder;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.binary.BinaryRowData;
 import org.apache.flink.table.data.writer.BinaryRowWriter;
-import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTest.MyProjection;
+import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTestBase.MyProjection;
 import org.apache.flink.table.runtime.operators.sort.IntRecordComparator;
 import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
 import org.apache.flink.table.runtime.util.LazyMemorySegmentPool;
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2HashJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2HashJoinOperatorTest.java
index 31eb491376d..ed6ff982eba 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2HashJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2HashJoinOperatorTest.java
@@ -31,10 +31,17 @@ import org.apache.flink.table.data.binary.BinaryRowData;
 import org.apache.flink.table.data.utils.JoinedRowData;
 import org.apache.flink.table.data.writer.BinaryRowWriter;
 import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
+import org.apache.flink.table.runtime.generated.GeneratedNormalizedKeyComputer;
 import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
 import org.apache.flink.table.runtime.generated.JoinCondition;
+import org.apache.flink.table.runtime.generated.NormalizedKeyComputer;
 import org.apache.flink.table.runtime.generated.Projection;
+import org.apache.flink.table.runtime.generated.RecordComparator;
+import org.apache.flink.table.runtime.operators.sort.StringNormalizedKeyComputer;
+import org.apache.flink.table.runtime.operators.sort.StringRecordComparator;
 import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.runtime.util.JoinUtil;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.VarCharType;
 
@@ -81,8 +88,10 @@ public class String2HashJoinOperatorTest implements Serializable {
     }
 
     private void init(boolean leftOut, boolean rightOut, boolean buildLeft) throws Exception {
-        HashJoinType type = HashJoinType.of(buildLeft, leftOut, rightOut);
-        HashJoinOperator operator = newOperator(33 * 32 * 1024, type, !buildLeft);
+        FlinkJoinType flinkJoinType = JoinUtil.getJoinType(leftOut, rightOut);
+        HashJoinType hashJoinType = HashJoinType.of(buildLeft, leftOut, rightOut);
+        HashJoinOperator operator =
+                newOperator(33 * 32 * 1024, flinkJoinType, hashJoinType, !buildLeft);
         testHarness =
                 new TwoInputStreamTaskTestHarness<>(
                         TwoInputStreamTask::new,
@@ -325,33 +334,98 @@ public class String2HashJoinOperatorTest implements Serializable {
     }
 
     private HashJoinOperator newOperator(
-            long memorySize, HashJoinType type, boolean reverseJoinFunction) {
-        return HashJoinOperator.newHashJoinOperator(
-                type,
+            long memorySize,
+            FlinkJoinType flinkJoinType,
+            HashJoinType hashJoinType,
+            boolean reverseJoinFunction) {
+        boolean buildLeft = false;
+        GeneratedJoinCondition condFuncCode =
                 new GeneratedJoinCondition("", "", new Object[0]) {
                     @Override
                     public JoinCondition newInstance(ClassLoader classLoader) {
-                        return new Int2HashJoinOperatorTest.TrueCondition();
+                        return new Int2HashJoinOperatorTestBase.TrueCondition();
                     }
-                },
-                reverseJoinFunction,
-                new boolean[] {true},
+                };
+        GeneratedProjection buildProjectionCode =
                 new GeneratedProjection("", "", new Object[0]) {
                     @Override
                     public Projection newInstance(ClassLoader classLoader) {
                         return new MyProjection();
                     }
-                },
+                };
+        GeneratedProjection probeProjectionCode =
                 new GeneratedProjection("", "", new Object[0]) {
                     @Override
                     public Projection newInstance(ClassLoader classLoader) {
                         return new MyProjection();
                     }
-                },
+                };
+        GeneratedNormalizedKeyComputer computer1 =
+                new GeneratedNormalizedKeyComputer("", "") {
+                    @Override
+                    public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+                        return new StringNormalizedKeyComputer();
+                    }
+                };
+        GeneratedRecordComparator comparator1 =
+                new GeneratedRecordComparator("", "", new Object[0]) {
+                    @Override
+                    public RecordComparator newInstance(ClassLoader classLoader) {
+                        return new StringRecordComparator();
+                    }
+                };
+
+        GeneratedNormalizedKeyComputer computer2 =
+                new GeneratedNormalizedKeyComputer("", "") {
+                    @Override
+                    public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+                        return new StringNormalizedKeyComputer();
+                    }
+                };
+        GeneratedRecordComparator comparator2 =
+                new GeneratedRecordComparator("", "", new Object[0]) {
+                    @Override
+                    public RecordComparator newInstance(ClassLoader classLoader) {
+                        return new StringRecordComparator();
+                    }
+                };
+        GeneratedRecordComparator genKeyComparator =
+                new GeneratedRecordComparator("", "", new Object[0]) {
+                    @Override
+                    public RecordComparator newInstance(ClassLoader classLoader) {
+                        return new StringRecordComparator();
+                    }
+                };
+        boolean[] filterNulls = new boolean[] {true};
+
+        SortMergeJoinFunction sortMergeJoinFunction =
+                new SortMergeJoinFunction(
+                        0,
+                        flinkJoinType,
+                        buildLeft,
+                        condFuncCode,
+                        probeProjectionCode,
+                        buildProjectionCode,
+                        computer2,
+                        comparator2,
+                        computer1,
+                        comparator1,
+                        genKeyComparator,
+                        filterNulls);
+
+        return HashJoinOperator.newHashJoinOperator(
+                hashJoinType,
+                buildLeft,
+                condFuncCode,
+                reverseJoinFunction,
+                filterNulls,
+                buildProjectionCode,
+                probeProjectionCode,
                 false,
                 20,
                 10000,
                 10000,
-                RowType.of(VarCharType.STRING_TYPE));
+                RowType.of(VarCharType.STRING_TYPE),
+                sortMergeJoinFunction);
     }
 }
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2SortMergeJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2SortMergeJoinOperatorTest.java
index 719cecfba8a..cb1ea8357ca 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2SortMergeJoinOperatorTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/String2SortMergeJoinOperatorTest.java
@@ -185,58 +185,60 @@ public class String2SortMergeJoinOperatorTest {
     }
 
     static StreamOperator newOperator(FlinkJoinType type, boolean leftIsSmaller) {
-        return new SortMergeJoinOperator(
-                0,
-                type,
-                leftIsSmaller,
-                new GeneratedJoinCondition("", "", new Object[0]) {
-                    @Override
-                    public JoinCondition newInstance(ClassLoader classLoader) {
-                        return new Int2HashJoinOperatorTest.TrueCondition();
-                    }
-                },
-                new GeneratedProjection("", "", new Object[0]) {
-                    @Override
-                    public Projection newInstance(ClassLoader classLoader) {
-                        return new MyProjection();
-                    }
-                },
-                new GeneratedProjection("", "", new Object[0]) {
-                    @Override
-                    public Projection newInstance(ClassLoader classLoader) {
-                        return new MyProjection();
-                    }
-                },
-                new GeneratedNormalizedKeyComputer("", "") {
-                    @Override
-                    public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
-                        return new StringNormalizedKeyComputer();
-                    }
-                },
-                new GeneratedRecordComparator("", "", new Object[0]) {
-                    @Override
-                    public RecordComparator newInstance(ClassLoader classLoader) {
-                        return new StringRecordComparator();
-                    }
-                },
-                new GeneratedNormalizedKeyComputer("", "") {
-                    @Override
-                    public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
-                        return new StringNormalizedKeyComputer();
-                    }
-                },
-                new GeneratedRecordComparator("", "", new Object[0]) {
-                    @Override
-                    public RecordComparator newInstance(ClassLoader classLoader) {
-                        return new StringRecordComparator();
-                    }
-                },
-                new GeneratedRecordComparator("", "", new Object[0]) {
-                    @Override
-                    public RecordComparator newInstance(ClassLoader classLoader) {
-                        return new StringRecordComparator();
-                    }
-                },
-                new boolean[] {true});
+        SortMergeJoinFunction sortMergeJoinFunction =
+                new SortMergeJoinFunction(
+                        0,
+                        type,
+                        leftIsSmaller,
+                        new GeneratedJoinCondition("", "", new Object[0]) {
+                            @Override
+                            public JoinCondition newInstance(ClassLoader classLoader) {
+                                return new Int2HashJoinOperatorTest.TrueCondition();
+                            }
+                        },
+                        new GeneratedProjection("", "", new Object[0]) {
+                            @Override
+                            public Projection newInstance(ClassLoader classLoader) {
+                                return new MyProjection();
+                            }
+                        },
+                        new GeneratedProjection("", "", new Object[0]) {
+                            @Override
+                            public Projection newInstance(ClassLoader classLoader) {
+                                return new MyProjection();
+                            }
+                        },
+                        new GeneratedNormalizedKeyComputer("", "") {
+                            @Override
+                            public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+                                return new StringNormalizedKeyComputer();
+                            }
+                        },
+                        new GeneratedRecordComparator("", "", new Object[0]) {
+                            @Override
+                            public RecordComparator newInstance(ClassLoader classLoader) {
+                                return new StringRecordComparator();
+                            }
+                        },
+                        new GeneratedNormalizedKeyComputer("", "") {
+                            @Override
+                            public NormalizedKeyComputer newInstance(ClassLoader classLoader) {
+                                return new StringNormalizedKeyComputer();
+                            }
+                        },
+                        new GeneratedRecordComparator("", "", new Object[0]) {
+                            @Override
+                            public RecordComparator newInstance(ClassLoader classLoader) {
+                                return new StringRecordComparator();
+                            }
+                        },
+                        new GeneratedRecordComparator("", "", new Object[0]) {
+                            @Override
+                            public RecordComparator newInstance(ClassLoader classLoader) {
+                                return new StringRecordComparator();
+                            }
+                        },
+                        new boolean[] {true});
+        return new SortMergeJoinOperator(sortMergeJoinFunction);
     }
 }
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/JoinUtil.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/JoinUtil.java
new file mode 100644
index 00000000000..cef98c84113
--- /dev/null
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/JoinUtil.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.util;
+
+import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
+
+import static org.apache.flink.table.runtime.operators.join.FlinkJoinType.FULL;
+import static org.apache.flink.table.runtime.operators.join.FlinkJoinType.INNER;
+import static org.apache.flink.table.runtime.operators.join.FlinkJoinType.RIGHT;
+
+/** Utility for join. */
+public class JoinUtil {
+
+    public static FlinkJoinType getJoinType(boolean leftOuter, boolean rightOuter) {
+        if (leftOuter && rightOuter) {
+            return FULL;
+        } else if (leftOuter) {
+            return FlinkJoinType.LEFT;
+        } else if (rightOuter) {
+            return RIGHT;
+        } else {
+            return INNER;
+        }
+    }
+
+    private JoinUtil() {}
+}
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/UniformBinaryRowGenerator.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/UniformBinaryRowGenerator.java
index 5113475fc38..74b26e9b085 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/UniformBinaryRowGenerator.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/util/UniformBinaryRowGenerator.java
@@ -23,7 +23,7 @@ import org.apache.flink.table.data.writer.BinaryRowWriter;
 import org.apache.flink.types.IntValue;
 import org.apache.flink.util.MutableObjectIterator;
 
-/** Uniform genarator for binary row. */
+/** Uniform generator for binary row. */
 public class UniformBinaryRowGenerator implements MutableObjectIterator<BinaryRowData> {
 
     int numKeys;