You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2017/07/22 11:30:45 UTC
flink git commit: [FLINK-7194] [table] Add default implementations
for type hints to UDAGG interface.
Repository: flink
Updated Branches:
refs/heads/master c472309c7 -> ea1edfb46
[FLINK-7194] [table] Add default implementations for type hints to UDAGG interface.
This closes #4379
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ea1edfb4
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ea1edfb4
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ea1edfb4
Branch: refs/heads/master
Commit: ea1edfb46f674035fd920c70100f60575600405f
Parents: c472309
Author: Fabian Hueske <fh...@apache.org>
Authored: Thu Jul 20 15:09:06 2017 +0200
Committer: Jincheng Sun <ji...@apache.org>
Committed: Sat Jul 22 06:55:44 2017 +0800
----------------------------------------------------------------------
.../table/functions/AggregateFunction.scala | 64 +++++++-------
.../functions/aggfunctions/AvgAggFunction.scala | 16 ++--
.../aggfunctions/CountAggFunction.scala | 13 +--
.../functions/aggfunctions/MaxAggFunction.scala | 4 +-
.../MaxAggFunctionWithRetract.scala | 8 +-
.../functions/aggfunctions/MinAggFunction.scala | 4 +-
.../MinAggFunctionWithRetract.scala | 8 +-
.../functions/aggfunctions/SumAggFunction.scala | 8 +-
.../SumWithRetractAggFunction.scala | 8 +-
.../utils/UserDefinedFunctionUtils.scala | 90 +++++++++++++-------
.../table/api/stream/sql/AggregateTest.scala | 2 +-
.../aggfunctions/CountAggFunctionTest.scala | 8 +-
12 files changed, 129 insertions(+), 104 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index f90860b..8f50971 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -17,6 +17,8 @@
*/
package org.apache.flink.table.functions
+import org.apache.flink.api.common.typeinfo.TypeInformation
+
/**
* Base class for User-Defined Aggregates.
*
@@ -28,9 +30,8 @@ package org.apache.flink.table.functions
*
* There are a few other methods that can be optional to have:
* - retract,
- * - merge,
- * - resetAccumulator, and
- * - getAccumulatorType.
+ * - merge, and
+ * - resetAccumulator
*
* All these methods muse be declared publicly, not static and named exactly as the names
* mentioned above. The methods createAccumulator and getValue are defined in the
@@ -72,7 +73,7 @@ package org.apache.flink.table.functions
* custom merge method.
* @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be
* merged.
-
+ *
* def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit
* }}}
*
@@ -82,39 +83,16 @@ package org.apache.flink.table.functions
* dataset grouping aggregate.
*
* @param accumulator the accumulator which needs to be reset
-
- * def resetAccumulator(accumulator: ACC): Unit
- * }}}
- *
*
- * {{{
- * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the accumulator. This
- * function is optional and can be implemented if the accumulator type cannot be automatically
- * inferred from the instance returned by createAccumulator method.
- *
- * @return the type information for the accumulator.
-
- * def getAccumulatorType: TypeInformation[_]
- * }}}
- *
- *
- * {{{
- * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the return value. This
- * function is optional and needed in case Flink's type extraction facilities are not sufficient
- * to extract the TypeInformation. Flink's type extraction facilities can handle basic types or
- * simple POJOs but might be wrong for more complex, custom, or composite types.
- *
- * @return the type information for the return value.
- *
- * def getResultType: TypeInformation[_]
+ * def resetAccumulator(accumulator: ACC): Unit
* }}}
*
*
* @tparam T the type of the aggregation result
- * @tparam ACC base class for aggregate Accumulator. The accumulator is used to keep the aggregated
- * values which are needed to compute an aggregation result. AggregateFunction
- * represents its state using accumulator, thereby the state of the AggregateFunction
- * must be put into the accumulator.
+ * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
+ * aggregated values which are needed to compute an aggregation result.
+ * AggregateFunction represents its state using accumulator, thereby the state of the
+ * AggregateFunction must be put into the accumulator.
*/
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
/**
@@ -136,8 +114,26 @@ abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
*/
def getValue(accumulator: ACC): T
- /**
- * whether this aggregate only used in OVER clause
+ /**
+ * Returns true if this AggregateFunction can only be applied in an OVER window.
+ *
+ * @return true if the AggregateFunction requires an OVER window, false otherwise.
*/
def requiresOver: Boolean = false
+
+ /**
+ * Returns the TypeInformation of the AggregateFunction's result.
+ *
+ * @return The TypeInformation of the AggregateFunction's result or null if the result type
+ * should be automatically inferred.
+ */
+ def getResultType: TypeInformation[T] = null
+
+ /**
+ * Returns the TypeInformation of the AggregateFunction's accumulator.
+ *
+ * @return The TypeInformation of the AggregateFunction's accumulator or null if the
+ * accumulator type should be automatically inferred.
+ */
+ def getAccumulatorType: TypeInformation[ACC] = null
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
index 3f4e5db..b651c42 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
@@ -80,9 +80,9 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T, IntegralAv
acc.f1 = 0L
}
- def getAccumulatorType: TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[IntegralAvgAccumulator] = {
new TupleTypeInfo(
- new IntegralAvgAccumulator().getClass,
+ classOf[IntegralAvgAccumulator],
BasicTypeInfo.LONG_TYPE_INFO,
BasicTypeInfo.LONG_TYPE_INFO)
}
@@ -175,9 +175,9 @@ abstract class BigIntegralAvgAggFunction[T]
acc.f1 = 0
}
- def getAccumulatorType: TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[BigIntegralAvgAccumulator] = {
new TupleTypeInfo(
- new BigIntegralAvgAccumulator().getClass,
+ classOf[BigIntegralAvgAccumulator],
BasicTypeInfo.BIG_INT_TYPE_INFO,
BasicTypeInfo.LONG_TYPE_INFO)
}
@@ -255,9 +255,9 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T, FloatingAv
acc.f1 = 0L
}
- def getAccumulatorType: TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[FloatingAvgAccumulator] = {
new TupleTypeInfo(
- new FloatingAvgAccumulator().getClass,
+ classOf[FloatingAvgAccumulator],
BasicTypeInfo.DOUBLE_TYPE_INFO,
BasicTypeInfo.LONG_TYPE_INFO)
}
@@ -339,9 +339,9 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal, DecimalAvgAccu
acc.f1 = 0L
}
- def getAccumulatorType: TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[DecimalAvgAccumulator] = {
new TupleTypeInfo(
- new DecimalAvgAccumulator().getClass,
+ classOf[DecimalAvgAccumulator],
BasicTypeInfo.BIG_DEC_TYPE_INFO,
BasicTypeInfo.LONG_TYPE_INFO)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
index 2b8ec14..c94e053 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
@@ -18,6 +18,7 @@
package org.apache.flink.table.functions.aggfunctions
import java.lang.{Iterable => JIterable}
+import java.lang.{Long => JLong}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
@@ -32,7 +33,7 @@ class CountAccumulator extends JTuple1[Long] {
/**
* built-in count aggregate function
*/
-class CountAggFunction extends AggregateFunction[Long, CountAccumulator] {
+class CountAggFunction extends AggregateFunction[JLong, CountAccumulator] {
def accumulate(acc: CountAccumulator, value: Any): Unit = {
if (value != null) {
@@ -46,7 +47,7 @@ class CountAggFunction extends AggregateFunction[Long, CountAccumulator] {
}
}
- override def getValue(acc: CountAccumulator): Long = {
+ override def getValue(acc: CountAccumulator): JLong = {
acc.f0
}
@@ -65,10 +66,10 @@ class CountAggFunction extends AggregateFunction[Long, CountAccumulator] {
acc.f0 = 0L
}
- def getAccumulatorType(): TypeInformation[_] = {
- new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO)
+ override def getAccumulatorType: TypeInformation[CountAccumulator] = {
+ new TupleTypeInfo(classOf[CountAccumulator], BasicTypeInfo.LONG_TYPE_INFO)
}
- def getResultType(): TypeInformation[_] =
- BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[_]]
+ override def getResultType: TypeInformation[JLong] =
+ BasicTypeInfo.LONG_TYPE_INFO
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
index 96ee8d1..0789bee 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
@@ -76,9 +76,9 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T])
acc.f1 = false
}
- def getAccumulatorType(): TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[MaxAccumulator[T]] = {
new TupleTypeInfo(
- new MaxAccumulator[T].getClass,
+ classOf[MaxAccumulator[T]],
getValueTypeInfo,
BasicTypeInfo.BOOLEAN_TYPE_INFO)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
index 6f18739..c79c06a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
@@ -82,7 +82,7 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T])
val iterator = acc.f1.keySet().iterator()
var key = iterator.next()
acc.f0 = key
- while (iterator.hasNext()) {
+ while (iterator.hasNext) {
key = iterator.next()
if (ord.compare(acc.f0, key) < 0) {
acc.f0 = key
@@ -116,7 +116,7 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T])
}
// merge the count for each key
val iterator = a.f1.keySet().iterator()
- while (iterator.hasNext()) {
+ while (iterator.hasNext) {
val key = iterator.next()
if (acc.f1.containsKey(key)) {
acc.f1.put(key, acc.f1.get(key) + a.f1.get(key))
@@ -133,9 +133,9 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T])
acc.f1.clear()
}
- def getAccumulatorType(): TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[MaxWithRetractAccumulator[T]] = {
new TupleTypeInfo(
- new MaxWithRetractAccumulator[T].getClass,
+ classOf[MaxWithRetractAccumulator[T]],
getValueTypeInfo,
new MapTypeInfo(getValueTypeInfo, BasicTypeInfo.LONG_TYPE_INFO))
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
index 88d7afd..d2132c2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
@@ -76,9 +76,9 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T])
acc.f1 = false
}
- def getAccumulatorType(): TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[MinAccumulator[T]] = {
new TupleTypeInfo(
- new MinAccumulator[T].getClass,
+ classOf[MinAccumulator[T]],
getValueTypeInfo,
BasicTypeInfo.BOOLEAN_TYPE_INFO)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
index 2d3348b..faa6725 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
@@ -82,7 +82,7 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T])
val iterator = acc.f1.keySet().iterator()
var key = iterator.next()
acc.f0 = key
- while (iterator.hasNext()) {
+ while (iterator.hasNext) {
key = iterator.next()
if (ord.compare(acc.f0, key) > 0) {
acc.f0 = key
@@ -116,7 +116,7 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T])
}
// merge the count for each key
val iterator = a.f1.keySet().iterator()
- while (iterator.hasNext()) {
+ while (iterator.hasNext) {
val key = iterator.next()
if (acc.f1.containsKey(key)) {
acc.f1.put(key, acc.f1.get(key) + a.f1.get(key))
@@ -133,9 +133,9 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T])
acc.f1.clear()
}
- def getAccumulatorType(): TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[MinWithRetractAccumulator[T]] = {
new TupleTypeInfo(
- new MinWithRetractAccumulator[T].getClass,
+ classOf[MinWithRetractAccumulator[T]],
getValueTypeInfo,
new MapTypeInfo(getValueTypeInfo, BasicTypeInfo.LONG_TYPE_INFO))
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
index 43fc7ff..5c0b14b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
@@ -76,9 +76,9 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T, SumAccumu
acc.f1 = false
}
- def getAccumulatorType: TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[SumAccumulator[T]] = {
new TupleTypeInfo(
- (new SumAccumulator).getClass,
+ classOf[SumAccumulator[T]],
getValueTypeInfo,
BasicTypeInfo.BOOLEAN_TYPE_INFO)
}
@@ -175,9 +175,9 @@ class DecimalSumAggFunction extends AggregateFunction[BigDecimal, DecimalSumAccu
acc.f1 = false
}
- def getAccumulatorType: TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[DecimalSumAccumulator] = {
new TupleTypeInfo(
- (new DecimalSumAccumulator).getClass,
+ classOf[DecimalSumAccumulator],
BasicTypeInfo.BIG_DEC_TYPE_INFO,
BasicTypeInfo.BOOLEAN_TYPE_INFO)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
index 7f68d11..fc51b9b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
@@ -84,9 +84,9 @@ abstract class SumWithRetractAggFunction[T: Numeric]
acc.f1 = 0L
}
- def getAccumulatorType(): TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[SumWithRetractAccumulator[T]] = {
new TupleTypeInfo(
- (new SumWithRetractAccumulator).getClass,
+ classOf[SumWithRetractAccumulator[T]],
getValueTypeInfo,
BasicTypeInfo.LONG_TYPE_INFO)
}
@@ -191,9 +191,9 @@ class DecimalSumWithRetractAggFunction
acc.f1 = 0L
}
- def getAccumulatorType(): TypeInformation[_] = {
+ override def getAccumulatorType: TypeInformation[DecimalSumWithRetractAccumulator] = {
new TupleTypeInfo(
- (new DecimalSumWithRetractAccumulator).getClass,
+ classOf[DecimalSumWithRetractAccumulator],
BasicTypeInfo.BIG_DEC_TYPE_INFO,
BasicTypeInfo.LONG_TYPE_INFO)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 5e34586..47469d1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -307,64 +307,90 @@ object UserDefinedFunctionUtils {
// ----------------------------------------------------------------------------------------------
/**
- * Internal method of AggregateFunction#getResultType() that does some pre-checking and uses
- * [[TypeExtractor]] as default return type inference.
+ * Tries to infer the TypeInformation of an AggregateFunction's return type.
+ *
+ * @param aggregateFunction The AggregateFunction for which the return type is inferred.
+ * @param extractedType The implicitly inferred type of the result type.
+ *
+ * @return The inferred result type of the AggregateFunction.
*/
def getResultTypeOfAggregateFunction(
aggregateFunction: AggregateFunction[_, _],
extractedType: TypeInformation[_] = null)
: TypeInformation[_] = {
- getParameterTypeOfAggregateFunction(aggregateFunction, "getResultType", 0, extractedType)
+
+ val resultType = aggregateFunction.getResultType
+ if (resultType != null) {
+ resultType
+ } else if (extractedType != null) {
+ extractedType
+ } else {
+ try {
+ extractTypeFromAggregateFunction(aggregateFunction, 0)
+ } catch {
+ case ite: InvalidTypesException =>
+ throw new TableException(
+ "Cannot infer generic type of ${aggregateFunction.getClass}. " +
+ "You can override AggregateFunction.getResultType() to specify the type.",
+ ite
+ )
+ }
+ }
}
/**
- * Internal method of AggregateFunction#getAccumulatorType() that does some pre-checking
- * and uses [[TypeExtractor]] as default return type inference.
+ * Tries to infer the TypeInformation of an AggregateFunction's accumulator type.
+ *
+ * @param aggregateFunction The AggregateFunction for which the accumulator type is inferred.
+ * @param extractedType The implicitly inferred type of the accumulator type.
+ *
+ * @return The inferred accumulator type of the AggregateFunction.
*/
def getAccumulatorTypeOfAggregateFunction(
aggregateFunction: AggregateFunction[_, _],
extractedType: TypeInformation[_] = null)
: TypeInformation[_] = {
- getParameterTypeOfAggregateFunction(aggregateFunction, "getAccumulatorType", 1, extractedType)
- }
-
- private def getParameterTypeOfAggregateFunction(
- aggregateFunction: AggregateFunction[_, _],
- getTypeMethod: String,
- parameterTypePos: Int,
- extractedType: TypeInformation[_] = null)
- : TypeInformation[_] = {
- val resultType = try {
- val method: Method = aggregateFunction.getClass.getMethod(getTypeMethod)
- method.invoke(aggregateFunction).asInstanceOf[TypeInformation[_]]
- } catch {
- case _: NoSuchMethodException => null
- case ite: Throwable => throw new TableException("Unexpected exception:", ite)
- }
- if (resultType != null) {
- resultType
+ val accType = aggregateFunction.getAccumulatorType
+ if (accType != null) {
+ accType
} else if (extractedType != null) {
extractedType
} else {
try {
- TypeExtractor
- .createTypeInfo(aggregateFunction,
- classOf[AggregateFunction[_, _]],
- aggregateFunction.getClass,
- parameterTypePos)
- .asInstanceOf[TypeInformation[_]]
+ extractTypeFromAggregateFunction(aggregateFunction, 1)
} catch {
case ite: InvalidTypesException =>
throw new TableException(
- s"Cannot infer generic type of ${aggregateFunction.getClass}. " +
- s"You can override AggregateFunction.$getTypeMethod() to specify the type.",
- ite)
+ "Cannot infer generic type of ${aggregateFunction.getClass}. " +
+ "You can override AggregateFunction.getAccumulatorType() to specify the type.",
+ ite
+ )
}
}
}
/**
+ * Internal method to extract a type from an AggregateFunction's type parameters.
+ *
+ * @param aggregateFunction The AggregateFunction for which the type is extracted.
+ * @param parameterTypePos The position of the type parameter for which the type is extracted.
+ *
+ * @return The extracted type.
+ */
+ @throws(classOf[InvalidTypesException])
+ private def extractTypeFromAggregateFunction(
+ aggregateFunction: AggregateFunction[_, _],
+ parameterTypePos: Int): TypeInformation[_] = {
+
+ TypeExtractor.createTypeInfo(
+ aggregateFunction,
+ classOf[AggregateFunction[_, _]],
+ aggregateFunction.getClass,
+ parameterTypePos).asInstanceOf[TypeInformation[_]]
+ }
+
+ /**
* Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses
* [[TypeExtractor]] as default return type inference.
*/
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
index 70d1d21..76d33c2 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
@@ -153,5 +153,5 @@ class MyAgg2 extends AggregateFunction[Long, Row] {
override def getValue(accumulator: Row): Long = 1L
- def getAccumulatorType: TypeInformation[_] = new RowTypeInfo(Types.LONG, Types.INT)
+ override def getAccumulatorType: TypeInformation[Row] = new RowTypeInfo(Types.LONG, Types.INT)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
index f9dd474..87aaff9 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
@@ -18,22 +18,24 @@
package org.apache.flink.table.runtime.aggfunctions
+import java.lang.{Long => JLong}
+
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.aggfunctions.{CountAccumulator, CountAggFunction}
/**
* Test case for built-in count aggregate function
*/
-class CountAggFunctionTest extends AggFunctionTestBase[Long, CountAccumulator] {
+class CountAggFunctionTest extends AggFunctionTestBase[JLong, CountAccumulator] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq("a", "b", null, "c", null, "d", "e", null, "f"),
Seq(null, null, null, null, null, null)
)
- override def expectedResults: Seq[Long] = Seq(6L, 0L)
+ override def expectedResults: Seq[JLong] = Seq(6L, 0L)
- override def aggregator: AggregateFunction[Long, CountAccumulator] = new CountAggFunction()
+ override def aggregator: AggregateFunction[JLong, CountAccumulator] = new CountAggFunction()
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}