You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by xi...@apache.org on 2022/05/31 15:45:08 UTC

[iotdb] branch xingtanzjr/agg_distribution_0531 updated: add unit tests for slidingWindowNode distribution plan

This is an automated email from the ASF dual-hosted git repository.

xingtanzjr pushed a commit to branch xingtanzjr/agg_distribution_0531
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/xingtanzjr/agg_distribution_0531 by this push:
     new c660cee52b add unit tests for slidingWindowNode distribution plan
c660cee52b is described below

commit c660cee52bfe5e3935de0c88baffd203e024ae09
Author: Jinrui.Zhang <xi...@gmail.com>
AuthorDate: Tue May 31 23:44:58 2022 +0800

    add unit tests for slidingWindowNode distribution plan
---
 .../planner/distribution/ExchangeNodeAdder.java    |  18 ++
 .../plan/planner/distribution/SourceRewriter.java  |   7 +-
 .../node/process/SlidingWindowAggregationNode.java |   4 +
 .../distribution/AggregationDistributionTest.java  | 277 +++++++++++++++++++++
 4 files changed, 303 insertions(+), 3 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
index 4de9558101..3373722007 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
@@ -37,6 +37,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.GroupByLevelNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.MultiChildNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SlidingWindowAggregationNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
@@ -250,6 +251,23 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
     return newNode;
   }
 
+  @Override
+  public PlanNode visitSlidingWindowAggregation(
+      SlidingWindowAggregationNode node, NodeGroupContext context) {
+    return processOneChildNode(node, context);
+  }
+
+  private PlanNode processOneChildNode(PlanNode node, NodeGroupContext context) {
+    PlanNode newNode = node.clone();
+    PlanNode child = visit(node.getChildren().get(0), context);
+    newNode.addChild(child);
+    TRegionReplicaSet dataRegion = context.getNodeDistribution(child.getPlanNodeId()).region;
+    context.putNodeDistribution(
+        newNode.getPlanNodeId(),
+        new NodeDistribution(NodeDistributionType.SAME_WITH_ALL_CHILDREN, dataRegion));
+    return newNode;
+  }
+
   private TRegionReplicaSet calculateDataRegionByChildren(
       List<PlanNode> children, NodeGroupContext context) {
     // Step 1: calculate the count of children group by DataRegion.
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
index c3f85386b9..6c20206302 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
@@ -460,9 +460,9 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
       SlidingWindowAggregationNode node, DistributionPlanContext context) {
     DistributionPlanContext childContext = context.copy().setRoot(false);
     PlanNode child = visit(node.getChild(), childContext);
-    node.getChildren().clear();
-    node.addChild(child);
-    return super.visitSlidingWindowAggregation(node, context);
+    PlanNode newRoot = node.clone();
+    newRoot.addChild(child);
+    return newRoot;
   }
 
   private PlanNode planAggregationWithTimeJoin(TimeJoinNode root, DistributionPlanContext context) {
@@ -562,6 +562,7 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
                 new TimeJoinNode(
                     context.queryContext.getQueryId().genPlanNodeId(), root.getScanOrder());
             sourceNodes.forEach(timeJoinNode::addChild);
+            parentOfGroup.addChild(timeJoinNode);
           }
           groups.add(parentOfGroup);
         });
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/SlidingWindowAggregationNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/SlidingWindowAggregationNode.java
index 5efbacbd4e..9362e40227 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/SlidingWindowAggregationNode.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/SlidingWindowAggregationNode.java
@@ -179,4 +179,8 @@ public class SlidingWindowAggregationNode extends ProcessNode {
   public int hashCode() {
     return Objects.hash(super.hashCode(), aggregationDescriptorList, groupByTimeParameter, child);
   }
+
+  public String toString() {
+    return String.format("SlidingWindowAggregationNode-%s", getPlanNodeId());
+  }
 }
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
index cfa79db920..e47068a7ff 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
@@ -35,8 +35,10 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.FragmentInstance;
 import org.apache.iotdb.db.mpp.plan.planner.plan.LogicalQueryPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.AggregationNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.GroupByLevelNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.LimitNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SlidingWindowAggregationNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesAggregationScanNode;
@@ -45,10 +47,12 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationDescriptor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByLevelDescriptor;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByTimeParameter;
 import org.apache.iotdb.db.mpp.plan.statement.component.OrderBy;
 import org.apache.iotdb.db.query.aggregation.AggregationType;
 import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
 
+import org.junit.Assert;
 import org.junit.Test;
 
 import java.util.ArrayList;
@@ -57,6 +61,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
@@ -101,6 +106,54 @@ public class AggregationDistributionTest {
     root.getChildren().forEach(child -> verifyAggregationStep(expected, child));
   }
 
+  @Test
+  public void testTimeJoinAggregationWithSlidingWindow() throws IllegalPathException {
+    QueryId queryId = new QueryId("test_query_time_join_agg_with_sliding");
+    TimeJoinNode timeJoinNode = new TimeJoinNode(queryId.genPlanNodeId(), OrderBy.TIMESTAMP_ASC);
+    String d1s1Path = "root.sg.d1.s1";
+    timeJoinNode.addChild(genAggregationSourceNode(queryId, d1s1Path, AggregationType.COUNT));
+
+    String d3s1Path = "root.sg.d333.s1";
+    timeJoinNode.addChild(genAggregationSourceNode(queryId, d3s1Path, AggregationType.COUNT));
+
+    SlidingWindowAggregationNode slidingWindowAggregationNode =
+        genSlidingWindowAggregationNode(
+            queryId,
+            Arrays.asList(new PartialPath(d1s1Path), new PartialPath(d3s1Path)),
+            AggregationType.COUNT,
+            AggregationStep.PARTIAL,
+            null);
+
+    slidingWindowAggregationNode.addChild(timeJoinNode);
+
+    Analysis analysis = Util.constructAnalysis();
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    DistributionPlanner planner =
+        new DistributionPlanner(
+            analysis, new LogicalQueryPlan(context, slidingWindowAggregationNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    Map<String, AggregationStep> expectedStep = new HashMap<>();
+    expectedStep.put(d1s1Path, AggregationStep.PARTIAL);
+    expectedStep.put(d3s1Path, AggregationStep.PARTIAL);
+    List<FragmentInstance> fragmentInstances = plan.getInstances();
+    fragmentInstances.forEach(f -> verifyAggregationStep(expectedStep, f.getFragment().getRoot()));
+    AggregationNode aggregationNode =
+        (AggregationNode)
+            fragmentInstances
+                .get(0)
+                .getFragment()
+                .getRoot()
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0);
+    aggregationNode
+        .getAggregationDescriptorList()
+        .forEach(d -> Assert.assertEquals(AggregationStep.PARTIAL, d.getStep()));
+  }
+
   @Test
   public void testTimeJoinAggregationMultiPerRegion() throws IllegalPathException {
     QueryId queryId = new QueryId("test_query_time_join_aggregation");
@@ -234,6 +287,89 @@ public class AggregationDistributionTest {
         (GroupByLevelNode) fragmentInstances.get(1).getFragment().getRoot().getChildren().get(0));
   }
 
+  @Test
+  public void testGroupByLevelNodeWithSlidingWindow() throws IllegalPathException {
+    QueryId queryId = new QueryId("test_group_by_level_with_sliding_window");
+    String d3s1Path = "root.sg.d333.s1";
+    String d4s1Path = "root.sg.d4444.s1";
+    String groupedPath = "root.sg.*.s1";
+
+    SlidingWindowAggregationNode slidingWindowAggregationNode =
+        genSlidingWindowAggregationNode(
+            queryId,
+            Arrays.asList(new PartialPath(d3s1Path), new PartialPath(d4s1Path)),
+            AggregationType.COUNT,
+            AggregationStep.PARTIAL,
+            null);
+    TimeJoinNode timeJoinNode = new TimeJoinNode(queryId.genPlanNodeId(), OrderBy.TIMESTAMP_ASC);
+    timeJoinNode.addChild(genAggregationSourceNode(queryId, d3s1Path, AggregationType.COUNT));
+    timeJoinNode.addChild(genAggregationSourceNode(queryId, d4s1Path, AggregationType.COUNT));
+    slidingWindowAggregationNode.addChild(timeJoinNode);
+
+    GroupByLevelNode groupByLevelNode =
+        new GroupByLevelNode(
+            new PlanNodeId("TestGroupByLevelNode"),
+            Collections.singletonList(slidingWindowAggregationNode),
+            Collections.singletonList(
+                new GroupByLevelDescriptor(
+                    AggregationType.COUNT,
+                    AggregationStep.FINAL,
+                    Arrays.asList(
+                        new TimeSeriesOperand(new PartialPath(d3s1Path)),
+                        new TimeSeriesOperand(new PartialPath(d4s1Path))),
+                    new TimeSeriesOperand(new PartialPath(groupedPath)))),
+            null,
+            OrderBy.TIMESTAMP_ASC);
+
+    Analysis analysis = Util.constructAnalysis();
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, groupByLevelNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    Map<String, AggregationStep> expectedStep = new HashMap<>();
+    expectedStep.put(d3s1Path, AggregationStep.PARTIAL);
+    expectedStep.put(d4s1Path, AggregationStep.PARTIAL);
+    List<FragmentInstance> fragmentInstances = plan.getInstances();
+    fragmentInstances.forEach(f -> verifyAggregationStep(expectedStep, f.getFragment().getRoot()));
+
+    Map<String, List<String>> expectedDescriptorValue = new HashMap<>();
+    expectedDescriptorValue.put(groupedPath, Arrays.asList(groupedPath, d3s1Path, d4s1Path));
+    verifyGroupByLevelDescriptor(
+        expectedDescriptorValue,
+        (GroupByLevelNode) fragmentInstances.get(0).getFragment().getRoot().getChildren().get(0));
+
+    Map<String, List<String>> expectedDescriptorValue2 = new HashMap<>();
+    expectedDescriptorValue2.put(groupedPath, Arrays.asList(d3s1Path, d4s1Path));
+    verifyGroupByLevelDescriptor(
+        expectedDescriptorValue2,
+        (GroupByLevelNode) fragmentInstances.get(1).getFragment().getRoot().getChildren().get(0));
+
+    verifySlidingWindowDescriptor(
+        Arrays.asList(d3s1Path, d4s1Path),
+        (SlidingWindowAggregationNode)
+            fragmentInstances
+                .get(0)
+                .getFragment()
+                .getRoot()
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0));
+    verifySlidingWindowDescriptor(
+        Arrays.asList(d3s1Path, d4s1Path),
+        (SlidingWindowAggregationNode)
+            fragmentInstances
+                .get(1)
+                .getFragment()
+                .getRoot()
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0));
+  }
+
   @Test
   public void testGroupByLevelTwoSeries() throws IllegalPathException {
     QueryId queryId = new QueryId("test_group_by_level_two_series");
@@ -349,6 +485,118 @@ public class AggregationDistributionTest {
         (GroupByLevelNode) fragmentInstances.get(2).getFragment().getRoot().getChildren().get(0));
   }
 
+  @Test
+  public void testGroupByLevelWithSliding2Series2Devices3Regions() throws IllegalPathException {
+    QueryId queryId = new QueryId("test_group_by_level_two_series");
+    String d1s1Path = "root.sg.d1.s1";
+    String d1s2Path = "root.sg.d1.s2";
+    String d2s1Path = "root.sg.d22.s1";
+    String groupedPathS1 = "root.sg.*.s1";
+    String groupedPathS2 = "root.sg.*.s2";
+
+    TimeJoinNode timeJoinNode = new TimeJoinNode(queryId.genPlanNodeId(), OrderBy.TIMESTAMP_ASC);
+    timeJoinNode.addChild(genAggregationSourceNode(queryId, d1s1Path, AggregationType.COUNT));
+    timeJoinNode.addChild(genAggregationSourceNode(queryId, d1s2Path, AggregationType.COUNT));
+    timeJoinNode.addChild(genAggregationSourceNode(queryId, d2s1Path, AggregationType.COUNT));
+
+    SlidingWindowAggregationNode slidingWindowAggregationNode =
+        genSlidingWindowAggregationNode(
+            queryId,
+            Arrays.asList(
+                new PartialPath(d1s1Path), new PartialPath(d1s2Path), new PartialPath(d2s1Path)),
+            AggregationType.COUNT,
+            AggregationStep.PARTIAL,
+            null);
+    slidingWindowAggregationNode.addChild(timeJoinNode);
+
+    GroupByLevelNode groupByLevelNode =
+        new GroupByLevelNode(
+            new PlanNodeId("TestGroupByLevelNode"),
+            Collections.singletonList(slidingWindowAggregationNode),
+            Arrays.asList(
+                new GroupByLevelDescriptor(
+                    AggregationType.COUNT,
+                    AggregationStep.FINAL,
+                    Arrays.asList(
+                        new TimeSeriesOperand(new PartialPath(d1s1Path)),
+                        new TimeSeriesOperand(new PartialPath(d2s1Path))),
+                    new TimeSeriesOperand(new PartialPath(groupedPathS1))),
+                new GroupByLevelDescriptor(
+                    AggregationType.COUNT,
+                    AggregationStep.FINAL,
+                    Collections.singletonList(new TimeSeriesOperand(new PartialPath(d1s2Path))),
+                    new TimeSeriesOperand(new PartialPath(groupedPathS2)))),
+            null,
+            OrderBy.TIMESTAMP_ASC);
+    Analysis analysis = Util.constructAnalysis();
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, groupByLevelNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    Map<String, AggregationStep> expectedStep = new HashMap<>();
+    expectedStep.put(d1s1Path, AggregationStep.PARTIAL);
+    expectedStep.put(d1s2Path, AggregationStep.PARTIAL);
+    expectedStep.put(d2s1Path, AggregationStep.PARTIAL);
+    List<FragmentInstance> fragmentInstances = plan.getInstances();
+    fragmentInstances.forEach(f -> verifyAggregationStep(expectedStep, f.getFragment().getRoot()));
+
+    Map<String, List<String>> expectedDescriptorValue = new HashMap<>();
+    expectedDescriptorValue.put(groupedPathS1, Arrays.asList(groupedPathS1, d1s1Path));
+    expectedDescriptorValue.put(groupedPathS2, Arrays.asList(groupedPathS2, d1s2Path));
+    verifyGroupByLevelDescriptor(
+        expectedDescriptorValue,
+        (GroupByLevelNode) fragmentInstances.get(0).getFragment().getRoot().getChildren().get(0));
+
+    Map<String, List<String>> expectedDescriptorValue2 = new HashMap<>();
+    expectedDescriptorValue2.put(groupedPathS1, Collections.singletonList(d2s1Path));
+    verifyGroupByLevelDescriptor(
+        expectedDescriptorValue2,
+        (GroupByLevelNode) fragmentInstances.get(1).getFragment().getRoot().getChildren().get(0));
+
+    Map<String, List<String>> expectedDescriptorValue3 = new HashMap<>();
+    expectedDescriptorValue3.put(groupedPathS1, Collections.singletonList(d1s1Path));
+    expectedDescriptorValue3.put(groupedPathS2, Collections.singletonList(d1s2Path));
+    verifyGroupByLevelDescriptor(
+        expectedDescriptorValue3,
+        (GroupByLevelNode) fragmentInstances.get(2).getFragment().getRoot().getChildren().get(0));
+
+    verifySlidingWindowDescriptor(
+        Arrays.asList(d1s1Path, d1s2Path),
+        (SlidingWindowAggregationNode)
+            fragmentInstances
+                .get(0)
+                .getFragment()
+                .getRoot()
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0));
+    verifySlidingWindowDescriptor(
+        Collections.singletonList(d2s1Path),
+        (SlidingWindowAggregationNode)
+            fragmentInstances
+                .get(1)
+                .getFragment()
+                .getRoot()
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0));
+    verifySlidingWindowDescriptor(
+        Arrays.asList(d1s1Path, d1s2Path),
+        (SlidingWindowAggregationNode)
+            fragmentInstances
+                .get(2)
+                .getFragment()
+                .getRoot()
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0));
+  }
+
   @Test
   public void testAggregation1Series1Region() throws IllegalPathException {
     QueryId queryId = new QueryId("test_aggregation_1_series_1_region");
@@ -378,6 +626,35 @@ public class AggregationDistributionTest {
     }
   }
 
+  private void verifySlidingWindowDescriptor(
+      List<String> expected, SlidingWindowAggregationNode node) {
+    List<AggregationDescriptor> descriptorList = node.getAggregationDescriptorList();
+    assertEquals(expected.size(), descriptorList.size());
+    Map<String, Integer> verification = new HashMap<>();
+    descriptorList.forEach(
+        d -> verification.put(d.getInputExpressions().get(0).getExpressionString(), 1));
+    assertEquals(expected.size(), verification.size());
+    expected.forEach(v -> assertEquals(1, (int) verification.get(v)));
+  }
+
+  private SlidingWindowAggregationNode genSlidingWindowAggregationNode(
+      QueryId queryId,
+      List<PartialPath> paths,
+      AggregationType type,
+      AggregationStep step,
+      GroupByTimeParameter groupByTimeParameter) {
+    return new SlidingWindowAggregationNode(
+        queryId.genPlanNodeId(),
+        paths.stream()
+            .map(
+                path ->
+                    new AggregationDescriptor(
+                        type, step, Collections.singletonList(new TimeSeriesOperand(path))))
+            .collect(Collectors.toList()),
+        groupByTimeParameter,
+        OrderBy.TIMESTAMP_ASC);
+  }
+
   private SeriesAggregationSourceNode genAggregationSourceNode(
       QueryId queryId, String path, AggregationType type) throws IllegalPathException {
     List<AggregationDescriptor> descriptors = new ArrayList<>();