You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@impala.apache.org by kw...@apache.org on 2017/04/26 23:22:21 UTC

[09/10] incubator-impala git commit: IMPALA-5251: Fix propagation of input exprs' types in 2-phase agg

IMPALA-5251: Fix propagation of input exprs' types in 2-phase agg

Since commit d2d3f4c (on asf-master), TAggregateExpr contains
the logical input types of the Aggregate Expr. The reason they
are included is that merging aggregate expressions will have
input tyes of the intermediate values which aren't necessarily
the same as the input types. For instance, NDV() uses a binary
blob as its intermediate value and it's passed to its merge
aggregate expressions as a StringVal but the input type of NDV()
in the query could be DecimalVal. In this case, we consider
DecimalVal as the logical input type while StringVal is the
intermediate type. The logical input types are accessed by the
BE via GetConstFnAttr() during interpretation and constant
propagation during codegen.

To handle distinct aggregate expressions (e.g. select count(distinct)),
the FE uses 2-phase aggregation by introducing an extra phase of
split/merge aggregation in which the distinct aggregate expressions'
inputs are coverted and added to the group-by expressions in the first
phase while the non-distinct aggregate expressions go through the normal
split/merge treatement.

The bug is that the existing code incorrectly propagates the intermediate
types of the non-grouping aggregate expressions as the logical input types
to the merging aggregate expressions in the second phase of aggregation.
The input aggregate expressions for the non-distinct aggregate expressions
in the second phase aggregation are already merging aggregate expressions
(from phase one) in which case we should not treat its input types as
logical input types.

This change fixes the problem above by checking if the input aggregate
expression passed to FunctionCallExpr.createMergeAggCall() is already
a merging aggregate expression. If so, it will use the logical input
types recorded in its 'mergeAggInputFn_' as references for its logical
input types instead of the aggregate expression input types themselves.

Change-Id: I158303b20d1afdff23c67f3338b9c4af2ad80691
Reviewed-on: http://gerrit.cloudera.org:8080/6724
Reviewed-by: Alex Behm <al...@cloudera.com>
Tested-by: Impala Public Jenkins


Project: http://git-wip-us.apache.org/repos/asf/incubator-impala/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-impala/commit/42ca45e8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-impala/tree/42ca45e8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-impala/diff/42ca45e8

Branch: refs/heads/master
Commit: 42ca45e8307ba4c831ad7ac8da86bbbd957fe4cd
Parents: e78d71e
Author: Michael Ho <kw...@cloudera.com>
Authored: Tue Apr 25 00:10:08 2017 -0700
Committer: Impala Public Jenkins <im...@gerrit.cloudera.org>
Committed: Wed Apr 26 21:40:32 2017 +0000

----------------------------------------------------------------------
 be/src/testutil/test-udas.cc                    | 93 ++++++++++----------
 .../impala/analysis/FunctionCallExpr.java       | 13 ++-
 .../queries/PlannerTest/aggregation.test        | 40 +++++++++
 .../functional-query/queries/QueryTest/uda.test | 52 +++++++++++
 tests/query_test/test_udfs.py                   |  5 ++
 5 files changed, 152 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/42ca45e8/be/src/testutil/test-udas.cc
----------------------------------------------------------------------
diff --git a/be/src/testutil/test-udas.cc b/be/src/testutil/test-udas.cc
index 549f2f0..806a971 100644
--- a/be/src/testutil/test-udas.cc
+++ b/be/src/testutil/test-udas.cc
@@ -57,36 +57,30 @@ StringVal AggFinalize(FunctionContext*, const StringVal& v) {
 
 // Defines AggIntermediate(int) returns BIGINT intermediate STRING
 void AggIntermediate(FunctionContext* context, const IntVal&, StringVal*) {}
-void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) {
+static void ValidateFunctionContext(const FunctionContext* context) {
   assert(context->GetNumArgs() == 1);
   assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
   assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
   assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
 }
+void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) {
+  ValidateFunctionContext(context);
+}
 void AggIntermediateInit(FunctionContext* context, StringVal*) {
-  assert(context->GetNumArgs() == 1);
-  assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
-  assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
-  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
+  ValidateFunctionContext(context);
 }
 void AggIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) {
-  assert(context->GetNumArgs() == 1);
-  assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
-  assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
-  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
+  ValidateFunctionContext(context);
 }
 BigIntVal AggIntermediateFinalize(FunctionContext* context, const StringVal&) {
-  assert(context->GetNumArgs() == 1);
-  assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
-  assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
-  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
+  ValidateFunctionContext(context);
   return BigIntVal::null();
 }
 
 // Defines AggDecimalIntermediate(DECIMAL(1,2), INT) returns DECIMAL(5,6)
 // intermediate DECIMAL(3,4)
 // Useful to test that type parameters are plumbed through.
-void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, const IntVal&, DecimalVal*) {
+static void ValidateFunctionContext2(const FunctionContext* context) {
   assert(context->GetNumArgs() == 2);
   assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
   assert(context->GetArgType(0)->precision == 2);
@@ -99,45 +93,50 @@ void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, c
   assert(context->GetReturnType().precision == 6);
   assert(context->GetReturnType().scale == 5);
 }
+void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&,
+    const IntVal&, DecimalVal*) {
+  ValidateFunctionContext2(context);
+}
 void AggDecimalIntermediateInit(FunctionContext* context, DecimalVal*) {
-  assert(context->GetNumArgs() == 2);
-  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetArgType(0)->precision == 2);
-  assert(context->GetArgType(0)->scale == 1);
-  assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
-  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetIntermediateType().precision == 4);
-  assert(context->GetIntermediateType().scale == 3);
-  assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetReturnType().precision == 6);
-  assert(context->GetReturnType().scale == 5);
+  ValidateFunctionContext2(context);
 }
-void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, DecimalVal*) {
-  assert(context->GetNumArgs() == 2);
-  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetArgType(0)->precision == 2);
-  assert(context->GetArgType(0)->scale == 1);
-  assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
-  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetIntermediateType().precision == 4);
-  assert(context->GetIntermediateType().scale == 3);
-  assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetReturnType().precision == 6);
-  assert(context->GetReturnType().scale == 5);
+void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&,
+    DecimalVal*) {
+  ValidateFunctionContext2(context);
 }
 DecimalVal AggDecimalIntermediateFinalize(FunctionContext* context, const DecimalVal&) {
-  assert(context->GetNumArgs() == 2);
+  ValidateFunctionContext2(context);
+  return DecimalVal::null();
+}
+
+// Defines AggStringIntermediate(DECIMAL(20,10), BIGINT, STRING) returns DECIMAL(20,0)
+// intermediate STRING.
+// Useful to test decimal input types with string as intermediate types.
+static void ValidateFunctionContext3(const FunctionContext* context) {
+  assert(context->GetNumArgs() == 3);
   assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetArgType(0)->precision == 2);
-  assert(context->GetArgType(0)->scale == 1);
-  assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
-  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetIntermediateType().precision == 4);
-  assert(context->GetIntermediateType().scale == 3);
+  assert(context->GetArgType(0)->precision == 20);
+  assert(context->GetArgType(0)->scale == 10);
+  assert(context->GetArgType(1)->type == FunctionContext::TYPE_BIGINT);
+  assert(context->GetArgType(2)->type == FunctionContext::TYPE_STRING);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
   assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
-  assert(context->GetReturnType().precision == 6);
-  assert(context->GetReturnType().scale == 5);
-  return DecimalVal::null();
+  assert(context->GetReturnType().precision == 20);
+  assert(context->GetReturnType().scale == 0);
+}
+void AggStringIntermediateUpdate(FunctionContext* context, const DecimalVal&,
+    const BigIntVal&, const StringVal&, StringVal*) {
+  ValidateFunctionContext3(context);
+}
+void AggStringIntermediateInit(FunctionContext* context, StringVal*) {
+  ValidateFunctionContext3(context);
+}
+void AggStringIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) {
+  ValidateFunctionContext3(context);
+}
+DecimalVal AggStringIntermediateFinalize(FunctionContext* context, const StringVal&) {
+  ValidateFunctionContext3(context);
+  return DecimalVal(100);
 }
 
 // Defines MemTest(bigint) return bigint

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/42ca45e8/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
----------------------------------------------------------------------
diff --git a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
index c9d098d..1e06254 100644
--- a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
+++ b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
@@ -102,8 +102,12 @@ public class FunctionCallExpr extends Expr {
       FunctionCallExpr agg, List<Expr> params) {
     Preconditions.checkState(agg.isAnalyzed());
     Preconditions.checkState(agg.isAggregateFunction());
+    // If the input aggregate function is already a merge aggregate function (due to
+    // 2-phase aggregation), its input types will be the intermediate value types. The
+    // original input argument exprs are in 'agg.mergeAggInputFn_' so use it instead.
+    FunctionCallExpr mergeAggInputFn = agg.isMergeAggFn() ? agg.mergeAggInputFn_ : agg;
     FunctionCallExpr result = new FunctionCallExpr(
-        agg.fnName_, new FunctionParams(false, params), agg);
+        agg.fnName_, new FunctionParams(false, params), mergeAggInputFn);
     // Inherit the function object from 'agg'.
     result.fn_ = agg.fn_;
     result.type_ = agg.type_;
@@ -127,8 +131,8 @@ public class FunctionCallExpr extends Expr {
     fnName_ = other.fnName_;
     isAnalyticFnCall_ = other.isAnalyticFnCall_;
     isInternalFnCall_ = other.isInternalFnCall_;
-    mergeAggInputFn_ =
-        other.mergeAggInputFn_ == null ? null : (FunctionCallExpr)other.mergeAggInputFn_.clone();
+    mergeAggInputFn_ = other.mergeAggInputFn_ == null ?
+        null : (FunctionCallExpr)other.mergeAggInputFn_.clone();
     // Clone the params in a way that keeps the children_ and the params.exprs()
     // in sync. The children have already been cloned in the super c'tor.
     if (other.params_.isStar()) {
@@ -574,7 +578,8 @@ public class FunctionCallExpr extends Expr {
   void validateMergeAggFn(FunctionCallExpr inputAggFn) {
     Preconditions.checkState(isMergeAggFn());
     List<Expr> copiedInputExprs = mergeAggInputFn_.getChildren();
-    List<Expr> inputExprs = inputAggFn.getChildren();
+    List<Expr> inputExprs = inputAggFn.isMergeAggFn() ?
+        inputAggFn.mergeAggInputFn_.getChildren() : inputAggFn.getChildren();
     Preconditions.checkState(copiedInputExprs.size() == inputExprs.size());
     for (int i = 0; i < inputExprs.size(); ++i) {
       Type copiedInputType = copiedInputExprs.get(i).getType();

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/42ca45e8/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test
----------------------------------------------------------------------
diff --git a/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test b/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test
index a1177b0..b5c3970 100644
--- a/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test
+++ b/testdata/workloads/functional-planner/queries/PlannerTest/aggregation.test
@@ -554,6 +554,46 @@ PLAN-ROOT SINK
 01:SCAN HDFS [functional.alltypes]
    partitions=24/24 files=24 size=478.45KB
 ====
+# Mixed distinct and non-distinct agg with intermediate type different from input type
+# Regression test for IMPALA-5251 to exercise validateMergeAggFn() in FunctionCallExpr.
+select avg(l_quantity), ndv(l_discount), count(distinct l_partkey)
+from tpch_parquet.lineitem;
+---- PLAN
+PLAN-ROOT SINK
+|
+02:AGGREGATE [FINALIZE]
+|  output: count(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount)
+|
+01:AGGREGATE
+|  output: avg(l_quantity), ndv(l_discount)
+|  group by: l_partkey
+|
+00:SCAN HDFS [tpch_parquet.lineitem]
+   partitions=1/1 files=3 size=193.74MB
+---- DISTRIBUTEDPLAN
+PLAN-ROOT SINK
+|
+06:AGGREGATE [FINALIZE]
+|  output: count:merge(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount)
+|
+05:EXCHANGE [UNPARTITIONED]
+|
+02:AGGREGATE
+|  output: count(l_partkey), avg:merge(l_quantity), ndv:merge(l_discount)
+|
+04:AGGREGATE
+|  output: avg:merge(l_quantity), ndv:merge(l_discount)
+|  group by: l_partkey
+|
+03:EXCHANGE [HASH(l_partkey)]
+|
+01:AGGREGATE [STREAMING]
+|  output: avg(l_quantity), ndv(l_discount)
+|  group by: l_partkey
+|
+00:SCAN HDFS [tpch_parquet.lineitem]
+   partitions=1/1 files=3 size=193.74MB
+====
 # test that aggregations are not placed below an unpartitioned exchange with a limit
 select count(*) from (select * from functional.alltypes limit 10) t
 ---- PLAN

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/42ca45e8/testdata/workloads/functional-query/queries/QueryTest/uda.test
----------------------------------------------------------------------
diff --git a/testdata/workloads/functional-query/queries/QueryTest/uda.test b/testdata/workloads/functional-query/queries/QueryTest/uda.test
index 3a9bbbe..932b94a 100644
--- a/testdata/workloads/functional-query/queries/QueryTest/uda.test
+++ b/testdata/workloads/functional-query/queries/QueryTest/uda.test
@@ -88,3 +88,55 @@ from functional.decimal_tbl
 NULL,5
 ---- TYPES
 decimal,bigint
+====
+---- QUERY
+# Test that all types are exposed via the FunctionContext correctly.
+# This includes distinct aggregate expression to test IMPALA-5251.
+# It also relies on asserts in the UDA funciton.
+select
+   agg_string_intermediate(cast(c1 as decimal(20,10)), 1000, "foobar"),
+   agg_decimal_intermediate(cast(c3 as decimal(2,1)), 2),
+   agg_intermediate(int_col),
+   avg(c2),
+   min(c3-c1),
+   max(c1+c3),
+   count(distinct int_col),
+   sum(distinct int_col)
+from
+   functional.alltypesagg,
+   functional.decimal_tiny
+---- RESULTS
+100,NULL,NULL,160.49989,-10.0989,11.8989,999,499500
+---- TYPES
+decimal,decimal,bigint,decimal,decimal,decimal,bigint,bigint
+====
+---- QUERY
+# Test that all types are exposed via the FunctionContext correctly.
+# This includes distinct aggregate expression to test IMPALA-5251.
+# It also relies on asserts in the UDA funciton.
+select
+   agg_string_intermediate(cast(c1 as decimal(20,10)), 1000, "foobar"),
+   agg_decimal_intermediate(cast(c3 as decimal(2,1)), 2),
+   agg_intermediate(int_col),
+   ndv(c2),
+   sum(distinct c1)/count(distinct c1)
+from
+   functional.alltypesagg,
+   functional.decimal_tiny
+group by
+   year,month,day
+---- RESULTS
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+100,NULL,NULL,99,5.4994
+---- TYPES
+decimal,decimal,bigint,bigint,decimal
+====
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/42ca45e8/tests/query_test/test_udfs.py
----------------------------------------------------------------------
diff --git a/tests/query_test/test_udfs.py b/tests/query_test/test_udfs.py
index 56ce233..ec24c9f 100644
--- a/tests/query_test/test_udfs.py
+++ b/tests/query_test/test_udfs.py
@@ -103,6 +103,11 @@ create aggregate function {database}.agg_decimal_intermediate(decimal(2,1), int)
 returns decimal(6,5) intermediate decimal(4,3) location '{location}'
 init_fn='AggDecimalIntermediateInit' update_fn='AggDecimalIntermediateUpdate'
 merge_fn='AggDecimalIntermediateMerge' finalize_fn='AggDecimalIntermediateFinalize';
+
+create aggregate function {database}.agg_string_intermediate(decimal(20,10), bigint, string)
+returns decimal(20,0) intermediate string location '{location}'
+init_fn='AggStringIntermediateInit' update_fn='AggStringIntermediateUpdate'
+merge_fn='AggStringIntermediateMerge' finalize_fn='AggStringIntermediateFinalize';
 """
 
   # Create test UDF functions in {database} from library {location}