You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by li...@apache.org on 2022/11/30 13:26:58 UTC

[calcite] branch main updated: [CALCITE-5230] Return type of PERCENTILE_DISC should be the same as sort expression

This is an automated email from the ASF dual-hosted git repository.

libenchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 9161a6f529 [CALCITE-5230] Return type of PERCENTILE_DISC should be the same as sort expression
9161a6f529 is described below

commit 9161a6f529775076eece3cb78bf4a898411e15ee
Author: Itiel Sadeh <it...@sqreamtech.com>
AuthorDate: Tue Aug 9 14:29:55 2022 +0300

    [CALCITE-5230] Return type of PERCENTILE_DISC should be the same as sort expression
    
    Close #2868
---
 .../org/apache/calcite/rel/core/Aggregate.java     | 17 +++++++
 .../org/apache/calcite/rel/core/AggregateCall.java |  9 ++++
 .../org/apache/calcite/sql/SqlOperatorBinding.java |  8 +++
 .../apache/calcite/sql/SqlWithinGroupOperator.java | 59 +++++++++++++++++++---
 .../calcite/sql/fun/SqlStdOperatorTable.java       |  7 ++-
 .../org/apache/calcite/sql/type/ReturnTypes.java   |  3 ++
 .../org/apache/calcite/test/SqlValidatorTest.java  |  2 +-
 .../org/apache/calcite/test/SqlOperatorTest.java   |  2 +-
 8 files changed, 95 insertions(+), 12 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/core/Aggregate.java b/core/src/main/java/org/apache/calcite/rel/core/Aggregate.java
index abb56ba0b4..568a1b45d4 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/Aggregate.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/Aggregate.java
@@ -622,4 +622,21 @@ public abstract class Aggregate extends SingleRel implements Hintable {
       return SqlUtil.newContextException(SqlParserPos.ZERO, e);
     }
   }
+
+  /** Used for PERCENTILE_DISC return type inference. */
+  public static class PercentileDiscAggCallBinding extends AggCallBinding {
+    private final RelDataType collationType;
+
+    PercentileDiscAggCallBinding(RelDataTypeFactory typeFactory, SqlAggFunction aggFunction,
+        List<RelDataType> operands, RelDataType collationType, int groupCount,
+        boolean filter) {
+      super(typeFactory, aggFunction, operands, groupCount, filter);
+      assert aggFunction.isPercentile();
+      this.collationType = collationType;
+    }
+
+    @Override public RelDataType getCollationType() {
+      return collationType;
+    }
+  }
 }
diff --git a/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java b/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java
index b3bff0ffb5..230b79fcc1 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java
@@ -23,6 +23,7 @@ import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.Optionality;
@@ -428,6 +429,14 @@ public class AggregateCall {
       Aggregate aggregateRelBase) {
     final RelDataType rowType = aggregateRelBase.getInput().getRowType();
 
+    if (aggFunction.getKind() == SqlKind.PERCENTILE_DISC) {
+      assert collation.getKeys().size() == 1;
+      return new Aggregate.PercentileDiscAggCallBinding(
+          aggregateRelBase.getCluster().getTypeFactory(), aggFunction,
+          SqlTypeUtil.projectTypes(rowType, argList),
+          SqlTypeUtil.projectTypes(rowType, collation.getKeys()).get(0),
+          aggregateRelBase.getGroupCount(), hasFilter());
+    }
     return new Aggregate.AggCallBinding(
         aggregateRelBase.getCluster().getTypeFactory(), aggFunction,
         SqlTypeUtil.projectTypes(rowType, argList),
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlOperatorBinding.java b/core/src/main/java/org/apache/calcite/sql/SqlOperatorBinding.java
index 9711385f6f..6a08a56d1d 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlOperatorBinding.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlOperatorBinding.java
@@ -216,6 +216,14 @@ public abstract class SqlOperatorBinding {
     return SqlMonotonicity.NOT_MONOTONIC;
   }
 
+
+  /**
+   * Returns the collation type.
+   */
+  public RelDataType getCollationType() {
+    throw new UnsupportedOperationException();
+  }
+
   /**
    * Collects the types of the bound operands into a list.
    *
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlWithinGroupOperator.java b/core/src/main/java/org/apache/calcite/sql/SqlWithinGroupOperator.java
index 596d9518ed..9024a21bd7 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlWithinGroupOperator.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlWithinGroupOperator.java
@@ -19,10 +19,14 @@ package org.apache.calcite.sql;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.sql.validate.SqlValidator;
+import org.apache.calcite.sql.validate.SqlValidatorNamespace;
 import org.apache.calcite.sql.validate.SqlValidatorScope;
 import org.apache.calcite.sql.validate.SqlValidatorUtil;
 
+import org.checkerframework.checker.nullness.qual.Nullable;
+
 import java.util.Objects;
 
 import static org.apache.calcite.util.Static.RESOURCE;
@@ -62,12 +66,13 @@ public class SqlWithinGroupOperator extends SqlBinaryOperator {
       SqlValidatorScope operandScope) {
     assert call.getOperator() == this;
     assert call.operandCount() == 2;
+
     final SqlValidatorUtil.FlatAggregate flat = SqlValidatorUtil.flatten(call);
-    if (!flat.aggregateCall.getOperator().isAggregator()) {
-      throw validator.newValidationError(call,
-          RESOURCE.withinGroupNotAllowed(
-              flat.aggregateCall.getOperator().getName()));
+    final SqlOperator operator = flat.aggregateCall.getOperator();
+    if (!operator.isAggregator()) {
+      throw validator.newValidationError(call, RESOURCE.withinGroupNotAllowed(operator.getName()));
     }
+
     for (SqlNode order : Objects.requireNonNull(flat.orderList)) {
       Objects.requireNonNull(validator.deriveType(scope, order));
     }
@@ -79,7 +84,49 @@ public class SqlWithinGroupOperator extends SqlBinaryOperator {
       SqlValidator validator,
       SqlValidatorScope scope,
       SqlCall call) {
-    // Validate type of the inner aggregate call
-    return validateOperands(validator, scope, call);
+
+    SqlCall inner = call.operand(0);
+    final SqlOperator operator = inner.getOperator();
+    if (!operator.isAggregator()) {
+      throw validator.newValidationError(call, RESOURCE.withinGroupNotAllowed(operator.getName()));
+    }
+
+    if (inner.getOperator().getKind() == SqlKind.PERCENTILE_DISC) {
+      // We first check the percentile call operands, and then derive the correct type using
+      // PercentileDiscCallBinding (See CALCITE-5230).
+      SqlCallBinding opBinding =
+          new PercentileDiscCallBinding(validator, scope, inner, getCollationColumn(call));
+      inner.getOperator().checkOperandTypes(opBinding, true);
+      RelDataType ret = inner.getOperator().inferReturnType(opBinding);
+      validator.setValidatedNodeType(inner, ret);
+      return ret;
+    } else {
+      return validateOperands(validator, scope, call);
+    }
+  }
+
+  private SqlNode getCollationColumn(SqlCall call) {
+    return ((SqlNodeList) call.operand(1)).get(0);
+  }
+
+  /**
+   * Used for PERCENTILE_DISC return type inference.
+   */
+  public static class PercentileDiscCallBinding extends SqlCallBinding {
+    private final SqlNode collationColumn;
+
+    private PercentileDiscCallBinding(SqlValidator validator,
+        @Nullable SqlValidatorScope scope,
+        SqlCall call,
+        SqlNode collation) {
+      super(validator, scope, call);
+      this.collationColumn = collation;
+    }
+
+    @Override public RelDataType getCollationType() {
+      final RelDataType type = SqlTypeUtil.deriveType(this, collationColumn);
+      final SqlValidatorNamespace namespace = super.getValidator().getNamespace(collationColumn);
+      return namespace != null ? namespace.getType() : type;
+    }
   }
 }
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
index 18d145e667..e2fdcebe5c 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
@@ -2299,13 +2299,12 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
    * {@code PERCENTILE_DISC} inverse distribution aggregate function.
    *
    * <p>The argument must be a numeric literal in the range 0 to 1 inclusive
-   * (representing a percentage), and the return type is {@code DOUBLE}.
-   * (The return type should determined by the type of the {@code ORDER BY}
-   * expression, but this cannot be determined by the function itself.)
+   * (representing a percentage), and the return type is the type of the
+   * {@code ORDER BY} expression.
    */
   public static final SqlAggFunction PERCENTILE_DISC =
       SqlBasicAggFunction
-          .create(SqlKind.PERCENTILE_DISC, ReturnTypes.DOUBLE,
+          .create(SqlKind.PERCENTILE_DISC, ReturnTypes.PERCENTILE_DISC,
               OperandTypes.UNIT_INTERVAL_NUMERIC_LITERAL)
           .withFunctionType(SqlFunctionCategory.SYSTEM)
           .withGroupOrder(Optionality.MANDATORY)
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index 4aede093be..17571c2e3d 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -1001,4 +1001,7 @@ public abstract class ReturnTypes {
       return relDataType;
     }
   };
+
+  public static final SqlReturnTypeInference PERCENTILE_DISC = opBinding ->
+      opBinding.getCollationType();
 }
diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
index b25fd2b666..5709ccc0de 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
@@ -7929,7 +7929,7 @@ public class SqlValidatorTest extends SqlValidatorTestCase {
         + "from emp\n"
         + "group by deptno";
     sql(sql)
-        .type("RecordType(DOUBLE NOT NULL C, DOUBLE NOT NULL D) NOT NULL");
+        .type("RecordType(DOUBLE NOT NULL C, INTEGER NOT NULL D) NOT NULL");
   }
 
   /** Tests that {@code PERCENTILE_CONT} only allows numeric fields. */
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 70861df1c5..0d7b530566 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -7894,7 +7894,7 @@ public class SqlOperatorTest {
     final SqlOperatorFixture f = fixture();
     f.setFor(SqlStdOperatorTable.PERCENTILE_DISC, VM_FENNEL, VM_JAVA);
     f.checkType("percentile_disc(0.25) within group (order by 1)",
-        "DOUBLE NOT NULL");
+        "INTEGER NOT NULL");
     f.checkFails("percentile_disc(0.25) within group (^order by 'a'^)",
         "Invalid type 'CHAR' in ORDER BY clause of 'PERCENTILE_DISC' function. "
             + "Only NUMERIC types are supported", false);