You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2023/01/03 14:09:36 UTC

[doris] branch master updated: [fix](nereids) binding priority in agg-sort, having, group_by_key (#15240)

This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d0c06c897 [fix](nereids) binding priority in agg-sort, having, group_by_key (#15240)
8d0c06c897 is described below

commit 8d0c06c89717e08daf360dd39ed3e4cb6a4e7fc5
Author: minghong <en...@gmail.com>
AuthorDate: Tue Jan 3 22:09:28 2023 +0800

    [fix](nereids) binding priority in agg-sort, having, group_by_key (#15240)
    
    This PR defines order_key and having_key binding priority.
    
    1. order key priority
     ```
                    select
                            col1 * -1 as col1    # inner_col1 * -1 as alias_col1
                    from
                            t
                    order by col1;     # order by order_col1
    ```
    to bind `order_col1`, `alias_col1` has higher priority than `inner_col1`
    
    2. having key priority
    ```
           select (a-1) as a  # inner_a - 1 as alias_a
           from bind_priority_tbl
           group by a
           having a=1;
    ```
    to bind having key, `inner_a` has higher priority than `alias_a`
    
    3. group by key binding priority
    ```
    SELECT date_format(b.k10,
             '%Y%m%d') AS k10
    FROM test a
    LEFT JOIN
        (SELECT k10
        FROM baseall) b
        ON a.k10 = b.k10
    GROUP BY  k10;
    ```
    group_by_key (k10) binding priority:
    
    - agg.child.output
    - agg.output
    if binding with agg.child.output failed(the slot not found, or more than one candidate slot found in agg.child.output), nereids try to bind group_by_key with agg.output.
    In above example, nereids found 2 candidate slots (a.k10, b.k10) in agg.child.output for group_by_key (k10), binding with agg.child.output failed. Then nereids try to bind group_by_key with agg.output, that is `date_format(b.k10, '%Y%m%d') AS k10`. and finally, group_by_key is bound with `alias k10`
---
 .../doris/catalog/BuiltinAggregateFunctions.java   |   4 +
 .../org/apache/doris/catalog/FunctionHelper.java   |   2 +
 .../nereids/rules/analysis/BindSlotReference.java  | 105 ++++++++++++++++++---
 .../nereids/rules/analysis/CheckAnalysis.java      |  21 +++--
 .../rules/analysis/BindSlotReferenceTest.java      |   2 +-
 .../nereids/rules/analysis/CheckAnalysisTest.java  |  13 +++
 .../data/nereids_syntax_p0/bind_priority.out       |  13 +++
 .../suites/nereids_syntax_p0/bind_priority.groovy  |  59 ++++++++++++
 8 files changed, 199 insertions(+), 20 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
index 4a5e9c3607..c253973aba 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
@@ -40,6 +40,9 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.VarianceSamp;
 
 import com.google.common.collect.ImmutableList;
 
+import java.util.HashSet;
+import java.util.Set;
+
 /**
  * Builtin aggregate functions.
  * <p>
@@ -47,6 +50,7 @@ import com.google.common.collect.ImmutableList;
  * It helps to be clear and concise.
  */
 public class BuiltinAggregateFunctions implements FunctionHelper {
+    public static Set<String> aggFuncNames = new HashSet<>();
     public final ImmutableList<AggregateFunc> aggregateFunctions = ImmutableList.of(
             agg(ApproxCountDistinct.class, "approx_count_distinct"),
             agg(Avg.class),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionHelper.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionHelper.java
index 58c4c91c13..bbedbf39b2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionHelper.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionHelper.java
@@ -66,6 +66,7 @@ public interface FunctionHelper {
 
     default AggregateFunc agg(Class<? extends AggregateFunction> functionClass) {
         String functionName = functionClass.getSimpleName();
+        BuiltinAggregateFunctions.aggFuncNames.add(functionName.toLowerCase());
         return new AggregateFunc(functionClass, functionName);
     }
 
@@ -75,6 +76,7 @@ public interface FunctionHelper {
      * @return AggregateFunc which contains the functionName and the AggregateFunc
      */
     default AggregateFunc agg(Class<? extends AggregateFunction> functionClass, String... functionNames) {
+        Arrays.stream(functionNames).forEach(name -> BuiltinAggregateFunctions.aggFuncNames.add(name));
         return new AggregateFunc(functionClass, functionNames);
     }
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
index bf171d0541..d0b9e1c599 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.analysis;
 
+import org.apache.doris.catalog.BuiltinAggregateFunctions;
 import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.analyzer.UnboundAlias;
 import org.apache.doris.nereids.analyzer.UnboundFunction;
@@ -194,11 +195,55 @@ public class BindSlotReference implements AnalysisRuleFactory {
 
                     // The columns referenced in group by are first obtained from the child's output,
                     // and then from the node's output
+                    Set<String> duplicatedSlotNames = new HashSet<>();
                     Map<String, Expression> childOutputsToExpr = agg.child().getOutput().stream()
-                            .collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr));
+                            .collect(Collectors.toMap(Slot::getName, Slot::toSlot,
+                                    (oldExpr, newExpr) -> {
+                                        duplicatedSlotNames.add(((Slot) oldExpr).getName());
+                                        return oldExpr;
+                                }));
+                    /*
+                    GroupByKey binding priority:
+                    1. child.output
+                    2. agg.output
+                    CASE 1
+                     k is not in agg.output
+                     plan:
+                         agg(group_by: k)
+                          +---child(output t1.k, t2.k)
+
+                     group_by_key: k is ambiguous, t1.k and t2.k are candidate.
+
+                    CASE 2
+                     k is in agg.output
+                     plan:
+                         agg(group_by: k, output (k+1 as k)
+                          +---child(output t1.k, t2.k)
+
+                     it is failed to bind group_by_key with child.output(ambiguous), but group_by_key can be bound with
+                     agg.output
+
+                    CASE 3
+                     group by key cannot bind with agg func
+                     plan:
+                        agg(group_by v, output sum(k) as v)
+
+                     throw AnalysisException
+                    */
+                    duplicatedSlotNames.stream().forEach(dup -> childOutputsToExpr.remove(dup));
                     Map<String, Expression> aliasNameToExpr = output.stream()
                             .filter(ne -> ne instanceof Alias)
                             .map(Alias.class::cast)
+                            //agg function cannot be bound with group_by_key
+                            .filter(alias -> ! alias.child().anyMatch(expr -> {
+                                        if (expr instanceof UnboundFunction) {
+                                            UnboundFunction unboundFunction = (UnboundFunction) expr;
+                                            return BuiltinAggregateFunctions.aggFuncNames.contains(
+                                                    unboundFunction.getName().toLowerCase());
+                                        }
+                                        return false;
+                                    }
+                            ))
                             .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr));
                     aliasNameToExpr.entrySet().stream()
                             .forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue()));
@@ -218,6 +263,18 @@ public class BindSlotReference implements AnalysisRuleFactory {
                             }).collect(Collectors.toList());
 
                     List<Expression> groupBy = bind(replacedGroupBy, agg.children(), agg, ctx.cascadesContext);
+                    List<Expression> unboundGroupBys = Lists.newArrayList();
+                    boolean hasUnbound = groupBy.stream().anyMatch(
+                            expression -> {
+                                if (expression.anyMatch(UnboundSlot.class::isInstance)) {
+                                    unboundGroupBys.add(expression);
+                                    return true;
+                                }
+                                return false;
+                            });
+                    if (hasUnbound) {
+                        throw new AnalysisException("cannot bind GROUP BY KEY: " + unboundGroupBys.get(0).toSql());
+                    }
                     List<NamedExpression> newOutput = adjustNullableForAgg(agg, output);
                     return agg.withGroupByAndOutput(groupBy, newOutput);
                 })
@@ -332,12 +389,20 @@ public class BindSlotReference implements AnalysisRuleFactory {
                     Plan childPlan = having.child();
                     // We should deduplicate the slots, otherwise the binding process will fail due to the
                     // ambiguous slots exist.
-                    Set<Slot> boundSlots = Stream.concat(Stream.of(childPlan), childPlan.children().stream())
-                            .flatMap(plan -> plan.getOutput().stream())
-                            .collect(Collectors.toSet());
-                    SlotBinder binder = new SlotBinder(toScope(Lists.newArrayList(boundSlots)), having,
+                    List<Slot> childChildSlots = childPlan.children().stream()
+                            .flatMap(plan -> plan.getOutputSet().stream())
+                            .collect(Collectors.toList());
+                    SlotBinder childChildBinder = new SlotBinder(toScope(childChildSlots), having,
                             ctx.cascadesContext);
-                    Set<Expression> boundConjuncts = having.getConjuncts().stream().map(binder::bind)
+                    List<Slot> childSlots = childPlan.getOutputSet().stream()
+                            .collect(Collectors.toList());
+                    SlotBinder childBinder = new SlotBinder(toScope(childSlots), having,
+                            ctx.cascadesContext);
+                    Set<Expression> boundConjuncts = having.getConjuncts().stream().map(
+                            expr -> {
+                                expr = childChildBinder.bind(expr);
+                                return childBinder.bind(expr);
+                            })
                             .collect(Collectors.toSet());
                     return new LogicalHaving<>(boundConjuncts, having.child());
                 })
@@ -411,16 +476,30 @@ public class BindSlotReference implements AnalysisRuleFactory {
 
     private Plan bindSortWithAggregateFunction(
             LogicalSort<? extends Plan> sort, Aggregate<? extends Plan> aggregate, CascadesContext ctx) {
-        // We should deduplicate the slots, otherwise the binding process will fail due to the
-        // ambiguous slots exist.
-        Set<Slot> boundSlots = Stream.concat(Stream.of(aggregate), aggregate.children().stream())
-                .flatMap(plan -> plan.getOutput().stream())
-                .collect(Collectors.toSet());
+        // 1. We should deduplicate the slots, otherwise the binding process will fail due to the
+        //    ambiguous slots exist.
+        // 2. try to bound order-key with agg output, if failed, try to bound with output of agg.child
+        //    binding priority example:
+        //        select
+        //        col1 * -1 as col1    # inner_col1 * -1 as alias_col1
+        //        from
+        //                (
+        //                        select 1 as col1
+        //                        union
+        //                        select -2 as col1
+        //                ) t
+        //        group by col1
+        //        order by col1;     # order by order_col1
+        //    bind order_col1 with alias_col1, then, bind it with inner_col1
+        SlotBinder outputBinder = new SlotBinder(
+                toScope(aggregate.getOutputSet().stream().collect(Collectors.toList())), sort, ctx);
+        List<Slot> childOutputSlots = aggregate.child().getOutputSet().stream().collect(Collectors.toList());
+        SlotBinder childOutputBinder = new SlotBinder(toScope(childOutputSlots), sort, ctx);
         List<OrderKey> sortItemList = sort.getOrderKeys()
                 .stream()
                 .map(orderKey -> {
-                    Expression item = new SlotBinder(toScope(new ArrayList<>(boundSlots)), sort, ctx)
-                            .bind(orderKey.getExpr());
+                    Expression item = outputBinder.bind(orderKey.getExpr());
+                    item = childOutputBinder.bind(item);
                     return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
                 }).collect(Collectors.toList());
         return new LogicalSort<>(sortItemList, sort.child());
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java
index 8ec06987cc..a4e14bfc80 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java
@@ -17,6 +17,8 @@
 
 package org.apache.doris.nereids.rules.analysis;
 
+import org.apache.doris.nereids.analyzer.Unbound;
+import org.apache.doris.nereids.analyzer.UnboundFunction;
 import org.apache.doris.nereids.analyzer.UnboundSlot;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.rules.Rule;
@@ -71,14 +73,21 @@ public class CheckAnalysis implements AnalysisRuleFactory {
     }
 
     private void checkBound(Plan plan) {
-        Set<UnboundSlot> unboundSlots = plan.getExpressions().stream()
-                .<Set<UnboundSlot>>map(e -> e.collect(UnboundSlot.class::isInstance))
+        Set<Unbound> unbounds = plan.getExpressions().stream()
+                .<Set<Unbound>>map(e -> e.collect(Unbound.class::isInstance))
                 .flatMap(Set::stream)
                 .collect(Collectors.toSet());
-        if (!unboundSlots.isEmpty()) {
-            throw new AnalysisException(String.format("Cannot find column %s.",
-                    StringUtils.join(unboundSlots.stream()
-                            .map(UnboundSlot::toSql)
+        if (!unbounds.isEmpty()) {
+            throw new AnalysisException(String.format("unbounded object %s.",
+                    StringUtils.join(unbounds.stream()
+                            .map(unbound -> {
+                                if (unbound instanceof UnboundSlot) {
+                                    return ((UnboundSlot) unbound).toSql();
+                                } else if (unbound instanceof UnboundFunction) {
+                                    return ((UnboundFunction) unbound).toSql();
+                                }
+                                return unbound.toString();
+                            })
                             .collect(Collectors.toSet()), ", ")));
         }
     }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
index 6537b96e60..e63618cdd7 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
@@ -47,7 +47,7 @@ class BindSlotReferenceTest {
                 new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student));
         AnalysisException exception = Assertions.assertThrows(AnalysisException.class,
                 () -> PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(project));
-        Assertions.assertEquals("Cannot find column foo.", exception.getMessage());
+        Assertions.assertEquals("unbounded object foo.", exception.getMessage());
     }
 
     @Test
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckAnalysisTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckAnalysisTest.java
index 62c8132cc4..cb455a76cf 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckAnalysisTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckAnalysisTest.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.rules.analysis;
 
 import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.analyzer.UnboundFunction;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.And;
@@ -31,6 +32,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
 import mockit.Mocked;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
@@ -57,4 +59,15 @@ public class CheckAnalysisTest {
         Assertions.assertThrows(AnalysisException.class, () ->
                 checkAnalysis.buildRules().forEach(rule -> rule.transform(plan, cascadesContext)));
     }
+
+    @Test
+    public void testUnbound() {
+        UnboundFunction func = new UnboundFunction("now", Lists.newArrayList(new IntegerLiteral(1)));
+        Plan plan = new LogicalOneRowRelation(
+                ImmutableList.of(new Alias(func, "unboundFunction")));
+        CheckAnalysis checkAnalysis = new CheckAnalysis();
+        Assertions.assertThrows(AnalysisException.class, () ->
+                checkAnalysis.buildRules().forEach(rule -> rule.transform(plan, cascadesContext)));
+    }
+
 }
diff --git a/regression-test/data/nereids_syntax_p0/bind_priority.out b/regression-test/data/nereids_syntax_p0/bind_priority.out
new file mode 100644
index 0000000000..90d639c228
--- /dev/null
+++ b/regression-test/data/nereids_syntax_p0/bind_priority.out
@@ -0,0 +1,13 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !select --
+-3
+-1
+
+-- !select --
+a	1
+all	1
+all	2
+
+-- !select --
+0
+
diff --git a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy
new file mode 100644
index 0000000000..89e878e89b
--- /dev/null
+++ b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("bind_priority") {
+    sql "SET enable_nereids_planner=true"
+
+    sql """
+        DROP TABLE IF EXISTS bind_priority_tbl
+       """
+
+    sql """CREATE TABLE IF NOT EXISTS bind_priority_tbl (a int not null, b int not null)
+        DISTRIBUTED BY HASH(a)
+        BUCKETS 1
+        PROPERTIES(
+            "replication_num"="1"
+        )
+        """
+
+    sql """
+    insert into bind_priority_tbl values(1, 2),(3, 4)
+    """
+
+    sql "SET enable_fallback_to_original_planner=false"
+
+    sql """sync"""
+
+    qt_select """
+        select a * -1 as a from bind_priority_tbl group by a order by a;
+    """
+
+    qt_select """
+        select coalesce(a, 'all') as a, count(*) as cnt from (select  null as a  union all  select  'a' as a ) t group by grouping sets ((a),()) order by a;
+    """
+
+    qt_select """
+        select (a-1) as a from bind_priority_tbl group by a having a=1;
+    """
+
+    test {
+        sql """
+            select sum(a) as v from bind_priority_tbl  group by v;
+            """
+        exception "Unexpected exception: cannot bind GROUP BY KEY: v"
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org