You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by hs...@apache.org on 2016/04/27 06:48:52 UTC

drill git commit: DRILL-4529: Force $SUM0 to be used when Window Sum is supposed to returned non-nullable type

Repository: drill
Updated Branches:
  refs/heads/master 8176fbca6 -> 5705d4509


DRILL-4529: Force $SUM0 to be used when Window Sum is supposed to returned non-nullable type


Project: http://git-wip-us.apache.org/repos/asf/drill/repo
Commit: http://git-wip-us.apache.org/repos/asf/drill/commit/5705d450
Tree: http://git-wip-us.apache.org/repos/asf/drill/tree/5705d450
Diff: http://git-wip-us.apache.org/repos/asf/drill/diff/5705d450

Branch: refs/heads/master
Commit: 5705d45095bd89fb9d1e7b3b3e12c34e74930c4c
Parents: 8176fbc
Author: Hsuan-Yi Chu <hs...@usc.edu>
Authored: Sun Mar 27 16:18:21 2016 -0700
Committer: Hsuan-Yi Chu <hs...@usc.edu>
Committed: Tue Apr 26 16:22:21 2016 -0700

----------------------------------------------------------------------
 .../apache/drill/exec/planner/PlannerPhase.java |  3 +-
 .../logical/DrillReduceAggregatesRule.java      | 93 +++++++++++++++++---
 .../drill/TestFunctionsWithTypeExpoQueries.java | 12 ++-
 .../apache/drill/exec/TestWindowFunctions.java  |  6 +-
 4 files changed, 98 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/drill/blob/5705d450/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
index 2875bcf..22a8b6f 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
@@ -141,7 +141,8 @@ public enum PlannerPhase {
     public RuleSet getRules(OptimizerRulesContext context, Collection<StoragePlugin> plugins) {
       return PlannerPhase.mergedRuleSets(
           RuleSets.ofList(
-              DrillReduceAggregatesRule.INSTANCE_SUM),
+              DrillReduceAggregatesRule.INSTANCE_SUM,
+              DrillReduceAggregatesRule.INSTANCE_WINDOW_SUM),
           getStorageRules(context, plugins, this)
           );
     }

http://git-wip-us.apache.org/repos/asf/drill/blob/5705d450/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
index dd2fc14..243e4db 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java
@@ -31,6 +31,7 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import org.apache.calcite.rel.InvalidRelException;
 import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.Window;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.sql.SqlOperator;
 import org.apache.calcite.sql.SqlOperatorBinding;
@@ -40,7 +41,6 @@ import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.util.trace.CalciteTrace;
 import org.apache.drill.exec.planner.physical.PlannerSettings;
 import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper;
-import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper;
 import org.apache.drill.exec.planner.sql.DrillSqlOperator;
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.RelNode;
@@ -82,6 +82,9 @@ public class DrillReduceAggregatesRule extends RelOptRule {
   public static final DrillConvertSumToSumZero INSTANCE_SUM =
       new DrillConvertSumToSumZero(operand(DrillAggregateRel.class, any()));
 
+  public static final DrillConvertWindowSumToSumZero INSTANCE_WINDOW_SUM =
+          new DrillConvertWindowSumToSumZero(operand(DrillWindowRel.class, any()));
+
   private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false,
       new SqlReturnTypeInference() {
         @Override
@@ -695,12 +698,7 @@ public class DrillReduceAggregatesRule extends RelOptRule {
     public boolean matches(RelOptRuleCall call) {
       DrillAggregateRel oldAggRel = (DrillAggregateRel) call.rels[0];
       for (AggregateCall aggregateCall : oldAggRel.getAggCallList()) {
-        final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(aggregateCall.getAggregation());
-        if(sqlAggFunction instanceof SqlSumAggFunction
-            && !aggregateCall.getType().isNullable()) {
-          // If SUM(x) is not nullable, the validator must have determined that
-          // nulls are impossible (because the group is never empty and x is never
-          // null). Therefore we translate to SUM0(x).
+        if(isConversionToSumZeroNeeded(aggregateCall.getAggregation(), aggregateCall.getType())) {
           return true;
         }
       }
@@ -714,10 +712,7 @@ public class DrillReduceAggregatesRule extends RelOptRule {
       final Map<AggregateCall, RexNode> aggCallMapping = Maps.newHashMap();
       final List<AggregateCall> newAggregateCalls = Lists.newArrayList();
       for (AggregateCall oldAggregateCall : oldAggRel.getAggCallList()) {
-        final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(
-            oldAggregateCall.getAggregation());
-        if(sqlAggFunction instanceof SqlSumAggFunction
-            && !oldAggregateCall.getType().isNullable()) {
+        if(isConversionToSumZeroNeeded(oldAggregateCall.getAggregation(), oldAggregateCall.getType())) {
           final RelDataType argType = oldAggregateCall.getType();
           final RelDataType sumType = oldAggRel.getCluster().getTypeFactory()
               .createTypeWithNullability(argType, argType.isNullable());
@@ -756,5 +751,81 @@ public class DrillReduceAggregatesRule extends RelOptRule {
       }
     }
   }
+
+  private static class DrillConvertWindowSumToSumZero extends RelOptRule {
+    public DrillConvertWindowSumToSumZero(RelOptRuleOperand operand) {
+      super(operand);
+    }
+
+    @Override
+    public boolean matches(RelOptRuleCall call) {
+      final DrillWindowRel oldWinRel = (DrillWindowRel) call.rels[0];
+      for(Window.Group group : oldWinRel.groups) {
+        for(Window.RexWinAggCall rexWinAggCall : group.aggCalls) {
+          if(isConversionToSumZeroNeeded(rexWinAggCall.getOperator(), rexWinAggCall.getType())) {
+            return true;
+          }
+        }
+      }
+      return false;
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+      final DrillWindowRel oldWinRel = (DrillWindowRel) call.rels[0];
+      final ImmutableList.Builder<Window.Group> builder = ImmutableList.builder();
+
+      for(Window.Group group : oldWinRel.groups) {
+        final List<Window.RexWinAggCall> aggCalls = Lists.newArrayList();
+        for(Window.RexWinAggCall rexWinAggCall : group.aggCalls) {
+          if(isConversionToSumZeroNeeded(rexWinAggCall.getOperator(), rexWinAggCall.getType())) {
+            final RelDataType argType = rexWinAggCall.getType();
+            final RelDataType sumType = oldWinRel.getCluster().getTypeFactory()
+                .createTypeWithNullability(argType, argType.isNullable());
+            final SqlAggFunction sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(
+                new SqlSumEmptyIsZeroAggFunction(), sumType);
+            final Window.RexWinAggCall sumZeroCall =
+                new Window.RexWinAggCall(
+                    sumZeroAgg,
+                    sumType,
+                    rexWinAggCall.operands,
+                    rexWinAggCall.ordinal);
+            aggCalls.add(sumZeroCall);
+          } else {
+            aggCalls.add(rexWinAggCall);
+          }
+        }
+
+        final Window.Group newGroup = new Window.Group(
+            group.keys,
+            group.isRows,
+            group.lowerBound,
+            group.upperBound,
+            group.orderKeys,
+            aggCalls);
+        builder.add(newGroup);
+      }
+
+      call.transformTo(new DrillWindowRel(
+          oldWinRel.getCluster(),
+          oldWinRel.getTraitSet(),
+          oldWinRel.getInput(),
+          oldWinRel.constants,
+          oldWinRel.getRowType(),
+          builder.build()));
+    }
+  }
+
+  private static boolean isConversionToSumZeroNeeded(SqlOperator sqlOperator, RelDataType type) {
+    sqlOperator = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(sqlOperator);
+    if(sqlOperator instanceof SqlSumAggFunction
+        && !type.isNullable()) {
+      // If SUM(x) is not nullable, the validator must have determined that
+      // nulls are impossible (because the group is never empty and x is never
+      // null). Therefore we translate to SUM0(x).
+      return true;
+    }
+    return false;
+  }
 }
 

http://git-wip-us.apache.org/repos/asf/drill/blob/5705d450/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java
index 5d16edd..43b594b 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java
@@ -22,7 +22,6 @@ import org.apache.commons.lang3.tuple.Pair;
 import org.apache.drill.common.expression.SchemaPath;
 import org.apache.drill.common.types.TypeProtos;
 import org.apache.drill.common.util.FileUtils;
-import org.junit.Ignore;
 import org.junit.Test;
 
 import java.util.List;
@@ -709,4 +708,15 @@ public class TestFunctionsWithTypeExpoQueries extends BaseTestQuery {
         .build()
         .run();
   }
+
+  @Test // DRILL-4529
+  public void testWindowSumConstant() throws Exception {
+    final String query = "select sum(1) over w as col \n" +
+        "from cp.`tpch/region.parquet` \n" +
+        "window w as (partition by r_regionkey)";
+
+    final String[] expectedPlan = {"\\$SUM0"};
+    final String[] excludedPlan = {};
+    PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, excludedPlan);
+  }
 }

http://git-wip-us.apache.org/repos/asf/drill/blob/5705d450/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
index 8055774..1d9900c 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
@@ -580,7 +580,7 @@ public class TestWindowFunctions extends BaseTestQuery {
         "window w as(partition by position_id order by employee_id)";
 
     // Validate the plan
-    final String[] expectedPlan = {"Window.*partition \\{0\\} order by \\[1\\].*RANK\\(\\), SUM\\(\\$2\\), SUM\\(\\$1\\), SUM\\(\\$3\\)",
+    final String[] expectedPlan = {"Window.*partition \\{0\\} order by \\[1\\].*RANK\\(\\), \\$SUM0\\(\\$2\\), SUM\\(\\$1\\), \\$SUM0\\(\\$3\\)",
         "Scan.*columns=\\[`position_id`, `employee_id`\\]"};
     final String[] excludedPatterns = {"Scan.*columns=\\[`\\*`\\]"};
     PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, excludedPatterns);
@@ -705,10 +705,10 @@ public class TestWindowFunctions extends BaseTestQuery {
         "order by 1, 2, 3, 4", root);
 
     // Validate the plan
-    final String[] expectedPlan = {"Window.*SUM\\(\\$3\\).*\n" +
+    final String[] expectedPlan = {"Window.*\\$SUM0\\(\\$3\\).*\n" +
         ".*SelectionVectorRemover.*\n" +
         ".*Sort.*\n" +
-        ".*Window.*SUM\\(\\$2\\).*"
+        ".*Window.*\\$SUM0\\(\\$2\\).*"
     };
     PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, new String[]{});