You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by dz...@apache.org on 2023/02/23 08:54:46 UTC
[drill] branch master updated: DRILL-8403: Generated aggregate function calls are missing required filters when used with PIVOT (#2765)
This is an automated email from the ASF dual-hosted git repository.
dzamo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git
The following commit(s) were added to refs/heads/master by this push:
new 6254eb66e7 DRILL-8403: Generated aggregate function calls are missing required filters when used with PIVOT (#2765)
6254eb66e7 is described below
commit 6254eb66e747bb445fd815edc1a30bd1114f26f0
Author: Volodymyr Vysotskyi <vv...@gmail.com>
AuthorDate: Thu Feb 23 10:54:38 2023 +0200
DRILL-8403: Generated aggregate function calls are missing required filters when used with PIVOT (#2765)
---
.../planner/logical/DrillReduceAggregatesRule.java | 47 ++++++++++++++------
.../drill/exec/planner/physical/AggPrelBase.java | 51 ++++++++++++++++------
.../drill/exec/fn/impl/TestAggregateFunctions.java | 18 ++++++++
3 files changed, 89 insertions(+), 27 deletions(-)
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
index 062fda0c34..b386361adf 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
@@ -343,12 +343,10 @@ public class DrillReduceAggregatesRule extends RelOptRule {
SqlAggFunction sumAgg =
new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(
new SqlSumEmptyIsZeroAggFunction(), sumType);
- AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(),
- oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
+ AggregateCall sumCall = getAggCall(oldCall, sumAgg, sumType);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
- AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
- oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
+ AggregateCall countCall = getAggCall(oldCall, countAgg, countType);
RexNode tmpsumRef =
rexBuilder.addAggCall(
@@ -414,6 +412,21 @@ public class DrillReduceAggregatesRule extends RelOptRule {
}
}
+ private static AggregateCall getAggCall(AggregateCall oldCall,
+ SqlAggFunction aggFunction,
+ RelDataType sumType) {
+ return AggregateCall.create(aggFunction,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ oldCall.getArgList(),
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.getCollation(),
+ sumType,
+ null);
+ }
+
private RexNode reduceSum(
Aggregate oldAggRel,
AggregateCall oldCall,
@@ -441,12 +454,10 @@ public class DrillReduceAggregatesRule extends RelOptRule {
}
sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(
new SqlSumEmptyIsZeroAggFunction(), sumType);
- AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(),
- oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
+ AggregateCall sumZeroCall = getAggCall(oldCall, sumZeroAgg, sumType);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
- AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
- oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
+ AggregateCall countCall = getAggCall(oldCall, countAgg, countType);
// NOTE: these references are with respect to the output
// of newAggRel
RexNode sumZeroRef =
@@ -529,8 +540,11 @@ public class DrillReduceAggregatesRule extends RelOptRule {
new SqlSumAggFunction(sumType), sumType),
oldCall.isDistinct(),
oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
ImmutableIntList.of(argSquaredOrdinal),
- -1,
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.getCollation(),
sumType,
null);
final RexNode sumArgSquared =
@@ -547,8 +561,11 @@ public class DrillReduceAggregatesRule extends RelOptRule {
new SqlSumAggFunction(sumType), sumType),
oldCall.isDistinct(),
oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
ImmutableIntList.of(argOrdinal),
- -1,
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.getCollation(),
sumType,
null);
final RexNode sumArg =
@@ -565,8 +582,7 @@ public class DrillReduceAggregatesRule extends RelOptRule {
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
- final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
- oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
+ final AggregateCall countArgAggCall = getAggCall(oldCall, countAgg, countType);
final RexNode countArg =
rexBuilder.addAggCall(
countArgAggCall,
@@ -677,7 +693,7 @@ public class DrillReduceAggregatesRule extends RelOptRule {
RelNode inputRel,
List<AggregateCall> newCalls) {
RelOptCluster cluster = inputRel.getCluster();
- return new LogicalAggregate(cluster, cluster.traitSetOf(Convention.NONE),
+ return new LogicalAggregate(cluster, cluster.traitSetOf(Convention.NONE), Collections.emptyList(),
inputRel, oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls);
}
@@ -722,8 +738,11 @@ public class DrillReduceAggregatesRule extends RelOptRule {
sumZeroAgg,
oldAggregateCall.isDistinct(),
oldAggregateCall.isApproximate(),
+ oldAggregateCall.ignoreNulls(),
oldAggregateCall.getArgList(),
- -1,
+ oldAggregateCall.filterArg,
+ oldAggregateCall.distinctKeys,
+ oldAggregateCall.getCollation(),
sumType,
oldAggregateCall.getName());
oldAggRel.getCluster().getRexBuilder()
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
index f9a7d0e099..a8619aa8d3 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
@@ -33,6 +33,7 @@ import org.apache.drill.exec.planner.common.DrillAggregateRelBase;
import org.apache.drill.exec.planner.physical.visitor.PrelVisitor;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.InvalidRelException;
+import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
@@ -42,6 +43,7 @@ import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.util.Optionality;
import java.util.Collections;
import java.util.Iterator;
@@ -61,7 +63,7 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
// phase
PHASE_2of2("2nd");
- private String name;
+ private final String name;
OperatorPhase(String name) {
this.name = name;
@@ -99,7 +101,7 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
* creating a SUM whose return type is non-nullable.
*
*/
- public class SqlSumCountAggFunction extends SqlAggFunction {
+ public static class SqlSumCountAggFunction extends SqlAggFunction {
private final RelDataType type;
@@ -112,7 +114,8 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
OperandTypes.NUMERIC,
SqlFunctionCategory.NUMERIC,
false,
- false);
+ false,
+ Optionality.FORBIDDEN);
this.type = type;
}
@@ -175,8 +178,11 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
sumAggFun,
aggCall.e.isDistinct(),
aggCall.e.isApproximate(),
+ false,
Collections.singletonList(aggExprOrdinal),
aggCall.e.filterArg,
+ null,
+ RelCollations.EMPTY,
aggCall.e.getType(),
aggCall.e.getName());
@@ -187,8 +193,11 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
aggCall.e.getAggregation(),
aggCall.e.isDistinct(),
aggCall.e.isApproximate(),
+ false,
Collections.singletonList(aggExprOrdinal),
aggCall.e.filterArg,
+ null,
+ RelCollations.EMPTY,
aggCall.e.getType(),
aggCall.e.getName());
@@ -202,21 +211,29 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
List<LogicalExpression> args = Lists.newArrayList();
for (Integer i : call.getArgList()) {
LogicalExpression expr = FieldReference.getWithQuotedRef(fn.get(i));
- if (call.hasFilter()) {
- expr = IfExpression.newBuilder()
- .setIfCondition(new IfExpression.IfCondition(FieldReference.getWithQuotedRef(fn.get(call.filterArg)), expr))
- .setElse(NullExpression.INSTANCE)
- .build();
- }
+ expr = getArgumentExpression(call, fn, expr);
args.add(expr);
}
if (SqlKind.COUNT.name().equals(call.getAggregation().getName()) && args.isEmpty()) {
- args.add(new ValueExpressions.LongExpression(1L));
+ LogicalExpression expr = new ValueExpressions.LongExpression(1L);
+ expr = getArgumentExpression(call, fn, expr);
+ args.add(expr);
}
return new FunctionCall(call.getAggregation().getName().toLowerCase(), args, ExpressionPosition.UNKNOWN);
}
+ private static LogicalExpression getArgumentExpression(AggregateCall call, List<String> fn,
+ LogicalExpression expr) {
+ if (call.hasFilter()) {
+ return IfExpression.newBuilder()
+ .setIfCondition(new IfExpression.IfCondition(FieldReference.getWithQuotedRef(fn.get(call.filterArg)), expr))
+ .setElse(NullExpression.INSTANCE)
+ .build();
+ }
+ return expr;
+ }
+
@Override
public Iterator<Prel> iterator() {
return PrelUtil.iter(getInput());
@@ -249,9 +266,17 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
for (int arg : aggCall.getArgList()) {
arglist.add(arg + 1);
}
- aggregateCalls.add(AggregateCall.create(aggCall.getAggregation(), aggCall.isDistinct(),
- aggCall.isApproximate(), arglist, aggCall.filterArg, aggCall.type, aggCall.name));
+ aggregateCalls.add(AggregateCall.create(aggCall.getAggregation(),
+ aggCall.isDistinct(),
+ aggCall.isApproximate(),
+ false,
+ arglist,
+ aggCall.filterArg,
+ null,
+ RelCollations.EMPTY,
+ aggCall.type,
+ aggCall.name));
}
- return (Prel) copy(traitSet, children.get(0),indicator,groupingSet,groupingSets, aggregateCalls);
+ return (Prel) copy(traitSet, children.get(0), groupingSet, groupingSets, aggregateCalls);
}
}
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
index 97f3c254b2..edf619074b 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
@@ -1271,4 +1271,22 @@ public class TestAggregateFunctions extends ClusterTest {
.baselineValues(5L, 5L, 5L, 5L, 5L)
.go();
}
+
+ @Test
+ public void testAggregateWithPivot() throws Exception {
+ String query = "SELECT * FROM (\n" +
+ "SELECT education_level, salary, marital_status, extract(year from age(birth_date)) age\n" +
+ "FROM cp.`employee.json`)\n" +
+ "PIVOT (avg(salary) avg_salary, avg(age) avg_age FOR marital_status IN ('M' married, 'S' single))";
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("education_level", "married_avg_salary", "married_avg_age", "single_avg_salary", "single_avg_age")
+ .baselineValues("Graduate Degree", 4038.470588235294, 101.98823529411764, 4747.176470588235, 98.65882352941176)
+ .baselineValues("Bachelors Degree", 4789.166666666667, 102.43055555555556, 4193.566433566433, 102.02797202797203)
+ .baselineValues("Partial College", 4281.381578947368, 99.25657894736842, 3785.294117647059, 101.04411764705883)
+ .baselineValues("High School Degree", 3459.2805755395684, 103.57553956834532, 3571.830985915493, 102.69014084507042)
+ .baselineValues("Partial High School", 3555.8064516129034, 101.14516129032258, 3469.7014925373132, 103.3731343283582)
+ .go();
+ }
}