You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by xi...@apache.org on 2022/05/31 13:05:58 UTC

[iotdb] 01/01: complete sliding window distribution planning

This is an automated email from the ASF dual-hosted git repository.

xingtanzjr pushed a commit to branch xingtanzjr/agg_distribution_0531
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit ff7bbffa2c40b027060a44403855a56e149e61e1
Author: Jinrui.Zhang <xi...@gmail.com>
AuthorDate: Tue May 31 21:05:38 2022 +0800

    complete sliding window distribution planning
---
 .../distribution/DistributionPlanContext.java      |  11 ++
 .../plan/planner/distribution/SourceRewriter.java  | 181 ++++++++++++++++-----
 .../node/process/SlidingWindowAggregationNode.java |   6 +-
 3 files changed, 154 insertions(+), 44 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanContext.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanContext.java
index 6f9e16e6ee..88f13eddf4 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanContext.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanContext.java
@@ -22,9 +22,20 @@ package org.apache.iotdb.db.mpp.plan.planner.distribution;
 import org.apache.iotdb.db.mpp.common.MPPQueryContext;
 
 public class DistributionPlanContext {
+  protected boolean isRoot;
   protected MPPQueryContext queryContext;
 
   protected DistributionPlanContext(MPPQueryContext queryContext) {
+    this.isRoot = true;
     this.queryContext = queryContext;
   }
+
+  protected DistributionPlanContext copy() {
+    return new DistributionPlanContext(queryContext);
+  }
+
+  protected DistributionPlanContext setRoot(boolean isRoot) {
+    this.isRoot = isRoot;
+    return this;
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
index 7cbc01469a..c3f85386b9 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
@@ -37,6 +37,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.DeviceViewNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.GroupByLevelNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.MultiChildNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SlidingWindowAggregationNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
@@ -52,6 +53,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByLevelDescriptor;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -452,6 +454,17 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
     return false;
   }
 
+  // This method is only used to process the PlanNodeTree whose root is SlidingWindowAggregationNode
+  @Override
+  public PlanNode visitSlidingWindowAggregation(
+      SlidingWindowAggregationNode node, DistributionPlanContext context) {
+    DistributionPlanContext childContext = context.copy().setRoot(false);
+    PlanNode child = visit(node.getChild(), childContext);
+    node.getChildren().clear();
+    node.addChild(child);
+    return super.visitSlidingWindowAggregation(node, context);
+  }
+
   private PlanNode planAggregationWithTimeJoin(TimeJoinNode root, DistributionPlanContext context) {
 
     List<SeriesAggregationSourceNode> sources = splitAggregationSourceByPartition(root, context);
@@ -469,7 +482,7 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
                 rootAggDescriptorList.add(
                     new AggregationDescriptor(
                         descriptor.getAggregationType(),
-                        AggregationStep.FINAL,
+                        context.isRoot ? AggregationStep.FINAL : AggregationStep.PARTIAL,
                         descriptor.getInputExpressions()));
               });
     }
@@ -512,6 +525,63 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
     Map<TRegionReplicaSet, List<SeriesAggregationSourceNode>> sourceGroup =
         sources.stream().collect(Collectors.groupingBy(SourceNode::getRegionReplicaSet));
 
+    boolean containsSlidingWindow =
+        root.getChildren().size() == 1
+            && root.getChildren().get(0) instanceof SlidingWindowAggregationNode;
+
+    GroupByLevelNode newRoot =
+        containsSlidingWindow
+            ? groupSourcesForGroupByLevelWithSlidingWindow(
+                root,
+                (SlidingWindowAggregationNode) root.getChildren().get(0),
+                sourceGroup,
+                context)
+            : groupSourcesForGroupByLevel(root, sourceGroup, context);
+
+    // Then, we calculate the attributes for GroupByLevelNode in each level
+    calculateGroupByLevelNodeAttributes(newRoot, 0);
+    return newRoot;
+  }
+
+  private GroupByLevelNode groupSourcesForGroupByLevelWithSlidingWindow(
+      GroupByLevelNode root,
+      SlidingWindowAggregationNode slidingWindowNode,
+      Map<TRegionReplicaSet, List<SeriesAggregationSourceNode>> sourceGroup,
+      DistributionPlanContext context) {
+    GroupByLevelNode newRoot = (GroupByLevelNode) root.clone();
+    List<SlidingWindowAggregationNode> groups = new ArrayList<>();
+    sourceGroup.forEach(
+        (dataRegion, sourceNodes) -> {
+          SlidingWindowAggregationNode parentOfGroup =
+              (SlidingWindowAggregationNode) slidingWindowNode.clone();
+          parentOfGroup.setPlanNodeId(context.queryContext.getQueryId().genPlanNodeId());
+          if (sourceNodes.size() == 1) {
+            parentOfGroup.addChild(sourceNodes.get(0));
+          } else {
+            TimeJoinNode timeJoinNode =
+                new TimeJoinNode(
+                    context.queryContext.getQueryId().genPlanNodeId(), root.getScanOrder());
+            sourceNodes.forEach(timeJoinNode::addChild);
+          }
+          groups.add(parentOfGroup);
+        });
+    for (int i = 0; i < groups.size(); i++) {
+      if (i == 0) {
+        newRoot.addChild(groups.get(i));
+        continue;
+      }
+      GroupByLevelNode parent = (GroupByLevelNode) root.clone();
+      parent.setPlanNodeId(context.queryContext.getQueryId().genPlanNodeId());
+      parent.addChild(groups.get(i));
+      newRoot.addChild(parent);
+    }
+    return newRoot;
+  }
+
+  private GroupByLevelNode groupSourcesForGroupByLevel(
+      GroupByLevelNode root,
+      Map<TRegionReplicaSet, List<SeriesAggregationSourceNode>> sourceGroup,
+      DistributionPlanContext context) {
     GroupByLevelNode newRoot = (GroupByLevelNode) root.clone();
     final boolean[] addParent = {false};
     sourceGroup.forEach(
@@ -523,8 +593,6 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
               sourceNodes.forEach(newRoot::addChild);
               addParent[0] = true;
             } else {
-              // We clone a TimeJoinNode from root to make the params to be consistent.
-              // But we need to assign a new ID to it
               GroupByLevelNode parentOfGroup = (GroupByLevelNode) root.clone();
               parentOfGroup.setPlanNodeId(context.queryContext.getQueryId().genPlanNodeId());
               sourceNodes.forEach(parentOfGroup::addChild);
@@ -532,55 +600,69 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
             }
           }
         });
-
-    // Then, we calculate the attributes for GroupByLevelNode in each level
-    calculateGroupByLevelNodeAttributes(newRoot, 0);
     return newRoot;
   }
 
+  // TODO: (xingtanzjr) consider to implement the descriptor construction in every class
   private void calculateGroupByLevelNodeAttributes(PlanNode node, int level) {
     if (node == null) {
       return;
     }
     node.getChildren().forEach(child -> calculateGroupByLevelNodeAttributes(child, level + 1));
-    if (!(node instanceof GroupByLevelNode)) {
-      return;
-    }
-    GroupByLevelNode handle = (GroupByLevelNode) node;
 
     // Construct all outputColumns from children. Using Set here to avoid duplication
     Set<String> childrenOutputColumns = new HashSet<>();
-    handle
-        .getChildren()
-        .forEach(child -> childrenOutputColumns.addAll(child.getOutputColumnNames()));
-
-    // Check every OutputColumn of GroupByLevelNode and set the Expression of corresponding
-    // AggregationDescriptor
-    List<GroupByLevelDescriptor> descriptorList = new ArrayList<>();
-    for (GroupByLevelDescriptor originalDescriptor : handle.getGroupByLevelDescriptors()) {
-      List<Expression> descriptorExpression = new ArrayList<>();
-      for (String childColumn : childrenOutputColumns) {
-        // If this condition matched, the childColumn should come from GroupByLevelNode
-        if (isAggColumnMatchExpression(childColumn, originalDescriptor.getOutputExpression())) {
-          descriptorExpression.add(originalDescriptor.getOutputExpression());
-          continue;
-        }
-        for (Expression exp : originalDescriptor.getInputExpressions()) {
-          if (isAggColumnMatchExpression(childColumn, exp)) {
-            descriptorExpression.add(exp);
+    node.getChildren().forEach(child -> childrenOutputColumns.addAll(child.getOutputColumnNames()));
+
+    if (node instanceof SlidingWindowAggregationNode) {
+      SlidingWindowAggregationNode handle = (SlidingWindowAggregationNode) node;
+      List<AggregationDescriptor> descriptorList = new ArrayList<>();
+      for (AggregationDescriptor originalDescriptor : handle.getAggregationDescriptorList()) {
+        boolean keep = false;
+        for (String childColumn : childrenOutputColumns) {
+          for (Expression exp : originalDescriptor.getInputExpressions()) {
+            if (isAggColumnMatchExpression(childColumn, exp)) {
+              keep = true;
+            }
           }
         }
+        if (keep) {
+          descriptorList.add(originalDescriptor);
+        }
       }
-      if (descriptorExpression.size() == 0) {
-        continue;
-      }
-      GroupByLevelDescriptor descriptor = originalDescriptor.deepClone();
-      descriptor.setStep(level == 0 ? AggregationStep.FINAL : AggregationStep.PARTIAL);
-      descriptor.setInputExpressions(descriptorExpression);
+      handle.setAggregationDescriptorList(descriptorList);
+    }
+
+    if (node instanceof GroupByLevelNode) {
+      GroupByLevelNode handle = (GroupByLevelNode) node;
+      // Check every OutputColumn of GroupByLevelNode and set the Expression of corresponding
+      // AggregationDescriptor
+      List<GroupByLevelDescriptor> descriptorList = new ArrayList<>();
+      for (GroupByLevelDescriptor originalDescriptor : handle.getGroupByLevelDescriptors()) {
+        List<Expression> descriptorExpression = new ArrayList<>();
+        for (String childColumn : childrenOutputColumns) {
+          // If this condition matched, the childColumn should come from GroupByLevelNode
+          if (isAggColumnMatchExpression(childColumn, originalDescriptor.getOutputExpression())) {
+            descriptorExpression.add(originalDescriptor.getOutputExpression());
+            continue;
+          }
+          for (Expression exp : originalDescriptor.getInputExpressions()) {
+            if (isAggColumnMatchExpression(childColumn, exp)) {
+              descriptorExpression.add(exp);
+            }
+          }
+        }
+        if (descriptorExpression.size() == 0) {
+          continue;
+        }
+        GroupByLevelDescriptor descriptor = originalDescriptor.deepClone();
+        descriptor.setStep(level == 0 ? AggregationStep.FINAL : AggregationStep.PARTIAL);
+        descriptor.setInputExpressions(descriptorExpression);
 
-      descriptorList.add(descriptor);
+        descriptorList.add(descriptor);
+      }
+      handle.setGroupByLevelDescriptors(descriptorList);
     }
-    handle.setGroupByLevelDescriptors(descriptorList);
   }
 
   // TODO: (xingtanzjr) need to confirm the logic when processing UDF
@@ -592,26 +674,27 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
   }
 
   private List<SeriesAggregationSourceNode> splitAggregationSourceByPartition(
-      MultiChildNode root, DistributionPlanContext context) {
+      PlanNode root, DistributionPlanContext context) {
+    // Step 0: get all SeriesAggregationSourceNode in PlanNodeTree
+    List<SeriesAggregationSourceNode> rawSources = findAggregationSourceNode(root);
     // Step 1: split SeriesAggregationSourceNode according to data partition
     List<SeriesAggregationSourceNode> sources = new ArrayList<>();
     Map<PartialPath, Integer> regionCountPerSeries = new HashMap<>();
-    for (PlanNode child : root.getChildren()) {
-      SeriesAggregationSourceNode handle = (SeriesAggregationSourceNode) child;
+    for (SeriesAggregationSourceNode child : rawSources) {
       List<TRegionReplicaSet> dataDistribution =
-          analysis.getPartitionInfo(handle.getPartitionPath(), handle.getPartitionTimeFilter());
+          analysis.getPartitionInfo(child.getPartitionPath(), child.getPartitionTimeFilter());
       for (TRegionReplicaSet dataRegion : dataDistribution) {
-        SeriesAggregationSourceNode split = (SeriesAggregationSourceNode) handle.clone();
+        SeriesAggregationSourceNode split = (SeriesAggregationSourceNode) child.clone();
         split.setPlanNodeId(context.queryContext.getQueryId().genPlanNodeId());
         split.setRegionReplicaSet(dataRegion);
         // Let each split reference different object of AggregationDescriptorList
         split.setAggregationDescriptorList(
-            handle.getAggregationDescriptorList().stream()
+            child.getAggregationDescriptorList().stream()
                 .map(AggregationDescriptor::deepClone)
                 .collect(Collectors.toList()));
         sources.add(split);
       }
-      regionCountPerSeries.put(handle.getPartitionPath(), dataDistribution.size());
+      regionCountPerSeries.put(child.getPartitionPath(), dataDistribution.size());
     }
 
     // Step 2: change the step for each SeriesAggregationSourceNode according to its split count
@@ -626,6 +709,18 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
     return sources;
   }
 
+  private List<SeriesAggregationSourceNode> findAggregationSourceNode(PlanNode node) {
+    if (node == null) {
+      return new ArrayList<>();
+    }
+    if (node instanceof SeriesAggregationSourceNode) {
+      return Collections.singletonList((SeriesAggregationSourceNode) node);
+    }
+    List<SeriesAggregationSourceNode> ret = new ArrayList<>();
+    node.getChildren().forEach(child -> ret.addAll(findAggregationSourceNode(child)));
+    return ret;
+  }
+
   public PlanNode visit(PlanNode node, DistributionPlanContext context) {
     return node.accept(this, context);
   }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/SlidingWindowAggregationNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/SlidingWindowAggregationNode.java
index d0df526b93..5efbacbd4e 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/SlidingWindowAggregationNode.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/SlidingWindowAggregationNode.java
@@ -40,7 +40,7 @@ public class SlidingWindowAggregationNode extends ProcessNode {
 
   // The list of aggregate functions, each AggregateDescriptor will be output as one column of
   // result TsBlock
-  private final List<AggregationDescriptor> aggregationDescriptorList;
+  private List<AggregationDescriptor> aggregationDescriptorList;
 
   // The parameter of `group by time`.
   private final GroupByTimeParameter groupByTimeParameter;
@@ -74,6 +74,10 @@ public class SlidingWindowAggregationNode extends ProcessNode {
     return aggregationDescriptorList;
   }
 
+  public void setAggregationDescriptorList(List<AggregationDescriptor> aggregationDescriptorList) {
+    this.aggregationDescriptorList = aggregationDescriptorList;
+  }
+
   public GroupByTimeParameter getGroupByTimeParameter() {
     return groupByTimeParameter;
   }