You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by zh...@apache.org on 2019/10/10 14:43:51 UTC

[incubator-doris] branch master updated: Enhance the speed of avg function (#1889)

This is an automated email from the ASF dual-hosted git repository.

zhaoc pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new e267d03  Enhance the speed of avg function (#1889)
e267d03 is described below

commit e267d031bb1402cf7fa32b8196553fd9b1266cb3
Author: EmmyMiao87 <52...@qq.com>
AuthorDate: Thu Oct 10 22:43:46 2019 +0800

    Enhance the speed of avg function (#1889)
    
    This commit enable the avg operator in fe instead of converting the avg function into sum/count.
    Also, this commit fix the bug of deciamlv2 avg which cause the core in be.
    The int128 could not be assinged directly.
    
    The speed of avg function is similar to sum function after enhancement.
---
 be/src/exprs/aggregate_functions.cpp               | 129 +++++++++++----------
 be/src/exprs/aggregate_functions.h                 |   2 +
 .../apache/doris/analysis/FunctionCallExpr.java    |   5 +-
 .../java/org/apache/doris/analysis/SelectStmt.java |  43 +------
 .../java/org/apache/doris/catalog/FunctionSet.java |   2 +-
 5 files changed, 75 insertions(+), 106 deletions(-)

diff --git a/be/src/exprs/aggregate_functions.cpp b/be/src/exprs/aggregate_functions.cpp
index 0a513f3..75de6da 100644
--- a/be/src/exprs/aggregate_functions.cpp
+++ b/be/src/exprs/aggregate_functions.cpp
@@ -177,61 +177,6 @@ void AggregateFunctions::count_remove(
     }
 }
 
-struct AvgState {
-    double sum;
-    int64_t count;
-};
-
-struct DecimalAvgState {
-    DecimalVal sum;
-    int64_t count;
-};
-
-struct DecimalV2AvgState {
-    DecimalV2Val sum;
-    int64_t count;
-};
-
-void AggregateFunctions::avg_init(FunctionContext* ctx, StringVal* dst) {
-    dst->is_null = false;
-    dst->len = sizeof(AvgState);
-    dst->ptr = ctx->allocate(dst->len);
-    memset(dst->ptr, 0, sizeof(AvgState));
-}
-
-void AggregateFunctions::decimal_avg_init(FunctionContext* ctx, StringVal* dst) {
-    dst->is_null = false;
-    dst->len = sizeof(DecimalAvgState);
-    dst->ptr = ctx->allocate(dst->len);
-    // memset(dst->ptr, 0, sizeof(DecimalAvgState));
-    DecimalAvgState* avg = reinterpret_cast<DecimalAvgState*>(dst->ptr);
-    avg->count = 0;
-    avg->sum.set_to_zero();
-}
-
-void AggregateFunctions::decimalv2_avg_init(FunctionContext* ctx, StringVal* dst) {
-    dst->is_null = false;
-    dst->len = sizeof(DecimalV2AvgState);
-    dst->ptr = ctx->allocate(dst->len);
-    // memset(dst->ptr, 0, sizeof(DecimalAvgState));
-    DecimalV2AvgState* avg = reinterpret_cast<DecimalV2AvgState*>(dst->ptr);
-    avg->count = 0;
-    avg->sum.set_to_zero();
-}
-
-
-template <typename T>
-void AggregateFunctions::avg_update(FunctionContext* ctx, const T& src, StringVal* dst) {
-    if (src.is_null) {
-        return;
-    }
-    DCHECK(dst->ptr != NULL);
-    DCHECK_EQ(sizeof(AvgState), dst->len);
-    AvgState* avg = reinterpret_cast<AvgState*>(dst->ptr);
-    avg->sum += src.val;
-    ++avg->count;
-}
-
 struct PercentileApproxState {
 public:
     PercentileApproxState() : digest(new TDigest()){
@@ -305,6 +250,59 @@ DoubleVal AggregateFunctions::percentile_approx_finalize(FunctionContext* ctx, c
     return DoubleVal(result);
 }
 
+struct AvgState {
+    double sum = 0;
+    int64_t count = 0;
+};
+
+struct DecimalAvgState {
+    DecimalVal sum;
+    int64_t count;
+};
+
+struct DecimalV2AvgState {
+    DecimalV2Val sum;
+    int64_t count = 0;
+};
+
+void AggregateFunctions::avg_init(FunctionContext* ctx, StringVal* dst) {
+    dst->is_null = false;
+    dst->len = sizeof(AvgState);
+    dst->ptr = ctx->allocate(dst->len);
+    new (dst->ptr) AvgState;
+}
+
+void AggregateFunctions::decimal_avg_init(FunctionContext* ctx, StringVal* dst) {
+    dst->is_null = false;
+    dst->len = sizeof(DecimalAvgState);
+    dst->ptr = ctx->allocate(dst->len);
+    // memset(dst->ptr, 0, sizeof(DecimalAvgState));
+    DecimalAvgState* avg = reinterpret_cast<DecimalAvgState*>(dst->ptr);
+    avg->count = 0;
+    avg->sum.set_to_zero();
+}
+
+void AggregateFunctions::decimalv2_avg_init(FunctionContext* ctx, StringVal* dst) {
+    dst->is_null = false;
+    dst->len = sizeof(DecimalV2AvgState);
+    // The memroy for int128 need to be aligned by 16.
+    // So the constructor has been used instead of allocating memory.
+    // Also, it will be release in finalize.
+    dst->ptr = (uint8_t*) new DecimalV2AvgState;
+}
+
+template <typename T>
+void AggregateFunctions::avg_update(FunctionContext* ctx, const T& src, StringVal* dst) {
+    if (src.is_null) {
+        return;
+    }
+    DCHECK(dst->ptr != NULL);
+    DCHECK_EQ(sizeof(AvgState), dst->len);
+    AvgState* avg = reinterpret_cast<AvgState*>(dst->ptr);
+    avg->sum += src.val;
+    ++avg->count;
+}
+
 void AggregateFunctions::decimal_avg_update(FunctionContext* ctx,
         const DecimalVal& src,
         StringVal* dst) {
@@ -341,6 +339,15 @@ void AggregateFunctions::decimalv2_avg_update(FunctionContext* ctx,
     ++avg->count;
 }
 
+StringVal AggregateFunctions::decimalv2_avg_serialize(
+        FunctionContext* ctx, const StringVal& src) {
+    DCHECK(!src.is_null);
+    StringVal result(ctx, src.len);
+    memcpy(result.ptr, src.ptr, src.len);
+    delete (DecimalV2AvgState*)src.ptr;
+    return result;
+}
+
 template <typename T>
 void AggregateFunctions::avg_remove(FunctionContext* ctx, const T& src, StringVal* dst) {
     // Remove doesn't need to explicitly check the number of calls to Update() or Remove()
@@ -424,16 +431,17 @@ void AggregateFunctions::decimal_avg_merge(FunctionContext* ctx, const StringVal
 
 void AggregateFunctions::decimalv2_avg_merge(FunctionContext* ctx, const StringVal& src,
         StringVal* dst) {
-    const DecimalV2AvgState* src_struct = reinterpret_cast<const DecimalV2AvgState*>(src.ptr);
+    DecimalV2AvgState src_struct;
+    memcpy(&src_struct, src.ptr, sizeof(DecimalV2AvgState));
     DCHECK(dst->ptr != NULL);
     DCHECK_EQ(sizeof(DecimalV2AvgState), dst->len);
     DecimalV2AvgState* dst_struct = reinterpret_cast<DecimalV2AvgState*>(dst->ptr);
 
     DecimalV2Value v1 = DecimalV2Value::from_decimal_val(dst_struct->sum);
-    DecimalV2Value v2 = DecimalV2Value::from_decimal_val(src_struct->sum);
+    DecimalV2Value v2 = DecimalV2Value::from_decimal_val(src_struct.sum);
     DecimalV2Value v = v1 + v2;
     v.to_decimal_val(&dst_struct->sum);
-    dst_struct->count += src_struct->count;
+    dst_struct->count += src_struct.count;
 }
 
 DoubleVal AggregateFunctions::avg_get_value(FunctionContext* ctx, const StringVal& src) {
@@ -489,11 +497,8 @@ DecimalVal AggregateFunctions::decimal_avg_finalize(FunctionContext* ctx, const
 }
 
 DecimalV2Val AggregateFunctions::decimalv2_avg_finalize(FunctionContext* ctx, const StringVal& src) {
-    if (src.is_null) {
-        return DecimalV2Val::null();
-    }
     DecimalV2Val result = decimalv2_avg_get_value(ctx, src);
-    ctx->free(src.ptr);
+    delete (DecimalV2AvgState*)src.ptr;
     return result;
 }
 
diff --git a/be/src/exprs/aggregate_functions.h b/be/src/exprs/aggregate_functions.h
index 23b4372..0e49334 100644
--- a/be/src/exprs/aggregate_functions.h
+++ b/be/src/exprs/aggregate_functions.h
@@ -118,6 +118,8 @@ dst);
             doris_udf::StringVal* dst);
     static void decimalv2_avg_merge(FunctionContext* ctx, const doris_udf::StringVal& src,
             doris_udf::StringVal* dst);
+    static doris_udf::StringVal decimalv2_avg_serialize(doris_udf::FunctionContext* ctx,
+         const doris_udf::StringVal& src);
     static void decimal_avg_remove(doris_udf::FunctionContext* ctx,
             const doris_udf::DecimalVal& src,
             doris_udf::StringVal* dst);
diff --git a/fe/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index f67bb78..dbdc79c 100644
--- a/fe/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -358,9 +358,10 @@ public class FunctionCallExpr extends Expr {
         }
 
         // SUM and AVG cannot be applied to non-numeric types
-        if (fnName.getFunction().equalsIgnoreCase("sum")
+        if ((fnName.getFunction().equalsIgnoreCase("sum")
+                || fnName.getFunction().equalsIgnoreCase("avg"))
                 && ((!arg.type.isNumericType() && !arg.type.isNull()) || arg.type.isHllType())) {
-            throw new AnalysisException("SUM requires a numeric parameter: " + this.toSql());
+            throw new AnalysisException(fnName.getFunction() + " requires a numeric parameter: " + this.toSql());
         }
         if (fnName.getFunction().equalsIgnoreCase("sum_distinct")
                 && ((!arg.type.isNumericType() && !arg.type.isNull()) || arg.type.isHllType())) {
diff --git a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java
index 1a1af29..ec8f6f8 100644
--- a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java
+++ b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java
@@ -742,8 +742,8 @@ public class SelectStmt extends QueryStmt {
 
     /**
      * Analyze aggregation-relevant components of the select block (Group By clause,
-     * select list, Order By clause), substitute AVG with SUM/COUNT, create the
-     * AggregationInfo, including the agg output tuple, and transform all post-agg exprs
+     * select list, Order By clause),
+     * Create the AggregationInfo, including the agg output tuple, and transform all post-agg exprs
      * given AggregationInfo's smap.
      */
     private void analyzeAggregation(Analyzer analyzer) throws AnalysisException {
@@ -867,12 +867,6 @@ public class SelectStmt extends QueryStmt {
             TreeNode.collect(sortInfo.getOrderingExprs(), Expr.isAggregatePredicate(), aggExprs);
         }
 
-        // substitute AVG before constructing AggregateInfo
-        ExprSubstitutionMap avgSMap = createAvgSMap(aggExprs, analyzer);
-
-        // Optionally rewrite all count(distinct <expr>) into equivalent NDV() calls.
-        ExprSubstitutionMap ndvSmap = avgSMap;
-
         // When DISTINCT aggregates are present, non-distinct (i.e. ALL) aggregates are
         // evaluated in two phases (see AggregateInfo for more details). In particular,
         // COUNT(c) in "SELECT COUNT(c), AGG(DISTINCT d) from R" is transformed to
@@ -884,7 +878,6 @@ public class SelectStmt extends QueryStmt {
         // i) There is no GROUP-BY clause, and
         // ii) Other DISTINCT aggregates are present.
         ExprSubstitutionMap countAllMap = createCountAllMap(aggExprs, analyzer);
-        countAllMap = ExprSubstitutionMap.compose(ndvSmap, countAllMap, analyzer);
         final ExprSubstitutionMap multiCountOrSumDistinctMap = 
                 createSumOrCountMultiDistinctSMap(aggExprs, analyzer);
         countAllMap = ExprSubstitutionMap.compose(multiCountOrSumDistinctMap, countAllMap, analyzer);
@@ -967,38 +960,6 @@ public class SelectStmt extends QueryStmt {
         }
     }
 
-    /**
-     * Build smap AVG -> SUM/COUNT;
-     * assumes that select list and having clause have been analyzed.
-     */
-    private ExprSubstitutionMap createAvgSMap(
-            ArrayList<FunctionCallExpr> aggExprs, Analyzer analyzer) throws AnalysisException {
-        ExprSubstitutionMap result = new ExprSubstitutionMap();
-        for (FunctionCallExpr aggExpr : aggExprs) {
-            if (!aggExpr.getFnName().getFunction().equalsIgnoreCase("AVG")) {
-                continue;
-            }
-            // Transform avg(TIMESTAMP) to cast(avg(cast(TIMESTAMP as DOUBLE)) as TIMESTAMP)
-            CastExpr inCastExpr = null;
-
-            List<Expr> sumInputExprs = Lists.newArrayList(aggExpr.getChild(0).clone(null));
-            List<Expr> countInputExpr = Lists.newArrayList(aggExpr.getChild(0).clone(null));
-
-            FunctionCallExpr sumExpr = new FunctionCallExpr("sum",
-              new FunctionParams(aggExpr.isDistinct(), sumInputExprs));
-            FunctionCallExpr countExpr =
-                    new FunctionCallExpr("count",
-                new FunctionParams(aggExpr.isDistinct(), countInputExpr));
-            ArithmeticExpr divExpr =
-              new ArithmeticExpr(ArithmeticExpr.Operator.DIVIDE, sumExpr, countExpr);
-
-            divExpr.analyze(analyzer);
-            result.put(aggExpr, divExpr);
-        }
-        LOG.debug("avg smap: {}", result.debugString());
-        return result;
-    }
-
 
     /**
      * Build smap count_distinct->multi_count_distinct sum_distinct->multi_count_distinct
diff --git a/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java
index ab0a20a..adb35c4 100644
--- a/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1037,7 +1037,7 @@ public class FunctionSet {
                 prefix + "18decimalv2_avg_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
                 prefix + "20decimalv2_avg_updateEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE",
                 prefix + "19decimalv2_avg_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
-                stringValSerializeOrFinalize,
+                prefix + "23decimalv2_avg_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
                 prefix + "23decimalv2_avg_get_valueEPN9doris_udf15FunctionContextERKNS1_9StringValE",
                 prefix + "20decimalv2_avg_removeEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE",
                 prefix + "22decimalv2_avg_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org