You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by me...@apache.org on 2015/06/19 20:23:35 UTC

drill git commit: DRILL-3254: Fix wrong results while using certain types of window functions that are rewritten using AvgVarianceConvertlet Added a custom Drill Convertlet to correctly reduce these aggregates.

Repository: drill
Updated Branches:
  refs/heads/master 710f82942 -> 6ebfbb9d0


DRILL-3254: Fix wrong results while using certain types of window functions that are rewritten using AvgVarianceConvertlet
Added a custom Drill Convertlet to correctly reduce these aggregates.


Project: http://git-wip-us.apache.org/repos/asf/drill/repo
Commit: http://git-wip-us.apache.org/repos/asf/drill/commit/6ebfbb9d
Tree: http://git-wip-us.apache.org/repos/asf/drill/tree/6ebfbb9d
Diff: http://git-wip-us.apache.org/repos/asf/drill/diff/6ebfbb9d

Branch: refs/heads/master
Commit: 6ebfbb9d0fc0b87b032f5e5d5cb0825f5464426e
Parents: 710f829
Author: Mehant Baid <me...@gmail.com>
Authored: Thu Jun 18 17:22:48 2015 -0700
Committer: Mehant Baid <me...@gmail.com>
Committed: Fri Jun 19 11:21:48 2015 -0700

----------------------------------------------------------------------
 .../planner/sql/DrillAvgVarianceConvertlet.java | 156 +++++++++++++++++++
 .../exec/planner/sql/DrillConvertletTable.java  |   6 +
 .../apache/drill/exec/TestWindowFunctions.java  |  28 ++++
 3 files changed, 190 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/drill/blob/6ebfbb9d/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java
new file mode 100644
index 0000000..4c0618d
--- /dev/null
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java
@@ -0,0 +1,156 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.planner.sql;
+
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlCall;
+import org.apache.calcite.sql.SqlLiteral;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlNumericLiteral;
+import org.apache.calcite.sql.fun.SqlAvgAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.parser.SqlParserPos;
+import org.apache.calcite.sql2rel.SqlRexContext;
+import org.apache.calcite.sql2rel.SqlRexConvertlet;
+import org.apache.calcite.util.Util;
+
+/*
+ * This class is adapted from calcite's AvgVarianceConvertlet. The difference being
+ * we add a cast to double before we perform the division. The reason we have a separate implementation
+ * from calcite's code is because while injecting a similar cast, calcite will look
+ * at the output type of the aggregate function which will be 'ANY' at that point and will
+ * inject a cast to 'ANY' which does not solve the problem.
+ */
+public class DrillAvgVarianceConvertlet implements SqlRexConvertlet {
+
+  private final SqlAvgAggFunction.Subtype subtype;
+  private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false);
+
+  public DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype subtype) {
+    this.subtype = subtype;
+  }
+
+  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
+    assert call.operandCount() == 1;
+    final SqlNode arg = call.operand(0);
+    final SqlNode expr;
+    switch (subtype) {
+      case AVG:
+        expr = expandAvg(arg);
+        break;
+      case STDDEV_POP:
+        expr = expandVariance(arg, true, true);
+        break;
+      case STDDEV_SAMP:
+        expr = expandVariance(arg, false, true);
+        break;
+      case VAR_POP:
+        expr = expandVariance(arg, true, false);
+        break;
+      case VAR_SAMP:
+        expr = expandVariance(arg, false, false);
+        break;
+      default:
+        throw Util.unexpected(subtype);
+    }
+    RelDataType type =
+        cx.getValidator().getValidatedNodeType(call);
+    RexNode rex = cx.convertExpression(expr);
+    return cx.getRexBuilder().ensureType(type, rex, true);
+  }
+
+  private SqlNode expandAvg(
+      final SqlNode arg) {
+    final SqlParserPos pos = SqlParserPos.ZERO;
+    final SqlNode sum =
+        SqlStdOperatorTable.SUM.createCall(pos, arg);
+    final SqlNode count =
+        SqlStdOperatorTable.COUNT.createCall(pos, arg);
+    final SqlNode sumAsDouble =
+        CastHighOp.createCall(pos, sum);
+    return SqlStdOperatorTable.DIVIDE.createCall(
+        pos, sumAsDouble, count);
+  }
+
+  private SqlNode expandVariance(
+      final SqlNode arg,
+      boolean biased,
+      boolean sqrt) {
+    /* stddev_pop(x) ==>
+     *   power(
+     *    (sum(x * x) - sum(x) * sum(x) / count(x))
+     *    / count(x),
+     *    .5)
+
+     * stddev_samp(x) ==>
+     *  power(
+     *    (sum(x * x) - sum(x) * sum(x) / count(x))
+     *    / (count(x) - 1),
+     *    .5)
+
+     * var_pop(x) ==>
+     *    (sum(x * x) - sum(x) * sum(x) / count(x))
+     *    / count(x)
+
+     * var_samp(x) ==>
+     *    (sum(x * x) - sum(x) * sum(x) / count(x))
+     *    / (count(x) - 1)
+     */
+    final SqlParserPos pos = SqlParserPos.ZERO;
+    final SqlNode argSquared =
+        SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
+    final SqlNode sumArgSquared =
+        SqlStdOperatorTable.SUM.createCall(pos, argSquared);
+    final SqlNode sum =
+        SqlStdOperatorTable.SUM.createCall(pos, arg);
+    final SqlNode sumSquared =
+        SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
+    final SqlNode count =
+        SqlStdOperatorTable.COUNT.createCall(pos, arg);
+    final SqlNode avgSumSquared =
+        SqlStdOperatorTable.DIVIDE.createCall(
+            pos, sumSquared, count);
+    final SqlNode diff =
+        SqlStdOperatorTable.MINUS.createCall(
+            pos, sumArgSquared, avgSumSquared);
+    final SqlNode denominator;
+    if (biased) {
+      denominator = count;
+    } else {
+      final SqlNumericLiteral one =
+          SqlLiteral.createExactNumeric("1", pos);
+      denominator =
+          SqlStdOperatorTable.MINUS.createCall(
+              pos, count, one);
+    }
+    final SqlNode diffAsDouble =
+        CastHighOp.createCall(pos, diff);
+    final SqlNode div =
+        SqlStdOperatorTable.DIVIDE.createCall(
+            pos, diffAsDouble, denominator);
+    SqlNode result = div;
+    if (sqrt) {
+      final SqlNumericLiteral half =
+          SqlLiteral.createExactNumeric("0.5", pos);
+      result =
+          SqlStdOperatorTable.POWER.createCall(pos, div, half);
+    }
+    return result;
+  }
+}

http://git-wip-us.apache.org/repos/asf/drill/blob/6ebfbb9d/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java
index 78404d4..8fddd14 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java
@@ -21,6 +21,7 @@ import java.util.HashMap;
 
 import org.apache.calcite.sql.SqlCall;
 import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.fun.SqlAvgAggFunction;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql2rel.SqlRexConvertlet;
 import org.apache.calcite.sql2rel.SqlRexConvertletTable;
@@ -33,6 +34,11 @@ public class DrillConvertletTable implements SqlRexConvertletTable{
   static {
     // Use custom convertlet for extract function
     map.put(SqlStdOperatorTable.EXTRACT, DrillExtractConvertlet.INSTANCE);
+    map.put(SqlStdOperatorTable.AVG, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.AVG));
+    map.put(SqlStdOperatorTable.STDDEV_POP, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.STDDEV_POP));
+    map.put(SqlStdOperatorTable.STDDEV_SAMP, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.STDDEV_SAMP));
+    map.put(SqlStdOperatorTable.VAR_POP, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.VAR_POP));
+    map.put(SqlStdOperatorTable.VAR_SAMP, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.VAR_SAMP));
   }
 
   /*

http://git-wip-us.apache.org/repos/asf/drill/blob/6ebfbb9d/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
index 9d660c3..1c8b0db 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
@@ -247,4 +247,32 @@ public class TestWindowFunctions extends BaseTestQuery {
     }
   }
 
+  /* Verify the output of aggregate functions (which are reduced
+    * eg: avg(x) = sum(x)/count(x)) return results of the correct
+    * data type (double)
+    */
+  @Test
+  public void testAvgVarianceWindowFunctions() throws Exception {
+    final String avgQuery = "select avg(n_nationkey) over (partition by n_nationkey) col1 " +
+        "from cp.`tpch/nation.parquet` " +
+        "where n_nationkey = 1";
+
+    testBuilder()
+        .sqlQuery(avgQuery)
+        .unOrdered()
+        .baselineColumns("col1")
+        .baselineValues(1.0d)
+        .go();
+
+    final String varianceQuery = "select var_pop(n_nationkey) over (partition by n_nationkey) col1 " +
+        "from cp.`tpch/nation.parquet` " +
+        "where n_nationkey = 1";
+
+    testBuilder()
+        .sqlQuery(varianceQuery)
+        .unOrdered()
+        .baselineColumns("col1")
+        .baselineValues(0.0d)
+        .go();
+  }
 }