You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by si...@apache.org on 2022/06/30 22:58:43 UTC

[pinot] branch master updated: Add data schema to stage nodes (#8985)

This is an automated email from the ASF dual-hosted git repository.

siddteotia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new d86d4abc63 Add data schema to stage nodes (#8985)
d86d4abc63 is described below

commit d86d4abc63e84857f3d2980799feadf86c9ad080
Author: Rong Rong <wa...@gmail.com>
AuthorDate: Thu Jun 30 15:58:38 2022 -0700

    Add data schema to stage nodes (#8985)
    
    * initial commit to add data schema to stage nodes
    
    * fix merge conflicts
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 pinot-common/src/main/proto/plan.proto             |  4 +-
 .../query/planner/logical/RelToStageConverter.java | 60 +++++++++++++++++++---
 .../pinot/query/planner/logical/StagePlanner.java  | 14 ++---
 .../query/planner/stage/AbstractStageNode.java     | 20 +++++++-
 .../pinot/query/planner/stage/AggregateNode.java   |  5 +-
 .../pinot/query/planner/stage/FilterNode.java      |  5 +-
 .../apache/pinot/query/planner/stage/JoinNode.java |  5 +-
 .../query/planner/stage/MailboxReceiveNode.java    |  5 +-
 .../pinot/query/planner/stage/MailboxSendNode.java |  9 ++--
 .../pinot/query/planner/stage/ProjectNode.java     |  5 +-
 .../pinot/query/planner/stage/StageNode.java       |  7 ++-
 .../query/planner/stage/StageNodeSerDeUtils.java   | 17 ++++++
 .../pinot/query/planner/stage/TableScanNode.java   |  5 +-
 .../pinot/query/planner/stage/SerDeUtilsTest.java  |  3 ++
 14 files changed, 132 insertions(+), 32 deletions(-)

diff --git a/pinot-common/src/main/proto/plan.proto b/pinot-common/src/main/proto/plan.proto
index 8e75a31a42..783e37b8b0 100644
--- a/pinot-common/src/main/proto/plan.proto
+++ b/pinot-common/src/main/proto/plan.proto
@@ -43,7 +43,9 @@ message StageNode {
   int32 stageId = 1;
   string nodeName = 2;
   repeated StageNode inputs = 3;
-  ObjectField objectField = 4;
+  repeated string columnNames = 4;
+  repeated string columnDataTypes = 5;
+  ObjectField objectField = 6;
 }
 
 // MemberVariableField defines the serialized format of the member variables of a class object.
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
index 23f8fb6db8..5cf10a879c 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
@@ -29,8 +29,11 @@ import org.apache.calcite.rel.logical.LogicalFilter;
 import org.apache.calcite.rel.logical.LogicalJoin;
 import org.apache.calcite.rel.logical.LogicalProject;
 import org.apache.calcite.rel.logical.LogicalTableScan;
+import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rel.type.RelRecordType;
 import org.apache.calcite.rex.RexCall;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.PlannerUtils;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
 import org.apache.pinot.query.planner.stage.AggregateNode;
@@ -75,22 +78,23 @@ public final class RelToStageConverter {
   }
 
   private static StageNode convertLogicalAggregate(LogicalAggregate node, int currentStageId) {
-    return new AggregateNode(currentStageId, node.getAggCallList(), node.getGroupSet());
+    return new AggregateNode(currentStageId, toDataSchema(node.getRowType()), node.getAggCallList(),
+        node.getGroupSet());
   }
 
   private static StageNode convertLogicalProject(LogicalProject node, int currentStageId) {
-    return new ProjectNode(currentStageId, node.getProjects());
+    return new ProjectNode(currentStageId, toDataSchema(node.getRowType()), node.getProjects());
   }
 
   private static StageNode convertLogicalFilter(LogicalFilter node, int currentStageId) {
-    return new FilterNode(currentStageId, node.getCondition());
+    return new FilterNode(currentStageId, toDataSchema(node.getRowType()), node.getCondition());
   }
 
   private static StageNode convertLogicalTableScan(LogicalTableScan node, int currentStageId) {
     String tableName = node.getTable().getQualifiedName().get(0);
     List<String> columnNames = node.getRowType().getFieldList().stream()
         .map(RelDataTypeField::getName).collect(Collectors.toList());
-    return new TableScanNode(currentStageId, tableName, columnNames);
+    return new TableScanNode(currentStageId, toDataSchema(node.getRowType()), tableName, columnNames);
   }
 
   private static StageNode convertLogicalJoin(LogicalJoin node, int currentStageId) {
@@ -104,7 +108,51 @@ public final class RelToStageConverter {
 
     FieldSelectionKeySelector leftFieldSelectionKeySelector = new FieldSelectionKeySelector(predicateColumns.get(0));
     FieldSelectionKeySelector rightFieldSelectionKeySelector = new FieldSelectionKeySelector(predicateColumns.get(1));
-    return new JoinNode(currentStageId, joinType, Collections.singletonList(new JoinNode.JoinClause(
-        leftFieldSelectionKeySelector, rightFieldSelectionKeySelector)));
+    return new JoinNode(currentStageId, toDataSchema(node.getRowType()), joinType, Collections.singletonList(
+        new JoinNode.JoinClause(leftFieldSelectionKeySelector, rightFieldSelectionKeySelector)));
+  }
+
+  private static DataSchema toDataSchema(RelDataType rowType) {
+    if (rowType instanceof RelRecordType) {
+      RelRecordType recordType = (RelRecordType) rowType;
+      String[] columnNames = recordType.getFieldNames().toArray(new String[]{});
+      DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[columnNames.length];
+      for (int i = 0; i < columnNames.length; i++) {
+        columnDataTypes[i] = convertColumnDataType(recordType.getFieldList().get(i));
+      }
+      return new DataSchema(columnNames, columnDataTypes);
+    } else {
+      throw new IllegalArgumentException("Unsupported RelDataType: " + rowType);
+    }
+  }
+
+  private static DataSchema.ColumnDataType convertColumnDataType(RelDataTypeField relDataTypeField) {
+    switch (relDataTypeField.getType().getSqlTypeName()) {
+      case BOOLEAN:
+        return DataSchema.ColumnDataType.BOOLEAN;
+      case TINYINT:
+      case SMALLINT:
+      case INTEGER:
+        return DataSchema.ColumnDataType.INT;
+      case BIGINT:
+        return DataSchema.ColumnDataType.LONG;
+      case DECIMAL:
+        return DataSchema.ColumnDataType.BIG_DECIMAL;
+      case FLOAT:
+        return DataSchema.ColumnDataType.FLOAT;
+      case REAL:
+      case DOUBLE:
+        return DataSchema.ColumnDataType.DOUBLE;
+      case DATE:
+      case TIME:
+      case TIMESTAMP:
+        return DataSchema.ColumnDataType.TIMESTAMP;
+      case VARCHAR:
+        return DataSchema.ColumnDataType.STRING;
+      case BINARY:
+        return DataSchema.ColumnDataType.BYTES;
+      default:
+        throw new IllegalStateException("Unexpected RelDataTypeField: " + relDataTypeField.getType());
+    }
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
index 8d9bbbc51d..02c2e7fb28 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
@@ -72,9 +72,10 @@ public class StagePlanner {
     // global root needs to send results back to the ROOT, a.k.a. the client response node. the last stage only has one
     // receiver so doesn't matter what the exchange type is. setting it to SINGLETON by default.
     StageNode globalReceiverNode =
-        new MailboxReceiveNode(0, globalStageRoot.getStageId(), RelDistribution.Type.SINGLETON);
-    StageNode globalSenderNode = new MailboxSendNode(globalStageRoot.getStageId(), globalReceiverNode.getStageId(),
-        RelDistribution.Type.SINGLETON);
+        new MailboxReceiveNode(0, globalStageRoot.getDataSchema(), globalStageRoot.getStageId(),
+            RelDistribution.Type.SINGLETON);
+    StageNode globalSenderNode = new MailboxSendNode(globalStageRoot.getStageId(), globalStageRoot.getDataSchema(),
+        globalReceiverNode.getStageId(), RelDistribution.Type.SINGLETON);
     globalSenderNode.addInput(globalStageRoot);
     _queryStageMap.put(globalSenderNode.getStageId(), globalSenderNode);
     StageMetadata stageMetadata = _stageMetadataMap.get(globalSenderNode.getStageId());
@@ -103,9 +104,10 @@ public class StagePlanner {
       RelDistribution.Type exchangeType = distribution.getType();
 
       // 2. make an exchange sender and receiver node pair
-      StageNode mailboxReceiver = new MailboxReceiveNode(currentStageId, nextStageRoot.getStageId(), exchangeType);
-      StageNode mailboxSender = new MailboxSendNode(nextStageRoot.getStageId(), mailboxReceiver.getStageId(),
-          exchangeType, exchangeType == RelDistribution.Type.HASH_DISTRIBUTED
+      StageNode mailboxReceiver = new MailboxReceiveNode(currentStageId, nextStageRoot.getDataSchema(),
+          nextStageRoot.getStageId(), exchangeType);
+      StageNode mailboxSender = new MailboxSendNode(nextStageRoot.getStageId(), nextStageRoot.getDataSchema(),
+          mailboxReceiver.getStageId(), exchangeType, exchangeType == RelDistribution.Type.HASH_DISTRIBUTED
           ? new FieldSelectionKeySelector(distributionKeys) : null);
       mailboxSender.addInput(nextStageRoot);
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
index 0f84a10f68..1de069f0b4 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AbstractStageNode.java
@@ -21,6 +21,7 @@ package org.apache.pinot.query.planner.stage;
 import java.util.ArrayList;
 import java.util.List;
 import org.apache.pinot.common.proto.Plan;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.serde.ProtoSerializable;
 import org.apache.pinot.query.planner.serde.ProtoSerializationUtils;
 
@@ -29,12 +30,23 @@ public abstract class AbstractStageNode implements StageNode, ProtoSerializable
 
   protected final int _stageId;
   protected final List<StageNode> _inputs;
+  protected DataSchema _dataSchema;
 
   public AbstractStageNode(int stageId) {
+    this(stageId, null);
+  }
+
+  public AbstractStageNode(int stageId, DataSchema dataSchema) {
     _stageId = stageId;
+    _dataSchema = dataSchema;
     _inputs = new ArrayList<>();
   }
 
+  @Override
+  public int getStageId() {
+    return _stageId;
+  }
+
   @Override
   public List<StageNode> getInputs() {
     return _inputs;
@@ -46,8 +58,12 @@ public abstract class AbstractStageNode implements StageNode, ProtoSerializable
   }
 
   @Override
-  public int getStageId() {
-    return _stageId;
+  public DataSchema getDataSchema() {
+    return _dataSchema;
+  }
+
+  public void setDataSchema(DataSchema dataSchema) {
+    _dataSchema = dataSchema;
   }
 
   @Override
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
index ae41d14a79..d0a28b0cbd 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
@@ -24,6 +24,7 @@ import java.util.List;
 import java.util.stream.Collectors;
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
 
@@ -38,8 +39,8 @@ public class AggregateNode extends AbstractStageNode {
     super(stageId);
   }
 
-  public AggregateNode(int stageId, List<AggregateCall> aggCalls, ImmutableBitSet groupSet) {
-    super(stageId);
+  public AggregateNode(int stageId, DataSchema dataSchema, List<AggregateCall> aggCalls, ImmutableBitSet groupSet) {
+    super(stageId, dataSchema);
     _aggCalls = aggCalls.stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
     _groupSet = new ArrayList<>(groupSet.cardinality());
     Iterator<Integer> groupSetIt = groupSet.iterator();
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/FilterNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/FilterNode.java
index c169a61970..2281d4e712 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/FilterNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/FilterNode.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.planner.stage;
 
 import org.apache.calcite.rex.RexNode;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
 
@@ -31,8 +32,8 @@ public class FilterNode extends AbstractStageNode {
     super(stageId);
   }
 
-  public FilterNode(int currentStageId, RexNode condition) {
-    super(currentStageId);
+  public FilterNode(int currentStageId, DataSchema dataSchema, RexNode condition) {
+    super(currentStageId, dataSchema);
     _condition = RexExpression.toRexExpression(condition);
   }
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
index 0f9e007871..eeec1b31e9 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/JoinNode.java
@@ -20,6 +20,7 @@ package org.apache.pinot.query.planner.stage;
 
 import java.util.List;
 import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
@@ -35,8 +36,8 @@ public class JoinNode extends AbstractStageNode {
     super(stageId);
   }
 
-  public JoinNode(int stageId, JoinRelType joinRelType, List<JoinClause> criteria) {
-    super(stageId);
+  public JoinNode(int stageId, DataSchema dataSchema, JoinRelType joinRelType, List<JoinClause> criteria) {
+    super(stageId, dataSchema);
     _joinRelType = joinRelType;
     _criteria = criteria;
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
index edadf30570..abba178865 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxReceiveNode.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.planner.stage;
 
 import org.apache.calcite.rel.RelDistribution;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
@@ -32,9 +33,9 @@ public class MailboxReceiveNode extends AbstractStageNode {
     super(stageId);
   }
 
-  public MailboxReceiveNode(int stageId, int senderStageId,
+  public MailboxReceiveNode(int stageId, DataSchema dataSchema, int senderStageId,
       RelDistribution.Type exchangeType) {
-    super(stageId);
+    super(stageId, dataSchema);
     _senderStageId = senderStageId;
     _exchangeType = exchangeType;
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
index 1400b61f82..962dbc73c4 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/MailboxSendNode.java
@@ -20,6 +20,7 @@ package org.apache.pinot.query.planner.stage;
 
 import javax.annotation.Nullable;
 import org.apache.calcite.rel.RelDistribution;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
 
@@ -36,15 +37,15 @@ public class MailboxSendNode extends AbstractStageNode {
     super(stageId);
   }
 
-  public MailboxSendNode(int stageId, int receiverStageId,
+  public MailboxSendNode(int stageId, DataSchema dataSchema, int receiverStageId,
       RelDistribution.Type exchangeType) {
     // When exchangeType is not HASH_DISTRIBUTE, no partitionKeySelector is needed.
-    this(stageId, receiverStageId, exchangeType, null);
+    this(stageId, dataSchema, receiverStageId, exchangeType, null);
   }
 
-  public MailboxSendNode(int stageId, int receiverStageId,
+  public MailboxSendNode(int stageId, DataSchema dataSchema, int receiverStageId,
       RelDistribution.Type exchangeType, @Nullable KeySelector<Object[], Object[]> partitionKeySelector) {
-    super(stageId);
+    super(stageId, dataSchema);
     _receiverStageId = receiverStageId;
     _exchangeType = exchangeType;
     _partitionKeySelector = partitionKeySelector;
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/ProjectNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/ProjectNode.java
index 9a026aae18..1b1f88b6d0 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/ProjectNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/ProjectNode.java
@@ -21,6 +21,7 @@ package org.apache.pinot.query.planner.stage;
 import java.util.List;
 import java.util.stream.Collectors;
 import org.apache.calcite.rex.RexNode;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
 
@@ -32,8 +33,8 @@ public class ProjectNode extends AbstractStageNode {
   public ProjectNode(int stageId) {
     super(stageId);
   }
-  public ProjectNode(int currentStageId, List<RexNode> projects) {
-    super(currentStageId);
+  public ProjectNode(int currentStageId, DataSchema dataSchema, List<RexNode> projects) {
+    super(currentStageId, dataSchema);
     _projects = projects.stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
   }
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
index 45e65a8c21..6efa59ce2d 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNode.java
@@ -20,6 +20,7 @@ package org.apache.pinot.query.planner.stage;
 
 import java.io.Serializable;
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
 
 
 /**
@@ -32,9 +33,13 @@ import java.util.List;
  */
 public interface StageNode extends Serializable {
 
+  int getStageId();
+
   List<StageNode> getInputs();
 
   void addInput(StageNode stageNode);
 
-  int getStageId();
+  DataSchema getDataSchema();
+
+  void setDataSchema(DataSchema dataSchema);
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
index 8d341a207c..c76792371f 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/StageNodeSerDeUtils.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.planner.stage;
 
 import org.apache.pinot.common.proto.Plan;
+import org.apache.pinot.common.utils.DataSchema;
 
 
 public final class StageNodeSerDeUtils {
@@ -28,6 +29,7 @@ public final class StageNodeSerDeUtils {
 
   public static AbstractStageNode deserializeStageNode(Plan.StageNode protoNode) {
     AbstractStageNode stageNode = newNodeInstance(protoNode.getNodeName(), protoNode.getStageId());
+    stageNode.setDataSchema(extractDataSchema(protoNode));
     stageNode.fromObjectField(protoNode.getObjectField());
     for (Plan.StageNode protoChild : protoNode.getInputsList()) {
       stageNode.addInput(deserializeStageNode(protoChild));
@@ -40,12 +42,27 @@ public final class StageNodeSerDeUtils {
         .setStageId(stageNode.getStageId())
         .setNodeName(stageNode.getClass().getSimpleName())
         .setObjectField(stageNode.toObjectField());
+    DataSchema dataSchema = stageNode.getDataSchema();
+    for (int i = 0; i < dataSchema.getColumnNames().length; i++) {
+      builder.addColumnNames(dataSchema.getColumnName(i));
+      builder.addColumnDataTypes(dataSchema.getColumnDataType(i).name());
+    }
     for (StageNode childNode : stageNode.getInputs()) {
       builder.addInputs(serializeStageNode((AbstractStageNode) childNode));
     }
     return builder.build();
   }
 
+  private static DataSchema extractDataSchema(Plan.StageNode protoNode) {
+    String[] columnDataTypesList = protoNode.getColumnDataTypesList().toArray(new String[]{});
+    String[] columnNames = protoNode.getColumnNamesList().toArray(new String[]{});
+    DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[columnNames.length];
+    for (int i = 0; i < columnNames.length; i++) {
+      columnDataTypes[i] = DataSchema.ColumnDataType.valueOf(columnDataTypesList[i]);
+    }
+    return new DataSchema(columnNames, columnDataTypes);
+  }
+
   private static AbstractStageNode newNodeInstance(String nodeName, int stageId) {
     switch (nodeName) {
       case "TableScanNode":
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/TableScanNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/TableScanNode.java
index 7151f84f56..01dda45d3a 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/TableScanNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/TableScanNode.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.planner.stage;
 
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
@@ -32,8 +33,8 @@ public class TableScanNode extends AbstractStageNode {
     super(stageId);
   }
 
-  public TableScanNode(int stageId, String tableName, List<String> tableScanColumns) {
-    super(stageId);
+  public TableScanNode(int stageId, DataSchema dataSchema, String tableName, List<String> tableScanColumns) {
+    super(stageId, dataSchema);
     _tableName = tableName;
     _tableScanColumns = tableScanColumns;
   }
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/stage/SerDeUtilsTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/stage/SerDeUtilsTest.java
index db9294252c..7bd9fcb6ad 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/stage/SerDeUtilsTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/stage/SerDeUtilsTest.java
@@ -39,6 +39,9 @@ public class SerDeUtilsTest extends QueryEnvironmentTestBase {
       Plan.StageNode serializedStageNode = StageNodeSerDeUtils.serializeStageNode((AbstractStageNode) stageNode);
       StageNode deserializedStageNode = StageNodeSerDeUtils.deserializeStageNode(serializedStageNode);
       Assert.assertTrue(isObjectEqual(stageNode, deserializedStageNode));
+      Assert.assertEquals(deserializedStageNode.getStageId(), stageNode.getStageId());
+      Assert.assertEquals(deserializedStageNode.getDataSchema(), stageNode.getDataSchema());
+      Assert.assertEquals(deserializedStageNode.getInputs().size(), stageNode.getInputs().size());
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org