You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2023/02/18 18:07:49 UTC

[pinot] branch master updated: [multistage] support SEMI/ANTI join (#10294)

This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 1033a11b6d [multistage] support SEMI/ANTI join (#10294)
1033a11b6d is described below

commit 1033a11b6df14e83baac176c185beeefeafb3eeb
Author: Rong Rong <ro...@apache.org>
AuthorDate: Sat Feb 18 10:07:42 2023 -0800

    [multistage] support SEMI/ANTI join (#10294)
    
    * [multistage] support SEMI/ANTI join
    
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../calcite/rel/rules/PinotQueryRuleSets.java      |   1 +
 .../query/planner/logical/RelToStageConverter.java |   3 +-
 .../apache/pinot/query/planner/stage/JoinNode.java |  19 +++-
 .../src/test/resources/queries/JoinPlans.json      |  15 +++
 .../query/runtime/operator/HashJoinOperator.java   |  74 ++++++++------
 .../runtime/operator/HashJoinOperatorTest.java     | 108 +++++++++++++--------
 6 files changed, 148 insertions(+), 72 deletions(-)

diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
index f93e16b637..5ce3bd1c55 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
@@ -71,6 +71,7 @@ public class PinotQueryRuleSets {
 
           // join rules
           CoreRules.JOIN_PUSH_EXPRESSIONS,
+          CoreRules.PROJECT_TO_SEMI_JOIN,
 
           // convert non-all union into all-union + distinct
           CoreRules.UNION_TO_DISTINCT,
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
index eb24da7e8e..2bc1ebdae0 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
@@ -130,7 +130,8 @@ public final class RelToStageConverter {
     JoinInfo joinInfo = node.analyzeCondition();
     FieldSelectionKeySelector leftFieldSelectionKeySelector = new FieldSelectionKeySelector(joinInfo.leftKeys);
     FieldSelectionKeySelector rightFieldSelectionKeySelector = new FieldSelectionKeySelector(joinInfo.rightKeys);
-    return new JoinNode(currentStageId, toDataSchema(node.getRowType()), joinType,
+    return new JoinNode(currentStageId, toDataSchema(node.getRowType()), toDataSchema(node.getLeft().getRowType()),
+        toDataSchema(node.getRight().getRowType()), joinType,
         new JoinNode.JoinKeys(leftFieldSelectionKeySelector, rightFieldSelectionKeySelector),
         joinInfo.nonEquiConditions.stream().map(RexExpression::toRexExpression).collect(Collectors.toList()));
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
index af9b4e03ed..3b127d6568 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.query.planner.stage;
 
+import java.util.Arrays;
 import java.util.List;
 import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.pinot.common.utils.DataSchema;
@@ -34,14 +35,20 @@ public class JoinNode extends AbstractStageNode {
   private JoinKeys _joinKeys;
   @ProtoProperties
   private List<RexExpression> _joinClause;
+  @ProtoProperties
+  private List<String> _leftColumnNames;
+  @ProtoProperties
+  private List<String> _rightColumnNames;
 
   public JoinNode(int stageId) {
     super(stageId);
   }
 
-  public JoinNode(int stageId, DataSchema dataSchema, JoinRelType joinRelType, JoinKeys joinKeys,
-      List<RexExpression> joinClause) {
+  public JoinNode(int stageId, DataSchema dataSchema, DataSchema leftSchema, DataSchema rightSchema,
+      JoinRelType joinRelType, JoinKeys joinKeys, List<RexExpression> joinClause) {
     super(stageId, dataSchema);
+    _leftColumnNames = Arrays.asList(leftSchema.getColumnNames());
+    _rightColumnNames = Arrays.asList(rightSchema.getColumnNames());
     _joinRelType = joinRelType;
     _joinKeys = joinKeys;
     _joinClause = joinClause;
@@ -59,6 +66,14 @@ public class JoinNode extends AbstractStageNode {
     return _joinClause;
   }
 
+  public List<String> getLeftColumnNames() {
+    return _leftColumnNames;
+  }
+
+  public List<String> getRightColumnNames() {
+    return _rightColumnNames;
+  }
+
   @Override
   public String explain() {
     return "JOIN";
diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
index c3dc7a99f2..d9ffeaf0da 100644
--- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
@@ -221,6 +221,21 @@
           "\n          LogicalTableScan(table=[[b]])",
           "\n"
         ]
+      },
+      {
+        "description": "Semi join with IN clause",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a WHERE col3 IN (SELECT col3 FROM b)",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$2], col2=[$0])",
+          "\n  LogicalJoin(condition=[=($1, $4)], joinType=[semi])",
+          "\n    LogicalExchange(distribution=[hash[1]])",
+          "\n      LogicalProject(col2=[$0], col3=[$1], col1=[$2])",
+          "\n        LogicalTableScan(table=[[a]])",
+          "\n    LogicalExchange(distribution=[hash[1]])",
+          "\n      LogicalTableScan(table=[[b]])",
+          "\n"
+        ]
       }
     ]
   },
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
index 453982e9a1..1c27281a96 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
@@ -60,8 +60,8 @@ public class HashJoinOperator extends MultiStageOperator {
   private static final String EXPLAIN_NAME = "HASH_JOIN";
   private static final Logger LOGGER = LoggerFactory.getLogger(AggregateOperator.class);
 
-  private static final Set<JoinRelType> SUPPORTED_JOIN_TYPES =
-      ImmutableSet.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL);
+  private static final Set<JoinRelType> SUPPORTED_JOIN_TYPES = ImmutableSet.of(
+      JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL, JoinRelType.SEMI, JoinRelType.ANTI);
 
   private final HashMap<Key, List<Object[]>> _broadcastRightTable;
 
@@ -101,8 +101,8 @@ public class HashJoinOperator extends MultiStageOperator {
     Preconditions.checkState(_leftRowSize > 0, "leftRowSize has to be greater than zero:" + _leftRowSize);
     _resultSchema = node.getDataSchema();
     _resultRowSize = _resultSchema.size();
-    Preconditions.checkState(_resultRowSize > _leftRowSize,
-        "Result row size" + _leftRowSize + " has to be greater than left row size:" + _leftRowSize);
+    Preconditions.checkState(_resultRowSize >= _leftRowSize,
+        "Result row size" + _leftRowSize + " has to be greater than or equal to left row size:" + _leftRowSize);
     _leftTableOperator = leftTableOperator;
     _rightTableOperator = rightTableOperator;
     _joinClauseEvaluators = new ArrayList<>(node.getJoinClauses().size());
@@ -215,32 +215,48 @@ public class HashJoinOperator extends MultiStageOperator {
     List<Object[]> container = leftBlock.isEndOfStreamBlock() ? new ArrayList<>() : leftBlock.getContainer();
     for (Object[] leftRow : container) {
       Key key = new Key(_leftKeySelector.getKey(leftRow));
-      // NOTE: Empty key selector will always give same hash code.
-      List<Object[]> matchedRightRows = _broadcastRightTable.getOrDefault(key, null);
-      if (matchedRightRows == null) {
-        if (needUnmatchedLeftRows()) {
-          rows.add(joinRow(leftRow, null));
-        }
-        continue;
-      }
-      boolean hasMatchForLeftRow = false;
-      for (int i = 0; i < matchedRightRows.size(); i++) {
-        Object[] rightRow = matchedRightRows.get(i);
-        // TODO: Optimize this to avoid unnecessary object copy.
-        Object[] resultRow = joinRow(leftRow, rightRow);
-        if (_joinClauseEvaluators.isEmpty() || _joinClauseEvaluators.stream().allMatch(
-            evaluator -> (Boolean) FunctionInvokeUtils.convert(evaluator.apply(resultRow),
-                DataSchema.ColumnDataType.BOOLEAN))) {
-          rows.add(resultRow);
-          hasMatchForLeftRow = true;
-          if (_matchedRightRows != null) {
-            HashSet<Integer> matchedRows = _matchedRightRows.computeIfAbsent(key, k -> new HashSet<>());
-            matchedRows.add(i);
+      switch (_joinType) {
+        case SEMI:
+          // SEMI-JOIN only checks existence of the key
+          if (_broadcastRightTable.containsKey(key)) {
+            rows.add(joinRow(leftRow, null));
           }
-        }
-      }
-      if (!hasMatchForLeftRow && needUnmatchedLeftRows()) {
-        rows.add(joinRow(leftRow, null));
+          break;
+        case ANTI:
+          // ANTI-JOIN only checks non-existence of the key
+          if (!_broadcastRightTable.containsKey(key)) {
+            rows.add(joinRow(leftRow, null));
+          }
+          break;
+        default: // INNER, LEFT, RIGHT, FULL
+          // NOTE: Empty key selector will always give same hash code.
+          List<Object[]> matchedRightRows = _broadcastRightTable.getOrDefault(key, null);
+          if (matchedRightRows == null) {
+            if (needUnmatchedLeftRows()) {
+              rows.add(joinRow(leftRow, null));
+            }
+            continue;
+          }
+          boolean hasMatchForLeftRow = false;
+          for (int i = 0; i < matchedRightRows.size(); i++) {
+            Object[] rightRow = matchedRightRows.get(i);
+            // TODO: Optimize this to avoid unnecessary object copy.
+            Object[] resultRow = joinRow(leftRow, rightRow);
+            if (_joinClauseEvaluators.isEmpty() || _joinClauseEvaluators.stream().allMatch(
+                evaluator -> (Boolean) FunctionInvokeUtils.convert(evaluator.apply(resultRow),
+                    DataSchema.ColumnDataType.BOOLEAN))) {
+              rows.add(resultRow);
+              hasMatchForLeftRow = true;
+              if (_matchedRightRows != null) {
+                HashSet<Integer> matchedRows = _matchedRightRows.computeIfAbsent(key, k -> new HashSet<>());
+                matchedRows.add(i);
+              }
+            }
+          }
+          if (!hasMatchForLeftRow && needUnmatchedLeftRows()) {
+            rows.add(joinRow(leftRow, null));
+          }
+          break;
       }
     }
     return new TransferableBlock(rows, _resultSchema, DataBlock.Type.ROW);
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
index 4075237249..b9a95ba9a0 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/HashJoinOperatorTest.java
@@ -88,8 +88,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
     HashJoinOperator joinOnString = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
 
     TransferableBlock result = joinOnString.nextBlock();
@@ -125,8 +125,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator joinOnInt = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
     TransferableBlock result = joinOnInt.nextBlock();
     while (result.isNoOpBlock()) {
@@ -159,8 +159,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node = new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(new ArrayList<>(), new ArrayList<>()),
-        joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses);
     HashJoinOperator joinOnInt = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
     TransferableBlock result = joinOnInt.nextBlock();
     while (result.isNoOpBlock()) {
@@ -200,8 +200,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.LEFT, getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.LEFT,
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
 
     TransferableBlock result = join.nextBlock();
@@ -234,8 +234,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.STRING
         });
     List<RexExpression> joinClauses = new ArrayList<>();
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
 
     TransferableBlock result = join.nextBlock();
@@ -265,8 +265,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.LEFT, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.LEFT,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
 
     TransferableBlock result = join.nextBlock();
@@ -299,8 +299,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
 
     TransferableBlock result = join.nextBlock();
@@ -337,8 +337,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node = new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(new ArrayList<>(), new ArrayList<>()),
-        joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
     TransferableBlock result = join.nextBlock();
     while (result.isNoOpBlock()) {
@@ -375,8 +375,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node = new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(new ArrayList<>(), new ArrayList<>()),
-        joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(new ArrayList<>(), new ArrayList<>()), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
     TransferableBlock result = join.nextBlock();
     while (result.isNoOpBlock()) {
@@ -409,8 +409,8 @@ public class HashJoinOperatorTest {
         DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
         DataSchema.ColumnDataType.STRING
     });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.RIGHT, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.RIGHT,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator joinOnNum = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
     TransferableBlock result = joinOnNum.nextBlock();
     while (result.isNoOpBlock()) {
@@ -438,16 +438,16 @@ public class HashJoinOperatorTest {
     Assert.assertTrue(result.isSuccessfulEndOfStreamBlock());
   }
 
-  @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = ".*SEMI is not supported.*")
-  public void shouldThrowOnSemiJoin() {
+  @Test
+  public void shouldHandleSemiJoin() {
     DataSchema leftSchema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{
         DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING
     });
     DataSchema rightSchema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{
         DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING
     });
-    Mockito.when(_leftOperator.nextBlock())
-        .thenReturn(OperatorTestUtil.block(leftSchema, new Object[]{1, "Aa"}, new Object[]{2, "BB"}))
+    Mockito.when(_leftOperator.nextBlock()).thenReturn(
+            OperatorTestUtil.block(leftSchema, new Object[]{1, "Aa"}, new Object[]{2, "BB"}, new Object[]{4, "CC"}))
         .thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
     Mockito.when(_rightOperator.nextBlock()).thenReturn(
             OperatorTestUtil.block(rightSchema, new Object[]{2, "Aa"}, new Object[]{2, "BB"}, new Object[]{3, "BB"}))
@@ -458,9 +458,24 @@ public class HashJoinOperatorTest {
         DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
         DataSchema.ColumnDataType.STRING
     });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.SEMI, getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.SEMI,
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
+    TransferableBlock result = join.nextBlock();
+    while (result.isNoOpBlock()) {
+      result = join.nextBlock();
+    }
+    List<Object[]> resultRows = result.getContainer();
+    List<Object[]> expectedRows = ImmutableList.of(new Object[]{1, "Aa", null, null},
+        new Object[]{2, "BB", null, null});
+    Assert.assertEquals(resultRows.size(), expectedRows.size());
+    Assert.assertEquals(resultRows.get(0), expectedRows.get(0));
+    Assert.assertEquals(resultRows.get(1), expectedRows.get(1));
+    result = join.nextBlock();
+    while (result.isNoOpBlock()) {
+      result = join.nextBlock();
+    }
+    Assert.assertTrue(result.isSuccessfulEndOfStreamBlock());
   }
 
   @Test
@@ -482,8 +497,8 @@ public class HashJoinOperatorTest {
         DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
         DataSchema.ColumnDataType.STRING
     });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.FULL, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.FULL,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
     TransferableBlock result = join.nextBlock();
     while (result.isNoOpBlock()) {
@@ -514,16 +529,16 @@ public class HashJoinOperatorTest {
     Assert.assertTrue(result.isSuccessfulEndOfStreamBlock());
   }
 
-  @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = ".*ANTI is not supported.*")
-  public void shouldThrowOnAntiJoin() {
+  @Test
+  public void shouldHandleAntiJoin() {
     DataSchema leftSchema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{
         DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING
     });
     DataSchema rightSchema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{
         DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING
     });
-    Mockito.when(_leftOperator.nextBlock())
-        .thenReturn(OperatorTestUtil.block(leftSchema, new Object[]{1, "Aa"}, new Object[]{2, "BB"}))
+    Mockito.when(_leftOperator.nextBlock()).thenReturn(
+            OperatorTestUtil.block(leftSchema, new Object[]{1, "Aa"}, new Object[]{2, "BB"}, new Object[]{4, "CC"}))
         .thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
     Mockito.when(_rightOperator.nextBlock()).thenReturn(
             OperatorTestUtil.block(rightSchema, new Object[]{2, "Aa"}, new Object[]{2, "BB"}, new Object[]{3, "BB"}))
@@ -534,9 +549,22 @@ public class HashJoinOperatorTest {
         DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
         DataSchema.ColumnDataType.STRING
     });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.ANTI, getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.ANTI,
+        getJoinKeys(Arrays.asList(1), Arrays.asList(1)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
+    TransferableBlock result = join.nextBlock();
+    while (result.isNoOpBlock()) {
+      result = join.nextBlock();
+    }
+    List<Object[]> resultRows = result.getContainer();
+    List<Object[]> expectedRows = ImmutableList.of(new Object[]{4, "CC", null, null});
+    Assert.assertEquals(resultRows.size(), expectedRows.size());
+    Assert.assertEquals(resultRows.get(0), expectedRows.get(0));
+    result = join.nextBlock();
+    while (result.isNoOpBlock()) {
+      result = join.nextBlock();
+    }
+    Assert.assertTrue(result.isSuccessfulEndOfStreamBlock());
   }
 
   @Test
@@ -559,8 +587,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
 
     TransferableBlock result = join.nextBlock();
@@ -592,8 +620,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
 
     TransferableBlock result = join.nextBlock();
@@ -628,8 +656,8 @@ public class HashJoinOperatorTest {
             DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
             DataSchema.ColumnDataType.STRING
         });
-    JoinNode node =
-        new JoinNode(1, resultSchema, JoinRelType.INNER, getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
+    JoinNode node = new JoinNode(1, resultSchema, leftSchema, rightSchema, JoinRelType.INNER,
+        getJoinKeys(Arrays.asList(0), Arrays.asList(0)), joinClauses);
     HashJoinOperator join = new HashJoinOperator(_leftOperator, _rightOperator, leftSchema, node, 1, 2);
 
     TransferableBlock result = join.nextBlock(); // first no-op consumes first right data block.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org