You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by fj...@apache.org on 2018/08/26 23:15:30 UTC
[incubator-druid] branch 0.12.3 updated: Support projection after
sorting in SQL (#5788) (#6228)
This is an automated email from the ASF dual-hosted git repository.
fjy pushed a commit to branch 0.12.3
in repository https://gitbox.apache.org/repos/asf/incubator-druid.git
The following commit(s) were added to refs/heads/0.12.3 by this push:
new bc07320 Support projection after sorting in SQL (#5788) (#6228)
bc07320 is described below
commit bc07320ae73a0ba03512fb4b647bf98dfdce6aef
Author: Gian Merlino <gi...@gmail.com>
AuthorDate: Sun Aug 26 16:15:27 2018 -0700
Support projection after sorting in SQL (#5788) (#6228)
* Add sort project
* add more test
* address comments
---
.../druid/sql/calcite/aggregation/Aggregation.java | 3 +-
.../java/io/druid/sql/calcite/rel/DruidQuery.java | 190 +++++++++++++++------
.../io/druid/sql/calcite/rel/DruidQueryRel.java | 8 +-
.../io/druid/sql/calcite/rel/DruidSemiJoin.java | 8 +-
.../druid/sql/calcite/rel/PartialDruidQuery.java | 120 +++++++++----
.../java/io/druid/sql/calcite/rel/SortProject.java | 112 ++++++++++++
.../java/io/druid/sql/calcite/rule/DruidRules.java | 38 ++++-
.../druid/sql/calcite/rule/DruidSemiJoinRule.java | 10 +-
.../io/druid/sql/calcite/CalciteQueryTest.java | 189 ++++++++++++++++++++
9 files changed, 581 insertions(+), 97 deletions(-)
diff --git a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
index 2532c8d..09436b9 100644
--- a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
+++ b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java
@@ -36,6 +36,7 @@ import io.druid.sql.calcite.filtration.Filtration;
import io.druid.sql.calcite.table.RowSignature;
import javax.annotation.Nullable;
+import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@@ -112,7 +113,7 @@ public class Aggregation
public static Aggregation create(final PostAggregator postAggregator)
{
- return new Aggregation(ImmutableList.of(), ImmutableList.of(), postAggregator);
+ return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
}
public static Aggregation create(
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
index 9740f68..2f6fde5 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
@@ -89,6 +89,7 @@ import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.OptionalInt;
import java.util.TreeSet;
import java.util.stream.Collectors;
@@ -105,9 +106,11 @@ public class DruidQuery
private final DimFilter filter;
private final SelectProjection selectProjection;
private final Grouping grouping;
+ private final SortProject sortProject;
+ private final DefaultLimitSpec limitSpec;
private final RowSignature outputRowSignature;
private final RelDataType outputRowType;
- private final DefaultLimitSpec limitSpec;
+
private final Query query;
public DruidQuery(
@@ -129,15 +132,22 @@ public class DruidQuery
this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature);
this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder, finalizeAggregations);
+ final RowSignature sortingInputRowSignature;
+
if (this.selectProjection != null) {
- this.outputRowSignature = this.selectProjection.getOutputRowSignature();
+ sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
} else if (this.grouping != null) {
- this.outputRowSignature = this.grouping.getOutputRowSignature();
+ sortingInputRowSignature = this.grouping.getOutputRowSignature();
} else {
- this.outputRowSignature = sourceRowSignature;
+ sortingInputRowSignature = sourceRowSignature;
}
- this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature);
+ this.sortProject = computeSortProject(partialQuery, plannerContext, sortingInputRowSignature, grouping);
+
+ // outputRowSignature is used only for scan and select query, and thus sort and grouping must be null
+ this.outputRowSignature = sortProject == null ? sortingInputRowSignature : sortProject.getOutputRowSignature();
+
+ this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
this.query = computeQuery();
}
@@ -237,7 +247,7 @@ public class DruidQuery
)
{
final Aggregate aggregate = partialQuery.getAggregate();
- final Project postProject = partialQuery.getPostProject();
+ final Project aggregateProject = partialQuery.getAggregateProject();
if (aggregate == null) {
return null;
@@ -268,49 +278,27 @@ public class DruidQuery
plannerContext
);
- if (postProject == null) {
+ if (aggregateProject == null) {
return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature);
} else {
- final List<String> rowOrder = new ArrayList<>();
-
- int outputNameCounter = 0;
- for (final RexNode postAggregatorRexNode : postProject.getChildExps()) {
- // Attempt to convert to PostAggregator.
- final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
- plannerContext,
- aggregateRowSignature,
- postAggregatorRexNode
- );
-
- if (postAggregatorExpression == null) {
- throw new CannotBuildQueryException(postProject, postAggregatorRexNode);
- }
-
- if (postAggregatorDirectColumnIsOk(aggregateRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
- // Direct column access, without any type cast as far as Druid's runtime is concerned.
- // (There might be a SQL-level type cast that we don't care about)
- rowOrder.add(postAggregatorExpression.getDirectColumn());
- } else {
- final String postAggregatorName = "p" + outputNameCounter++;
- final PostAggregator postAggregator = new ExpressionPostAggregator(
- postAggregatorName,
- postAggregatorExpression.getExpression(),
- null,
- plannerContext.getExprMacroTable()
- );
- aggregations.add(Aggregation.create(postAggregator));
- rowOrder.add(postAggregator.getName());
- }
- }
+ final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
+ plannerContext,
+ aggregateRowSignature,
+ aggregateProject,
+ 0
+ );
+ projectRowOrderAndPostAggregations.postAggregations.forEach(
+ postAggregator -> aggregations.add(Aggregation.create(postAggregator))
+ );
// Remove literal dimensions that did not appear in the projection. This is useful for queries
// like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't
// actually want to include a dimension 'dummy'.
- final ImmutableBitSet postProjectBits = RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
+ final ImmutableBitSet aggregateProjectBits = RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
for (int i = dimensions.size() - 1; i >= 0; i--) {
final DimensionExpression dimension = dimensions.get(i);
if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable())
- .isLiteral() && !postProjectBits.get(i)) {
+ .isLiteral() && !aggregateProjectBits.get(i)) {
dimensions.remove(i);
}
}
@@ -319,11 +307,98 @@ public class DruidQuery
dimensions,
aggregations,
havingFilter,
- RowSignature.from(rowOrder, postProject.getRowType())
+ RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, aggregateProject.getRowType())
);
}
}
+ @Nullable
+ private SortProject computeSortProject(
+ PartialDruidQuery partialQuery,
+ PlannerContext plannerContext,
+ RowSignature sortingInputRowSignature,
+ Grouping grouping
+ )
+ {
+ final Project sortProject = partialQuery.getSortProject();
+ if (sortProject == null) {
+ return null;
+ } else {
+ final List<PostAggregator> postAggregators = grouping.getPostAggregators();
+ final OptionalInt maybeMaxCounter = postAggregators
+ .stream()
+ .mapToInt(postAggregator -> Integer.parseInt(postAggregator.getName().substring(1)))
+ .max();
+
+ final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
+ plannerContext,
+ sortingInputRowSignature,
+ sortProject,
+ maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
+ );
+
+ return new SortProject(
+ sortingInputRowSignature,
+ projectRowOrderAndPostAggregations.postAggregations,
+ RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, sortProject.getRowType())
+ );
+ }
+ }
+
+ private static class ProjectRowOrderAndPostAggregations
+ {
+ private final List<String> rowOrder;
+ private final List<PostAggregator> postAggregations;
+
+ ProjectRowOrderAndPostAggregations(List<String> rowOrder, List<PostAggregator> postAggregations)
+ {
+ this.rowOrder = rowOrder;
+ this.postAggregations = postAggregations;
+ }
+ }
+
+ private static ProjectRowOrderAndPostAggregations computePostAggregations(
+ PlannerContext plannerContext,
+ RowSignature inputRowSignature,
+ Project project,
+ int outputNameCounter
+ )
+ {
+ final List<String> rowOrder = new ArrayList<>();
+ final List<PostAggregator> aggregations = new ArrayList<>();
+
+ for (final RexNode postAggregatorRexNode : project.getChildExps()) {
+ // Attempt to convert to PostAggregator.
+ final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
+ plannerContext,
+ inputRowSignature,
+ postAggregatorRexNode
+ );
+
+ if (postAggregatorExpression == null) {
+ throw new CannotBuildQueryException(project, postAggregatorRexNode);
+ }
+
+ if (postAggregatorDirectColumnIsOk(inputRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
+ // Direct column access, without any type cast as far as Druid's runtime is concerned.
+ // (There might be a SQL-level type cast that we don't care about)
+ rowOrder.add(postAggregatorExpression.getDirectColumn());
+ } else {
+ final String postAggregatorName = "p" + outputNameCounter++;
+ final PostAggregator postAggregator = new ExpressionPostAggregator(
+ postAggregatorName,
+ postAggregatorExpression.getExpression(),
+ null,
+ plannerContext.getExprMacroTable()
+ );
+ aggregations.add(postAggregator);
+ rowOrder.add(postAggregator.getName());
+ }
+ }
+
+ return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
+ }
+
/**
* Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in the same order.
*
@@ -548,18 +623,20 @@ public class DruidQuery
{
final List<VirtualColumn> retVal = new ArrayList<>();
- if (grouping != null) {
- if (includeDimensions) {
- for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
- retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
+ if (selectProjection != null) {
+ retVal.addAll(selectProjection.getVirtualColumns());
+ } else {
+ if (grouping != null) {
+ if (includeDimensions) {
+ for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
+ retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
+ }
}
- }
- for (Aggregation aggregation : grouping.getAggregations()) {
- retVal.addAll(aggregation.getVirtualColumns());
+ for (Aggregation aggregation : grouping.getAggregations()) {
+ retVal.addAll(aggregation.getVirtualColumns());
+ }
}
- } else if (selectProjection != null) {
- retVal.addAll(selectProjection.getVirtualColumns());
}
return VirtualColumns.create(retVal);
@@ -575,6 +652,11 @@ public class DruidQuery
return limitSpec;
}
+ public SortProject getSortProject()
+ {
+ return sortProject;
+ }
+
public RelDataType getOutputRowType()
{
return outputRowType;
@@ -675,7 +757,6 @@ public class DruidQuery
if (limitSpec != null) {
// If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing).
-
if (limitSpec.isLimited()) {
return null;
}
@@ -805,6 +886,11 @@ public class DruidQuery
final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature);
+ final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
+ if (sortProject != null) {
+ postAggregators.addAll(sortProject.getPostAggregators());
+ }
+
return new GroupByQuery(
dataSource,
filtration.getQuerySegmentSpec(),
@@ -813,7 +899,7 @@ public class DruidQuery
Granularities.ALL,
grouping.getDimensionSpecs(),
grouping.getAggregatorFactories(),
- grouping.getPostAggregators(),
+ postAggregators,
grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
limitSpec,
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
index c304e5b..a62096a 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
@@ -225,14 +225,18 @@ public class DruidQueryRel extends DruidRel<DruidQueryRel>
cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size();
}
- if (partialQuery.getPostProject() != null) {
- cost += COST_PER_COLUMN * partialQuery.getPostProject().getChildExps().size();
+ if (partialQuery.getAggregateProject() != null) {
+ cost += COST_PER_COLUMN * partialQuery.getAggregateProject().getChildExps().size();
}
if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) {
cost *= COST_LIMIT_MULTIPLIER;
}
+ if (partialQuery.getSortProject() != null) {
+ cost += COST_PER_COLUMN * partialQuery.getSortProject().getChildExps().size();
+ }
+
if (partialQuery.getHavingFilter() != null) {
cost *= COST_HAVING_MULTIPLIER;
}
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
index ecfd8bb..7c0d8b6 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
@@ -358,8 +358,12 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
}
- if (leftPartialQuery.getPostProject() != null) {
- newPartialQuery = newPartialQuery.withPostProject(leftPartialQuery.getPostProject());
+ if (leftPartialQuery.getAggregateProject() != null) {
+ newPartialQuery = newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
+ }
+
+ if (leftPartialQuery.getSortProject() != null) {
+ newPartialQuery = newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
}
if (leftPartialQuery.getSort() != null) {
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
index 01c960c..d7d0f77 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java
@@ -46,8 +46,9 @@ public class PartialDruidQuery
private final Sort selectSort;
private final Aggregate aggregate;
private final Filter havingFilter;
- private final Project postProject;
+ private final Project aggregateProject;
private final Sort sort;
+ private final Project sortProject;
public enum Stage
{
@@ -57,8 +58,9 @@ public class PartialDruidQuery
SELECT_SORT,
AGGREGATE,
HAVING_FILTER,
- POST_PROJECT,
- SORT
+ AGGREGATE_PROJECT,
+ SORT,
+ SORT_PROJECT
}
public PartialDruidQuery(
@@ -67,9 +69,10 @@ public class PartialDruidQuery
final Project selectProject,
final Sort selectSort,
final Aggregate aggregate,
- final Project postProject,
+ final Project aggregateProject,
final Filter havingFilter,
- final Sort sort
+ final Sort sort,
+ final Project sortProject
)
{
this.scan = Preconditions.checkNotNull(scan, "scan");
@@ -77,14 +80,15 @@ public class PartialDruidQuery
this.selectProject = selectProject;
this.selectSort = selectSort;
this.aggregate = aggregate;
- this.postProject = postProject;
+ this.aggregateProject = aggregateProject;
this.havingFilter = havingFilter;
this.sort = sort;
+ this.sortProject = sortProject;
}
public static PartialDruidQuery create(final RelNode scanRel)
{
- return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null);
+ return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null, null);
}
public RelNode getScan()
@@ -117,9 +121,9 @@ public class PartialDruidQuery
return havingFilter;
}
- public Project getPostProject()
+ public Project getAggregateProject()
{
- return postProject;
+ return aggregateProject;
}
public Sort getSort()
@@ -127,6 +131,11 @@ public class PartialDruidQuery
return sort;
}
+ public Project getSortProject()
+ {
+ return sortProject;
+ }
+
public PartialDruidQuery withWhereFilter(final Filter newWhereFilter)
{
validateStage(Stage.WHERE_FILTER);
@@ -136,9 +145,10 @@ public class PartialDruidQuery
selectProject,
selectSort,
aggregate,
- postProject,
+ aggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -151,9 +161,10 @@ public class PartialDruidQuery
newSelectProject,
selectSort,
aggregate,
- postProject,
+ aggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -166,9 +177,10 @@ public class PartialDruidQuery
selectProject,
newSelectSort,
aggregate,
- postProject,
+ aggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -181,9 +193,10 @@ public class PartialDruidQuery
selectProject,
selectSort,
newAggregate,
- postProject,
+ aggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -196,24 +209,26 @@ public class PartialDruidQuery
selectProject,
selectSort,
aggregate,
- postProject,
+ aggregateProject,
newHavingFilter,
- sort
+ sort,
+ sortProject
);
}
- public PartialDruidQuery withPostProject(final Project newPostProject)
+ public PartialDruidQuery withAggregateProject(final Project newAggregateProject)
{
- validateStage(Stage.POST_PROJECT);
+ validateStage(Stage.AGGREGATE_PROJECT);
return new PartialDruidQuery(
scan,
whereFilter,
selectProject,
selectSort,
aggregate,
- newPostProject,
+ newAggregateProject,
havingFilter,
- sort
+ sort,
+ sortProject
);
}
@@ -226,9 +241,26 @@ public class PartialDruidQuery
selectProject,
selectSort,
aggregate,
- postProject,
+ aggregateProject,
+ havingFilter,
+ newSort,
+ sortProject
+ );
+ }
+
+ public PartialDruidQuery withSortProject(final Project newSortProject)
+ {
+ validateStage(Stage.SORT_PROJECT);
+ return new PartialDruidQuery(
+ scan,
+ whereFilter,
+ selectProject,
+ selectSort,
+ aggregate,
+ aggregateProject,
havingFilter,
- newSort
+ sort,
+ newSortProject
);
}
@@ -266,6 +298,9 @@ public class PartialDruidQuery
} else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) {
// Cannot do any aggregations after a select + sort.
return false;
+ } else if (stage.compareTo(Stage.SORT) > 0 && sort == null) {
+ // Cannot add sort project without a sort
+ return false;
} else {
// Looks good.
return true;
@@ -278,12 +313,15 @@ public class PartialDruidQuery
*
* @return stage
*/
+ @SuppressWarnings("VariableNotUsedInsideIf")
public Stage stage()
{
- if (sort != null) {
+ if (sortProject != null) {
+ return Stage.SORT_PROJECT;
+ } else if (sort != null) {
return Stage.SORT;
- } else if (postProject != null) {
- return Stage.POST_PROJECT;
+ } else if (aggregateProject != null) {
+ return Stage.AGGREGATE_PROJECT;
} else if (havingFilter != null) {
return Stage.HAVING_FILTER;
} else if (aggregate != null) {
@@ -309,10 +347,12 @@ public class PartialDruidQuery
final Stage currentStage = stage();
switch (currentStage) {
+ case SORT_PROJECT:
+ return sortProject;
case SORT:
return sort;
- case POST_PROJECT:
- return postProject;
+ case AGGREGATE_PROJECT:
+ return aggregateProject;
case HAVING_FILTER:
return havingFilter;
case AGGREGATE:
@@ -353,14 +393,25 @@ public class PartialDruidQuery
Objects.equals(selectSort, that.selectSort) &&
Objects.equals(aggregate, that.aggregate) &&
Objects.equals(havingFilter, that.havingFilter) &&
- Objects.equals(postProject, that.postProject) &&
- Objects.equals(sort, that.sort);
+ Objects.equals(aggregateProject, that.aggregateProject) &&
+ Objects.equals(sort, that.sort) &&
+ Objects.equals(sortProject, that.sortProject);
}
@Override
public int hashCode()
{
- return Objects.hash(scan, whereFilter, selectProject, selectSort, aggregate, havingFilter, postProject, sort);
+ return Objects.hash(
+ scan,
+ whereFilter,
+ selectProject,
+ selectSort,
+ aggregate,
+ havingFilter,
+ aggregateProject,
+ sort,
+ sortProject
+ );
}
@Override
@@ -373,8 +424,9 @@ public class PartialDruidQuery
", selectSort=" + selectSort +
", aggregate=" + aggregate +
", havingFilter=" + havingFilter +
- ", postProject=" + postProject +
+ ", aggregateProject=" + aggregateProject +
", sort=" + sort +
+ ", sortProject=" + sortProject +
'}';
}
}
diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
new file mode 100644
index 0000000..c00aff9
--- /dev/null
+++ b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to Metamarkets Group Inc. (Metamarkets) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. Metamarkets licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package io.druid.sql.calcite.rel;
+
+import com.google.common.base.Preconditions;
+import io.druid.java.util.common.ISE;
+import io.druid.query.aggregation.PostAggregator;
+import io.druid.sql.calcite.table.RowSignature;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+
+public class SortProject
+{
+ private final RowSignature inputRowSignature;
+ private final List<PostAggregator> postAggregators;
+ private final RowSignature outputRowSignature;
+
+ SortProject(
+ RowSignature inputRowSignature,
+ List<PostAggregator> postAggregators,
+ RowSignature outputRowSignature
+ )
+ {
+ this.inputRowSignature = Preconditions.checkNotNull(inputRowSignature, "inputRowSignature");
+ this.postAggregators = Preconditions.checkNotNull(postAggregators, "postAggregators");
+ this.outputRowSignature = Preconditions.checkNotNull(outputRowSignature, "outputRowSignature");
+
+ // Verify no collisions.
+ final Set<String> seen = new HashSet<>();
+ inputRowSignature.getRowOrder().forEach(field -> {
+ if (!seen.add(field)) {
+ throw new ISE("Duplicate field name: %s", field);
+ }
+ });
+
+ for (PostAggregator postAggregator : postAggregators) {
+ if (postAggregator == null) {
+ throw new ISE("aggregation[%s] is not a postAggregator", postAggregator);
+ }
+ if (!seen.add(postAggregator.getName())) {
+ throw new ISE("Duplicate field name: %s", postAggregator.getName());
+ }
+ }
+
+ // Verify that items in the output signature exist.
+ outputRowSignature.getRowOrder().forEach(field -> {
+ if (!seen.contains(field)) {
+ throw new ISE("Missing field in rowOrder: %s", field);
+ }
+ });
+ }
+
+ public List<PostAggregator> getPostAggregators()
+ {
+ return postAggregators;
+ }
+
+ public RowSignature getOutputRowSignature()
+ {
+ return outputRowSignature;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ SortProject sortProject = (SortProject) o;
+ return Objects.equals(inputRowSignature, sortProject.inputRowSignature) &&
+ Objects.equals(postAggregators, sortProject.postAggregators) &&
+ Objects.equals(outputRowSignature, sortProject.outputRowSignature);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(inputRowSignature, postAggregators, outputRowSignature);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "SortProject{" +
+ "inputRowSignature=" + inputRowSignature +
+ ", postAggregators=" + postAggregators +
+ ", outputRowSignature=" + outputRowSignature +
+ '}';
+ }
+}
diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
index b565aba..c2c6208 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java
@@ -68,8 +68,8 @@ public class DruidRules
),
new DruidQueryRule<>(
Project.class,
- PartialDruidQuery.Stage.POST_PROJECT,
- PartialDruidQuery::withPostProject
+ PartialDruidQuery.Stage.AGGREGATE_PROJECT,
+ PartialDruidQuery::withAggregateProject
),
new DruidQueryRule<>(
Filter.class,
@@ -81,10 +81,16 @@ public class DruidRules
PartialDruidQuery.Stage.SORT,
PartialDruidQuery::withSort
),
+ new DruidQueryRule<>(
+ Project.class,
+ PartialDruidQuery.Stage.SORT_PROJECT,
+ PartialDruidQuery::withSortProject
+ ),
DruidOuterQueryRule.AGGREGATE,
DruidOuterQueryRule.FILTER_AGGREGATE,
DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE,
- DruidOuterQueryRule.PROJECT_AGGREGATE
+ DruidOuterQueryRule.PROJECT_AGGREGATE,
+ DruidOuterQueryRule.AGGREGATE_SORT_PROJECT
);
}
@@ -227,6 +233,32 @@ public class DruidRules
}
};
+ public static RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule(
+ operand(Project.class, operand(Sort.class, operand(Aggregate.class, operand(DruidRel.class, any())))),
+ "AGGREGATE_SORT_PROJECT"
+ )
+ {
+ @Override
+ public void onMatch(RelOptRuleCall call)
+ {
+ final Project sortProject = call.rel(0);
+ final Sort sort = call.rel(1);
+ final Aggregate aggregate = call.rel(2);
+ final DruidRel druidRel = call.rel(3);
+
+ final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
+ druidRel,
+ PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
+ .withAggregate(aggregate)
+ .withSort(sort)
+ .withSortProject(sortProject)
+ );
+ if (outerQueryRel.isValidDruidQuery()) {
+ call.transformTo(outerQueryRel);
+ }
+ }
+ };
+
public DruidOuterQueryRule(final RelOptRuleOperand op, final String description)
{
super(op, StringUtils.format("%s(%s)", DruidOuterQueryRel.class.getSimpleName(), description));
diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
index 5376ff1..9ef0430 100644
--- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
+++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java
@@ -24,6 +24,7 @@ import com.google.common.base.Predicates;
import io.druid.sql.calcite.planner.PlannerConfig;
import io.druid.sql.calcite.rel.DruidRel;
import io.druid.sql.calcite.rel.DruidSemiJoin;
+import io.druid.sql.calcite.rel.PartialDruidQuery;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
@@ -115,15 +116,18 @@ public class DruidSemiJoinRule extends RelOptRule
return;
}
- final Project rightPostProject = right.getPartialDruidQuery().getPostProject();
+ final PartialDruidQuery rightQuery = right.getPartialDruidQuery();
+ final Project rightProject = rightQuery.getSortProject() != null ?
+ rightQuery.getSortProject() :
+ rightQuery.getAggregateProject();
int i = 0;
for (int joinRef : joinInfo.rightSet()) {
final int aggregateRef;
- if (rightPostProject == null) {
+ if (rightProject == null) {
aggregateRef = joinRef;
} else {
- final RexNode projectExp = rightPostProject.getChildExps().get(joinRef);
+ final RexNode projectExp = rightProject.getChildExps().get(joinRef);
if (projectExp.isA(SqlKind.INPUT_REF)) {
aggregateRef = ((RexInputRef) projectExp).getIndex();
} else {
diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
index 5bfa8a3..dccdc5f 100644
--- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
@@ -73,6 +73,7 @@ import io.druid.query.groupby.GroupByQuery;
import io.druid.query.groupby.having.DimFilterHavingSpec;
import io.druid.query.groupby.orderby.DefaultLimitSpec;
import io.druid.query.groupby.orderby.OrderByColumnSpec;
+import io.druid.query.groupby.orderby.OrderByColumnSpec.Direction;
import io.druid.query.lookup.RegisteredLookupExtractionFn;
import io.druid.query.ordering.StringComparator;
import io.druid.query.ordering.StringComparators;
@@ -123,6 +124,7 @@ import org.junit.rules.TemporaryFolder;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -6446,6 +6448,193 @@ public class CalciteQueryTest extends CalciteTestBase
);
}
+ @Test
+ public void testProjectAfterSort() throws Exception
+ {
+ testQuery(
+ "select dim1 from (select dim1, dim2, count(*) cnt from druid.foo group by dim1, dim2 order by cnt)",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(
+ DIMS(
+ new DefaultDimensionSpec("dim1", "d0"),
+ new DefaultDimensionSpec("dim2", "d1")
+ )
+ )
+ .setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
+ .setLimitSpec(
+ new DefaultLimitSpec(
+ Collections.singletonList(
+ new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
+ ),
+ Integer.MAX_VALUE
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{""},
+ new Object[]{"1"},
+ new Object[]{"10.1"},
+ new Object[]{"2"},
+ new Object[]{"abc"},
+ new Object[]{"def"}
+ )
+ );
+ }
+
+ @Test
+ public void testProjectAfterSort2() throws Exception
+ {
+ testQuery(
+ "select s / cnt, dim1, dim2, s from (select dim1, dim2, count(*) cnt, sum(m2) s from druid.foo group by dim1, dim2 order by cnt)",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(
+ DIMS(
+ new DefaultDimensionSpec("dim1", "d0"),
+ new DefaultDimensionSpec("dim2", "d1")
+ )
+ )
+ .setAggregatorSpecs(
+ AGGS(new CountAggregatorFactory("a0"), new DoubleSumAggregatorFactory("a1", "m2"))
+ )
+ .setPostAggregatorSpecs(Collections.singletonList(EXPRESSION_POST_AGG("p0", "(\"a1\" / \"a0\")")))
+ .setLimitSpec(
+ new DefaultLimitSpec(
+ Collections.singletonList(
+ new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
+ ),
+ Integer.MAX_VALUE
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{1.0, "", "a", 1.0},
+ new Object[]{4.0, "1", "a", 4.0},
+ new Object[]{2.0, "10.1", "", 2.0},
+ new Object[]{3.0, "2", "", 3.0},
+ new Object[]{6.0, "abc", "", 6.0},
+ new Object[]{5.0, "def", "abc", 5.0}
+ )
+ );
+ }
+
+ @Test
+ public void testProjectAfterSort3() throws Exception
+ {
+ testQuery(
+ "select dim1 from (select dim1, dim1, count(*) cnt from druid.foo group by dim1, dim1 order by cnt)",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(
+ DIMS(
+ new DefaultDimensionSpec("dim1", "d0")
+ )
+ )
+ .setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
+ .setLimitSpec(
+ new DefaultLimitSpec(
+ Collections.singletonList(
+ new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
+ ),
+ Integer.MAX_VALUE
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{""},
+ new Object[]{"1"},
+ new Object[]{"10.1"},
+ new Object[]{"2"},
+ new Object[]{"abc"},
+ new Object[]{"def"}
+ )
+ );
+ }
+
+ @Test
+ public void testSortProjectAfterNestedGroupBy() throws Exception
+ {
+ testQuery(
+ "SELECT "
+ + " cnt "
+ + "FROM ("
+ + " SELECT "
+ + " __time, "
+ + " dim1, "
+ + " COUNT(m2) AS cnt "
+ + " FROM ("
+ + " SELECT "
+ + " __time, "
+ + " m2, "
+ + " dim1 "
+ + " FROM druid.foo "
+ + " GROUP BY __time, m2, dim1 "
+ + " ) "
+ + " GROUP BY __time, dim1 "
+ + " ORDER BY cnt"
+ + ")",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(DIMS(
+ new DefaultDimensionSpec("__time", "d0", ValueType.LONG),
+ new DefaultDimensionSpec("dim1", "d1"),
+ new DefaultDimensionSpec("m2", "d2", ValueType.DOUBLE)
+ ))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ )
+ .setInterval(QSS(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(DIMS(
+ new DefaultDimensionSpec("d0", "_d0", ValueType.LONG),
+ new DefaultDimensionSpec("d1", "_d1", ValueType.STRING)
+ ))
+ .setAggregatorSpecs(AGGS(
+ new CountAggregatorFactory("a0")
+ ))
+ .setLimitSpec(
+ new DefaultLimitSpec(
+ Collections.singletonList(
+ new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
+ ),
+ Integer.MAX_VALUE
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{1L},
+ new Object[]{1L},
+ new Object[]{1L},
+ new Object[]{1L},
+ new Object[]{1L},
+ new Object[]{1L}
+ )
+ );
+ }
+
private void testQuery(
final String sql,
final List<Query> expectedQueries,
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org