You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by hy...@apache.org on 2020/06/08 23:40:54 UTC

[calcite] branch master updated: [CALCITE-4007] MergeJoin collation check should not be limited to join key's order

This is an automated email from the ASF dual-hosted git repository.

hyuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new dcc76ce  [CALCITE-4007] MergeJoin collation check should not be limited to join key's order
dcc76ce is described below

commit dcc76cede53c7971bc9c3755d9261e766aa63b66
Author: Haisheng Yuan <h....@alibaba-inc.com>
AuthorDate: Sun May 31 12:00:32 2020 -0500

    [CALCITE-4007] MergeJoin collation check should not be limited to join key's order
    
    Given MergeJoin on foo.a=bar.a and foo.b=bar.b,
    The collation check on MergeJoin always require it is sorted by (a,b), but
    after 1.23.0 calcite can generate MergeJoin on collation of (b,a), or even
    (a,b,c), (b,a,c), which are all legit. We should relax the check condition.
    
    This also fixed CALCITE-4050.
    
    Close #2010
---
 .../adapter/enumerable/EnumerableMergeJoin.java    | 102 +++++++++++++--------
 .../java/org/apache/calcite/rel/RelCollation.java  |   2 +-
 .../java/org/apache/calcite/rel/RelCollations.java |  36 ++++++++
 .../calcite/rel/metadata/RelMdCollation.java       |   4 -
 .../org/apache/calcite/util/ImmutableIntList.java  |  12 +++
 .../org/apache/calcite/rel/RelCollationTest.java   |  16 ++++
 .../apache/calcite/test/SqlHintsConverterTest.java |   7 +-
 .../apache/calcite/test/SqlHintsConverterTest.xml  |   6 +-
 8 files changed, 140 insertions(+), 45 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
index 4105682..f9f4d1a 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
@@ -44,6 +44,7 @@ import org.apache.calcite.util.BuiltInMethod;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.ImmutableIntList;
 import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.Util;
 import org.apache.calcite.util.mapping.Mappings;
 
 import com.google.common.collect.ImmutableList;
@@ -69,6 +70,39 @@ public class EnumerableMergeJoin extends Join implements EnumerableRel {
       Set<CorrelationId> variablesSet,
       JoinRelType joinType) {
     super(cluster, traits, ImmutableList.of(), left, right, condition, variablesSet, joinType);
+    final List<RelCollation> leftCollations =
+        left.getTraitSet().getTraits(RelCollationTraitDef.INSTANCE);
+    final List<RelCollation> rightCollations =
+        right.getTraitSet().getTraits(RelCollationTraitDef.INSTANCE);
+
+    // If the join keys are not distinct, the sanity check doesn't apply.
+    // e.g. t1.a=t2.b and t1.a=t2.c
+    boolean isDistinct = Util.isDistinct(joinInfo.leftKeys)
+        && Util.isDistinct(joinInfo.rightKeys);
+
+    if (!RelCollations.containsOrderless(leftCollations, joinInfo.leftKeys)
+        || !RelCollations.containsOrderless(rightCollations, joinInfo.rightKeys)) {
+      if (isDistinct) {
+        throw new RuntimeException("wrong collation in left or right input");
+      }
+    }
+
+    final List<RelCollation> collations =
+        traits.getTraits(RelCollationTraitDef.INSTANCE);
+    assert collations != null && collations.size() > 0;
+    ImmutableIntList rightKeys = joinInfo.rightKeys
+        .incr(left.getRowType().getFieldCount());
+    // Currently it has very limited ability to represent the equivalent traits
+    // due to the flaw of RelCompositeTrait, so the following case is totally
+    // legit, but not yet supported:
+    // SELECT * FROM foo JOIN bar ON foo.a = bar.c AND foo.b = bar.d;
+    // MergeJoin has collation on [a, d], or [b, c]
+    if (!RelCollations.containsOrderless(collations, joinInfo.leftKeys)
+        && !RelCollations.containsOrderless(collations, rightKeys)) {
+      if (isDistinct) {
+        throw new RuntimeException("wrong collation for mergejoin");
+      }
+    }
     if (!isMergeJoinSupported(joinType)) {
       throw new UnsupportedOperationException(
           "EnumerableMergeJoin unsupported for join type " + joinType);
@@ -107,26 +141,18 @@ public class EnumerableMergeJoin extends Join implements EnumerableRel {
     ImmutableBitSet rightKeySet = ImmutableBitSet.of(joinInfo.rightKeys)
         .shift(left.getRowType().getFieldCount());
 
-    Map<Integer, Integer> keyMap = new HashMap<>();
-    final int keyCount = leftKeySet.cardinality();
-    for (int i = 0; i < keyCount; i++) {
-      keyMap.put(joinInfo.leftKeys.get(i), joinInfo.rightKeys.get(i));
-    }
-    Mappings.TargetMapping mapping = Mappings.target(keyMap,
-        left.getRowType().getFieldCount(),
-        right.getRowType().getFieldCount());
-
     // Only consider exact key match for now
     if (reqKeySet.equals(leftKeySet)) {
-      RelCollation rightCollation = RexUtil.apply(mapping, collation);
+      Mappings.TargetMapping mapping = buildMapping(true);
+      RelCollation rightCollation = collation.apply(mapping);
       return Pair.of(
           required, ImmutableList.of(required,
           required.replace(rightCollation)));
     } else if (reqKeySet.equals(rightKeySet)) {
       RelCollation rightCollation = RelCollations.shift(collation,
           -left.getRowType().getFieldCount());
-      Mappings.TargetMapping invMapping = mapping.inverse();
-      RelCollation leftCollation = RexUtil.apply(invMapping, rightCollation);
+      Mappings.TargetMapping mapping = buildMapping(false);
+      RelCollation leftCollation = rightCollation.apply(mapping);
       return Pair.of(
           required, ImmutableList.of(
           required.replace(leftCollation),
@@ -148,44 +174,32 @@ public class EnumerableMergeJoin extends Join implements EnumerableRel {
     if (colCount > keyCount) {
       collation = RelCollations.of(collation.getFieldCollations().subList(0, keyCount));
     }
+
+    ImmutableIntList sourceKeys = childId == 0 ? joinInfo.leftKeys : joinInfo.rightKeys;
+    ImmutableBitSet keySet = ImmutableBitSet.of(sourceKeys);
     ImmutableBitSet childCollationKeys = ImmutableBitSet.of(
         RelCollations.ordinals(collation));
-
-    Map<Integer, Integer> keyMap = new HashMap<>();
-    for (int i = 0; i < keyCount; i++) {
-      keyMap.put(joinInfo.leftKeys.get(i), joinInfo.rightKeys.get(i));
+    if (!childCollationKeys.equals(keySet)) {
+      return null;
     }
 
-    Mappings.TargetMapping mapping = Mappings.target(keyMap,
-        left.getRowType().getFieldCount(),
-        right.getRowType().getFieldCount());
+    Mappings.TargetMapping mapping = buildMapping(childId == 0);
+    RelCollation targetCollation = collation.apply(mapping);
 
     if (childId == 0) {
       // traits from left child
-      ImmutableBitSet keySet = ImmutableBitSet.of(joinInfo.leftKeys);
-      if (!childCollationKeys.equals(keySet)) {
-        return null;
-      }
-      RelCollation rightCollation = RexUtil.apply(mapping, collation);
       RelTraitSet joinTraits = getTraitSet().replace(collation);
-
       // Forget about the equiv keys for the moment
-      return Pair.of(
-          joinTraits, ImmutableList.of(childTraits,
-          right.getTraitSet().replace(rightCollation)));
+      return Pair.of(joinTraits,
+          ImmutableList.of(childTraits,
+          right.getTraitSet().replace(targetCollation)));
     } else {
       // traits from right child
       assert childId == 1;
-      ImmutableBitSet keySet = ImmutableBitSet.of(joinInfo.rightKeys);
-      if (!childCollationKeys.equals(keySet)) {
-        return null;
-      }
-      RelCollation leftCollation = RexUtil.apply(mapping.inverse(), collation);
-      RelTraitSet joinTraits = getTraitSet().replace(leftCollation);
-
+      RelTraitSet joinTraits = getTraitSet().replace(targetCollation);
       // Forget about the equiv keys for the moment
-      return Pair.of(
-          joinTraits, ImmutableList.of(left.getTraitSet().replace(leftCollation),
+      return Pair.of(joinTraits,
+          ImmutableList.of(joinTraits,
           childTraits.replace(collation)));
     }
   }
@@ -194,6 +208,20 @@ public class EnumerableMergeJoin extends Join implements EnumerableRel {
     return DeriveMode.BOTH;
   }
 
+  private Mappings.TargetMapping buildMapping(boolean left2Right) {
+    ImmutableIntList sourceKeys = left2Right ? joinInfo.leftKeys : joinInfo.rightKeys;
+    ImmutableIntList targetKeys = left2Right ? joinInfo.rightKeys : joinInfo.leftKeys;
+    Map<Integer, Integer> keyMap = new HashMap<>();
+    for (int i = 0; i < joinInfo.leftKeys.size(); i++) {
+      keyMap.put(sourceKeys.get(i), targetKeys.get(i));
+    }
+
+    Mappings.TargetMapping mapping = Mappings.target(keyMap,
+        (left2Right ? left : right).getRowType().getFieldCount(),
+        (left2Right ? right : left).getRowType().getFieldCount());
+    return mapping;
+  }
+
   public static EnumerableMergeJoin create(RelNode left, RelNode right,
       RexNode condition, ImmutableIntList leftKeys,
       ImmutableIntList rightKeys, JoinRelType joinType) {
diff --git a/core/src/main/java/org/apache/calcite/rel/RelCollation.java b/core/src/main/java/org/apache/calcite/rel/RelCollation.java
index 2cf5156..35e101d 100644
--- a/core/src/main/java/org/apache/calcite/rel/RelCollation.java
+++ b/core/src/main/java/org/apache/calcite/rel/RelCollation.java
@@ -39,7 +39,7 @@ public interface RelCollation extends RelMultipleTrait {
   /**
    * Returns the ordinals of the key columns.
    */
-  default @Nonnull List<Integer> getKeys() {
+  default @Nonnull ImmutableIntList getKeys() {
     final List<RelFieldCollation> collations = getFieldCollations();
     final int size = collations.size();
     final int[] keys = new int[size];
diff --git a/core/src/main/java/org/apache/calcite/rel/RelCollations.java b/core/src/main/java/org/apache/calcite/rel/RelCollations.java
index 8f1bfb1..05914bb 100644
--- a/core/src/main/java/org/apache/calcite/rel/RelCollations.java
+++ b/core/src/main/java/org/apache/calcite/rel/RelCollations.java
@@ -17,6 +17,7 @@
 package org.apache.calcite.rel;
 
 import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.ImmutableIntList;
 import org.apache.calcite.util.Util;
 import org.apache.calcite.util.mapping.Mappings;
@@ -191,6 +192,41 @@ public class RelCollations {
     return false;
   }
 
+  /** Returns whether a collation contains a given list of keys regardless
+   * the order.
+   *
+   * @param collation Collation
+   * @param keys List of keys
+   * @return Whether the collection contains the given keys
+   */
+  private static boolean containsOrderless(RelCollation collation,
+      List<Integer> keys) {
+    final List<Integer> distinctKeys = Util.distinctList(keys);
+    final ImmutableBitSet keysBitSet = ImmutableBitSet.of(distinctKeys);
+    List<Integer> colKeys = Util.distinctList(collation.getKeys());
+    if (colKeys.size() < distinctKeys.size()) {
+      return false;
+    }
+    ImmutableBitSet bitset = ImmutableBitSet.of(
+        colKeys.subList(0, distinctKeys.size()));
+    return bitset.equals(keysBitSet);
+  }
+
+  /**
+   * Returns whether one of a list of collations contains the given list of keys
+   * regardless the order.
+   */
+  public static boolean containsOrderless(List<RelCollation> collations,
+      List<Integer> keys) {
+    final List<Integer> distinctKeys = Util.distinctList(keys);
+    for (RelCollation collation : collations) {
+      if (containsOrderless(collation, distinctKeys)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   public static RelCollation shift(RelCollation collation, int offset) {
     if (offset == 0) {
       return collation; // save some effort
diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdCollation.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdCollation.java
index 5b7b8ba..1a5b42b 100644
--- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdCollation.java
+++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdCollation.java
@@ -454,8 +454,6 @@ public class RelMdCollation
         : "EnumerableMergeJoin unsupported for join type " + joinType;
 
     final ImmutableList<RelCollation> leftCollations = mq.collations(left);
-    assert RelCollations.contains(leftCollations, leftKeys)
-        : "cannot merge join: left input is not sorted on left keys";
     if (!joinType.projectsRight()) {
       return leftCollations;
     }
@@ -464,8 +462,6 @@ public class RelMdCollation
     builder.addAll(leftCollations);
 
     final ImmutableList<RelCollation> rightCollations = mq.collations(right);
-    assert RelCollations.contains(rightCollations, rightKeys)
-        : "cannot merge join: right input is not sorted on right keys";
     final int leftFieldCount = left.getRowType().getFieldCount();
     for (RelCollation collation : rightCollations) {
       builder.add(RelCollations.shift(collation, leftFieldCount));
diff --git a/core/src/main/java/org/apache/calcite/util/ImmutableIntList.java b/core/src/main/java/org/apache/calcite/util/ImmutableIntList.java
index de5d7d4..9fb4d77 100644
--- a/core/src/main/java/org/apache/calcite/util/ImmutableIntList.java
+++ b/core/src/main/java/org/apache/calcite/util/ImmutableIntList.java
@@ -264,6 +264,18 @@ public class ImmutableIntList extends FlatLists.AbstractFlatList<Integer> {
     return ImmutableIntList.copyOf(Iterables.concat(this, list));
   }
 
+  /**
+   * Increments {@code offset} to each element of the list and
+   * returns a new int list.
+   */
+  public ImmutableIntList incr(int offset) {
+    final int[] integers = new int[ints.length];
+    for (int i = 0; i < ints.length; i++) {
+      integers[i] = ints[i] + offset;
+    }
+    return new ImmutableIntList(integers);
+  }
+
   /** Special sub-class of {@link ImmutableIntList} that is always
    * empty and has only one instance. */
   private static class EmptyImmutableIntList extends ImmutableIntList {
diff --git a/core/src/test/java/org/apache/calcite/rel/RelCollationTest.java b/core/src/test/java/org/apache/calcite/rel/RelCollationTest.java
index 08f9d3d..baaea3c 100644
--- a/core/src/test/java/org/apache/calcite/rel/RelCollationTest.java
+++ b/core/src/test/java/org/apache/calcite/rel/RelCollationTest.java
@@ -20,6 +20,8 @@ import org.apache.calcite.util.ImmutableIntList;
 import org.apache.calcite.util.mapping.Mapping;
 import org.apache.calcite.util.mapping.Mappings;
 
+import com.google.common.collect.Lists;
+
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
@@ -82,6 +84,20 @@ class RelCollationTest {
         is(true));
   }
 
+  /** Unit test for {@link RelCollations#containsOrderless(List, List)}. */
+  @Test void testCollationContainsOrderless() {
+    final List<RelCollation> collations = Lists.newArrayList(collation(2, 3, 1));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(2, 2)), is(true));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(2, 3)), is(true));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(3, 2)), is(true));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(3, 2, 1)), is(true));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(3, 2, 1, 0)), is(false));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(2, 3, 0)), is(false));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(1)), is(false));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(3, 1)), is(false));
+    assertThat(RelCollations.containsOrderless(collations, Arrays.asList(0)), is(false));
+  }
+
   /**
    * Unit test for {@link org.apache.calcite.rel.RelCollationImpl#compareTo}.
    */
diff --git a/core/src/test/java/org/apache/calcite/test/SqlHintsConverterTest.java b/core/src/test/java/org/apache/calcite/test/SqlHintsConverterTest.java
index adc40b2..4e7a3f6 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlHintsConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlHintsConverterTest.java
@@ -30,7 +30,9 @@ import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.plan.hep.HepPlanner;
 import org.apache.calcite.plan.hep.HepProgram;
 import org.apache.calcite.plan.hep.HepProgramBuilder;
+import org.apache.calcite.plan.volcano.AbstractConverter;
 import org.apache.calcite.plan.volcano.VolcanoPlanner;
+import org.apache.calcite.rel.RelCollationTraitDef;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelShuttleImpl;
 import org.apache.calcite.rel.RelVisitor;
@@ -440,6 +442,7 @@ class SqlHintsConverterTest extends SqlToRelTestBase {
         + "from emp join dept on emp.deptno = dept.deptno";
     RelOptPlanner planner = new VolcanoPlanner();
     planner.addRelTraitDef(ConventionTraitDef.INSTANCE);
+    planner.addRelTraitDef(RelCollationTraitDef.INSTANCE);
     Tester tester1 = tester.withDecorrelation(true)
         .withClusterFactory(
             relOptCluster -> RelOptCluster.create(planner, relOptCluster.getRexBuilder()));
@@ -448,7 +451,9 @@ class SqlHintsConverterTest extends SqlToRelTestBase {
         EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE,
         EnumerableRules.ENUMERABLE_JOIN_RULE,
         EnumerableRules.ENUMERABLE_PROJECT_RULE,
-        EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE);
+        EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE,
+        EnumerableRules.ENUMERABLE_SORT_RULE,
+        AbstractConverter.ExpandConversionRule.INSTANCE);
     Program program = Programs.of(ruleSet);
     RelTraitSet toTraits = rel
         .getCluster()
diff --git a/core/src/test/resources/org/apache/calcite/test/SqlHintsConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlHintsConverterTest.xml
index effae41..658e7cf 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlHintsConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlHintsConverterTest.xml
@@ -268,8 +268,10 @@ from emp join dept on emp.deptno = dept.deptno]]>
             <![CDATA[
 EnumerableProject(ENAME=[$1], JOB=[$2], SAL=[$5], NAME=[$10])
   EnumerableMergeJoin(condition=[=($7, $9)], joinType=[inner])
-    EnumerableTableScan(table=[[CATALOG, SALES, EMP]])
-    EnumerableTableScan(table=[[CATALOG, SALES, DEPT]])
+    EnumerableSort(sort0=[$7], dir0=[ASC])
+      EnumerableTableScan(table=[[CATALOG, SALES, EMP]])
+    EnumerableSort(sort0=[$0], dir0=[ASC])
+      EnumerableTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
         </Resource>
     </TestCase>