You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by xi...@apache.org on 2022/05/19 08:08:46 UTC

[iotdb] 02/02: complete basic GroupByLevel distribution

This is an automated email from the ASF dual-hosted git repository.

xingtanzjr pushed a commit to branch xingtanzjr/agg_distribution_plan
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 8da9a7b64c79646b534890a4b8172bcc594e8f3a
Author: Jinrui.Zhang <xi...@gmail.com>
AuthorDate: Thu May 19 16:08:36 2022 +0800

    complete basic GroupByLevel distribution
---
 .../db/mpp/plan/planner/DistributionPlanner.java   | 158 +++++++++++++++++----
 .../plan/parameter/AggregationDescriptor.java      |   6 +-
 .../db/mpp/plan/plan/DistributionPlannerTest.java  |   1 -
 3 files changed, 135 insertions(+), 30 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/DistributionPlanner.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/DistributionPlanner.java
index 26ac5e456e..12d3bc8493 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/DistributionPlanner.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/DistributionPlanner.java
@@ -19,6 +19,7 @@
 package org.apache.iotdb.db.mpp.plan.planner;
 
 import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
+import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.db.mpp.common.MPPQueryContext;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
@@ -42,6 +43,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryM
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.AggregationNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.GroupByLevelNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.MultiChildNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
@@ -55,6 +57,8 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SourceNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationDescriptor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
 import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
+import org.apache.iotdb.db.query.expression.Expression;
+import org.apache.iotdb.db.query.expression.leaf.TimeSeriesOperand;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -413,10 +417,12 @@ public class DistributionPlanner {
     private PlanNode planAggregationWithTimeJoin(
         TimeJoinNode root, DistributionPlanContext context) {
 
-      // Step 1: construct AggregationDescriptor for AggregationNode
+      List<SeriesAggregationSourceNode> sources = splitAggregationSourceByPartition(root, context);
+      Map<TRegionReplicaSet, List<SeriesAggregationSourceNode>> sourceGroup =
+          sources.stream().collect(Collectors.groupingBy(SourceNode::getRegionReplicaSet));
+
+      // construct AggregationDescriptor for AggregationNode
       List<AggregationDescriptor> rootAggDescriptorList = new ArrayList<>();
-      List<SeriesAggregationSourceNode> sources = new ArrayList<>();
-      Map<PartialPath, Integer> regionCountPerSeries = new HashMap<>();
       for (PlanNode child : root.getChildren()) {
         SeriesAggregationSourceNode handle = (SeriesAggregationSourceNode) child;
         handle
@@ -429,32 +435,7 @@ public class DistributionPlanner {
                           AggregationStep.FINAL,
                           descriptor.getInputExpressions()));
                 });
-        List<TRegionReplicaSet> dataDistribution =
-            analysis.getPartitionInfo(handle.getPartitionPath(), handle.getPartitionTimeFilter());
-        for (TRegionReplicaSet dataRegion : dataDistribution) {
-          SeriesAggregationSourceNode split = (SeriesAggregationSourceNode) handle.clone();
-          split.setPlanNodeId(context.queryContext.getQueryId().genPlanNodeId());
-          split.setRegionReplicaSet(dataRegion);
-          // Let each split reference different object of AggregationDescriptorList
-          split.setAggregationDescriptorList(
-              handle.getAggregationDescriptorList().stream()
-                  .map(AggregationDescriptor::deepClone)
-                  .collect(Collectors.toList()));
-          sources.add(split);
-        }
-        regionCountPerSeries.put(handle.getPartitionPath(), dataDistribution.size());
-      }
-
-      // Step 2: change the step for each SeriesAggregationSourceNode according to its split count
-      for (SeriesAggregationSourceNode node : sources) {
-        boolean isFinal = regionCountPerSeries.get(node.getPartitionPath()) == 1;
-        node.getAggregationDescriptorList()
-            .forEach(d -> d.setStep(isFinal ? AggregationStep.FINAL : AggregationStep.PARTIAL));
       }
-
-      Map<TRegionReplicaSet, List<SeriesAggregationSourceNode>> sourceGroup =
-          sources.stream().collect(Collectors.groupingBy(SourceNode::getRegionReplicaSet));
-
       AggregationNode aggregationNode =
           new AggregationNode(
               context.queryContext.getQueryId().genPlanNodeId(), rootAggDescriptorList);
@@ -482,6 +463,127 @@ public class DistributionPlanner {
       return aggregationNode;
     }
 
+    public PlanNode visitGroupByLevel(GroupByLevelNode root, DistributionPlanContext context) {
+      // Firstly, we build the tree structure for GroupByLevelNode
+      List<SeriesAggregationSourceNode> sources = splitAggregationSourceByPartition(root, context);
+      Map<TRegionReplicaSet, List<SeriesAggregationSourceNode>> sourceGroup =
+          sources.stream().collect(Collectors.groupingBy(SourceNode::getRegionReplicaSet));
+
+      GroupByLevelNode newRoot = (GroupByLevelNode) root.clone();
+      final boolean[] addParent = {false};
+      sourceGroup.forEach(
+          (dataRegion, sourceNodes) -> {
+            if (sourceNodes.size() == 1) {
+              newRoot.addChild(sourceNodes.get(0));
+            } else {
+              if (!addParent[0]) {
+                sourceNodes.forEach(newRoot::addChild);
+                addParent[0] = true;
+              } else {
+                // We clone a TimeJoinNode from root to make the params to be consistent.
+                // But we need to assign a new ID to it
+                GroupByLevelNode parentOfGroup = (GroupByLevelNode) root.clone();
+                parentOfGroup.setPlanNodeId(context.queryContext.getQueryId().genPlanNodeId());
+                sourceNodes.forEach(parentOfGroup::addChild);
+                newRoot.addChild(parentOfGroup);
+              }
+            }
+          });
+
+      // Then, we calculate the attributes for GroupByLevelNode in each level
+      calculateGroupByLevelNodeAttributes(newRoot, 0);
+      return null;
+    }
+
+    private void calculateGroupByLevelNodeAttributes(PlanNode node, int level) {
+      if (node == null) {
+        return;
+      }
+      node.getChildren().forEach(child -> calculateGroupByLevelNodeAttributes(child, level + 1));
+      if (!(node instanceof GroupByLevelNode)) {
+        return;
+      }
+      GroupByLevelNode handle = (GroupByLevelNode) node;
+
+      // Construct all outputColumns from children
+      List<String> childrenOutputColumns = new ArrayList<>();
+      handle
+          .getChildren()
+          .forEach(child -> childrenOutputColumns.addAll(child.getOutputColumnNames()));
+
+      // Check every OutputColumn of GroupByLevelNode and set the Expression of corresponding
+      // AggregationDescriptor
+      List<String> outputColumnList = new ArrayList<>();
+      List<AggregationDescriptor> descriptorList = new ArrayList<>();
+      for (int i = 0; i < handle.getOutputColumnNames().size(); i++) {
+        String column = handle.getOutputColumnNames().get(i);
+        Set<Expression> originalExpressions =
+            analysis.getAggregationExpressions().getOrDefault(column, new HashSet<>());
+        List<Expression> descriptorExpression = new ArrayList<>();
+        for (String childColumn : childrenOutputColumns) {
+          if (childColumn.equals(column)) {
+            try {
+              descriptorExpression.add(new TimeSeriesOperand(new PartialPath(childColumn)));
+            } catch (IllegalPathException e) {
+              throw new RuntimeException("error when plan distribution aggregation query", e);
+            }
+            continue;
+          }
+          for (Expression exp : originalExpressions) {
+            if (exp.getExpressionString().equals(childColumn)) {
+              descriptorExpression.add(exp);
+            }
+          }
+        }
+        if (descriptorExpression.size() == 0) {
+          continue;
+        }
+        AggregationDescriptor descriptor = handle.getAggregationDescriptorList().get(i).deepClone();
+        descriptor.setStep(level == 0 ? AggregationStep.FINAL : AggregationStep.PARTIAL);
+        descriptor.setInputExpressions(descriptorExpression);
+
+        outputColumnList.add(column);
+        descriptorList.add(descriptor);
+      }
+      handle.getOutputColumnNames().clear();
+      handle.getOutputColumnNames().addAll(outputColumnList);
+      handle.getAggregationDescriptorList().clear();
+      handle.getAggregationDescriptorList().addAll(descriptorList);
+    }
+
+    private List<SeriesAggregationSourceNode> splitAggregationSourceByPartition(
+        MultiChildNode root, DistributionPlanContext context) {
+      // Step 1: split SeriesAggregationSourceNode according to data partition
+      List<SeriesAggregationSourceNode> sources = new ArrayList<>();
+      Map<PartialPath, Integer> regionCountPerSeries = new HashMap<>();
+      for (PlanNode child : root.getChildren()) {
+        SeriesAggregationSourceNode handle = (SeriesAggregationSourceNode) child;
+        List<TRegionReplicaSet> dataDistribution =
+            analysis.getPartitionInfo(handle.getPartitionPath(), handle.getPartitionTimeFilter());
+        for (TRegionReplicaSet dataRegion : dataDistribution) {
+          SeriesAggregationSourceNode split = (SeriesAggregationSourceNode) handle.clone();
+          split.setPlanNodeId(context.queryContext.getQueryId().genPlanNodeId());
+          split.setRegionReplicaSet(dataRegion);
+          // Let each split reference different object of AggregationDescriptorList
+          split.setAggregationDescriptorList(
+              handle.getAggregationDescriptorList().stream()
+                  .map(AggregationDescriptor::deepClone)
+                  .collect(Collectors.toList()));
+          sources.add(split);
+        }
+        regionCountPerSeries.put(handle.getPartitionPath(), dataDistribution.size());
+      }
+
+      // Step 2: change the step for each SeriesAggregationSourceNode according to its split count
+      for (SeriesAggregationSourceNode source : sources) {
+        boolean isFinal = regionCountPerSeries.get(source.getPartitionPath()) == 1;
+        source
+            .getAggregationDescriptorList()
+            .forEach(d -> d.setStep(isFinal ? AggregationStep.FINAL : AggregationStep.PARTIAL));
+      }
+      return sources;
+    }
+
     public PlanNode visit(PlanNode node, DistributionPlanContext context) {
       return node.accept(this, context);
     }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
index c6933fd19d..f1667f7657 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
@@ -42,7 +42,7 @@ public class AggregationDescriptor {
    *
    * <p>example: select sum(s1) from root.sg.d1; expression [root.sg.d1.s1] will be in this field.
    */
-  private final List<Expression> inputExpressions;
+  private List<Expression> inputExpressions;
 
   private String parametersString;
 
@@ -122,6 +122,10 @@ public class AggregationDescriptor {
     this.step = step;
   }
 
+  public void setInputExpressions(List<Expression> inputExpressions) {
+    this.inputExpressions = inputExpressions;
+  }
+
   public AggregationDescriptor deepClone() {
     return new AggregationDescriptor(
         this.getAggregationType(), this.step, this.getInputExpressions());
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/DistributionPlannerTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/DistributionPlannerTest.java
index bc88569410..a1025fe686 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/DistributionPlannerTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/DistributionPlannerTest.java
@@ -44,7 +44,6 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.FragmentInstance;
 import org.apache.iotdb.db.mpp.plan.planner.plan.LogicalQueryPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.SubPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeUtil;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.TimeSeriesSchemaScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;