You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2020/11/01 17:46:18 UTC

[calcite] 03/04: [CALCITE-4369] Support COUNTIF aggregate function for BigQuery (Aryeh Hillman)

This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit a5801bed0d0d74ad12b6742d64ffb07b4a05f674
Author: Aryeh Hillman <ar...@google.com>
AuthorDate: Mon Oct 12 17:29:57 2020 -0700

    [CALCITE-4369] Support COUNTIF aggregate function for BigQuery (Aryeh Hillman)
    
    In SQL reference, move COUNTIF to the list of dialect-specific
    functions; during SQL-to-Rel, transform to 'COUNT(*) FILTER
    (WHERE b)' rather than 'COUNT(CASE WHEN b THEN 1 END)' (Julian Hyde).
    
    Close apache/calcite#2235
---
 .../main/java/org/apache/calcite/sql/SqlKind.java  |  5 ++-
 .../main/java/org/apache/calcite/sql/SqlUtil.java  | 12 +++++--
 .../calcite/sql/fun/SqlLibraryOperators.java       | 12 +++++++
 .../apache/calcite/sql2rel/SqlToRelConverter.java  | 15 +++++---
 .../calcite/materialize/LatticeSuggesterTest.java  |  4 ++-
 .../calcite/sql/test/SqlOperatorBaseTest.java      | 27 +++++++++++++++
 core/src/test/resources/sql/agg.iq                 | 40 ++++++++++++++++++++++
 site/_docs/reference.md                            |  1 +
 8 files changed, 107 insertions(+), 9 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index e76727b..aa1332e 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -805,6 +805,9 @@ public enum SqlKind {
   /** The {@code STRING_AGG} aggregate function. */
   STRING_AGG,
 
+  /** The {@code COUNTIF} aggregate function. */
+  COUNTIF,
+
   /** The {@code ARRAY_AGG} aggregate function. */
   ARRAY_AGG,
 
@@ -1040,7 +1043,7 @@ public enum SqlKind {
           AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP, NTILE, COLLECT,
           FUSION, SINGLE_VALUE, ROW_NUMBER, RANK, PERCENT_RANK, DENSE_RANK,
           CUME_DIST, JSON_ARRAYAGG, JSON_OBJECTAGG, BIT_AND, BIT_OR, BIT_XOR,
-          LISTAGG, STRING_AGG, ARRAY_AGG, ARRAY_CONCAT_AGG,
+          LISTAGG, STRING_AGG, ARRAY_AGG, ARRAY_CONCAT_AGG, COUNTIF,
           INTERSECTION, ANY_VALUE);
 
   /**
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlUtil.java b/core/src/main/java/org/apache/calcite/sql/SqlUtil.java
index 1176fbd..9f9175d 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlUtil.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlUtil.java
@@ -66,6 +66,8 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.function.Predicate;
 import java.util.stream.Collectors;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import static org.apache.calcite.util.Static.RESOURCE;
 
@@ -75,9 +77,13 @@ import static org.apache.calcite.util.Static.RESOURCE;
 public abstract class SqlUtil {
   //~ Methods ----------------------------------------------------------------
 
-  static SqlNode andExpressions(
-      SqlNode node1,
-      SqlNode node2) {
+  /** Returns the AND of two expressions.
+   *
+   * <p>If {@code node1} is null, returns {@code node2}.
+   * Flattens if either node is an AND. */
+  public static @Nonnull SqlNode andExpressions(
+      @Nullable SqlNode node1,
+      @Nonnull SqlNode node2) {
     if (node1 == null) {
       return node2;
     }
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 9cb9e0f..8ff1b46 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -35,6 +35,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeTransforms;
 import org.apache.calcite.util.Litmus;
+import org.apache.calcite.util.Optionality;
 
 import com.google.common.collect.ImmutableList;
 
@@ -269,6 +270,17 @@ public abstract class SqlLibraryOperators {
   public static final SqlAggFunction LOGICAL_OR =
       new SqlMinMaxAggFunction("LOGICAL_OR", SqlKind.MAX, OperandTypes.BOOLEAN);
 
+  /** The "COUNTIF(condition) [OVER (...)]" function, in BigQuery,
+   * returns the count of TRUE values for expression.
+   *
+   * <p>{@code COUNTIF(b)} is equivalent to
+   * {@code COUNT(*) FILTER (WHERE b)}. */
+  @LibraryOperator(libraries = {BIG_QUERY})
+  public static final SqlAggFunction COUNTIF =
+      SqlBasicAggFunction
+          .create(SqlKind.COUNTIF, ReturnTypes.BIGINT, OperandTypes.BOOLEAN)
+          .withDistinct(Optionality.FORBIDDEN);
+
   /** The "ARRAY_AGG(value [ ORDER BY ...])" aggregate function,
    * in BigQuery and PostgreSQL, gathers values into arrays. */
   @LibraryOperator(libraries = {POSTGRESQL, BIG_QUERY})
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
index fab36f2..39ca7dd 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
@@ -5437,6 +5437,7 @@ public class SqlToRelConverter {
       assert bb.agg == this;
       assert outerCall != null;
       final List<SqlNode> operands = call.getOperandList();
+      final SqlParserPos pos = call.getParserPosition();
       switch (call.getKind()) {
       case FILTER:
         assert filter == null;
@@ -5455,6 +5456,14 @@ public class SqlToRelConverter {
         translateAgg(call.operand(0), filter, orderList, ignoreNulls,
             outerCall);
         return;
+      case COUNTIF:
+        // COUNTIF(b)  ==> COUNT(*) FILTER (WHERE b)
+        // COUNTIF(b) FILTER (WHERE b2)  ==> COUNT(*) FILTER (WHERE b2 AND b)
+        final SqlCall call4 =
+            SqlStdOperatorTable.COUNT.createCall(pos, SqlIdentifier.star(pos));
+        final SqlNode filter2 = SqlUtil.andExpressions(filter, call.operand(0));
+        translateAgg(call4, filter2, orderList, ignoreNulls, outerCall);
+        return;
       case STRING_AGG:
         // Translate "STRING_AGG(s, sep ORDER BY x, y)"
         // as if it were "LISTAGG(s, sep) WITHIN GROUP (ORDER BY x, y)";
@@ -5469,8 +5478,7 @@ public class SqlToRelConverter {
         }
         final SqlCall call2 =
             SqlStdOperatorTable.LISTAGG.createCall(
-                call.getFunctionQuantifier(), call.getParserPosition(),
-                operands2);
+                call.getFunctionQuantifier(), pos, operands2);
         translateAgg(call2, filter, orderList, ignoreNulls, outerCall);
         return;
       case ARRAY_AGG:
@@ -5483,8 +5491,7 @@ public class SqlToRelConverter {
           orderList = (SqlNodeList) Util.last(operands);
           final SqlCall call3 =
               call.getOperator().createCall(
-                  call.getFunctionQuantifier(), call.getParserPosition(),
-                  Util.skipLast(operands));
+                  call.getFunctionQuantifier(), pos, Util.skipLast(operands));
           translateAgg(call3, filter, orderList, ignoreNulls, outerCall);
           return;
         }
diff --git a/core/src/test/java/org/apache/calcite/materialize/LatticeSuggesterTest.java b/core/src/test/java/org/apache/calcite/materialize/LatticeSuggesterTest.java
index 2fb2e4e..5ff1e95 100644
--- a/core/src/test/java/org/apache/calcite/materialize/LatticeSuggesterTest.java
+++ b/core/src/test/java/org/apache/calcite/materialize/LatticeSuggesterTest.java
@@ -620,13 +620,15 @@ class LatticeSuggesterTest {
   }
 
   /** Tests a number of features only available in BigQuery: back-ticks;
-   * GROUP BY ordinal; case-insensitive unquoted identifiers. */
+   * GROUP BY ordinal; case-insensitive unquoted identifiers;
+   * the {@code COUNTIF} aggregate function. */
   @Test void testBigQueryDialect() throws Exception {
     final Tester t = new Tester().foodmart().withEvolve(true)
         .withDialect(SqlDialect.DatabaseProduct.BIG_QUERY.getDialect())
         .withLibrary(SqlLibrary.BIG_QUERY);
 
     final String q0 = "select `product_id`,\n"
+        + "  countif(unit_sales > 1000) as num_over_thousand,\n"
         + "  SUM(unit_sales)\n"
         + "from\n"
         + "  `sales_fact_1997`"
diff --git a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
index 7aeda68..827ec84 100644
--- a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
+++ b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
@@ -8602,6 +8602,33 @@ public abstract class SqlOperatorBaseTest {
     tester.checkAgg("COUNT(DISTINCT 123)", stringValues, 1, 0d);
   }
 
+  @Test void testCountifFunc() {
+    tester.setFor(SqlLibraryOperators.COUNTIF, VM_FENNEL, VM_JAVA);
+    final SqlTester tester = libraryTester(SqlLibrary.BIG_QUERY);
+    tester.checkType("countif(true)", "BIGINT NOT NULL");
+    tester.checkType("countif(nullif(true,true))", "BIGINT NOT NULL");
+    tester.checkType("countif(false) filter (where true)", "BIGINT NOT NULL");
+
+    final String expectedError = "Invalid number of arguments to function "
+        + "'COUNTIF'. Was expecting 1 arguments";
+    tester.checkFails("^COUNTIF()^", expectedError, false);
+    tester.checkFails("^COUNTIF(true, false)^", expectedError, false);
+    final String expectedError2 = "Cannot apply 'COUNTIF' to arguments of "
+        + "type 'COUNTIF\\(<INTEGER>\\)'\\. Supported form\\(s\\): "
+        + "'COUNTIF\\(<BOOLEAN>\\)'";
+    tester.checkFails("^COUNTIF(1)^", expectedError2, false);
+
+    final String[] values = {"1", "2", "CAST(NULL AS INTEGER)", "1"};
+    tester.checkAgg("countif(x > 0)", values, 3, 0d);
+    tester.checkAgg("countif(x < 2)", values, 2, 0d);
+    tester.checkAgg("countif(x is not null) filter (where x < 2)",
+        values, 2, 0d);
+    tester.checkAgg("countif(x < 2) filter (where x is not null)",
+        values, 2, 0d);
+    tester.checkAgg("countif(x between 1 and 2)", values, 3, 0d);
+    tester.checkAgg("countif(x < 0)", values, 0, 0d);
+  }
+
   @Test void testApproxCountDistinctFunc() {
     tester.setFor(SqlStdOperatorTable.COUNT, VM_EXPAND);
     tester.checkFails("approx_count_distinct(^*^)", "Unknown identifier '\\*'",
diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq
index 25e627b..2dba1c3 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -2890,4 +2890,44 @@ from emp group by gender;
 
 !ok
 
+# COUNTIF(b) (BigQuery) is equivalent to COUNT(*) FILTER (WHERE b)
+select deptno, countif(gender = 'F') as f
+from emp
+group by deptno;
++--------+---+
+| DEPTNO | F |
++--------+---+
+|     10 | 1 |
+|     20 | 0 |
+|     30 | 2 |
+|     50 | 1 |
+|     60 | 1 |
+|        | 1 |
++--------+---+
+(6 rows)
+
+!ok
+
+select countif(gender = 'F') filter (where deptno = 30) as f
+from emp;
++---+
+| F |
++---+
+| 2 |
++---+
+(1 row)
+
+!ok
+
+select countif(a > 0) + countif(a > 1) + countif(c > 1) as c
+from (select 1 as a, 2 as b, 3 as c);
++---+
+| C |
++---+
+| 2 |
++---+
+(1 row)
+
+!ok
+
 # End agg.iq
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index 3f995e7..260f8a3 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -2522,6 +2522,7 @@ Dialect-specific aggregate functions.
 | b p | ARRAY_CONCAT_AGG( [ ALL &#124; DISTINCT ] value [ ORDER BY orderItem [, orderItem ]* ] ) | Concatenates arrays into arrays
 | p | BOOL_AND(condition)                            | Synonym for `EVERY`
 | p | BOOL_OR(condition)                             | Synonym for `SOME`
+| b | COUNTIF(condition)                             | Returns the number of rows for which *condition* is TRUE; equivalent to `COUNT(*) FILTER (WHERE condition)`
 | b | LOGICAL_AND(condition)                         | Synonym for `EVERY`
 | b | LOGICAL_OR(condition)                          | Synonym for `SOME`
 | b p | STRING_AGG( [ ALL &#124; DISTINCT ] value [, separator] [ ORDER BY orderItem [, orderItem ]* ] ) | Synonym for `LISTAGG`