You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2021/05/28 18:50:49 UTC

[calcite] 02/03: [CALCITE-4620] Join on CASE causes AssertionError in RelToSqlConverter

This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit af5444ffdfd9e836b052ee58abec9430d55d4139
Author: Julian Hyde <jh...@apache.org>
AuthorDate: Mon May 24 18:35:53 2021 -0700

    [CALCITE-4620] Join on CASE causes AssertionError in RelToSqlConverter
    
    The fix is to remove the logic that generates ON conditions
    and instead use the logic that is used for all other
    expressions. This logic is more widely used, and therefore
    already handles cases such as CASE. (Previous bugs such as
    [CALCITE-3207] LIKE, [CALCITE-4610] Sarg, and
    [CALCITE-1422] IS NULL have been caused by this unnecessary
    split.)
    
    We have retained the logic that reverses 't1.id = t0.id' to
    't0.id = t1.id' if 't0' is from the left side of the join.
    
    We have reverted one test case, which was accidentally
    seeing 'mandatoryColumn IS NULL' simplified to 'FALSE',
    because the simplification was not the point of the test.
---
 .../calcite/rel/rel2sql/RelToSqlConverter.java     |  18 +--
 .../apache/calcite/rel/rel2sql/SqlImplementor.java | 149 +++++++--------------
 .../calcite/rel/rel2sql/RelToSqlConverterTest.java |  49 ++++++-
 3 files changed, 101 insertions(+), 115 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
index 918ee29..bd73cf4 100644
--- a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
+++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
@@ -212,18 +212,17 @@ public class RelToSqlConverter extends SqlImplementor
     final Result rightResult = visitInput(e, 1).resetAlias();
     final Context leftContext = leftResult.qualifiedContext();
     final Context rightContext = rightResult.qualifiedContext();
-    SqlNode sqlCondition = null;
+    final SqlNode sqlCondition;
     SqlLiteral condType = JoinConditionType.ON.symbol(POS);
     JoinType joinType = joinType(e.getJoinType());
     if (isCrossJoin(e)) {
+      sqlCondition = null;
       joinType = dialect.emulateJoinTypeForCrossJoin();
       condType = JoinConditionType.NONE.symbol(POS);
     } else {
-      sqlCondition = convertConditionToSqlNode(e.getCondition(),
-          leftContext,
-          rightContext,
-          e.getLeft().getRowType().getFieldCount(),
-          dialect);
+      sqlCondition =
+          convertConditionToSqlNode(e.getCondition(), leftContext,
+              rightContext);
     }
     SqlNode join =
         new SqlJoin(POS,
@@ -243,11 +242,8 @@ public class RelToSqlConverter extends SqlImplementor
     final Context rightContext = rightResult.qualifiedContext();
 
     final SqlSelect sqlSelect = leftResult.asSelect();
-    SqlNode sqlCondition = convertConditionToSqlNode(e.getCondition(),
-        leftContext,
-        rightContext,
-        e.getLeft().getRowType().getFieldCount(),
-        dialect);
+    SqlNode sqlCondition =
+        convertConditionToSqlNode(e.getCondition(), leftContext, rightContext);
     if (leftResult.neededAlias != null) {
       SqlShuttle visitor = new AliasReplacementShuttle(leftResult.neededAlias,
           e.getLeft().getRowType(), sqlSelect.getSelectList());
diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
index efc7d5d..9ea2f80 100644
--- a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
+++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
@@ -276,110 +276,21 @@ public abstract class SqlImplementor {
    * @param node            Join condition
    * @param leftContext     Left context
    * @param rightContext    Right context
-   * @param leftFieldCount  Number of fields on left result
+   *
    * @return SqlNode that represents the condition
    */
   public static SqlNode convertConditionToSqlNode(RexNode node,
       Context leftContext,
-      Context rightContext,
-      int leftFieldCount,
-      SqlDialect dialect) {
+      Context rightContext) {
     if (node.isAlwaysTrue()) {
       return SqlLiteral.createBoolean(true, POS);
     }
     if (node.isAlwaysFalse()) {
       return SqlLiteral.createBoolean(false, POS);
     }
-    if (node instanceof RexInputRef) {
-      Context joinContext = leftContext.implementor().joinContext(leftContext, rightContext);
-      return joinContext.toSql(null, node);
-    }
-    if (!(node instanceof RexCall)) {
-      throw new AssertionError(node);
-    }
-    final List<RexNode> operands;
-    final SqlOperator op;
-    final Context joinContext;
-    switch (node.getKind()) {
-    case AND:
-    case OR:
-      operands = ((RexCall) node).getOperands();
-      op = ((RexCall) node).getOperator();
-      final List<SqlNode> sqlOperands = new ArrayList<>();
-      for (RexNode operand : operands) {
-        sqlOperands.add(
-            convertConditionToSqlNode(operand, leftContext,
-                rightContext, leftFieldCount, dialect));
-      }
-      return SqlUtil.createCall(op, POS, sqlOperands);
-
-    case EQUALS:
-    case IS_DISTINCT_FROM:
-    case IS_NOT_DISTINCT_FROM:
-    case NOT_EQUALS:
-    case GREATER_THAN:
-    case GREATER_THAN_OR_EQUAL:
-    case LESS_THAN:
-    case LESS_THAN_OR_EQUAL:
-    case LIKE:
-    case NOT:
-      node = stripCastFromString(node, dialect);
-      operands = ((RexCall) node).getOperands();
-      op = ((RexCall) node).getOperator();
-      if (operands.size() == 2
-          && operands.get(0) instanceof RexInputRef
-          && operands.get(1) instanceof RexInputRef) {
-        final RexInputRef op0 = (RexInputRef) operands.get(0);
-        final RexInputRef op1 = (RexInputRef) operands.get(1);
-
-        if (op0.getIndex() < leftFieldCount
-            && op1.getIndex() >= leftFieldCount) {
-          // Arguments were of form 'op0 = op1'
-          return op.createCall(POS,
-              leftContext.field(op0.getIndex()),
-              rightContext.field(op1.getIndex() - leftFieldCount));
-        }
-        if (op1.getIndex() < leftFieldCount
-            && op0.getIndex() >= leftFieldCount) {
-          // Arguments were of form 'op1 = op0'
-          return requireNonNull(op.reverse()).createCall(POS,
-              leftContext.field(op1.getIndex()),
-              rightContext.field(op0.getIndex() - leftFieldCount));
-        }
-      }
-      joinContext =
-          leftContext.implementor().joinContext(leftContext, rightContext);
-      return joinContext.toSql(null, node);
-
-    case SEARCH:
-      final RexCall search = (RexCall) node;
-      final RexLiteral literal = (RexLiteral) search.operands.get(1);
-      final Sarg sarg = castNonNull(literal.getValueAs(Sarg.class));
-      joinContext =
-          leftContext.implementor().joinContext(leftContext, rightContext);
-      return joinContext.toSql(null, search.operands.get(0), literal.getType(),
-          sarg);
-
-    case IS_NULL:
-    case IS_NOT_NULL:
-      operands = ((RexCall) node).getOperands();
-      if (operands.size() == 1
-          && operands.get(0) instanceof RexInputRef) {
-        op = ((RexCall) node).getOperator();
-        final RexInputRef op0 = (RexInputRef) operands.get(0);
-        if (op0.getIndex() < leftFieldCount) {
-          return op.createCall(POS, leftContext.field(op0.getIndex()));
-        } else {
-          return op.createCall(POS,
-              rightContext.field(op0.getIndex() - leftFieldCount));
-        }
-      }
-      joinContext =
-          leftContext.implementor().joinContext(leftContext, rightContext);
-      return joinContext.toSql(null, node);
-    default:
-      throw new AssertionError(node);
-    }
+    final Context joinContext =
+        leftContext.implementor().joinContext(leftContext, rightContext);
+    return joinContext.toSql(null, node);
   }
 
   /** Removes cast from string.
@@ -819,8 +730,10 @@ public abstract class SqlImplementor {
       }
     }
 
-    private SqlNode callToSql(@Nullable RexProgram program, RexCall rex, boolean not) {
-      final RexCall call = (RexCall) stripCastFromString(rex, dialect);
+    private SqlNode callToSql(@Nullable RexProgram program, RexCall call0,
+        boolean not) {
+      final RexCall call1 = reverseCall(call0);
+      final RexCall call = (RexCall) stripCastFromString(call1, dialect);
       SqlOperator op = call.getOperator();
       switch (op.getKind()) {
       case SUM0:
@@ -844,9 +757,9 @@ public abstract class SqlImplementor {
       case CAST:
         // CURSOR is used inside CAST, like 'CAST ($0): CURSOR NOT NULL',
         // convert it to sql call of {@link SqlStdOperatorTable#CURSOR}.
-        RelDataType dataType = rex.getType();
+        final RelDataType dataType = call.getType();
         if (dataType.getSqlTypeName() == SqlTypeName.CURSOR) {
-          RexNode operand0 = ((RexCall) rex).operands.get(0);
+          final RexNode operand0 = call.operands.get(0);
           assert operand0 instanceof RexInputRef;
           int ordinal = ((RexInputRef) operand0).getIndex();
           SqlNode fieldOperand = field(ordinal);
@@ -865,6 +778,19 @@ public abstract class SqlImplementor {
       return SqlUtil.createCall(op, POS, nodeList);
     }
 
+    /** Reverses the order of a call, while preserving semantics, if it improves
+     * readability.
+     *
+     * <p>In the base implementation, this method does nothing;
+     * in a join context, reverses a call such as
+     * "e.deptno = d.deptno" to
+     * "d.deptno = e.deptno"
+     * if "d" is the left input to the join
+     * and "e" is the right. */
+    protected RexCall reverseCall(RexCall call) {
+      return call;
+    }
+
     /** If {@code node} is a {@link RexCall}, extracts the operator and
      * finds the corresponding inverse operator using {@link SqlOperator#not()}.
      * Returns null if {@code node} is not a {@link RexCall},
@@ -1587,7 +1513,34 @@ public abstract class SqlImplementor {
         return rightContext.field(ordinal - leftContext.fieldCount);
       }
     }
+
+    @Override protected RexCall reverseCall(RexCall call) {
+      switch (call.getKind()) {
+      case EQUALS:
+      case IS_DISTINCT_FROM:
+      case IS_NOT_DISTINCT_FROM:
+      case GREATER_THAN:
+      case GREATER_THAN_OR_EQUAL:
+      case LESS_THAN:
+      case LESS_THAN_OR_EQUAL:
+        assert call.operands.size() == 2;
+        final RexNode op0 = call.operands.get(0);
+        final RexNode op1 = call.operands.get(1);
+        if (op0 instanceof RexInputRef
+            && op1 instanceof RexInputRef
+            && ((RexInputRef) op1).getIndex() < leftContext.fieldCount
+            && ((RexInputRef) op0).getIndex() >= leftContext.fieldCount) {
+          // Arguments were of form 'op1 = op0'
+          final SqlOperator op2 = requireNonNull(call.getOperator().reverse());
+          return (RexCall) rexBuilder.makeCall(op2, op1, op0);
+        }
+        // fall through
+      default:
+        return call;
+      }
+    }
   }
+
   /** Context for translating call of a TableFunctionScan from {@link RexNode} to
    * {@link SqlNode}. */
   class TableFunctionScanContext extends BaseContext {
diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
index 75ea3fc..cbc1269 100644
--- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
@@ -2896,17 +2896,17 @@ class RelToSqlConverterTest {
         + "on \"t1\".\"product_id\" = \"t3\".\"product_id\" or "
         + "(\"t1\".\"product_id\" is not null or "
         + "\"t3\".\"product_id\" is not null)";
-    // Some of the "IS NULL" and "IS NOT NULL" are reduced to TRUE or FALSE,
-    // but not all.
-    String expected = "SELECT *\nFROM \"foodmart\".\"sales_fact_1997\"\n"
+    String expected = "SELECT *\n"
+        + "FROM \"foodmart\".\"sales_fact_1997\"\n"
         + "INNER JOIN \"foodmart\".\"customer\" "
         + "ON \"sales_fact_1997\".\"customer_id\" = \"customer\".\"customer_id\""
-        + " OR FALSE AND FALSE"
+        + " OR \"sales_fact_1997\".\"customer_id\" IS NULL"
+        + " AND \"customer\".\"customer_id\" IS NULL"
         + " OR \"customer\".\"occupation\" IS NULL\n"
         + "INNER JOIN \"foodmart\".\"product\" "
         + "ON \"sales_fact_1997\".\"product_id\" = \"product\".\"product_id\""
-        + " OR TRUE"
-        + " OR TRUE";
+        + " OR \"sales_fact_1997\".\"product_id\" IS NOT NULL"
+        + " OR \"product\".\"product_id\" IS NOT NULL";
     // The hook prevents RelBuilder from removing "FALSE AND FALSE" and such
     try (Hook.Closeable ignore =
              Hook.REL_BUILDER_SIMPLIFY.addThread(Hook.propertyJ(false))) {
@@ -2939,6 +2939,43 @@ class RelToSqlConverterTest {
   }
 
   /** Test case for
+   * <a href="https://issues.apache.org/jira/browse/CALCITE-4620">[CALCITE-4620]
+   * Join on CASE causes AssertionError in RelToSqlConverter</a>. */
+  @Test void testJoinOnCase() {
+    final String sql = "SELECT d.deptno, e.deptno\n"
+        + "FROM dept AS d LEFT JOIN emp AS e\n"
+        + " ON CASE WHEN e.job = 'PRESIDENT' THEN true ELSE d.deptno = 10 END\n"
+        + "WHERE e.job LIKE 'PRESIDENT'";
+    final String expected = "SELECT \"DEPT\".\"DEPTNO\","
+        + " \"EMP\".\"DEPTNO\" AS \"DEPTNO0\"\n"
+        + "FROM \"SCOTT\".\"DEPT\"\n"
+        + "LEFT JOIN \"SCOTT\".\"EMP\""
+        + " ON CASE WHEN \"EMP\".\"JOB\" = 'PRESIDENT' THEN TRUE"
+        + " ELSE CAST(\"DEPT\".\"DEPTNO\" AS INTEGER) = 10 END\n"
+        + "WHERE \"EMP\".\"JOB\" LIKE 'PRESIDENT'";
+    sql(sql)
+        .schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
+        .ok(expected);
+  }
+
+  @Test void testWhereCase() {
+    final String sql = "SELECT d.deptno, e.deptno\n"
+        + "FROM dept AS d LEFT JOIN emp AS e ON d.deptno = e.deptno\n"
+        + "WHERE CASE WHEN e.job = 'PRESIDENT' THEN true\n"
+        + "      ELSE d.deptno = 10 END\n";
+    final String expected = "SELECT \"DEPT\".\"DEPTNO\","
+        + " \"EMP\".\"DEPTNO\" AS \"DEPTNO0\"\n"
+        + "FROM \"SCOTT\".\"DEPT\"\n"
+        + "LEFT JOIN \"SCOTT\".\"EMP\""
+        + " ON \"DEPT\".\"DEPTNO\" = \"EMP\".\"DEPTNO\"\n"
+        + "WHERE CASE WHEN \"EMP\".\"JOB\" = 'PRESIDENT' THEN TRUE"
+        + " ELSE CAST(\"DEPT\".\"DEPTNO\" AS INTEGER) = 10 END";
+    sql(sql)
+        .schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
+        .ok(expected);
+  }
+
+  /** Test case for
    * <a href="https://issues.apache.org/jira/browse/CALCITE-1586">[CALCITE-1586]
    * JDBC adapter generates wrong SQL if UNION has more than two inputs</a>. */
   @Test void testThreeQueryUnion() {