You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by me...@apache.org on 2015/06/19 20:23:35 UTC
drill git commit: DRILL-3254: Fix wrong results while using certain
types of window functions that are rewritten using AvgVarianceConvertlet
Added a custom Drill Convertlet to correctly reduce these aggregates.
Repository: drill
Updated Branches:
refs/heads/master 710f82942 -> 6ebfbb9d0
DRILL-3254: Fix wrong results while using certain types of window functions that are rewritten using AvgVarianceConvertlet
Added a custom Drill Convertlet to correctly reduce these aggregates.
Project: http://git-wip-us.apache.org/repos/asf/drill/repo
Commit: http://git-wip-us.apache.org/repos/asf/drill/commit/6ebfbb9d
Tree: http://git-wip-us.apache.org/repos/asf/drill/tree/6ebfbb9d
Diff: http://git-wip-us.apache.org/repos/asf/drill/diff/6ebfbb9d
Branch: refs/heads/master
Commit: 6ebfbb9d0fc0b87b032f5e5d5cb0825f5464426e
Parents: 710f829
Author: Mehant Baid <me...@gmail.com>
Authored: Thu Jun 18 17:22:48 2015 -0700
Committer: Mehant Baid <me...@gmail.com>
Committed: Fri Jun 19 11:21:48 2015 -0700
----------------------------------------------------------------------
.../planner/sql/DrillAvgVarianceConvertlet.java | 156 +++++++++++++++++++
.../exec/planner/sql/DrillConvertletTable.java | 6 +
.../apache/drill/exec/TestWindowFunctions.java | 28 ++++
3 files changed, 190 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/drill/blob/6ebfbb9d/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java
new file mode 100644
index 0000000..4c0618d
--- /dev/null
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java
@@ -0,0 +1,156 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.planner.sql;
+
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlCall;
+import org.apache.calcite.sql.SqlLiteral;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlNumericLiteral;
+import org.apache.calcite.sql.fun.SqlAvgAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.parser.SqlParserPos;
+import org.apache.calcite.sql2rel.SqlRexContext;
+import org.apache.calcite.sql2rel.SqlRexConvertlet;
+import org.apache.calcite.util.Util;
+
+/*
+ * This class is adapted from calcite's AvgVarianceConvertlet. The difference being
+ * we add a cast to double before we perform the division. The reason we have a separate implementation
+ * from calcite's code is because while injecting a similar cast, calcite will look
+ * at the output type of the aggregate function which will be 'ANY' at that point and will
+ * inject a cast to 'ANY' which does not solve the problem.
+ */
+public class DrillAvgVarianceConvertlet implements SqlRexConvertlet {
+
+ private final SqlAvgAggFunction.Subtype subtype;
+ private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false);
+
+ public DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype subtype) {
+ this.subtype = subtype;
+ }
+
+ public RexNode convertCall(SqlRexContext cx, SqlCall call) {
+ assert call.operandCount() == 1;
+ final SqlNode arg = call.operand(0);
+ final SqlNode expr;
+ switch (subtype) {
+ case AVG:
+ expr = expandAvg(arg);
+ break;
+ case STDDEV_POP:
+ expr = expandVariance(arg, true, true);
+ break;
+ case STDDEV_SAMP:
+ expr = expandVariance(arg, false, true);
+ break;
+ case VAR_POP:
+ expr = expandVariance(arg, true, false);
+ break;
+ case VAR_SAMP:
+ expr = expandVariance(arg, false, false);
+ break;
+ default:
+ throw Util.unexpected(subtype);
+ }
+ RelDataType type =
+ cx.getValidator().getValidatedNodeType(call);
+ RexNode rex = cx.convertExpression(expr);
+ return cx.getRexBuilder().ensureType(type, rex, true);
+ }
+
+ private SqlNode expandAvg(
+ final SqlNode arg) {
+ final SqlParserPos pos = SqlParserPos.ZERO;
+ final SqlNode sum =
+ SqlStdOperatorTable.SUM.createCall(pos, arg);
+ final SqlNode count =
+ SqlStdOperatorTable.COUNT.createCall(pos, arg);
+ final SqlNode sumAsDouble =
+ CastHighOp.createCall(pos, sum);
+ return SqlStdOperatorTable.DIVIDE.createCall(
+ pos, sumAsDouble, count);
+ }
+
+ private SqlNode expandVariance(
+ final SqlNode arg,
+ boolean biased,
+ boolean sqrt) {
+ /* stddev_pop(x) ==>
+ * power(
+ * (sum(x * x) - sum(x) * sum(x) / count(x))
+ * / count(x),
+ * .5)
+
+ * stddev_samp(x) ==>
+ * power(
+ * (sum(x * x) - sum(x) * sum(x) / count(x))
+ * / (count(x) - 1),
+ * .5)
+
+ * var_pop(x) ==>
+ * (sum(x * x) - sum(x) * sum(x) / count(x))
+ * / count(x)
+
+ * var_samp(x) ==>
+ * (sum(x * x) - sum(x) * sum(x) / count(x))
+ * / (count(x) - 1)
+ */
+ final SqlParserPos pos = SqlParserPos.ZERO;
+ final SqlNode argSquared =
+ SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
+ final SqlNode sumArgSquared =
+ SqlStdOperatorTable.SUM.createCall(pos, argSquared);
+ final SqlNode sum =
+ SqlStdOperatorTable.SUM.createCall(pos, arg);
+ final SqlNode sumSquared =
+ SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
+ final SqlNode count =
+ SqlStdOperatorTable.COUNT.createCall(pos, arg);
+ final SqlNode avgSumSquared =
+ SqlStdOperatorTable.DIVIDE.createCall(
+ pos, sumSquared, count);
+ final SqlNode diff =
+ SqlStdOperatorTable.MINUS.createCall(
+ pos, sumArgSquared, avgSumSquared);
+ final SqlNode denominator;
+ if (biased) {
+ denominator = count;
+ } else {
+ final SqlNumericLiteral one =
+ SqlLiteral.createExactNumeric("1", pos);
+ denominator =
+ SqlStdOperatorTable.MINUS.createCall(
+ pos, count, one);
+ }
+ final SqlNode diffAsDouble =
+ CastHighOp.createCall(pos, diff);
+ final SqlNode div =
+ SqlStdOperatorTable.DIVIDE.createCall(
+ pos, diffAsDouble, denominator);
+ SqlNode result = div;
+ if (sqrt) {
+ final SqlNumericLiteral half =
+ SqlLiteral.createExactNumeric("0.5", pos);
+ result =
+ SqlStdOperatorTable.POWER.createCall(pos, div, half);
+ }
+ return result;
+ }
+}
http://git-wip-us.apache.org/repos/asf/drill/blob/6ebfbb9d/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java
index 78404d4..8fddd14 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java
@@ -21,6 +21,7 @@ import java.util.HashMap;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql2rel.SqlRexConvertlet;
import org.apache.calcite.sql2rel.SqlRexConvertletTable;
@@ -33,6 +34,11 @@ public class DrillConvertletTable implements SqlRexConvertletTable{
static {
// Use custom convertlet for extract function
map.put(SqlStdOperatorTable.EXTRACT, DrillExtractConvertlet.INSTANCE);
+ map.put(SqlStdOperatorTable.AVG, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.AVG));
+ map.put(SqlStdOperatorTable.STDDEV_POP, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.STDDEV_POP));
+ map.put(SqlStdOperatorTable.STDDEV_SAMP, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.STDDEV_SAMP));
+ map.put(SqlStdOperatorTable.VAR_POP, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.VAR_POP));
+ map.put(SqlStdOperatorTable.VAR_SAMP, new DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype.VAR_SAMP));
}
/*
http://git-wip-us.apache.org/repos/asf/drill/blob/6ebfbb9d/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
----------------------------------------------------------------------
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
index 9d660c3..1c8b0db 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/TestWindowFunctions.java
@@ -247,4 +247,32 @@ public class TestWindowFunctions extends BaseTestQuery {
}
}
+ /* Verify the output of aggregate functions (which are reduced
+ * eg: avg(x) = sum(x)/count(x)) return results of the correct
+ * data type (double)
+ */
+ @Test
+ public void testAvgVarianceWindowFunctions() throws Exception {
+ final String avgQuery = "select avg(n_nationkey) over (partition by n_nationkey) col1 " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_nationkey = 1";
+
+ testBuilder()
+ .sqlQuery(avgQuery)
+ .unOrdered()
+ .baselineColumns("col1")
+ .baselineValues(1.0d)
+ .go();
+
+ final String varianceQuery = "select var_pop(n_nationkey) over (partition by n_nationkey) col1 " +
+ "from cp.`tpch/nation.parquet` " +
+ "where n_nationkey = 1";
+
+ testBuilder()
+ .sqlQuery(varianceQuery)
+ .unOrdered()
+ .baselineColumns("col1")
+ .baselineValues(0.0d)
+ .go();
+ }
}