You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by fj...@apache.org on 2018/08/26 23:15:30 UTC

[incubator-druid] branch 0.12.3 updated: Support projection after sorting in SQL (#5788) (#6228)

This is an automated email from the ASF dual-hosted git repository.

fjy pushed a commit to branch 0.12.3
in repository https://gitbox.apache.org/repos/asf/incubator-druid.git


The following commit(s) were added to refs/heads/0.12.3 by this push:
     new bc07320  Support projection after sorting in SQL (#5788) (#6228)
bc07320 is described below

commit bc07320ae73a0ba03512fb4b647bf98dfdce6aef
Author: Gian Merlino <gi...@gmail.com>
AuthorDate: Sun Aug 26 16:15:27 2018 -0700

    Support projection after sorting in SQL (#5788) (#6228)
    
    * Add sort project
    
    * add more test
    
    * address comments
---
 .../druid/sql/calcite/aggregation/Aggregation.java |   3 +-
 .../java/io/druid/sql/calcite/rel/DruidQuery.java  | 190 +++++++++++++++------
 .../io/druid/sql/calcite/rel/DruidQueryRel.java    |   8 +-
 .../io/druid/sql/calcite/rel/DruidSemiJoin.java    |   8 +-
 .../druid/sql/calcite/rel/PartialDruidQuery.java   | 120 +++++++++----
 .../java/io/druid/sql/calcite/rel/SortProject.java | 112 ++++++++++++
 .../java/io/druid/sql/calcite/rule/DruidRules.java |  38 ++++-
 .../druid/sql/calcite/rule/DruidSemiJoinRule.java  |  10 +-
 .../io/druid/sql/calcite/CalciteQueryTest.java     | 189 ++++++++++++++++++++
 9 files changed, 581 insertions(+), 97 deletions(-)

diff --git a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
index 2532c8d..09436b9 100644
--- a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
+++ b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
@@ -36,6 +36,7 @@ import io.druid.sql.calcite.filtration.Filtration;
 import io.druid.sql.calcite.table.RowSignature;
 
 import javax.annotation.Nullable;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
@@ -112,7 +113,7 @@ public class Aggregation
 
   public static Aggregation create(final PostAggregator postAggregator)
   {
-    return new Aggregation(ImmutableList.of(), ImmutableList.of(), postAggregator);
+    return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
   }
 
   public static Aggregation create(
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
index 9740f68..2f6fde5 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
@@ -89,6 +89,7 @@ import javax.annotation.Nullable;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.OptionalInt;
 import java.util.TreeSet;
 import java.util.stream.Collectors;
 
@@ -105,9 +106,11 @@ public class DruidQuery
   private final DimFilter filter;
   private final SelectProjection selectProjection;
   private final Grouping grouping;
+  private final SortProject sortProject;
+  private final DefaultLimitSpec limitSpec;
   private final RowSignature outputRowSignature;
   private final RelDataType outputRowType;
-  private final DefaultLimitSpec limitSpec;
+
   private final Query query;
 
   public DruidQuery(
@@ -129,15 +132,22 @@ public class DruidQuery
     this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature);
     this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder, finalizeAggregations);
 
+    final RowSignature sortingInputRowSignature;
+
     if (this.selectProjection != null) {
-      this.outputRowSignature = this.selectProjection.getOutputRowSignature();
+      sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
     } else if (this.grouping != null) {
-      this.outputRowSignature = this.grouping.getOutputRowSignature();
+      sortingInputRowSignature = this.grouping.getOutputRowSignature();
     } else {
-      this.outputRowSignature = sourceRowSignature;
+      sortingInputRowSignature = sourceRowSignature;
     }
 
-    this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature);
+    this.sortProject = computeSortProject(partialQuery, plannerContext, sortingInputRowSignature, grouping);
+
+    // outputRowSignature is used only for scan and select query, and thus sort and grouping must be null
+    this.outputRowSignature = sortProject == null ? sortingInputRowSignature : sortProject.getOutputRowSignature();
+
+    this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
     this.query = computeQuery();
   }
 
@@ -237,7 +247,7 @@ public class DruidQuery
   )
   {
     final Aggregate aggregate = partialQuery.getAggregate();
-    final Project postProject = partialQuery.getPostProject();
+    final Project aggregateProject = partialQuery.getAggregateProject();
 
     if (aggregate == null) {
       return null;
@@ -268,49 +278,27 @@ public class DruidQuery
         plannerContext
     );
 
-    if (postProject == null) {
+    if (aggregateProject == null) {
       return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature);
     } else {
-      final List<String> rowOrder = new ArrayList<>();
-
-      int outputNameCounter = 0;
-      for (final RexNode postAggregatorRexNode : postProject.getChildExps()) {
-        // Attempt to convert to PostAggregator.
-        final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
-            plannerContext,
-            aggregateRowSignature,
-            postAggregatorRexNode
-        );
-
-        if (postAggregatorExpression == null) {
-          throw new CannotBuildQueryException(postProject, postAggregatorRexNode);
-        }
-
-        if (postAggregatorDirectColumnIsOk(aggregateRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
-          // Direct column access, without any type cast as far as Druid's runtime is concerned.
-          // (There might be a SQL-level type cast that we don't care about)
-          rowOrder.add(postAggregatorExpression.getDirectColumn());
-        } else {
-          final String postAggregatorName = "p" + outputNameCounter++;
-          final PostAggregator postAggregator = new ExpressionPostAggregator(
-              postAggregatorName,
-              postAggregatorExpression.getExpression(),
-              null,
-              plannerContext.getExprMacroTable()
-          );
-          aggregations.add(Aggregation.create(postAggregator));
-          rowOrder.add(postAggregator.getName());
-        }
-      }
+      final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
+          plannerContext,
+          aggregateRowSignature,
+          aggregateProject,
+          0
+      );
+      projectRowOrderAndPostAggregations.postAggregations.forEach(
+          postAggregator -> aggregations.add(Aggregation.create(postAggregator))
+      );
 
       // Remove literal dimensions that did not appear in the projection. This is useful for queries
       // like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't
       // actually want to include a dimension 'dummy'.
-      final ImmutableBitSet postProjectBits = RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
+      final ImmutableBitSet aggregateProjectBits = RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
       for (int i = dimensions.size() - 1; i >= 0; i--) {
         final DimensionExpression dimension = dimensions.get(i);
         if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable())
-                  .isLiteral() && !postProjectBits.get(i)) {
+                  .isLiteral() && !aggregateProjectBits.get(i)) {
           dimensions.remove(i);
         }
       }
@@ -319,11 +307,98 @@ public class DruidQuery
           dimensions,
           aggregations,
           havingFilter,
-          RowSignature.from(rowOrder, postProject.getRowType())
+          RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, aggregateProject.getRowType())
       );
     }
   }
 
+  @Nullable
+  private SortProject computeSortProject(
+      PartialDruidQuery partialQuery,
+      PlannerContext plannerContext,
+      RowSignature sortingInputRowSignature,
+      Grouping grouping
+  )
+  {
+    final Project sortProject = partialQuery.getSortProject();
+    if (sortProject == null) {
+      return null;
+    } else {
+      final List<PostAggregator> postAggregators = grouping.getPostAggregators();
+      final OptionalInt maybeMaxCounter = postAggregators
+          .stream()
+          .mapToInt(postAggregator -> Integer.parseInt(postAggregator.getName().substring(1)))
+          .max();
+
+      final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
+          plannerContext,
+          sortingInputRowSignature,
+          sortProject,
+          maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
+      );
+
+      return new SortProject(
+          sortingInputRowSignature,
+          projectRowOrderAndPostAggregations.postAggregations,
+          RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, sortProject.getRowType())
+      );
+    }
+  }
+
+  private static class ProjectRowOrderAndPostAggregations
+  {
+    private final List<String> rowOrder;
+    private final List<PostAggregator> postAggregations;
+
+    ProjectRowOrderAndPostAggregations(List<String> rowOrder, List<PostAggregator> postAggregations)
+    {
+      this.rowOrder = rowOrder;
+      this.postAggregations = postAggregations;
+    }
+  }
+
+  private static ProjectRowOrderAndPostAggregations computePostAggregations(
+      PlannerContext plannerContext,
+      RowSignature inputRowSignature,
+      Project project,
+      int outputNameCounter
+  )
+  {
+    final List<String> rowOrder = new ArrayList<>();
+    final List<PostAggregator> aggregations = new ArrayList<>();
+
+    for (final RexNode postAggregatorRexNode : project.getChildExps()) {
+      // Attempt to convert to PostAggregator.
+      final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
+          plannerContext,
+          inputRowSignature,
+          postAggregatorRexNode
+      );
+
+      if (postAggregatorExpression == null) {
+        throw new CannotBuildQueryException(project, postAggregatorRexNode);
+      }
+
+      if (postAggregatorDirectColumnIsOk(inputRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
+        // Direct column access, without any type cast as far as Druid's runtime is concerned.
+        // (There might be a SQL-level type cast that we don't care about)
+        rowOrder.add(postAggregatorExpression.getDirectColumn());
+      } else {
+        final String postAggregatorName = "p" + outputNameCounter++;
+        final PostAggregator postAggregator = new ExpressionPostAggregator(
+            postAggregatorName,
+            postAggregatorExpression.getExpression(),
+            null,
+            plannerContext.getExprMacroTable()
+        );
+        aggregations.add(postAggregator);
+        rowOrder.add(postAggregator.getName());
+      }
+    }
+
+    return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
+  }
+
   /**
    * Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in the same order.
    *
@@ -548,18 +623,20 @@ public class DruidQuery
   {
     final List<VirtualColumn> retVal = new ArrayList<>();
 
-    if (grouping != null) {
-      if (includeDimensions) {
-        for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
-          retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
+    if (selectProjection != null) {
+      retVal.addAll(selectProjection.getVirtualColumns());
+    } else {
+      if (grouping != null) {
+        if (includeDimensions) {
+          for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
+            retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
+          }
         }
-      }
 
-      for (Aggregation aggregation : grouping.getAggregations()) {
-        retVal.addAll(aggregation.getVirtualColumns());
+        for (Aggregation aggregation : grouping.getAggregations()) {
+          retVal.addAll(aggregation.getVirtualColumns());
+        }
       }
-    } else if (selectProjection != null) {
-      retVal.addAll(selectProjection.getVirtualColumns());
     }
 
     return VirtualColumns.create(retVal);
@@ -575,6 +652,11 @@ public class DruidQuery
     return limitSpec;
   }
 
+  public SortProject getSortProject()
+  {
+    return sortProject;
+  }
+
   public RelDataType getOutputRowType()
   {
     return outputRowType;
@@ -675,7 +757,6 @@ public class DruidQuery
 
       if (limitSpec != null) {
         // If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing).
-
         if (limitSpec.isLimited()) {
           return null;
         }
@@ -805,6 +886,11 @@ public class DruidQuery
 
     final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature);
 
+    final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
+    if (sortProject != null) {
+      postAggregators.addAll(sortProject.getPostAggregators());
+    }
+
     return new GroupByQuery(
         dataSource,
         filtration.getQuerySegmentSpec(),
@@ -813,7 +899,7 @@ public class DruidQuery
         Granularities.ALL,
         grouping.getDimensionSpecs(),
         grouping.getAggregatorFactories(),
-        grouping.getPostAggregators(),
+        postAggregators,
         grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
         limitSpec,
         ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
index c304e5b..a62096a 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
@@ -225,14 +225,18 @@ public class DruidQueryRel extends DruidRel<DruidQueryRel>
       cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size();
     }
 
-    if (partialQuery.getPostProject() != null) {
-      cost += COST_PER_COLUMN * partialQuery.getPostProject().getChildExps().size();
+    if (partialQuery.getAggregateProject() != null) {
+      cost += COST_PER_COLUMN * partialQuery.getAggregateProject().getChildExps().size();
     }
 
     if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) {
       cost *= COST_LIMIT_MULTIPLIER;
     }
 
+    if (partialQuery.getSortProject() != null) {
+      cost += COST_PER_COLUMN * partialQuery.getSortProject().getChildExps().size();
+    }
+
     if (partialQuery.getHavingFilter() != null) {
       cost *= COST_HAVING_MULTIPLIER;
     }
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
index ecfd8bb..7c0d8b6 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
@@ -358,8 +358,12 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
         newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
       }
 
-      if (leftPartialQuery.getPostProject() != null) {
-        newPartialQuery = newPartialQuery.withPostProject(leftPartialQuery.getPostProject());
+      if (leftPartialQuery.getAggregateProject() != null) {
+        newPartialQuery = newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
+      }
+
+      if (leftPartialQuery.getSortProject() != null) {
+        newPartialQuery = newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
       }
 
       if (leftPartialQuery.getSort() != null) {
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
index 01c960c..d7d0f77 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
@@ -46,8 +46,9 @@ public class PartialDruidQuery
   private final Sort selectSort;
   private final Aggregate aggregate;
   private final Filter havingFilter;
-  private final Project postProject;
+  private final Project aggregateProject;
   private final Sort sort;
+  private final Project sortProject;
 
   public enum Stage
   {
@@ -57,8 +58,9 @@ public class PartialDruidQuery
     SELECT_SORT,
     AGGREGATE,
     HAVING_FILTER,
-    POST_PROJECT,
-    SORT
+    AGGREGATE_PROJECT,
+    SORT,
+    SORT_PROJECT
   }
 
   public PartialDruidQuery(
@@ -67,9 +69,10 @@ public class PartialDruidQuery
       final Project selectProject,
       final Sort selectSort,
       final Aggregate aggregate,
-      final Project postProject,
+      final Project aggregateProject,
       final Filter havingFilter,
-      final Sort sort
+      final Sort sort,
+      final Project sortProject
   )
   {
     this.scan = Preconditions.checkNotNull(scan, "scan");
@@ -77,14 +80,15 @@ public class PartialDruidQuery
     this.selectProject = selectProject;
     this.selectSort = selectSort;
     this.aggregate = aggregate;
-    this.postProject = postProject;
+    this.aggregateProject = aggregateProject;
     this.havingFilter = havingFilter;
     this.sort = sort;
+    this.sortProject = sortProject;
   }
 
   public static PartialDruidQuery create(final RelNode scanRel)
   {
-    return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null);
+    return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null, null);
   }
 
   public RelNode getScan()
@@ -117,9 +121,9 @@ public class PartialDruidQuery
     return havingFilter;
   }
 
-  public Project getPostProject()
+  public Project getAggregateProject()
   {
-    return postProject;
+    return aggregateProject;
   }
 
   public Sort getSort()
@@ -127,6 +131,11 @@ public class PartialDruidQuery
     return sort;
   }
 
+  public Project getSortProject()
+  {
+    return sortProject;
+  }
+
   public PartialDruidQuery withWhereFilter(final Filter newWhereFilter)
   {
     validateStage(Stage.WHERE_FILTER);
@@ -136,9 +145,10 @@ public class PartialDruidQuery
         selectProject,
         selectSort,
         aggregate,
-        postProject,
+        aggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -151,9 +161,10 @@ public class PartialDruidQuery
         newSelectProject,
         selectSort,
         aggregate,
-        postProject,
+        aggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -166,9 +177,10 @@ public class PartialDruidQuery
         selectProject,
         newSelectSort,
         aggregate,
-        postProject,
+        aggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -181,9 +193,10 @@ public class PartialDruidQuery
         selectProject,
         selectSort,
         newAggregate,
-        postProject,
+        aggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -196,24 +209,26 @@ public class PartialDruidQuery
         selectProject,
         selectSort,
         aggregate,
-        postProject,
+        aggregateProject,
         newHavingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
-  public PartialDruidQuery withPostProject(final Project newPostProject)
+  public PartialDruidQuery withAggregateProject(final Project newAggregateProject)
   {
-    validateStage(Stage.POST_PROJECT);
+    validateStage(Stage.AGGREGATE_PROJECT);
     return new PartialDruidQuery(
         scan,
         whereFilter,
         selectProject,
         selectSort,
         aggregate,
-        newPostProject,
+        newAggregateProject,
         havingFilter,
-        sort
+        sort,
+        sortProject
     );
   }
 
@@ -226,9 +241,26 @@ public class PartialDruidQuery
         selectProject,
         selectSort,
         aggregate,
-        postProject,
+        aggregateProject,
+        havingFilter,
+        newSort,
+        sortProject
+    );
+  }
+
+  public PartialDruidQuery withSortProject(final Project newSortProject)
+  {
+    validateStage(Stage.SORT_PROJECT);
+    return new PartialDruidQuery(
+        scan,
+        whereFilter,
+        selectProject,
+        selectSort,
+        aggregate,
+        aggregateProject,
         havingFilter,
-        newSort
+        sort,
+        newSortProject
     );
   }
 
@@ -266,6 +298,9 @@ public class PartialDruidQuery
     } else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) {
       // Cannot do any aggregations after a select + sort.
       return false;
+    } else if (stage.compareTo(Stage.SORT) > 0 && sort == null) {
+      // Cannot add sort project without a sort
+      return false;
     } else {
       // Looks good.
       return true;
@@ -278,12 +313,15 @@ public class PartialDruidQuery
    *
    * @return stage
    */
+  @SuppressWarnings("VariableNotUsedInsideIf")
   public Stage stage()
   {
-    if (sort != null) {
+    if (sortProject != null) {
+      return Stage.SORT_PROJECT;
+    } else if (sort != null) {
       return Stage.SORT;
-    } else if (postProject != null) {
-      return Stage.POST_PROJECT;
+    } else if (aggregateProject != null) {
+      return Stage.AGGREGATE_PROJECT;
     } else if (havingFilter != null) {
       return Stage.HAVING_FILTER;
     } else if (aggregate != null) {
@@ -309,10 +347,12 @@ public class PartialDruidQuery
     final Stage currentStage = stage();
 
     switch (currentStage) {
+      case SORT_PROJECT:
+        return sortProject;
       case SORT:
         return sort;
-      case POST_PROJECT:
-        return postProject;
+      case AGGREGATE_PROJECT:
+        return aggregateProject;
       case HAVING_FILTER:
         return havingFilter;
       case AGGREGATE:
@@ -353,14 +393,25 @@ public class PartialDruidQuery
            Objects.equals(selectSort, that.selectSort) &&
            Objects.equals(aggregate, that.aggregate) &&
            Objects.equals(havingFilter, that.havingFilter) &&
-           Objects.equals(postProject, that.postProject) &&
-           Objects.equals(sort, that.sort);
+           Objects.equals(aggregateProject, that.aggregateProject) &&
+           Objects.equals(sort, that.sort) &&
+           Objects.equals(sortProject, that.sortProject);
   }
 
   @Override
   public int hashCode()
   {
-    return Objects.hash(scan, whereFilter, selectProject, selectSort, aggregate, havingFilter, postProject, sort);
+    return Objects.hash(
+        scan,
+        whereFilter,
+        selectProject,
+        selectSort,
+        aggregate,
+        havingFilter,
+        aggregateProject,
+        sort,
+        sortProject
+    );
   }
 
   @Override
@@ -373,8 +424,9 @@ public class PartialDruidQuery
            ", selectSort=" + selectSort +
            ", aggregate=" + aggregate +
            ", havingFilter=" + havingFilter +
-           ", postProject=" + postProject +
+           ", aggregateProject=" + aggregateProject +
            ", sort=" + sort +
+           ", sortProject=" + sortProject +
            '}';
   }
 }
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
new file mode 100644
index 0000000..c00aff9
--- /dev/null
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to Metamarkets Group Inc. (Metamarkets) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. Metamarkets licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package io.druid.sql.calcite.rel;
+
+import com.google.common.base.Preconditions;
+import io.druid.java.util.common.ISE;
+import io.druid.query.aggregation.PostAggregator;
+import io.druid.sql.calcite.table.RowSignature;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+
+public class SortProject
+{
+  private final RowSignature inputRowSignature;
+  private final List<PostAggregator> postAggregators;
+  private final RowSignature outputRowSignature;
+
+  SortProject(
+      RowSignature inputRowSignature,
+      List<PostAggregator> postAggregators,
+      RowSignature outputRowSignature
+  )
+  {
+    this.inputRowSignature = Preconditions.checkNotNull(inputRowSignature, "inputRowSignature");
+    this.postAggregators = Preconditions.checkNotNull(postAggregators, "postAggregators");
+    this.outputRowSignature = Preconditions.checkNotNull(outputRowSignature, "outputRowSignature");
+
+    // Verify no collisions.
+    final Set<String> seen = new HashSet<>();
+    inputRowSignature.getRowOrder().forEach(field -> {
+      if (!seen.add(field)) {
+        throw new ISE("Duplicate field name: %s", field);
+      }
+    });
+
+    for (PostAggregator postAggregator : postAggregators) {
+      if (postAggregator == null) {
+        throw new ISE("aggregation[%s] is not a postAggregator", postAggregator);
+      }
+      if (!seen.add(postAggregator.getName())) {
+        throw new ISE("Duplicate field name: %s", postAggregator.getName());
+      }
+    }
+
+    // Verify that items in the output signature exist.
+    outputRowSignature.getRowOrder().forEach(field -> {
+      if (!seen.contains(field)) {
+        throw new ISE("Missing field in rowOrder: %s", field);
+      }
+    });
+  }
+
+  public List<PostAggregator> getPostAggregators()
+  {
+    return postAggregators;
+  }
+
+  public RowSignature getOutputRowSignature()
+  {
+    return outputRowSignature;
+  }
+
+  @Override
+  public boolean equals(Object o)
+  {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    SortProject sortProject = (SortProject) o;
+    return Objects.equals(inputRowSignature, sortProject.inputRowSignature) &&
+           Objects.equals(postAggregators, sortProject.postAggregators) &&
+           Objects.equals(outputRowSignature, sortProject.outputRowSignature);
+  }
+
+  @Override
+  public int hashCode()
+  {
+    return Objects.hash(inputRowSignature, postAggregators, outputRowSignature);
+  }
+
+  @Override
+  public String toString()
+  {
+    return "SortProject{" +
+           "inputRowSignature=" + inputRowSignature +
+           ", postAggregators=" + postAggregators +
+           ", outputRowSignature=" + outputRowSignature +
+           '}';
+  }
+}
diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
index b565aba..c2c6208 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
@@ -68,8 +68,8 @@ public class DruidRules
         ),
         new DruidQueryRule<>(
             Project.class,
-            PartialDruidQuery.Stage.POST_PROJECT,
-            PartialDruidQuery::withPostProject
+            PartialDruidQuery.Stage.AGGREGATE_PROJECT,
+            PartialDruidQuery::withAggregateProject
         ),
         new DruidQueryRule<>(
             Filter.class,
@@ -81,10 +81,16 @@ public class DruidRules
             PartialDruidQuery.Stage.SORT,
             PartialDruidQuery::withSort
         ),
+        new DruidQueryRule<>(
+            Project.class,
+            PartialDruidQuery.Stage.SORT_PROJECT,
+            PartialDruidQuery::withSortProject
+        ),
         DruidOuterQueryRule.AGGREGATE,
         DruidOuterQueryRule.FILTER_AGGREGATE,
         DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE,
-        DruidOuterQueryRule.PROJECT_AGGREGATE
+        DruidOuterQueryRule.PROJECT_AGGREGATE,
+        DruidOuterQueryRule.AGGREGATE_SORT_PROJECT
     );
   }
 
@@ -227,6 +233,32 @@ public class DruidRules
       }
     };
 
+    public static RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule(
+        operand(Project.class, operand(Sort.class, operand(Aggregate.class, operand(DruidRel.class, any())))),
+        "AGGREGATE_SORT_PROJECT"
+    )
+    {
+      @Override
+      public void onMatch(RelOptRuleCall call)
+      {
+        final Project sortProject = call.rel(0);
+        final Sort sort = call.rel(1);
+        final Aggregate aggregate = call.rel(2);
+        final DruidRel druidRel = call.rel(3);
+
+        final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
+            druidRel,
+            PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
+                             .withAggregate(aggregate)
+                             .withSort(sort)
+                             .withSortProject(sortProject)
+        );
+        if (outerQueryRel.isValidDruidQuery()) {
+          call.transformTo(outerQueryRel);
+        }
+      }
+    };
+
     public DruidOuterQueryRule(final RelOptRuleOperand op, final String description)
     {
       super(op, StringUtils.format("%s(%s)", DruidOuterQueryRel.class.getSimpleName(), description));
diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
index 5376ff1..9ef0430 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
@@ -24,6 +24,7 @@ import com.google.common.base.Predicates;
 import io.druid.sql.calcite.planner.PlannerConfig;
 import io.druid.sql.calcite.rel.DruidRel;
 import io.druid.sql.calcite.rel.DruidSemiJoin;
+import io.druid.sql.calcite.rel.PartialDruidQuery;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.RelOptUtil;
@@ -115,15 +116,18 @@ public class DruidSemiJoinRule extends RelOptRule
       return;
     }
 
-    final Project rightPostProject = right.getPartialDruidQuery().getPostProject();
+    final PartialDruidQuery rightQuery = right.getPartialDruidQuery();
+    final Project rightProject = rightQuery.getSortProject() != null ?
+                                 rightQuery.getSortProject() :
+                                 rightQuery.getAggregateProject();
     int i = 0;
     for (int joinRef : joinInfo.rightSet()) {
       final int aggregateRef;
 
-      if (rightPostProject == null) {
+      if (rightProject == null) {
         aggregateRef = joinRef;
       } else {
-        final RexNode projectExp = rightPostProject.getChildExps().get(joinRef);
+        final RexNode projectExp = rightProject.getChildExps().get(joinRef);
         if (projectExp.isA(SqlKind.INPUT_REF)) {
           aggregateRef = ((RexInputRef) projectExp).getIndex();
         } else {
diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
index 5bfa8a3..dccdc5f 100644
--- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
@@ -73,6 +73,7 @@ import io.druid.query.groupby.GroupByQuery;
 import io.druid.query.groupby.having.DimFilterHavingSpec;
 import io.druid.query.groupby.orderby.DefaultLimitSpec;
 import io.druid.query.groupby.orderby.OrderByColumnSpec;
+import io.druid.query.groupby.orderby.OrderByColumnSpec.Direction;
 import io.druid.query.lookup.RegisteredLookupExtractionFn;
 import io.druid.query.ordering.StringComparator;
 import io.druid.query.ordering.StringComparators;
@@ -123,6 +124,7 @@ import org.junit.rules.TemporaryFolder;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
@@ -6446,6 +6448,193 @@ public class CalciteQueryTest extends CalciteTestBase
     );
   }
 
+  @Test
+  public void testProjectAfterSort() throws Exception
+  {
+    testQuery(
+        "select dim1 from (select dim1, dim2, count(*) cnt from druid.foo group by dim1, dim2 order by cnt)",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(CalciteTests.DATASOURCE1)
+                        .setInterval(QSS(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(
+                            DIMS(
+                                new DefaultDimensionSpec("dim1", "d0"),
+                                new DefaultDimensionSpec("dim2", "d1")
+                            )
+                        )
+                        .setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
+                        .setLimitSpec(
+                            new DefaultLimitSpec(
+                                Collections.singletonList(
+                                    new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
+                                ),
+                                Integer.MAX_VALUE
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{""},
+            new Object[]{"1"},
+            new Object[]{"10.1"},
+            new Object[]{"2"},
+            new Object[]{"abc"},
+            new Object[]{"def"}
+        )
+    );
+  }
+
+  @Test
+  public void testProjectAfterSort2() throws Exception
+  {
+    testQuery(
+        "select s / cnt, dim1, dim2, s from (select dim1, dim2, count(*) cnt, sum(m2) s from druid.foo group by dim1, dim2 order by cnt)",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(CalciteTests.DATASOURCE1)
+                        .setInterval(QSS(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(
+                            DIMS(
+                                new DefaultDimensionSpec("dim1", "d0"),
+                                new DefaultDimensionSpec("dim2", "d1")
+                            )
+                        )
+                        .setAggregatorSpecs(
+                            AGGS(new CountAggregatorFactory("a0"), new DoubleSumAggregatorFactory("a1", "m2"))
+                        )
+                        .setPostAggregatorSpecs(Collections.singletonList(EXPRESSION_POST_AGG("p0", "(\"a1\" / \"a0\")")))
+                        .setLimitSpec(
+                            new DefaultLimitSpec(
+                                Collections.singletonList(
+                                    new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
+                                ),
+                                Integer.MAX_VALUE
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{1.0, "", "a", 1.0},
+            new Object[]{4.0, "1", "a", 4.0},
+            new Object[]{2.0, "10.1", "", 2.0},
+            new Object[]{3.0, "2", "", 3.0},
+            new Object[]{6.0, "abc", "", 6.0},
+            new Object[]{5.0, "def", "abc", 5.0}
+        )
+    );
+  }
+
+  @Test
+  public void testProjectAfterSort3() throws Exception
+  {
+    testQuery(
+        "select dim1 from (select dim1, dim1, count(*) cnt from druid.foo group by dim1, dim1 order by cnt)",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(CalciteTests.DATASOURCE1)
+                        .setInterval(QSS(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(
+                            DIMS(
+                                new DefaultDimensionSpec("dim1", "d0")
+                            )
+                        )
+                        .setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
+                        .setLimitSpec(
+                            new DefaultLimitSpec(
+                                Collections.singletonList(
+                                    new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
+                                ),
+                                Integer.MAX_VALUE
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{""},
+            new Object[]{"1"},
+            new Object[]{"10.1"},
+            new Object[]{"2"},
+            new Object[]{"abc"},
+            new Object[]{"def"}
+        )
+    );
+  }
+
+  @Test
+  public void testSortProjectAfterNestedGroupBy() throws Exception
+  {
+    testQuery(
+        "SELECT "
+        + "  cnt "
+        + "FROM ("
+        + "  SELECT "
+        + "    __time, "
+        + "    dim1, "
+        + "    COUNT(m2) AS cnt "
+        + "  FROM ("
+        + "    SELECT "
+        + "        __time, "
+        + "        m2, "
+        + "        dim1 "
+        + "    FROM druid.foo "
+        + "    GROUP BY __time, m2, dim1 "
+        + "  ) "
+        + "  GROUP BY __time, dim1 "
+        + "  ORDER BY cnt"
+        + ")",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            GroupByQuery.builder()
+                                        .setDataSource(CalciteTests.DATASOURCE1)
+                                        .setInterval(QSS(Filtration.eternity()))
+                                        .setGranularity(Granularities.ALL)
+                                        .setDimensions(DIMS(
+                                            new DefaultDimensionSpec("__time", "d0", ValueType.LONG),
+                                            new DefaultDimensionSpec("dim1", "d1"),
+                                            new DefaultDimensionSpec("m2", "d2", ValueType.DOUBLE)
+                                        ))
+                                        .setContext(QUERY_CONTEXT_DEFAULT)
+                                        .build()
+                        )
+                        .setInterval(QSS(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(DIMS(
+                            new DefaultDimensionSpec("d0", "_d0", ValueType.LONG),
+                            new DefaultDimensionSpec("d1", "_d1", ValueType.STRING)
+                        ))
+                        .setAggregatorSpecs(AGGS(
+                            new CountAggregatorFactory("a0")
+                        ))
+                        .setLimitSpec(
+                            new DefaultLimitSpec(
+                                Collections.singletonList(
+                                    new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
+                                ),
+                                Integer.MAX_VALUE
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{1L},
+            new Object[]{1L},
+            new Object[]{1L},
+            new Object[]{1L},
+            new Object[]{1L},
+            new Object[]{1L}
+        )
+    );
+  }
+
   private void testQuery(
       final String sql,
       final List<Query> expectedQueries,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org