You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kylin.apache.org by li...@apache.org on 2017/03/31 07:00:06 UTC

[16/21] kylin git commit: #KYLIN-490 support multiple column distinct count

#KYLIN-490 support multiple column distinct count

Signed-off-by: Hongbin Ma <ma...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/kylin/repo
Commit: http://git-wip-us.apache.org/repos/asf/kylin/commit/88a1c71d
Tree: http://git-wip-us.apache.org/repos/asf/kylin/tree/88a1c71d
Diff: http://git-wip-us.apache.org/repos/asf/kylin/diff/88a1c71d

Branch: refs/heads/KYLIN-2501
Commit: 88a1c71dde855c693b230f67b92c4cd067d43b2b
Parents: f72a3f6
Author: Roger Shi <ro...@hotmail.com>
Authored: Wed Mar 22 19:22:22 2017 +0800
Committer: Hongbin Ma <ma...@apache.org>
Committed: Mon Mar 27 15:20:40 2017 +0800

----------------------------------------------------------------------
 .../kylin/measure/ParamAsMeasureCount.java      |  30 +++++
 .../BitmapIntersectDistinctCountAggFunc.java    |   9 +-
 .../measure/percentile/PercentileAggFunc.java   |   9 +-
 .../kylin/metadata/model/FunctionDesc.java      |  62 ++++++---
 .../kylin/metadata/model/ParameterDesc.java     | 135 +++++++++++++++++--
 .../resources/query/sql_distinct/query08.sql    |  24 ++++
 .../kylin/query/relnode/OLAPAggregateRel.java   |  86 +++++++++---
 7 files changed, 304 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kylin/blob/88a1c71d/core-metadata/src/main/java/org/apache/kylin/measure/ParamAsMeasureCount.java
----------------------------------------------------------------------
diff --git a/core-metadata/src/main/java/org/apache/kylin/measure/ParamAsMeasureCount.java b/core-metadata/src/main/java/org/apache/kylin/measure/ParamAsMeasureCount.java
new file mode 100644
index 0000000..b9bcd10
--- /dev/null
+++ b/core-metadata/src/main/java/org/apache/kylin/measure/ParamAsMeasureCount.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+package org.apache.kylin.measure;
+
+public interface ParamAsMeasureCount {
+    /**
+     * Get how many parameters are required to identify the measure
+     * Negative value is for var arguments function
+     * @return 0 ==> all parameters
+     *         positive number ==> parameter count
+     *         negative number ==> parameter count - required number
+     */
+    int getParamAsMeasureCount();
+}

http://git-wip-us.apache.org/repos/asf/kylin/blob/88a1c71d/core-metadata/src/main/java/org/apache/kylin/measure/bitmap/BitmapIntersectDistinctCountAggFunc.java
----------------------------------------------------------------------
diff --git a/core-metadata/src/main/java/org/apache/kylin/measure/bitmap/BitmapIntersectDistinctCountAggFunc.java b/core-metadata/src/main/java/org/apache/kylin/measure/bitmap/BitmapIntersectDistinctCountAggFunc.java
index cd4d306..a1e2665 100644
--- a/core-metadata/src/main/java/org/apache/kylin/measure/bitmap/BitmapIntersectDistinctCountAggFunc.java
+++ b/core-metadata/src/main/java/org/apache/kylin/measure/bitmap/BitmapIntersectDistinctCountAggFunc.java
@@ -17,6 +17,8 @@
 */
 package org.apache.kylin.measure.bitmap;
 
+import org.apache.kylin.measure.ParamAsMeasureCount;
+
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -27,9 +29,14 @@ import java.util.Map;
  * Example: intersect_count(uuid, event, array['A', 'B', 'C']), meaning find the count of uuid in all A/B/C 3 bitmaps
  *          requires an bitmap count distinct measure of uuid, and an dimension of event
  */
-public class BitmapIntersectDistinctCountAggFunc {
+public class BitmapIntersectDistinctCountAggFunc implements ParamAsMeasureCount {
     private static final BitmapCounterFactory factory = RoaringBitmapCounterFactory.INSTANCE;
 
+    @Override
+    public int getParamAsMeasureCount() {
+        return -2;
+    }
+
     public static class RetentionPartialResult {
         Map<Object, BitmapCounter> map;
         List keyList;

http://git-wip-us.apache.org/repos/asf/kylin/blob/88a1c71d/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileAggFunc.java
----------------------------------------------------------------------
diff --git a/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileAggFunc.java b/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileAggFunc.java
index ad02019..d3cec8f 100644
--- a/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileAggFunc.java
+++ b/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileAggFunc.java
@@ -18,7 +18,9 @@
 
 package org.apache.kylin.measure.percentile;
 
-public class PercentileAggFunc {
+import org.apache.kylin.measure.ParamAsMeasureCount;
+
+public class PercentileAggFunc implements ParamAsMeasureCount{
     public static PercentileCounter init() {
         return null;
     }
@@ -41,4 +43,9 @@ public class PercentileAggFunc {
     public static double result(PercentileCounter counter) {
         return counter == null ? 0L : counter.getResultEstimate();
     }
+
+    @Override
+    public int getParamAsMeasureCount() {
+        return 1;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kylin/blob/88a1c71d/core-metadata/src/main/java/org/apache/kylin/metadata/model/FunctionDesc.java
----------------------------------------------------------------------
diff --git a/core-metadata/src/main/java/org/apache/kylin/metadata/model/FunctionDesc.java b/core-metadata/src/main/java/org/apache/kylin/metadata/model/FunctionDesc.java
index cbd7574..61c5fac 100644
--- a/core-metadata/src/main/java/org/apache/kylin/metadata/model/FunctionDesc.java
+++ b/core-metadata/src/main/java/org/apache/kylin/metadata/model/FunctionDesc.java
@@ -18,22 +18,26 @@
 
 package org.apache.kylin.metadata.model;
 
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.kylin.measure.MeasureType;
+import org.apache.kylin.measure.MeasureTypeFactory;
+import org.apache.kylin.measure.basic.BasicMeasureType;
+import org.apache.kylin.metadata.datatype.DataType;
+
 import com.fasterxml.jackson.annotation.JsonAutoDetect;
 import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility;
 import com.fasterxml.jackson.annotation.JsonInclude;
 import com.fasterxml.jackson.annotation.JsonProperty;
+import com.google.common.base.Joiner;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
-import org.apache.kylin.measure.MeasureType;
-import org.apache.kylin.measure.MeasureTypeFactory;
-import org.apache.kylin.measure.basic.BasicMeasureType;
-import org.apache.kylin.metadata.datatype.DataType;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.LinkedHashMap;
-import java.util.Map;
-import java.util.Set;
 
 /**
  */
@@ -48,7 +52,7 @@ public class FunctionDesc implements Serializable {
         r.returnDataType = DataType.getType(returnType);
         return r;
     }
-    
+
     public static final String FUNC_SUM = "SUM";
     public static final String FUNC_MIN = "MIN";
     public static final String FUNC_MAX = "MAX";
@@ -95,7 +99,7 @@ public class FunctionDesc implements Serializable {
             }
         }
 
-        if(parameter != null)
+        if (parameter != null)
             parameter.setColRefs(colRefs);
     }
 
@@ -140,6 +144,8 @@ public class FunctionDesc implements Serializable {
             return getParameter().getValue();
         } else if (isCount()) {
             return "_KY_" + "COUNT__"; // ignores parameter, count(*), count(1), count(col) are all the same
+        } else if (isCountDistinct()) {
+            return "_KY_" + getFullExpressionInAlphabetOrder().replaceAll("[(),. ]", "_");
         } else {
             return "_KY_" + getFullExpression().replaceAll("[(),. ]", "_");
         }
@@ -197,6 +203,25 @@ public class FunctionDesc implements Serializable {
         return sb.toString();
     }
 
+    /**
+     * Parameters' name appears in alphabet order.
+     * This method is used for funcs whose parameters appear in arbitrary order
+     */
+    public String getFullExpressionInAlphabetOrder() {
+        StringBuilder sb = new StringBuilder(expression);
+        sb.append("(");
+        ParameterDesc localParam = parameter;
+        List<String> flatParams = Lists.newArrayList();
+        while (localParam != null) {
+            flatParams.add(localParam.getValue());
+            localParam = localParam.getNextParameter();
+        }
+        Collections.sort(flatParams);
+        sb.append(Joiner.on(",").join(flatParams));
+        sb.append(")");
+        return sb.toString();
+    }
+
     public boolean isDimensionAsMetric() {
         return isDimensionAsMetric;
     }
@@ -264,13 +289,20 @@ public class FunctionDesc implements Serializable {
                 return false;
         } else if (!expression.equals(other.expression))
             return false;
-        // NOTE: don't check the parameter of count()
-        if (isCount() == false) {
+        if (isCountDistinct()) {
+            // for count distinct func, param's order doesn't matter
+            if (parameter == null) {
+                if (other.parameter != null)
+                    return false;
+            } else {
+                return parameter.equalInArbitraryOrder(other.parameter);
+            }
+        } else if (!isCount()) { // NOTE: don't check the parameter of count()
             if (parameter == null) {
                 if (other.parameter != null)
                     return false;
             } else {
-                 if (!parameter.equals(other.parameter))
+                if (!parameter.equals(other.parameter))
                     return false;
             }
         }

http://git-wip-us.apache.org/repos/asf/kylin/blob/88a1c71d/core-metadata/src/main/java/org/apache/kylin/metadata/model/ParameterDesc.java
----------------------------------------------------------------------
diff --git a/core-metadata/src/main/java/org/apache/kylin/metadata/model/ParameterDesc.java b/core-metadata/src/main/java/org/apache/kylin/metadata/model/ParameterDesc.java
index 8ad20a8..5ba2f14 100644
--- a/core-metadata/src/main/java/org/apache/kylin/metadata/model/ParameterDesc.java
+++ b/core-metadata/src/main/java/org/apache/kylin/metadata/model/ParameterDesc.java
@@ -18,17 +18,19 @@
 
 package org.apache.kylin.metadata.model;
 
+import java.io.Serializable;
+import java.io.UnsupportedEncodingException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Set;
+
 import com.fasterxml.jackson.annotation.JsonAutoDetect;
 import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility;
 import com.fasterxml.jackson.annotation.JsonInclude;
 import com.fasterxml.jackson.annotation.JsonProperty;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
-
-import java.io.Serializable;
-import java.io.UnsupportedEncodingException;
-import java.util.Arrays;
-import java.util.List;
+import com.google.common.collect.Sets;
 
 /**
  */
@@ -38,9 +40,9 @@ public class ParameterDesc implements Serializable {
     public static ParameterDesc newInstance(Object... objs) {
         if (objs.length == 0)
             throw new IllegalArgumentException();
-        
+
         ParameterDesc r = new ParameterDesc();
-        
+
         Object obj = objs[0];
         if (obj instanceof TblColRef) {
             TblColRef col = (TblColRef) obj;
@@ -51,7 +53,7 @@ public class ParameterDesc implements Serializable {
             r.type = FunctionDesc.PARAMETER_TYPE_CONSTANT;
             r.value = (String) obj;
         }
-        
+
         if (objs.length >= 2) {
             r.nextParameter = newInstance(Arrays.copyOfRange(objs, 1, objs.length));
             if (r.nextParameter.colRefs.size() > 0) {
@@ -63,7 +65,7 @@ public class ParameterDesc implements Serializable {
         }
         return r;
     }
-    
+
     @JsonProperty("type")
     private String type;
     @JsonProperty("value")
@@ -74,6 +76,15 @@ public class ParameterDesc implements Serializable {
     private ParameterDesc nextParameter;
 
     private List<TblColRef> colRefs = ImmutableList.of();
+    private Set<PlainParameter> plainParameters = null;
+
+    // Lazy evaluation
+    public Set<PlainParameter> getPlainParameters() {
+        if (plainParameters == null) {
+            plainParameters = PlainParameter.createFromParameterDesc(this);
+        }
+        return plainParameters;
+    }
 
     public String getType() {
         return type;
@@ -86,7 +97,7 @@ public class ParameterDesc implements Serializable {
     public String getValue() {
         return value;
     }
-    
+
     void setValue(String value) {
         this.value = value;
     }
@@ -94,7 +105,7 @@ public class ParameterDesc implements Serializable {
     public List<TblColRef> getColRefs() {
         return colRefs;
     }
-    
+
     void setColRefs(List<TblColRef> colRefs) {
         this.colRefs = colRefs;
     }
@@ -118,7 +129,7 @@ public class ParameterDesc implements Serializable {
 
         if (type != null ? !type.equals(that.type) : that.type != null)
             return false;
-        
+
         ParameterDesc p = this, q = that;
         int refi = 0, refj = 0;
         for (; p != null && q != null; p = p.nextParameter, q = q.nextParameter) {
@@ -138,10 +149,24 @@ public class ParameterDesc implements Serializable {
                     return false;
             }
         }
-        
+
         return p == null && q == null;
     }
 
+    public boolean equalInArbitraryOrder(Object o) {
+        if (this == o)
+            return true;
+        if (o == null || getClass() != o.getClass())
+            return false;
+
+        ParameterDesc that = (ParameterDesc) o;
+
+        Set<PlainParameter> thisPlainParams = this.getPlainParameters();
+        Set<PlainParameter> thatPlainParams = that.getPlainParameters();
+
+        return thisPlainParams.containsAll(thatPlainParams) && thatPlainParams.containsAll(thisPlainParams);
+    }
+
     @Override
     public int hashCode() {
         int result = type != null ? type.hashCode() : 0;
@@ -154,4 +179,88 @@ public class ParameterDesc implements Serializable {
         return "ParameterDesc [type=" + type + ", value=" + value + ", nextParam=" + nextParameter + "]";
     }
 
+    /**
+     * PlainParameter is created to present ParameterDesc in List style.
+     * Compared to ParameterDesc its advantage is:
+     * 1. easy to compare without considering order
+     * 2. easy to compare one by one
+     */
+    private static class PlainParameter {
+        private String type;
+        private String value;
+        private TblColRef colRef = null;
+
+        private PlainParameter() {
+        }
+
+        public boolean isColumnType() {
+            return FunctionDesc.PARAMETER_TYPE_COLUMN.equals(type);
+        }
+
+        static Set<PlainParameter> createFromParameterDesc(ParameterDesc parameterDesc) {
+            Set<PlainParameter> result = Sets.newHashSet();
+            ParameterDesc local = parameterDesc;
+            List<TblColRef> totalColRef = parameterDesc.colRefs;
+            Integer colIndex = 0;
+            while (local != null) {
+                if (local.isColumnType()) {
+                    result.add(createSingleColumnParameter(local, totalColRef.get(colIndex++)));
+                } else {
+                    result.add(createSingleValueParameter(local));
+                }
+                local = local.nextParameter;
+            }
+            return result;
+        }
+
+        static PlainParameter createSingleValueParameter(ParameterDesc parameterDesc) {
+            PlainParameter single = new PlainParameter();
+            single.type = parameterDesc.type;
+            single.value = parameterDesc.value;
+            return single;
+        }
+
+        static PlainParameter createSingleColumnParameter(ParameterDesc parameterDesc, TblColRef colRef) {
+            PlainParameter single = new PlainParameter();
+            single.type = parameterDesc.type;
+            single.value = parameterDesc.value;
+            single.colRef = colRef;
+            return single;
+        }
+
+        @Override
+        public int hashCode() {
+            int result = type != null ? type.hashCode() : 0;
+            result = 31 * result + (colRef != null ? colRef.hashCode() : 0);
+            return result;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o)
+                return true;
+            if (o == null || getClass() != o.getClass())
+                return false;
+
+            PlainParameter that = (PlainParameter) o;
+
+            if (type != null ? !type.equals(that.type) : that.type != null)
+                return false;
+
+            if (this.isColumnType()) {
+                if (!that.isColumnType())
+                    return false;
+                if (!this.colRef.equals(that.colRef)) {
+                    return false;
+                }
+            } else {
+                if (that.isColumnType())
+                    return false;
+                if (!this.value.equals(that.value))
+                    return false;
+            }
+
+            return true;
+        }
+    }
 }

http://git-wip-us.apache.org/repos/asf/kylin/blob/88a1c71d/kylin-it/src/test/resources/query/sql_distinct/query08.sql
----------------------------------------------------------------------
diff --git a/kylin-it/src/test/resources/query/sql_distinct/query08.sql b/kylin-it/src/test/resources/query/sql_distinct/query08.sql
new file mode 100644
index 0000000..60f02e7
--- /dev/null
+++ b/kylin-it/src/test/resources/query/sql_distinct/query08.sql
@@ -0,0 +1,24 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements.  See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership.  The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License.  You may obtain a copy of the License at
+--
+--     http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+select cal_dt,
+ sum(price) as GMV, 
+ count(1) as TRANS_CNT, 
+ count(distinct seller_id, lstg_format_name) as DIST_SELLER_FORMAT
+ from test_kylin_fact 
+ group by cal_dt

http://git-wip-us.apache.org/repos/asf/kylin/blob/88a1c71d/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java
----------------------------------------------------------------------
diff --git a/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java b/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java
index 8d7c597..2c75a14 100644
--- a/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java
+++ b/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java
@@ -18,6 +18,7 @@
 
 package org.apache.kylin.query.relnode;
 
+import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -55,6 +56,7 @@ import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.Util;
 import org.apache.kylin.measure.MeasureTypeFactory;
+import org.apache.kylin.measure.ParamAsMeasureCount;
 import org.apache.kylin.metadata.model.FunctionDesc;
 import org.apache.kylin.metadata.model.MeasureDesc;
 import org.apache.kylin.metadata.model.ParameterDesc;
@@ -71,6 +73,7 @@ import com.google.common.collect.Sets;
 public class OLAPAggregateRel extends Aggregate implements OLAPRel {
 
     private final static Map<String, String> AGGR_FUNC_MAP = new HashMap<String, String>();
+    private final static Map<String, Integer> AGGR_FUNC_PARAM_AS_MEASTURE_MAP = new HashMap<String, Integer>();
 
     static {
         AGGR_FUNC_MAP.put("SUM", "SUM");
@@ -84,6 +87,15 @@ public class OLAPAggregateRel extends Aggregate implements OLAPRel {
         for (String udaf : udafFactories.keySet()) {
             AGGR_FUNC_MAP.put(udaf, udafFactories.get(udaf).getAggrFunctionName());
         }
+
+        Map<String, Class<?>> udafs = MeasureTypeFactory.getUDAFs();
+        for (String func : udafs.keySet()) {
+            try {
+                AGGR_FUNC_PARAM_AS_MEASTURE_MAP.put(func, ((ParamAsMeasureCount) (udafs.get(func).newInstance())).getParamAsMeasureCount());
+            } catch (Exception e) {
+                throw new RuntimeException(e);
+            }
+        }
     }
 
     private static String getSqlFuncName(AggregateCall aggCall) {
@@ -235,12 +247,27 @@ public class OLAPAggregateRel extends Aggregate implements OLAPRel {
         this.aggregations = new ArrayList<FunctionDesc>();
         for (AggregateCall aggCall : this.rewriteAggCalls) {
             ParameterDesc parameter = null;
+            // By default all args are included, UDFs can define their own in getParamAsMeasureCount method.
             if (!aggCall.getArgList().isEmpty()) {
-                // TODO: Currently only get the column of first param
-                int index = aggCall.getArgList().get(0);
-                TblColRef column = inputColumnRowType.getColumnByIndex(index);
-                if (!column.isInnerColumn()) {
-                    parameter = ParameterDesc.newInstance(column);
+                List<TblColRef> columns = Lists.newArrayList();
+                String funcName = getSqlFuncName(aggCall);
+                int columnsCount = aggCall.getArgList().size();
+                if (AGGR_FUNC_PARAM_AS_MEASTURE_MAP.containsKey(funcName)) {
+                    int asMeasureCnt = AGGR_FUNC_PARAM_AS_MEASTURE_MAP.get(funcName);
+                    if (asMeasureCnt > 0) {
+                        columnsCount = asMeasureCnt;
+                    } else {
+                        columnsCount += asMeasureCnt;
+                    }
+                }
+                for (Integer index : aggCall.getArgList().subList(0, columnsCount)) {
+                    TblColRef column = inputColumnRowType.getColumnByIndex(index);
+                    if (!column.isInnerColumn()) {
+                        columns.add(column);
+                    }
+                }
+                if (!columns.isEmpty()) {
+                    parameter = ParameterDesc.newInstance(columns.toArray(new TblColRef[columns.size()]));
                 }
             }
             String expression = getAggrFuncName(aggCall);
@@ -341,10 +368,11 @@ public class OLAPAggregateRel extends Aggregate implements OLAPRel {
 
             AggregateCall aggCall = this.rewriteAggCalls.get(i);
             if (!aggCall.getArgList().isEmpty()) {
-                int index = aggCall.getArgList().get(0);
-                TblColRef column = inputColumnRowType.getColumnByIndex(index);
-                if (!column.isInnerColumn()) {
-                    this.context.metricsColumns.add(column);
+                for (Integer index : aggCall.getArgList()) {
+                    TblColRef column = inputColumnRowType.getColumnByIndex(index);
+                    if (!column.isInnerColumn()) {
+                        this.context.metricsColumns.add(column);
+                    }
                 }
             }
         }
@@ -385,18 +413,6 @@ public class OLAPAggregateRel extends Aggregate implements OLAPRel {
             return aggCall;
         }
 
-        // rebuild parameters
-        List<Integer> newArgList = Lists.newArrayList(aggCall.getArgList());
-        if (func.needRewriteField()) {
-            RelDataTypeField field = getInput().getRowType().getField(func.getRewriteFieldName(), true, false);
-            if (newArgList.isEmpty()) {
-                newArgList.add(field.getIndex());
-            } else {
-                // only the first column got overwritten
-                newArgList.set(0, field.getIndex());
-            }
-        }
-
         // rebuild function
         String callName = getSqlFuncName(aggCall);
         RelDataType fieldType = aggCall.getType();
@@ -408,12 +424,40 @@ public class OLAPAggregateRel extends Aggregate implements OLAPRel {
             newAgg = createCustomAggFunction(callName, fieldType, udafMap.get(callName));
         }
 
+        // rebuild parameters
+        List<Integer> newArgList = Lists.newArrayList(aggCall.getArgList());
+        if (udafMap != null && udafMap.containsKey(callName)) {
+            newArgList = truncArgList(newArgList, udafMap.get(callName));
+        }
+        if (func.needRewriteField()) {
+            RelDataTypeField field = getInput().getRowType().getField(func.getRewriteFieldName(), true, false);
+            if (newArgList.isEmpty()) {
+                newArgList.add(field.getIndex());
+            } else {
+                // TODO: only the first column got overwritten
+                newArgList.set(0, field.getIndex());
+            }
+        }
+
         // rebuild aggregate call
         @SuppressWarnings("deprecation")
         AggregateCall newAggCall = new AggregateCall(newAgg, false, newArgList, fieldType, callName);
         return newAggCall;
     }
 
+    /**
+     * truncate Arg List according to UDAF's "add" method parameter count
+     */
+    private List<Integer> truncArgList(List<Integer> argList, Class<?> udafClazz) {
+        int argListLength = argList.size();
+        for (Method method : udafClazz.getMethods()) {
+            if (method.getName().equals("add")) {
+                argListLength = Math.min(method.getParameterTypes().length - 1, argListLength);
+            }
+        }
+        return argList.subList(0, argListLength);
+    }
+
     private SqlAggFunction createCustomAggFunction(String funcName, RelDataType returnType, Class<?> customAggFuncClz) {
         RelDataTypeFactory typeFactory = getCluster().getTypeFactory();
         SqlIdentifier sqlIdentifier = new SqlIdentifier(funcName, new SqlParserPos(1, 1));