You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by xi...@apache.org on 2022/05/22 09:09:32 UTC

[iotdb] branch xingtanzjr/agg_distribution_plan updated: add more tests for GroupByLevelNode distribution plan

This is an automated email from the ASF dual-hosted git repository.

xingtanzjr pushed a commit to branch xingtanzjr/agg_distribution_plan
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/xingtanzjr/agg_distribution_plan by this push:
     new b262189ad1 add more tests for GroupByLevelNode distribution plan
b262189ad1 is described below

commit b262189ad1fff2b929e51d0b6b282e84a9b21aaa
Author: Jinrui.Zhang <xi...@gmail.com>
AuthorDate: Sun May 22 17:09:20 2022 +0800

    add more tests for GroupByLevelNode distribution plan
---
 .../db/mpp/plan/planner/DistributionPlanner.java   |  38 +++---
 .../db/mpp/plan/planner/LocalExecutionPlanner.java |   1 -
 .../plan/parameter/GroupByLevelDescriptor.java     |   2 +-
 .../plan/analyze/AggregationDescriptorTest.java    |   4 +-
 .../db/mpp/plan/plan/DistributionPlannerTest.java  | 135 ++++++++++++++++++++-
 5 files changed, 152 insertions(+), 28 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/DistributionPlanner.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/DistributionPlanner.java
index 13182a6d40..065110d231 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/DistributionPlanner.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/DistributionPlanner.java
@@ -19,12 +19,12 @@
 package org.apache.iotdb.db.mpp.plan.planner;
 
 import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
-import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.db.mpp.common.MPPQueryContext;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
 import org.apache.iotdb.db.mpp.plan.analyze.Analysis;
 import org.apache.iotdb.db.mpp.plan.analyze.QueryType;
+import org.apache.iotdb.db.mpp.plan.expression.Expression;
 import org.apache.iotdb.db.mpp.plan.planner.plan.DistributedQueryPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.FragmentInstance;
 import org.apache.iotdb.db.mpp.plan.planner.plan.LogicalQueryPlan;
@@ -59,8 +59,6 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationDescriptor
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByLevelDescriptor;
 import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
-import org.apache.iotdb.db.query.expression.Expression;
-import org.apache.iotdb.db.query.expression.leaf.TimeSeriesOperand;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -510,32 +508,25 @@ public class DistributionPlanner {
       }
       GroupByLevelNode handle = (GroupByLevelNode) node;
 
-      // Construct all outputColumns from children
-      List<String> childrenOutputColumns = new ArrayList<>();
+      // Construct all outputColumns from children. Using Set here to avoid duplication
+      Set<String> childrenOutputColumns = new HashSet<>();
       handle
           .getChildren()
           .forEach(child -> childrenOutputColumns.addAll(child.getOutputColumnNames()));
 
       // Check every OutputColumn of GroupByLevelNode and set the Expression of corresponding
       // AggregationDescriptor
-      List<String> outputColumnList = new ArrayList<>();
       List<GroupByLevelDescriptor> descriptorList = new ArrayList<>();
-      for (int i = 0; i < handle.getOutputColumnNames().size(); i++) {
-        String column = handle.getOutputColumnNames().get(i);
-        Set<Expression> originalExpressions =
-            analysis.getGroupByLevelExpressions().getOrDefault(column, new HashSet<>());
+      for (GroupByLevelDescriptor originalDescriptor : handle.getGroupByLevelDescriptors()) {
         List<Expression> descriptorExpression = new ArrayList<>();
         for (String childColumn : childrenOutputColumns) {
-          if (childColumn.equals(column)) {
-            try {
-              descriptorExpression.add(new TimeSeriesOperand(new PartialPath(childColumn)));
-            } catch (IllegalPathException e) {
-              throw new RuntimeException("error when plan distribution aggregation query", e);
-            }
+          // If this condition matched, the childColumn should come from GroupByLevelNode
+          if (isAggColumnMatchExpression(childColumn, originalDescriptor.getOutputExpression())) {
+            descriptorExpression.add(originalDescriptor.getOutputExpression());
             continue;
           }
-          for (Expression exp : originalExpressions) {
-            if (exp.getExpressionString().equals(childColumn)) {
+          for (Expression exp : originalDescriptor.getInputExpressions()) {
+            if (isAggColumnMatchExpression(childColumn, exp)) {
               descriptorExpression.add(exp);
             }
           }
@@ -543,16 +534,23 @@ public class DistributionPlanner {
         if (descriptorExpression.size() == 0) {
           continue;
         }
-        GroupByLevelDescriptor descriptor = handle.getGroupByLevelDescriptors().get(i).deepClone();
+        GroupByLevelDescriptor descriptor = originalDescriptor.deepClone();
         descriptor.setStep(level == 0 ? AggregationStep.FINAL : AggregationStep.PARTIAL);
         descriptor.setInputExpressions(descriptorExpression);
 
-        outputColumnList.add(column);
         descriptorList.add(descriptor);
       }
       handle.setGroupByLevelDescriptors(descriptorList);
     }
 
+    // TODO: (xingtanzjr) need to confirm the logic when processing UDF
+    private boolean isAggColumnMatchExpression(String columnName, Expression expression) {
+      if (columnName == null) {
+        return false;
+      }
+      return columnName.contains(expression.getExpressionString());
+    }
+
     private List<SeriesAggregationSourceNode> splitAggregationSourceByPartition(
         MultiChildNode root, DistributionPlanContext context) {
       // Step 1: split SeriesAggregationSourceNode according to data partition
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
index 66a12f3cc3..d071fd1faf 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
@@ -93,7 +93,6 @@ import org.apache.iotdb.db.mpp.execution.operator.source.ExchangeOperator;
 import org.apache.iotdb.db.mpp.execution.operator.source.SeriesAggregationScanOperator;
 import org.apache.iotdb.db.mpp.execution.operator.source.SeriesScanOperator;
 import org.apache.iotdb.db.mpp.plan.analyze.TypeProvider;
-import org.apache.iotdb.db.mpp.plan.expression.Expression;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.ChildNodesSchemaScanNode;
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/GroupByLevelDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/GroupByLevelDescriptor.java
index ccb2bf21e1..2d3dc2de60 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/GroupByLevelDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/GroupByLevelDescriptor.java
@@ -19,8 +19,8 @@
 
 package org.apache.iotdb.db.mpp.plan.planner.plan.parameter;
 
-import org.apache.iotdb.db.query.aggregation.AggregationType;
 import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.db.query.aggregation.AggregationType;
 
 import java.nio.ByteBuffer;
 import java.util.List;
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
index e542d5c568..fb7b72515b 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
@@ -22,12 +22,12 @@ package org.apache.iotdb.db.mpp.plan.analyze;
 import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.db.metadata.path.MeasurementPath;
+import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.db.mpp.plan.expression.leaf.TimeSeriesOperand;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationDescriptor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByLevelDescriptor;
 import org.apache.iotdb.db.query.aggregation.AggregationType;
-import org.apache.iotdb.db.query.expression.Expression;
-import org.apache.iotdb.db.query.expression.leaf.TimeSeriesOperand;
 import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
 
 import org.junit.Assert;
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/DistributionPlannerTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/DistributionPlannerTest.java
index b79c5d2f56..ba1e0df5fb 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/DistributionPlannerTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/DistributionPlannerTest.java
@@ -38,6 +38,8 @@ import org.apache.iotdb.db.mpp.common.MPPQueryContext;
 import org.apache.iotdb.db.mpp.common.QueryId;
 import org.apache.iotdb.db.mpp.plan.analyze.Analysis;
 import org.apache.iotdb.db.mpp.plan.analyze.QueryType;
+import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.db.mpp.plan.expression.leaf.TimeSeriesOperand;
 import org.apache.iotdb.db.mpp.plan.planner.DistributionPlanner;
 import org.apache.iotdb.db.mpp.plan.planner.plan.DistributedQueryPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.FragmentInstance;
@@ -62,8 +64,6 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByLevelDescriptor;
 import org.apache.iotdb.db.mpp.plan.statement.component.OrderBy;
 import org.apache.iotdb.db.query.aggregation.AggregationType;
-import org.apache.iotdb.db.query.expression.Expression;
-import org.apache.iotdb.db.query.expression.leaf.TimeSeriesOperand;
 import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
 import org.apache.iotdb.tsfile.write.schema.MeasurementSchema;
 
@@ -79,6 +79,7 @@ import java.util.Map;
 import java.util.Set;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
 public class DistributionPlannerTest {
 
@@ -451,12 +452,138 @@ public class DistributionPlannerTest {
     DistributionPlanner planner =
         new DistributionPlanner(analysis, new LogicalQueryPlan(context, groupByLevelNode));
     DistributedQueryPlan plan = planner.planFragments();
-    assertEquals(3, plan.getInstances().size());
+    assertEquals(2, plan.getInstances().size());
     Map<String, AggregationStep> expectedStep = new HashMap<>();
     expectedStep.put(d3s1Path, AggregationStep.PARTIAL);
-    expectedStep.put(d4s1Path, AggregationStep.FINAL);
+    expectedStep.put(d4s1Path, AggregationStep.PARTIAL);
+    List<FragmentInstance> fragmentInstances = plan.getInstances();
+    fragmentInstances.forEach(f -> verifyAggregationStep(expectedStep, f.getFragment().getRoot()));
+
+    Map<String, List<String>> expectedDescriptorValue = new HashMap<>();
+    expectedDescriptorValue.put(groupedPath, Arrays.asList(groupedPath, d3s1Path, d4s1Path));
+    verifyGroupByLevelDescriptor(expectedDescriptorValue, (GroupByLevelNode) fragmentInstances.get(0).getFragment().getRoot().getChildren().get(0));
+
+    Map<String, List<String>> expectedDescriptorValue2 = new HashMap<>();
+    expectedDescriptorValue2.put(groupedPath, Arrays.asList(d3s1Path, d4s1Path));
+    verifyGroupByLevelDescriptor(expectedDescriptorValue2, (GroupByLevelNode) fragmentInstances.get(1).getFragment().getRoot().getChildren().get(0));
+  }
+
+  @Test
+  public void testGroupByLevelTwoSeries() throws IllegalPathException {
+    QueryId queryId = new QueryId("test_group_by_level_two_series");
+    String d1s1Path = "root.sg.d1.s1";
+    String d1s2Path = "root.sg.d1.s2";
+    String groupedPathS1 = "root.sg.*.s1";
+    String groupedPathS2 = "root.sg.*.s2";
+
+    GroupByLevelNode groupByLevelNode =
+        new GroupByLevelNode(
+            new PlanNodeId("TestGroupByLevelNode"),
+            Arrays.asList(
+                genAggregationSourceNode(queryId, d1s1Path, AggregationType.COUNT),
+                genAggregationSourceNode(queryId, d1s2Path, AggregationType.COUNT)),
+            Arrays.asList(
+                new GroupByLevelDescriptor(
+                    AggregationType.COUNT,
+                    AggregationStep.FINAL,
+                    Collections.singletonList(
+                        new TimeSeriesOperand(new PartialPath(d1s1Path))),
+                    new TimeSeriesOperand(new PartialPath(groupedPathS1))),
+                new GroupByLevelDescriptor(
+                    AggregationType.COUNT,
+                    AggregationStep.FINAL,
+                    Collections.singletonList(
+                        new TimeSeriesOperand(new PartialPath(d1s2Path))),
+                    new TimeSeriesOperand(new PartialPath(groupedPathS2)))));
+    Analysis analysis = constructAnalysis();
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, groupByLevelNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    Map<String, AggregationStep> expectedStep = new HashMap<>();
+    expectedStep.put(d1s1Path, AggregationStep.PARTIAL);
+    expectedStep.put(d1s2Path, AggregationStep.PARTIAL);
+    List<FragmentInstance> fragmentInstances = plan.getInstances();
+    fragmentInstances.forEach(f -> verifyAggregationStep(expectedStep, f.getFragment().getRoot()));
+
+    Map<String, List<String>> expectedDescriptorValue = new HashMap<>();
+    expectedDescriptorValue.put(groupedPathS1, Arrays.asList(groupedPathS1, d1s1Path));
+    expectedDescriptorValue.put(groupedPathS2, Arrays.asList(groupedPathS2, d1s2Path));
+    verifyGroupByLevelDescriptor(expectedDescriptorValue, (GroupByLevelNode) fragmentInstances.get(0).getFragment().getRoot().getChildren().get(0));
+
+    Map<String, List<String>> expectedDescriptorValue2 = new HashMap<>();
+    expectedDescriptorValue2.put(groupedPathS1, Collections.singletonList(d1s1Path));
+    expectedDescriptorValue2.put(groupedPathS2, Collections.singletonList(d1s2Path));
+    verifyGroupByLevelDescriptor(expectedDescriptorValue2, (GroupByLevelNode) fragmentInstances.get(1).getFragment().getRoot().getChildren().get(0));
+  }
+
+  @Test
+  public void testGroupByLevel2Series2Devices3Regions() throws IllegalPathException {
+    QueryId queryId = new QueryId("test_group_by_level_two_series");
+    String d1s1Path = "root.sg.d1.s1";
+    String d1s2Path = "root.sg.d1.s2";
+    String d2s1Path = "root.sg.d22.s1";
+    String groupedPathS1 = "root.sg.*.s1";
+    String groupedPathS2 = "root.sg.*.s2";
+
+    GroupByLevelNode groupByLevelNode =
+        new GroupByLevelNode(
+            new PlanNodeId("TestGroupByLevelNode"),
+            Arrays.asList(
+                genAggregationSourceNode(queryId, d1s1Path, AggregationType.COUNT),
+                genAggregationSourceNode(queryId, d1s2Path, AggregationType.COUNT),
+                genAggregationSourceNode(queryId, d2s1Path, AggregationType.COUNT)),
+            Arrays.asList(
+                new GroupByLevelDescriptor(
+                    AggregationType.COUNT,
+                    AggregationStep.FINAL,
+                    Arrays.asList(
+                        new TimeSeriesOperand(new PartialPath(d1s1Path)),
+                        new TimeSeriesOperand(new PartialPath(d2s1Path))),
+                    new TimeSeriesOperand(new PartialPath(groupedPathS1))),
+                new GroupByLevelDescriptor(
+                    AggregationType.COUNT,
+                    AggregationStep.FINAL,
+                    Collections.singletonList(
+                        new TimeSeriesOperand(new PartialPath(d1s2Path))),
+                    new TimeSeriesOperand(new PartialPath(groupedPathS2)))));
+    Analysis analysis = constructAnalysis();
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, groupByLevelNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    Map<String, AggregationStep> expectedStep = new HashMap<>();
+    expectedStep.put(d1s1Path, AggregationStep.PARTIAL);
+    expectedStep.put(d1s2Path, AggregationStep.PARTIAL);
+    expectedStep.put(d2s1Path, AggregationStep.FINAL);
     List<FragmentInstance> fragmentInstances = plan.getInstances();
     fragmentInstances.forEach(f -> verifyAggregationStep(expectedStep, f.getFragment().getRoot()));
+
+    Map<String, List<String>> expectedDescriptorValue = new HashMap<>();
+    expectedDescriptorValue.put(groupedPathS1, Arrays.asList(groupedPathS1, d1s1Path, d2s1Path));
+    expectedDescriptorValue.put(groupedPathS2, Arrays.asList(groupedPathS2, d1s2Path));
+    verifyGroupByLevelDescriptor(expectedDescriptorValue, (GroupByLevelNode) fragmentInstances.get(0).getFragment().getRoot().getChildren().get(0));
+
+    Map<String, List<String>> expectedDescriptorValue2 = new HashMap<>();
+    expectedDescriptorValue2.put(groupedPathS1, Collections.singletonList(d1s1Path));
+    expectedDescriptorValue2.put(groupedPathS2, Collections.singletonList(d1s2Path));
+    verifyGroupByLevelDescriptor(expectedDescriptorValue2, (GroupByLevelNode) fragmentInstances.get(2).getFragment().getRoot().getChildren().get(0));
+  }
+
+  private void verifyGroupByLevelDescriptor(Map<String, List<String>> expected, GroupByLevelNode node) {
+    List<GroupByLevelDescriptor> descriptors = node.getGroupByLevelDescriptors();
+    assertEquals(expected.size(), descriptors.size());
+    for (GroupByLevelDescriptor descriptor : descriptors) {
+      String outputExpression = descriptor.getOutputExpression().getExpressionString();
+      assertEquals(expected.get(outputExpression).size(), descriptor.getInputExpressions().size());
+      for (Expression inputExpression : descriptor.getInputExpressions()) {
+        assertTrue(expected.get(outputExpression).contains(inputExpression.getExpressionString()));
+      }
+    }
   }
 
   private SeriesAggregationSourceNode genAggregationSourceNode(