You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2023/07/11 02:49:06 UTC
[pinot] branch master updated: [multistage] Fix aggregate type issues (#11068)
This is an automated email from the ASF dual-hosted git repository.
rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 19dce0e282 [multistage] Fix aggregate type issues (#11068)
19dce0e282 is described below
commit 19dce0e282ccd40f3da0fa2c8168efbd812f60fd
Author: Rong Rong <ro...@apache.org>
AuthorDate: Mon Jul 10 19:49:00 2023 -0700
[multistage] Fix aggregate type issues (#11068)
1. cleaned up on type issues for agg
- pinot sql agg functions are no longer pre-registered with calcite which caused type during SqlValidator
- dynamic pinot sql agg function creating is being used to avoid type inference not applying properly (such as inferring from ARG0 cannot be applied statically b/c then ARG0 refers to the input to the intermediate type, not the original DIRECT type)
2. revisit intermediate type inference when registering built-in agg (such as MIN/MAX) instead of doing raw casting from number/boolean to double/int when extracting value from input blocks
- follow up by registering explicit intermediate result type
- also follow up when `ReturnType.explicit(*)` is not possible in final result type, by working around from available validation type from RelBuilder rule call object.
3. fixed boolean handling (except bool_and/bool_or which is handled separately by #11033)
4. re-adjusted direct project rule, by using calcite AGG_PROJECT_EXPAND template
- follow up needed to ensure the direct aggregate, (see TODO#1, 2)
5. re-adjusted agg-exchange-rule
- no longer generates multiple agg now we only generate DIRECT/LEAF/FINAL (INTERMEDIATE not supported yet)
- un-generalized the rule NOT to reuse convert-agg-node and build-agg-call. They don't share logic, thus, the generic causes more problems with various nullable fields.
- removed unnecessary logic for discovering field-type/field-name/field-used, which already exists in calcite RelOptUtils
6. modified test plan generated, removed unnecessary additional agg
---------
Co-authored-by: Rong Rong <ro...@startree.ai>
---
.../PinotAggregateExchangeNodeInsertRule.java | 464 ++++++++-------------
.../rules/PinotAggregateReduceFunctionsRule.java | 427 -------------------
.../calcite/rel/rules/PinotQueryRuleSets.java | 2 +-
.../apache/calcite/sql/fun/PinotOperatorTable.java | 140 +------
.../apache/pinot/query/QueryCompilationTest.java | 1 -
.../pinot/query/QueryEnvironmentTestBase.java | 5 +-
.../src/test/resources/queries/AggregatePlans.json | 183 ++------
.../src/test/resources/queries/GroupByPlans.json | 308 ++++----------
.../src/test/resources/queries/JoinPlans.json | 77 ++--
.../src/test/resources/queries/OrderByPlans.json | 32 +-
.../test/resources/queries/PinotHintablePlans.json | 64 +--
.../resources/queries/WindowFunctionPlans.json | 228 ++++------
.../query/runtime/operator/AggregateOperator.java | 47 ++-
.../query/runtime/operator/FilterOperator.java | 4 +-
.../query/runtime/operator/HashJoinOperator.java | 4 +-
.../LeafStageTransferableBlockOperator.java | 44 +-
.../operator/MultistageAggregationExecutor.java | 109 ++---
.../operator/MultistageGroupByExecutor.java | 128 ++----
.../query/runtime/operator/TransformOperator.java | 5 +-
.../runtime/operator/operands/FilterOperand.java | 6 +-
.../operator/utils/FunctionInvokeUtils.java | 45 --
.../query/runtime/operator/utils/TypeUtils.java | 88 ++++
.../runtime/operator/AggregateOperatorTest.java | 40 +-
.../pinot/segment/spi/AggregationFunctionType.java | 186 ++++-----
24 files changed, 786 insertions(+), 1851 deletions(-)
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
index 199619d14b..7792a0b4e0 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -26,7 +26,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
-import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.hep.HepRelVertex;
@@ -39,49 +38,45 @@ import org.apache.calcite.rel.hint.PinotHintOptions;
import org.apache.calcite.rel.hint.PinotHintStrategyTable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
-import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.PinotLogicalExchange;
-import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.PinotSqlAggFunction;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.PinotOperatorTable;
-import org.apache.calcite.sql.type.SqlReturnTypeInference;
+import org.apache.calcite.sql.type.OperandTypes;
+import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mapping;
+import org.apache.calcite.util.mapping.MappingType;
+import org.apache.calcite.util.mapping.Mappings;
import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
import org.apache.pinot.segment.spi.AggregationFunctionType;
/**
- * Special rule for Pinot, this rule is fixed to generate a 3-stage aggregation split between the
- * (1) non-data-locale Pinot server agg stage, (2) the data-locale Pinot intermediate agg stage, and
- * (3) final result Pinot final agg stage.
+ * Special rule for Pinot, this rule is fixed to generate a 2-stage aggregation split between the
+ * (1) non-data-locale Pinot server agg stage, and (2) the data-locale Pinot intermediate agg stage.
*
* Pinot uses special intermediate data representation for partially aggregated results, thus we can't use
* {@link org.apache.calcite.rel.rules.AggregateReduceFunctionsRule} to reduce complex aggregation.
*
- * This rule is here to introduces Pinot-special aggregation splits. In-general, all aggregations are split into
- * final-stage AGG, intermediate-stage AGG, and server-stage AGG with the same naming. E.g.
+ * This rule is here to introduces Pinot-special aggregation splits. In-general there are several options:
+ * <ul>
+ * <li>`aggName`__DIRECT</li>
+ * <li>`aggName`__LEAF + `aggName`__FINAL</li>
+ * <li>`aggName`__LEAF [+ `aggName`__INTERMEDIATE] + `aggName`__FINAL</li>
+ * </ul>
*
- * COUNT(*) transforms into: COUNT(*)_SERVER --> COUNT(*)_INTERMEDIATE --> COUNT(*)_FINAL, where
- * COUNT(*)_SERVER produces TUPLE[ COUNT(data), GROUP_BY_KEY ]
- * COUNT(*)_INTERMEDIATE produces TUPLE[ SUM(COUNT(*)_SERVER), GROUP_BY_KEY ] (intermediate result here is the count)
- * COUNT(*)_FINAL produces the final TUPLE[ FINAL_COUNT, GROUP_BY_KEY ]
- *
- * Taking an example of a function which has a different intermediate object representation than the final result:
- * KURTOSIS(*) transforms into: 4THMOMENT(*)_SERVER --> 4THMOMENT(*)_INTERMEDIATE --> KURTOSIS(*)_FINAL, where
- * FOURTHMOMENT(*)_SERVER produces TUPLE[ 4THMOMENT(data) object, GROUP_BY_KEY ] (input: rowType, output: object)
- * FOURTHMOMENT(*)_INTERMEDIATE produces TUPLE[ 4THMOMENT(4THMOMENT(*)_SERVER), GROUP_BY_KEY ] (input, output: object)
- * KURTOSIS(*)_FINAL produces the final TUPLE[ KURTOSIS(4THMOMENT(*)_INTERMEDIATE), GROUP_BY_KEY ]
- * (input: object, output: double)
- *
- * However, the suffix _SERVER/_INTERMEDIATE/_FINAL is merely a SQL hint to the Aggregate operator and will be
- * translated into correct, actual operator chain during Physical plan.
+ * for example:
+ * - COUNT(*) with a GROUP_BY_KEY transforms into: COUNT(*)__LEAF --> COUNT(*)__FINAL, where
+ * - COUNT(*)__LEAF produces TUPLE[ SUM(1), GROUP_BY_KEY ]
+ * - COUNT(*)__FINAL produces TUPLE[ SUM(COUNT(*)__LEAF), GROUP_BY_KEY ]
*/
public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
public static final PinotAggregateExchangeNodeInsertRule INSTANCE =
@@ -117,336 +112,241 @@ public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
Aggregate oldAggRel = call.rel(0);
ImmutableList<RelHint> oldHints = oldAggRel.getHints();
- // If the "is_partitioned_by_group_by_keys" aggregate hint option is set, just add additional hints indicating
- // this is a single stage aggregation and intermediate stage. This only applies to GROUP BY aggregations.
+ Aggregate newAgg;
if (!oldAggRel.getGroupSet().isEmpty() && PinotHintStrategyTable.containsHintOption(oldHints,
PinotHintOptions.AGGREGATE_HINT_OPTIONS, PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) {
- // 1. attach intermediate agg and skip leaf stage RelHints to original agg.
+ // ------------------------------------------------------------------------
+ // If the "is_partitioned_by_group_by_keys" aggregate hint option is set, just add additional hints indicating
+ // this is a single stage aggregation.
ImmutableList<RelHint> newLeafAggHints =
new ImmutableList.Builder<RelHint>().addAll(oldHints).add(createAggHint(AggType.DIRECT)).build();
- Aggregate newAgg =
+ newAgg =
new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newLeafAggHints, oldAggRel.getInput(),
oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), oldAggRel.getAggCallList());
- call.transformTo(newAgg);
- return;
- }
-
- // If "is_skip_leaf_stage_group_by" SQLHint option is passed, the leaf stage aggregation is skipped.
- if (!oldAggRel.getGroupSet().isEmpty() && PinotHintStrategyTable.containsHintOption(oldHints,
+ } else if (!oldAggRel.getGroupSet().isEmpty() && PinotHintStrategyTable.containsHintOption(oldHints,
PinotHintOptions.AGGREGATE_HINT_OPTIONS,
PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)) {
- // This is not the default path. Use this group by optimization to skip leaf stage aggregation when aggregating
- // at leaf level could be wasted effort. eg: when cardinality of group by column is very high.
- Aggregate newAgg = (Aggregate) createPlanWithoutLeafAggregation(call);
- call.transformTo(newAgg);
- return;
+ // ------------------------------------------------------------------------
+ // If "is_skip_leaf_stage_group_by" SQLHint option is passed, the leaf stage aggregation is skipped.
+ newAgg = (Aggregate) createPlanWithExchangeDirectAggregation(call);
+ } else {
+ // ------------------------------------------------------------------------
+ newAgg = (Aggregate) createPlanWithLeafExchangeFinalAggregate(call);
}
+ call.transformTo(newAgg);
+ }
+ /**
+ * Aggregate node will be split into LEAF + exchange + FINAL.
+ * optionally we can insert INTERMEDIATE to reduce hotspot in the future.
+ */
+ private RelNode createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call) {
+ // TODO: add optional intermediate stage here when hinted.
+ Aggregate oldAggRel = call.rel(0);
// 1. attach leaf agg RelHint to original agg. Perform any aggregation call conversions necessary
- ImmutableList<RelHint> newLeafAggHints =
- new ImmutableList.Builder<RelHint>().addAll(oldHints).add(createAggHint(AggType.LEAF)).build();
- Aggregate newLeafAgg =
- new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newLeafAggHints, oldAggRel.getInput(),
- oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), convertLeafAggCalls(oldAggRel));
-
+ Aggregate leafAgg = convertAggForLeafInput(oldAggRel);
// 2. attach exchange.
List<Integer> groupSetIndices = ImmutableIntList.range(0, oldAggRel.getGroupCount());
PinotLogicalExchange exchange;
if (groupSetIndices.size() == 0) {
- exchange = PinotLogicalExchange.create(newLeafAgg, RelDistributions.hash(Collections.emptyList()));
+ exchange = PinotLogicalExchange.create(leafAgg, RelDistributions.hash(Collections.emptyList()));
} else {
- exchange = PinotLogicalExchange.create(newLeafAgg, RelDistributions.hash(groupSetIndices));
+ exchange = PinotLogicalExchange.create(leafAgg, RelDistributions.hash(groupSetIndices));
}
+ // 3. attach final agg stage.
+ return convertAggFromIntermediateInput(call, oldAggRel, exchange, AggType.FINAL);
+ }
- // 3. attach intermediate agg stage.
- RelNode newAggNode = makeNewIntermediateAgg(call, oldAggRel, exchange, AggType.INTERMEDIATE, null, null);
+ /**
+ * Use this group by optimization to skip leaf stage aggregation when aggregating at leaf level is not desired.
+ * Many situation could be wasted effort to do group-by on leaf, eg: when cardinality of group by column is very high.
+ */
+ private RelNode createPlanWithExchangeDirectAggregation(RelOptRuleCall call) {
+ Aggregate oldAggRel = call.rel(0);
+ ImmutableList<RelHint> oldHints = oldAggRel.getHints();
+ ImmutableList<RelHint> newHints =
+ new ImmutableList.Builder<RelHint>().addAll(oldHints).add(createAggHint(AggType.DIRECT)).build();
- // 4. attach final agg stage if aggregations are present.
- RelNode transformToAgg = newAggNode;
- if (oldAggRel.getAggCallList() != null && oldAggRel.getAggCallList().size() > 0) {
- transformToAgg = makeNewFinalAgg(call, oldAggRel, newAggNode);
+ // create project when there's none below the aggregate to reduce exchange overhead
+ RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
+ if (!(childRel instanceof Project)) {
+ return convertAggForExchangeDirectAggregate(call, newHints);
+ } else {
+ // create normal exchange
+ List<Integer> groupSetIndices = new ArrayList<>();
+ oldAggRel.getGroupSet().forEach(groupSetIndices::add);
+ PinotLogicalExchange exchange = PinotLogicalExchange.create(childRel, RelDistributions.hash(groupSetIndices));
+ return new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newHints, exchange,
+ oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), oldAggRel.getAggCallList());
}
+ }
- call.transformTo(transformToAgg);
+ /**
+ * The following is copied from {@link AggregateExtractProjectRule#onMatch(RelOptRuleCall)}
+ * with modification to insert an exchange in between the Aggregate and Project
+ */
+ private RelNode convertAggForExchangeDirectAggregate(RelOptRuleCall call, ImmutableList<RelHint> newHints) {
+ final Aggregate aggregate = call.rel(0);
+ final RelNode input = aggregate.getInput();
+ // Compute which input fields are used.
+ // 1. group fields are always used
+ final ImmutableBitSet.Builder inputFieldsUsed =
+ aggregate.getGroupSet().rebuild();
+ // 2. agg functions
+ for (AggregateCall aggCall : aggregate.getAggCallList()) {
+ for (int i : aggCall.getArgList()) {
+ inputFieldsUsed.set(i);
+ }
+ if (aggCall.filterArg >= 0) {
+ inputFieldsUsed.set(aggCall.filterArg);
+ }
+ }
+ final RelBuilder relBuilder1 = call.builder().push(input);
+ final List<RexNode> projects = new ArrayList<>();
+ final Mapping mapping =
+ Mappings.create(MappingType.INVERSE_SURJECTION,
+ aggregate.getInput().getRowType().getFieldCount(),
+ inputFieldsUsed.cardinality());
+ int j = 0;
+ for (int i : inputFieldsUsed.build()) {
+ projects.add(relBuilder1.field(i));
+ mapping.set(i, j++);
+ }
+ relBuilder1.project(projects);
+ final ImmutableBitSet newGroupSet =
+ Mappings.apply(mapping, aggregate.getGroupSet());
+ Project project = (Project) relBuilder1.build();
+
+ // ------------------------------------------------------------------------
+ PinotLogicalExchange exchange = PinotLogicalExchange.create(project, RelDistributions.hash(newGroupSet.asList()));
+ // ------------------------------------------------------------------------
+
+ final RelBuilder relBuilder2 = call.builder().push(exchange);
+ final List<ImmutableBitSet> newGroupSets =
+ aggregate.getGroupSets().stream()
+ .map(bitSet -> Mappings.apply(mapping, bitSet))
+ .collect(Util.toImmutableList());
+ final List<RelBuilder.AggCall> newAggCallList =
+ aggregate.getAggCallList().stream()
+ .map(aggCall -> relBuilder2.aggregateCall(aggCall, mapping))
+ .collect(Util.toImmutableList());
+ final RelBuilder.GroupKey groupKey =
+ relBuilder2.groupKey(newGroupSet, newGroupSets);
+ relBuilder2.aggregate(groupKey, newAggCallList).hints(newHints);
+ return relBuilder2.build();
}
- private List<AggregateCall> convertLeafAggCalls(Aggregate oldAggRel) {
+ private Aggregate convertAggForLeafInput(Aggregate oldAggRel) {
List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
List<AggregateCall> newCalls = new ArrayList<>();
for (AggregateCall oldCall : oldCalls) {
- newCalls.add(buildAggregateCall(oldAggRel.getInput(), oldCall.getArgList(), oldAggRel.getGroupCount(), oldCall,
- oldCall, false));
+ newCalls.add(buildAggregateCall(oldAggRel.getInput(), oldCall, oldCall.getArgList(), oldAggRel.getGroupCount(),
+ AggType.LEAF));
}
- return newCalls;
+ ImmutableList<RelHint> oldHints = oldAggRel.getHints();
+ ImmutableList<RelHint> newHints =
+ new ImmutableList.Builder<RelHint>().addAll(oldHints).add(createAggHint(AggType.LEAF)).build();
+ return new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newHints, oldAggRel.getInput(),
+ oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls);
}
- private RelNode makeNewIntermediateAgg(RelOptRuleCall ruleCall, Aggregate oldAggRel, PinotLogicalExchange exchange,
- AggType aggType, @Nullable List<Integer> argList, @Nullable List<Integer> groupByList) {
-
+ private RelNode convertAggFromIntermediateInput(RelOptRuleCall ruleCall, Aggregate oldAggRel,
+ PinotLogicalExchange exchange, AggType aggType) {
// add the exchange as the input node to the relation builder.
RelBuilder relBuilder = ruleCall.builder();
relBuilder.push(exchange);
- // make input ref to the exchange after the leaf aggregate.
+ // make input ref to the exchange after the leaf aggregate, all groups should be at the front
RexBuilder rexBuilder = exchange.getCluster().getRexBuilder();
final int nGroups = oldAggRel.getGroupCount();
for (int i = 0; i < nGroups; i++) {
rexBuilder.makeInputRef(oldAggRel, i);
}
- // create new aggregate function calls from exchange input.
- List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
List<AggregateCall> newCalls = new ArrayList<>();
Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
- for (int oldCallIndex = 0; oldCallIndex < oldCalls.size(); oldCallIndex++) {
- AggregateCall oldCall = oldCalls.get(oldCallIndex);
- convertIntermediateAggCall(rexBuilder, oldAggRel, oldCallIndex, oldCall, newCalls, aggCallMapping,
- aggType, argList, exchange);
- }
-
- // create new aggregate relation.
- ImmutableList<RelHint> orgHints = oldAggRel.getHints();
- // if the aggregation isn't split between intermediate and final stages, indicate that this is a single stage
- // aggregation so that the execution engine knows whether to aggregate or merge
- ImmutableList<RelHint> newIntermediateAggHints =
- new ImmutableList.Builder<RelHint>().addAll(orgHints).add(createAggHint(aggType)).build();
- ImmutableBitSet groupSet = groupByList == null ? ImmutableBitSet.range(nGroups) : ImmutableBitSet.of(groupByList);
- relBuilder.aggregate(
- relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)),
- newCalls);
- relBuilder.hints(newIntermediateAggHints);
- return relBuilder.build();
- }
-
- /**
- * convert aggregate call based on the intermediate stage input.
- *
- * <p>Note that the intermediate stage input only supports splittable aggregators such as SUM/MIN/MAX.
- * All non-splittable aggregator must be converted into splittable aggregator first.
- */
- private static void convertIntermediateAggCall(RexBuilder rexBuilder, Aggregate oldAggRel, int oldCallIndex,
- AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping,
- AggType aggType, List<Integer> argList, PinotLogicalExchange exchange) {
- final int nGroups = oldAggRel.getGroupCount();
- final SqlAggFunction oldAggregation = oldCall.getAggregation();
-
- List<Integer> newArgList;
- if (aggType.isInputIntermediateFormat()) {
- // Make sure COUNT in the intermediate stage takes an argument
- List<Integer> oldArgList = (oldAggregation.getKind() == SqlKind.COUNT && !oldCall.isDistinct())
- ? Collections.singletonList(oldCallIndex)
- : oldCall.getArgList();
- newArgList = convertArgList(nGroups + oldCallIndex, oldArgList);
- } else {
- newArgList = oldCall.getArgList().size() == 0 ? Collections.emptyList()
- : Collections.singletonList(argList.get(oldCallIndex));
- }
- AggregateCall newCall = buildAggregateCall(exchange, newArgList, nGroups, oldCall, oldCall, false);
- rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable);
- }
-
- private RelNode makeNewFinalAgg(RelOptRuleCall ruleCall, Aggregate oldAggRel, RelNode newIntAggNode) {
- // add the intermediate agg node as the input node to the relation builder.
- RelBuilder relBuilder = ruleCall.builder();
- relBuilder.push(newIntAggNode);
-
- Aggregate aggIntNode = (Aggregate) newIntAggNode;
-
- // make input ref to the intermediate agg node after the leaf aggregate.
- RexBuilder rexBuilder = newIntAggNode.getCluster().getRexBuilder();
- final int nGroups = aggIntNode.getGroupCount();
- for (int i = 0; i < nGroups; i++) {
- rexBuilder.makeInputRef(aggIntNode, i);
- }
-
- // create new aggregate function calls from intermediate agg input.
+ // create new aggregate function calls from exchange input, all aggCalls are followed one by one from exchange
+ // b/c the exchange produces intermediate results, thus the input to the newCall will be indexed at
+ // [nGroup + oldCallIndex]
List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
- List<AggregateCall> oldCallsIntAgg = aggIntNode.getAggCallList();
- List<AggregateCall> newCalls = new ArrayList<>();
- Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
-
for (int oldCallIndex = 0; oldCallIndex < oldCalls.size(); oldCallIndex++) {
AggregateCall oldCall = oldCalls.get(oldCallIndex);
- AggregateCall oldCallIntAgg = oldCallsIntAgg.get(oldCallIndex);
- convertFinalAggCall(rexBuilder, aggIntNode, oldCallIndex, oldCall, oldCallIntAgg, newCalls, aggCallMapping);
+ // intermediate stage input only supports single argument inputs.
+ List<Integer> argList = Collections.singletonList(nGroups + oldCallIndex);
+ AggregateCall newCall = buildAggregateCall(exchange, oldCall, argList, nGroups, aggType);
+ rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable);
}
// create new aggregate relation.
ImmutableList<RelHint> orgHints = oldAggRel.getHints();
- ImmutableList<RelHint> newIntermediateAggHints =
- new ImmutableList.Builder<RelHint>().addAll(orgHints).add(createAggHint(AggType.FINAL)).build();
+ ImmutableList<RelHint> newAggHint =
+ new ImmutableList.Builder<RelHint>().addAll(orgHints).add(createAggHint(aggType)).build();
ImmutableBitSet groupSet = ImmutableBitSet.range(nGroups);
- relBuilder.aggregate(
- relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)),
- newCalls);
- relBuilder.hints(newIntermediateAggHints);
+ relBuilder.aggregate(relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)), newCalls);
+ relBuilder.hints(newAggHint);
return relBuilder.build();
}
- /**
- * convert aggregate call based on the final stage input.
- */
- private static void convertFinalAggCall(RexBuilder rexBuilder, Aggregate inputAggRel, int oldCallIndex,
- AggregateCall oldCall, AggregateCall oldCallIntAgg, List<AggregateCall> newCalls, Map<AggregateCall,
- RexNode> aggCallMapping) {
- final int nGroups = inputAggRel.getGroupCount();
- final SqlAggFunction oldAggregation = oldCallIntAgg.getAggregation();
- // Make sure COUNT in the final stage takes an argument
- List<Integer> oldArgList = (oldAggregation.getKind() == SqlKind.COUNT && !oldCallIntAgg.isDistinct())
- ? Collections.singletonList(oldCallIndex)
- : oldCallIntAgg.getArgList();
- List<Integer> newArgList = convertArgList(nGroups + oldCallIndex, oldArgList);
- AggregateCall newCall = buildAggregateCall(inputAggRel, newArgList, nGroups, oldCallIntAgg, oldCall, true);
- rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, inputAggRel.getInput()::fieldIsNullable);
- }
-
- private static AggregateCall buildAggregateCall(RelNode input, List<Integer> newArgList, int numberGroups,
- AggregateCall inputCall, AggregateCall functionNameCall, boolean isFinalStage) {
- final SqlAggFunction oldAggregation = functionNameCall.getAggregation();
- final SqlKind aggKind = oldAggregation.getKind();
- String functionName = getFunctionNameFromAggregateCall(functionNameCall);
+ private static AggregateCall buildAggregateCall(RelNode inputNode, AggregateCall orgAggCall, List<Integer> argList,
+ int numberGroups, AggType aggType) {
+ final SqlAggFunction oldAggFunction = orgAggCall.getAggregation();
+ final SqlKind aggKind = oldAggFunction.getKind();
+ String functionName = getFunctionNameFromAggregateCall(orgAggCall);
+ AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName);
// Check only the supported AGG functions are provided.
- validateAggregationFunctionIsSupported(functionName, aggKind);
-
- AggregationFunctionType type = AggregationFunctionType.getAggregationFunctionType(functionName);
- // Use the actual function name and return type for final stage to ensure that for aggregation functions that share
- // leaf and intermediate functions, we can correctly extract the correct final result. e.g. KURTOSIS and SKEWNESS
- // both use FOURTHMOMENT
- String aggregationFunctionName = isFinalStage ? type.getName().toUpperCase(Locale.ROOT)
- : type.getIntermediateFunctionName().toUpperCase(Locale.ROOT);
- SqlReturnTypeInference returnTypeInference = isFinalStage ? type.getSqlReturnTypeInference()
- : type.getSqlIntermediateReturnTypeInference();
- SqlAggFunction sqlAggFunction =
- new PinotSqlAggFunction(aggregationFunctionName, type.getSqlIdentifier(), type.getSqlKind(),
- returnTypeInference, type.getSqlOperandTypeInference(), type.getSqlOperandTypeChecker(),
- type.getSqlFunctionCategory());
+ validateAggregationFunctionIsSupported(functionType.getName(), aggKind);
+ // create the aggFunction
+ SqlAggFunction sqlAggFunction;
+ if (functionType.getIntermediateReturnTypeInference() != null) {
+ switch (aggType) {
+ case LEAF:
+ sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
+ functionType.getSqlKind(), functionType.getIntermediateReturnTypeInference(), null,
+ functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory());
+ break;
+ case INTERMEDIATE:
+ sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
+ functionType.getSqlKind(), functionType.getIntermediateReturnTypeInference(), null,
+ OperandTypes.ANY, functionType.getSqlFunctionCategory());
+ break;
+ case FINAL:
+ sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
+ functionType.getSqlKind(), ReturnTypes.explicit(orgAggCall.getType()), null,
+ OperandTypes.ANY, functionType.getSqlFunctionCategory());
+ break;
+ default:
+ throw new UnsupportedOperationException("Unsuppoted aggType: " + aggType + " for " + functionName);
+ }
+ } else {
+ sqlAggFunction = oldAggFunction;
+ }
return AggregateCall.create(sqlAggFunction,
- functionName.equals("distinctCount") || inputCall.isDistinct(),
- inputCall.isApproximate(),
- inputCall.ignoreNulls(),
- newArgList,
- inputCall.filterArg,
- inputCall.distinctKeys,
- inputCall.collation,
+ functionName.equals("distinctCount") || orgAggCall.isDistinct(),
+ orgAggCall.isApproximate(),
+ orgAggCall.ignoreNulls(),
+ argList,
+ orgAggCall.filterArg,
+ orgAggCall.distinctKeys,
+ orgAggCall.collation,
numberGroups,
- input,
+ inputNode,
null,
null);
}
- private static List<Integer> convertArgList(int oldCallIndexWithShift, List<Integer> argList) {
- Preconditions.checkArgument(argList.size() <= 1,
- "Unable to convert call as the argList contains more than 1 argument");
- return argList.size() == 1 ? Collections.singletonList(oldCallIndexWithShift) : Collections.emptyList();
- }
-
private static String getFunctionNameFromAggregateCall(AggregateCall aggregateCall) {
return aggregateCall.getAggregation().getName().equalsIgnoreCase("COUNT") && aggregateCall.isDistinct()
? "distinctCount" : aggregateCall.getAggregation().getName();
}
private static void validateAggregationFunctionIsSupported(String functionName, SqlKind aggKind) {
- Preconditions.checkState(PinotOperatorTable.isAggregationFunctionRegisteredWithOperatorTable(functionName)
- || PinotOperatorTable.isAggregationKindSupported(aggKind),
+ Preconditions.checkState(PinotOperatorTable.isAggregationFunctionRegisteredWithOperatorTable(functionName),
String.format("Failed to create aggregation. Unsupported SQL aggregation kind: %s or function name: %s. "
+ "Only splittable aggregation functions are supported!", aggKind, functionName));
}
- private RelNode createPlanWithoutLeafAggregation(RelOptRuleCall call) {
- Aggregate oldAggRel = call.rel(0);
- RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
- LogicalProject project;
-
- List<Integer> newAggArgColumns = new ArrayList<>();
- List<Integer> newAggGroupByColumns = new ArrayList<>();
-
- // 1. Create the LogicalProject node if it does not exist. This is to send only the relevant columns over
- // the wire for intermediate aggregation.
- if (childRel instanceof Project) {
- // Avoid creating a new LogicalProject if the child node of aggregation is already a project node.
- project = (LogicalProject) childRel;
- newAggArgColumns = fetchNewAggArgCols(oldAggRel.getAggCallList());
- newAggGroupByColumns = oldAggRel.getGroupSet().asList();
- } else {
- // Create a leaf stage project. This is done so that only the required columns are sent over the wire for
- // intermediate aggregation. If there are multiple aggregations on the same column, the column is projected
- // only once.
- project = createLogicalProjectForAggregate(oldAggRel, newAggArgColumns, newAggGroupByColumns);
- }
-
- // 2. Create an exchange on top of the LogicalProject.
- PinotLogicalExchange exchange = PinotLogicalExchange.create(project, RelDistributions.hash(newAggGroupByColumns));
-
- // 3. Create an intermediate stage aggregation.
- RelNode newAggNode =
- makeNewIntermediateAgg(call, oldAggRel, exchange, AggType.LEAF, newAggArgColumns, newAggGroupByColumns);
-
- // 4. Create the final agg stage node on top of the intermediate agg if aggregations are present.
- RelNode transformToAgg = newAggNode;
- if (oldAggRel.getAggCallList() != null && oldAggRel.getAggCallList().size() > 0) {
- transformToAgg = makeNewFinalAgg(call, oldAggRel, newAggNode);
- }
- return transformToAgg;
- }
-
- private LogicalProject createLogicalProjectForAggregate(Aggregate oldAggRel, List<Integer> newAggArgColumns,
- List<Integer> newAggGroupByCols) {
- RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
- RexBuilder childRexBuilder = childRel.getCluster().getRexBuilder();
- List<RelDataTypeField> fieldList = childRel.getRowType().getFieldList();
-
- List<RexNode> projectColRexNodes = new ArrayList<>();
- List<String> projectColNames = new ArrayList<>();
- // Maintains a mapping from the column to the corresponding index in projectColRexNodes.
- Map<Integer, Integer> projectSet = new HashMap<>();
-
- int projectIndex = 0;
- for (int group : oldAggRel.getGroupSet().asSet()) {
- projectColNames.add(fieldList.get(group).getName());
- projectColRexNodes.add(childRexBuilder.makeInputRef(childRel, group));
- projectSet.put(group, projectColRexNodes.size() - 1);
- newAggGroupByCols.add(projectIndex++);
- }
-
- List<AggregateCall> oldAggCallList = oldAggRel.getAggCallList();
- for (AggregateCall aggregateCall : oldAggCallList) {
- List<Integer> argList = aggregateCall.getArgList();
- if (argList.size() == 0) {
- newAggArgColumns.add(-1);
- continue;
- }
- for (Integer col : argList) {
- if (!projectSet.containsKey(col)) {
- projectColRexNodes.add(childRexBuilder.makeInputRef(childRel, col));
- projectColNames.add(fieldList.get(col).getName());
- projectSet.put(col, projectColRexNodes.size() - 1);
- newAggArgColumns.add(projectColRexNodes.size() - 1);
- } else {
- newAggArgColumns.add(projectSet.get(col));
- }
- }
- }
-
- return LogicalProject.create(childRel, Collections.emptyList(), projectColRexNodes, projectColNames);
- }
-
- private List<Integer> fetchNewAggArgCols(List<AggregateCall> oldAggCallList) {
- List<Integer> newAggArgColumns = new ArrayList<>();
-
- for (AggregateCall aggregateCall : oldAggCallList) {
- if (aggregateCall.getArgList().size() == 0) {
- // This can be true for COUNT. Add a placeholder value which will be ignored.
- newAggArgColumns.add(-1);
- continue;
- }
- newAggArgColumns.addAll(aggregateCall.getArgList());
- }
-
- return newAggArgColumns;
- }
-
private static RelHint createAggHint(AggType aggType) {
return RelHint.builder(PinotHintOptions.INTERNAL_AGG_OPTIONS)
.hintOption(PinotHintOptions.InternalAggregateOptions.AGG_TYPE, aggType.name())
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java
deleted file mode 100644
index a0d9634667..0000000000
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java
+++ /dev/null
@@ -1,427 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.calcite.rel.rules;
-
-import com.google.common.collect.ImmutableSet;
-import java.math.BigDecimal;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-import java.util.Set;
-import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.rel.core.Aggregate;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeFactory;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.sql.PinotSqlAggFunction;
-import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlFunction;
-import org.apache.calcite.sql.SqlFunctionCategory;
-import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.tools.RelBuilder;
-import org.apache.calcite.tools.RelBuilderFactory;
-import org.apache.calcite.util.CompositeList;
-import org.apache.calcite.util.Util;
-import org.apache.pinot.segment.spi.AggregationFunctionType;
-
-
-/**
- * Note: This class copies the logic for reducing SUM and AVG from {@link AggregateReduceFunctionsRule} with some
- * changes to use our Pinot defined operand type checkers and return types. This is necessary otherwise the v1
- * aggregations won't work in the v2 engine due to type issues. Once we fix the return types for the v1 aggregation
- * functions the logic for AVG and SUM can be removed. We also had to resort to using an AVG_REDUCE scalar function
- * due to null handling issues with DIVIDE (returning null on count = 0 via a CASE statement was also not possible
- * as the types of the columns were all non-null and Calcite marks nullable and non-nullable columns as incompatible).
- *
- * We added additional logic to handle typecasting MIN / MAX functions for EVERY / SOME aggregation functions in Calcite
- * which internally uses MIN / MAX with boolean return types. This was necessary because the v1 aggregations for
- * MIN / MAX always return DOUBLE and this caused type issues for certain queries that utilize Calcite's EVERY / SOME
- * aggregation functions.
- *
- * Planner rule that reduces aggregate functions in
- * {@link org.apache.calcite.rel.core.Aggregate}s to simpler forms.
- *
- * <p>Rewrites:
- * <ul>
- *
- * <li>AVG(x) → SUM(x) / COUNT(x)
- *
- * </ul>
- *
- * <p>Since many of these rewrites introduce multiple occurrences of simpler
- * forms like {@code COUNT(x)}, the rule gathers common sub-expressions as it
- * goes.
- *
- * @see CoreRules#AGGREGATE_REDUCE_FUNCTIONS
- */
-public class PinotAggregateReduceFunctionsRule
- extends RelOptRule {
-
- public static final PinotAggregateReduceFunctionsRule INSTANCE =
- new PinotAggregateReduceFunctionsRule(PinotRuleUtils.PINOT_REL_FACTORY);
- //~ Static fields/initializers ---------------------------------------------
-
- protected PinotAggregateReduceFunctionsRule(RelBuilderFactory factory) {
- super(operand(Aggregate.class, any()), factory, null);
- }
-
- private final Set<SqlKind> _functionsToReduce = ImmutableSet.<SqlKind>builder().addAll(SqlKind.AVG_AGG_FUNCTIONS)
- .add(SqlKind.SUM).add(SqlKind.MAX).add(SqlKind.MIN).build();
-
- //~ Constructors -----------------------------------------------------------
-
-
-
- //~ Methods ----------------------------------------------------------------
-
- @Override public boolean matches(RelOptRuleCall call) {
- if (!super.matches(call)) {
- return false;
- }
- Aggregate oldAggRel = (Aggregate) call.rels[0];
- return containsAvgStddevVarCall(oldAggRel.getAggCallList());
- }
-
- @Override public void onMatch(RelOptRuleCall ruleCall) {
- Aggregate oldAggRel = (Aggregate) ruleCall.rels[0];
- reduceAggs(ruleCall, oldAggRel);
- }
-
- /**
- * Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*.
- *
- * @param aggCallList List of aggregate calls
- */
- private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
- return aggCallList.stream().anyMatch(this::canReduce);
- }
-
- /** Returns whether this rule can reduce a given aggregate function call. */
- public boolean canReduce(AggregateCall call) {
- return _functionsToReduce.contains(call.getAggregation().getKind());
- }
-
- /**
- * Reduces calls to functions AVG, SUM, MIN, MAX if the function is
- * present in {@link PinotAggregateReduceFunctionsRule#_functionsToReduce}
- *
- * <p>It handles newly generated common subexpressions since this was done
- * at the sql2rel stage.
- */
- private void reduceAggs(
- RelOptRuleCall ruleCall,
- Aggregate oldAggRel) {
- RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-
- List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
- final int groupCount = oldAggRel.getGroupCount();
-
- final List<AggregateCall> newCalls = new ArrayList<>();
- final Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
-
- final List<RexNode> projList = new ArrayList<>();
-
- // pass through group key
- for (int i = 0; i < groupCount; i++) {
- projList.add(rexBuilder.makeInputRef(oldAggRel, i));
- }
-
- // List of input expressions. If a particular aggregate needs more, it
- // will add an expression to the end, and we will create an extra
- // project.
- final RelBuilder relBuilder = ruleCall.builder();
- relBuilder.push(oldAggRel.getInput());
- final List<RexNode> inputExprs = new ArrayList<>(relBuilder.fields());
-
- // create new aggregate function calls and rest of project list together
- for (AggregateCall oldCall : oldCalls) {
- projList.add(
- reduceAgg(
- oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
- }
-
- final int extraArgCount =
- inputExprs.size() - relBuilder.peek().getRowType().getFieldCount();
- if (extraArgCount > 0) {
- relBuilder.project(inputExprs,
- CompositeList.of(
- relBuilder.peek().getRowType().getFieldNames(),
- Collections.nCopies(extraArgCount, null)));
- }
- newAggregateRel(relBuilder, oldAggRel, newCalls);
- newCalcRel(relBuilder, oldAggRel.getRowType(), projList);
- ruleCall.transformTo(relBuilder.build());
- }
-
- private RexNode reduceAgg(
- Aggregate oldAggRel,
- AggregateCall oldCall,
- List<AggregateCall> newCalls,
- Map<AggregateCall, RexNode> aggCallMapping,
- List<RexNode> inputExprs) {
- if (canReduce(oldCall)) {
- final Integer y;
- final Integer x;
- final SqlKind kind = oldCall.getAggregation().getKind();
- switch (kind) {
- case SUM:
- // replace original SUM(x) with
- // case COUNT(x) when 0 then null else SUM0(x) end
- return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
- case AVG:
- // replace original AVG(x) with SUM(x) / COUNT(x) via an AVG_REDUCE scalar function for null handling
- return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
- case MIN:
- case MAX:
- // typecast to oldCall type (BOOLEAN) if needed to handle EVERY / SOME aggregations. This is essentially a
- // no-op typecasting for normal MIN / MAX aggregations.
- return reduceMinMax(oldAggRel, oldCall, newCalls, aggCallMapping);
- default:
- throw Util.unexpected(kind);
- }
- } else {
- // anything else: preserve original call
- RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
- final int nGroups = oldAggRel.getGroupCount();
- return rexBuilder.addAggCall(oldCall,
- nGroups,
- newCalls,
- aggCallMapping,
- oldAggRel.getInput()::fieldIsNullable);
- }
- }
-
- private static RexNode reduceAvg(
- Aggregate oldAggRel,
- AggregateCall oldCall,
- List<AggregateCall> newCalls,
- Map<AggregateCall, RexNode> aggCallMapping,
- @SuppressWarnings("unused") List<RexNode> inputExprs) {
- final int nGroups = oldAggRel.getGroupCount();
- final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-
- AggregationFunctionType functionTypeSum = AggregationFunctionType.SUM;
- SqlAggFunction sumAggFunc = new PinotSqlAggFunction(functionTypeSum.getName().toUpperCase(),
- functionTypeSum.getSqlIdentifier(), functionTypeSum.getSqlKind(), functionTypeSum.getSqlReturnTypeInference(),
- functionTypeSum.getSqlOperandTypeInference(), functionTypeSum.getSqlOperandTypeChecker(),
- functionTypeSum.getSqlFunctionCategory());
-
- final AggregateCall sumCall =
- AggregateCall.create(sumAggFunc,
- oldCall.isDistinct(),
- oldCall.isApproximate(),
- oldCall.ignoreNulls(),
- oldCall.getArgList(),
- oldCall.filterArg,
- oldCall.distinctKeys,
- oldCall.collation,
- oldAggRel.getGroupCount(),
- oldAggRel.getInput(),
- null,
- null);
- final AggregateCall countCall =
- AggregateCall.create(SqlStdOperatorTable.COUNT,
- oldCall.isDistinct(),
- oldCall.isApproximate(),
- oldCall.ignoreNulls(),
- oldCall.getArgList(),
- oldCall.filterArg,
- oldCall.distinctKeys,
- oldCall.collation,
- oldAggRel.getGroupCount(),
- oldAggRel.getInput(),
- null,
- null);
-
- // NOTE: these references are with respect to the output
- // of newAggRel
- RexNode numeratorRef =
- rexBuilder.addAggCall(sumCall,
- nGroups,
- newCalls,
- aggCallMapping,
- oldAggRel.getInput()::fieldIsNullable);
- final RexNode denominatorRef =
- rexBuilder.addAggCall(countCall,
- nGroups,
- newCalls,
- aggCallMapping,
- oldAggRel.getInput()::fieldIsNullable);
-
- final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
- final RelDataType avgType = typeFactory.createTypeWithNullability(
- oldCall.getType(), numeratorRef.getType().isNullable());
- numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
-
- // TODO: Find a way to correctly use the DIVIDE binary operator instead of reduce function with a COUNT = 0
- // check. Today this does not work due to all types being declared as "NOT NULL", but returning null violates
- // this
- // Special casing AVG to use a scalar function for reducing the results.
- AggregationFunctionType type =
- AggregationFunctionType.getAggregationFunctionType(oldCall.getAggregation().getName());
- SqlFunction function = new SqlFunction(type.getReduceFunctionName().toUpperCase(Locale.ROOT),
- SqlKind.OTHER_FUNCTION, type.getSqlReduceReturnTypeInference(), null,
- type.getSqlReduceOperandTypeChecker(), SqlFunctionCategory.USER_DEFINED_FUNCTION);
- List<RexNode> functionArgs = Arrays.asList(numeratorRef, denominatorRef);
-
- // Use our own reducer instead of divide for null/0 count handling
- final RexNode reduceRef = rexBuilder.makeCall(function, functionArgs);
- return rexBuilder.makeCast(oldCall.getType(), reduceRef);
- }
-
- private static RexNode reduceSum(
- Aggregate oldAggRel,
- AggregateCall oldCall,
- List<AggregateCall> newCalls,
- Map<AggregateCall, RexNode> aggCallMapping) {
- final int nGroups = oldAggRel.getGroupCount();
- RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-
- AggregationFunctionType functionTypeSum = AggregationFunctionType.SUM0;
- SqlAggFunction sumAggFunc = new PinotSqlAggFunction(functionTypeSum.getName().toUpperCase(),
- functionTypeSum.getSqlIdentifier(), functionTypeSum.getSqlKind(), functionTypeSum.getSqlReturnTypeInference(),
- functionTypeSum.getSqlOperandTypeInference(), functionTypeSum.getSqlOperandTypeChecker(),
- functionTypeSum.getSqlFunctionCategory());
-
- final AggregateCall sumZeroCall =
- AggregateCall.create(sumAggFunc, oldCall.isDistinct(),
- oldCall.isApproximate(), oldCall.ignoreNulls(),
- oldCall.getArgList(), oldCall.filterArg, oldCall.distinctKeys,
- oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(),
- null, oldCall.name);
- final AggregateCall countCall =
- AggregateCall.create(SqlStdOperatorTable.COUNT,
- oldCall.isDistinct(),
- oldCall.isApproximate(),
- oldCall.ignoreNulls(),
- oldCall.getArgList(),
- oldCall.filterArg,
- oldCall.distinctKeys,
- oldCall.collation,
- oldAggRel.getGroupCount(),
- oldAggRel,
- null,
- null);
-
- // NOTE: these references are with respect to the output
- // of newAggRel
- RexNode sumZeroRef =
- rexBuilder.addAggCall(sumZeroCall,
- nGroups,
- newCalls,
- aggCallMapping,
- oldAggRel.getInput()::fieldIsNullable);
- if (!oldCall.getType().isNullable()) {
- // If SUM(x) is not nullable, the validator must have determined that
- // nulls are impossible (because the group is never empty and x is never
- // null). Therefore we translate to SUM0(x).
- return sumZeroRef;
- }
- RexNode countRef =
- rexBuilder.addAggCall(countCall,
- nGroups,
- newCalls,
- aggCallMapping,
- oldAggRel.getInput()::fieldIsNullable);
- return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
- rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
- countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
- rexBuilder.makeNullLiteral(sumZeroRef.getType()),
- sumZeroRef);
- }
-
- private static RexNode reduceMinMax(
- Aggregate oldAggRel,
- AggregateCall oldCall,
- List<AggregateCall> newCalls,
- Map<AggregateCall, RexNode> aggCallMapping) {
- final int nGroups = oldAggRel.getGroupCount();
- final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
- String functionName = oldCall.getAggregation().getName();
-
- AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName);
- SqlAggFunction aggFunc = new PinotSqlAggFunction(functionType.getName().toUpperCase(),
- functionType.getSqlIdentifier(), functionType.getSqlKind(), functionType.getSqlReturnTypeInference(),
- functionType.getSqlOperandTypeInference(), functionType.getSqlOperandTypeChecker(),
- functionType.getSqlFunctionCategory());
-
- final AggregateCall newCall =
- AggregateCall.create(aggFunc,
- oldCall.isDistinct(),
- oldCall.isApproximate(),
- oldCall.ignoreNulls(),
- oldCall.getArgList(),
- oldCall.filterArg,
- oldCall.distinctKeys,
- oldCall.collation,
- oldAggRel.getGroupCount(),
- oldAggRel.getInput(),
- null,
- null);
-
- RexNode ref =
- rexBuilder.addAggCall(newCall,
- nGroups,
- newCalls,
- aggCallMapping,
- oldAggRel.getInput()::fieldIsNullable);
- return rexBuilder.makeCast(oldCall.getType(), ref);
- }
-
- /**
- * Does a shallow clone of oldAggRel and updates aggCalls. Could be refactored
- * into Aggregate and subclasses - but it's only needed for some
- * subclasses.
- *
- * @param relBuilder Builder of relational expressions; at the top of its
- * stack is its input
- * @param oldAggregate LogicalAggregate to clone.
- * @param newCalls New list of AggregateCalls
- */
- protected void newAggregateRel(RelBuilder relBuilder,
- Aggregate oldAggregate,
- List<AggregateCall> newCalls) {
- relBuilder.aggregate(
- relBuilder.groupKey(oldAggregate.getGroupSet(), oldAggregate.getGroupSets()),
- newCalls);
- }
-
- /**
- * Adds a calculation with the expressions to compute the original aggregate
- * calls from the decomposed ones.
- *
- * @param relBuilder Builder of relational expressions; at the top of its
- * stack is its input
- * @param rowType The output row type of the original aggregate.
- * @param exprs The expressions to compute the original aggregate calls
- */
- protected void newCalcRel(RelBuilder relBuilder,
- RelDataType rowType,
- List<RexNode> exprs) {
- relBuilder.project(exprs, rowType.getFieldNames());
- }
-}
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
index 76ca322edf..ffde1200aa 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
@@ -86,7 +86,7 @@ public class PinotQueryRuleSets {
CoreRules.AGGREGATE_UNION_AGGREGATE,
// reduce aggregate functions like AVG, STDDEV_POP etc.
- PinotAggregateReduceFunctionsRule.INSTANCE
+ CoreRules.AGGREGATE_REDUCE_FUNCTIONS
);
// Filter pushdown rules run using a RuleCollection since we want to push down a filter as much as possible in a
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
index dd772c349c..8cd1289f5d 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
@@ -28,12 +28,10 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.sql.PinotSqlAggFunction;
import org.apache.calcite.sql.SqlFunction;
-import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.validate.SqlNameMatchers;
import org.apache.calcite.util.Util;
-import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
@@ -60,17 +58,9 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
Arrays.stream(AggregationFunctionType.values()).filter(func -> func.getSqlKind() != null)
.flatMap(func -> Stream.of(func.getSqlKind())).collect(Collectors.toSet());
- private static final Set<String> AGGREGATION_REDUCE_SUPPORTED_FUNCTIONS =
- Arrays.stream(AggregationFunctionType.values())
- .filter(func -> func.getReduceFunctionName() != null)
- .flatMap(func -> Stream.of(func.name(), func.getName(), func.getName().toUpperCase(),
- func.getName().toLowerCase()))
- .collect(Collectors.toSet());
-
- // TODO: This is needed until all aggregation functions are registered with the operator table. SqlKind cannot be
- // null for all registered calcite functions
private static final Set<String> CALCITE_OPERATOR_TABLE_REGISTERED_FUNCTIONS =
- Arrays.stream(AggregationFunctionType.values()).filter(func -> func.getSqlKind() != null)
+ Arrays.stream(AggregationFunctionType.values())
+ .filter(func -> func.getSqlKind() != null) // TODO: remove this once all V1 AGG functions are registered
.flatMap(func -> Stream.of(func.name(), func.getName(), func.getName().toUpperCase(),
func.getName().toLowerCase()))
.collect(Collectors.toSet());
@@ -88,24 +78,8 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
return _instance;
}
- public static boolean isAggregationKindSupported(SqlKind sqlKind) {
- return KINDS.contains(sqlKind);
- }
-
- public static boolean isAggregationReduceSupported(String functionName) {
- if (AGGREGATION_REDUCE_SUPPORTED_FUNCTIONS.contains(functionName)) {
- return true;
- }
- String upperCaseFunctionName = AggregationFunctionType.getNormalizedAggregationFunctionName(functionName);
- return AGGREGATION_REDUCE_SUPPORTED_FUNCTIONS.contains(upperCaseFunctionName);
- }
-
public static boolean isAggregationFunctionRegisteredWithOperatorTable(String functionName) {
- if (CALCITE_OPERATOR_TABLE_REGISTERED_FUNCTIONS.contains(functionName)) {
- return true;
- }
- String upperCaseFunctionName = AggregationFunctionType.getNormalizedAggregationFunctionName(functionName);
- return CALCITE_OPERATOR_TABLE_REGISTERED_FUNCTIONS.contains(upperCaseFunctionName);
+ return CALCITE_OPERATOR_TABLE_REGISTERED_FUNCTIONS.contains(functionName);
}
/**
@@ -125,10 +99,7 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
if (SqlFunction.class.isAssignableFrom(field.getType())) {
SqlFunction op = (SqlFunction) field.get(this);
if (op != null && notRegistered(op)) {
- if (!isPinotAggregationFunction(op.getName())) {
- // Register the standard Calcite functions and those defined in this class only
- register(op);
- }
+ register(op);
}
} else if (
SqlOperator.class.isAssignableFrom(field.getType())) {
@@ -142,89 +113,32 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
}
}
- // Walk through all the Pinot aggregation types and register those that are supported in multistage and which
- // aren't standard Calcite functions such as SUM / MIN / MAX / COUNT etc.
+ // Walk through all the Pinot aggregation types and
+ // 1. register those that are supported in multistage in addition to calcite standard opt table.
+ // 2. register special handling that differs from calcite standard.
for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) {
- if (aggregationFunctionType.isNativeCalciteAggregationFunctionType()
- || aggregationFunctionType.getSqlKind() == null) {
- // Skip registering functions which are standard Calcite functions and functions which are not yet supported
- // in multistage
- continue;
- }
-
- // Register the aggregation function with Calcite along with all alternative names
- List<PinotSqlAggFunction> sqlAggFunctions = new ArrayList<>();
- PinotSqlAggFunction aggFunction = generatePinotSqlAggFunction(aggregationFunctionType.getName(),
- aggregationFunctionType, false);
- sqlAggFunctions.add(aggFunction);
- List<String> alternativeFunctionNames = aggregationFunctionType.getAlternativeNames();
- if (alternativeFunctionNames == null || alternativeFunctionNames.size() == 0) {
- // If no alternative function names are specified, generate one which converts camel case to have underscores
- // as delimiters instead. E.g. boolAnd -> BOOL_AND
- String alternativeFunctionName =
- convertCamelCaseToUseUnderscores(aggregationFunctionType.getName());
- PinotSqlAggFunction function = generatePinotSqlAggFunction(alternativeFunctionName, aggregationFunctionType,
- false);
- sqlAggFunctions.add(function);
- } else {
+ if (aggregationFunctionType.getSqlKind() != null) {
+ // 1. Register the aggregation function with Calcite
+ registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType);
+ // 2. Register the aggregation function with Calcite on all alternative names
+ List<String> alternativeFunctionNames = aggregationFunctionType.getAlternativeNames();
for (String alternativeFunctionName : alternativeFunctionNames) {
- PinotSqlAggFunction function = generatePinotSqlAggFunction(alternativeFunctionName, aggregationFunctionType,
- false);
- sqlAggFunctions.add(function);
- }
- }
- for (PinotSqlAggFunction sqlAggFunction : sqlAggFunctions) {
- if (notRegistered(sqlAggFunction)) {
- register(sqlAggFunction);
- }
- }
-
- if (!StringUtils.isEmpty(aggregationFunctionType.getIntermediateFunctionName())
- && !StringUtils.equals(aggregationFunctionType.getIntermediateFunctionName(),
- aggregationFunctionType.getName())) {
- // Register the intermediate function with Calcite if the name differs from the main function name
- PinotSqlAggFunction intermediateAggFunction =
- generatePinotSqlAggFunction(aggregationFunctionType.getIntermediateFunctionName(), aggregationFunctionType,
- true);
- if (notRegistered(intermediateAggFunction)) {
- register(intermediateAggFunction);
- }
- }
-
- if (!StringUtils.isEmpty(aggregationFunctionType.getReduceFunctionName())) {
- // If a reduce function name is available, register this as a SqlFunction with Calcite
- SqlFunction function = new SqlFunction(
- aggregationFunctionType.getReduceFunctionName(),
- SqlKind.OTHER_FUNCTION,
- aggregationFunctionType.getSqlReduceReturnTypeInference(),
- null,
- aggregationFunctionType.getSqlReduceOperandTypeChecker(),
- SqlFunctionCategory.USER_DEFINED_FUNCTION);
- if (notRegistered(function)) {
- register(function);
+ registerAggregateFunction(alternativeFunctionName, aggregationFunctionType);
}
}
}
}
- private static String convertCamelCaseToUseUnderscores(String functionName) {
- // Skip functions that have numbers for now and return their name as is
- return functionName.matches(".*\\d.*")
- ? functionName
- : functionName.replaceAll("(.)(\\p{Upper}+|\\d+)", "$1_$2");
- }
-
- private static PinotSqlAggFunction generatePinotSqlAggFunction(String functionName,
- AggregationFunctionType aggregationFunctionType, boolean isIntermediateStageFunction) {
- return new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT),
- aggregationFunctionType.getSqlIdentifier(),
- aggregationFunctionType.getSqlKind(),
- isIntermediateStageFunction
- ? aggregationFunctionType.getSqlIntermediateReturnTypeInference()
- : aggregationFunctionType.getSqlReturnTypeInference(),
- aggregationFunctionType.getSqlOperandTypeInference(),
- aggregationFunctionType.getSqlOperandTypeChecker(),
- aggregationFunctionType.getSqlFunctionCategory());
+ private void registerAggregateFunction(String functionName, AggregationFunctionType functionType) {
+ // register function behavior that's different from Calcite
+ if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) {
+ PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
+ functionType.getSqlKind(), functionType.getReturnTypeInference(), null,
+ functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory());
+ if (notRegistered(sqlAggFunction)) {
+ register(sqlAggFunction);
+ }
+ }
}
private boolean notRegistered(SqlFunction op) {
@@ -240,12 +154,4 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
SqlNameMatchers.withCaseSensitive(false));
return operatorList.size() == 0;
}
-
- private boolean isPinotAggregationFunction(String name) {
- AggregationFunctionType aggFunctionType = null;
- if (isAggregationFunctionRegisteredWithOperatorTable(name)) {
- aggFunctionType = AggregationFunctionType.getAggregationFunctionType(name);
- }
- return aggFunctionType != null && !aggFunctionType.isNativeCalciteAggregationFunctionType();
- }
}
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 1d73743b09..05e837ee77 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -437,7 +437,6 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
new Object[]{"EXPLAIN PLAN EXCLUDING ATTRIBUTES AS DOT FOR SELECT col1, COUNT(*) FROM a GROUP BY col1",
"Execution Plan\n"
+ "digraph {\n"
- + "\"LogicalAggregate\\n\" -> \"LogicalAggregate\\n\" [label=\"0\"]\n"
+ "\"PinotLogicalExchange\\n\" -> \"LogicalAggregate\\n\" [label=\"0\"]\n"
+ "\"LogicalAggregate\\n\" -> \"PinotLogicalExchange\\n\" [label=\"0\"]\n"
+ "\"LogicalTableScan\\n\" -> \"LogicalAggregate\\n\" [label=\"0\"]\n"
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index f4c2ceafa7..135e1cb208 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -93,8 +93,9 @@ public class QueryEnvironmentTestBase {
new Object[]{"SELECT SUM(a.col3), COUNT(*) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a'"},
new Object[]{"SELECT AVG(a.col3), SUM(a.col3), COUNT(a.col3) FROM a"},
new Object[]{"SELECT a.col1, AVG(a.col3), SUM(a.col3), COUNT(a.col3) FROM a GROUP BY a.col1"},
- new Object[]{"SELECT BOOL_AND(a.col5), BOOL_OR(a.col5) FROM a"},
- new Object[]{"SELECT a.col3, BOOL_AND(a.col5), BOOL_OR(a.col5) FROM a GROUP BY a.col3"},
+ // TODO: support BOOL_AND and BOOL_OR as MIN/MAX
+// new Object[]{"SELECT BOOL_AND(a.col5), BOOL_OR(a.col5) FROM a"},
+// new Object[]{"SELECT a.col3, BOOL_AND(a.col5), BOOL_OR(a.col5) FROM a GROUP BY a.col3"},
new Object[]{"SELECT KURTOSIS(a.col2), COUNT(DISTINCT a.col3), SKEWNESS(a.col3) FROM a"},
new Object[]{"SELECT a.col1, KURTOSIS(a.col2), SKEWNESS(a.col3) FROM a GROUP BY a.col1"},
new Object[]{"SELECT COUNT(a.col3), AVG(a.col3), SUM(a.col3), MIN(a.col3), MAX(a.col3) FROM a"},
diff --git a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
index 0cc18fe67f..102fb8681c 100644
--- a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
@@ -6,14 +6,13 @@
"sql": "EXPLAIN PLAN FOR SELECT AVG(a.col4) as avg FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
"output": [
"Execution Plan",
- "\nLogicalProject(avg=[CAST(AVG_REDUCE(CAST($0):DECIMAL(1000, 0) NOT NULL, $1)):DECIMAL(1000, 0)])",
+ "\nLogicalProject(avg=[/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), $1)])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalProject(col2=[$1], col3=[$2], col4=[$3])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash])",
+ "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalProject(col2=[$1], col3=[$2], col4=[$3])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -22,14 +21,13 @@
"sql": "EXPLAIN PLAN FOR SELECT AVG(a.col4) as avg, SUM(a.col4) as sum, MAX(a.col4) as max FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
"output": [
"Execution Plan",
- "\nLogicalProject(avg=[CAST(AVG_REDUCE(CAST($0):DECIMAL(1000, 0) NOT NULL, $1)):DECIMAL(1000, 0)], sum=[$0], max=[$2])",
+ "\nLogicalProject(avg=[/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), $1)], sum=[CASE(=($1, 0), null:DECIMAL(1000, 0), $0)], max=[$2])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], agg#2=[MAX($2)])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], agg#2=[MAX($2)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)])",
- "\n LogicalProject(col2=[$1], col3=[$2], col4=[$3])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash])",
+ "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)])",
+ "\n LogicalProject(col2=[$1], col3=[$2], col4=[$3])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -38,28 +36,12 @@
"sql": "EXPLAIN PLAN FOR SELECT AVG(a.col3) as avg, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
"output": [
"Execution Plan",
- "\nLogicalProject(avg=[CAST(AVG_REDUCE($0, $1)):DOUBLE], count=[$1])",
+ "\nLogicalProject(avg=[/(CAST(CASE(=($1, 0), null:INTEGER, $0)):DOUBLE, $1)], count=[$1])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
- "\n LogicalProject(col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Select aggregates with filters and select alias",
- "sql": "EXPLAIN PLAN FOR SELECT KURTOSIS(a.col3) as kurtosis, DISTINCTCOUNT(a.col1) as dcount, MIN(a.col6), BOOL_AND(a.col5) FROM a WHERE a.col3 >= 0 AND a.col2 = 'iron maiden'",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{}], agg#0=[KURTOSIS($0)], agg#1=[DISTINCTCOUNT($1)], agg#2=[MIN($2)], agg#3=[BOOLAND($3)])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[DISTINCTCOUNT($1)], agg#2=[MIN($2)], agg#3=[BOOLAND($3)])",
"\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($2)], agg#1=[DISTINCTCOUNT($0)], agg#2=[MIN($4)], agg#3=[BOOLAND($3)])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2], col5=[$4], col6=[$5])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'iron maiden'))])",
+ "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
+ "\n LogicalProject(col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
@@ -69,7 +51,7 @@
"sql": "EXPLAIN PLAN FOR SELECT SUM(a.col3), COUNT(a.col1) FROM a",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
+ "\nLogicalProject(EXPR$0=[CASE(=($1, 0), null:INTEGER, $0)], EXPR$1=[$1])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
@@ -77,26 +59,12 @@
"\n"
]
},
- {
- "description": "Select aggregates for COUNT variations",
- "sql": "EXPLAIN PLAN FOR SELECT DISTINCTCOUNT(a.col3), COUNT(DISTINCT a.col3), COUNT(a.col3), COUNT(*), DISTINCTCOUNT(a.col5) FROM a",
- "output": [
- "Execution Plan",
- "\nLogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[$2], EXPR$3=[$2], EXPR$4=[$3])",
- "\n LogicalAggregate(group=[{}], agg#0=[DISTINCTCOUNT($0)], agg#1=[DISTINCTCOUNT(DISTINCT $1)], agg#2=[COUNT($2)], agg#3=[DISTINCTCOUNT($3)])",
- "\n LogicalAggregate(group=[{}], agg#0=[DISTINCTCOUNT($0)], agg#1=[DISTINCTCOUNT(DISTINCT $1)], agg#2=[COUNT($2)], agg#3=[DISTINCTCOUNT($3)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[DISTINCTCOUNT($2)], agg#1=[DISTINCTCOUNT(DISTINCT $2)], agg#2=[COUNT()], agg#3=[DISTINCTCOUNT($4)])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
{
"description": "Select aggregates with filters",
"sql": "EXPLAIN PLAN FOR SELECT SUM(a.col3), COUNT(*) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a'",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
+ "\nLogicalProject(EXPR$0=[CASE(=($1, 0), null:INTEGER, $0)], EXPR$1=[$1])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
@@ -106,93 +74,12 @@
"\n"
]
},
- {
- "description": "Select transform inside aggregate with filters",
- "sql": "EXPLAIN PLAN FOR SELECT SUM(ADD(a.col3, a.col6)), AVG(MOD(a.col6, a.col3)), MIN(ABS(a.col6)) FROM a WHERE a.col3 >= 0 AND a.col2 = 'hooloovoo'",
- "output": [
- "Execution Plan",
- "\nLogicalProject(EXPR$0=[$0], EXPR$1=[CAST(AVG_REDUCE($1, $2)):DOUBLE], EXPR$2=[$3])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[MIN($3)])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[MIN($3)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT()], agg#3=[MIN($2)])",
- "\n LogicalProject($f0=[ADD($2, $5)], $f1=[MOD($5, $2)], $f2=[ABS($5)])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'hooloovoo'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Select ORDER BY on aggregate with filters",
- "sql": "EXPLAIN PLAN FOR SELECT SUM(a.col6), AVG(a.col3), KURTOSIS(ABS(a.col6)) FROM a WHERE a.col3 >= 0 AND a.col2 = 'zaphoid beeblebrox' ORDER BY SUM(a.col6), AVG(a.col3) DESC",
- "output": [
- "Execution Plan",
- "\nLogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[DESC], offset=[0])",
- "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0, 1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
- "\n LogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[DESC])",
- "\n LogicalProject(EXPR$0=[$0], EXPR$1=[CAST(AVG_REDUCE($1, $2)):DOUBLE], EXPR$2=[$3])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[KURTOSIS($3)])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[FOURTHMOMENT($3)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT()], agg#3=[FOURTHMOMENT($2)])",
- "\n LogicalProject(col6=[$5], col3=[$2], $f2=[ABS($5)])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'zaphoid beeblebrox'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Select transform inside aggregate with transform inside filters",
- "sql": "EXPLAIN PLAN FOR SELECT SKEWNESS(ADD(a.col3, a.col6)), DISTINCTCOUNT(ABS(a.col6)) FROM a WHERE a.col3 >= 0 AND REVERSE(a.col2) = 'oovoolooh'",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{}], agg#0=[SKEWNESS($0)], agg#1=[DISTINCTCOUNT($1)])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[DISTINCTCOUNT($1)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[DISTINCTCOUNT($1)])",
- "\n LogicalProject($f0=[ADD($2, $5)], $f1=[ABS($5)])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =(REVERSE($1), 'oovoolooh'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Select more aggregates with filters",
- "sql": "EXPLAIN PLAN FOR SELECT SKEWNESS(a.col3), AVG(a.col3), BOOL_OR(a.col5), SUM(a.col3), AVG(a.col6), MAX(a.col6) FROM a WHERE a.col3 >= 0 AND a.col2 = 'rolling stones'",
- "output": [
- "Execution Plan",
- "\nLogicalProject(EXPR$0=[$0], EXPR$1=[CAST(AVG_REDUCE($1, $2)):DOUBLE], EXPR$2=[$3], EXPR$3=[$1], EXPR$4=[CAST(AVG_REDUCE($4, $2)):DOUBLE], EXPR$5=[$5])",
- "\n LogicalAggregate(group=[{}], agg#0=[SKEWNESS($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[BOOLOR($3)], agg#4=[$SUM0($4)], agg#5=[MAX($5)])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[BOOLOR($3)], agg#4=[$SUM0($4)], agg#5=[MAX($5)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($1)], agg#1=[$SUM0($1)], agg#2=[COUNT()], agg#3=[BOOLOR($2)], agg#4=[$SUM0($3)], agg#5=[MAX($3)])",
- "\n LogicalProject(col2=[$1], col3=[$2], col5=[$4], col6=[$5])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'rolling stones'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Select kurtosis and skewness with limit",
- "sql": "EXPLAIN PLAN FOR SELECT SKEWNESS(a.col3), KURTOSIS(a.col3) FROM a LIMIT 100",
- "output": [
- "Execution Plan",
- "\nLogicalSort(fetch=[100])",
- "\n PinotLogicalSortExchange(distribution=[hash], collation=[[]], isSortOnSender=[false], isSortOnReceiver=[false])",
- "\n LogicalAggregate(group=[{}], agg#0=[SKEWNESS($0)], agg#1=[KURTOSIS($1)])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[FOURTHMOMENT($1)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($2)], agg#1=[FOURTHMOMENT($2)])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
{
"description": "Select aggregates with filters and select alias",
"sql": "EXPLAIN PLAN FOR SELECT SUM(a.col3) as sum, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
+ "\nLogicalProject(sum=[CASE(=($1, 0), null:INTEGER, $0)], count=[$1])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
@@ -207,43 +94,27 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ AVG(a.col3) as avg, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
"output": [
"Execution Plan",
- "\nLogicalProject(avg=[CAST(AVG_REDUCE($0, $1)):DOUBLE], count=[$1])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
- "\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
- "\n LogicalProject(col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Select aggregates with filters and select alias. The group by aggregate hint should be a no-op.",
- "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ SUM(a.col3) as sum, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'queen'",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
+ "\nLogicalProject(avg=[/(CAST(CASE(=($1, 0), null:INTEGER, $0)):DOUBLE, $1)], count=[$1])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
"\n LogicalProject(col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'queen'))])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
},
{
"description": "Select aggregates with filters and select alias. The group by aggregate hint should be a no-op.",
- "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ KURTOSIS(a.col3) as kurtosis, SKEWNESS(a.col6) as skewness, COUNT(DISTINCT a.col6), COUNT(DISTINCT a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'metallica'",
+ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ SUM(a.col3) as sum, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{}], agg#0=[KURTOSIS($0)], agg#1=[SKEWNESS($1)], agg#2=[DISTINCTCOUNT(DISTINCT $2)], agg#3=[DISTINCTCOUNT(DISTINCT $3)])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[FOURTHMOMENT($1)], agg#2=[DISTINCTCOUNT(DISTINCT $2)], agg#3=[DISTINCTCOUNT(DISTINCT $3)])",
+ "\nLogicalProject(sum=[CASE(=($1, 0), null:INTEGER, $0)], count=[$1])",
+ "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($1)], agg#1=[FOURTHMOMENT($2)], agg#2=[DISTINCTCOUNT(DISTINCT $2)], agg#3=[DISTINCTCOUNT(DISTINCT $1)])",
- "\n LogicalProject(col2=[$1], col3=[$2], col6=[$5])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'metallica'))])",
+ "\n LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
+ "\n LogicalProject(col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
index 3c586b369d..0b19599217 100644
--- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
@@ -7,55 +7,9 @@
"output": [
"Execution Plan",
"\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Group by with select and aggregates with filters and select alias",
- "sql": "EXPLAIN PLAN FOR SELECT a.col2, KURTOSIS(a.col3) as kurtosis, DISTINCTCOUNT(a.col1) as dcount, MIN(a.col6), BOOL_AND(a.col5) FROM a WHERE a.col3 >= 0 AND a.col2 = 'linkin park' GROUP BY a.col2",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{0}], agg#0=[KURTOSIS($1)], agg#1=[DISTINCTCOUNT($2)], agg#2=[MIN($3)], agg#3=[BOOLAND($4)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)], agg#2=[MIN($3)], agg#3=[BOOLAND($4)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{1}], agg#0=[FOURTHMOMENT($2)], agg#1=[DISTINCTCOUNT($0)], agg#2=[MIN($4)], agg#3=[BOOLAND($3)])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2], col5=[$4], col6=[$5])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'linkin park'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Group by with elect and aggregates for COUNT variations",
- "sql": "EXPLAIN PLAN FOR SELECT a.col1, a.col2, DISTINCTCOUNT(a.col3), COUNT(DISTINCT a.col3), COUNT(a.col3), COUNT(*), DISTINCTCOUNT(a.col5) FROM a GROUP BY a.col1, a.col2",
- "output": [
- "Execution Plan",
- "\nLogicalProject(col1=[$0], col2=[$1], EXPR$2=[$2], EXPR$3=[$3], EXPR$4=[$4], EXPR$5=[$4], EXPR$6=[$5])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[DISTINCTCOUNT($2)], agg#1=[DISTINCTCOUNT(DISTINCT $3)], agg#2=[COUNT($4)], agg#3=[DISTINCTCOUNT($5)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[DISTINCTCOUNT($2)], agg#1=[DISTINCTCOUNT(DISTINCT $3)], agg#2=[COUNT($4)], agg#3=[DISTINCTCOUNT($5)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[DISTINCTCOUNT($2)], agg#1=[DISTINCTCOUNT(DISTINCT $2)], agg#2=[COUNT()], agg#3=[DISTINCTCOUNT($4)])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Group by with select and aggregates with filters",
- "sql": "EXPLAIN PLAN FOR SELECT a.col5, SKEWNESS(a.col3), AVG(a.col3), BOOL_OR(a.col5), SUM(a.col3), AVG(a.col6), MAX(a.col6) FROM a WHERE a.col3 >= 0 AND a.col2 = 'rolling stones' GROUP BY a.col5",
- "output": [
- "Execution Plan",
- "\nLogicalProject(col5=[$0], EXPR$1=[$1], EXPR$2=[AVG_REDUCE($2, $3)], EXPR$3=[$4], EXPR$4=[$2], EXPR$5=[AVG_REDUCE($5, $3)], EXPR$6=[$6])",
- "\n LogicalAggregate(group=[{0}], agg#0=[SKEWNESS($1)], agg#1=[$SUM0($2)], agg#2=[COUNT($3)], agg#3=[BOOLOR($4)], agg#4=[$SUM0($5)], agg#5=[MAX($6)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[$SUM0($2)], agg#2=[COUNT($3)], agg#3=[BOOLOR($4)], agg#4=[$SUM0($5)], agg#5=[MAX($6)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[FOURTHMOMENT($1)], agg#1=[$SUM0($1)], agg#2=[COUNT()], agg#3=[BOOLOR($2)], agg#4=[$SUM0($3)], agg#5=[MAX($3)])",
- "\n LogicalProject(col2=[$1], col3=[$2], col5=[$4], col6=[$5])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'rolling stones'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -64,12 +18,11 @@
"sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3), AVG(a.col3), MAX(a.col3), MIN(a.col3) FROM a GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[AVG_REDUCE($1, $2)], EXPR$3=[$3], EXPR$4=[$4])",
+ "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)], agg#3=[MIN($2)])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -79,12 +32,11 @@
"output": [
"Execution Plan",
"\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -95,12 +47,11 @@
"output": [
"Execution Plan",
"\nLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -111,12 +62,11 @@
"Execution Plan",
"\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
"\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
+ "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -126,31 +76,13 @@
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
- "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
+ "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
"\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Group by with having clause and different aggregation functions",
- "sql": "EXPLAIN PLAN FOR SELECT a.col1, KURTOSIS(a.col3), BOOL_OR(a.col5) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1 HAVING COUNT(*) > 10 AND MAX(a.col3) >= 0 AND MIN(a.col3) < 20 AND SKEWNESS(a.col3) <= 10 AND AVG(a.col3) = 5",
- "output": [
- "Execution Plan",
- "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
- "\n LogicalFilter(condition=[AND(>($3, 10), >=($4, 0), <($5, 20), <=($6, 10), =(AVG_REDUCE($7, $3), 5))])",
- "\n LogicalAggregate(group=[{0}], agg#0=[KURTOSIS($1)], agg#1=[BOOLOR($2)], agg#2=[COUNT($3)], agg#3=[MAX($4)], agg#4=[MIN($5)], agg#5=[SKEWNESS($6)], agg#6=[$SUM0($7)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[BOOLOR($2)], agg#2=[COUNT($3)], agg#3=[MAX($4)], agg#4=[MIN($5)], agg#5=[FOURTHMOMENT($6)], agg#6=[$SUM0($7)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($2)], agg#1=[BOOLOR($3)], agg#2=[COUNT()], agg#3=[MAX($2)], agg#4=[MIN($2)], agg#5=[FOURTHMOMENT($2)], agg#6=[$SUM0($2)])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2], col5=[$4])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -160,72 +92,25 @@
"output": [
"Execution Plan",
"\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
- "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
+ "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
"\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Group by aggregation kurtosis and skewness with limit",
- "sql": "EXPLAIN PLAN FOR SELECT a.col6, SKEWNESS(a.col3), KURTOSIS(a.col3) FROM a GROUP BY a.col6 LIMIT 100",
- "output": [
- "Execution Plan",
- "\nLogicalSort(offset=[0], fetch=[100])",
- "\n PinotLogicalSortExchange(distribution=[hash], collation=[[]], isSortOnSender=[false], isSortOnReceiver=[false])",
- "\n LogicalSort(fetch=[100])",
- "\n LogicalAggregate(group=[{0}], agg#0=[SKEWNESS($1)], agg#1=[KURTOSIS($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[FOURTHMOMENT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{5}], agg#0=[FOURTHMOMENT($2)], agg#1=[FOURTHMOMENT($2)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
},
- {
- "description": "Group by transform inside aggregation and transform in group by",
- "sql": "EXPLAIN PLAN FOR SELECT REVERSE(a.col1), SKEWNESS(ADD(a.col3, a.col6)), DISTINCTCOUNT(CONCAT(a.col1, '-', 'a.col2')) FROM a GROUP BY REVERSE(a.col1)",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{0}], agg#0=[SKEWNESS($1)], agg#1=[DISTINCTCOUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)])",
- "\n LogicalProject(EXPR$0=[REVERSE($0)], $f1=[ADD($2, $5)], $f2=[CONCAT($0, '-', 'a.col2')])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "Group by transform inside aggregation and transform in group by with filter with transform",
- "sql": "EXPLAIN PLAN FOR SELECT REVERSE(a.col1), SKEWNESS(ADD(a.col3, a.col6)), DISTINCTCOUNT(CONCAT(a.col1, '-', 'a.col2')) FROM a WHERE a.col3 >= 42 AND SUBSTR(a.col2, 0, 4) != 'prod' GROUP BY REVERSE(a.col1)",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{0}], agg#0=[SKEWNESS($1)], agg#1=[DISTINCTCOUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)])",
- "\n LogicalProject(EXPR$0=[REVERSE($0)], $f1=[ADD($2, $5)], $f2=[CONCAT($0, '-', 'a.col2')])",
- "\n LogicalFilter(condition=[AND(>=($2, 42), <>(SUBSTR($1, 0, 4), 'prod'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
{
"description": "SQL hint based group by optimization with select and aggregate column",
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col3=[$2])",
- "\n LogicalTableScan(table=[[a]])",
+ "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col3=[$2])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -234,12 +119,11 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, AVG(a.col3) FROM a GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalProject(col1=[$0], EXPR$1=[AVG_REDUCE($1, $2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col3=[$2])",
- "\n LogicalTableScan(table=[[a]])",
+ "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col3=[$2])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -248,12 +132,11 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3), AVG(a.col3), MAX(a.col3), MIN(a.col3) FROM a GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[AVG_REDUCE($1, $2)], EXPR$3=[$3], EXPR$4=[$4])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()], agg#2=[MAX($1)], agg#3=[MIN($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col3=[$2])",
- "\n LogicalTableScan(table=[[a]])",
+ "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])",
+ "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col3=[$2])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -262,12 +145,11 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($2)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -276,12 +158,11 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3), MAX(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[MAX($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($2)], EXPR$2=[MAX($2)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -291,12 +172,11 @@
"notes": "TODO: Needs follow up. Project should only keep a.col1 since the other columns are pushed to the filter, but it currently keeps them all",
"output": [
"Execution Plan",
- "\nLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -306,12 +186,11 @@
"output": [
"Execution Plan",
"\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
+ "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -321,13 +200,12 @@
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
- "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+ "\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -337,13 +215,12 @@
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], EXPR$1=[$1])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), =(AVG_REDUCE($1, $4), 5))])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($2)], agg#2=[MIN($3)], agg#3=[COUNT($4)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[MAX($2)], agg#2=[MIN($2)], agg#3=[COUNT()])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), =(/(CAST($1):DOUBLE NOT NULL, $4), 5))])",
+ "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($2)], agg#1=[MAX($2)], agg#2=[MIN($2)], agg#3=[COUNT()])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -353,27 +230,12 @@
"output": [
"Execution Plan",
"\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
- "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
- {
- "description": "SQL hint based group by optimization with select and aggregates with filters and select alias",
- "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ ts, KURTOSIS(a.col3) as kurtosis, SKEWNESS(a.col6) as skewness, COUNT(DISTINCT a.col6), COUNT(DISTINCT a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'metallica' GROUP BY ts",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{0}], agg#0=[KURTOSIS($1)], agg#1=[SKEWNESS($2)], agg#2=[DISTINCTCOUNT(DISTINCT $3)], agg#3=[DISTINCTCOUNT(DISTINCT $4)])",
- "\n LogicalAggregate(group=[{3}], agg#0=[FOURTHMOMENT($1)], agg#1=[FOURTHMOMENT($2)], agg#2=[DISTINCTCOUNT(DISTINCT $2)], agg#3=[DISTINCTCOUNT(DISTINCT $1)])",
- "\n PinotLogicalExchange(distribution=[hash[3]])",
- "\n LogicalProject(col2=[$1], col3=[$2], col6=[$5], ts=[$6])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'metallica'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+ "\n LogicalAggregate(group=[{0}], count=[COUNT()], SUM=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
}
diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
index 6d7b76c4e9..03f8c9dbf7 100644
--- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
@@ -116,20 +116,19 @@
"sql": "EXPLAIN PLAN FOR SELECT a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalProject(col1=[$0], EXPR$1=[AVG_REDUCE($1, $2)])",
+ "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[<($2, 0)])",
- "\n LogicalTableScan(table=[[b]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[<($2, 0)])",
+ "\n LogicalTableScan(table=[[b]])",
"\n"
]
},
@@ -289,32 +288,26 @@
"\n LogicalFilter(condition=[=($1, 'test')])",
"\n LogicalTableScan(table=[[a]])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col3=[$0], $f1=[CAST($1):BOOLEAN NOT NULL])",
- "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
- "\n LogicalProject(col3=[$2], $f1=[true])",
- "\n LogicalFilter(condition=[=($0, 'foo')])",
- "\n LogicalTableScan(table=[[b]])",
+ "\n LogicalProject(col3=[$2], $f1=[true])",
+ "\n LogicalFilter(condition=[=($0, 'foo')])",
+ "\n LogicalTableScan(table=[[b]])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col3=[$0], $f1=[CAST($1):BOOLEAN NOT NULL])",
- "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
- "\n LogicalProject(col3=[$2], $f1=[true])",
- "\n LogicalFilter(condition=[=($0, 'bar')])",
- "\n LogicalTableScan(table=[[b]])",
+ "\n LogicalProject(col3=[$2], $f1=[true])",
+ "\n LogicalFilter(condition=[=($0, 'bar')])",
+ "\n LogicalTableScan(table=[[b]])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col3=[$0], $f1=[CAST($1):BOOLEAN NOT NULL])",
- "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
- "\n LogicalProject(col3=[$2], $f1=[true])",
- "\n LogicalFilter(condition=[=($0, 'foobar')])",
- "\n LogicalTableScan(table=[[b]])",
+ "\n LogicalProject(col3=[$2], $f1=[true])",
+ "\n LogicalFilter(condition=[=($0, 'foobar')])",
+ "\n LogicalTableScan(table=[[b]])",
"\n"
]
},
@@ -328,12 +321,11 @@
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalTableScan(table=[[a]])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col10=[CAST($0):VARCHAR], col20=[CAST($1):VARCHAR], $f2=[CAST($2):DOUBLE], EXPR$3=[*(0.5:DECIMAL(2, 1), $2)])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col10=[CAST($0):VARCHAR], col20=[CAST($1):VARCHAR], $f2=[CAST($2):INTEGER], EXPR$3=[*(0.5:DECIMAL(2, 1), $2)])",
"\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[b]])",
+ "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
+ "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[b]])",
"\n"
]
},
@@ -349,10 +341,9 @@
"\n LogicalTableScan(table=[[a]])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[b]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[b]])",
"\n"
]
}
diff --git a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
index f205f0bb89..a6b391ca1c 100644
--- a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
@@ -60,10 +60,9 @@
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$0], dir0=[ASC])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -75,11 +74,10 @@
"\nLogicalSort(sort0=[$0], dir0=[ASC], offset=[0])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$0], dir0=[ASC])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col3=[$2])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col3=[$2])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -92,10 +90,9 @@
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$0], dir0=[ASC])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -107,11 +104,10 @@
"\nLogicalSort(sort0=[$0], dir0=[ASC], offset=[0])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$0], dir0=[ASC])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col3=[$2])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n LogicalAggregate(group=[{0}], sum=[$SUM0($1)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col1=[$0], col3=[$2])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
}
diff --git a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
index 08e4b2579a..f5ba9d3182 100644
--- a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
@@ -20,7 +20,7 @@
"sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(is_colocated_by_join_keys='true'), aggOptions(is_partitioned_by_group_by_keys='true') */a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1",
"output": [
"Execution Plan",
- "\nLogicalProject(col1=[$0], EXPR$1=[AVG_REDUCE($1, $2)])",
+ "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
"\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])",
"\n PinotLogicalExchange(distribution=[single])",
@@ -104,17 +104,16 @@
"output": [
"Execution Plan",
"\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalJoin(condition=[=($0, $3)], joinType=[semi])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalTableScan(table=[[a]])",
- "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])",
- "\n LogicalProject(col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[>($2, 0)])",
- "\n LogicalTableScan(table=[[b]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalJoin(condition=[=($0, $3)], joinType=[semi])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])",
+ "\n LogicalProject(col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[>($2, 0)])",
+ "\n LogicalTableScan(table=[[b]])",
"\n"
]
},
@@ -124,12 +123,11 @@
"output": [
"Execution Plan",
"\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
+ "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
+ "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -139,13 +137,12 @@
"output": [
"Execution Plan",
"\nLogicalProject(col2=[$0], EXPR$1=[$1], EXPR$2=[$2], EXPR$3=[$3])",
- "\n LogicalFilter(condition=[AND(>($1, 10), >=($4, 0), <($5, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[$SUM0($3)], agg#3=[MAX($4)], agg#4=[MIN($5)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
- "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n LogicalFilter(condition=[AND(>($1, 10), >=($4, 0), <($5, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+ "\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
+ "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -173,23 +170,6 @@
"\n LogicalTableScan(table=[[a]])",
"\n"
]
- },
- {
- "description": "Join and agg options",
- "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(is_colocated_by_join_keys='true'), aggOptions(is_partitioned_by_group_by_keys='true') */ a.col3, a.col1, SUM(b.col3) FROM a JOIN b ON a.col3 = b.col3 GROUP BY a.col3, a.col1",
- "output": [
- "Execution Plan",
- "\nLogicalProject(col3=[$1], col1=[$0], EXPR$2=[$2])",
- "\n LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
- "\n LogicalJoin(condition=[=($1, $2)], joinType=[inner])",
- "\n PinotLogicalExchange(distribution=[single])",
- "\n LogicalProject(col1=[$0], col3=[$2])",
- "\n LogicalTableScan(table=[[a]])",
- "\n PinotLogicalExchange(distribution=[single])",
- "\n LogicalProject(col3=[$2])",
- "\n LogicalTableScan(table=[[b]])",
- "\n"
- ]
}
]
}
diff --git a/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json b/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json
index ed3f067cac..262a73b52b 100644
--- a/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json
@@ -264,10 +264,9 @@
"\n LogicalWindow(window#0=[window(aggs [MIN($0)])])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -279,12 +278,11 @@
"\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)])",
"\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])",
"\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -299,12 +297,11 @@
"\n LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], col3=[$0])",
"\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])",
"\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -319,12 +316,11 @@
"\n LogicalProject(EXPR$0=[$1], EXPR$1=[$2], col3=[$0])",
"\n LogicalWindow(window#0=[window( rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])",
"\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -525,10 +521,9 @@
"\n LogicalWindow(window#0=[window(aggs [MIN($0), MAX($0)])])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -537,15 +532,14 @@
"sql": "EXPLAIN PLAN FOR SELECT AVG(a.col3), AVG(a.col3) OVER(), SUM(a.col3) OVER() FROM a GROUP BY a.col3",
"output": [
"Execution Plan",
- "\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$4])",
- "\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0), SUM($0)])])",
+ "\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$2])",
+ "\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])",
"\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -557,15 +551,14 @@
"\nLogicalSort(sort0=[$3], dir0=[ASC], offset=[0])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[3]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$3], dir0=[ASC])",
- "\n LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$4], col3=[$0])",
- "\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0), SUM($0)])])",
+ "\n LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$2], col3=[$0])",
+ "\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])",
"\n PinotLogicalExchange(distribution=[hash])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -861,10 +854,9 @@
"\n LogicalWindow(window#0=[window(partition {0} aggs [MIN($0)])])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -876,12 +868,11 @@
"\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)])",
"\n LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0)])])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -896,12 +887,11 @@
"\n LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], col3=[$0])",
"\n LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0)])])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -1044,8 +1034,8 @@
"\nLogicalSort(sort0=[$0], dir0=[ASC], offset=[0])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$0], dir0=[ASC])",
- "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)])",
- "\n LogicalWindow(window#0=[window(partition {0, 1} aggs [SUM($2), SUM($2), COUNT($2)])])",
+ "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)])",
+ "\n LogicalWindow(window#0=[window(partition {0, 1} aggs [SUM($2), COUNT($2)])])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[a]])",
@@ -1093,8 +1083,8 @@
"\nLogicalSort(sort0=[$3], dir0=[ASC], offset=[0])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[3]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$3], dir0=[ASC])",
- "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
- "\n LogicalWindow(window#0=[window(partition {0, 1} aggs [SUM($2), SUM($2), COUNT($2)])])",
+ "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+ "\n LogicalWindow(window#0=[window(partition {0, 1} aggs [SUM($2), COUNT($2)])])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[a]])",
@@ -1194,10 +1184,9 @@
"\n LogicalWindow(window#0=[window(partition {0} aggs [MIN($0), SUM($0)])])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -1209,12 +1198,11 @@
"\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$4])",
"\n LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0), MAX($0)])])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -1229,12 +1217,11 @@
"\n LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$4], col3=[$0])",
"\n LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0), MAX($0)])])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+ "\n LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -1839,8 +1826,8 @@
"\nLogicalSort(sort0=[$0], dir0=[DESC], offset=[0])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$0], dir0=[DESC])",
- "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)])",
- "\n LogicalWindow(window#0=[window(order by [1, 0 DESC] aggs [SUM($2), SUM($2), COUNT($2)])])",
+ "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)])",
+ "\n LogicalWindow(window#0=[window(order by [1, 0 DESC] aggs [SUM($2), COUNT($2)])])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[1, 0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[a]])",
@@ -1887,8 +1874,8 @@
"\nLogicalSort(sort0=[$0], dir0=[DESC], offset=[0], fetch=[10])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$0], dir0=[DESC], fetch=[10])",
- "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)])",
- "\n LogicalWindow(window#0=[window(order by [1, 0 DESC] aggs [SUM($2), SUM($2), COUNT($2)])])",
+ "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)])",
+ "\n LogicalWindow(window#0=[window(order by [1, 0 DESC] aggs [SUM($2), COUNT($2)])])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[1, 0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[a]])",
@@ -2519,8 +2506,8 @@
"\nLogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[ASC], offset=[0])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[3, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[ASC])",
- "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
- "\n LogicalWindow(window#0=[window(partition {0, 1} order by [1, 0] aggs [SUM($2), SUM($2), COUNT($2)])])",
+ "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+ "\n LogicalWindow(window#0=[window(partition {0, 1} order by [1, 0] aggs [SUM($2), COUNT($2)])])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[a]])",
@@ -2567,8 +2554,8 @@
"\nLogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[ASC], offset=[0], fetch=[10])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[3, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[ASC], fetch=[10])",
- "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
- "\n LogicalWindow(window#0=[window(partition {0, 1} order by [1, 0] aggs [SUM($2), SUM($2), COUNT($2)])])",
+ "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+ "\n LogicalWindow(window#0=[window(partition {0, 1} order by [1, 0] aggs [SUM($2), COUNT($2)])])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[a]])",
@@ -3186,8 +3173,8 @@
"\nLogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[DESC], offset=[0])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[3, 0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[DESC])",
- "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
- "\n LogicalWindow(window#0=[window(partition {0, 1} order by [2, 0] aggs [SUM($2), SUM($2), COUNT($2)])])",
+ "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+ "\n LogicalWindow(window#0=[window(partition {0, 1} order by [2, 0] aggs [SUM($2), COUNT($2)])])",
"\n PinotLogicalSortExchange(distribution=[hash[0, 1]], collation=[[2, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[a]])",
@@ -3234,8 +3221,8 @@
"\nLogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[DESC], offset=[0], fetch=[10])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[3, 0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[DESC], fetch=[10])",
- "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
- "\n LogicalWindow(window#0=[window(partition {0, 1} order by [2, 0] aggs [SUM($2), SUM($2), COUNT($2)])])",
+ "\n LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+ "\n LogicalWindow(window#0=[window(partition {0, 1} order by [2, 0] aggs [SUM($2), COUNT($2)])])",
"\n PinotLogicalSortExchange(distribution=[hash[0, 1]], collation=[[2, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalTableScan(table=[[a]])",
@@ -3395,10 +3382,9 @@
"\n LogicalWindow(window#0=[window(order by [2 DESC, 0] aggs [SUM($1), COUNT($1)])])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[2 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
+ "\n LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -3411,10 +3397,9 @@
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalProject(col1=[$0], EXPR$1=[$2])",
"\n LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
+ "\n LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -3427,10 +3412,9 @@
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalProject(col1=[$0], EXPR$1=[$2])",
"\n LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
+ "\n LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -3443,10 +3427,9 @@
"\n LogicalWindow(window#0=[window(partition {0} order by [2 DESC, 0] aggs [MAX($1)])])",
"\n PinotLogicalSortExchange(distribution=[hash[0]], collation=[[2 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
+ "\n LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -3478,24 +3461,6 @@
"\n"
]
},
- {
- "description": "Window function CTE: row_number WITH statement having OVER with PARTITION BY ORDER BY and aggregations and transforms",
- "sql": "EXPLAIN PLAN FOR WITH windowfunc AS (SELECT REVERSE(a.col1) as rev, ROW_NUMBER() OVER(PARTITION BY a.col2 ORDER BY a.col3) as rownum, a.col6 from a) SELECT rev, a.rownum, MAX(a.col6) FROM windowfunc AS a where a.rownum < 5 GROUP BY rev, a.rownum",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{0, 1}], agg#0=[MAX($2)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[MAX($2)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[MAX($2)])",
- "\n LogicalProject(rev=[$3], rownum=[$4], col6=[$2])",
- "\n LogicalFilter(condition=[<($4, 5)])",
- "\n LogicalWindow(window#0=[window(partition {0} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])",
- "\n PinotLogicalSortExchange(distribution=[hash[0]], collation=[[1]], isSortOnSender=[false], isSortOnReceiver=[true])",
- "\n LogicalProject(col2=[$1], col3=[$2], col6=[$5], $3=[REVERSE($0)])",
- "\n LogicalTableScan(table=[[a]])",
- "\n"
- ]
- },
{
"description": "Window function subquery: row_number having OVER with PARTITION BY ORDER BY",
"sql": "EXPLAIN PLAN FOR SELECT row_number, col2, col3 FROM (SELECT ROW_NUMBER() OVER(PARTITION BY a.col2 ORDER BY a.col3 DESC) as row_number, a.col2, a.col3 FROM a) WHERE row_number <= 10",
@@ -3519,10 +3484,9 @@
"\n LogicalWindow(window#0=[window(order by [1 DESC] aggs [RANK()])])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
"\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
- "\n LogicalTableScan(table=[[a]])",
+ "\n PinotLogicalExchange(distribution=[hash[0]])",
+ "\n LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
+ "\n LogicalTableScan(table=[[a]])",
"\n"
]
},
@@ -3569,32 +3533,6 @@
"\n LogicalTableScan(table=[[b]])",
"\n"
]
- },
- {
- "description": "Window function subquery with join using row_number and aggregation",
- "sql": "EXPLAIN PLAN FOR SELECT row_number, col2, SUM(col3), KURTOSIS(MULT(col3, col6)) FROM (SELECT a.col2 as col2, a.col3 as col3, a.col6 as col6, ROW_NUMBER() OVER(PARTITION BY a.col2 ORDER BY a.col3 DESC) as row_number FROM a INNER JOIN b ON a.col1 = b.col2 WHERE a.col3 > 100 AND b.col1 IN ('douglas adams', 'brandon sanderson')) where row_number <= 3 GROUP BY row_number, col2",
- "output": [
- "Execution Plan",
- "\nLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[KURTOSIS($3)])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[FOURTHMOMENT($3)])",
- "\n PinotLogicalExchange(distribution=[hash[0, 1]])",
- "\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[FOURTHMOMENT($3)])",
- "\n LogicalProject(row_number=[$3], col2=[$0], col3=[$1], $f3=[MULT($1, $2)])",
- "\n LogicalFilter(condition=[<=($3, 3)])",
- "\n LogicalWindow(window#0=[window(partition {0} order by [1 DESC] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])",
- "\n PinotLogicalSortExchange(distribution=[hash[0]], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
- "\n LogicalProject(col2=[$1], col3=[$2], col6=[$3])",
- "\n LogicalJoin(condition=[=($0, $4)], joinType=[inner])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2], col6=[$5])",
- "\n LogicalFilter(condition=[>($2, 100)])",
- "\n LogicalTableScan(table=[[a]])",
- "\n PinotLogicalExchange(distribution=[hash[0]])",
- "\n LogicalProject(col2=[$1])",
- "\n LogicalFilter(condition=[OR(=($0, 'brandon sanderson':VARCHAR(17)), =($0, 'douglas adams':VARCHAR(17)))])",
- "\n LogicalTableScan(table=[[b]])",
- "\n"
- ]
}
]
},
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 7dd9a546c2..1cc0bc5d19 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -19,17 +19,21 @@
package org.apache.pinot.query.runtime.operator;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.common.IntermediateStageBlockValSet;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
import org.apache.pinot.query.planner.logical.RexExpression;
@@ -71,7 +75,7 @@ public class AggregateOperator extends MultiStageOperator {
// Map that maintains the mapping between columnName and the column ordinal index. It is used to fetch the required
// column value from row-based container and fetch the input datatype for the column.
- private final HashMap<String, Integer> _colNameToIndexMap;
+ private final Map<String, Integer> _colNameToIndexMap;
private TransferableBlock _upstreamErrorBlock;
private boolean _readyToConstruct;
@@ -112,10 +116,12 @@ public class AggregateOperator extends MultiStageOperator {
// Initialize the appropriate executor.
if (!groupSet.isEmpty()) {
_isGroupByAggregation = true;
- _groupByExecutor = new MultistageGroupByExecutor(groupByExpr, aggFunctions, aggType, _colNameToIndexMap);
+ _groupByExecutor = new MultistageGroupByExecutor(groupByExpr, aggFunctions, aggType, _colNameToIndexMap,
+ _resultSchema);
} else {
_isGroupByAggregation = false;
- _aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, aggType, _colNameToIndexMap);
+ _aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, aggType, _colNameToIndexMap,
+ _resultSchema);
}
}
@@ -267,4 +273,39 @@ public class AggregateOperator extends MultiStageOperator {
return exprContext;
}
+
+ static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction,
+ TransferableBlock block, DataSchema inputDataSchema, Map<String, Integer> colNameToIndexMap) {
+ List<ExpressionContext> expressions = aggFunction.getInputExpressions();
+ int numExpressions = expressions.size();
+ if (numExpressions == 0) {
+ return Collections.emptyMap();
+ }
+
+ Preconditions.checkState(numExpressions == 1, "Cannot handle more than one identifier in aggregation function.");
+ ExpressionContext expression = expressions.get(0);
+ Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER));
+ int index = colNameToIndexMap.get(expression.getIdentifier());
+
+ DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index);
+ Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
+ // TODO: If the previous block is not mailbox received, this method is not efficient. Then getDataBlock() will
+ // convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column
+ // value primitive type.
+ return Collections.singletonMap(expression,
+ new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index));
+ }
+
+ static Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row,
+ Map<String, Integer> colNameToIndexMap) {
+ List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
+ Preconditions.checkState(expressions.size() == 1, "Only support single expression, got: %s", expressions.size());
+ ExpressionContext expr = expressions.get(0);
+ ExpressionContext.Type exprType = expr.getType();
+ if (exprType == ExpressionContext.Type.IDENTIFIER) {
+ return row[colNameToIndexMap.get(expr.getIdentifier())];
+ }
+ Preconditions.checkState(exprType == ExpressionContext.Type.LITERAL, "Unsupported expression type: %s", exprType);
+ return expr.getLiteral().getValue();
+ }
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/FilterOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/FilterOperator.java
index 2851aa528a..d2ec50129d 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/FilterOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/FilterOperator.java
@@ -28,7 +28,7 @@ import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -100,7 +100,7 @@ public class FilterOperator extends MultiStageOperator {
List<Object[]> resultRows = new ArrayList<>();
List<Object[]> container = block.getContainer();
for (Object[] row : container) {
- if ((Boolean) FunctionInvokeUtils.convert(_filterOperand.apply(row), DataSchema.ColumnDataType.BOOLEAN)) {
+ if ((Boolean) TypeUtils.convert(_filterOperand.apply(row), DataSchema.ColumnDataType.BOOLEAN)) {
resultRows.add(row);
}
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
index 61798f3dc7..8d5e5b1a06 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
@@ -38,7 +38,7 @@ import org.apache.pinot.query.planner.plannode.JoinNode;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -245,7 +245,7 @@ public class HashJoinOperator extends MultiStageOperator {
// TODO: Optimize this to avoid unnecessary object copy.
Object[] resultRow = joinRow(leftRow, rightRow);
if (_joinClauseEvaluators.isEmpty() || _joinClauseEvaluators.stream().allMatch(
- evaluator -> (Boolean) FunctionInvokeUtils.convert(evaluator.apply(resultRow),
+ evaluator -> (Boolean) TypeUtils.convert(evaluator.apply(resultRow),
DataSchema.ColumnDataType.BOOLEAN))) {
rows.add(resultRow);
hasMatchForLeftRow = true;
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
index 64e9117a92..1f7069ea45 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
@@ -43,6 +43,7 @@ import org.apache.pinot.core.operator.blocks.results.SelectionResultsBlock;
import org.apache.pinot.core.query.request.ServerQueryRequest;
import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -254,12 +255,12 @@ public class LeafStageTransferableBlockOperator extends MultiStageOperator {
List<Object[]> extractedRows = new ArrayList<>(resultRows.size());
if (resultRows instanceof List) {
for (Object[] row : resultRows) {
- extractedRows.add(canonicalizeRow(row, desiredDataSchema, columnIndices));
+ extractedRows.add(TypeUtils.canonicalizeRow(row, desiredDataSchema, columnIndices));
}
} else if (resultRows instanceof PriorityQueue) {
PriorityQueue<Object[]> priorityQueue = (PriorityQueue<Object[]>) resultRows;
while (!priorityQueue.isEmpty()) {
- extractedRows.add(canonicalizeRow(priorityQueue.poll(), desiredDataSchema, columnIndices));
+ extractedRows.add(TypeUtils.canonicalizeRow(priorityQueue.poll(), desiredDataSchema, columnIndices));
}
}
return new TransferableBlock(extractedRows, desiredDataSchema, DataBlock.Type.ROW);
@@ -277,12 +278,12 @@ public class LeafStageTransferableBlockOperator extends MultiStageOperator {
List<Object[]> extractedRows = new ArrayList<>(resultRows.size());
if (resultRows instanceof List) {
for (Object[] orgRow : resultRows) {
- extractedRows.add(canonicalizeRow(orgRow, desiredDataSchema));
+ extractedRows.add(TypeUtils.canonicalizeRow(orgRow, desiredDataSchema));
}
} else if (resultRows instanceof PriorityQueue) {
PriorityQueue<Object[]> priorityQueue = (PriorityQueue<Object[]>) resultRows;
while (!priorityQueue.isEmpty()) {
- extractedRows.add(canonicalizeRow(priorityQueue.poll(), desiredDataSchema));
+ extractedRows.add(TypeUtils.canonicalizeRow(priorityQueue.poll(), desiredDataSchema));
}
} else {
throw new UnsupportedOperationException("Unsupported collection type: " + resultRows.getClass());
@@ -299,41 +300,6 @@ public class LeafStageTransferableBlockOperator extends MultiStageOperator {
return true;
}
- /**
- * This util is used to canonicalize row generated from V1 engine, which is stored using
- * {@link DataSchema#getStoredColumnDataTypes()} format. However, the transferable block ser/de stores data in the
- * {@link DataSchema#getColumnDataTypes()} format.
- *
- * @param row un-canonicalize row.
- * @param dataSchema data schema desired for the row.
- * @return canonicalize row.
- */
- private static Object[] canonicalizeRow(Object[] row, DataSchema dataSchema) {
- Object[] resultRow = new Object[row.length];
- for (int colId = 0; colId < row.length; colId++) {
- Object value = row[colId];
- if (value != null) {
- if (dataSchema.getColumnDataType(colId) == DataSchema.ColumnDataType.OBJECT) {
- resultRow[colId] = value;
- } else {
- resultRow[colId] = dataSchema.getColumnDataType(colId).convert(value);
- }
- }
- }
- return resultRow;
- }
-
- private static Object[] canonicalizeRow(Object[] row, DataSchema dataSchema, int[] columnIndices) {
- Object[] resultRow = new Object[columnIndices.length];
- for (int colId = 0; colId < columnIndices.length; colId++) {
- Object value = row[columnIndices[colId]];
- if (value != null) {
- resultRow[colId] = dataSchema.getColumnDataType(colId).convert(value);
- }
- }
- return resultRow;
- }
-
private static boolean isDataSchemaColumnTypesCompatible(DataSchema.ColumnDataType[] desiredTypes,
DataSchema.ColumnDataType[] givenTypes) {
if (desiredTypes.length != givenTypes.length) {
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
index 8ea4ab0a91..19e4f66cc6 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
@@ -18,19 +18,17 @@
*/
package org.apache.pinot.query.runtime.operator;
-import com.google.common.base.Preconditions;
import java.util.Collections;
import java.util.List;
import java.util.Map;
-import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
-import org.apache.pinot.core.common.IntermediateStageBlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
/**
@@ -41,23 +39,23 @@ public class MultistageAggregationExecutor {
// The identifier operands for the aggregation function only store the column name. This map contains mapping
// from column name to their index.
private final Map<String, Integer> _colNameToIndexMap;
+ private final DataSchema _resultSchema;
private final AggregationFunction[] _aggFunctions;
// Result holders for each mode.
private final AggregationResultHolder[] _aggregateResultHolder;
private final Object[] _mergeResultHolder;
- private final Object[] _finalResultHolder;
- public MultistageAggregationExecutor(AggregationFunction[] aggFunctions, AggType aggType,
- Map<String, Integer> colNameToIndexMap) {
+ public MultistageAggregationExecutor(AggregationFunction[] aggFunctions,
+ AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) {
_aggFunctions = aggFunctions;
_aggType = aggType;
_colNameToIndexMap = colNameToIndexMap;
+ _resultSchema = resultSchema;
_aggregateResultHolder = new AggregationResultHolder[aggFunctions.length];
_mergeResultHolder = new Object[aggFunctions.length];
- _finalResultHolder = new Object[aggFunctions.length];
for (int i = 0; i < _aggFunctions.length; i++) {
_aggregateResultHolder[i] = _aggFunctions[i].createAggregationResultHolder();
@@ -70,10 +68,8 @@ public class MultistageAggregationExecutor {
public void processBlock(TransferableBlock block, DataSchema inputDataSchema) {
if (!_aggType.isInputIntermediateFormat()) {
processAggregate(block, inputDataSchema);
- } else if (_aggType.isOutputIntermediateFormat()) {
- processMerge(block);
} else {
- collectResult(block);
+ processMerge(block);
}
}
@@ -93,35 +89,37 @@ public class MultistageAggregationExecutor {
* Fetches the result.
*/
public List<Object[]> getResult() {
- int numFunctions = _aggFunctions.length;
- Object[] row = new Object[numFunctions];
- for (int i = 0; i < numFunctions; i++) {
- AggregationFunction func = _aggFunctions[i];
- if (!_aggType.isInputIntermediateFormat()) {
- Object intermediateResult = func.extractAggregationResult(_aggregateResultHolder[i]);
- if (_aggType.isOutputIntermediateFormat()) {
- row[i] = intermediateResult;
- } else {
- Object finalResult = func.extractFinalResult(intermediateResult);
- row[i] = finalResult == null ? null : func.getFinalResultColumnType().convert(finalResult);
- }
- } else {
- if (_aggType.isOutputIntermediateFormat()) {
- row[i] = _mergeResultHolder[i];
- } else {
- Object finalResult = func.extractFinalResult(_finalResultHolder[i]);
- row[i] = finalResult == null ? null : func.getFinalResultColumnType().convert(finalResult);
- }
+ Object[] row = new Object[_aggFunctions.length];
+ for (int i = 0; i < _aggFunctions.length; i++) {
+ AggregationFunction aggFunction = _aggFunctions[i];
+ Object value;
+ switch (_aggType) {
+ case LEAF:
+ value = aggFunction.extractAggregationResult(_aggregateResultHolder[i]);
+ break;
+ case INTERMEDIATE:
+ value = _mergeResultHolder[i];
+ break;
+ case FINAL:
+ value = aggFunction.extractFinalResult(_mergeResultHolder[i]);
+ break;
+ case DIRECT:
+ Object intermediate = aggFunction.extractAggregationResult(_aggregateResultHolder[i]);
+ value = aggFunction.extractFinalResult(intermediate);
+ break;
+ default:
+ throw new UnsupportedOperationException("Unsupported aggTyp: " + _aggType);
}
+ row[i] = value;
}
- return Collections.singletonList(row);
+ return Collections.singletonList(TypeUtils.canonicalizeRow(row, _resultSchema));
}
private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) {
for (int i = 0; i < _aggFunctions.length; i++) {
AggregationFunction aggregationFunction = _aggFunctions[i];
Map<ExpressionContext, BlockValSet> blockValSetMap =
- getBlockValSetMap(aggregationFunction, block, inputDataSchema);
+ AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap);
aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap);
}
}
@@ -131,7 +129,8 @@ public class MultistageAggregationExecutor {
for (int i = 0; i < _aggFunctions.length; i++) {
for (Object[] row : container) {
- Object intermediateResultToMerge = extractValueFromRow(_aggFunctions[i], row);
+ Object intermediateResultToMerge =
+ AggregateOperator.extractValueFromRow(_aggFunctions[i], row, _colNameToIndexMap);
// Not all V1 aggregation functions have null-handling logic. Handle null values before calling merge.
if (intermediateResultToMerge == null) {
@@ -147,50 +146,4 @@ public class MultistageAggregationExecutor {
}
}
}
-
- private void collectResult(TransferableBlock block) {
- List<Object[]> container = block.getContainer();
- assert container.size() == 1;
- Object[] row = container.get(0);
- for (int i = 0; i < _aggFunctions.length; i++) {
- _finalResultHolder[i] = extractValueFromRow(_aggFunctions[i], row);
- }
- }
-
- private Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction,
- TransferableBlock block, DataSchema inputDataSchema) {
- List<ExpressionContext> expressions = aggFunction.getInputExpressions();
- int numExpressions = expressions.size();
- if (numExpressions == 0) {
- return Collections.emptyMap();
- }
-
- Preconditions.checkState(numExpressions == 1, "Cannot handle more than one identifier in aggregation function.");
- ExpressionContext expression = expressions.get(0);
- Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER));
- int index = _colNameToIndexMap.get(expression.getIdentifier());
-
- DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index);
- Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
- // TODO: If the previous block is not mailbox received, this method is not efficient. Then getDataBlock() will
- // convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column
- // value primitive type.
- return Collections.singletonMap(expression,
- new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index));
- }
-
- private Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row) {
- // TODO: Add support to handle aggregation functions where:
- // 1. The identifier need not be the first argument
- // 2. There are more than one identifiers.
- List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
- Preconditions.checkState(expressions.size() == 1, "Only support single expression, got: %s", expressions.size());
- ExpressionContext expr = expressions.get(0);
- ExpressionContext.Type exprType = expr.getType();
- if (exprType == ExpressionContext.Type.IDENTIFIER) {
- return row[_colNameToIndexMap.get(expr.getIdentifier())];
- }
- Preconditions.checkState(exprType == ExpressionContext.Type.LITERAL, "Unsupported expression type: %s", exprType);
- return expr.getLiteral().getValue();
- }
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
index 03d7638aa4..5eacba025b 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
@@ -18,23 +18,21 @@
*/
package org.apache.pinot.query.runtime.operator;
-import com.google.common.base.Preconditions;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
-import org.apache.pinot.core.common.IntermediateStageBlockValSet;
import org.apache.pinot.core.data.table.Key;
import org.apache.pinot.core.plan.maker.InstancePlanMakerImplV2;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
+
/**
@@ -45,6 +43,7 @@ public class MultistageGroupByExecutor {
// The identifier operands for the aggregation function only store the column name. This map contains mapping
// between column name to their index which is used in v2 engine.
private final Map<String, Integer> _colNameToIndexMap;
+ private final DataSchema _resultSchema;
private final List<ExpressionContext> _groupSet;
private final AggregationFunction[] _aggFunctions;
@@ -52,22 +51,21 @@ public class MultistageGroupByExecutor {
// Group By Result holders for each mode
private final GroupByResultHolder[] _aggregateResultHolders;
private final Map<Integer, Object[]> _mergeResultHolder;
- private final List<Object[]> _finalResultHolder;
// Mapping from the row-key to a zero based integer index. This is used when we invoke the v1 aggregation functions
// because they use the zero based integer indexes to store results.
private final Map<Key, Integer> _groupKeyToIdMap;
public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, AggregationFunction[] aggFunctions,
- AggType aggType, Map<String, Integer> colNameToIndexMap) {
+ AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) {
_aggType = aggType;
_colNameToIndexMap = colNameToIndexMap;
_groupSet = groupByExpr;
_aggFunctions = aggFunctions;
+ _resultSchema = resultSchema;
_aggregateResultHolders = new GroupByResultHolder[_aggFunctions.length];
_mergeResultHolder = new HashMap<>();
- _finalResultHolder = new ArrayList<>();
_groupKeyToIdMap = new HashMap<>();
@@ -84,10 +82,8 @@ public class MultistageGroupByExecutor {
public void processBlock(TransferableBlock block, DataSchema inputDataSchema) {
if (!_aggType.isInputIntermediateFormat()) {
processAggregate(block, inputDataSchema);
- } else if (_aggType.isOutputIntermediateFormat()) {
- processMerge(block);
} else {
- collectResult(block);
+ processMerge(block);
}
}
@@ -95,10 +91,6 @@ public class MultistageGroupByExecutor {
* Fetches the result.
*/
public List<Object[]> getResult() {
- if (_aggType == AggType.FINAL) {
- return extractFinalGroupByResult();
- }
-
List<Object[]> rows = new ArrayList<>(_groupKeyToIdMap.size());
int numKeys = _groupSet.size();
int numFunctions = _aggFunctions.length;
@@ -111,39 +103,27 @@ public class MultistageGroupByExecutor {
for (int i = 0; i < numFunctions; i++) {
AggregationFunction func = _aggFunctions[i];
int index = numKeys + i;
- if (!_aggType.isInputIntermediateFormat()) {
- Object intermediateResult = func.extractGroupByResult(_aggregateResultHolders[i], groupId);
- if (_aggType.isOutputIntermediateFormat()) {
- row[index] = intermediateResult;
- } else {
- Object finalResult = func.extractFinalResult(intermediateResult);
- row[index] = finalResult == null ? null : func.getFinalResultColumnType().convert(finalResult);
- }
- } else {
- assert _aggType == AggType.INTERMEDIATE;
- row[index] = _mergeResultHolder.get(groupId)[i];
+ Object value;
+ switch (_aggType) {
+ case LEAF:
+ value = func.extractGroupByResult(_aggregateResultHolders[i], groupId);
+ break;
+ case INTERMEDIATE:
+ value = _mergeResultHolder.get(groupId)[i];
+ break;
+ case FINAL:
+ value = func.extractFinalResult(_mergeResultHolder.get(groupId)[i]);
+ break;
+ case DIRECT:
+ Object intermediate = _aggFunctions[i].extractGroupByResult(_aggregateResultHolders[i], groupId);
+ value = func.extractFinalResult(intermediate);
+ break;
+ default:
+ throw new UnsupportedOperationException("Unsupported aggTyp: " + _aggType);
}
+ row[index] = value;
}
- rows.add(row);
- }
- return rows;
- }
-
- private List<Object[]> extractFinalGroupByResult() {
- List<Object[]> rows = new ArrayList<>(_finalResultHolder.size());
- int numKeys = _groupSet.size();
- int numFunctions = _aggFunctions.length;
- int numColumns = numKeys + numFunctions;
- for (Object[] resultRow : _finalResultHolder) {
- Object[] row = new Object[numColumns];
- System.arraycopy(resultRow, 0, row, 0, numKeys);
- for (int i = 0; i < numFunctions; i++) {
- AggregationFunction func = _aggFunctions[i];
- int index = numKeys + i;
- Object finalResult = func.extractFinalResult(resultRow[index]);
- row[index] = finalResult == null ? null : func.getFinalResultColumnType().convert(finalResult);
- }
- rows.add(row);
+ rows.add(TypeUtils.canonicalizeRow(row, _resultSchema));
}
return rows;
}
@@ -154,7 +134,7 @@ public class MultistageGroupByExecutor {
for (int i = 0; i < _aggFunctions.length; i++) {
AggregationFunction aggregationFunction = _aggFunctions[i];
Map<ExpressionContext, BlockValSet> blockValSetMap =
- getBlockValSetMap(aggregationFunction, block, inputDataSchema);
+ AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap);
GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap);
@@ -172,7 +152,8 @@ public class MultistageGroupByExecutor {
if (!_mergeResultHolder.containsKey(rowKey)) {
_mergeResultHolder.put(rowKey, new Object[_aggFunctions.length]);
}
- Object intermediateResultToMerge = extractValueFromRow(_aggFunctions[i], row);
+ Object intermediateResultToMerge =
+ AggregateOperator.extractValueFromRow(_aggFunctions[i], row, _colNameToIndexMap);
// Not all V1 aggregation functions have null-handling. So handle null values and call merge only if necessary.
if (intermediateResultToMerge == null) {
@@ -189,22 +170,6 @@ public class MultistageGroupByExecutor {
}
}
- private void collectResult(TransferableBlock block) {
- List<Object[]> container = block.getContainer();
- for (Object[] row : container) {
- assert row.length == _groupSet.size() + _aggFunctions.length;
- Object[] resultRow = new Object[row.length];
- System.arraycopy(row, 0, resultRow, 0, _groupSet.size());
-
- for (int i = 0; i < _aggFunctions.length; i++) {
- int index = _groupSet.size() + i;
- resultRow[index] = extractValueFromRow(_aggFunctions[i], row);
- }
-
- _finalResultHolder.add(resultRow);
- }
- }
-
/**
* Creates the group by key for each row. Converts the key into a 0-index based int value that can be used by
* GroupByAggregationResultHolders used in v1 aggregations.
@@ -226,41 +191,4 @@ public class MultistageGroupByExecutor {
}
return rowIntKeys;
}
-
- private Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction,
- TransferableBlock block, DataSchema inputDataSchema) {
- List<ExpressionContext> expressions = aggFunction.getInputExpressions();
- int numExpressions = expressions.size();
- if (numExpressions == 0) {
- return Collections.emptyMap();
- }
-
- Preconditions.checkState(numExpressions == 1, "Cannot handle more than one identifier in aggregation function.");
- ExpressionContext expression = expressions.get(0);
- Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER));
- int index = _colNameToIndexMap.get(expression.getIdentifier());
-
- DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index);
- Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
- // TODO: If the previous block is not mailbox received, this method is not efficient. Then getDataBlock() will
- // convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column
- // value primitive type.
- return Collections.singletonMap(expression,
- new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index));
- }
-
- private Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row) {
- // TODO: Add support to handle aggregation functions where:
- // 1. The identifier need not be the first argument
- // 2. There are more than one identifiers.
- List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
- Preconditions.checkState(expressions.size() == 1, "Only support single expression, got: %s", expressions.size());
- ExpressionContext expr = expressions.get(0);
- ExpressionContext.Type exprType = expr.getType();
- if (exprType == ExpressionContext.Type.IDENTIFIER) {
- return row[_colNameToIndexMap.get(expr.getIdentifier())];
- }
- Preconditions.checkState(exprType == ExpressionContext.Type.LITERAL, "Unsupported expression type: %s", exprType);
- return expr.getLiteral().getValue();
- }
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/TransformOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/TransformOperator.java
index e53bcfdac2..a913d95295 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/TransformOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/TransformOperator.java
@@ -29,7 +29,7 @@ import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -113,8 +113,7 @@ public class TransformOperator extends MultiStageOperator {
for (Object[] row : container) {
Object[] resultRow = new Object[_resultColumnSize];
for (int i = 0; i < _resultColumnSize; i++) {
- resultRow[i] =
- FunctionInvokeUtils.convert(_transformOperandsList.get(i).apply(row), _resultSchema.getColumnDataType(i));
+ resultRow[i] = TypeUtils.convert(_transformOperandsList.get(i).apply(row), _resultSchema.getColumnDataType(i));
}
resultRows.add(resultRow);
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
index ab69ce67b2..dbe3c179ca 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
@@ -24,7 +24,7 @@ import java.util.List;
import java.util.function.IntPredicate;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
import org.apache.pinot.spi.utils.BooleanUtils;
@@ -160,8 +160,8 @@ public abstract class FilterOperand extends TransformOperand {
return false;
}
if (_requireCasting) {
- v1 = (Comparable) FunctionInvokeUtils.convert(v1, _commonCastType);
- v2 = (Comparable) FunctionInvokeUtils.convert(v2, _commonCastType);
+ v1 = (Comparable) TypeUtils.convert(v1, _commonCastType);
+ v2 = (Comparable) TypeUtils.convert(v2, _commonCastType);
assert v1 != null && v2 != null;
}
return _comparisonResultPredicate.test(v1.compareTo(v2));
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
deleted file mode 100644
index 851748d2b2..0000000000
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.query.runtime.operator.utils;
-
-import javax.annotation.Nullable;
-import org.apache.pinot.common.utils.DataSchema;
-
-
-public class FunctionInvokeUtils {
- private FunctionInvokeUtils() {
- }
-
- /**
- * Convert result to the appropriate column data type according to the desired {@link DataSchema.ColumnDataType}
- * of the {@link org.apache.pinot.core.common.Operator}.
- *
- * @param inputObj input entry
- * @param columnDataType desired column data type
- * @return converted entry
- */
- @Nullable
- public static Object convert(@Nullable Object inputObj, DataSchema.ColumnDataType columnDataType) {
- if (columnDataType.isNumber() && columnDataType != DataSchema.ColumnDataType.BIG_DECIMAL) {
- return inputObj == null ? null : columnDataType.convert(inputObj);
- } else {
- return inputObj;
- }
- }
-}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java
new file mode 100644
index 0000000000..d60ad0bdb7
--- /dev/null
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.utils;
+
+import javax.annotation.Nullable;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.spi.utils.BooleanUtils;
+
+
+public class TypeUtils {
+ private TypeUtils() {
+ }
+
+ /**
+ * Convert result to the appropriate column data type according to the desired {@link DataSchema.ColumnDataType}
+ * of the {@link org.apache.pinot.core.common.Operator}.
+ *
+ * @param inputObj input entry
+ * @param columnDataType desired column data type
+ * @return converted entry
+ */
+ @Nullable
+ public static Object convert(@Nullable Object inputObj, DataSchema.ColumnDataType columnDataType) {
+ if (columnDataType.isNumber() && columnDataType != DataSchema.ColumnDataType.BIG_DECIMAL) {
+ return inputObj == null ? null : columnDataType.convert(inputObj);
+ } else {
+ return inputObj;
+ }
+ }
+
+ /**
+ * This util is used to canonicalize row generated from V1 engine, which is stored using
+ * {@link DataSchema#getStoredColumnDataTypes()} format. However, the transferable block ser/de stores data in the
+ * {@link DataSchema#getColumnDataTypes()} format.
+ *
+ * @param row un-canonicalize row.
+ * @param dataSchema data schema desired for the row.
+ * @return canonicalize row.
+ */
+ public static Object[] canonicalizeRow(Object[] row, DataSchema dataSchema) {
+ Object[] resultRow = new Object[row.length];
+ for (int colId = 0; colId < row.length; colId++) {
+ Object value = row[colId];
+ if (value != null) {
+ if (dataSchema.getColumnDataType(colId) == DataSchema.ColumnDataType.OBJECT) {
+ resultRow[colId] = value;
+ } else if (dataSchema.getColumnDataType(colId) == DataSchema.ColumnDataType.BOOLEAN) {
+ resultRow[colId] = BooleanUtils.toBoolean(value);
+ } else {
+ resultRow[colId] = dataSchema.getColumnDataType(colId).convert(value);
+ }
+ }
+ }
+ return resultRow;
+ }
+
+ /**
+ * Canonicalize rows with column indices not matching calcite order.
+ *
+ * see: {@link TypeUtils#canonicalizeRow(Object[], DataSchema)}
+ */
+ public static Object[] canonicalizeRow(Object[] row, DataSchema dataSchema, int[] columnIndices) {
+ Object[] resultRow = new Object[columnIndices.length];
+ for (int colId = 0; colId < columnIndices.length; colId++) {
+ Object value = row[columnIndices[colId]];
+ if (value != null) {
+ resultRow[colId] = dataSchema.getColumnDataType(colId).convert(value);
+ }
+ }
+ return resultRow;
+ }
+}
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index 6ecb29efb4..8048bad5f4 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -41,6 +41,7 @@ import org.testng.annotations.Test;
import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.DOUBLE;
import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.INT;
+import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.LONG;
import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.STRING;
public class AggregateOperatorTest {
@@ -74,8 +75,8 @@ public class AggregateOperatorTest {
Mockito.when(_input.nextBlock())
.thenReturn(TransferableBlockUtils.getErrorTransferableBlock(new Exception("foo!")));
- DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
- DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+ DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, DOUBLE});
+ DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
AggType.INTERMEDIATE);
@@ -97,10 +98,10 @@ public class AggregateOperatorTest {
Mockito.when(_input.nextBlock()).thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
- DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+ DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.INTERMEDIATE);
+ AggType.LEAF);
// When:
TransferableBlock block = operator.nextBlock();
@@ -121,10 +122,10 @@ public class AggregateOperatorTest {
.thenReturn(TransferableBlockUtils.getNoOpTransferableBlock())
.thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
- DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+ DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.INTERMEDIATE);
+ AggType.LEAF);
// When:
TransferableBlock block1 = operator.nextBlock(); // build when reading NoOp block
@@ -142,11 +143,11 @@ public class AggregateOperatorTest {
List<RexExpression> calls = ImmutableList.of(getSum(new RexExpression.InputRef(1)));
List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0));
- DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
- Mockito.when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 1}))
+ DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, DOUBLE});
+ Mockito.when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 1.0}))
.thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
- DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+ DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
AggType.INTERMEDIATE);
@@ -158,25 +159,25 @@ public class AggregateOperatorTest {
// Then:
Mockito.verify(_input, Mockito.times(2)).nextBlock();
Assert.assertTrue(block1.getNumRows() > 0, "First block is the result");
- Assert.assertEquals(block1.getContainer().get(0), new Object[]{2, 1},
- "Expected two columns (group by key, agg value)");
+ Assert.assertEquals(block1.getContainer().get(0), new Object[]{2, 1.0},
+ "Expected two columns (group by key, agg value), agg value is intermediate type");
Assert.assertTrue(block2.isEndOfStreamBlock(), "Second block is EOS (done processing)");
}
@Test
public void shouldAggregateSingleInputBlockWithLiteralInput() {
// Given:
- List<RexExpression> calls = ImmutableList.of(getSum(new RexExpression.Literal(FieldSpec.DataType.INT, 1)));
+ List<RexExpression> calls = ImmutableList.of(getSum(new RexExpression.Literal(FieldSpec.DataType.DOUBLE, 1.0)));
List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0));
- DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
- Mockito.when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 3}))
+ DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, DOUBLE});
+ Mockito.when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 3.0}))
.thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
- DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+ DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, LONG});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.INTERMEDIATE);
+ AggType.FINAL);
// When:
TransferableBlock block1 = operator.nextBlock();
@@ -186,7 +187,7 @@ public class AggregateOperatorTest {
Mockito.verify(_input, Mockito.times(2)).nextBlock();
Assert.assertTrue(block1.getNumRows() > 0, "First block is the result");
// second value is 1 (the literal) instead of 3 (the col val)
- Assert.assertEquals(block1.getContainer().get(0), new Object[]{2, 1},
+ Assert.assertEquals(block1.getContainer().get(0), new Object[]{2, 1L},
"Expected two columns (group by key, agg value)");
Assert.assertTrue(block2.isEndOfStreamBlock(), "Second block is EOS (done processing)");
}
@@ -196,9 +197,10 @@ public class AggregateOperatorTest {
MultiStageOperator upstreamOperator = OperatorTestUtil.getOperator(OperatorTestUtil.OP_1);
// Create an aggregation call with sum for first column and group by second column.
RexExpression.FunctionCall agg = getSum(new RexExpression.InputRef(0));
- DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
+ DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{STRING, INT});
+ DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{STRING, DOUBLE});
AggregateOperator sum0GroupBy1 = new AggregateOperator(OperatorTestUtil.getDefaultContext(), upstreamOperator,
- OperatorTestUtil.getDataSchema(OperatorTestUtil.OP_1), inSchema, Collections.singletonList(agg),
+ outSchema, inSchema, Collections.singletonList(agg),
Collections.singletonList(new RexExpression.InputRef(1)), AggType.LEAF);
TransferableBlock result = sum0GroupBy1.getNextBlock();
while (result.isNoOpBlock()) {
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index f377ca188e..70f5fcc6ef 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -24,14 +24,11 @@ import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
-import org.apache.calcite.config.CalciteSystemProperty;
import org.apache.calcite.sql.SqlFunctionCategory;
-import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
-import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.commons.lang.StringUtils;
@@ -50,32 +47,37 @@ import org.apache.pinot.spi.utils.CommonConstants;
*/
public enum AggregationFunctionType {
// Aggregation functions for single-valued columns
- COUNT("count", Collections.emptyList(), true, null, SqlKind.COUNT, ReturnTypes.BIGINT, null,
- CalciteSystemProperty.STRICT.value() ? OperandTypes.ANY : OperandTypes.ONE_OR_MORE, SqlFunctionCategory.NUMERIC,
- null, ReturnTypes.BIGINT, null, null, null),
- MIN("min", Collections.emptyList(), false, null, SqlKind.MIN, ReturnTypes.DOUBLE, null,
- OperandTypes.COMPARABLE_ORDERED, SqlFunctionCategory.SYSTEM, null, ReturnTypes.DOUBLE, null, null, null),
- MAX("max", Collections.emptyList(), false, null, SqlKind.MAX, ReturnTypes.DOUBLE, null,
- OperandTypes.COMPARABLE_ORDERED, SqlFunctionCategory.SYSTEM, null, ReturnTypes.DOUBLE, null, null, null),
- // In multistage SUM is reduced via the PinotAvgSumAggregateReduceFunctionsRule, need not set up anything for reduce
- SUM("sum", Collections.emptyList(), false, null, SqlKind.SUM, ReturnTypes.DOUBLE, null, OperandTypes.NUMERIC,
- SqlFunctionCategory.NUMERIC, null, ReturnTypes.DOUBLE, null, null, null),
- SUM0("$sum0", Collections.emptyList(), false, null, SqlKind.SUM0, ReturnTypes.DOUBLE, null, OperandTypes.NUMERIC,
- SqlFunctionCategory.NUMERIC, null, ReturnTypes.DOUBLE, null, null, null),
+ COUNT("count", null, SqlKind.COUNT, SqlFunctionCategory.NUMERIC, OperandTypes.ONE_OR_MORE,
+ ReturnTypes.explicit(SqlTypeName.BIGINT), ReturnTypes.explicit(SqlTypeName.BIGINT)),
+ // TODO: min/max only supports NUMERIC in Pinot, where Calcite supports COMPARABLE_ORDERED
+ MIN("min", null, SqlKind.MIN, SqlFunctionCategory.SYSTEM, OperandTypes.NUMERIC, ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
+ ReturnTypes.explicit(SqlTypeName.DOUBLE)),
+ MAX("max", null, SqlKind.MAX, SqlFunctionCategory.SYSTEM, OperandTypes.NUMERIC, ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
+ ReturnTypes.explicit(SqlTypeName.DOUBLE)),
+ SUM("sum", null, SqlKind.SUM, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC, ReturnTypes.AGG_SUM,
+ ReturnTypes.explicit(SqlTypeName.DOUBLE)),
+ SUM0("$sum0", null, SqlKind.SUM0, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC,
+ ReturnTypes.AGG_SUM_EMPTY_IS_ZERO, ReturnTypes.explicit(SqlTypeName.DOUBLE)),
SUMPRECISION("sumPrecision"),
- AVG("avg", Collections.emptyList(), true, null, SqlKind.AVG, ReturnTypes.AVG_AGG_FUNCTION, null,
- OperandTypes.NUMERIC, SqlFunctionCategory.NUMERIC, null, null, "AVG_REDUCE", ReturnTypes.AVG_AGG_FUNCTION,
- OperandTypes.NUMERIC_NUMERIC),
+ AVG("avg"),
MODE("mode"),
+
FIRSTWITHTIME("firstWithTime"),
LASTWITHTIME("lastWithTime"),
MINMAXRANGE("minMaxRange"),
- DISTINCTCOUNT("distinctCount", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.BIGINT,
- null, OperandTypes.ANY, SqlFunctionCategory.USER_DEFINED_FUNCTION, null,
- ReturnTypes.explicit(SqlTypeName.OTHER), null, null, null),
+ /**
+ * for all distinct count family functions:
+ * (1) distinct_count only supports single argument;
+ * (2) count(distinct ...) support multi-argument and will be converted into DISTINCT + COUNT
+ */
+ DISTINCTCOUNT("distinctCount", null, SqlKind.OTHER_FUNCTION,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.ANY, ReturnTypes.BIGINT,
+ ReturnTypes.explicit(SqlTypeName.OTHER)),
DISTINCTCOUNTBITMAP("distinctCountBitmap"),
SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount"),
- DISTINCTCOUNTHLL("distinctCountHLL"),
+ DISTINCTCOUNTHLL("distinctCountHLL", Collections.emptyList(), SqlKind.OTHER_FUNCTION,
+ SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.ANY, ReturnTypes.BIGINT,
+ ReturnTypes.explicit(SqlTypeName.OTHER)),
DISTINCTCOUNTRAWHLL("distinctCountRawHLL"),
DISTINCTCOUNTSMARTHLL("distinctCountSmartHLL"),
FASTHLL("fastHLL"),
@@ -83,6 +85,7 @@ public enum AggregationFunctionType {
DISTINCTCOUNTRAWTHETASKETCH("distinctCountRawThetaSketch"),
DISTINCTSUM("distinctSum"),
DISTINCTAVG("distinctAvg"),
+
PERCENTILE("percentile"),
PERCENTILEEST("percentileEst"),
PERCENTILERAWEST("percentileRawEst"),
@@ -91,20 +94,21 @@ public enum AggregationFunctionType {
PERCENTILESMARTTDIGEST("percentileSmartTDigest"),
PERCENTILEKLL("percentileKLL"),
PERCENTILERAWKLL("percentileRawKLL"),
+
IDSET("idSet"),
+
HISTOGRAM("histogram"),
+
COVARPOP("covarPop"),
COVARSAMP("covarSamp"),
VARPOP("varPop"),
VARSAMP("varSamp"),
STDDEVPOP("stdDevPop"),
STDDEVSAMP("stdDevSamp"),
- SKEWNESS("skewness", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.DOUBLE, null,
- OperandTypes.NUMERIC, SqlFunctionCategory.USER_DEFINED_FUNCTION, "fourthMoment",
- ReturnTypes.explicit(SqlTypeName.OTHER), null, null, null),
- KURTOSIS("kurtosis", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.DOUBLE, null,
- OperandTypes.NUMERIC, SqlFunctionCategory.USER_DEFINED_FUNCTION, "fourthMoment",
- ReturnTypes.explicit(SqlTypeName.OTHER), null, null, null),
+ SKEWNESS("skewness", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
+ OperandTypes.NUMERIC, ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)),
+ KURTOSIS("kurtosis", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
+ OperandTypes.NUMERIC, ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)),
FOURTHMOMENT("fourthMoment"),
// DataSketches Tuple Sketch support
@@ -141,10 +145,10 @@ public enum AggregationFunctionType {
PERCENTILERAWKLLMV("percentileRawKLLMV"),
// boolean aggregate functions
- BOOLAND("boolAnd", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.BOOLEAN, null,
- OperandTypes.BOOLEAN, SqlFunctionCategory.USER_DEFINED_FUNCTION, null, ReturnTypes.INTEGER, null, null, null),
- BOOLOR("boolOr", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.BOOLEAN, null,
- OperandTypes.BOOLEAN, SqlFunctionCategory.USER_DEFINED_FUNCTION, null, ReturnTypes.INTEGER, null, null, null),
+ BOOLAND("boolAnd", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
+ OperandTypes.BOOLEAN, ReturnTypes.BOOLEAN, ReturnTypes.explicit(SqlTypeName.INTEGER)),
+ BOOLOR("boolOr", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
+ OperandTypes.BOOLEAN, ReturnTypes.BOOLEAN, ReturnTypes.explicit(SqlTypeName.INTEGER)),
// argMin and argMax
ARGMIN("argMin"),
@@ -160,61 +164,64 @@ public enum AggregationFunctionType {
private static final Set<String> NAMES = Arrays.stream(values()).flatMap(func -> Stream.of(func.name(),
func.getName(), func.getName().toLowerCase())).collect(Collectors.toSet());
+ // --------------------------------------------------------------------------
+ // Function signature used by Calcite.
+ // --------------------------------------------------------------------------
private final String _name;
private final List<String> _alternativeNames;
- private final boolean _isNativeCalciteAggregationFunctionType;
// Fields for registering the aggregation function with Calcite in multistage. These are typically used for the
// user facing aggregation functions and the return and operand types should reflect that which is user facing.
- private final SqlIdentifier _sqlIdentifier;
private final SqlKind _sqlKind;
- private final SqlReturnTypeInference _sqlReturnTypeInference;
- private final SqlOperandTypeInference _sqlOperandTypeInference;
- private final SqlOperandTypeChecker _sqlOperandTypeChecker;
private final SqlFunctionCategory _sqlFunctionCategory;
- // Fields for the intermediate stage functions used in multistage. These intermediate stage aggregation functions
- // are typically internal to Pinot and are used to handle the complex types used in the intermediate and final
- // aggregation stages. The intermediate function name may be the same as the user facing function name.
- private final String _intermediateFunctionName;
- private final SqlReturnTypeInference _sqlIntermediateReturnTypeInference;
-
- // Fields for the Calcite aggregation reduce step in multistage. The aggregation reduce is only used for special
- // functions like AVG which is split into SUM / COUNT. This is needed for proper null handling if COUNT is 0.
- private final String _reduceFunctionName;
- private final SqlReturnTypeInference _sqlReduceReturnTypeInference;
- private final SqlOperandTypeChecker _sqlReduceOperandTypeChecker;
+ // override options for Pinot aggregate functions that expects different return type or operand type
+ private final SqlReturnTypeInference _returnTypeInference;
+ private final SqlOperandTypeChecker _operandTypeChecker;
+ // override options for Pinot aggregate rules to insert intermediate results that are non-standard than return type.
+ private final SqlReturnTypeInference _intermediateReturnTypeInference;
/**
* Constructor to use for aggregation functions which are only supported in v1 engine today
*/
AggregationFunctionType(String name) {
- this(name, Collections.emptyList(), false, null, null, null, null, null, null, null, null, null, null, null);
+ this(name, null, null, null);
}
/**
* Constructor to use for aggregation functions which are supported in both v1 and multistage engines
*/
- AggregationFunctionType(String name, List<String> alternativeNames, boolean isNativeCalciteAggregationFunctionType,
- SqlIdentifier sqlIdentifier, SqlKind sqlKind, SqlReturnTypeInference sqlReturnTypeInference,
- SqlOperandTypeInference sqlOperandTypeInference, SqlOperandTypeChecker sqlOperandTypeChecker,
- SqlFunctionCategory sqlFunctionCategory, String intermediateFunctionName,
- SqlReturnTypeInference sqlIntermediateReturnTypeInference, String reduceFunctionName,
- SqlReturnTypeInference sqlReduceReturnTypeInference, SqlOperandTypeChecker sqlReduceOperandTypeChecker) {
+ AggregationFunctionType(String name, List<String> alternativeNames, SqlKind sqlKind,
+ SqlFunctionCategory sqlFunctionCategory) {
+ this(name, alternativeNames, sqlKind, sqlFunctionCategory, null, null, null);
+ }
+
+ /**
+ * Constructor to use for aggregation functions which are supported in both v1 and multistage engines
+ * and requires override on calcite behaviors.
+ */
+ AggregationFunctionType(String name, List<String> alternativeNames,
+ SqlKind sqlKind, SqlFunctionCategory sqlFunctionCategory, SqlOperandTypeChecker operandTypeChecker,
+ SqlReturnTypeInference returnTypeInference) {
+ this(name, alternativeNames, sqlKind, sqlFunctionCategory, operandTypeChecker, returnTypeInference, null);
+ }
+
+ AggregationFunctionType(String name, List<String> alternativeNames,
+ SqlKind sqlKind, SqlFunctionCategory sqlFunctionCategory, SqlOperandTypeChecker operandTypeChecker,
+ SqlReturnTypeInference returnTypeInference, SqlReturnTypeInference intermediateReturnTypeInference) {
_name = name;
- _alternativeNames = alternativeNames;
- _isNativeCalciteAggregationFunctionType = isNativeCalciteAggregationFunctionType;
- _sqlIdentifier = sqlIdentifier;
+ if (alternativeNames == null || alternativeNames.size() == 0) {
+ _alternativeNames = Collections.singletonList(getUnderscoreSplitAggregationFunctionName(_name));
+ } else {
+ _alternativeNames = alternativeNames;
+ }
_sqlKind = sqlKind;
- _sqlReturnTypeInference = sqlReturnTypeInference;
- _sqlOperandTypeInference = sqlOperandTypeInference;
- _sqlOperandTypeChecker = sqlOperandTypeChecker;
_sqlFunctionCategory = sqlFunctionCategory;
- _intermediateFunctionName = intermediateFunctionName == null ? name : intermediateFunctionName;
- _sqlIntermediateReturnTypeInference = sqlIntermediateReturnTypeInference;
- _reduceFunctionName = reduceFunctionName;
- _sqlReduceReturnTypeInference = sqlReduceReturnTypeInference;
- _sqlReduceOperandTypeChecker = sqlReduceOperandTypeChecker;
+
+ _returnTypeInference = returnTypeInference;
+ _operandTypeChecker = operandTypeChecker;
+ _intermediateReturnTypeInference = intermediateReturnTypeInference == null ? _returnTypeInference
+ : intermediateReturnTypeInference;
}
public String getName() {
@@ -225,54 +232,26 @@ public enum AggregationFunctionType {
return _alternativeNames;
}
- public boolean isNativeCalciteAggregationFunctionType() {
- return _isNativeCalciteAggregationFunctionType;
- }
-
- public SqlIdentifier getSqlIdentifier() {
- return _sqlIdentifier;
- }
-
public SqlKind getSqlKind() {
return _sqlKind;
}
- public SqlReturnTypeInference getSqlReturnTypeInference() {
- return _sqlReturnTypeInference;
+ public SqlReturnTypeInference getIntermediateReturnTypeInference() {
+ return _intermediateReturnTypeInference;
}
- public SqlOperandTypeInference getSqlOperandTypeInference() {
- return _sqlOperandTypeInference;
+ public SqlReturnTypeInference getReturnTypeInference() {
+ return _returnTypeInference;
}
- public SqlOperandTypeChecker getSqlOperandTypeChecker() {
- return _sqlOperandTypeChecker;
+ public SqlOperandTypeChecker getOperandTypeChecker() {
+ return _operandTypeChecker;
}
public SqlFunctionCategory getSqlFunctionCategory() {
return _sqlFunctionCategory;
}
- public String getIntermediateFunctionName() {
- return _intermediateFunctionName;
- }
-
- public SqlReturnTypeInference getSqlIntermediateReturnTypeInference() {
- return _sqlIntermediateReturnTypeInference;
- }
-
- public String getReduceFunctionName() {
- return _reduceFunctionName;
- }
-
- public SqlReturnTypeInference getSqlReduceReturnTypeInference() {
- return _sqlReduceReturnTypeInference;
- }
-
- public SqlOperandTypeChecker getSqlReduceOperandTypeChecker() {
- return _sqlReduceOperandTypeChecker;
- }
-
public static boolean isAggregationFunction(String functionName) {
if (NAMES.contains(functionName)) {
return true;
@@ -293,6 +272,13 @@ public enum AggregationFunctionType {
return StringUtils.remove(StringUtils.remove(functionName, '_').toUpperCase(), "$");
}
+ public static String getUnderscoreSplitAggregationFunctionName(String functionName) {
+ // Skip functions that have numbers for now and return their name as is
+ return functionName.matches(".*\\d.*")
+ ? functionName
+ : functionName.replaceAll("(.)(\\p{Upper}+|\\d+)", "$1_$2");
+ }
+
/**
* Returns the corresponding aggregation function type for the given function name.
* <p>NOTE: Underscores in the function name are ignored.
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org