You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2020/11/01 17:46:18 UTC
[calcite] 03/04: [CALCITE-4369] Support COUNTIF aggregate function
for BigQuery (Aryeh Hillman)
This is an automated email from the ASF dual-hosted git repository.
jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
commit a5801bed0d0d74ad12b6742d64ffb07b4a05f674
Author: Aryeh Hillman <ar...@google.com>
AuthorDate: Mon Oct 12 17:29:57 2020 -0700
[CALCITE-4369] Support COUNTIF aggregate function for BigQuery (Aryeh Hillman)
In SQL reference, move COUNTIF to the list of dialect-specific
functions; during SQL-to-Rel, transform to 'COUNT(*) FILTER
(WHERE b)' rather than 'COUNT(CASE WHEN b THEN 1 END)' (Julian Hyde).
Close apache/calcite#2235
---
.../main/java/org/apache/calcite/sql/SqlKind.java | 5 ++-
.../main/java/org/apache/calcite/sql/SqlUtil.java | 12 +++++--
.../calcite/sql/fun/SqlLibraryOperators.java | 12 +++++++
.../apache/calcite/sql2rel/SqlToRelConverter.java | 15 +++++---
.../calcite/materialize/LatticeSuggesterTest.java | 4 ++-
.../calcite/sql/test/SqlOperatorBaseTest.java | 27 +++++++++++++++
core/src/test/resources/sql/agg.iq | 40 ++++++++++++++++++++++
site/_docs/reference.md | 1 +
8 files changed, 107 insertions(+), 9 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index e76727b..aa1332e 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -805,6 +805,9 @@ public enum SqlKind {
/** The {@code STRING_AGG} aggregate function. */
STRING_AGG,
+ /** The {@code COUNTIF} aggregate function. */
+ COUNTIF,
+
/** The {@code ARRAY_AGG} aggregate function. */
ARRAY_AGG,
@@ -1040,7 +1043,7 @@ public enum SqlKind {
AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP, NTILE, COLLECT,
FUSION, SINGLE_VALUE, ROW_NUMBER, RANK, PERCENT_RANK, DENSE_RANK,
CUME_DIST, JSON_ARRAYAGG, JSON_OBJECTAGG, BIT_AND, BIT_OR, BIT_XOR,
- LISTAGG, STRING_AGG, ARRAY_AGG, ARRAY_CONCAT_AGG,
+ LISTAGG, STRING_AGG, ARRAY_AGG, ARRAY_CONCAT_AGG, COUNTIF,
INTERSECTION, ANY_VALUE);
/**
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlUtil.java b/core/src/main/java/org/apache/calcite/sql/SqlUtil.java
index 1176fbd..9f9175d 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlUtil.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlUtil.java
@@ -66,6 +66,8 @@ import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Collectors;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import static org.apache.calcite.util.Static.RESOURCE;
@@ -75,9 +77,13 @@ import static org.apache.calcite.util.Static.RESOURCE;
public abstract class SqlUtil {
//~ Methods ----------------------------------------------------------------
- static SqlNode andExpressions(
- SqlNode node1,
- SqlNode node2) {
+ /** Returns the AND of two expressions.
+ *
+ * <p>If {@code node1} is null, returns {@code node2}.
+ * Flattens if either node is an AND. */
+ public static @Nonnull SqlNode andExpressions(
+ @Nullable SqlNode node1,
+ @Nonnull SqlNode node2) {
if (node1 == null) {
return node2;
}
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 9cb9e0f..8ff1b46 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -35,6 +35,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.util.Litmus;
+import org.apache.calcite.util.Optionality;
import com.google.common.collect.ImmutableList;
@@ -269,6 +270,17 @@ public abstract class SqlLibraryOperators {
public static final SqlAggFunction LOGICAL_OR =
new SqlMinMaxAggFunction("LOGICAL_OR", SqlKind.MAX, OperandTypes.BOOLEAN);
+ /** The "COUNTIF(condition) [OVER (...)]" function, in BigQuery,
+ * returns the count of TRUE values for expression.
+ *
+ * <p>{@code COUNTIF(b)} is equivalent to
+ * {@code COUNT(*) FILTER (WHERE b)}. */
+ @LibraryOperator(libraries = {BIG_QUERY})
+ public static final SqlAggFunction COUNTIF =
+ SqlBasicAggFunction
+ .create(SqlKind.COUNTIF, ReturnTypes.BIGINT, OperandTypes.BOOLEAN)
+ .withDistinct(Optionality.FORBIDDEN);
+
/** The "ARRAY_AGG(value [ ORDER BY ...])" aggregate function,
* in BigQuery and PostgreSQL, gathers values into arrays. */
@LibraryOperator(libraries = {POSTGRESQL, BIG_QUERY})
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
index fab36f2..39ca7dd 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
@@ -5437,6 +5437,7 @@ public class SqlToRelConverter {
assert bb.agg == this;
assert outerCall != null;
final List<SqlNode> operands = call.getOperandList();
+ final SqlParserPos pos = call.getParserPosition();
switch (call.getKind()) {
case FILTER:
assert filter == null;
@@ -5455,6 +5456,14 @@ public class SqlToRelConverter {
translateAgg(call.operand(0), filter, orderList, ignoreNulls,
outerCall);
return;
+ case COUNTIF:
+ // COUNTIF(b) ==> COUNT(*) FILTER (WHERE b)
+ // COUNTIF(b) FILTER (WHERE b2) ==> COUNT(*) FILTER (WHERE b2 AND b)
+ final SqlCall call4 =
+ SqlStdOperatorTable.COUNT.createCall(pos, SqlIdentifier.star(pos));
+ final SqlNode filter2 = SqlUtil.andExpressions(filter, call.operand(0));
+ translateAgg(call4, filter2, orderList, ignoreNulls, outerCall);
+ return;
case STRING_AGG:
// Translate "STRING_AGG(s, sep ORDER BY x, y)"
// as if it were "LISTAGG(s, sep) WITHIN GROUP (ORDER BY x, y)";
@@ -5469,8 +5478,7 @@ public class SqlToRelConverter {
}
final SqlCall call2 =
SqlStdOperatorTable.LISTAGG.createCall(
- call.getFunctionQuantifier(), call.getParserPosition(),
- operands2);
+ call.getFunctionQuantifier(), pos, operands2);
translateAgg(call2, filter, orderList, ignoreNulls, outerCall);
return;
case ARRAY_AGG:
@@ -5483,8 +5491,7 @@ public class SqlToRelConverter {
orderList = (SqlNodeList) Util.last(operands);
final SqlCall call3 =
call.getOperator().createCall(
- call.getFunctionQuantifier(), call.getParserPosition(),
- Util.skipLast(operands));
+ call.getFunctionQuantifier(), pos, Util.skipLast(operands));
translateAgg(call3, filter, orderList, ignoreNulls, outerCall);
return;
}
diff --git a/core/src/test/java/org/apache/calcite/materialize/LatticeSuggesterTest.java b/core/src/test/java/org/apache/calcite/materialize/LatticeSuggesterTest.java
index 2fb2e4e..5ff1e95 100644
--- a/core/src/test/java/org/apache/calcite/materialize/LatticeSuggesterTest.java
+++ b/core/src/test/java/org/apache/calcite/materialize/LatticeSuggesterTest.java
@@ -620,13 +620,15 @@ class LatticeSuggesterTest {
}
/** Tests a number of features only available in BigQuery: back-ticks;
- * GROUP BY ordinal; case-insensitive unquoted identifiers. */
+ * GROUP BY ordinal; case-insensitive unquoted identifiers;
+ * the {@code COUNTIF} aggregate function. */
@Test void testBigQueryDialect() throws Exception {
final Tester t = new Tester().foodmart().withEvolve(true)
.withDialect(SqlDialect.DatabaseProduct.BIG_QUERY.getDialect())
.withLibrary(SqlLibrary.BIG_QUERY);
final String q0 = "select `product_id`,\n"
+ + " countif(unit_sales > 1000) as num_over_thousand,\n"
+ " SUM(unit_sales)\n"
+ "from\n"
+ " `sales_fact_1997`"
diff --git a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
index 7aeda68..827ec84 100644
--- a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
+++ b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
@@ -8602,6 +8602,33 @@ public abstract class SqlOperatorBaseTest {
tester.checkAgg("COUNT(DISTINCT 123)", stringValues, 1, 0d);
}
+ @Test void testCountifFunc() {
+ tester.setFor(SqlLibraryOperators.COUNTIF, VM_FENNEL, VM_JAVA);
+ final SqlTester tester = libraryTester(SqlLibrary.BIG_QUERY);
+ tester.checkType("countif(true)", "BIGINT NOT NULL");
+ tester.checkType("countif(nullif(true,true))", "BIGINT NOT NULL");
+ tester.checkType("countif(false) filter (where true)", "BIGINT NOT NULL");
+
+ final String expectedError = "Invalid number of arguments to function "
+ + "'COUNTIF'. Was expecting 1 arguments";
+ tester.checkFails("^COUNTIF()^", expectedError, false);
+ tester.checkFails("^COUNTIF(true, false)^", expectedError, false);
+ final String expectedError2 = "Cannot apply 'COUNTIF' to arguments of "
+ + "type 'COUNTIF\\(<INTEGER>\\)'\\. Supported form\\(s\\): "
+ + "'COUNTIF\\(<BOOLEAN>\\)'";
+ tester.checkFails("^COUNTIF(1)^", expectedError2, false);
+
+ final String[] values = {"1", "2", "CAST(NULL AS INTEGER)", "1"};
+ tester.checkAgg("countif(x > 0)", values, 3, 0d);
+ tester.checkAgg("countif(x < 2)", values, 2, 0d);
+ tester.checkAgg("countif(x is not null) filter (where x < 2)",
+ values, 2, 0d);
+ tester.checkAgg("countif(x < 2) filter (where x is not null)",
+ values, 2, 0d);
+ tester.checkAgg("countif(x between 1 and 2)", values, 3, 0d);
+ tester.checkAgg("countif(x < 0)", values, 0, 0d);
+ }
+
@Test void testApproxCountDistinctFunc() {
tester.setFor(SqlStdOperatorTable.COUNT, VM_EXPAND);
tester.checkFails("approx_count_distinct(^*^)", "Unknown identifier '\\*'",
diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq
index 25e627b..2dba1c3 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -2890,4 +2890,44 @@ from emp group by gender;
!ok
+# COUNTIF(b) (BigQuery) is equivalent to COUNT(*) FILTER (WHERE b)
+select deptno, countif(gender = 'F') as f
+from emp
+group by deptno;
++--------+---+
+| DEPTNO | F |
++--------+---+
+| 10 | 1 |
+| 20 | 0 |
+| 30 | 2 |
+| 50 | 1 |
+| 60 | 1 |
+| | 1 |
++--------+---+
+(6 rows)
+
+!ok
+
+select countif(gender = 'F') filter (where deptno = 30) as f
+from emp;
++---+
+| F |
++---+
+| 2 |
++---+
+(1 row)
+
+!ok
+
+select countif(a > 0) + countif(a > 1) + countif(c > 1) as c
+from (select 1 as a, 2 as b, 3 as c);
++---+
+| C |
++---+
+| 2 |
++---+
+(1 row)
+
+!ok
+
# End agg.iq
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index 3f995e7..260f8a3 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -2522,6 +2522,7 @@ Dialect-specific aggregate functions.
| b p | ARRAY_CONCAT_AGG( [ ALL | DISTINCT ] value [ ORDER BY orderItem [, orderItem ]* ] ) | Concatenates arrays into arrays
| p | BOOL_AND(condition) | Synonym for `EVERY`
| p | BOOL_OR(condition) | Synonym for `SOME`
+| b | COUNTIF(condition) | Returns the number of rows for which *condition* is TRUE; equivalent to `COUNT(*) FILTER (WHERE condition)`
| b | LOGICAL_AND(condition) | Synonym for `EVERY`
| b | LOGICAL_OR(condition) | Synonym for `SOME`
| b p | STRING_AGG( [ ALL | DISTINCT ] value [, separator] [ ORDER BY orderItem [, orderItem ]* ] ) | Synonym for `LISTAGG`