You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by xi...@apache.org on 2022/11/23 15:03:03 UTC

[iotdb] branch xingtanzjr/device_agg_optimize updated: optimize the distribution plan for aggregation with align by device

This is an automated email from the ASF dual-hosted git repository.

xingtanzjr pushed a commit to branch xingtanzjr/device_agg_optimize
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/xingtanzjr/device_agg_optimize by this push:
     new b6f3bdf95b optimize the distribution plan for aggregation with align by device
b6f3bdf95b is described below

commit b6f3bdf95b9b62f5fd064ef4268fd52eac6bfce2
Author: Jinrui.Zhang <xi...@gmail.com>
AuthorDate: Wed Nov 23 23:02:51 2022 +0800

    optimize the distribution plan for aggregation with align by device
---
 .../planner/distribution/DistributionPlanner.java  |   2 +-
 .../planner/distribution/ExchangeNodeAdder.java    | 111 +++++++++++++++++++++
 .../distribution/AggregationDistributionTest.java  |  34 ++++++-
 3 files changed, 141 insertions(+), 6 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanner.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanner.java
index 60882f5cdf..7aafb9198c 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanner.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanner.java
@@ -55,7 +55,7 @@ public class DistributionPlanner {
   }
 
   public PlanNode addExchangeNode(PlanNode root) {
-    ExchangeNodeAdder adder = new ExchangeNodeAdder();
+    ExchangeNodeAdder adder = new ExchangeNodeAdder(this.analysis);
     return adder.visit(root, new NodeGroupContext(context));
   }
 
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
index 60c0a6a9b6..7a1a186861 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
@@ -21,6 +21,7 @@ package org.apache.iotdb.db.mpp.plan.planner.distribution;
 
 import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
 import org.apache.iotdb.commons.partition.DataPartition;
+import org.apache.iotdb.db.mpp.plan.analyze.Analysis;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.WritePlanNode;
@@ -52,9 +53,12 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.LastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesAggregationScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SourceNode;
+import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -62,6 +66,13 @@ import java.util.stream.Collectors;
 import static com.google.common.collect.ImmutableList.toImmutableList;
 
 public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
+
+  private final Analysis analysis;
+
+  public ExchangeNodeAdder(Analysis analysis) {
+    this.analysis = analysis;
+  }
+
   @Override
   public PlanNode visitPlan(PlanNode node, NodeGroupContext context) {
     // TODO: (xingtanzjr) we apply no action for IWritePlanNode currently
@@ -183,6 +194,10 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
 
   @Override
   public PlanNode visitDeviceView(DeviceViewNode node, NodeGroupContext context) {
+    // A temporary way to decrease the FragmentInstance for aggregation with device view.
+    if (isAggregationQuery()) {
+      return processDeviceViewWithAggregation(node, context);
+    }
     return processMultiChildNode(node, context);
   }
 
@@ -241,6 +256,98 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
     return processMultiChildNode(node, context);
   }
 
+  private PlanNode processDeviceViewWithAggregation(DeviceViewNode node, NodeGroupContext context) {
+    // group all the children by DataRegion distribution
+    Map<TRegionReplicaSet, DeviceViewGroup> deviceViewGroupMap = new HashMap<>();
+    for (int i = 0; i < node.getDevices().size(); i++) {
+      String device = node.getDevices().get(i);
+      PlanNode rawChildNode = node.getChildren().get(i);
+      PlanNode visitedChild = visit(rawChildNode, context);
+      TRegionReplicaSet region = context.getNodeDistribution(visitedChild.getPlanNodeId()).region;
+      DeviceViewGroup group = deviceViewGroupMap.computeIfAbsent(region, DeviceViewGroup::new);
+      group.addChild(device, visitedChild);
+    }
+
+    // Generate DeviceViewNode for each group
+    List<DeviceViewNode> deviceViewNodeList = new ArrayList<>();
+    for (DeviceViewGroup group : deviceViewGroupMap.values()) {
+      DeviceViewNode deviceViewNode =
+          new DeviceViewNode(
+              context.queryContext.getQueryId().genPlanNodeId(),
+              node.getMergeOrderParameter(),
+              node.getOutputColumnNames(),
+              node.getDeviceToMeasurementIndexesMap());
+      for (int i = 0; i < group.devices.size(); i++) {
+        deviceViewNode.addChildDeviceNode(group.devices.get(i), group.children.get(i));
+      }
+      context.putNodeDistribution(
+          deviceViewNode.getPlanNodeId(),
+          new NodeDistribution(
+              NodeDistributionType.SAME_WITH_ALL_CHILDREN,
+              context.getNodeDistribution(deviceViewNode.getChildren().get(0).getPlanNodeId())
+                  .region));
+      deviceViewNodeList.add(deviceViewNode);
+    }
+
+    if (deviceViewNodeList.size() == 1) {
+      return deviceViewNodeList.get(0);
+    }
+
+    DeviceMergeNode deviceMergeNode =
+        new DeviceMergeNode(
+            context.queryContext.getQueryId().genPlanNodeId(),
+            node.getMergeOrderParameter(),
+            node.getDevices());
+
+    // Each child of deviceMergeNode has different TRegionReplicaSet, so we can select any one from
+    // its child
+    deviceMergeNode.addChild(deviceViewNodeList.get(0));
+    context.putNodeDistribution(
+        deviceMergeNode.getPlanNodeId(),
+        new NodeDistribution(
+            NodeDistributionType.SAME_WITH_SOME_CHILD,
+            context.getNodeDistribution(deviceViewNodeList.get(0).getPlanNodeId()).region));
+
+    // Add ExchangeNode for any other child except first one
+    for (int i = 1; i < deviceViewNodeList.size(); i++) {
+      PlanNode child = deviceViewNodeList.get(i);
+      ExchangeNode exchangeNode =
+          new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
+      exchangeNode.setChild(child);
+      exchangeNode.setOutputColumnNames(child.getOutputColumnNames());
+      deviceMergeNode.addChild(exchangeNode);
+    }
+    return deviceMergeNode;
+  }
+
+  private static class DeviceViewGroup {
+    public TRegionReplicaSet regionReplicaSet;
+    public List<PlanNode> children;
+    public List<String> devices;
+
+    public DeviceViewGroup(TRegionReplicaSet regionReplicaSet) {
+      this.regionReplicaSet = regionReplicaSet;
+      this.children = new LinkedList<>();
+      this.devices = new LinkedList<>();
+    }
+
+    public void addChild(String device, PlanNode child) {
+      devices.add(device);
+      children.add(child);
+    }
+
+    public int hashCode() {
+      return regionReplicaSet.hashCode();
+    }
+
+    public boolean equals(Object o) {
+      if (o instanceof DeviceViewGroup) {
+        return regionReplicaSet.equals(((DeviceViewGroup) o).regionReplicaSet);
+      }
+      return false;
+    }
+  }
+
   private PlanNode processMultiChildNode(MultiChildProcessNode node, NodeGroupContext context) {
     MultiChildProcessNode newNode = (MultiChildProcessNode) node.clone();
     List<PlanNode> visitedChildren = new ArrayList<>();
@@ -350,6 +457,10 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
     return true;
   }
 
+  private boolean isAggregationQuery() {
+    return ((QueryStatement) analysis.getStatement()).isAggregationQuery();
+  }
+
   public PlanNode visit(PlanNode node, NodeGroupContext context) {
     return node.accept(this, context);
   }
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
index 6db6b2fa36..8074203fc2 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
@@ -35,6 +35,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.LogicalQueryPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.AggregationNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.DeviceMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.DeviceViewNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.GroupByLevelNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SlidingWindowAggregationNode;
@@ -712,8 +713,8 @@ public class AggregationDistributionTest {
   }
 
   @Test
-  public void testAlignByDevice2Device2Region() throws IllegalPathException {
-    QueryId queryId = new QueryId("test_align_by_device_2_device_2_region");
+  public void testAlignByDevice2Device3Region() {
+    QueryId queryId = new QueryId("test_align_by_device_2_device_3_region");
     MPPQueryContext context =
         new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
     String sql = "select count(s1), count(s2) from root.sg.d1,root.sg.d22 align by device";
@@ -729,11 +730,34 @@ public class AggregationDistributionTest {
         plan.getInstances().get(1).getFragment().getPlanNodeTree().getChildren().get(0);
     PlanNode f3Root =
         plan.getInstances().get(2).getFragment().getPlanNodeTree().getChildren().get(0);
+    assertTrue(f1Root instanceof DeviceMergeNode);
+    assertTrue(f2Root instanceof TimeJoinNode);
+    assertTrue(f3Root instanceof DeviceViewNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof AggregationNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertEquals(3, f1Root.getChildren().get(0).getChildren().get(0).getChildren().size());
+  }
+
+  @Test
+  public void testAlignByDevice2Device2Region() {
+    QueryId queryId = new QueryId("test_align_by_device_2_device_2_region");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    String sql = "select count(s1), count(s2) from root.sg.d333,root.sg.d4444 align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    PlanNode f1Root =
+        plan.getInstances().get(0).getFragment().getPlanNodeTree().getChildren().get(0);
+    PlanNode f2Root =
+        plan.getInstances().get(1).getFragment().getPlanNodeTree().getChildren().get(0);
     assertTrue(f1Root instanceof DeviceViewNode);
     assertTrue(f2Root instanceof TimeJoinNode);
-    assertTrue(f3Root instanceof AggregationNode);
-    assertTrue(f1Root.getChildren().get(0) instanceof AggregationNode);
-    assertEquals(3, f1Root.getChildren().get(0).getChildren().size());
+    assertEquals(2, f1Root.getChildren().size());
   }
 
   private void verifyGroupByLevelDescriptor(