You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2022/06/08 06:03:04 UTC

[pinot] 05/11: add support for project/filter pushdown (#8558)

This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git

commit 0b414974789b16aa6b54ccd4062f34bfd085db7e
Author: Rong Rong <ro...@apache.org>
AuthorDate: Fri Apr 22 08:57:29 2022 -0700

    add support for project/filter pushdown (#8558)
    
    * make planner node ready for project/filter pushdown
    
    making filter expression compilation work, generating operator
    
    add ProjectNode compilation as well
    
    * fixing comments from previous PR
    
    - add float type
    - change serde javadoc
    - change function name for StageNode serializable
    - change function name in proto utils
    
    * add distributed hash join capability
    
    * fixing serde
    
    * refactor calcite components into sql/rex/expression 3 subclasses
    
    * address diff comments
    
    1. renamed BroadcastJoin to HashJoin
    2. relocate packages inside planner to planner.node to planner.logical/serde/stage based on their functionalities
    3. added SqlHint strategy to planner so that it can do either hash or broadcast join
    
    * additional comment address
    
    1. remove CalciteExpressionParser as it is not used
    2. remove rowType in construction for AbstractStagerNode
    
    * rename RelHint
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 pinot-common/src/main/proto/plan.proto             |   5 +-
 pinot-query-planner/pom.xml                        |   4 +-
 .../org/apache/pinot/query/QueryEnvironment.java   |  12 +-
 .../query/parser/CalciteExpressionParser.java      | 502 ---------------------
 .../query/parser/CalciteRexExpressionParser.java   | 217 +++++++++
 .../pinot/query/parser/CalciteSqlParser.java       |   2 +-
 .../org/apache/pinot/query/parser/ParserUtils.java |  55 +++
 .../org/apache/pinot/query/planner/QueryPlan.java  |   3 +-
 .../apache/pinot/query/planner/StageMetadata.java  |   4 +-
 .../PinotRelationalHints.java}                     |  24 +-
 .../planner/{ => logical}/LogicalPlanner.java      |   2 +-
 .../planner/{ => logical}/RelToStageConverter.java |  39 +-
 .../pinot/query/planner/logical/RexExpression.java | 157 +++++++
 .../query/planner/{ => logical}/StagePlanner.java  |  31 +-
 .../partitioning/FieldSelectionKeySelector.java    |   3 +
 .../query/planner/partitioning/KeySelector.java    |   4 +-
 .../ProtoProperties.java}                          |  19 +-
 .../{nodes => }/serde/ProtoSerializable.java       |  20 +-
 .../{nodes => }/serde/ProtoSerializationUtils.java | 117 +++--
 .../{nodes => stage}/AbstractStageNode.java        |  24 +-
 .../{nodes/CalcNode.java => stage/FilterNode.java} |  24 +-
 .../query/planner/{nodes => stage}/JoinNode.java   |  12 +-
 .../{nodes => stage}/MailboxReceiveNode.java       |   9 +-
 .../planner/{nodes => stage}/MailboxSendNode.java  |  24 +-
 .../TableScanNode.java => stage/ProjectNode.java}  |  32 +-
 .../query/planner/{nodes => stage}/StageNode.java  |   2 +-
 .../StageNodeSerDeUtils.java}                      |  16 +-
 .../planner/{nodes => stage}/TableScanNode.java    |   9 +-
 .../query/rules/PinotExchangeNodeInsertRule.java   |  24 +-
 .../pinot/query/rules/PinotQueryRuleSets.java      |   2 +
 .../apache/pinot/query/QueryEnvironmentTest.java   |  80 ++--
 .../pinot/query/QueryEnvironmentTestBase.java      |  52 +++
 .../pinot/query/planner/stage/SerDeUtilsTest.java  |  81 ++++
 .../apache/pinot/query/runtime/QueryRunner.java    |   6 +-
 .../runtime/executor/WorkerQueryExecutor.java      |  33 +-
 ...castJoinOperator.java => HashJoinOperator.java} |   6 +-
 .../runtime/operator/MailboxSendOperator.java      |  52 ++-
 .../query/runtime/plan/DistributedStagePlan.java   |   2 +-
 .../runtime/plan/serde/QueryPlanSerDeUtils.java    |   8 +-
 .../query/runtime/utils/ServerRequestUtils.java    |  33 +-
 .../pinot/query/service/QueryDispatcher.java       |   2 +-
 .../pinot/query/runtime/QueryRunnerTest.java       | 170 ++-----
 .../pinot/query/service/QueryServerTest.java       |  40 +-
 43 files changed, 1071 insertions(+), 892 deletions(-)

diff --git a/pinot-common/src/main/proto/plan.proto b/pinot-common/src/main/proto/plan.proto
index 47018197fc..8e75a31a42 100644
--- a/pinot-common/src/main/proto/plan.proto
+++ b/pinot-common/src/main/proto/plan.proto
@@ -76,8 +76,9 @@ message LiteralField {
     bool boolField = 1;
     int32 intField = 2;
     int64 longField = 3;
-    double doubleField = 4;
-    string stringField = 5;
+    float floatField = 4;
+    double doubleField = 5;
+    string stringField = 6;
   }
 }
 
diff --git a/pinot-query-planner/pom.xml b/pinot-query-planner/pom.xml
index 05b9461cdc..68ee3a2057 100644
--- a/pinot-query-planner/pom.xml
+++ b/pinot-query-planner/pom.xml
@@ -70,12 +70,12 @@
     <dependency>
       <groupId>org.codehaus.janino</groupId>
       <artifactId>janino</artifactId>
-      <version>3.0.9</version>
+      <version>3.1.6</version>
     </dependency>
     <dependency>
       <groupId>org.codehaus.janino</groupId>
       <artifactId>commons-compiler</artifactId>
-      <version>3.0.9</version>
+      <version>3.1.6</version>
     </dependency>
 
     <dependency>
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index 215c0051d7..3a19156287 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -31,6 +31,7 @@ import org.apache.calcite.prepare.PlannerImpl;
 import org.apache.calcite.prepare.Prepare;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelRoot;
+import org.apache.calcite.rel.hint.HintStrategyTable;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.sql.SqlKind;
@@ -43,9 +44,9 @@ import org.apache.calcite.tools.FrameworkConfig;
 import org.apache.calcite.tools.Frameworks;
 import org.apache.pinot.query.context.PlannerContext;
 import org.apache.pinot.query.parser.CalciteSqlParser;
-import org.apache.pinot.query.planner.LogicalPlanner;
 import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.StagePlanner;
+import org.apache.pinot.query.planner.logical.LogicalPlanner;
+import org.apache.pinot.query.planner.logical.StagePlanner;
 import org.apache.pinot.query.routing.WorkerManager;
 import org.apache.pinot.query.rules.PinotQueryRuleSets;
 import org.apache.pinot.query.type.TypeFactory;
@@ -157,10 +158,15 @@ public class QueryEnvironment {
     RelOptCluster cluster = RelOptCluster.create(_relOptPlanner, rexBuilder);
     SqlToRelConverter sqlToRelConverter =
         new SqlToRelConverter(_planner, _validator, _catalogReader, cluster, StandardConvertletTable.INSTANCE,
-            SqlToRelConverter.config());
+            SqlToRelConverter.config().withHintStrategyTable(getHintStrategyTable(plannerContext)));
     return sqlToRelConverter.convertQuery(parsed, false, true);
   }
 
+  // TODO: add hint strategy table based on plannerContext.
+  private HintStrategyTable getHintStrategyTable(PlannerContext plannerContext) {
+    return HintStrategyTable.builder().build();
+  }
+
   protected RelNode optimize(RelRoot relRoot, PlannerContext plannerContext) {
     // 4. optimize relNode
     // TODO: add support for traits, cost factory.
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteExpressionParser.java
deleted file mode 100644
index fc75efb1ae..0000000000
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteExpressionParser.java
+++ /dev/null
@@ -1,502 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.query.parser;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Set;
-import org.apache.calcite.sql.SqlBasicCall;
-import org.apache.calcite.sql.SqlDataTypeSpec;
-import org.apache.calcite.sql.SqlIdentifier;
-import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.SqlLiteral;
-import org.apache.calcite.sql.SqlNode;
-import org.apache.calcite.sql.SqlNodeList;
-import org.apache.calcite.sql.fun.SqlCase;
-import org.apache.calcite.sql.parser.SqlParseException;
-import org.apache.calcite.sql.parser.SqlParser;
-import org.apache.commons.lang3.StringUtils;
-import org.apache.pinot.common.request.Expression;
-import org.apache.pinot.common.request.ExpressionType;
-import org.apache.pinot.common.request.Function;
-import org.apache.pinot.common.utils.request.RequestUtils;
-import org.apache.pinot.segment.spi.AggregationFunctionType;
-import org.apache.pinot.sql.parsers.SqlCompilationException;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-
-/**
- * Calcite parser to convert SQL expressions into {@link Expression}.
- *
- * <p>This class is extracted from {@link org.apache.pinot.sql.parsers.CalciteSqlParser}. It only contains the
- * {@link Expression} related info, this is used for ingestion and query rewrite.
- */
-public class CalciteExpressionParser {
-  private static final Logger LOGGER = LoggerFactory.getLogger(CalciteExpressionParser.class);
-
-  private CalciteExpressionParser() {
-    // do not instantiate.
-  }
-
-  private static List<Expression> getAliasLeftExpressionsFromDistinctExpression(Function function) {
-    List<Expression> operands = function.getOperands();
-    List<Expression> expressions = new ArrayList<>(operands.size());
-    for (Expression operand : operands) {
-      if (isAsFunction(operand)) {
-        expressions.add(operand.getFunctionCall().getOperands().get(0));
-      } else {
-        expressions.add(operand);
-      }
-    }
-    return expressions;
-  }
-
-  public static boolean isAggregateExpression(Expression expression) {
-    Function functionCall = expression.getFunctionCall();
-    if (functionCall != null) {
-      String operator = functionCall.getOperator();
-      try {
-        AggregationFunctionType.getAggregationFunctionType(operator);
-        return true;
-      } catch (IllegalArgumentException e) {
-      }
-      if (functionCall.getOperandsSize() > 0) {
-        for (Expression operand : functionCall.getOperands()) {
-          if (isAggregateExpression(operand)) {
-            return true;
-          }
-        }
-      }
-    }
-    return false;
-  }
-
-  public static boolean isAsFunction(Expression expression) {
-    return expression.getFunctionCall() != null && expression.getFunctionCall().getOperator().equalsIgnoreCase("AS");
-  }
-
-  /**
-   * Extract all the identifiers from given expressions.
-   *
-   * @param expressions
-   * @param excludeAs if true, ignores the right side identifier for AS function.
-   * @return all the identifier names.
-   */
-  public static Set<String> extractIdentifiers(List<Expression> expressions, boolean excludeAs) {
-    Set<String> identifiers = new HashSet<>();
-    for (Expression expression : expressions) {
-      if (expression.getIdentifier() != null) {
-        identifiers.add(expression.getIdentifier().getName());
-      } else if (expression.getFunctionCall() != null) {
-        if (excludeAs && expression.getFunctionCall().getOperator().equalsIgnoreCase("AS")) {
-          identifiers.addAll(
-              extractIdentifiers(Arrays.asList(expression.getFunctionCall().getOperands().get(0)), true));
-          continue;
-        } else {
-          identifiers.addAll(extractIdentifiers(expression.getFunctionCall().getOperands(), excludeAs));
-        }
-      }
-    }
-    return identifiers;
-  }
-
-  /**
-   * Compiles a String expression into {@link Expression}.
-   *
-   * @param expression String expression.
-   * @return {@link Expression} equivalent of the string.
-   *
-   * @throws SqlCompilationException if String is not a valid expression.
-   */
-  public static Expression compileToExpression(String expression) {
-    SqlParser sqlParser = SqlParser.create(expression, ParserUtils.PARSER_CONFIG);
-    SqlNode sqlNode;
-    try {
-      sqlNode = sqlParser.parseExpression();
-    } catch (SqlParseException e) {
-      throw new SqlCompilationException("Caught exception while parsing expression: " + expression, e);
-    }
-    return toExpression(sqlNode);
-  }
-
-  private static List<Expression> convertDistinctSelectList(SqlNodeList selectList) {
-    List<Expression> selectExpr = new ArrayList<>();
-    selectExpr.add(convertDistinctAndSelectListToFunctionExpression(selectList));
-    return selectExpr;
-  }
-
-  private static List<Expression> convertSelectList(SqlNodeList selectList) {
-    List<Expression> selectExpr = new ArrayList<>();
-
-    final Iterator<SqlNode> iterator = selectList.iterator();
-    while (iterator.hasNext()) {
-      final SqlNode next = iterator.next();
-      selectExpr.add(toExpression(next));
-    }
-
-    return selectExpr;
-  }
-
-  private static List<Expression> convertOrderByList(SqlNodeList orderList) {
-    List<Expression> orderByExpr = new ArrayList<>();
-    final Iterator<SqlNode> iterator = orderList.iterator();
-    while (iterator.hasNext()) {
-      final SqlNode next = iterator.next();
-      orderByExpr.add(convertOrderBy(next));
-    }
-    return orderByExpr;
-  }
-
-  private static Expression convertOrderBy(SqlNode node) {
-    final SqlKind kind = node.getKind();
-    Expression expression;
-    switch (kind) {
-      case DESCENDING:
-        SqlBasicCall basicCall = (SqlBasicCall) node;
-        expression = RequestUtils.getFunctionExpression("DESC");
-        expression.getFunctionCall().addToOperands(toExpression(basicCall.getOperandList().get(0)));
-        break;
-      case IDENTIFIER:
-      default:
-        expression = RequestUtils.getFunctionExpression("ASC");
-        expression.getFunctionCall().addToOperands(toExpression(node));
-        break;
-    }
-    return expression;
-  }
-
-  /**
-   * DISTINCT is implemented as an aggregation function so need to take the select list items
-   * and convert them into a single function expression for handing over to execution engine
-   * either as a PinotQuery or BrokerRequest via conversion
-   * @param selectList select list items
-   * @return DISTINCT function expression
-   */
-  private static Expression convertDistinctAndSelectListToFunctionExpression(SqlNodeList selectList) {
-    String functionName = AggregationFunctionType.DISTINCT.getName();
-    Expression functionExpression = RequestUtils.getFunctionExpression(functionName);
-    for (SqlNode node : selectList) {
-      Expression columnExpression = toExpression(node);
-      if (columnExpression.getType() == ExpressionType.IDENTIFIER && columnExpression.getIdentifier().getName()
-          .equals("*")) {
-        throw new SqlCompilationException(
-            "Syntax error: Pinot currently does not support DISTINCT with *. Please specify each column name after "
-                + "DISTINCT keyword");
-      } else if (columnExpression.getType() == ExpressionType.FUNCTION) {
-        Function functionCall = columnExpression.getFunctionCall();
-        String function = functionCall.getOperator();
-        if (AggregationFunctionType.isAggregationFunction(function)) {
-          throw new SqlCompilationException(
-              "Syntax error: Use of DISTINCT with aggregation functions is not supported");
-        }
-      }
-      functionExpression.getFunctionCall().addToOperands(columnExpression);
-    }
-    return functionExpression;
-  }
-
-  private static Expression toExpression(SqlNode node) {
-    LOGGER.debug("Current processing SqlNode: {}, node.getKind(): {}", node, node.getKind());
-    switch (node.getKind()) {
-      case IDENTIFIER:
-        if (((SqlIdentifier) node).isStar()) {
-          return RequestUtils.getIdentifierExpression("*");
-        }
-        if (((SqlIdentifier) node).isSimple()) {
-          return RequestUtils.getIdentifierExpression(((SqlIdentifier) node).getSimple());
-        }
-        return RequestUtils.getIdentifierExpression(node.toString());
-      case LITERAL:
-        return RequestUtils.getLiteralExpression((SqlLiteral) node);
-      case AS:
-        SqlBasicCall asFuncSqlNode = (SqlBasicCall) node;
-        List<SqlNode> operands = asFuncSqlNode.getOperandList();
-        Expression leftExpr = toExpression(operands.get(0));
-        SqlNode aliasSqlNode = operands.get(1);
-        String aliasName;
-        switch (aliasSqlNode.getKind()) {
-          case IDENTIFIER:
-            aliasName = ((SqlIdentifier) aliasSqlNode).getSimple();
-            break;
-          case LITERAL:
-            aliasName = ((SqlLiteral) aliasSqlNode).toValue();
-            break;
-          default:
-            throw new SqlCompilationException("Unsupported Alias sql node - " + aliasSqlNode);
-        }
-        Expression rightExpr = RequestUtils.getIdentifierExpression(aliasName);
-        // Just return left identifier if both sides are the same identifier.
-        if (leftExpr.isSetIdentifier() && rightExpr.isSetIdentifier()) {
-          if (leftExpr.getIdentifier().getName().equals(rightExpr.getIdentifier().getName())) {
-            return leftExpr;
-          }
-        }
-        final Expression asFuncExpr = RequestUtils.getFunctionExpression(SqlKind.AS.toString());
-        asFuncExpr.getFunctionCall().addToOperands(leftExpr);
-        asFuncExpr.getFunctionCall().addToOperands(rightExpr);
-        return asFuncExpr;
-      case CASE:
-        // CASE WHEN Statement is model as a function with variable length parameters.
-        // Assume N is number of WHEN Statements, total number of parameters is (2 * N + 1).
-        // - N: Convert each WHEN Statement into a function Expression;
-        // - N: Convert each THEN Statement into an Expression;
-        // - 1: Convert ELSE Statement into an Expression.
-        SqlCase caseSqlNode = (SqlCase) node;
-        SqlNodeList whenOperands = caseSqlNode.getWhenOperands();
-        SqlNodeList thenOperands = caseSqlNode.getThenOperands();
-        SqlNode elseOperand = caseSqlNode.getElseOperand();
-        Expression caseFuncExpr = RequestUtils.getFunctionExpression(SqlKind.CASE.name());
-        for (SqlNode whenSqlNode : whenOperands.getList()) {
-          Expression whenExpression = toExpression(whenSqlNode);
-          if (isAggregateExpression(whenExpression)) {
-            throw new SqlCompilationException(
-                "Aggregation functions inside WHEN Clause is not supported - " + whenSqlNode);
-          }
-          caseFuncExpr.getFunctionCall().addToOperands(whenExpression);
-        }
-        for (SqlNode thenSqlNode : thenOperands.getList()) {
-          Expression thenExpression = toExpression(thenSqlNode);
-          if (isAggregateExpression(thenExpression)) {
-            throw new SqlCompilationException(
-                "Aggregation functions inside THEN Clause is not supported - " + thenSqlNode);
-          }
-          caseFuncExpr.getFunctionCall().addToOperands(thenExpression);
-        }
-        Expression elseExpression = toExpression(elseOperand);
-        if (isAggregateExpression(elseExpression)) {
-          throw new SqlCompilationException(
-              "Aggregation functions inside ELSE Clause is not supported - " + elseExpression);
-        }
-        caseFuncExpr.getFunctionCall().addToOperands(elseExpression);
-        return caseFuncExpr;
-      default:
-        if (node instanceof SqlDataTypeSpec) {
-          // This is to handle expression like: CAST(col AS INT)
-          return RequestUtils.getLiteralExpression(((SqlDataTypeSpec) node).getTypeName().getSimple());
-        } else {
-          return compileFunctionExpression((SqlBasicCall) node);
-        }
-    }
-  }
-
-  private static Expression compileFunctionExpression(SqlBasicCall functionNode) {
-    SqlKind functionKind = functionNode.getKind();
-    String functionName;
-    switch (functionKind) {
-      case AND:
-        return compileAndExpression(functionNode);
-      case OR:
-        return compileOrExpression(functionNode);
-      case COUNT:
-        SqlLiteral functionQuantifier = functionNode.getFunctionQuantifier();
-        if (functionQuantifier != null && functionQuantifier.toValue().equalsIgnoreCase("DISTINCT")) {
-          functionName = AggregationFunctionType.DISTINCTCOUNT.name();
-        } else {
-          functionName = AggregationFunctionType.COUNT.name();
-        }
-        break;
-      case OTHER:
-      case OTHER_FUNCTION:
-      case DOT:
-        functionName = functionNode.getOperator().getName().toUpperCase();
-        if (functionName.equals("ITEM") || functionName.equals("DOT")) {
-          // Calcite parses path expression such as "data[0][1].a.b[0]" into a chain of ITEM and/or DOT
-          // functions. Collapse this chain into an identifier.
-          StringBuffer path = new StringBuffer();
-          compilePathExpression(functionName, functionNode, path);
-          return RequestUtils.getIdentifierExpression(path.toString());
-        }
-        break;
-      default:
-        functionName = functionKind.name();
-        break;
-    }
-    // When there is no argument, set an empty list as the operands
-    List<SqlNode> childNodes = functionNode.getOperandList();
-    List<Expression> operands = new ArrayList<>(childNodes.size());
-    for (SqlNode childNode : childNodes) {
-      if (childNode instanceof SqlNodeList) {
-        for (SqlNode node : (SqlNodeList) childNode) {
-          operands.add(toExpression(node));
-        }
-      } else {
-        operands.add(toExpression(childNode));
-      }
-    }
-    validateFunction(functionName, operands);
-    Expression functionExpression = RequestUtils.getFunctionExpression(functionName);
-    functionExpression.getFunctionCall().setOperands(operands);
-    return functionExpression;
-  }
-
-  /**
-   * Convert Calcite operator tree made up of ITEM and DOT functions to an identifier. For example, the operator tree
-   * shown below will be converted to IDENTIFIER "jsoncolumn.data[0][1].a.b[0]".
-   *
-   * ├── ITEM(jsoncolumn.data[0][1].a.b[0])
-   *      ├── LITERAL (0)
-   *      └── DOT (jsoncolumn.daa[0][1].a.b)
-   *            ├── IDENTIFIER (b)
-   *            └── DOT (jsoncolumn.data[0][1].a)
-   *                  ├── IDENTIFIER (a)
-   *                  └── ITEM (jsoncolumn.data[0][1])
-   *                        ├── LITERAL (1)
-   *                        └── ITEM (jsoncolumn.data[0])
-   *                              ├── LITERAL (1)
-   *                              └── IDENTIFIER (jsoncolumn.data)
-   *
-   * @param functionName Name of the function ("DOT" or "ITEM")
-   * @param functionNode Root node of the DOT and/or ITEM operator function chain.
-   * @param path String representation of path represented by DOT and/or ITEM function chain.
-   */
-  private static void compilePathExpression(String functionName, SqlBasicCall functionNode, StringBuffer path) {
-    List<SqlNode> operands = functionNode.getOperandList();
-
-    // Compile first operand of the function (either an identifier or another DOT and/or ITEM function).
-    SqlKind kind0 = operands.get(0).getKind();
-    if (kind0 == SqlKind.IDENTIFIER) {
-      path.append(operands.get(0).toString());
-    } else if (kind0 == SqlKind.DOT || kind0 == SqlKind.OTHER_FUNCTION) {
-      SqlBasicCall function0 = (SqlBasicCall) operands.get(0);
-      String name0 = function0.getOperator().getName();
-      if (name0.equals("ITEM") || name0.equals("DOT")) {
-        compilePathExpression(name0, function0, path);
-      } else {
-        throw new SqlCompilationException("SELECT list item has bad path expression.");
-      }
-    } else {
-      throw new SqlCompilationException("SELECT list item has bad path expression.");
-    }
-
-    // Compile second operand of the function (either an identifier or literal).
-    SqlKind kind1 = operands.get(1).getKind();
-    if (kind1 == SqlKind.IDENTIFIER) {
-      path.append(".").append(((SqlIdentifier) operands.get(1)).getSimple());
-    } else if (kind1 == SqlKind.LITERAL) {
-      path.append("[").append(((SqlLiteral) operands.get(1)).toValue()).append("]");
-    } else {
-      throw new SqlCompilationException("SELECT list item has bad path expression.");
-    }
-  }
-
-  public static String canonicalize(String functionName) {
-    return StringUtils.remove(functionName, '_').toLowerCase();
-  }
-
-  public static boolean isSameFunction(String function1, String function2) {
-    return canonicalize(function1).equals(canonicalize(function2));
-  }
-
-  private static void validateFunction(String functionName, List<Expression> operands) {
-    switch (canonicalize(functionName)) {
-      case "jsonextractscalar":
-        validateJsonExtractScalarFunction(operands);
-        break;
-      case "jsonextractkey":
-        validateJsonExtractKeyFunction(operands);
-        break;
-      default:
-        break;
-    }
-  }
-
-  private static void validateJsonExtractScalarFunction(List<Expression> operands) {
-    int numOperands = operands.size();
-
-    // Check that there are exactly 3 or 4 arguments
-    if (numOperands != 3 && numOperands != 4) {
-      throw new SqlCompilationException(
-          "Expect 3 or 4 arguments for transform function: jsonExtractScalar(jsonFieldName, 'jsonPath', "
-              + "'resultsType', ['defaultValue'])");
-    }
-    if (!operands.get(1).isSetLiteral() || !operands.get(2).isSetLiteral() || (numOperands == 4 && !operands.get(3)
-        .isSetLiteral())) {
-      throw new SqlCompilationException(
-          "Expect the 2nd/3rd/4th argument of transform function: jsonExtractScalar(jsonFieldName, 'jsonPath',"
-              + " 'resultsType', ['defaultValue']) to be a single-quoted literal value.");
-    }
-  }
-
-  private static void validateJsonExtractKeyFunction(List<Expression> operands) {
-    // Check that there are exactly 2 arguments
-    if (operands.size() != 2) {
-      throw new SqlCompilationException(
-          "Expect 2 arguments are required for transform function: jsonExtractKey(jsonFieldName, 'jsonPath')");
-    }
-    if (!operands.get(1).isSetLiteral()) {
-      throw new SqlCompilationException(
-          "Expect the 2nd argument for transform function: jsonExtractKey(jsonFieldName, 'jsonPath') to be a "
-              + "single-quoted literal value.");
-    }
-  }
-
-  /**
-   * Helper method to flatten the operands for the AND expression.
-   */
-  private static Expression compileAndExpression(SqlBasicCall andNode) {
-    List<Expression> operands = new ArrayList<>();
-    for (SqlNode childNode : andNode.getOperandList()) {
-      if (childNode.getKind() == SqlKind.AND) {
-        Expression childAndExpression = compileAndExpression((SqlBasicCall) childNode);
-        operands.addAll(childAndExpression.getFunctionCall().getOperands());
-      } else {
-        operands.add(toExpression(childNode));
-      }
-    }
-    Expression andExpression = RequestUtils.getFunctionExpression(SqlKind.AND.name());
-    andExpression.getFunctionCall().setOperands(operands);
-    return andExpression;
-  }
-
-  /**
-   * Helper method to flatten the operands for the OR expression.
-   */
-  private static Expression compileOrExpression(SqlBasicCall orNode) {
-    List<Expression> operands = new ArrayList<>();
-    for (SqlNode childNode : orNode.getOperandList()) {
-      if (childNode.getKind() == SqlKind.OR) {
-        Expression childAndExpression = compileOrExpression((SqlBasicCall) childNode);
-        operands.addAll(childAndExpression.getFunctionCall().getOperands());
-      } else {
-        operands.add(toExpression(childNode));
-      }
-    }
-    Expression andExpression = RequestUtils.getFunctionExpression(SqlKind.OR.name());
-    andExpression.getFunctionCall().setOperands(operands);
-    return andExpression;
-  }
-
-  public static boolean isLiteralOnlyExpression(Expression e) {
-    if (e.getType() == ExpressionType.LITERAL) {
-      return true;
-    }
-    if (e.getType() == ExpressionType.FUNCTION) {
-      Function functionCall = e.getFunctionCall();
-      if (functionCall.getOperator().equalsIgnoreCase(SqlKind.AS.toString())) {
-        return isLiteralOnlyExpression(functionCall.getOperands().get(0));
-      }
-      return false;
-    }
-    return false;
-  }
-}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
new file mode 100644
index 0000000000..91d573c89a
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
@@ -0,0 +1,217 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.parser;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.pinot.common.request.Expression;
+import org.apache.pinot.common.request.ExpressionType;
+import org.apache.pinot.common.request.Function;
+import org.apache.pinot.common.request.PinotQuery;
+import org.apache.pinot.common.utils.request.RequestUtils;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.apache.pinot.sql.parsers.SqlCompilationException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Calcite parser to convert SQL expressions into {@link Expression}.
+ *
+ * <p>This class is extracted from {@link org.apache.pinot.sql.parsers.CalciteSqlParser}. It contains the logic
+ * to parsed {@link org.apache.calcite.rex.RexNode}, in the format of {@link RexExpression} and convert them into
+ * Thrift {@link Expression} format.
+ */
+public class CalciteRexExpressionParser {
+  private static final Logger LOGGER = LoggerFactory.getLogger(CalciteRexExpressionParser.class);
+
+  private CalciteRexExpressionParser() {
+    // do not instantiate.
+  }
+
+  // --------------------------------------------------------------------------
+  // Relational conversion Utils
+  // --------------------------------------------------------------------------
+
+  public static List<Expression> convertSelectList(List<RexExpression> rexNodeList, PinotQuery pinotQuery) {
+    List<Expression> selectExpr = new ArrayList<>();
+
+    final Iterator<RexExpression> iterator = rexNodeList.iterator();
+    while (iterator.hasNext()) {
+      final RexExpression next = iterator.next();
+      selectExpr.add(toExpression(next, pinotQuery));
+    }
+
+    return selectExpr;
+  }
+
+  private static List<Expression> convertDistinctSelectList(RexExpression.FunctionCall rexCall, PinotQuery pinotQuery) {
+    List<Expression> selectExpr = new ArrayList<>();
+    selectExpr.add(convertDistinctAndSelectListToFunctionExpression(rexCall, pinotQuery));
+    return selectExpr;
+  }
+
+  private static List<Expression> convertOrderByList(RexExpression.FunctionCall rexCall, PinotQuery pinotQuery) {
+    Preconditions.checkState(rexCall.getKind() == SqlKind.ORDER_BY);
+    List<Expression> orderByExpr = new ArrayList<>();
+
+    final Iterator<RexExpression> iterator = rexCall.getFunctionOperands().iterator();
+    while (iterator.hasNext()) {
+      final RexExpression next = iterator.next();
+      orderByExpr.add(convertOrderBy(next, pinotQuery));
+    }
+    return orderByExpr;
+  }
+
+  private static Expression convertOrderBy(RexExpression rexNode, PinotQuery pinotQuery) {
+    final SqlKind kind = rexNode.getKind();
+    Expression expression;
+    switch (kind) {
+      case DESCENDING:
+        RexExpression.FunctionCall rexCall = (RexExpression.FunctionCall) rexNode;
+        expression = RequestUtils.getFunctionExpression("DESC");
+        expression.getFunctionCall().addToOperands(toExpression(rexCall.getFunctionOperands().get(0), pinotQuery));
+        break;
+      case IDENTIFIER:
+      default:
+        expression = RequestUtils.getFunctionExpression("ASC");
+        expression.getFunctionCall().addToOperands(toExpression(rexNode, pinotQuery));
+        break;
+    }
+    return expression;
+  }
+
+  private static Expression convertDistinctAndSelectListToFunctionExpression(RexExpression.FunctionCall rexCall,
+      PinotQuery pinotQuery) {
+    String functionName = AggregationFunctionType.DISTINCT.getName();
+    Expression functionExpression = RequestUtils.getFunctionExpression(functionName);
+    for (RexExpression node : rexCall.getFunctionOperands()) {
+      Expression columnExpression = toExpression(node, pinotQuery);
+      if (columnExpression.getType() == ExpressionType.IDENTIFIER && columnExpression.getIdentifier().getName()
+          .equals("*")) {
+        throw new SqlCompilationException(
+            "Syntax error: Pinot currently does not support DISTINCT with *. Please specify each column name after "
+                + "DISTINCT keyword");
+      } else if (columnExpression.getType() == ExpressionType.FUNCTION) {
+        Function functionCall = columnExpression.getFunctionCall();
+        String function = functionCall.getOperator();
+        if (AggregationFunctionType.isAggregationFunction(function)) {
+          throw new SqlCompilationException(
+              "Syntax error: Use of DISTINCT with aggregation functions is not supported");
+        }
+      }
+      functionExpression.getFunctionCall().addToOperands(columnExpression);
+    }
+    return functionExpression;
+  }
+
+  public static Expression toExpression(RexExpression rexNode, PinotQuery pinotQuery) {
+    LOGGER.debug("Current processing RexNode: {}, node.getKind(): {}", rexNode, rexNode.getKind());
+    switch (rexNode.getKind()) {
+      case INPUT_REF:
+        return inputRefToIdentifier((RexExpression.InputRef) rexNode, pinotQuery);
+      case LITERAL:
+        return rexLiteralToExpression((RexExpression.Literal) rexNode);
+      default:
+        return compileFunctionExpression((RexExpression.FunctionCall) rexNode, pinotQuery);
+    }
+  }
+
+  private static Expression rexLiteralToExpression(RexExpression.Literal rexLiteral) {
+    RelDataType type = rexLiteral.getDataType();
+    switch (type.getSqlTypeName()) {
+      default:
+        return RequestUtils.getLiteralExpression(rexLiteral.getValue());
+    }
+  }
+
+  private static Expression inputRefToIdentifier(RexExpression.InputRef inputRef, PinotQuery pinotQuery) {
+    List<Expression> selectList = pinotQuery.getSelectList();
+    return selectList.get(inputRef.getIndex());
+  }
+
+  private static Expression compileFunctionExpression(RexExpression.FunctionCall rexCall, PinotQuery pinotQuery) {
+    SqlKind functionKind = rexCall.getKind();
+    String functionName;
+    switch (functionKind) {
+      case AND:
+        return compileAndExpression(rexCall, pinotQuery);
+      case OR:
+        return compileOrExpression(rexCall, pinotQuery);
+      case COUNT:
+      case OTHER:
+      case OTHER_FUNCTION:
+      case DOT:
+      default:
+        functionName = functionKind.name();
+        break;
+    }
+    // When there is no argument, set an empty list as the operands
+    List<RexExpression> childNodes = rexCall.getFunctionOperands();
+    List<Expression> operands = new ArrayList<>(childNodes.size());
+    for (RexExpression childNode : childNodes) {
+      operands.add(toExpression(childNode, pinotQuery));
+    }
+    ParserUtils.validateFunction(functionName, operands);
+    Expression functionExpression = RequestUtils.getFunctionExpression(functionName);
+    functionExpression.getFunctionCall().setOperands(operands);
+    return functionExpression;
+  }
+
+  /**
+   * Helper method to flatten the operands for the AND expression.
+   */
+  private static Expression compileAndExpression(RexExpression.FunctionCall andNode, PinotQuery pinotQuery) {
+    List<Expression> operands = new ArrayList<>();
+    for (RexExpression childNode : andNode.getFunctionOperands()) {
+      if (childNode.getKind() == SqlKind.AND) {
+        Expression childAndExpression = compileAndExpression((RexExpression.FunctionCall) childNode, pinotQuery);
+        operands.addAll(childAndExpression.getFunctionCall().getOperands());
+      } else {
+        operands.add(toExpression(childNode, pinotQuery));
+      }
+    }
+    Expression andExpression = RequestUtils.getFunctionExpression(SqlKind.AND.name());
+    andExpression.getFunctionCall().setOperands(operands);
+    return andExpression;
+  }
+
+  /**
+   * Helper method to flatten the operands for the OR expression.
+   */
+  private static Expression compileOrExpression(RexExpression.FunctionCall orNode, PinotQuery pinotQuery) {
+    List<Expression> operands = new ArrayList<>();
+    for (RexExpression childNode : orNode.getFunctionOperands()) {
+      if (childNode.getKind() == SqlKind.OR) {
+        Expression childAndExpression = compileOrExpression((RexExpression.FunctionCall) childNode, pinotQuery);
+        operands.addAll(childAndExpression.getFunctionCall().getOperands());
+      } else {
+        operands.add(toExpression(childNode, pinotQuery));
+      }
+    }
+    Expression andExpression = RequestUtils.getFunctionExpression(SqlKind.OR.name());
+    andExpression.getFunctionCall().setOperands(operands);
+    return andExpression;
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteSqlParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteSqlParser.java
index d67896f9b9..d2fd054833 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteSqlParser.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteSqlParser.java
@@ -57,7 +57,7 @@ import org.slf4j.LoggerFactory;
  * This class provide API to parse a SQL string into Pinot query {@link SqlNode}.
  *
  * <p>This class is extracted from {@link org.apache.pinot.sql.parsers.CalciteSqlParser}. It contains the logic
- * to parsed SQL into {@link SqlNode} and use {@link QueryRewriter} to rewrite the query with Pinot specific
+ * to parsed SQL string into {@link SqlNode} and use {@link QueryRewriter} to rewrite the query with Pinot specific
  * contextual info.
  */
 public class CalciteSqlParser {
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/ParserUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/ParserUtils.java
index 5422382509..ec6494fea2 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/ParserUtils.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/ParserUtils.java
@@ -18,11 +18,15 @@
  */
 package org.apache.pinot.query.parser;
 
+import java.util.List;
 import java.util.regex.Pattern;
 import org.apache.calcite.config.Lex;
 import org.apache.calcite.sql.parser.SqlParser;
 import org.apache.calcite.sql.parser.babel.SqlBabelParserImpl;
 import org.apache.calcite.sql.validate.SqlConformanceEnum;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.common.request.Expression;
+import org.apache.pinot.sql.parsers.SqlCompilationException;
 
 
 /**
@@ -60,4 +64,55 @@ final class ParserUtils {
   private ParserUtils() {
     // do not instantiate.
   }
+
+  public static String canonicalize(String functionName) {
+    return StringUtils.remove(functionName, '_').toLowerCase();
+  }
+
+  public static boolean isSameFunction(String function1, String function2) {
+    return canonicalize(function1).equals(canonicalize(function2));
+  }
+
+  public static void validateFunction(String functionName, List<Expression> operands) {
+    switch (canonicalize(functionName)) {
+      case "jsonextractscalar":
+        validateJsonExtractScalarFunction(operands);
+        break;
+      case "jsonextractkey":
+        validateJsonExtractKeyFunction(operands);
+        break;
+      default:
+        break;
+    }
+  }
+
+  private static void validateJsonExtractScalarFunction(List<Expression> operands) {
+    int numOperands = operands.size();
+
+    // Check that there are exactly 3 or 4 arguments
+    if (numOperands != 3 && numOperands != 4) {
+      throw new SqlCompilationException(
+          "Expect 3 or 4 arguments for transform function: jsonExtractScalar(jsonFieldName, 'jsonPath', "
+              + "'resultsType', ['defaultValue'])");
+    }
+    if (!operands.get(1).isSetLiteral() || !operands.get(2).isSetLiteral() || (numOperands == 4 && !operands.get(3)
+        .isSetLiteral())) {
+      throw new SqlCompilationException(
+          "Expect the 2nd/3rd/4th argument of transform function: jsonExtractScalar(jsonFieldName, 'jsonPath',"
+              + " 'resultsType', ['defaultValue']) to be a single-quoted literal value.");
+    }
+  }
+
+  private static void validateJsonExtractKeyFunction(List<Expression> operands) {
+    // Check that there are exactly 2 arguments
+    if (operands.size() != 2) {
+      throw new SqlCompilationException(
+          "Expect 2 arguments are required for transform function: jsonExtractKey(jsonFieldName, 'jsonPath')");
+    }
+    if (!operands.get(1).isSetLiteral()) {
+      throw new SqlCompilationException(
+          "Expect the 2nd argument for transform function: jsonExtractKey(jsonFieldName, 'jsonPath') to be a "
+              + "single-quoted literal value.");
+    }
+  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
index cb075bb9f3..a93bddc4f7 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
@@ -19,7 +19,8 @@
 package org.apache.pinot.query.planner;
 
 import java.util.Map;
-import org.apache.pinot.query.planner.nodes.StageNode;
+import org.apache.pinot.query.planner.logical.LogicalPlanner;
+import org.apache.pinot.query.planner.stage.StageNode;
 
 
 /**
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java
index 90dd22a8aa..8e691003c9 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java
@@ -24,8 +24,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import org.apache.pinot.core.transport.ServerInstance;
-import org.apache.pinot.query.planner.nodes.StageNode;
-import org.apache.pinot.query.planner.nodes.TableScanNode;
+import org.apache.pinot.query.planner.stage.StageNode;
+import org.apache.pinot.query.planner.stage.TableScanNode;
 
 
 /**
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/CalcNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/hints/PinotRelationalHints.java
similarity index 62%
copy from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/CalcNode.java
copy to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/hints/PinotRelationalHints.java
index 0aa8c94ec8..19a9daa54f 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/CalcNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/hints/PinotRelationalHints.java
@@ -16,23 +16,19 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.hints;
 
+import org.apache.calcite.rel.hint.RelHint;
 
 
-public class CalcNode extends AbstractStageNode {
-  private String _expression;
-
-  public CalcNode(int stageId) {
-    super(stageId);
-  }
-
-  public CalcNode(int stageId, String expression) {
-    super(stageId);
-    _expression = expression;
-  }
+/**
+ * Provide certain relational hint to query planner for better optimization.
+ */
+public class PinotRelationalHints {
+  public static final RelHint USE_HASH_DISTRIBUTE = RelHint.builder("USE_HASH_DISTRIBUTE").build();
+  public static final RelHint USE_BROADCAST_DISTRIBUTE = RelHint.builder("USE_BROADCAST_DISTRIBUTE").build();
 
-  public String getExpression() {
-    return _expression;
+  private PinotRelationalHints() {
+    // do not instantiate.
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/LogicalPlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LogicalPlanner.java
similarity index 97%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/LogicalPlanner.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LogicalPlanner.java
index 9844916490..0e317560e9 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/LogicalPlanner.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LogicalPlanner.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner;
+package org.apache.pinot.query.planner.logical;
 
 import java.util.ArrayList;
 import java.util.List;
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/RelToStageConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
similarity index 75%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/RelToStageConverter.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
index 572302ef92..3750437f7f 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/RelToStageConverter.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner;
+package org.apache.pinot.query.planner.logical;
 
 import com.google.common.base.Preconditions;
 import java.util.Collections;
@@ -24,19 +24,21 @@ import java.util.List;
 import java.util.stream.Collectors;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.JoinRelType;
-import org.apache.calcite.rel.logical.LogicalCalc;
+import org.apache.calcite.rel.logical.LogicalFilter;
 import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rel.logical.LogicalProject;
 import org.apache.calcite.rel.logical.LogicalTableScan;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexCall;
 import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.sql.SqlKind;
-import org.apache.pinot.query.planner.nodes.CalcNode;
-import org.apache.pinot.query.planner.nodes.JoinNode;
-import org.apache.pinot.query.planner.nodes.StageNode;
-import org.apache.pinot.query.planner.nodes.TableScanNode;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
+import org.apache.pinot.query.planner.stage.FilterNode;
+import org.apache.pinot.query.planner.stage.JoinNode;
+import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.StageNode;
+import org.apache.pinot.query.planner.stage.TableScanNode;
 
 
 /**
@@ -57,27 +59,32 @@ public final class RelToStageConverter {
    * @return stage node.
    */
   public static StageNode toStageNode(RelNode node, int currentStageId) {
-    if (node instanceof LogicalCalc) {
-      return convertLogicalCal((LogicalCalc) node, currentStageId);
-    } else if (node instanceof LogicalTableScan) {
+    if (node instanceof LogicalTableScan) {
       return convertLogicalTableScan((LogicalTableScan) node, currentStageId);
     } else if (node instanceof LogicalJoin) {
       return convertLogicalJoin((LogicalJoin) node, currentStageId);
+    } else if (node instanceof LogicalProject) {
+      return convertLogicalProject((LogicalProject) node, currentStageId);
+    } else if (node instanceof LogicalFilter) {
+      return convertLogicalFilter((LogicalFilter) node, currentStageId);
     } else {
       throw new UnsupportedOperationException("Unsupported logical plan node: " + node);
     }
   }
 
+  private static StageNode convertLogicalProject(LogicalProject node, int currentStageId) {
+    return new ProjectNode(currentStageId, node.getRowType(), node.getProjects());
+  }
+
+  private static StageNode convertLogicalFilter(LogicalFilter node, int currentStageId) {
+    return new FilterNode(currentStageId, node.getRowType(), node.getCondition());
+  }
+
   private static StageNode convertLogicalTableScan(LogicalTableScan node, int currentStageId) {
     String tableName = node.getTable().getQualifiedName().get(0);
     List<String> columnNames = node.getRowType().getFieldList().stream()
         .map(RelDataTypeField::getName).collect(Collectors.toList());
-    return new TableScanNode(currentStageId, tableName, columnNames);
-  }
-
-  private static StageNode convertLogicalCal(LogicalCalc node, int currentStageId) {
-    // TODO: support actual calcNode
-    return new CalcNode(currentStageId, node.getDigest());
+    return new TableScanNode(currentStageId, node.getRowType(), tableName, columnNames);
   }
 
   private static StageNode convertLogicalJoin(LogicalJoin node, int currentStageId) {
@@ -95,7 +102,7 @@ public final class RelToStageConverter {
     FieldSelectionKeySelector leftFieldSelectionKeySelector = new FieldSelectionKeySelector(leftOperandIndex);
     FieldSelectionKeySelector rightFieldSelectionKeySelector =
           new FieldSelectionKeySelector(rightOperandIndex - leftRowType.getFieldNames().size());
-    return new JoinNode(currentStageId, joinType, Collections.singletonList(new JoinNode.JoinClause(
+    return new JoinNode(currentStageId, node.getRowType(), joinType, Collections.singletonList(new JoinNode.JoinClause(
         leftFieldSelectionKeySelector, rightFieldSelectionKeySelector)));
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
new file mode 100644
index 0000000000..899636ab4d
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
@@ -0,0 +1,157 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.logical;
+
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.util.NlsString;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+
+/**
+ * {@code RexExpression} is the serializable format of the {@link RexNode}.
+ */
+public abstract class RexExpression {
+  @ProtoProperties
+  protected SqlKind _sqlKind;
+  @ProtoProperties
+  protected RelDataType _dataType;
+
+  public SqlKind getKind() {
+    return _sqlKind;
+  }
+
+  public RelDataType getDataType() {
+    return _dataType;
+  }
+
+  public static RexExpression toRexExpression(RexNode rexNode) {
+    if (rexNode instanceof RexInputRef) {
+      return new RexExpression.InputRef(((RexInputRef) rexNode).getIndex());
+    } else if (rexNode instanceof RexLiteral) {
+      RexLiteral rexLiteral = ((RexLiteral) rexNode);
+      return new RexExpression.Literal(rexLiteral.getType(), rexLiteral.getTypeName(), rexLiteral.getValue());
+    } else if (rexNode instanceof RexCall) {
+      RexCall rexCall = (RexCall) rexNode;
+      List<RexExpression> operands = rexCall.getOperands().stream().map(RexExpression::toRexExpression)
+          .collect(Collectors.toList());
+      return new RexExpression.FunctionCall(rexCall.getKind(), rexCall.getType(), rexCall.getOperator().getName(),
+          operands);
+    } else {
+      throw new IllegalArgumentException("Unsupported RexNode type with SqlKind: " + rexNode.getKind());
+    }
+  }
+
+  private static Comparable convertLiteral(Comparable value, SqlTypeName sqlTypeName, RelDataType dataType) {
+    switch (sqlTypeName) {
+      case BOOLEAN:
+        return (boolean) value;
+      case DECIMAL:
+        switch (dataType.getSqlTypeName()) {
+          case INTEGER:
+            return ((BigDecimal) value).intValue();
+          case BIGINT:
+            return ((BigDecimal) value).longValue();
+          case FLOAT:
+            return ((BigDecimal) value).floatValue();
+          case DOUBLE:
+          default:
+            return ((BigDecimal) value).doubleValue();
+        }
+      case CHAR:
+        switch (dataType.getSqlTypeName()) {
+          case VARCHAR:
+            return ((NlsString) value).getValue();
+          default:
+            return value;
+        }
+      default:
+        return value;
+    }
+  }
+
+  public static class InputRef extends RexExpression {
+    @ProtoProperties
+    private int _index;
+
+    public InputRef() {
+    }
+
+    public InputRef(int index) {
+      _sqlKind = SqlKind.INPUT_REF;
+      _index = index;
+    }
+
+    public int getIndex() {
+      return _index;
+    }
+  }
+
+  public static class Literal extends RexExpression {
+    @ProtoProperties
+    private Object _value;
+
+    public Literal() {
+    }
+
+    public Literal(RelDataType dataType, SqlTypeName sqlTypeName, @Nullable Comparable value) {
+      _sqlKind = SqlKind.LITERAL;
+      _dataType = dataType;
+      _value = convertLiteral(value, sqlTypeName, dataType);
+    }
+
+    public Object getValue() {
+      return _value;
+    }
+  }
+
+  public static class FunctionCall extends RexExpression {
+    @ProtoProperties
+    private String _functionName;
+    @ProtoProperties
+    private List<RexExpression> _functionOperands;
+
+    public FunctionCall() {
+    }
+
+    public FunctionCall(SqlKind sqlKind, RelDataType type, String functionName, List<RexExpression> functionOperands) {
+      _sqlKind = sqlKind;
+      _dataType = type;
+      _functionName = functionName;
+      _functionOperands = functionOperands;
+    }
+
+    public String getFunctionName() {
+      return _functionName;
+    }
+
+    public List<RexExpression> getFunctionOperands() {
+      return _functionOperands;
+    }
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StagePlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
similarity index 79%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StagePlanner.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
index f6ec38738c..b5dfab8689 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StagePlanner.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner;
+package org.apache.pinot.query.planner.logical;
 
 import java.util.HashMap;
 import java.util.List;
@@ -26,9 +26,12 @@ import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelRoot;
 import org.apache.calcite.rel.logical.LogicalExchange;
 import org.apache.pinot.query.context.PlannerContext;
-import org.apache.pinot.query.planner.nodes.MailboxReceiveNode;
-import org.apache.pinot.query.planner.nodes.MailboxSendNode;
-import org.apache.pinot.query.planner.nodes.StageNode;
+import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
+import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
+import org.apache.pinot.query.planner.stage.MailboxSendNode;
+import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.routing.WorkerManager;
 
 
@@ -66,12 +69,13 @@ public class StagePlanner {
     // walk the plan and create stages.
     StageNode globalStageRoot = walkRelPlan(relRoot, getNewStageId());
 
-    // global root needs to send results back to the ROOT, a.k.a. the client response node.
-    // the last stage is always a broadcast-gather.
+    // global root needs to send results back to the ROOT, a.k.a. the client response node. the last stage only has one
+    // receiver so doesn't matter what the exchange type is. setting it to SINGLETON by default.
     StageNode globalReceiverNode =
-        new MailboxReceiveNode(0, globalStageRoot.getStageId(), RelDistribution.Type.BROADCAST_DISTRIBUTED);
-    StageNode globalSenderNode = new MailboxSendNode(globalStageRoot.getStageId(), globalReceiverNode.getStageId(),
-        RelDistribution.Type.BROADCAST_DISTRIBUTED);
+        new MailboxReceiveNode(0, relRoot.getRowType(), globalStageRoot.getStageId(),
+            RelDistribution.Type.SINGLETON);
+    StageNode globalSenderNode = new MailboxSendNode(globalStageRoot.getStageId(), relRoot.getRowType(),
+        globalReceiverNode.getStageId(), RelDistribution.Type.SINGLETON);
     globalSenderNode.addInput(globalStageRoot);
     _queryStageMap.put(globalSenderNode.getStageId(), globalSenderNode);
     StageMetadata stageMetadata = _stageMetadataMap.get(globalSenderNode.getStageId());
@@ -95,12 +99,15 @@ public class StagePlanner {
     if (isExchangeNode(node)) {
       // 1. exchangeNode always have only one input, get its input converted as a new stage root.
       StageNode nextStageRoot = walkRelPlan(node.getInput(0), getNewStageId());
-      RelDistribution.Type exchangeType = ((LogicalExchange) node).distribution.getType();
+      RelDistribution distribution = ((LogicalExchange) node).getDistribution();
+      RelDistribution.Type exchangeType = distribution.getType();
 
       // 2. make an exchange sender and receiver node pair
-      StageNode mailboxReceiver = new MailboxReceiveNode(currentStageId, nextStageRoot.getStageId(), exchangeType);
-      StageNode mailboxSender = new MailboxSendNode(nextStageRoot.getStageId(), mailboxReceiver.getStageId(),
+      StageNode mailboxReceiver = new MailboxReceiveNode(currentStageId, node.getRowType(), nextStageRoot.getStageId(),
           exchangeType);
+      StageNode mailboxSender = new MailboxSendNode(nextStageRoot.getStageId(), node.getRowType(),
+          mailboxReceiver.getStageId(), exchangeType, exchangeType == RelDistribution.Type.HASH_DISTRIBUTED
+          ? new FieldSelectionKeySelector(distribution.getKeys().get(0)) : null);
       mailboxSender.addInput(nextStageRoot);
 
       // 3. put the sender side as a completed stage.
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
index 95991d558b..14f263c44f 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
@@ -18,12 +18,15 @@
  */
 package org.apache.pinot.query.planner.partitioning;
 
+import org.apache.pinot.query.planner.serde.ProtoProperties;
+
 
 /**
  * The {@code FieldSelectionKeySelector} simply extract a column value out from a row array {@link Object[]}.
  */
 public class FieldSelectionKeySelector implements KeySelector<Object[], Object> {
 
+  @ProtoProperties
   private int _columnIndex;
 
   public FieldSelectionKeySelector() {
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/KeySelector.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/KeySelector.java
index eaefb77604..e6b6e598a2 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/KeySelector.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/KeySelector.java
@@ -18,8 +18,6 @@
  */
 package org.apache.pinot.query.planner.partitioning;
 
-import java.io.Serializable;
-
 
 /**
  * The {@code KeySelector} provides a partitioning function to encode a specific input data type into a key.
@@ -28,7 +26,7 @@ import java.io.Serializable;
  *
  * <p>Key selector should always produce the same selection hash key when the same input is provided.
  */
-public interface KeySelector<IN, OUT> extends Serializable {
+public interface KeySelector<IN, OUT> {
 
   /**
    * Extract the key out of an input data construct.
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializable.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoProperties.java
similarity index 63%
copy from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializable.java
copy to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoProperties.java
index 2b99003e87..5a10b91941 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializable.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoProperties.java
@@ -22,14 +22,19 @@
  * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
  *  @generated
  */
-package org.apache.pinot.query.planner.nodes.serde;
+package org.apache.pinot.query.planner.serde;
 
-import org.apache.pinot.common.proto.Plan;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
 
 
-public interface ProtoSerializable {
-
-  void setObjectField(Plan.ObjectField objFields);
-
-  Plan.ObjectField getObjectField();
+/**
+ * Annotation {@code ProtoProperties} indicates whether a field defined in a
+ * {@link org.apache.pinot.query.planner.stage.StageNode} should be serialized.
+ */
+@Target({ElementType.ANNOTATION_TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER})
+@Retention(RetentionPolicy.RUNTIME)
+public @interface ProtoProperties {
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializable.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializable.java
similarity index 62%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializable.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializable.java
index 2b99003e87..f10cb9dd35 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializable.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializable.java
@@ -22,14 +22,28 @@
  * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
  *  @generated
  */
-package org.apache.pinot.query.planner.nodes.serde;
+package org.apache.pinot.query.planner.serde;
 
 import org.apache.pinot.common.proto.Plan;
 
 
+/**
+ * Interface to convert between proto serialized payload and object.
+ *
+ * <p>Classes that implement {@code ProtoSerializable} should provide methods to convert to and from
+ * {@link Plan.ObjectField}.
+ */
 public interface ProtoSerializable {
 
-  void setObjectField(Plan.ObjectField objFields);
+  /**
+   * Setting object's own member variable from a serialized {@link Plan.ObjectField}.
+   * @param objFields the serialized ObjectField.
+   */
+  void fromObjectField(Plan.ObjectField objFields);
 
-  Plan.ObjectField getObjectField();
+  /**
+   * convert the object to a serialized {@link Plan.ObjectField}.
+   * @return the serialized ObjectField.
+   */
+  Plan.ObjectField toObjectField();
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializationUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializationUtils.java
similarity index 61%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializationUtils.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializationUtils.java
index c30295a101..d0d3b1dc3e 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/serde/ProtoSerializationUtils.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializationUtils.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes.serde;
+package org.apache.pinot.query.planner.serde;
 
 import com.google.common.base.Preconditions;
 import java.lang.reflect.Field;
@@ -28,48 +28,75 @@ import java.util.Set;
 import org.apache.pinot.common.proto.Plan;
 
 
+/**
+ * Utils to convert automatically from/to object that's implementing {@link ProtoSerializable}.
+ */
 @SuppressWarnings({"rawtypes", "unchecked"})
 public class ProtoSerializationUtils {
   private static final String ENUM_VALUE_KEY = "ENUM_VALUE_KEY";
+  private static final String NULL_OBJECT_CLASSNAME = "null";
+  private static final Plan.ObjectField NULL_OBJECT_VALUE = Plan.ObjectField.newBuilder()
+      .setObjectClassName(NULL_OBJECT_CLASSNAME).build();
 
   private ProtoSerializationUtils() {
     // do not instantiate.
   }
 
-  public static void fromObjectField(Object object, Plan.ObjectField objectField) {
+  /**
+   * Reflectively set object's field based on {@link Plan.ObjectField} provided.
+   *
+   * @param object the object to be set.
+   * @param objectField the proto ObjectField from which the object will be set.
+   */
+  public static void setObjectFieldToObject(Object object, Plan.ObjectField objectField) {
     Map<String, Plan.MemberVariableField> memberVariablesMap = objectField.getMemberVariablesMap();
-    try {
-      for (Map.Entry<String, Plan.MemberVariableField> e : memberVariablesMap.entrySet()) {
-        Object memberVarObject = constructMemberVariable(e.getValue());
-        if (memberVarObject != null) {
-          Field declaredField = object.getClass().getDeclaredField(e.getKey());
-          declaredField.setAccessible(true);
-          declaredField.set(object, memberVarObject);
+    for (Map.Entry<String, Plan.MemberVariableField> e : memberVariablesMap.entrySet()) {
+      try {
+        Field declaredField = object.getClass().getDeclaredField(e.getKey());
+        if (declaredField.isAnnotationPresent(ProtoProperties.class)) {
+          Object memberVarObject = constructMemberVariable(e.getValue());
+          if (memberVarObject != null) {
+            declaredField.setAccessible(true);
+            declaredField.set(object, memberVarObject);
+          }
         }
+      } catch (NoSuchFieldException | IllegalAccessException ex) {
+        throw new IllegalStateException("Unable to set Object " + object.getClass() + " on field " + e.getKey()
+            + "with object of type: " + objectField.getObjectClassName(), ex);
       }
-    } catch (NoSuchFieldException | IllegalAccessException e) {
-      throw new IllegalStateException("Unable to set Object field for: " + objectField.getObjectClassName(), e);
     }
   }
 
-  public static Plan.ObjectField toObjectField(Object object) {
-    Plan.ObjectField.Builder builder = Plan.ObjectField.newBuilder();
-    builder.setObjectClassName(object.getClass().getName());
-    // special handling for enum
-    if (object instanceof Enum) {
-      builder.putMemberVariables(ENUM_VALUE_KEY, serializeMemberVariable(((Enum) object).name()));
-    } else {
-      try {
-        for (Field field : object.getClass().getDeclaredFields()) {
-          field.setAccessible(true);
-          Object fieldObject = field.get(object);
-          builder.putMemberVariables(field.getName(), serializeMemberVariable(fieldObject));
+  /**
+   * Convert object into a proto {@link Plan.ObjectField}.
+   *
+   * @param object object to be converted.
+   * @return the converted proto ObjectField.
+   */
+  public static Plan.ObjectField convertObjectToObjectField(Object object) {
+    if (object != null) {
+      Plan.ObjectField.Builder builder = Plan.ObjectField.newBuilder();
+      builder.setObjectClassName(object.getClass().getName());
+      // special handling for enum
+      if (object instanceof Enum) {
+        builder.putMemberVariables(ENUM_VALUE_KEY, serializeMemberVariable(((Enum) object).name()));
+      } else {
+        try {
+          for (Field field : object.getClass().getDeclaredFields()) {
+            if (field.isAnnotationPresent(ProtoProperties.class)) {
+              field.setAccessible(true);
+              Object fieldObject = field.get(object);
+              builder.putMemberVariables(field.getName(), serializeMemberVariable(fieldObject));
+            }
+          }
+        } catch (IllegalAccessException e) {
+          throw new IllegalStateException("Unable to serialize Object: " + object.getClass(), e);
         }
-      } catch (IllegalAccessException e) {
-        throw new IllegalStateException("Unable to serialize Object: " + object.getClass(), e);
       }
+      return builder.build();
+    } else {
+      return NULL_OBJECT_VALUE;
     }
-    return builder.build();
   }
 
   // --------------------------------------------------------------------------
@@ -88,6 +115,10 @@ public class ProtoSerializationUtils {
     return Plan.LiteralField.newBuilder().setLongField(val).build();
   }
 
+  private static Plan.LiteralField floatField(float val) {
+    return Plan.LiteralField.newBuilder().setFloatField(val).build();
+  }
+
   private static Plan.LiteralField doubleField(double val) {
     return Plan.LiteralField.newBuilder().setDoubleField(val).build();
   }
@@ -104,6 +135,8 @@ public class ProtoSerializationUtils {
       builder.setLiteralField(intField((Integer) fieldObject));
     } else if (fieldObject instanceof Long) {
       builder.setLiteralField(longField((Long) fieldObject));
+    } else if (fieldObject instanceof Float) {
+      builder.setLiteralField(floatField((Float) fieldObject));
     } else if (fieldObject instanceof Double) {
       builder.setLiteralField(doubleField((Double) fieldObject));
     } else if (fieldObject instanceof String) {
@@ -113,7 +146,7 @@ public class ProtoSerializationUtils {
     } else if (fieldObject instanceof Map) {
       builder.setMapField(serializeMapMemberVariable(fieldObject));
     } else {
-      builder.setObjectField(toObjectField(fieldObject));
+      builder.setObjectField(convertObjectToObjectField(fieldObject));
     }
     return builder.build();
   }
@@ -165,6 +198,8 @@ public class ProtoSerializationUtils {
         return literalField.getIntField();
       case LONGFIELD:
         return literalField.getLongField();
+      case FLOATFIELD:
+        return literalField.getFloatField();
       case DOUBLEFIELD:
         return literalField.getDoubleField();
       case STRINGFIELD:
@@ -183,7 +218,7 @@ public class ProtoSerializationUtils {
     return list;
   }
 
-  private static Object constructMap(Plan.MapField mapField) {
+  private static Map constructMap(Plan.MapField mapField) {
     Map map = new HashMap();
     for (Map.Entry<String, Plan.MemberVariableField> e : mapField.getContentMap().entrySet()) {
       map.put(e.getKey(), constructMemberVariable(e.getValue()));
@@ -192,18 +227,22 @@ public class ProtoSerializationUtils {
   }
 
   private static Object constructObject(Plan.ObjectField objectField) {
-    try {
-      Class<?> clazz = Class.forName(objectField.getObjectClassName());
-      if (clazz.isEnum()) {
-        return Enum.valueOf((Class<Enum>) clazz,
-            objectField.getMemberVariablesOrDefault(ENUM_VALUE_KEY, null).getLiteralField().getStringField());
-      } else {
-        Object obj = clazz.newInstance();
-        fromObjectField(obj, objectField);
-        return obj;
+    if (!NULL_OBJECT_CLASSNAME.equals(objectField.getObjectClassName())) {
+      try {
+        Class<?> clazz = Class.forName(objectField.getObjectClassName());
+        if (clazz.isEnum()) {
+          return Enum.valueOf((Class<Enum>) clazz,
+              objectField.getMemberVariablesOrDefault(ENUM_VALUE_KEY, null).getLiteralField().getStringField());
+        } else {
+          Object obj = clazz.newInstance();
+          setObjectFieldToObject(obj, objectField);
+          return obj;
+        }
+      } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
+        throw new IllegalStateException("Unable to create Object of type: " + objectField.getObjectClassName(), e);
       }
-    } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
-      throw new IllegalStateException("Unable to create Object of type: " + objectField.getObjectClassName(), e);
+    } else {
+      return null;
     }
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/AbstractStageNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
similarity index 67%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/AbstractStageNode.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
index ed1fc9ba3e..bdcccea355 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/AbstractStageNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
@@ -16,19 +16,25 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.calcite.rel.type.RelDataType;
 import org.apache.pinot.common.proto.Plan;
-import org.apache.pinot.query.planner.nodes.serde.ProtoSerializable;
-import org.apache.pinot.query.planner.nodes.serde.ProtoSerializationUtils;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
+import org.apache.pinot.query.planner.serde.ProtoSerializable;
+import org.apache.pinot.query.planner.serde.ProtoSerializationUtils;
 
 
 public abstract class AbstractStageNode implements StageNode, ProtoSerializable {
 
+  @ProtoProperties
   protected final int _stageId;
+  @ProtoProperties
   protected final List<StageNode> _inputs;
+  @ProtoProperties
+  protected RelDataType _rowType;
 
   public AbstractStageNode(int stageId) {
     _stageId = stageId;
@@ -51,12 +57,16 @@ public abstract class AbstractStageNode implements StageNode, ProtoSerializable
   }
 
   @Override
-  public void setObjectField(Plan.ObjectField objectField) {
-    ProtoSerializationUtils.fromObjectField(this, objectField);
+  public void fromObjectField(Plan.ObjectField objectField) {
+    ProtoSerializationUtils.setObjectFieldToObject(this, objectField);
   }
 
   @Override
-  public Plan.ObjectField getObjectField() {
-    return ProtoSerializationUtils.toObjectField(this);
+  public Plan.ObjectField toObjectField() {
+    return ProtoSerializationUtils.convertObjectToObjectField(this);
+  }
+
+  public RelDataType getRowType() {
+    return _rowType;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/CalcNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/FilterNode.java
similarity index 56%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/CalcNode.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/FilterNode.java
index 0aa8c94ec8..52df4ed5d3 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/CalcNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/FilterNode.java
@@ -16,23 +16,29 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexNode;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
-public class CalcNode extends AbstractStageNode {
-  private String _expression;
+public class FilterNode extends AbstractStageNode {
+  @ProtoProperties
+  private RexExpression _condition;
 
-  public CalcNode(int stageId) {
+  public FilterNode(int stageId) {
     super(stageId);
   }
 
-  public CalcNode(int stageId, String expression) {
-    super(stageId);
-    _expression = expression;
+  public FilterNode(int currentStageId, RelDataType rowType, RexNode condition) {
+    super(currentStageId);
+    super._rowType = rowType;
+    _condition = RexExpression.toRexExpression(condition);
   }
 
-  public String getExpression() {
-    return _expression;
+  public RexExpression getCondition() {
+    return _condition;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/JoinNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
similarity index 84%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/JoinNode.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
index bf380639d8..96b6c43a95 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/JoinNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
@@ -16,25 +16,29 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
 import java.util.List;
 import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.type.RelDataType;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
 public class JoinNode extends AbstractStageNode {
+  @ProtoProperties
   private JoinRelType _joinRelType;
+  @ProtoProperties
   private List<JoinClause> _criteria;
 
   public JoinNode(int stageId) {
     super(stageId);
   }
 
-  public JoinNode(int stageId, JoinRelType joinRelType, List<JoinClause> criteria
-  ) {
+  public JoinNode(int stageId, RelDataType rowType, JoinRelType joinRelType, List<JoinClause> criteria) {
     super(stageId);
+    super._rowType = rowType;
     _joinRelType = joinRelType;
     _criteria = criteria;
   }
@@ -48,7 +52,9 @@ public class JoinNode extends AbstractStageNode {
   }
 
   public static class JoinClause {
+    @ProtoProperties
     private KeySelector<Object[], Object> _leftJoinKeySelector;
+    @ProtoProperties
     private KeySelector<Object[], Object> _rightJoinKeySelector;
 
     public JoinClause() {
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/MailboxReceiveNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
similarity index 79%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/MailboxReceiveNode.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
index 8f0c619b79..1c01d8de5c 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/MailboxReceiveNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
@@ -16,21 +16,26 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
 import org.apache.calcite.rel.RelDistribution;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
 public class MailboxReceiveNode extends AbstractStageNode {
+  @ProtoProperties
   private int _senderStageId;
+  @ProtoProperties
   private RelDistribution.Type _exchangeType;
 
   public MailboxReceiveNode(int stageId) {
     super(stageId);
   }
 
-  public MailboxReceiveNode(int stageId, int senderStageId, RelDistribution.Type exchangeType) {
+  public MailboxReceiveNode(int stageId, RelDataType rowType, int senderStageId, RelDistribution.Type exchangeType) {
     super(stageId);
+    super._rowType = rowType;
     _senderStageId = senderStageId;
     _exchangeType = exchangeType;
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/MailboxSendNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
similarity index 56%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/MailboxSendNode.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
index 9867a16f61..c3f540aa0e 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/MailboxSendNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
@@ -16,23 +16,39 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
+import javax.annotation.Nullable;
 import org.apache.calcite.rel.RelDistribution;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.pinot.query.planner.partitioning.KeySelector;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
 public class MailboxSendNode extends AbstractStageNode {
+  @ProtoProperties
   private int _receiverStageId;
+  @ProtoProperties
   private RelDistribution.Type _exchangeType;
+  @ProtoProperties
+  private KeySelector<Object[], Object> _partitionKeySelector;
 
   public MailboxSendNode(int stageId) {
     super(stageId);
   }
 
-  public MailboxSendNode(int stageId, int receiverStageId, RelDistribution.Type exchangeType) {
+  public MailboxSendNode(int stageId, RelDataType rowType, int receiverStageId, RelDistribution.Type exchangeType) {
+    // When exchangeType is not HASH_DISTRIBUTE, no partitionKeySelector is needed.
+    this(stageId, rowType, receiverStageId, exchangeType, null);
+  }
+
+  public MailboxSendNode(int stageId, RelDataType rowType, int receiverStageId, RelDistribution.Type exchangeType,
+      @Nullable KeySelector<Object[], Object> partitionKeySelector) {
     super(stageId);
+    super._rowType = rowType;
     _receiverStageId = receiverStageId;
     _exchangeType = exchangeType;
+    _partitionKeySelector = partitionKeySelector;
   }
 
   public int getReceiverStageId() {
@@ -42,4 +58,8 @@ public class MailboxSendNode extends AbstractStageNode {
   public RelDistribution.Type getExchangeType() {
     return _exchangeType;
   }
+
+  public KeySelector<Object[], Object> getPartitionKeySelector() {
+    return _partitionKeySelector;
+  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/TableScanNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/ProjectNode.java
similarity index 51%
copy from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/TableScanNode.java
copy to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/ProjectNode.java
index 9375a7e986..4ee9be6c0a 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/TableScanNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/ProjectNode.java
@@ -16,30 +16,34 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
 import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexNode;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
-public class TableScanNode extends AbstractStageNode {
-  private String _tableName;
-  private List<String> _tableScanColumns;
+public class ProjectNode extends AbstractStageNode {
+  @ProtoProperties
+  private List<RexExpression> _projects;
 
-  public TableScanNode(int stageId) {
+  public ProjectNode(int stageId) {
     super(stageId);
   }
-
-  public TableScanNode(int stageId, String tableName, List<String> tableScanColumns) {
-    super(stageId);
-    _tableName = tableName;
-    _tableScanColumns = tableScanColumns;
+  public ProjectNode(int currentStageId, RelDataType rowType, List<RexNode> projects) {
+    super(currentStageId);
+    super._rowType = rowType;
+    _projects = projects.stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
   }
 
-  public String getTableName() {
-    return _tableName;
+  public List<RexExpression> getProjects() {
+    return _projects;
   }
 
-  public List<String> getTableScanColumns() {
-    return _tableScanColumns;
+  public RelDataType getRowType() {
+    return _rowType;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/StageNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
similarity index 96%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/StageNode.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
index cd34aca530..45e65a8c21 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/StageNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
 import java.io.Serializable;
 import java.util.List;
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/SerDeUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
similarity index 85%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/SerDeUtils.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
index ad7184cdb1..3d34f6effb 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/SerDeUtils.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
@@ -16,19 +16,19 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
 import org.apache.pinot.common.proto.Plan;
 
 
-public final class SerDeUtils {
-  private SerDeUtils() {
+public final class StageNodeSerDeUtils {
+  private StageNodeSerDeUtils() {
     // do not instantiate.
   }
 
   public static AbstractStageNode deserializeStageNode(Plan.StageNode protoNode) {
     AbstractStageNode stageNode = newNodeInstance(protoNode.getNodeName(), protoNode.getStageId());
-    stageNode.setObjectField(protoNode.getObjectField());
+    stageNode.fromObjectField(protoNode.getObjectField());
     for (Plan.StageNode protoChild : protoNode.getInputsList()) {
       stageNode.addInput(deserializeStageNode(protoChild));
     }
@@ -39,7 +39,7 @@ public final class SerDeUtils {
     Plan.StageNode.Builder builder = Plan.StageNode.newBuilder()
         .setStageId(stageNode.getStageId())
         .setNodeName(stageNode.getClass().getSimpleName())
-        .setObjectField(stageNode.getObjectField());
+        .setObjectField(stageNode.toObjectField());
     for (StageNode childNode : stageNode.getInputs()) {
       builder.addInputs(serializeStageNode((AbstractStageNode) childNode));
     }
@@ -52,8 +52,10 @@ public final class SerDeUtils {
         return new TableScanNode(stageId);
       case "JoinNode":
         return new JoinNode(stageId);
-      case "CalcNode":
-        return new CalcNode(stageId);
+      case "ProjectNode":
+        return new ProjectNode(stageId);
+      case "FilterNode":
+        return new FilterNode(stageId);
       case "MailboxSendNode":
         return new MailboxSendNode(stageId);
       case "MailboxReceiveNode":
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/TableScanNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/TableScanNode.java
similarity index 79%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/TableScanNode.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/TableScanNode.java
index 9375a7e986..9ba36d34f3 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/nodes/TableScanNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/TableScanNode.java
@@ -16,21 +16,26 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner.nodes;
+package org.apache.pinot.query.planner.stage;
 
 import java.util.List;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
 public class TableScanNode extends AbstractStageNode {
+  @ProtoProperties
   private String _tableName;
+  @ProtoProperties
   private List<String> _tableScanColumns;
 
   public TableScanNode(int stageId) {
     super(stageId);
   }
 
-  public TableScanNode(int stageId, String tableName, List<String> tableScanColumns) {
+  public TableScanNode(int stageId, RelDataType rowType, String tableName, List<String> tableScanColumns) {
     super(stageId);
+    super._rowType = rowType;
     _tableName = tableName;
     _tableScanColumns = tableScanColumns;
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotExchangeNodeInsertRule.java
index 2b35613f7a..e7ef083ded 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotExchangeNodeInsertRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotExchangeNodeInsertRule.java
@@ -19,6 +19,8 @@
 package org.apache.pinot.query.rules;
 
 import com.google.common.collect.ImmutableList;
+import java.util.Collections;
+import java.util.List;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.hep.HepRelVertex;
@@ -27,9 +29,13 @@ import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Exchange;
 import org.apache.calcite.rel.core.Join;
 import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.hint.RelHint;
 import org.apache.calcite.rel.logical.LogicalExchange;
 import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.pinot.query.planner.hints.PinotRelationalHints;
 
 
 /**
@@ -57,12 +63,26 @@ public class PinotExchangeNodeInsertRule extends RelOptRule {
 
   @Override
   public void onMatch(RelOptRuleCall call) {
+    // TODO: this only works for single equality JOIN. add generic condition parser
     Join join = call.rel(0);
     RelNode leftInput = join.getInput(0);
     RelNode rightInput = join.getInput(1);
 
-    RelNode leftExchange = LogicalExchange.create(leftInput, RelDistributions.SINGLETON);
-    RelNode rightExchange = LogicalExchange.create(rightInput, RelDistributions.BROADCAST_DISTRIBUTED);
+    RelNode leftExchange;
+    RelNode rightExchange;
+    List<RelHint> hints = join.getHints();
+    if (hints.contains(PinotRelationalHints.USE_HASH_DISTRIBUTE)) {
+      int leftOperandIndex = ((RexInputRef) ((RexCall) join.getCondition()).getOperands().get(0)).getIndex();
+      int rightOperandIndex = ((RexInputRef) ((RexCall) join.getCondition()).getOperands().get(1)).getIndex()
+          - join.getLeft().getRowType().getFieldNames().size();
+      leftExchange = LogicalExchange.create(leftInput,
+          RelDistributions.hash(Collections.singletonList(leftOperandIndex)));
+      rightExchange = LogicalExchange.create(rightInput,
+          RelDistributions.hash(Collections.singletonList(rightOperandIndex)));
+    } else { // if (hints.contains(PinotRelationalHints.USE_BROADCAST_JOIN))
+      leftExchange = LogicalExchange.create(leftInput, RelDistributions.SINGLETON);
+      rightExchange = LogicalExchange.create(rightInput, RelDistributions.BROADCAST_DISTRIBUTED);
+    }
 
     RelNode newJoinNode =
         new LogicalJoin(join.getCluster(), join.getTraitSet(), leftExchange, rightExchange, join.getCondition(),
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotQueryRuleSets.java
index 1b4e0850ac..63c2fd799f 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotQueryRuleSets.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/rules/PinotQueryRuleSets.java
@@ -45,6 +45,8 @@ public class PinotQueryRuleSets {
           CoreRules.FILTER_AGGREGATE_TRANSPOSE,
           // push filter through set operation
           CoreRules.FILTER_SET_OP_TRANSPOSE,
+          // push project through join,
+          CoreRules.PROJECT_JOIN_TRANSPOSE,
           // push project through set operation
           CoreRules.PROJECT_SET_OP_TRANSPOSE,
 
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTest.java
index 60c7cd11af..9f4778743b 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTest.java
@@ -19,56 +19,45 @@
 package org.apache.pinot.query;
 
 import com.google.common.collect.ImmutableList;
-import java.io.PrintWriter;
-import java.io.StringWriter;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
-import org.apache.calcite.jdbc.CalciteSchemaBuilder;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.RelRoot;
-import org.apache.calcite.rel.RelWriter;
-import org.apache.calcite.rel.externalize.RelXmlWriter;
-import org.apache.calcite.sql.SqlExplainLevel;
 import org.apache.calcite.sql.SqlNode;
-import org.apache.pinot.core.routing.RoutingManager;
 import org.apache.pinot.core.transport.ServerInstance;
-import org.apache.pinot.query.catalog.PinotCatalog;
 import org.apache.pinot.query.context.PlannerContext;
 import org.apache.pinot.query.planner.PlannerUtils;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.routing.WorkerManager;
-import org.apache.pinot.query.type.TypeFactory;
-import org.apache.pinot.query.type.TypeSystem;
 import org.testng.Assert;
-import org.testng.annotations.BeforeClass;
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 
-public class QueryEnvironmentTest {
-  private QueryEnvironment _queryEnvironment;
+public class QueryEnvironmentTest extends QueryEnvironmentTestBase {
 
-  @BeforeClass
-  public void setUp() {
-    // the port doesn't matter as we are not actually making a server call.
-    RoutingManager routingManager = QueryEnvironmentTestUtils.getMockRoutingManager(1, 2);
-    _queryEnvironment = new QueryEnvironment(new TypeFactory(new TypeSystem()),
-        CalciteSchemaBuilder.asRootSchema(new PinotCatalog(QueryEnvironmentTestUtils.mockTableCache())),
-        new WorkerManager("localhost", 3, routingManager));
+  @Test(dataProvider = "testQueryParserDataProvider")
+  public void testQueryParser(String query, String digest)
+      throws Exception {
+    PlannerContext plannerContext = new PlannerContext();
+    SqlNode sqlNode = _queryEnvironment.parse(query, plannerContext);
+    _queryEnvironment.validate(sqlNode);
+    Assert.assertEquals(sqlNode.toString(), digest);
   }
 
-  @Test
-  public void testSqlStrings()
+  @Test(dataProvider = "testQueryDataProvider")
+  public void testQueryToRel(String query)
       throws Exception {
-    testQueryParsing("SELECT * FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0",
-        "SELECT *\n" + "FROM `a`\n" + "INNER JOIN `b` ON `a`.`col1` = `b`.`col2`\n" + "WHERE `a`.`col3` >= 0");
+    try {
+      QueryPlan queryPlan = _queryEnvironment.planQuery(query);
+      Assert.assertNotNull(queryPlan);
+    } catch (RuntimeException e) {
+      Assert.fail("failed to plan query: " + query, e);
+    }
   }
 
   @Test
-  public void testQueryToStages()
+  public void testQueryAndAssertStageContentForJoin()
       throws Exception {
-    PlannerContext plannerContext = new PlannerContext();
     String query = "SELECT * FROM a JOIN b ON a.col1 = b.col2";
     QueryPlan queryPlan = _queryEnvironment.planQuery(query);
     Assert.assertEquals(queryPlan.getQueryStageMap().size(), 4);
@@ -96,28 +85,19 @@ public class QueryEnvironmentTest {
   }
 
   @Test
-  public void testQueryToRel()
-      throws Exception {
-    PlannerContext plannerContext = new PlannerContext();
-    String query = "SELECT * FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0";
-    SqlNode parsed = _queryEnvironment.parse(query, plannerContext);
-    SqlNode validated = _queryEnvironment.validate(parsed);
-    RelRoot relRoot = _queryEnvironment.toRelation(validated, plannerContext);
-    RelNode optimized = _queryEnvironment.optimize(relRoot, plannerContext);
-
-    // Assert that relational plan can be written into a ALL-ATTRIBUTE digest.
-    StringWriter sw = new StringWriter();
-    PrintWriter pw = new PrintWriter(sw);
-    RelWriter planWriter = new RelXmlWriter(pw, SqlExplainLevel.ALL_ATTRIBUTES);
-    optimized.explain(planWriter);
-    Assert.assertNotNull(sw.toString());
+  public void testQueryProjectFilterPushdownForJoin() {
+    String query = "SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
+        + "WHERE a.col3 >= 0 AND a.col2 IN  ('a', 'b') AND b.col3 < 0";
+    QueryPlan queryPlan = _queryEnvironment.planQuery(query);
+    Assert.assertEquals(queryPlan.getQueryStageMap().size(), 4);
+    Assert.assertEquals(queryPlan.getStageMetadataMap().size(), 4);
   }
 
-  private void testQueryParsing(String query, String digest)
-      throws Exception {
-    PlannerContext plannerContext = new PlannerContext();
-    SqlNode sqlNode = _queryEnvironment.parse(query, plannerContext);
-    _queryEnvironment.validate(sqlNode);
-    Assert.assertEquals(sqlNode.toString(), digest);
+  @DataProvider(name = "testQueryParserDataProvider")
+  private Object[][] provideQueriesAndDigest() {
+    return new Object[][] {
+        new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0",
+            "SELECT *\n" + "FROM `a`\n" + "INNER JOIN `b` ON `a`.`col1` = `b`.`col2`\n" + "WHERE `a`.`col3` >= 0"},
+    };
   }
 }
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
new file mode 100644
index 0000000000..40841d7c72
--- /dev/null
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query;
+
+import org.apache.calcite.jdbc.CalciteSchemaBuilder;
+import org.apache.pinot.core.routing.RoutingManager;
+import org.apache.pinot.query.catalog.PinotCatalog;
+import org.apache.pinot.query.routing.WorkerManager;
+import org.apache.pinot.query.type.TypeFactory;
+import org.apache.pinot.query.type.TypeSystem;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.DataProvider;
+
+
+public class QueryEnvironmentTestBase {
+  protected QueryEnvironment _queryEnvironment;
+
+  @BeforeClass
+  public void setUp() {
+    // the port doesn't matter as we are not actually making a server call.
+    RoutingManager routingManager = QueryEnvironmentTestUtils.getMockRoutingManager(1, 2);
+    _queryEnvironment = new QueryEnvironment(new TypeFactory(new TypeSystem()),
+        CalciteSchemaBuilder.asRootSchema(new PinotCatalog(QueryEnvironmentTestUtils.mockTableCache())),
+        new WorkerManager("localhost", 3, routingManager));
+  }
+
+  @DataProvider(name = "testQueryDataProvider")
+  protected Object[][] provideQueries() {
+    return new Object[][] {
+        new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col2"},
+        new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0"},
+        new Object[]{"SELECT a.col1, a.ts, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
+            + "WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0"},
+    };
+  }
+}
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/stage/SerDeUtilsTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/stage/SerDeUtilsTest.java
new file mode 100644
index 0000000000..21031cd303
--- /dev/null
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/stage/SerDeUtilsTest.java
@@ -0,0 +1,81 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.stage;
+
+import java.lang.reflect.Field;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.proto.Plan;
+import org.apache.pinot.query.QueryEnvironmentTestBase;
+import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.serde.ProtoProperties;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+
+public class SerDeUtilsTest extends QueryEnvironmentTestBase {
+
+  @Test(dataProvider = "testQueryDataProvider")
+  public void testQueryStagePlanSerDe(String query)
+      throws Exception {
+    QueryPlan queryPlan = _queryEnvironment.planQuery(query);
+    for (StageNode stageNode : queryPlan.getQueryStageMap().values()) {
+      Plan.StageNode serializedStageNode = StageNodeSerDeUtils.serializeStageNode((AbstractStageNode) stageNode);
+      StageNode deserializedStageNode = StageNodeSerDeUtils.deserializeStageNode(serializedStageNode);
+      Assert.assertTrue(isObjectEqual(stageNode, deserializedStageNode));
+    }
+  }
+
+  @SuppressWarnings({"rawtypes"})
+  private boolean isObjectEqual(Object left, Object right)
+      throws IllegalAccessException {
+    Class<?> clazz = left.getClass();
+    for (Field field : clazz.getDeclaredFields()) {
+      if (field.isAnnotationPresent(ProtoProperties.class)) {
+        field.setAccessible(true);
+        Object l = field.get(left);
+        Object r = field.get(right);
+        if (l instanceof List) {
+          if (((List) l).size() != ((List) r).size()) {
+            return false;
+          }
+          for (int i = 0; i < ((List) l).size(); i++) {
+            if (!isObjectEqual(((List) l).get(i), ((List) r).get(i))) {
+              return false;
+            }
+          }
+        } else if (l instanceof Map) {
+          if (((Map) l).size() != ((Map) r).size()) {
+            return false;
+          }
+          for (Object key : ((Map) l).keySet()) {
+            if (!isObjectEqual(((Map) l).get(key), ((Map) r).get(key))) {
+              return false;
+            }
+          }
+        } else {
+          if (!(l == null && r == null || l != null && l.equals(r) || isObjectEqual(l, r))) {
+            return false;
+          }
+        }
+      }
+    }
+    return true;
+  }
+}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
index 2968ff9a23..f33cde43c2 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
@@ -31,7 +31,7 @@ import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.mailbox.GrpcMailboxService;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.planner.nodes.MailboxSendNode;
+import org.apache.pinot.query.planner.stage.MailboxSendNode;
 import org.apache.pinot.query.runtime.executor.WorkerQueryExecutor;
 import org.apache.pinot.query.runtime.operator.MailboxSendOperator;
 import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
@@ -100,8 +100,8 @@ public class QueryRunner {
       StageMetadata receivingStageMetadata = distributedStagePlan.getMetadataMap().get(sendNode.getReceiverStageId());
       MailboxSendOperator mailboxSendOperator =
           new MailboxSendOperator(_mailboxService, dataTable, receivingStageMetadata.getServerInstances(),
-              sendNode.getExchangeType(), _hostname, _port, serverQueryRequest.getRequestId(),
-              sendNode.getStageId());
+              sendNode.getExchangeType(), sendNode.getPartitionKeySelector(), _hostname, _port,
+              serverQueryRequest.getRequestId(), sendNode.getStageId());
       mailboxSendOperator.nextBlock();
     } else {
       _workerExecutor.processQuery(distributedStagePlan, requestMetadataMap, executorService);
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/WorkerQueryExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/WorkerQueryExecutor.java
index 85c0f108b4..9a66bf86c1 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/WorkerQueryExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/WorkerQueryExecutor.java
@@ -30,13 +30,15 @@ import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.core.util.trace.TraceRunnable;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.planner.nodes.JoinNode;
-import org.apache.pinot.query.planner.nodes.MailboxReceiveNode;
-import org.apache.pinot.query.planner.nodes.MailboxSendNode;
-import org.apache.pinot.query.planner.nodes.StageNode;
+import org.apache.pinot.query.planner.stage.FilterNode;
+import org.apache.pinot.query.planner.stage.JoinNode;
+import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
+import org.apache.pinot.query.planner.stage.MailboxSendNode;
+import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.runtime.blocks.DataTableBlock;
 import org.apache.pinot.query.runtime.blocks.DataTableBlockUtils;
-import org.apache.pinot.query.runtime.operator.BroadcastJoinOperator;
+import org.apache.pinot.query.runtime.operator.HashJoinOperator;
 import org.apache.pinot.query.runtime.operator.MailboxReceiveOperator;
 import org.apache.pinot.query.runtime.operator.MailboxSendOperator;
 import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
@@ -98,22 +100,27 @@ public class WorkerQueryExecutor {
   private BaseOperator<DataTableBlock> getOperator(long requestId, StageNode stageNode,
       Map<Integer, StageMetadata> metadataMap) {
     // TODO: optimize this into a framework. (physical planner)
-    if (stageNode instanceof MailboxSendNode) {
-      MailboxSendNode sendNode = (MailboxSendNode) stageNode;
-      BaseOperator<DataTableBlock> nextOperator = getOperator(requestId, sendNode.getInputs().get(0), metadataMap);
-      StageMetadata receivingStageMetadata = metadataMap.get(sendNode.getReceiverStageId());
-      return new MailboxSendOperator(_mailboxService, nextOperator, receivingStageMetadata.getServerInstances(),
-          sendNode.getExchangeType(), _hostName, _port, requestId, sendNode.getStageId());
-    } else if (stageNode instanceof MailboxReceiveNode) {
+    if (stageNode instanceof MailboxReceiveNode) {
       MailboxReceiveNode receiveNode = (MailboxReceiveNode) stageNode;
       List<ServerInstance> sendingInstances = metadataMap.get(receiveNode.getSenderStageId()).getServerInstances();
       return new MailboxReceiveOperator(_mailboxService, RelDistribution.Type.ANY, sendingInstances, _hostName, _port,
           requestId, receiveNode.getSenderStageId());
+    } else if (stageNode instanceof MailboxSendNode) {
+      MailboxSendNode sendNode = (MailboxSendNode) stageNode;
+      BaseOperator<DataTableBlock> nextOperator = getOperator(requestId, sendNode.getInputs().get(0), metadataMap);
+      StageMetadata receivingStageMetadata = metadataMap.get(sendNode.getReceiverStageId());
+      return new MailboxSendOperator(_mailboxService, nextOperator, receivingStageMetadata.getServerInstances(),
+          sendNode.getExchangeType(), sendNode.getPartitionKeySelector(), _hostName, _port, requestId,
+          sendNode.getStageId());
     } else if (stageNode instanceof JoinNode) {
       JoinNode joinNode = (JoinNode) stageNode;
       BaseOperator<DataTableBlock> leftOperator = getOperator(requestId, joinNode.getInputs().get(0), metadataMap);
       BaseOperator<DataTableBlock> rightOperator = getOperator(requestId, joinNode.getInputs().get(1), metadataMap);
-      return new BroadcastJoinOperator(leftOperator, rightOperator, joinNode.getCriteria());
+      return new HashJoinOperator(leftOperator, rightOperator, joinNode.getCriteria());
+    } else if (stageNode instanceof FilterNode) {
+      throw new UnsupportedOperationException("Unsupported!");
+    } else if (stageNode instanceof ProjectNode) {
+      throw new UnsupportedOperationException("Unsupported!");
     } else {
       throw new UnsupportedOperationException(
           String.format("Stage node type %s is not supported!", stageNode.getClass().getSimpleName()));
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BroadcastJoinOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
similarity index 96%
rename from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BroadcastJoinOperator.java
rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
index db5bb5289d..6cf9fddd12 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BroadcastJoinOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
@@ -28,8 +28,8 @@ import org.apache.pinot.common.utils.DataTable;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.operator.BaseOperator;
 import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
-import org.apache.pinot.query.planner.nodes.JoinNode;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
+import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.runtime.blocks.DataTableBlock;
 import org.apache.pinot.query.runtime.blocks.DataTableBlockUtils;
 
@@ -42,7 +42,7 @@ import org.apache.pinot.query.runtime.blocks.DataTableBlockUtils;
  *
  * <p>For each of the data block received from the left table, it will generate a joint data block.
  */
-public class BroadcastJoinOperator extends BaseOperator<DataTableBlock> {
+public class HashJoinOperator extends BaseOperator<DataTableBlock> {
   private static final String OPERATOR_NAME = "BroadcastJoinOperator";
   private static final String EXPLAIN_NAME = "BROADCAST_JOIN";
 
@@ -57,7 +57,7 @@ public class BroadcastJoinOperator extends BaseOperator<DataTableBlock> {
   private KeySelector<Object[], Object> _leftKeySelector;
   private KeySelector<Object[], Object> _rightKeySelector;
 
-  public BroadcastJoinOperator(BaseOperator<DataTableBlock> leftTableOperator,
+  public HashJoinOperator(BaseOperator<DataTableBlock> leftTableOperator,
       BaseOperator<DataTableBlock> rightTableOperator, List<JoinNode.JoinClause> criteria) {
     // TODO: this assumes right table is broadcast.
     _leftKeySelector = criteria.get(0).getLeftJoinKeySelector();
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
index 3971dfb325..3dd9c5ff77 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
@@ -22,6 +22,7 @@ import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableSet;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Set;
 import javax.annotation.Nullable;
@@ -30,10 +31,12 @@ import org.apache.pinot.common.proto.Mailbox;
 import org.apache.pinot.common.utils.DataTable;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.operator.BaseOperator;
+import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
 import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.SendingMailbox;
 import org.apache.pinot.query.mailbox.StringMailboxIdentifier;
+import org.apache.pinot.query.planner.partitioning.KeySelector;
 import org.apache.pinot.query.runtime.blocks.DataTableBlock;
 import org.apache.pinot.query.runtime.blocks.DataTableBlockUtils;
 import org.slf4j.Logger;
@@ -49,10 +52,11 @@ public class MailboxSendOperator extends BaseOperator<DataTableBlock> {
   private static final String EXPLAIN_NAME = "MAILBOX_SEND";
   private static final Set<RelDistribution.Type> SUPPORTED_EXCHANGE_TYPE =
       ImmutableSet.of(RelDistribution.Type.SINGLETON, RelDistribution.Type.RANDOM_DISTRIBUTED,
-          RelDistribution.Type.BROADCAST_DISTRIBUTED);
+          RelDistribution.Type.BROADCAST_DISTRIBUTED, RelDistribution.Type.HASH_DISTRIBUTED);
 
   private final List<ServerInstance> _receivingStageInstances;
   private final RelDistribution.Type _exchangeType;
+  private final KeySelector<Object[], Object> _keySelector;
   private final String _serverHostName;
   private final int _serverPort;
   private final long _jobId;
@@ -63,11 +67,13 @@ public class MailboxSendOperator extends BaseOperator<DataTableBlock> {
 
   public MailboxSendOperator(MailboxService<Mailbox.MailboxContent> mailboxService,
       BaseOperator<DataTableBlock> dataTableBlockBaseOperator, List<ServerInstance> receivingStageInstances,
-      RelDistribution.Type exchangeType, String hostName, int port, long jobId, int stageId) {
+      RelDistribution.Type exchangeType, KeySelector<Object[], Object> keySelector, String hostName, int port,
+      long jobId, int stageId) {
     _mailboxService = mailboxService;
     _dataTableBlockBaseOperator = dataTableBlockBaseOperator;
     _receivingStageInstances = receivingStageInstances;
     _exchangeType = exchangeType;
+    _keySelector = keySelector;
     _serverHostName = hostName;
     _serverPort = port;
     _jobId = jobId;
@@ -82,12 +88,13 @@ public class MailboxSendOperator extends BaseOperator<DataTableBlock> {
    * creation of MailboxSendOperator we should not use this API.
    */
   public MailboxSendOperator(MailboxService<Mailbox.MailboxContent> mailboxService, DataTable dataTable,
-      List<ServerInstance> receivingStageInstances, RelDistribution.Type exchangeType, String hostName, int port,
-      long jobId, int stageId) {
+      List<ServerInstance> receivingStageInstances, RelDistribution.Type exchangeType,
+      KeySelector<Object[], Object> keySelector, String hostName, int port, long jobId, int stageId) {
     _mailboxService = mailboxService;
     _dataTable = dataTable;
     _receivingStageInstances = receivingStageInstances;
     _exchangeType = exchangeType;
+    _keySelector = keySelector;
     _serverHostName = hostName;
     _serverPort = port;
     _jobId = jobId;
@@ -115,13 +122,16 @@ public class MailboxSendOperator extends BaseOperator<DataTableBlock> {
   protected DataTableBlock getNextBlock() {
     DataTable dataTable;
     DataTableBlock dataTableBlock = null;
+    boolean isEndOfStream;
     if (_dataTableBlockBaseOperator != null) {
       dataTableBlock = _dataTableBlockBaseOperator.nextBlock();
       dataTable = dataTableBlock.getDataTable();
+      isEndOfStream = DataTableBlockUtils.isEndOfStream(dataTableBlock);
     } else {
       dataTable = _dataTable;
+      isEndOfStream = true;
     }
-    boolean isEndOfStream = dataTableBlock == null || DataTableBlockUtils.isEndOfStream(dataTableBlock);
+
     try {
       switch (_exchangeType) {
         // TODO: random and singleton distribution should've been selected in planning phase.
@@ -142,6 +152,13 @@ public class MailboxSendOperator extends BaseOperator<DataTableBlock> {
           }
           break;
         case HASH_DISTRIBUTED:
+          // TODO: ensure that server instance list is sorted using same function in sender.
+          List<DataTable> dataTableList = constructPartitionedDataBlock(dataTable, _keySelector,
+              _receivingStageInstances.size());
+          for (int i = 0; i < _receivingStageInstances.size(); i++) {
+            sendDataTableBlock(_receivingStageInstances.get(i), dataTableList.get(i), isEndOfStream);
+          }
+          break;
         case RANGE_DISTRIBUTED:
         case ROUND_ROBIN_DISTRIBUTED:
         case ANY:
@@ -154,6 +171,31 @@ public class MailboxSendOperator extends BaseOperator<DataTableBlock> {
     return dataTableBlock;
   }
 
+  private static List<DataTable> constructPartitionedDataBlock(DataTable dataTable,
+      KeySelector<Object[], Object> keySelector, int partitionSize)
+      throws Exception {
+    List<List<Object[]>> temporaryRows = new ArrayList<>(partitionSize);
+    for (int i = 0; i < partitionSize; i++) {
+      temporaryRows.add(new ArrayList<>());
+    }
+    for (int rowId = 0; rowId < dataTable.getNumberOfRows(); rowId++) {
+      Object[] row = SelectionOperatorUtils.extractRowFromDataTable(dataTable, rowId);
+      Object key = keySelector.getKey(row);
+      // TODO: support other partitioning algorithm
+      temporaryRows.get(hashToIndex(key, partitionSize)).add(row);
+    }
+    List<DataTable> dataTableList = new ArrayList<>(partitionSize);
+    for (int i = 0; i < partitionSize; i++) {
+      List<Object[]> objects = temporaryRows.get(i);
+      dataTableList.add(SelectionOperatorUtils.getDataTableFromRows(objects, dataTable.getDataSchema()));
+    }
+    return dataTableList;
+  }
+
+  private static int hashToIndex(Object key, int partitionSize) {
+    return (key.hashCode()) % partitionSize;
+  }
+
   private void sendDataTableBlock(ServerInstance serverInstance, DataTable dataTable, boolean isEndOfStream)
       throws IOException {
     String mailboxId = toMailboxId(serverInstance);
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
index f9ecf7f089..828a969dd0 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
@@ -22,7 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.planner.nodes.StageNode;
+import org.apache.pinot.query.planner.stage.StageNode;
 
 
 /**
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
index 358ebb8465..034bf561a2 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
@@ -25,8 +25,8 @@ import org.apache.commons.lang3.StringUtils;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.planner.nodes.AbstractStageNode;
-import org.apache.pinot.query.planner.nodes.SerDeUtils;
+import org.apache.pinot.query.planner.stage.AbstractStageNode;
+import org.apache.pinot.query.planner.stage.StageNodeSerDeUtils;
 import org.apache.pinot.query.routing.WorkerInstance;
 import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
 
@@ -43,7 +43,7 @@ public class QueryPlanSerDeUtils {
   public static DistributedStagePlan deserialize(Worker.StagePlan stagePlan) {
     DistributedStagePlan distributedStagePlan = new DistributedStagePlan(stagePlan.getStageId());
     distributedStagePlan.setServerInstance(stringToInstance(stagePlan.getInstanceId()));
-    distributedStagePlan.setStageRoot(SerDeUtils.deserializeStageNode(stagePlan.getStageRoot()));
+    distributedStagePlan.setStageRoot(StageNodeSerDeUtils.deserializeStageNode(stagePlan.getStageRoot()));
     Map<Integer, Worker.StageMetadata> metadataMap = stagePlan.getStageMetadataMap();
     distributedStagePlan.getMetadataMap().putAll(protoMapToStageMetadataMap(metadataMap));
     return distributedStagePlan;
@@ -53,7 +53,7 @@ public class QueryPlanSerDeUtils {
     return Worker.StagePlan.newBuilder()
         .setStageId(distributedStagePlan.getStageId())
         .setInstanceId(instanceToString(distributedStagePlan.getServerInstance()))
-        .setStageRoot(SerDeUtils.serializeStageNode((AbstractStageNode) distributedStagePlan.getStageRoot()))
+        .setStageRoot(StageNodeSerDeUtils.serializeStageNode((AbstractStageNode) distributedStagePlan.getStageRoot()))
         .putAllStageMetadata(stageMetadataMapToProtoMap(distributedStagePlan.getMetadataMap())).build();
   }
 
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/utils/ServerRequestUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/utils/ServerRequestUtils.java
index 306faa2598..5f17e83305 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/utils/ServerRequestUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/utils/ServerRequestUtils.java
@@ -29,11 +29,12 @@ import org.apache.pinot.common.request.PinotQuery;
 import org.apache.pinot.common.request.QuerySource;
 import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.core.query.request.ServerQueryRequest;
-import org.apache.pinot.query.planner.nodes.CalcNode;
-import org.apache.pinot.query.planner.nodes.MailboxReceiveNode;
-import org.apache.pinot.query.planner.nodes.MailboxSendNode;
-import org.apache.pinot.query.planner.nodes.StageNode;
-import org.apache.pinot.query.planner.nodes.TableScanNode;
+import org.apache.pinot.query.parser.CalciteRexExpressionParser;
+import org.apache.pinot.query.planner.stage.FilterNode;
+import org.apache.pinot.query.planner.stage.MailboxSendNode;
+import org.apache.pinot.query.planner.stage.ProjectNode;
+import org.apache.pinot.query.planner.stage.StageNode;
+import org.apache.pinot.query.planner.stage.TableScanNode;
 import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
 
 
@@ -88,22 +89,28 @@ public class ServerRequestUtils {
   }
 
   private static void walkStageTree(StageNode node, PinotQuery pinotQuery) {
-    if (node instanceof CalcNode) {
-      // TODO: add conversion for CalcNode, specifically filter/alias/...
-    } else if (node instanceof TableScanNode) {
+    // this walkStageTree should only be a sequential walk.
+    for (StageNode child : node.getInputs()) {
+      walkStageTree(child, pinotQuery);
+    }
+    if (node instanceof TableScanNode) {
       TableScanNode tableScanNode = (TableScanNode) node;
       DataSource dataSource = new DataSource();
       dataSource.setTableName(tableScanNode.getTableName());
       pinotQuery.setDataSource(dataSource);
       pinotQuery.setSelectList(tableScanNode.getTableScanColumns().stream().map(RequestUtils::getIdentifierExpression)
           .collect(Collectors.toList()));
-    } else if (node instanceof MailboxSendNode || node instanceof MailboxReceiveNode) {
-      // ignore for now. continue to child.
+    } else if (node instanceof FilterNode) {
+      pinotQuery.setFilterExpression(CalciteRexExpressionParser.toExpression(
+          ((FilterNode) node).getCondition(), pinotQuery));
+    } else if (node instanceof ProjectNode) {
+      pinotQuery.setSelectList(CalciteRexExpressionParser.convertSelectList(
+          ((ProjectNode) node).getProjects(), pinotQuery));
+    } else if (node instanceof MailboxSendNode) {
+      // TODO: MailboxSendNode should be the root of the leaf stage. but ignore for now since it is handle seperately
+      // in QueryRunner as a single step sender.
     } else {
       throw new UnsupportedOperationException("Unsupported logical plan node: " + node);
     }
-    for (StageNode child : node.getInputs()) {
-      walkStageTree(child, pinotQuery);
-    }
   }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/QueryDispatcher.java
index 3200b317bb..7ea69b8d21 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/QueryDispatcher.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/QueryDispatcher.java
@@ -32,7 +32,7 @@ import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.planner.nodes.MailboxReceiveNode;
+import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.runtime.blocks.DataTableBlock;
 import org.apache.pinot.query.runtime.blocks.DataTableBlockUtils;
 import org.apache.pinot.query.runtime.operator.MailboxReceiveOperator;
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
index 9de2db834f..2316195bce 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
@@ -36,7 +36,7 @@ import org.apache.pinot.query.QueryEnvironmentTestUtils;
 import org.apache.pinot.query.QueryServerEnclosure;
 import org.apache.pinot.query.mailbox.GrpcMailboxService;
 import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.nodes.MailboxReceiveNode;
+import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.routing.WorkerInstance;
 import org.apache.pinot.query.runtime.blocks.DataTableBlock;
 import org.apache.pinot.query.runtime.blocks.DataTableBlockUtils;
@@ -47,6 +47,7 @@ import org.apache.pinot.query.service.QueryDispatcher;
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 import static org.apache.pinot.core.query.selection.SelectionOperatorUtils.extractRowFromDataTable;
@@ -101,73 +102,9 @@ public class QueryRunnerTest {
     _mailboxService.shutdown();
   }
 
-  @Test
-  public void testRunningTableScanOnlyQuery()
-      throws Exception {
-    QueryPlan queryPlan = _queryEnvironment.planQuery("SELECT * FROM b");
-    int stageRoodId = QueryEnvironmentTestUtils.getTestStageByServerCount(queryPlan, 1);
-    Map<String, String> requestMetadataMap =
-        ImmutableMap.of("REQUEST_ID", String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()));
-
-    ServerInstance serverInstance = queryPlan.getStageMetadataMap().get(stageRoodId).getServerInstances().get(0);
-    DistributedStagePlan distributedStagePlan =
-        QueryDispatcher.constructDistributedStagePlan(queryPlan, stageRoodId, serverInstance);
-
-    MailboxReceiveOperator mailboxReceiveOperator =
-        createReduceStageOperator(queryPlan.getStageMetadataMap().get(stageRoodId).getServerInstances(),
-            Long.parseLong(requestMetadataMap.get("REQUEST_ID")), stageRoodId, _reducerGrpcPort);
-
-    // execute this single stage.
-    _servers.get(serverInstance).processQuery(distributedStagePlan, requestMetadataMap);
-
-    DataTableBlock dataTableBlock;
-    // get the block back and it should have 5 rows
-    dataTableBlock = mailboxReceiveOperator.nextBlock();
-    Assert.assertEquals(dataTableBlock.getDataTable().getNumberOfRows(), 5);
-    // next block should be null as all servers finished sending.
-    dataTableBlock = mailboxReceiveOperator.nextBlock();
-    Assert.assertTrue(DataTableBlockUtils.isEndOfStream(dataTableBlock));
-  }
-
-  @Test
-  public void testRunningTableScanMultipleServer()
-      throws Exception {
-    QueryPlan queryPlan = _queryEnvironment.planQuery("SELECT * FROM a");
-    int stageRoodId = QueryEnvironmentTestUtils.getTestStageByServerCount(queryPlan, 2);
-    Map<String, String> requestMetadataMap =
-        ImmutableMap.of("REQUEST_ID", String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()));
-
-    for (ServerInstance serverInstance : queryPlan.getStageMetadataMap().get(stageRoodId).getServerInstances()) {
-      DistributedStagePlan distributedStagePlan =
-          QueryDispatcher.constructDistributedStagePlan(queryPlan, stageRoodId, serverInstance);
-
-      // execute this single stage.
-      _servers.get(serverInstance).processQuery(distributedStagePlan, requestMetadataMap);
-    }
-
-    MailboxReceiveOperator mailboxReceiveOperator =
-        createReduceStageOperator(queryPlan.getStageMetadataMap().get(stageRoodId).getServerInstances(),
-            Long.parseLong(requestMetadataMap.get("REQUEST_ID")), stageRoodId, _reducerGrpcPort);
-
-    int count = 0;
-    int rowCount = 0;
-    DataTableBlock dataTableBlock;
-    while (count < 2) { // we have 2 servers sending data.
-      dataTableBlock = mailboxReceiveOperator.nextBlock();
-      rowCount += dataTableBlock.getDataTable().getNumberOfRows();
-      count++;
-    }
-    // assert that all table A segments returned successfully.
-    Assert.assertEquals(rowCount, 15);
-    // assert that the next block is null (e.g. finished receiving).
-    dataTableBlock = mailboxReceiveOperator.nextBlock();
-    Assert.assertTrue(DataTableBlockUtils.isEndOfStream(dataTableBlock));
-  }
-
-  @Test
-  public void testJoin()
-      throws Exception {
-    QueryPlan queryPlan = _queryEnvironment.planQuery("SELECT * FROM a JOIN b on a.col1 = b.col2");
+  @Test(dataProvider = "testDataWithSqlToFinalRowCount")
+  public void testSqlWithFinalRowCountChecker(String sql, int expectedRowCount) {
+    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
     Map<String, String> requestMetadataMap =
         ImmutableMap.of("REQUEST_ID", String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()));
     MailboxReceiveOperator mailboxReceiveOperator = null;
@@ -187,85 +124,52 @@ public class QueryRunnerTest {
     }
     Preconditions.checkNotNull(mailboxReceiveOperator);
 
-    int count = 0;
-    int rowCount = 0;
-    List<Object[]> resultRows = new ArrayList<>();
-    DataTableBlock dataTableBlock;
-    while (count < 2) { // we have 2 servers sending data.
-      dataTableBlock = mailboxReceiveOperator.nextBlock();
-      if (dataTableBlock.getDataTable() != null) {
-        DataTable dataTable = dataTableBlock.getDataTable();
-        int numRows = dataTable.getNumberOfRows();
-        for (int rowId = 0; rowId < numRows; rowId++) {
-          resultRows.add(extractRowFromDataTable(dataTable, rowId));
-        }
-        rowCount += numRows;
-      }
-      count++;
-    }
-
-    // Assert that each of the 5 categories from left table is joined with right table.
-    // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
-    // thus the final JOIN result will be 15 x 1 = 15.
-    Assert.assertEquals(rowCount, 15);
-
-    // assert that the next block is null (e.g. finished receiving).
-    dataTableBlock = mailboxReceiveOperator.nextBlock();
-    Assert.assertTrue(DataTableBlockUtils.isEndOfStream(dataTableBlock));
+    List<Object[]> resultRows = reduceMailboxReceive(mailboxReceiveOperator);
+    Assert.assertEquals(resultRows.size(), expectedRowCount);
   }
 
-  @Test
-  public void testMultipleJoin()
-      throws Exception {
-    QueryPlan queryPlan =
-        _queryEnvironment.planQuery("SELECT * FROM a JOIN b ON a.col1 = b.col2 " + "JOIN c ON a.col3 = c.col3");
-    Map<String, String> requestMetadataMap =
-        ImmutableMap.of("REQUEST_ID", String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()));
-    MailboxReceiveOperator mailboxReceiveOperator = null;
-    for (int stageId : queryPlan.getStageMetadataMap().keySet()) {
-      if (queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode) {
-        MailboxReceiveNode reduceNode = (MailboxReceiveNode) queryPlan.getQueryStageMap().get(stageId);
-        mailboxReceiveOperator = createReduceStageOperator(
-            queryPlan.getStageMetadataMap().get(reduceNode.getSenderStageId()).getServerInstances(),
-            Long.parseLong(requestMetadataMap.get("REQUEST_ID")), reduceNode.getSenderStageId(), _reducerGrpcPort);
-      } else {
-        for (ServerInstance serverInstance : queryPlan.getStageMetadataMap().get(stageId).getServerInstances()) {
-          DistributedStagePlan distributedStagePlan =
-              QueryDispatcher.constructDistributedStagePlan(queryPlan, stageId, serverInstance);
-          _servers.get(serverInstance).processQuery(distributedStagePlan, requestMetadataMap);
-        }
-      }
-    }
-    Preconditions.checkNotNull(mailboxReceiveOperator);
+  @DataProvider(name = "testDataWithSqlToFinalRowCount")
+  private Object[][] provideTestSqlAndRowCount() {
+    return new Object[][] {
+        new Object[]{"SELECT * FROM b", 5},
+        new Object[]{"SELECT * FROM a", 15},
+
+        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
+        // thus the final JOIN result will be 15 x 1 = 15.
+        // Next join with table C which has (5 on server1 and 10 on server2), since data is identical. each of the row
+        // of the A JOIN B will have identical value of col3 as table C.col3 has. Since the values are cycling between
+        // (1, 2, 42, 1, 2). we will have 6 1s, 6 2s, and 3 42s, total result count will be 36 + 36 + 9 = 81
+        new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col2 JOIN c ON a.col3 = c.col3", 81},
+
+        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
+        // thus the final JOIN result will be 15 x 1 = 15.
+        new Object[]{"SELECT * FROM a JOIN b on a.col1 = b.col2", 15},
+
+        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
+        // but only 1 out of 5 rows from table A will be selected out; and all in table B will be selected.
+        // thus the final JOIN result will be 1 x 3 x 1 = 3.
+        new Object[]{"SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
+            + " WHERE a.col3 >= 0 AND a.col2 = 'foo' AND b.col3 >= 0", 3},
+    };
+  }
 
-    int count = 0;
-    int rowCount = 0;
+  protected static List<Object[]> reduceMailboxReceive(MailboxReceiveOperator mailboxReceiveOperator) {
     List<Object[]> resultRows = new ArrayList<>();
     DataTableBlock dataTableBlock;
-    while (count < 2) { // we have 2 servers sending data.
+    while (true) {
       dataTableBlock = mailboxReceiveOperator.nextBlock();
+      if (DataTableBlockUtils.isEndOfStream(dataTableBlock)) {
+        break;
+      }
       if (dataTableBlock.getDataTable() != null) {
         DataTable dataTable = dataTableBlock.getDataTable();
         int numRows = dataTable.getNumberOfRows();
         for (int rowId = 0; rowId < numRows; rowId++) {
           resultRows.add(extractRowFromDataTable(dataTable, rowId));
         }
-        rowCount += numRows;
       }
-      count++;
     }
-
-    // Assert that each of the 5 categories from left table is joined with right table.
-    // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
-    // thus the final JOIN result will be 15 x 1 = 15.
-    // Next join with table C which has (5 on server1 and 10 on server2), since data is identical. each of the row of
-    // the A JOIN B will have identical value of col3 as table C.col3 has. Since the values are cycling between
-    // (1, 2, 42, 1, 2). we will have 6 1s, 6 2s, and 3 42s, total result count will be 36 + 36 + 9 = 81
-    Assert.assertEquals(rowCount, 81);
-
-    // assert that the next block is null (e.g. finished receiving).
-    dataTableBlock = mailboxReceiveOperator.nextBlock();
-    Assert.assertTrue(DataTableBlockUtils.isEndOfStream(dataTableBlock));
+    return resultRows;
   }
 
   protected MailboxReceiveOperator createReduceStageOperator(List<ServerInstance> sendingInstances, long jobId,
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
index eeed6d20d9..735c54dd0c 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
@@ -24,6 +24,7 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.concurrent.ExecutorService;
 import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
 import org.apache.pinot.common.proto.Worker;
@@ -32,7 +33,7 @@ import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryEnvironmentTestUtils;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.planner.nodes.StageNode;
+import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.routing.WorkerInstance;
 import org.apache.pinot.query.runtime.QueryRunner;
 import org.apache.pinot.query.runtime.plan.serde.QueryPlanSerDeUtils;
@@ -41,12 +42,14 @@ import org.mockito.Mockito;
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 import static org.mockito.ArgumentMatchers.any;
 
 
 public class QueryServerTest {
+  private static final Random RANDOM_REQUEST_ID_GEN = new Random();
   private static final int QUERY_SERVER_COUNT = 2;
   private final Map<Integer, QueryServer> _queryServerMap = new HashMap<>();
   private final Map<Integer, ServerInstance> _queryServerInstanceMap = new HashMap<>();
@@ -84,13 +87,14 @@ public class QueryServerTest {
   }
 
   @SuppressWarnings("unchecked")
-  @Test
-  public void testWorkerAcceptsWorkerRequestCorrect()
+  @Test(dataProvider = "testDataWithSqlToCompiledAsWorkerRequest")
+  public void testWorkerAcceptsWorkerRequestCorrect(String sql)
       throws Exception {
-    QueryPlan queryPlan = _queryEnvironment.planQuery("SELECT * FROM a JOIN b ON a.col1 = b.col2");
+    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
 
     for (int stageId : queryPlan.getStageMetadataMap().keySet()) {
       if (stageId > 0) { // we do not test reduce stage.
+        // only get one worker request out.
         Worker.QueryRequest queryRequest = getQueryRequest(queryPlan, stageId);
 
         // submit the request for testing.
@@ -99,24 +103,39 @@ public class QueryServerTest {
         StageMetadata stageMetadata = queryPlan.getStageMetadataMap().get(stageId);
 
         // ensure mock query runner received correctly deserialized payload.
+        QueryRunner mockRunner = _queryRunnerMap.get(
+            Integer.parseInt(queryRequest.getMetadataOrThrow("SERVER_INSTANCE_PORT")));
+        String requestIdStr = queryRequest.getMetadataOrThrow("REQUEST_ID");
+
         // since submitRequest is async, we need to wait for the mockRunner to receive the query payload.
-        QueryRunner mockRunner = _queryRunnerMap.get(stageMetadata.getServerInstances().get(0).getPort());
         TestUtils.waitForCondition(aVoid -> {
           try {
             Mockito.verify(mockRunner).processQuery(Mockito.argThat(distributedStagePlan -> {
               StageNode stageNode = queryPlan.getQueryStageMap().get(stageId);
-              return isStageNodesEqual(stageNode, distributedStagePlan.getStageRoot()) && isMetadataMapsEqual(
-                  stageMetadata, distributedStagePlan.getMetadataMap().get(stageId));
-            }), any(ExecutorService.class), any(Map.class));
+              return isStageNodesEqual(stageNode, distributedStagePlan.getStageRoot())
+                  && isMetadataMapsEqual(stageMetadata, distributedStagePlan.getMetadataMap().get(stageId));
+            }), any(ExecutorService.class), Mockito.argThat(requestMetadataMap ->
+                requestIdStr.equals(requestMetadataMap.get("REQUEST_ID"))));
             return true;
           } catch (Throwable t) {
             return false;
           }
-        }, 1000L, "Error verifying mock QueryRunner intercepted query payload!");
+        }, 10000L, "Error verifying mock QueryRunner intercepted query payload!");
       }
     }
   }
 
+  @DataProvider(name = "testDataWithSqlToCompiledAsWorkerRequest")
+  private Object[][] provideTestSqlToCompiledToWorkerRequest() {
+    return new Object[][] {
+        new Object[]{"SELECT * FROM b"},
+        new Object[]{"SELECT * FROM a"},
+        new Object[]{"SELECT * FROM a JOIN b ON a.col3 = b.col3"},
+        new Object[]{"SELECT a.col1, a.ts, c.col2, c.col3 FROM a JOIN c ON a.col1 = c.col2 "
+            + " WHERE (a.col3 >= 0 OR a.col2 = 'foo') AND c.col3 >= 0"},
+    };
+  }
+
   private static boolean isMetadataMapsEqual(StageMetadata left, StageMetadata right) {
     return left.getServerInstances().equals(right.getServerInstances())
         && left.getServerInstanceToSegmentsMap().equals(right.getServerInstanceToSegmentsMap())
@@ -124,6 +143,8 @@ public class QueryServerTest {
   }
 
   private static boolean isStageNodesEqual(StageNode left, StageNode right) {
+    // This only checks the stage tree structure is correct. because the input/stageId fields are not
+    // part of the generic proto ser/de; which is tested in query planner.
     if (left.getStageId() != right.getStageId() || left.getClass() != right.getClass()
         || left.getInputs().size() != right.getInputs().size()) {
       return false;
@@ -153,6 +174,7 @@ public class QueryServerTest {
 
     return Worker.QueryRequest.newBuilder().setStagePlan(QueryPlanSerDeUtils.serialize(
             QueryDispatcher.constructDistributedStagePlan(queryPlan, stageId, serverInstance)))
+        .putMetadata("REQUEST_ID", String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()))
         .putMetadata("SERVER_INSTANCE_HOST", serverInstance.getHostname())
         .putMetadata("SERVER_INSTANCE_PORT", String.valueOf(serverInstance.getPort())).build();
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org