You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/11/15 15:18:31 UTC
flink git commit: [FLINK-8013] [table] Support aggregate functions
with generic arrays
Repository: flink
Updated Branches:
refs/heads/master 9d3471574 -> 11218a35d
[FLINK-8013] [table] Support aggregate functions with generic arrays
This closes #5011.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/11218a35
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/11218a35
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/11218a35
Branch: refs/heads/master
Commit: 11218a35dc0fdd7439142a313e6628c51cffe689
Parents: 9d34715
Author: twalthr <tw...@apache.org>
Authored: Tue Nov 14 11:06:54 2017 +0100
Committer: twalthr <tw...@apache.org>
Committed: Wed Nov 15 15:11:39 2017 +0100
----------------------------------------------------------------------
.../api/java/typeutils/TypeExtractionUtils.java | 16 +++
.../codegen/AggregationCodeGenerator.scala | 20 ++-
.../flink/table/expressions/aggregations.scala | 6 +-
.../utils/UserDefinedFunctionUtils.scala | 9 +-
.../runtime/batch/table/AggregateITCase.scala | 25 +++-
.../flink/table/utils/TableTestBase.scala | 2 +-
.../table/utils/UserDefinedAggFunctions.scala | 126 +++++++++++++++++++
.../api/scala/util/CollectionDataSets.scala | 1 -
8 files changed, 185 insertions(+), 20 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
index c5c2565..56fcf82 100644
--- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
+++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java
@@ -18,7 +18,9 @@
package org.apache.flink.api.java.typeutils;
+import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
+import java.lang.reflect.GenericArrayType;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
@@ -322,4 +324,18 @@ public class TypeExtractionUtils {
}
return false;
}
+
+ /**
+ * Returns the raw class of both parameterized types and generic arrays.
+ * Returns java.lang.Object for all other types.
+ */
+ public static Class<?> getRawClass(Type t) {
+ if (isClassType(t)) {
+ return typeToClass(t);
+ } else if (t instanceof GenericArrayType) {
+ Type component = ((GenericArrayType)t).getGenericComponentType();
+ return Array.newInstance(getRawClass(component), 0).getClass();
+ }
+ return Object.class;
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index c85b111..32cbde2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -17,17 +17,18 @@
*/
package org.apache.flink.table.codegen
-import java.lang.reflect.{Modifier, ParameterizedType}
+import java.lang.reflect.Modifier
import java.lang.{Iterable => JIterable}
import org.apache.calcite.rex.RexLiteral
import org.apache.commons.codec.binary.Base64
import org.apache.flink.api.common.state.{State, StateDescriptor}
import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument, getRawClass}
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.api.dataview._
-import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess}
+import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString}
@@ -175,7 +176,7 @@ class AggregationCodeGenerator(
}
if (needMerge) {
- val methods =
+ val method =
getUserDefinedMethod(a, "merge", Array(accTypeClasses(i), classOf[JIterable[Any]]))
.getOrElse(
throw new CodeGenException(
@@ -183,17 +184,14 @@ class AggregationCodeGenerator(
s"${a.getClass.getCanonicalName}'.")
)
- var iterableTypeClass = methods.getGenericParameterTypes.apply(1)
- .asInstanceOf[ParameterizedType].getActualTypeArguments.apply(0)
- // further extract iterableTypeClass if the accumulator has generic type
- iterableTypeClass match {
- case impl: ParameterizedType => iterableTypeClass = impl.getRawType
- case _ =>
- }
+ // use the TypeExtractionUtils here to support nested GenericArrayTypes and
+ // other complex types
+ val iterableGenericType = extractTypeArgument(method.getGenericParameterTypes()(1), 0)
+ val iterableTypeClass = getRawClass(iterableGenericType)
if (iterableTypeClass != accTypeClasses(i)) {
throw new CodeGenException(
- s"merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " +
+ s"Merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " +
s"the correct Iterable type. Actually: ${iterableTypeClass.toString}. " +
s"Expected: ${accTypeClasses(i).toString}")
}
http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index 1ffcb12..3adaaa9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -245,8 +245,10 @@ case class AggFunctionCall(
ValidationFailure(s"Given parameters do not match any signature. \n" +
s"Actual: ${signatureToString(signature)} \n" +
s"Expected: ${
- getMethodSignatures(aggregateFunction, "accumulate").drop(1)
- .map(signatureToString).mkString(", ")}")
+ getMethodSignatures(aggregateFunction, "accumulate")
+ .map(_.drop(1))
+ .map(signatureToString)
+ .mkString(", ")}")
} else {
ValidationSuccess
}
http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 3cd694a..4a34732 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -607,13 +607,14 @@ object UserDefinedFunctionUtils {
candidate == expected ||
expected == classOf[Object] ||
expected.isPrimitive && Primitives.wrap(expected) == candidate ||
+ // time types
candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong]) ||
- (candidate.isArray &&
- expected.isArray &&
- candidate.getComponentType.isInstanceOf[Object] &&
- expected.getComponentType == classOf[Object])
+ // arrays
+ (candidate.isArray && expected.isArray &&
+ (candidate.getComponentType == expected.getComponentType ||
+ expected.getComponentType == classOf[Object]))
@throws[Exception]
def serialize(function: UserDefinedFunction): String = {
http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
index cf96d19..e1348f6 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
@@ -22,12 +22,13 @@ import java.math.BigDecimal
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
-import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.{TableEnvironment, Types}
import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinctWithMergeAndReset, WeightedAvgWithMergeAndReset}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.functions.aggfunctions.CountAggFunction
import org.apache.flink.table.runtime.utils.TableProgramsCollectionTestBase
import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.utils.Top10
import org.apache.flink.test.util.TestBaseUtils
import org.apache.flink.types.Row
import org.junit._
@@ -392,6 +393,28 @@ class AggregationsITCase(
val results = res.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
+
+ @Test
+ def testComplexAggregate(): Unit = {
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env, config)
+ val top10Fun = new Top10
+
+ val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
+ .groupBy('b)
+ .select('b, top10Fun('b.cast(Types.INT), 'a.cast(Types.FLOAT)))
+
+ val expected =
+ "1,[(1,1.0), null, null, null, null, null, null, null, null, null]\n" +
+ "2,[(2,3.0), (2,2.0), null, null, null, null, null, null, null, null]\n" +
+ "3,[(3,6.0), (3,5.0), (3,4.0), null, null, null, null, null, null, null]\n" +
+ "4,[(4,10.0), (4,9.0), (4,8.0), (4,7.0), null, null, null, null, null, null]\n" +
+ "5,[(5,15.0), (5,14.0), (5,13.0), (5,12.0), (5,11.0), null, null, null, null, null]\n" +
+ "6,[(6,21.0), (6,20.0), (6,19.0), (6,18.0), (6,17.0), (6,16.0), null, null, null, null]"
+ val results = t.toDataSet[Row].collect()
+ TestBaseUtils.compareResultAsText(results.asJava, expected)
+ }
}
case class WC(word: String, frequency: Long)
http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
index 3829314..804fad8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
@@ -91,7 +91,7 @@ abstract class TableTestUtil {
val actual = RelOptUtil.toString(optimized)
// we remove the charset for testing because it
// depends on the native machine (Little/Big Endian)
- val actualNoCharset = actual.replace("_UTF-16LE'", "'")
+ val actualNoCharset = actual.replace("_UTF-16LE'", "'").replace("_UTF-16BE'", "'")
assertEquals(
expected.split("\n").map(_.trim).mkString("\n"),
actualNoCharset.split("\n").map(_.trim).mkString("\n"))
http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala
new file mode 100644
index 0000000..7d4393c
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.utils
+
+import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import java.lang.{Integer => JInt}
+import java.lang.{Float => JFloat}
+import java.util
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, TupleTypeInfo}
+import org.apache.flink.table.api.Types
+
+/**
+ * User-defined aggregation function to compute the top 10 most visited Int IDs
+ * with the highest Float values. We use an Array[Tuple2[Int, Float]] as accumulator to
+ * store the top 10 entries.
+ *
+ * The result is emitted as Array as well.
+ */
+class Top10 extends AggregateFunction[Array[JTuple2[JInt, JFloat]], Array[JTuple2[JInt, JFloat]]] {
+
+ @Override
+ def createAccumulator(): Array[JTuple2[JInt, JFloat]] = {
+ new Array[JTuple2[JInt, JFloat]](10)
+ }
+
+ /**
+ * Adds a new entry and count to the top 10 entries if necessary.
+ *
+ * @param acc The current top 10
+ * @param id The ID
+ * @param value The value for the ID
+ */
+ def accumulate(acc: Array[JTuple2[JInt, JFloat]], id: Int, value: Float) {
+
+ var i = 9
+ var skipped = 0
+
+ // skip positions without records
+ while (i >= 0 && acc(i) == null) {
+ if (acc(i) == null) {
+ // continue until first entry in the top10 list
+ i -= 1
+ }
+ }
+ // backward linear search for insert position
+ while (i >= 0 && value > acc(i).f1) {
+ // check next entry
+ skipped += 1
+ i -= 1
+ }
+
+ // set if necessary
+ if (i < 9) {
+ // move entries with lower count by one position
+ if (i < 8 && skipped > 0) {
+ System.arraycopy(acc, i + 1, acc, i + 2, skipped)
+ }
+
+ // add ID to top10 list
+ acc(i + 1) = JTuple2.of(id, value)
+ }
+ }
+
+ override def getValue(acc: Array[JTuple2[JInt, JFloat]]): Array[JTuple2[JInt, JFloat]] = acc
+
+ def resetAccumulator(acc: Array[JTuple2[JInt, JFloat]]): Unit = {
+ util.Arrays.fill(acc.asInstanceOf[Array[Object]], null)
+ }
+
+ def merge(
+ acc: Array[JTuple2[JInt, JFloat]],
+ its: java.lang.Iterable[Array[JTuple2[JInt, JFloat]]]): Unit = {
+
+ val it = its.iterator()
+ while(it.hasNext) {
+ val acc2 = it.next()
+
+ var i = 0
+ var i2 = 0
+ while (i < 10 && i2 < 10 && acc2(i2) != null) {
+ if (acc(i) == null) {
+ // copy to empty place
+ acc(i) = acc2(i2)
+ i += 1
+ i2 += 1
+ } else if (acc(i).f1.asInstanceOf[Float] >= acc2(i2).f1.asInstanceOf[Float]) {
+ // forward to next
+ i += 1
+ } else {
+ // shift and copy
+ System.arraycopy(acc, i, acc, i + 1, 9 - i)
+ acc(i) = acc2(i2)
+ i += 1
+ i2 += 1
+ }
+ }
+ }
+ }
+
+ override def getAccumulatorType: TypeInformation[Array[JTuple2[JInt, JFloat]]] = {
+ ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo[JTuple2[JInt, JFloat]](Types.INT, Types.FLOAT))
+ }
+
+ override def getResultType: TypeInformation[Array[JTuple2[JInt, JFloat]]] = {
+ ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo[JTuple2[JInt, JFloat]](Types.INT, Types.FLOAT))
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/11218a35/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala
index ec1a810..1cb5b52 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/util/CollectionDataSets.scala
@@ -55,7 +55,6 @@ object CollectionDataSets {
data.+=((19, 6L, "Comment#13"))
data.+=((20, 6L, "Comment#14"))
data.+=((21, 6L, "Comment#15"))
- Random.shuffle(data)
env.fromCollection(Random.shuffle(data))
}