You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2018/12/13 04:16:51 UTC

[flink] 01/03: [FLINK-11136] [table] Fix the merge logic of DISTINCT aggregates.

This is an automated email from the ASF dual-hosted git repository.

fhueske pushed a commit to branch release-1.7
in repository https://gitbox.apache.org/repos/asf/flink.git

commit ff1821a6d2f8317d0c344719b14350ac362143d9
Author: Dian Fu <fu...@alibaba-inc.com>
AuthorDate: Mon Dec 10 21:33:02 2018 +0800

    [FLINK-11136] [table] Fix the merge logic of DISTINCT aggregates.
    
    This closes #7284.
---
 .../flink/table/codegen/AggregationCodeGenerator.scala  | 17 ++++++++++++++++-
 .../flink/table/runtime/stream/sql/SqlITCase.scala      |  7 ++++---
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index 566e3d7..57cc815 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -142,6 +142,21 @@ class AggregationCodeGenerator(
       fields.mkString(", ")
     }
 
+    val parametersCodeForDistinctMerge = aggFields.map { inFields =>
+      val fields = inFields.filter(_ > -1).zipWithIndex.map { case (f, i) =>
+        // index to constant
+        if (f >= physicalInputTypes.length) {
+          constantFields(f - physicalInputTypes.length)
+        }
+        // index to input field
+        else {
+          s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) k.getField($i)"
+        }
+      }
+
+      fields.mkString(", ")
+    }
+
     // get method signatures
     val classes = UserDefinedFunctionUtils.typeInfoToClass(physicalInputTypes)
     val constantClasses = UserDefinedFunctionUtils.typeInfoToClass(constantTypes)
@@ -643,7 +658,7 @@ class AggregationCodeGenerator(
                |          (${classOf[Row].getCanonicalName}) entry.getKey();
                |      Long v = (Long) entry.getValue();
                |      if (aDistinctAcc$i.add(k, v)) {
-               |        ${aggs(i)}.accumulate(aAcc$i, k);
+               |        ${aggs(i)}.accumulate(aAcc$i, ${parametersCodeForDistinctMerge(i)});
                |      }
                |    }
                |    a.setField($i, aDistinctAcc$i);
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
index 46dde8e..ddc2a68 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala
@@ -78,6 +78,7 @@ class SqlITCase extends StreamingWithStateTestBase {
 
     val sqlQuery = "SELECT c, " +
       "  COUNT(DISTINCT b)," +
+      "  SUM(DISTINCT b)," +
       "  SESSION_END(rowtime, INTERVAL '0.005' SECOND) " +
       "FROM MyTable " +
       "GROUP BY SESSION(rowtime, INTERVAL '0.005' SECOND), c "
@@ -87,9 +88,9 @@ class SqlITCase extends StreamingWithStateTestBase {
     env.execute()
 
     val expected = Seq(
-      "Hello World,1,1970-01-01 00:00:00.014", // window starts at [9L] till {14L}
-      "Hello,1,1970-01-01 00:00:00.021",       // window starts at [16L] till {21L}, not merged
-      "Hello,3,1970-01-01 00:00:00.015"        // window starts at [1L,2L],
+      "Hello World,1,9,1970-01-01 00:00:00.014", // window starts at [9L] till {14L}
+      "Hello,1,16,1970-01-01 00:00:00.021",       // window starts at [16L] till {21L}, not merged
+      "Hello,3,6,1970-01-01 00:00:00.015"        // window starts at [1L,2L],
                                                //   merged with [8L,10L], by [4L], till {15L}
     )
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)