You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/11/15 15:18:31 UTC

flink git commit: [FLINK-8013] [table] Support aggregate functions with generic arrays

Repository: flink
Updated Branches:
  refs/heads/master 9d3471574 -> 11218a35d


[FLINK-8013] [table] Support aggregate functions with generic arrays

This closes #5011.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/11218a35
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/11218a35
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/11218a35

Branch: refs/heads/master
Commit: 11218a35dc0fdd7439142a313e6628c51cffe689
Parents: 9d34715
Author: twalthr <tw...@apache.org>
Authored: Tue Nov 14 11:06:54 2017 +0100
Committer: twalthr <tw...@apache.org>
Committed: Wed Nov 15 15:11:39 2017 +0100

----------------------------------------------------------------------
 .../api/java/typeutils/TypeExtractionUtils.java |  16 +++
 .../codegen/AggregationCodeGenerator.scala      |  20 ++-
 .../flink/table/expressions/aggregations.scala  |   6 +-
 .../utils/UserDefinedFunctionUtils.scala        |   9 +-
 .../runtime/batch/table/AggregateITCase.scala   |  25 +++-
 .../flink/table/utils/TableTestBase.scala       |   2 +-
 .../table/utils/UserDefinedAggFunctions.scala   | 126 +++++++++++++++++++
 .../api/scala/util/CollectionDataSets.scala     |   1 -
 8 files changed, 185 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
index c5c2565..56fcf82 100644
--- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
+++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
@@ -18,7 +18,9 @@
 
 package org.apache.flink.api.java.typeutils;
 
+import java.lang.reflect.Array;
 import java.lang.reflect.Constructor;
+import java.lang.reflect.GenericArrayType;
 import java.lang.reflect.Method;
 import java.lang.reflect.Modifier;
 import java.lang.reflect.ParameterizedType;
@@ -322,4 +324,18 @@ public class TypeExtractionUtils {
 		}
 		return false;
 	}
+
+	/**
+	 * Returns the raw class of both parameterized types and generic arrays.
+	 * Returns java.lang.Object for all other types.
+	 */
+	public static Class<?> getRawClass(Type t) {
+		if (isClassType(t)) {
+			return typeToClass(t);
+		} else if (t instanceof GenericArrayType) {
+			Type component = ((GenericArrayType)t).getGenericComponentType();
+			return Array.newInstance(getRawClass(component), 0).getClass();
+		}
+		return Object.class;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index c85b111..32cbde2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -17,17 +17,18 @@
  */
 package org.apache.flink.table.codegen
 
-import java.lang.reflect.{Modifier, ParameterizedType}
+import java.lang.reflect.Modifier
 import java.lang.{Iterable => JIterable}
 
 import org.apache.calcite.rex.RexLiteral
 import org.apache.commons.codec.binary.Base64
 import org.apache.flink.api.common.state.{State, StateDescriptor}
 import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument, getRawClass}
 import org.apache.flink.table.api.TableConfig
 import org.apache.flink.table.api.dataview._
-import org.apache.flink.table.codegen.Indenter.toISC
 import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess}
+import org.apache.flink.table.codegen.Indenter.toISC
 import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString}
@@ -175,7 +176,7 @@ class AggregationCodeGenerator(
         }
 
         if (needMerge) {
-          val methods =
+          val method =
             getUserDefinedMethod(a, "merge", Array(accTypeClasses(i), classOf[JIterable[Any]]))
             .getOrElse(
               throw new CodeGenException(
@@ -183,17 +184,14 @@ class AggregationCodeGenerator(
                   s"${a.getClass.getCanonicalName}'.")
             )
 
-          var iterableTypeClass = methods.getGenericParameterTypes.apply(1)
-                                  .asInstanceOf[ParameterizedType].getActualTypeArguments.apply(0)
-          // further extract iterableTypeClass if the accumulator has generic type
-          iterableTypeClass match {
-            case impl: ParameterizedType => iterableTypeClass = impl.getRawType
-            case _ =>
-          }
+          // use the TypeExtractionUtils here to support nested GenericArrayTypes and
+          // other complex types
+          val iterableGenericType = extractTypeArgument(method.getGenericParameterTypes()(1), 0)
+          val iterableTypeClass = getRawClass(iterableGenericType)
 
           if (iterableTypeClass != accTypeClasses(i)) {
             throw new CodeGenException(
-              s"merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " +
+              s"Merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " +
                 s"the correct Iterable type. Actually: ${iterableTypeClass.toString}. " +
                 s"Expected: ${accTypeClasses(i).toString}")
           }

http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index 1ffcb12..3adaaa9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -245,8 +245,10 @@ case class AggFunctionCall(
       ValidationFailure(s"Given parameters do not match any signature. \n" +
                           s"Actual: ${signatureToString(signature)} \n" +
                           s"Expected: ${
-                            getMethodSignatures(aggregateFunction, "accumulate").drop(1)
-                              .map(signatureToString).mkString(", ")}")
+                            getMethodSignatures(aggregateFunction, "accumulate")
+                              .map(_.drop(1))
+                              .map(signatureToString)
+                              .mkString(", ")}")
     } else {
       ValidationSuccess
     }

http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 3cd694a..4a34732 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -607,13 +607,14 @@ object UserDefinedFunctionUtils {
     candidate == expected ||
     expected == classOf[Object] ||
     expected.isPrimitive && Primitives.wrap(expected) == candidate ||
+    // time types
     candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt])  ||
     candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
     candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong]) ||
-    (candidate.isArray &&
-      expected.isArray &&
-      candidate.getComponentType.isInstanceOf[Object] &&
-      expected.getComponentType == classOf[Object])
+    // arrays
+    (candidate.isArray && expected.isArray &&
+      (candidate.getComponentType == expected.getComponentType ||
+        expected.getComponentType == classOf[Object]))
 
   @throws[Exception]
   def serialize(function: UserDefinedFunction): String = {

http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
index cf96d19..e1348f6 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
@@ -22,12 +22,13 @@ import java.math.BigDecimal
 
 import org.apache.flink.api.scala._
 import org.apache.flink.api.scala.util.CollectionDataSets
-import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.{TableEnvironment, Types}
 import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinctWithMergeAndReset, WeightedAvgWithMergeAndReset}
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.functions.aggfunctions.CountAggFunction
 import org.apache.flink.table.runtime.utils.TableProgramsCollectionTestBase
 import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.utils.Top10
 import org.apache.flink.test.util.TestBaseUtils
 import org.apache.flink.types.Row
 import org.junit._
@@ -392,6 +393,28 @@ class AggregationsITCase(
     val results = res.toDataSet[Row].collect()
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
+
+  @Test
+  def testComplexAggregate(): Unit = {
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+    val top10Fun = new Top10
+
+    val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
+      .groupBy('b)
+      .select('b, top10Fun('b.cast(Types.INT), 'a.cast(Types.FLOAT)))
+
+    val expected =
+      "1,[(1,1.0), null, null, null, null, null, null, null, null, null]\n" +
+      "2,[(2,3.0), (2,2.0), null, null, null, null, null, null, null, null]\n" +
+      "3,[(3,6.0), (3,5.0), (3,4.0), null, null, null, null, null, null, null]\n" +
+      "4,[(4,10.0), (4,9.0), (4,8.0), (4,7.0), null, null, null, null, null, null]\n" +
+      "5,[(5,15.0), (5,14.0), (5,13.0), (5,12.0), (5,11.0), null, null, null, null, null]\n" +
+      "6,[(6,21.0), (6,20.0), (6,19.0), (6,18.0), (6,17.0), (6,16.0), null, null, null, null]"
+    val results = t.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
 }
 
 case class WC(word: String, frequency: Long)

http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
index 3829314..804fad8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
@@ -91,7 +91,7 @@ abstract class TableTestUtil {
     val actual = RelOptUtil.toString(optimized)
     // we remove the charset for testing because it
     // depends on the native machine (Little/Big Endian)
-    val actualNoCharset = actual.replace("_UTF-16LE'", "'")
+    val actualNoCharset = actual.replace("_UTF-16LE'", "'").replace("_UTF-16BE'", "'")
     assertEquals(
       expected.split("\n").map(_.trim).mkString("\n"),
       actualNoCharset.split("\n").map(_.trim).mkString("\n"))

http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala
new file mode 100644
index 0000000..7d4393c
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.utils
+
+import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import java.lang.{Integer => JInt}
+import java.lang.{Float => JFloat}
+import java.util
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, TupleTypeInfo}
+import org.apache.flink.table.api.Types
+
+/**
+  * User-defined aggregation function to compute the top 10 most visited Int IDs
+  * with the highest Float values. We use an Array[Tuple2[Int, Float]] as accumulator to
+  * store the top 10 entries.
+  *
+  * The result is emitted as Array as well.
+  */
+class Top10 extends AggregateFunction[Array[JTuple2[JInt, JFloat]], Array[JTuple2[JInt, JFloat]]] {
+
+  @Override
+  def createAccumulator(): Array[JTuple2[JInt, JFloat]] = {
+    new Array[JTuple2[JInt, JFloat]](10)
+  }
+
+  /**
+    * Adds a new entry and count to the top 10 entries if necessary.
+    *
+    * @param acc The current top 10
+    * @param id The ID
+    * @param value The value for the ID
+    */
+  def accumulate(acc: Array[JTuple2[JInt, JFloat]], id: Int, value: Float) {
+
+    var i = 9
+    var skipped = 0
+
+    // skip positions without records
+    while (i >= 0 && acc(i) == null) {
+      if (acc(i) == null) {
+        // continue until first entry in the top10 list
+        i -= 1
+      }
+    }
+    // backward linear search for insert position
+    while (i >= 0 && value > acc(i).f1) {
+      // check next entry
+      skipped += 1
+      i -= 1
+    }
+
+    // set if necessary
+    if (i < 9) {
+      // move entries with lower count by one position
+      if (i < 8 && skipped > 0) {
+        System.arraycopy(acc, i + 1, acc, i + 2, skipped)
+      }
+
+      // add ID to top10 list
+      acc(i + 1) = JTuple2.of(id, value)
+    }
+  }
+
+  override def getValue(acc: Array[JTuple2[JInt, JFloat]]): Array[JTuple2[JInt, JFloat]] = acc
+
+  def resetAccumulator(acc: Array[JTuple2[JInt, JFloat]]): Unit = {
+    util.Arrays.fill(acc.asInstanceOf[Array[Object]], null)
+  }
+
+  def merge(
+      acc: Array[JTuple2[JInt, JFloat]],
+      its: java.lang.Iterable[Array[JTuple2[JInt, JFloat]]]): Unit = {
+
+    val it = its.iterator()
+    while(it.hasNext) {
+      val acc2 = it.next()
+
+      var i = 0
+      var i2 = 0
+      while (i < 10 && i2 < 10 && acc2(i2) != null) {
+        if (acc(i) == null) {
+          // copy to empty place
+          acc(i) = acc2(i2)
+          i += 1
+          i2 += 1
+        } else if (acc(i).f1.asInstanceOf[Float] >= acc2(i2).f1.asInstanceOf[Float]) {
+          // forward to next
+          i += 1
+        } else {
+          // shift and copy
+          System.arraycopy(acc, i, acc, i + 1, 9 - i)
+          acc(i) = acc2(i2)
+          i += 1
+          i2 += 1
+        }
+      }
+    }
+  }
+
+  override def getAccumulatorType: TypeInformation[Array[JTuple2[JInt, JFloat]]] = {
+    ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo[JTuple2[JInt, JFloat]](Types.INT, Types.FLOAT))
+  }
+
+  override def getResultType: TypeInformation[Array[JTuple2[JInt, JFloat]]] = {
+    ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo[JTuple2[JInt, JFloat]](Types.INT, Types.FLOAT))
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala
index ec1a810..1cb5b52 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala
@@ -55,7 +55,6 @@ object CollectionDataSets {
     data.+=((19, 6L, "Comment#13"))
     data.+=((20, 6L, "Comment#14"))
     data.+=((21, 6L, "Comment#15"))
-    Random.shuffle(data)
     env.fromCollection(Random.shuffle(data))
   }