You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by go...@apache.org on 2023/01/18 14:14:40 UTC

[flink] branch master updated: [FLINK-29719][hive] Supports native count function for hive dialect

This is an automated email from the ASF dual-hosted git repository.

godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 606f297198a [FLINK-29719][hive] Supports native count function for hive dialect
606f297198a is described below

commit 606f297198acd74a5c1a39700bd84ad9e26e7b82
Author: fengli <ld...@163.com>
AuthorDate: Wed Jan 4 15:31:03 2023 +0800

    [FLINK-29719][hive] Supports native count function for hive dialect
    
    This closes #21596
---
 .../table/functions/hive/HiveCountAggFunction.java | 116 +++++++++++++++++++++
 .../apache/flink/table/module/hive/HiveModule.java |   6 +-
 .../connectors/hive/HiveDialectAggITCase.java      |  72 +++++++++++--
 .../connectors/hive/HiveDialectQueryPlanTest.java  |  24 ++++-
 .../explain/testCountAggFunctionFallbackPlan.out   |  35 +++++++
 .../resources/explain/testCountAggFunctionPlan.out |  27 +++++
 6 files changed, 268 insertions(+), 12 deletions(-)

diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java
new file mode 100644
index 00000000000..e15a0cbaf3d
--- /dev/null
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveCountAggFunction.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.hive;
+
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.expressions.Expression;
+import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
+import org.apache.flink.table.planner.expressions.ExpressionBuilder;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.CallContext;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
+
+/** built-in hive count aggregate function. */
+public class HiveCountAggFunction extends HiveDeclarativeAggregateFunction {
+
+    private final UnresolvedReferenceExpression count = unresolvedRef("count");
+    private Integer arguments;
+    private boolean countLiteral;
+
+    @Override
+    public int operandCount() {
+        return arguments;
+    }
+
+    @Override
+    public UnresolvedReferenceExpression[] aggBufferAttributes() {
+        return new UnresolvedReferenceExpression[] {count};
+    }
+
+    @Override
+    public DataType[] getAggBufferTypes() {
+        return new DataType[] {DataTypes.BIGINT()};
+    }
+
+    @Override
+    public DataType getResultType() {
+        return DataTypes.BIGINT();
+    }
+
+    @Override
+    public Expression[] initialValuesExpressions() {
+        return new Expression[] {/* count = */ literal(0L, getResultType().notNull())};
+    }
+
+    @Override
+    public Expression[] accumulateExpressions() {
+        // count(*) and count(literal) mean that count all elements
+        if (arguments == 0 || countLiteral) {
+            return new Expression[] {/* count = */ plus(count, literal(1L))};
+        }
+
+        // other case need to determine the value of the element
+        List<Expression> operandExpressions = new ArrayList<>();
+        for (int i = 0; i < arguments; i++) {
+            operandExpressions.add(operand(i));
+        }
+        Expression operandExpression =
+                operandExpressions.stream()
+                        .map(ExpressionBuilder::isNull)
+                        .reduce(ExpressionBuilder::or)
+                        .get();
+        return new Expression[] {
+            /* count = */ ifThenElse(operandExpression, count, plus(count, literal(1L)))
+        };
+    }
+
+    @Override
+    public Expression[] retractExpressions() {
+        throw new TableException("Count aggregate function does not support retraction.");
+    }
+
+    @Override
+    public Expression[] mergeExpressions() {
+        return new Expression[] {/* count = */ plus(count, mergeOperand(count))};
+    }
+
+    @Override
+    public Expression getValueExpression() {
+        return count;
+    }
+
+    @Override
+    public void setArguments(CallContext callContext) {
+        if (arguments == null) {
+            arguments = callContext.getArgumentDataTypes().size();
+            if (arguments == 1) {
+                // If the argument is literal indicates use count(literal)
+                countLiteral = callContext.isArgumentLiteral(0);
+            }
+        }
+    }
+}
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
index 6ca6ca84dd9..bb598891fef 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
@@ -26,6 +26,7 @@ import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
 import org.apache.flink.table.catalog.hive.factories.HiveFunctionDefinitionFactory;
 import org.apache.flink.table.factories.FunctionDefinitionFactory;
 import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.functions.hive.HiveCountAggFunction;
 import org.apache.flink.table.functions.hive.HiveMinAggFunction;
 import org.apache.flink.table.functions.hive.HiveSumAggFunction;
 import org.apache.flink.table.module.Module;
@@ -86,7 +87,7 @@ public class HiveModule implements Module {
                                     "tumble_start")));
 
     static final Set<String> BUILTIN_NATIVE_AGG_FUNC =
-            Collections.unmodifiableSet(new HashSet<>(Arrays.asList("sum", "min")));
+            Collections.unmodifiableSet(new HashSet<>(Arrays.asList("sum", "count", "min")));
 
     private final HiveFunctionDefinitionFactory factory;
     private final String hiveVersion;
@@ -206,6 +207,9 @@ public class HiveModule implements Module {
             case "sum":
                 // We override Hive's sum function by native implementation to supports hash-agg
                 return Optional.of(new HiveSumAggFunction());
+            case "count":
+                // We override Hive's sum function by native implementation to supports hash-agg
+                return Optional.of(new HiveCountAggFunction());
             case "min":
                 // We override Hive's min function by native implementation to supports hash-agg
                 return Optional.of(new HiveMinAggFunction());
diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
index bd124c7d554..3af77ad72b6 100644
--- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
+++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
@@ -30,7 +30,6 @@ import org.apache.flink.util.CollectionUtil;
 
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.ClassRule;
 import org.junit.Test;
@@ -58,6 +57,8 @@ public class HiveDialectAggITCase {
         hiveCatalog.getHiveConf().setVar(HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT, "none");
         hiveCatalog.open();
         tableEnv = getTableEnvWithHiveCatalog();
+        // enable native hive agg function
+        tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, true);
 
         // create tables
         tableEnv.executeSql("create table foo (x int, y int)");
@@ -71,12 +72,6 @@ public class HiveDialectAggITCase {
                 .commit();
     }
 
-    @Before
-    public void before() {
-        // enable native hive agg function
-        tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, true);
-    }
-
     @Test
     public void testSimpleSumAggFunction() throws Exception {
         tableEnv.executeSql(
@@ -167,6 +162,69 @@ public class HiveDialectAggITCase {
         tableEnv.executeSql("drop table test_sum_group");
     }
 
+    @Test
+    public void testSimpleCount() throws Exception {
+        tableEnv.executeSql("create table test_count(a int, x string, y string, z int, d bigint)");
+        tableEnv.executeSql(
+                        "insert into test_count values (1, NULL, '2', 1, 2), "
+                                + "(1, NULL, 'b', 2, NULL), "
+                                + "(2, NULL, '4', 1, 2), "
+                                + "(2, NULL, NULL, 4, 3)")
+                .await();
+
+        // test count(*)
+        List<Row> result =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql("select count(*) from test_count").collect());
+        assertThat(result.toString()).isEqualTo("[+I[4]]");
+
+        // test count(1)
+        List<Row> result2 =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql("select count(1) from test_count").collect());
+        assertThat(result2.toString()).isEqualTo("[+I[4]]");
+
+        // test count(col1)
+        List<Row> result3 =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql("select count(y) from test_count").collect());
+        assertThat(result3.toString()).isEqualTo("[+I[3]]");
+
+        // test count(distinct col1)
+        List<Row> result4 =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql("select count(distinct z) from test_count").collect());
+        assertThat(result4.toString()).isEqualTo("[+I[3]]");
+
+        // test count(distinct col1, col2)
+        List<Row> result5 =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql("select count(distinct z, d) from test_count")
+                                .collect());
+        assertThat(result5.toString()).isEqualTo("[+I[2]]");
+
+        tableEnv.executeSql("drop table test_count");
+    }
+
+    @Test
+    public void testCountAggWithGroupKey() throws Exception {
+        tableEnv.executeSql(
+                "create table test_count_group(a int, x string, y string, z int, d bigint)");
+        tableEnv.executeSql(
+                        "insert into test_count_group values (1, NULL, '2', 1, 2), "
+                                + "(1, NULL, '2', 2, NULL), "
+                                + "(2, NULL, '4', 1, 2), "
+                                + "(2, NULL, 3, 4, 3)")
+                .await();
+
+        List<Row> result =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql(
+                                        "select count(*), count(x), count(distinct y), count(distinct z, d) from test_count_group group by a")
+                                .collect());
+        assertThat(result.toString()).isEqualTo("[+I[2, 0, 1, 1], +I[2, 0, 2, 2]]");
+    }
+
     @Test
     public void testMinAggFunction() throws Exception {
         tableEnv.executeSql(
diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java
index 69b8ca9179f..48cc8b913d1 100644
--- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java
+++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryPlanTest.java
@@ -70,25 +70,41 @@ public class HiveDialectQueryPlanTest {
     @Test
     public void testSumAggFunctionPlan() {
         // test explain
-        String actualPlan = explainSql("select x, sum(y) from foo group by x");
+        String sql = "select x, sum(y) from foo group by x";
+        String actualPlan = explainSql(sql);
         assertThat(actualPlan).isEqualTo(readFromResource("/explain/testSumAggFunctionPlan.out"));
 
         // test fallback to hive sum udaf
         tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false);
-        String actualSortAggPlan = explainSql("select x, sum(y) from foo group by x");
+        String actualSortAggPlan = explainSql(sql);
         assertThat(actualSortAggPlan)
                 .isEqualTo(readFromResource("/explain/testSumAggFunctionFallbackPlan.out"));
     }
 
+    @Test
+    public void testCountAggFunctionPlan() {
+        // test explain
+        String sql = "select x, count(*), count(y), count(distinct y) from foo group by x";
+        String actualPlan = explainSql(sql);
+        assertThat(actualPlan).isEqualTo(readFromResource("/explain/testCountAggFunctionPlan.out"));
+
+        // test fallback to hive count udaf
+        tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false);
+        String actualSortAggPlan = explainSql(sql);
+        assertThat(actualSortAggPlan)
+                .isEqualTo(readFromResource("/explain/testCountAggFunctionFallbackPlan.out"));
+    }
+
     @Test
     public void testMinAggFunctionPlan() {
         // test explain
-        String actualPlan = explainSql("select x, min(y) from foo group by x");
+        String sql = "select x, min(y) from foo group by x";
+        String actualPlan = explainSql(sql);
         assertThat(actualPlan).isEqualTo(readFromResource("/explain/testMinAggFunctionPlan.out"));
 
         // test fallback to hive min udaf
         tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED, false);
-        String actualSortAggPlan = explainSql("select x, min(y) from foo group by x");
+        String actualSortAggPlan = explainSql(sql);
         assertThat(actualSortAggPlan)
                 .isEqualTo(readFromResource("/explain/testMinAggFunctionFallbackPlan.out"));
     }
diff --git a/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out
new file mode 100644
index 00000000000..e356f402212
--- /dev/null
+++ b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionFallbackPlan.out
@@ -0,0 +1,35 @@
+== Abstract Syntax Tree ==
+LogicalProject(x=[$0], _o__c1=[$1], _o__c2=[$2], _o__c3=[$3])
++- LogicalAggregate(group=[{0}], agg#0=[count()], agg#1=[count($1)], agg#2=[count(DISTINCT $1)])
+   +- LogicalProject($f0=[$0], $f1=[$1])
+      +- LogicalTableScan(table=[[test-catalog, default, foo]])
+
+== Optimized Physical Plan ==
+SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count($f3) AS $f3])
++- Sort(orderBy=[x ASC])
+   +- Exchange(distribution=[hash[x]])
+      +- LocalSortAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS $f3])
+         +- Calc(select=[x, y, $f1, $f2, =(CASE(=($e, 0), 0, 1), 0) AS $g_0, =(CASE(=($e, 0), 0, 1), 1) AS $g_1])
+            +- Sort(orderBy=[x ASC])
+               +- SortAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count($f1) AS $f1, Final_count($f2) AS $f2])
+                  +- Sort(orderBy=[x ASC, y ASC, $e ASC])
+                     +- Exchange(distribution=[hash[x, y, $e]])
+                        +- LocalSortAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS $f1, Partial_count(y_0) AS $f2])
+                           +- Sort(orderBy=[x ASC, y ASC, $e ASC])
+                              +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}])
+                                 +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
+
+== Optimized Execution Plan ==
+SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count($f3) AS $f3])
++- Sort(orderBy=[x ASC])
+   +- Exchange(distribution=[hash[x]])
+      +- LocalSortAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS $f3])
+         +- Calc(select=[x, y, $f1, $f2, (CASE(($e = 0), 0, 1) = 0) AS $g_0, (CASE(($e = 0), 0, 1) = 1) AS $g_1])
+            +- Sort(orderBy=[x ASC])
+               +- SortAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count($f1) AS $f1, Final_count($f2) AS $f2])
+                  +- Sort(orderBy=[x ASC, y ASC, $e ASC])
+                     +- Exchange(distribution=[hash[x, y, $e]])
+                        +- LocalSortAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS $f1, Partial_count(y_0) AS $f2])
+                           +- Sort(orderBy=[x ASC, y ASC, $e ASC])
+                              +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}])
+                                 +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
diff --git a/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out
new file mode 100644
index 00000000000..fc6e8b6d8cb
--- /dev/null
+++ b/flink-connectors/flink-connector-hive/src/test/resources/explain/testCountAggFunctionPlan.out
@@ -0,0 +1,27 @@
+== Abstract Syntax Tree ==
+LogicalProject(x=[$0], _o__c1=[$1], _o__c2=[$2], _o__c3=[$3])
++- LogicalAggregate(group=[{0}], agg#0=[count()], agg#1=[count($1)], agg#2=[count(DISTINCT $1)])
+   +- LogicalProject($f0=[$0], $f1=[$1])
+      +- LogicalTableScan(table=[[test-catalog, default, foo]])
+
+== Optimized Physical Plan ==
+HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count(count$2) AS $f3])
++- Exchange(distribution=[hash[x]])
+   +- LocalHashAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS count$2])
+      +- Calc(select=[x, y, $f1, $f2, =(CASE(=($e, 0), 0, 1), 0) AS $g_0, =(CASE(=($e, 0), 0, 1), 1) AS $g_1])
+         +- HashAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count(count$0) AS $f1, Final_count(count$1) AS $f2])
+            +- Exchange(distribution=[hash[x, y, $e]])
+               +- LocalHashAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS count$0, Partial_count(y_0) AS count$1])
+                  +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}])
+                     +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
+
+== Optimized Execution Plan ==
+HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_MIN(min$0) AS $f1, Final_MIN(min$1) AS $f2, Final_count(count$2) AS $f3])
++- Exchange(distribution=[hash[x]])
+   +- LocalHashAggregate(groupBy=[x], select=[x, Partial_MIN($f1) FILTER $g_1 AS min$0, Partial_MIN($f2) FILTER $g_1 AS min$1, Partial_count(y) FILTER $g_0 AS count$2])
+      +- Calc(select=[x, y, $f1, $f2, (CASE(($e = 0), 0, 1) = 0) AS $g_0, (CASE(($e = 0), 0, 1) = 1) AS $g_1])
+         +- HashAggregate(isMerge=[true], groupBy=[x, y, $e], select=[x, y, $e, Final_count(count$0) AS $f1, Final_count(count$1) AS $f2])
+            +- Exchange(distribution=[hash[x, y, $e]])
+               +- LocalHashAggregate(groupBy=[x, y, $e], select=[x, y, $e, Partial_count(*) AS count$0, Partial_count(y_0) AS count$1])
+                  +- Expand(projects=[{x, y, 0 AS $e, y AS y_0}, {x, null AS y, 1 AS $e, y AS y_0}])
+                     +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])