You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by dz...@apache.org on 2023/02/23 08:54:46 UTC

[drill] branch master updated: DRILL-8403: Generated aggregate function calls are missing required filters when used with PIVOT (#2765)

This is an automated email from the ASF dual-hosted git repository.

dzamo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git


The following commit(s) were added to refs/heads/master by this push:
     new 6254eb66e7 DRILL-8403: Generated aggregate function calls are missing required filters when used with PIVOT (#2765)
6254eb66e7 is described below

commit 6254eb66e747bb445fd815edc1a30bd1114f26f0
Author: Volodymyr Vysotskyi <vv...@gmail.com>
AuthorDate: Thu Feb 23 10:54:38 2023 +0200

    DRILL-8403: Generated aggregate function calls are missing required filters when used with PIVOT (#2765)
---
 .../planner/logical/DrillReduceAggregatesRule.java | 47 ++++++++++++++------
 .../drill/exec/planner/physical/AggPrelBase.java   | 51 ++++++++++++++++------
 .../drill/exec/fn/impl/TestAggregateFunctions.java | 18 ++++++++
 3 files changed, 89 insertions(+), 27 deletions(-)

diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
index 062fda0c34..b386361adf 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
@@ -343,12 +343,10 @@ public class DrillReduceAggregatesRule extends RelOptRule {
     SqlAggFunction sumAgg =
         new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(
           new SqlSumEmptyIsZeroAggFunction(), sumType);
-    AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(),
-        oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
+    AggregateCall sumCall = getAggCall(oldCall, sumAgg, sumType);
     final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
     final RelDataType countType = countAgg.getReturnType(typeFactory);
-    AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
-        oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
+    AggregateCall countCall = getAggCall(oldCall, countAgg, countType);
 
     RexNode tmpsumRef =
         rexBuilder.addAggCall(
@@ -414,6 +412,21 @@ public class DrillReduceAggregatesRule extends RelOptRule {
     }
   }
 
+  private static AggregateCall getAggCall(AggregateCall oldCall,
+      SqlAggFunction aggFunction,
+      RelDataType sumType) {
+    return AggregateCall.create(aggFunction,
+        oldCall.isDistinct(),
+        oldCall.isApproximate(),
+        oldCall.ignoreNulls(),
+        oldCall.getArgList(),
+        oldCall.filterArg,
+        oldCall.distinctKeys,
+        oldCall.getCollation(),
+        sumType,
+        null);
+  }
+
   private RexNode reduceSum(
       Aggregate oldAggRel,
       AggregateCall oldCall,
@@ -441,12 +454,10 @@ public class DrillReduceAggregatesRule extends RelOptRule {
     }
     sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(
         new SqlSumEmptyIsZeroAggFunction(), sumType);
-    AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(),
-        oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
+    AggregateCall sumZeroCall = getAggCall(oldCall, sumZeroAgg, sumType);
     final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
     final RelDataType countType = countAgg.getReturnType(typeFactory);
-    AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
-        oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
+    AggregateCall countCall = getAggCall(oldCall, countAgg, countType);
     // NOTE:  these references are with respect to the output
     // of newAggRel
     RexNode sumZeroRef =
@@ -529,8 +540,11 @@ public class DrillReduceAggregatesRule extends RelOptRule {
                 new SqlSumAggFunction(sumType), sumType),
             oldCall.isDistinct(),
             oldCall.isApproximate(),
+            oldCall.ignoreNulls(),
             ImmutableIntList.of(argSquaredOrdinal),
-            -1,
+            oldCall.filterArg,
+            oldCall.distinctKeys,
+            oldCall.getCollation(),
             sumType,
             null);
     final RexNode sumArgSquared =
@@ -547,8 +561,11 @@ public class DrillReduceAggregatesRule extends RelOptRule {
                 new SqlSumAggFunction(sumType), sumType),
             oldCall.isDistinct(),
             oldCall.isApproximate(),
+            oldCall.ignoreNulls(),
             ImmutableIntList.of(argOrdinal),
-            -1,
+            oldCall.filterArg,
+            oldCall.distinctKeys,
+            oldCall.getCollation(),
             sumType,
             null);
     final RexNode sumArg =
@@ -565,8 +582,7 @@ public class DrillReduceAggregatesRule extends RelOptRule {
 
     final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
     final RelDataType countType = countAgg.getReturnType(typeFactory);
-    final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
-        oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
+    final AggregateCall countArgAggCall = getAggCall(oldCall, countAgg, countType);
     final RexNode countArg =
         rexBuilder.addAggCall(
             countArgAggCall,
@@ -677,7 +693,7 @@ public class DrillReduceAggregatesRule extends RelOptRule {
       RelNode inputRel,
       List<AggregateCall> newCalls) {
     RelOptCluster cluster = inputRel.getCluster();
-    return new LogicalAggregate(cluster, cluster.traitSetOf(Convention.NONE),
+    return new LogicalAggregate(cluster, cluster.traitSetOf(Convention.NONE), Collections.emptyList(),
         inputRel, oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls);
   }
 
@@ -722,8 +738,11 @@ public class DrillReduceAggregatesRule extends RelOptRule {
                   sumZeroAgg,
                   oldAggregateCall.isDistinct(),
                   oldAggregateCall.isApproximate(),
+                  oldAggregateCall.ignoreNulls(),
                   oldAggregateCall.getArgList(),
-                  -1,
+                  oldAggregateCall.filterArg,
+                  oldAggregateCall.distinctKeys,
+                  oldAggregateCall.getCollation(),
                   sumType,
                   oldAggregateCall.getName());
           oldAggRel.getCluster().getRexBuilder()
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
index f9a7d0e099..a8619aa8d3 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java
@@ -33,6 +33,7 @@ import org.apache.drill.exec.planner.common.DrillAggregateRelBase;
 import org.apache.drill.exec.planner.physical.visitor.PrelVisitor;
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.InvalidRelException;
+import org.apache.calcite.rel.RelCollations;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
@@ -42,6 +43,7 @@ import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.util.Optionality;
 
 import java.util.Collections;
 import java.util.Iterator;
@@ -61,7 +63,7 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
     // phase
     PHASE_2of2("2nd");
 
-    private String name;
+    private final String name;
 
     OperatorPhase(String name) {
       this.name = name;
@@ -99,7 +101,7 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
    * creating a SUM whose return type is non-nullable.
    *
    */
-  public class SqlSumCountAggFunction extends SqlAggFunction {
+  public static class SqlSumCountAggFunction extends SqlAggFunction {
 
     private final RelDataType type;
 
@@ -112,7 +114,8 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
           OperandTypes.NUMERIC,
           SqlFunctionCategory.NUMERIC,
           false,
-          false);
+          false,
+          Optionality.FORBIDDEN);
 
       this.type = type;
     }
@@ -175,8 +178,11 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
                   sumAggFun,
                   aggCall.e.isDistinct(),
                   aggCall.e.isApproximate(),
+                  false,
                   Collections.singletonList(aggExprOrdinal),
                   aggCall.e.filterArg,
+                  null,
+                  RelCollations.EMPTY,
                   aggCall.e.getType(),
                   aggCall.e.getName());
 
@@ -187,8 +193,11 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
                   aggCall.e.getAggregation(),
                   aggCall.e.isDistinct(),
                   aggCall.e.isApproximate(),
+                  false,
                   Collections.singletonList(aggExprOrdinal),
                   aggCall.e.filterArg,
+                  null,
+                  RelCollations.EMPTY,
                   aggCall.e.getType(),
                   aggCall.e.getName());
 
@@ -202,21 +211,29 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
     List<LogicalExpression> args = Lists.newArrayList();
     for (Integer i : call.getArgList()) {
       LogicalExpression expr = FieldReference.getWithQuotedRef(fn.get(i));
-      if (call.hasFilter()) {
-        expr = IfExpression.newBuilder()
-          .setIfCondition(new IfExpression.IfCondition(FieldReference.getWithQuotedRef(fn.get(call.filterArg)), expr))
-          .setElse(NullExpression.INSTANCE)
-          .build();
-      }
+      expr = getArgumentExpression(call, fn, expr);
       args.add(expr);
     }
 
     if (SqlKind.COUNT.name().equals(call.getAggregation().getName()) && args.isEmpty()) {
-      args.add(new ValueExpressions.LongExpression(1L));
+      LogicalExpression expr = new ValueExpressions.LongExpression(1L);
+      expr = getArgumentExpression(call, fn, expr);
+      args.add(expr);
     }
     return new FunctionCall(call.getAggregation().getName().toLowerCase(), args, ExpressionPosition.UNKNOWN);
   }
 
+  private static LogicalExpression getArgumentExpression(AggregateCall call, List<String> fn,
+      LogicalExpression expr) {
+    if (call.hasFilter()) {
+      return IfExpression.newBuilder()
+          .setIfCondition(new IfExpression.IfCondition(FieldReference.getWithQuotedRef(fn.get(call.filterArg)), expr))
+          .setElse(NullExpression.INSTANCE)
+          .build();
+    }
+    return expr;
+  }
+
   @Override
   public Iterator<Prel> iterator() {
     return PrelUtil.iter(getInput());
@@ -249,9 +266,17 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel
       for (int arg : aggCall.getArgList()) {
         arglist.add(arg + 1);
       }
-      aggregateCalls.add(AggregateCall.create(aggCall.getAggregation(), aggCall.isDistinct(),
-              aggCall.isApproximate(), arglist, aggCall.filterArg, aggCall.type, aggCall.name));
+      aggregateCalls.add(AggregateCall.create(aggCall.getAggregation(),
+          aggCall.isDistinct(),
+          aggCall.isApproximate(),
+          false,
+          arglist,
+          aggCall.filterArg,
+          null,
+          RelCollations.EMPTY,
+          aggCall.type,
+          aggCall.name));
     }
-    return (Prel) copy(traitSet, children.get(0),indicator,groupingSet,groupingSets, aggregateCalls);
+    return (Prel) copy(traitSet, children.get(0), groupingSet, groupingSets, aggregateCalls);
   }
 }
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
index 97f3c254b2..edf619074b 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java
@@ -1271,4 +1271,22 @@ public class TestAggregateFunctions extends ClusterTest {
       .baselineValues(5L, 5L, 5L, 5L, 5L)
       .go();
   }
+
+  @Test
+  public void testAggregateWithPivot() throws Exception {
+    String query = "SELECT * FROM (\n" +
+        "SELECT education_level, salary, marital_status, extract(year from age(birth_date)) age\n" +
+        "FROM cp.`employee.json`)\n" +
+        "PIVOT (avg(salary) avg_salary, avg(age) avg_age FOR marital_status IN ('M' married, 'S' single))";
+    testBuilder()
+      .sqlQuery(query)
+      .unOrdered()
+      .baselineColumns("education_level", "married_avg_salary", "married_avg_age", "single_avg_salary", "single_avg_age")
+      .baselineValues("Graduate Degree", 4038.470588235294, 101.98823529411764, 4747.176470588235, 98.65882352941176)
+      .baselineValues("Bachelors Degree", 4789.166666666667, 102.43055555555556, 4193.566433566433, 102.02797202797203)
+      .baselineValues("Partial College", 4281.381578947368, 99.25657894736842, 3785.294117647059, 101.04411764705883)
+      .baselineValues("High School Degree", 3459.2805755395684, 103.57553956834532, 3571.830985915493, 102.69014084507042)
+      .baselineValues("Partial High School", 3555.8064516129034, 101.14516129032258, 3469.7014925373132, 103.3731343283582)
+      .go();
+  }
 }