You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/04/25 13:41:42 UTC
[2/2] flink git commit: [FLINK-6361] [table] Refactor the
AggregateFunction interface and built-in aggregates.
[FLINK-6361] [table] Refactor the AggregateFunction interface and built-in aggregates.
This closes #3762.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/bc6409d6
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/bc6409d6
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/bc6409d6
Branch: refs/heads/master
Commit: bc6409d624df54c2309c8bdb767f95de74ea1475
Parents: fe01892
Author: shaoxuan-wang <ws...@gmail.com>
Authored: Tue Apr 25 00:28:37 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Apr 25 14:21:05 2017 +0200
----------------------------------------------------------------------
.../org/apache/flink/table/api/Types.scala | 3 +-
.../flink/table/codegen/CodeGenerator.scala | 60 +++---
.../table/functions/AggregateFunction.scala | 157 +++++++-------
.../functions/aggfunctions/AvgAggFunction.scala | 206 ++++++++-----------
.../aggfunctions/CountAggFunction.scala | 39 ++--
.../functions/aggfunctions/MaxAggFunction.scala | 48 ++---
.../MaxAggFunctionWithRetract.scala | 86 ++++----
.../functions/aggfunctions/MinAggFunction.scala | 48 ++---
.../MinAggFunctionWithRetract.scala | 86 ++++----
.../functions/aggfunctions/SumAggFunction.scala | 89 ++++----
.../SumWithRetractAggFunction.scala | 107 +++++-----
.../table/runtime/aggregate/AggregateUtil.scala | 96 ++++++---
.../aggregate/GeneratedAggregations.scala | 32 +++
.../aggfunctions/AggFunctionTestBase.scala | 62 +++---
.../aggfunctions/AvgFunctionTest.scala | 23 ++-
.../aggfunctions/CountAggFunctionTest.scala | 6 +-
.../aggfunctions/MaxAggFunctionTest.scala | 44 ++--
.../MaxWithRetractAggFunctionTest.scala | 47 +++--
.../aggfunctions/MinAggFunctionTest.scala | 45 ++--
.../MinWithRetractAggFunctionTest.scala | 47 +++--
.../aggfunctions/SumAggFunctionTest.scala | 31 +--
.../SumWithRetractAggFunctionTest.scala | 31 ++-
...ProcessingOverRangeProcessFunctionTest.scala | 28 ++-
23 files changed, 772 insertions(+), 649 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala
index 262a452..d82b990 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala
@@ -17,9 +17,8 @@
*/
package org.apache.flink.table.api
-import org.apache.flink.api.common.typeinfo.{Types, TypeInformation}
+import org.apache.flink.api.common.typeinfo.{Types => JTypes, TypeInformation}
import org.apache.flink.table.typeutils.TimeIntervalTypeInfo
-import org.apache.flink.api.common.typeinfo.{Types => JTypes}
/**
* This class enumerates all supported types of the Table API.
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index 510a870..298fb70 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -265,14 +265,16 @@ class CodeGenerator(
name: String,
generator: CodeGenerator,
inputType: RelDataType,
- aggregates: Array[AggregateFunction[_ <: Any]],
+ aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]],
aggFields: Array[Array[Int]],
aggMapping: Array[Int],
partialResults: Boolean,
fwdMapping: Array[Int],
mergeMapping: Option[Array[Int]],
constantFlags: Option[Array[(Int, Boolean)]],
- outputArity: Int)
+ outputArity: Int,
+ needRetract: Boolean,
+ needMerge: Boolean)
: GeneratedAggregationsFunction = {
// get unique function name
@@ -364,9 +366,16 @@ class CodeGenerator(
| ${parameters(i)});""".stripMargin
}.mkString("\n")
- j"""$sig {
- |$retract
- | }""".stripMargin
+ if (needRetract) {
+ j"""
+ |$sig {
+ |$retract
+ | }""".stripMargin
+ } else {
+ j"""
+ |$sig {
+ | }""".stripMargin
+ }
}
def genCreateAccumulators: String = {
@@ -471,11 +480,9 @@ class CodeGenerator(
j"""
| ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
| ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
- | accList$i.set(0, aAcc$i);
- | accList$i.set(1, bAcc$i);
- | a.setField(
- | $i,
- | ${aggs(i)}.merge(accList$i));
+ | accIt$i.setElement(bAcc$i);
+ | ${aggs(i)}.merge(aAcc$i, accIt$i);
+ | a.setField($i, aAcc$i);
""".stripMargin
}.mkString("\n")
val ret: String =
@@ -483,29 +490,27 @@ class CodeGenerator(
| return a;
""".stripMargin
- j"""
- |$sig {
- |$merge
- |$ret
- | }""".stripMargin
+ if (needMerge) {
+ j"""
+ |$sig {
+ |$merge
+ |$ret
+ | }""".stripMargin
+ } else {
+ j"""
+ |$sig {
+ |$ret
+ | }""".stripMargin
+ }
}
def genMergeList: String = {
{
+ val singleIterableClass = "org.apache.flink.table.runtime.aggregate.SingleElementIterable"
for (i <- accTypes.indices) yield
j"""
- | private final java.util.ArrayList<${accTypes(i)}> accList$i =
- | new java.util.ArrayList<${accTypes(i)}>(2);
- """.stripMargin
- }.mkString("\n")
- }
-
- def initMergeList: String = {
- {
- for (i <- accTypes.indices) yield
- j"""
- | accList$i.add(${aggs(i)}.createAccumulator());
- | accList$i.add(${aggs(i)}.createAccumulator());
+ | private final $singleIterableClass<${accTypes(i)}> accIt$i =
+ | new $singleIterableClass<${accTypes(i)}>();
""".stripMargin
}.mkString("\n")
}
@@ -538,7 +543,6 @@ class CodeGenerator(
| $genMergeList
| public $funcName() throws Exception {
| ${reuseInitCode()}
- | $initMergeList
| }
| ${reuseConstructorCode(funcName)}
|
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index a67ccaa..7a74112 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -17,36 +17,100 @@
*/
package org.apache.flink.table.functions
-import java.util.{List => JList}
-
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.table.api.TableException
-
/**
* Base class for User-Defined Aggregates.
*
- * @tparam T the type of the aggregation result
+ * The behavior of an [[AggregateFunction]] can be defined by implementing a series of custom
+ * methods. An [[AggregateFunction]] needs at least three methods:
+ * - createAccumulator,
+ * - accumulate, and
+ * - getValue.
+ *
+ * There are a few other methods that can be optional to have:
+ * - retract,
+ * - merge,
+ * - resetAccumulator, and
+ * - getAccumulatorType.
+ *
+ * All these methods muse be declared publicly, not static and named exactly as the names
+ * mentioned above. The methods createAccumulator and getValue are defined in the
+ * [[AggregateFunction]] functions, while other methods are explained below.
+ *
+ *
+ * {{{
+ * Processes the input values and update the provided accumulator instance. The method
+ * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
+ * requires at least one accumulate() method.
+ *
+ * @param accumulator the accumulator which contains the current aggregated results
+ * @param [user defined inputs] the input value (usually obtained from a new arrived data).
+ *
+ * def accumulate(accumulator: ACC, [user defined inputs]): Unit
+ * }}}
+ *
+ *
+ * {{{
+ * Retracts the input values from the accumulator instance. The current design assumes the
+ * inputs are the values that have been previously accumulated. The method retract can be
+ * overloaded with different custom types and arguments. This function must be implemented for
+ * datastream bounded over aggregate.
+ *
+ * @param accumulator the accumulator which contains the current aggregated results
+ * @param [user defined inputs] the input value (usually obtained from a new arrived data).
+ *
+ * def retract(accumulator: ACC, [user defined inputs]): Unit
+ * }}}
+ *
+ *
+ * {{{
+ * Merges a group of accumulator instances into one accumulator instance. This function must be
+ * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
+ *
+ * @param accumulator the accumulator which will keep the merged aggregate results. It should
+ * be noted that the accumulator may contain the previous aggregated
+ * results. Therefore user should not replace or clean this instance in the
+ * custom merge method.
+ * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be
+ * merged.
+
+ * def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit
+ * }}}
+ *
+ *
+ * {{{
+ * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
+ * dataset grouping aggregate.
+ *
+ * @param accumulator the accumulator which needs to be reset
+
+ * def resetAccumulator(accumulator: ACC): Unit
+ * }}}
+ *
+ *
+ * {{{
+ * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the accumulator. This
+ * function is optional and can be implemented if the accumulator type cannot automatically
+ * inferred from the instance returned by createAccumulator method.
+ *
+ * @return the type information for the accumulator.
+
+ * def getAccumulatorType: TypeInformation[_]
+ * }}}
+ *
+ *
+ * @tparam T the type of the aggregation result
+ * @tparam ACC base class for aggregate Accumulator. The accumulator is used to keep the aggregated
+ * values which are needed to compute an aggregation result. AggregateFunction
+ * represents its state using accumulator, thereby the state of the AggregateFunction
+ * must be put into the accumulator.
*/
-abstract class AggregateFunction[T] extends UserDefinedFunction {
+abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
/**
* Creates and init the Accumulator for this [[AggregateFunction]].
*
* @return the accumulator with the initial value
*/
- def createAccumulator(): Accumulator
-
- /**
- * Retracts the input values from the accumulator instance. The current design assumes the
- * inputs are the values that have been previously accumulated.
- *
- * @param accumulator the accumulator which contains the current
- * aggregated results
- * @param input the input value (usually obtained from a new arrived data)
- */
- def retract(accumulator: Accumulator, input: Any): Unit = {
- throw TableException("Retract is an optional method. There is no default implementation. You " +
- "must implement one for yourself.")
- }
+ def createAccumulator(): ACC
/**
* Called every time when an aggregation result should be materialized.
@@ -58,54 +122,5 @@ abstract class AggregateFunction[T] extends UserDefinedFunction {
* aggregated results
* @return the aggregation result
*/
- def getValue(accumulator: Accumulator): T
-
- /**
- * Processes the input values and update the provided accumulator instance.
- *
- * @param accumulator the accumulator which contains the current
- * aggregated results
- * @param input the input value (usually obtained from a new arrived data)
- */
- def accumulate(accumulator: Accumulator, input: Any): Unit
-
- /**
- * Merges a list of accumulator instances into one accumulator instance.
- *
- * IMPORTANT: You may only return a new accumulator instance or the first accumulator of the
- * input list. If you return another instance, the result of the aggregation function might be
- * incorrect.
- *
- * @param accumulators the [[java.util.List]] of accumulators that will be merged
- * @return the resulting accumulator
- */
- def merge(accumulators: JList[Accumulator]): Accumulator
-
- /**
- * Resets the Accumulator for this [[AggregateFunction]].
- *
- * @param accumulator the accumulator which needs to be reset
- */
- def resetAccumulator(accumulator: Accumulator): Unit
-
- /**
- * Returns the [[TypeInformation]] of the accumulator.
- * This function is optional and can be implemented if the accumulator type cannot automatically
- * inferred from the instance returned by [[createAccumulator()]].
- *
- * @return The type information for the accumulator.
- */
- def getAccumulatorType: TypeInformation[_] = null
+ def getValue(accumulator: ACC): T
}
-
-/**
- * Base class for aggregate Accumulator. The accumulator is used to keep the
- * aggregated values which are needed to compute an aggregation result.
- * The state of the function must be put into the accumulator.
- *
- * TODO: We have the plan to have the accumulator and return types of
- * functions dynamically provided by the users. This needs the refactoring
- * of the AggregateFunction interface with the code generation. We will remove
- * the [[Accumulator]] once codeGen for UDAGG is completed (FLINK-5813).
- */
-trait Accumulator
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
index 4837139..3f4e5db 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
@@ -18,15 +18,15 @@
package org.apache.flink.table.functions.aggfunctions
import java.math.{BigDecimal, BigInteger}
-import java.util.{List => JList}
+import java.lang.{Iterable => JIterable}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
/** The initial accumulator for Integral Avg aggregate function */
-class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator {
+class IntegralAvgAccumulator extends JTuple2[Long, Long] {
f0 = 0L //sum
f1 = 0L //count
}
@@ -36,57 +36,51 @@ class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator {
*
* @tparam T the type for the aggregation result
*/
-abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] {
+abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T, IntegralAvgAccumulator] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): IntegralAvgAccumulator = {
new IntegralAvgAccumulator
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: IntegralAvgAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[Number].longValue()
- val accum = accumulator.asInstanceOf[IntegralAvgAccumulator]
- accum.f0 += v
- accum.f1 += 1L
+ acc.f0 += v
+ acc.f1 += 1L
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: IntegralAvgAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[Number].longValue()
- val accum = accumulator.asInstanceOf[IntegralAvgAccumulator]
- accum.f0 -= v
- accum.f1 -= 1L
+ acc.f0 -= v
+ acc.f1 -= 1L
}
}
- override def getValue(accumulator: Accumulator): T = {
- val accum = accumulator.asInstanceOf[IntegralAvgAccumulator]
- if (accum.f1 == 0) {
+ override def getValue(acc: IntegralAvgAccumulator): T = {
+ if (acc.f1 == 0) {
null.asInstanceOf[T]
} else {
- resultTypeConvert(accum.f0 / accum.f1)
+ resultTypeConvert(acc.f0 / acc.f1)
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[IntegralAvgAccumulator]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[IntegralAvgAccumulator]
- ret.f1 += a.f1
- ret.f0 += a.f0
- i += 1
+ def merge(acc: IntegralAvgAccumulator, its: JIterable[IntegralAvgAccumulator]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
+ acc.f1 += a.f1
+ acc.f0 += a.f0
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[IntegralAvgAccumulator].f0 = 0L
- accumulator.asInstanceOf[IntegralAvgAccumulator].f1 = 0L
+ def resetAccumulator(acc: IntegralAvgAccumulator): Unit = {
+ acc.f0 = 0L
+ acc.f1 = 0L
}
- override def getAccumulatorType: TypeInformation[_] = {
+ def getAccumulatorType: TypeInformation[_] = {
new TupleTypeInfo(
new IntegralAvgAccumulator().getClass,
BasicTypeInfo.LONG_TYPE_INFO,
@@ -126,7 +120,7 @@ class IntAvgAggFunction extends IntegralAvgAggFunction[Int] {
/** The initial accumulator for Big Integral Avg aggregate function */
class BigIntegralAvgAccumulator
- extends JTuple2[BigInteger, Long] with Accumulator {
+ extends JTuple2[BigInteger, Long] {
f0 = BigInteger.ZERO //sum
f1 = 0L //count
}
@@ -136,57 +130,52 @@ class BigIntegralAvgAccumulator
*
* @tparam T the type for the aggregation result
*/
-abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] {
+abstract class BigIntegralAvgAggFunction[T]
+ extends AggregateFunction[T, BigIntegralAvgAccumulator] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): BigIntegralAvgAccumulator = {
new BigIntegralAvgAccumulator
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: BigIntegralAvgAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[Long]
- val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator]
- a.f0 = a.f0.add(BigInteger.valueOf(v))
- a.f1 += 1L
+ acc.f0 = acc.f0.add(BigInteger.valueOf(v))
+ acc.f1 += 1L
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: BigIntegralAvgAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[Long]
- val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator]
- a.f0 = a.f0.subtract(BigInteger.valueOf(v))
- a.f1 -= 1L
+ acc.f0 = acc.f0.subtract(BigInteger.valueOf(v))
+ acc.f1 -= 1L
}
}
- override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator]
- if (a.f1 == 0) {
+ override def getValue(acc: BigIntegralAvgAccumulator): T = {
+ if (acc.f1 == 0) {
null.asInstanceOf[T]
} else {
- resultTypeConvert(a.f0.divide(BigInteger.valueOf(a.f1)))
+ resultTypeConvert(acc.f0.divide(BigInteger.valueOf(acc.f1)))
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[BigIntegralAvgAccumulator]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[BigIntegralAvgAccumulator]
- ret.f1 += a.f1
- ret.f0 = ret.f0.add(a.f0)
- i += 1
+ def merge(acc: BigIntegralAvgAccumulator, its: JIterable[BigIntegralAvgAccumulator]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
+ acc.f1 += a.f1
+ acc.f0 = acc.f0.add(a.f0)
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[BigIntegralAvgAccumulator].f0 = BigInteger.ZERO
- accumulator.asInstanceOf[BigIntegralAvgAccumulator].f1 = 0
+ def resetAccumulator(acc: BigIntegralAvgAccumulator): Unit = {
+ acc.f0 = BigInteger.ZERO
+ acc.f1 = 0
}
- override def getAccumulatorType: TypeInformation[_] = {
+ def getAccumulatorType: TypeInformation[_] = {
new TupleTypeInfo(
new BigIntegralAvgAccumulator().getClass,
BasicTypeInfo.BIG_INT_TYPE_INFO,
@@ -212,7 +201,7 @@ class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] {
}
/** The initial accumulator for Floating Avg aggregate function */
-class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator {
+class FloatingAvgAccumulator extends JTuple2[Double, Long] {
f0 = 0 //sum
f1 = 0L //count
}
@@ -222,57 +211,51 @@ class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator {
*
* @tparam T the type for the aggregation result
*/
-abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] {
+abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T, FloatingAvgAccumulator] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): FloatingAvgAccumulator = {
new FloatingAvgAccumulator
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: FloatingAvgAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[Number].doubleValue()
- val accum = accumulator.asInstanceOf[FloatingAvgAccumulator]
- accum.f0 += v
- accum.f1 += 1L
+ acc.f0 += v
+ acc.f1 += 1L
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: FloatingAvgAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[Number].doubleValue()
- val accum = accumulator.asInstanceOf[FloatingAvgAccumulator]
- accum.f0 -= v
- accum.f1 -= 1L
+ acc.f0 -= v
+ acc.f1 -= 1L
}
}
- override def getValue(accumulator: Accumulator): T = {
- val accum = accumulator.asInstanceOf[FloatingAvgAccumulator]
- if (accum.f1 == 0) {
+ override def getValue(acc: FloatingAvgAccumulator): T = {
+ if (acc.f1 == 0) {
null.asInstanceOf[T]
} else {
- resultTypeConvert(accum.f0 / accum.f1)
+ resultTypeConvert(acc.f0 / acc.f1)
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[FloatingAvgAccumulator]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[FloatingAvgAccumulator]
- ret.f1 += a.f1
- ret.f0 += a.f0
- i += 1
+ def merge(acc: FloatingAvgAccumulator, its: JIterable[FloatingAvgAccumulator]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
+ acc.f1 += a.f1
+ acc.f0 += a.f0
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[FloatingAvgAccumulator].f0 = 0
- accumulator.asInstanceOf[FloatingAvgAccumulator].f1 = 0L
+ def resetAccumulator(acc: FloatingAvgAccumulator): Unit = {
+ acc.f0 = 0
+ acc.f1 = 0L
}
- override def getAccumulatorType: TypeInformation[_] = {
+ def getAccumulatorType: TypeInformation[_] = {
new TupleTypeInfo(
new FloatingAvgAccumulator().getClass,
BasicTypeInfo.DOUBLE_TYPE_INFO,
@@ -304,8 +287,7 @@ class DoubleAvgAggFunction extends FloatingAvgAggFunction[Double] {
}
/** The initial accumulator for Big Decimal Avg aggregate function */
-class DecimalAvgAccumulator
- extends JTuple2[BigDecimal, Long] with Accumulator {
+class DecimalAvgAccumulator extends JTuple2[BigDecimal, Long] {
f0 = BigDecimal.ZERO //sum
f1 = 0L //count
}
@@ -313,57 +295,51 @@ class DecimalAvgAccumulator
/**
* Base class for built-in Big Decimal Avg aggregate function
*/
-class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] {
+class DecimalAvgAggFunction extends AggregateFunction[BigDecimal, DecimalAvgAccumulator] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): DecimalAvgAccumulator = {
new DecimalAvgAccumulator
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: DecimalAvgAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[DecimalAvgAccumulator]
- accum.f0 = accum.f0.add(v)
- accum.f1 += 1L
+ acc.f0 = acc.f0.add(v)
+ acc.f1 += 1L
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: DecimalAvgAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[DecimalAvgAccumulator]
- accum.f0 = accum.f0.subtract(v)
- accum.f1 -= 1L
+ acc.f0 = acc.f0.subtract(v)
+ acc.f1 -= 1L
}
}
- override def getValue(accumulator: Accumulator): BigDecimal = {
- val a = accumulator.asInstanceOf[DecimalAvgAccumulator]
- if (a.f1 == 0) {
+ override def getValue(acc: DecimalAvgAccumulator): BigDecimal = {
+ if (acc.f1 == 0) {
null.asInstanceOf[BigDecimal]
} else {
- a.f0.divide(BigDecimal.valueOf(a.f1))
+ acc.f0.divide(BigDecimal.valueOf(acc.f1))
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[DecimalAvgAccumulator]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[DecimalAvgAccumulator]
- ret.f0 = ret.f0.add(a.f0)
- ret.f1 += a.f1
- i += 1
+ def merge(acc: DecimalAvgAccumulator, its: JIterable[DecimalAvgAccumulator]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
+ acc.f0 = acc.f0.add(a.f0)
+ acc.f1 += a.f1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[DecimalAvgAccumulator].f0 = BigDecimal.ZERO
- accumulator.asInstanceOf[DecimalAvgAccumulator].f1 = 0L
+ def resetAccumulator(acc: DecimalAvgAccumulator): Unit = {
+ acc.f0 = BigDecimal.ZERO
+ acc.f1 = 0L
}
- override def getAccumulatorType: TypeInformation[_] = {
+ def getAccumulatorType: TypeInformation[_] = {
new TupleTypeInfo(
new DecimalAvgAccumulator().getClass,
BasicTypeInfo.BIG_DEC_TYPE_INFO,
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
index 231337a..77341cd 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
@@ -17,58 +17,55 @@
*/
package org.apache.flink.table.functions.aggfunctions
-import java.util.{List => JList}
+import java.lang.{Iterable => JIterable}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
/** The initial accumulator for count aggregate function */
-class CountAccumulator extends JTuple1[Long] with Accumulator {
+class CountAccumulator extends JTuple1[Long] {
f0 = 0L //count
}
/**
* built-in count aggregate function
*/
-class CountAggFunction extends AggregateFunction[Long] {
+class CountAggFunction extends AggregateFunction[Long, CountAccumulator] {
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: CountAccumulator, value: Any): Unit = {
if (value != null) {
- accumulator.asInstanceOf[CountAccumulator].f0 += 1L
+ acc.f0 += 1L
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: CountAccumulator, value: Any): Unit = {
if (value != null) {
- accumulator.asInstanceOf[CountAccumulator].f0 -= 1L
+ acc.f0 -= 1L
}
}
- override def getValue(accumulator: Accumulator): Long = {
- accumulator.asInstanceOf[CountAccumulator].f0
+ override def getValue(acc: CountAccumulator): Long = {
+ acc.f0
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[CountAccumulator]
- var i: Int = 1
- while (i < accumulators.size()) {
- ret.f0 += accumulators.get(i).asInstanceOf[CountAccumulator].f0
- i += 1
+ def merge(acc: CountAccumulator, its: JIterable[CountAccumulator]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ acc.f0 += iter.next().f0
}
- ret
}
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): CountAccumulator = {
new CountAccumulator
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[CountAccumulator].f0 = 0L
+ def resetAccumulator(acc: CountAccumulator): Unit = {
+ acc.f0 = 0L
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
index 2e666fa..96ee8d1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
@@ -18,69 +18,65 @@
package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
-import java.util.{List => JList}
+import java.lang.{Iterable => JIterable}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
/** The initial accumulator for Max aggregate function */
-class MaxAccumulator[T] extends JTuple2[T, Boolean] with Accumulator
+class MaxAccumulator[T] extends JTuple2[T, Boolean]
/**
* Base class for built-in Max aggregate function
*
* @tparam T the type for the aggregation result
*/
-abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] {
+abstract class MaxAggFunction[T](implicit ord: Ordering[T])
+ extends AggregateFunction[T, MaxAccumulator[T]] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): MaxAccumulator[T] = {
val acc = new MaxAccumulator[T]
acc.f0 = getInitValue
acc.f1 = false
acc
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: MaxAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MaxAccumulator[T]]
- if (!a.f1 || ord.compare(a.f0, v) < 0) {
- a.f0 = v
- a.f1 = true
+ if (!acc.f1 || ord.compare(acc.f0, v) < 0) {
+ acc.f0 = v
+ acc.f1 = true
}
}
}
- override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[MaxAccumulator[T]]
- if (a.f1) {
- a.f0
+ override def getValue(acc: MaxAccumulator[T]): T = {
+ if (acc.f1) {
+ acc.f0
} else {
null.asInstanceOf[T]
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0)
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[MaxAccumulator[T]]
+ def merge(acc: MaxAccumulator[T], its: JIterable[MaxAccumulator[T]]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
if (a.f1) {
- accumulate(ret.asInstanceOf[MaxAccumulator[T]], a.f0)
+ accumulate(acc, a.f0)
}
- i += 1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[MaxAccumulator[T]].f0 = getInitValue
- accumulator.asInstanceOf[MaxAccumulator[T]].f1 = false
+ def resetAccumulator(acc: MaxAccumulator[T]): Unit = {
+ acc.f0 = getInitValue
+ acc.f1 = false
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo(
new MaxAccumulator[T].getClass,
getValueTypeInfo,
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
index 14ceba2..6f18739 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
@@ -18,15 +18,16 @@
package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
-import java.util.{HashMap => JHashMap, List => JList}
+import java.util.{HashMap => JHashMap}
+import java.lang.{Iterable => JIterable}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.{MapTypeInfo, TupleTypeInfo}
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
/** The initial accumulator for Max with retraction aggregate function */
-class MaxWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] with Accumulator
+class MaxWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]]
/**
* Base class for built-in Max with retraction aggregate function
@@ -34,110 +35,105 @@ class MaxWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] with Ac
* @tparam T the type for the aggregation result
*/
abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T])
- extends AggregateFunction[T] {
+ extends AggregateFunction[T, MaxWithRetractAccumulator[T]] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): MaxWithRetractAccumulator[T] = {
val acc = new MaxWithRetractAccumulator[T]
acc.f0 = getInitValue //max
acc.f1 = new JHashMap[T, Long]() //store the count for each value
acc
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: MaxWithRetractAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MaxWithRetractAccumulator[T]]
- if (a.f1.size() == 0 || (ord.compare(a.f0, v) < 0)) {
- a.f0 = v
+ if (acc.f1.size() == 0 || (ord.compare(acc.f0, v) < 0)) {
+ acc.f0 = v
}
- if (!a.f1.containsKey(v)) {
- a.f1.put(v, 1L)
+ if (!acc.f1.containsKey(v)) {
+ acc.f1.put(v, 1L)
} else {
- var count = a.f1.get(v)
+ var count = acc.f1.get(v)
count += 1L
- a.f1.put(v, count)
+ acc.f1.put(v, count)
}
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: MaxWithRetractAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MaxWithRetractAccumulator[T]]
- var count = a.f1.get(v)
+ var count = acc.f1.get(v)
count -= 1L
if (count == 0) {
//remove the key v from the map if the number of appearance of the value v is 0
- a.f1.remove(v)
+ acc.f1.remove(v)
//if the total count is 0, we could just simply set the f0(max) to the initial value
- if (a.f1.size() == 0) {
- a.f0 = getInitValue
+ if (acc.f1.size() == 0) {
+ acc.f0 = getInitValue
return
}
//if v is the current max value, we have to iterate the map to find the 2nd biggest
// value to replace v as the max value
- if (v == a.f0) {
- val iterator = a.f1.keySet().iterator()
+ if (v == acc.f0) {
+ val iterator = acc.f1.keySet().iterator()
var key = iterator.next()
- a.f0 = key
+ acc.f0 = key
while (iterator.hasNext()) {
key = iterator.next()
- if (ord.compare(a.f0, key) < 0) {
- a.f0 = key
+ if (ord.compare(acc.f0, key) < 0) {
+ acc.f0 = key
}
}
}
} else {
- a.f1.put(v, count)
+ acc.f1.put(v, count)
}
}
}
- override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[MaxWithRetractAccumulator[T]]
- if (a.f1.size() != 0) {
- a.f0
+ override def getValue(acc: MaxWithRetractAccumulator[T]): T = {
+ if (acc.f1.size() != 0) {
+ acc.f0
} else {
null.asInstanceOf[T]
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[MaxWithRetractAccumulator[T]]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[MaxWithRetractAccumulator[T]]
+ def merge(acc: MaxWithRetractAccumulator[T],
+ its: JIterable[MaxWithRetractAccumulator[T]]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
if (a.f1.size() != 0) {
// set max element
- if (ord.compare(ret.f0, a.f0) < 0) {
- ret.f0 = a.f0
+ if (ord.compare(acc.f0, a.f0) < 0) {
+ acc.f0 = a.f0
}
// merge the count for each key
val iterator = a.f1.keySet().iterator()
while (iterator.hasNext()) {
val key = iterator.next()
- if (ret.f1.containsKey(key)) {
- ret.f1.put(key, ret.f1.get(key) + a.f1.get(key))
+ if (acc.f1.containsKey(key)) {
+ acc.f1.put(key, acc.f1.get(key) + a.f1.get(key))
} else {
- ret.f1.put(key, a.f1.get(key))
+ acc.f1.put(key, a.f1.get(key))
}
}
}
- i += 1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[MaxWithRetractAccumulator[T]].f0 = getInitValue
- accumulator.asInstanceOf[MaxWithRetractAccumulator[T]].f1.clear()
+ def resetAccumulator(acc: MaxWithRetractAccumulator[T]): Unit = {
+ acc.f0 = getInitValue
+ acc.f1.clear()
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo(
new MaxWithRetractAccumulator[T].getClass,
getValueTypeInfo,
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
index 75a8ebc..88d7afd 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
@@ -18,69 +18,65 @@
package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
-import java.util.{List => JList}
+import java.lang.{Iterable => JIterable}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
/** The initial accumulator for Min aggregate function */
-class MinAccumulator[T] extends JTuple2[T, Boolean] with Accumulator
+class MinAccumulator[T] extends JTuple2[T, Boolean]
/**
* Base class for built-in Min aggregate function
*
* @tparam T the type for the aggregation result
*/
-abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] {
+abstract class MinAggFunction[T](implicit ord: Ordering[T])
+ extends AggregateFunction[T, MinAccumulator[T]] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): MinAccumulator[T] = {
val acc = new MinAccumulator[T]
acc.f0 = getInitValue
acc.f1 = false
acc
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: MinAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MinAccumulator[T]]
- if (!a.f1 || ord.compare(a.f0, v) > 0) {
- a.f0 = v
- a.f1 = true
+ if (!acc.f1 || ord.compare(acc.f0, v) > 0) {
+ acc.f0 = v
+ acc.f1 = true
}
}
}
- override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[MinAccumulator[T]]
- if (a.f1) {
- a.f0
+ override def getValue(acc: MinAccumulator[T]): T = {
+ if (acc.f1) {
+ acc.f0
} else {
null.asInstanceOf[T]
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0)
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[MinAccumulator[T]]
+ def merge(acc: MinAccumulator[T], its: JIterable[MinAccumulator[T]]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
if (a.f1) {
- accumulate(ret.asInstanceOf[MinAccumulator[T]], a.f0)
+ accumulate(acc, a.f0)
}
- i += 1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[MinAccumulator[T]].f0 = getInitValue
- accumulator.asInstanceOf[MinAccumulator[T]].f1 = false
+ def resetAccumulator(acc: MinAccumulator[T]): Unit = {
+ acc.f0 = getInitValue
+ acc.f1 = false
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo(
new MinAccumulator[T].getClass,
getValueTypeInfo,
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
index 6f2c3a1..2d3348b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
@@ -18,15 +18,16 @@
package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
-import java.util.{HashMap => JHashMap, List => JList}
+import java.util.{HashMap => JHashMap}
+import java.lang.{Iterable => JIterable}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.{MapTypeInfo, TupleTypeInfo}
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
/** The initial accumulator for Min with retraction aggregate function */
-class MinWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] with Accumulator
+class MinWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]]
/**
* Base class for built-in Min with retraction aggregate function
@@ -34,110 +35,105 @@ class MinWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] with Ac
* @tparam T the type for the aggregation result
*/
abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T])
- extends AggregateFunction[T] {
+ extends AggregateFunction[T, MinWithRetractAccumulator[T]] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): MinWithRetractAccumulator[T] = {
val acc = new MinWithRetractAccumulator[T]
acc.f0 = getInitValue //min
acc.f1 = new JHashMap[T, Long]() //store the count for each value
acc
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: MinWithRetractAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MinWithRetractAccumulator[T]]
- if (a.f1.size() == 0 || (ord.compare(a.f0, v) > 0)) {
- a.f0 = v
+ if (acc.f1.size() == 0 || (ord.compare(acc.f0, v) > 0)) {
+ acc.f0 = v
}
- if (!a.f1.containsKey(v)) {
- a.f1.put(v, 1L)
+ if (!acc.f1.containsKey(v)) {
+ acc.f1.put(v, 1L)
} else {
- var count = a.f1.get(v)
+ var count = acc.f1.get(v)
count += 1L
- a.f1.put(v, count)
+ acc.f1.put(v, count)
}
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: MinWithRetractAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MinWithRetractAccumulator[T]]
- var count = a.f1.get(v)
+ var count = acc.f1.get(v)
count -= 1L
if (count == 0) {
//remove the key v from the map if the number of appearance of the value v is 0
- a.f1.remove(v)
+ acc.f1.remove(v)
//if the total count is 0, we could just simply set the f0(min) to the initial value
- if (a.f1.size() == 0) {
- a.f0 = getInitValue
+ if (acc.f1.size() == 0) {
+ acc.f0 = getInitValue
return
}
//if v is the current min value, we have to iterate the map to find the 2nd smallest
// value to replace v as the min value
- if (v == a.f0) {
- val iterator = a.f1.keySet().iterator()
+ if (v == acc.f0) {
+ val iterator = acc.f1.keySet().iterator()
var key = iterator.next()
- a.f0 = key
+ acc.f0 = key
while (iterator.hasNext()) {
key = iterator.next()
- if (ord.compare(a.f0, key) > 0) {
- a.f0 = key
+ if (ord.compare(acc.f0, key) > 0) {
+ acc.f0 = key
}
}
}
} else {
- a.f1.put(v, count)
+ acc.f1.put(v, count)
}
}
}
- override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[MinWithRetractAccumulator[T]]
- if (a.f1.size() != 0) {
- a.f0
+ override def getValue(acc: MinWithRetractAccumulator[T]): T = {
+ if (acc.f1.size() != 0) {
+ acc.f0
} else {
null.asInstanceOf[T]
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[MinWithRetractAccumulator[T]]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[MinWithRetractAccumulator[T]]
+ def merge(acc: MinWithRetractAccumulator[T],
+ its: JIterable[MinWithRetractAccumulator[T]]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
if (a.f1.size() != 0) {
// set min element
- if (ord.compare(ret.f0, a.f0) > 0) {
- ret.f0 = a.f0
+ if (ord.compare(acc.f0, a.f0) > 0) {
+ acc.f0 = a.f0
}
// merge the count for each key
val iterator = a.f1.keySet().iterator()
while (iterator.hasNext()) {
val key = iterator.next()
- if (ret.f1.containsKey(key)) {
- ret.f1.put(key, ret.f1.get(key) + a.f1.get(key))
+ if (acc.f1.containsKey(key)) {
+ acc.f1.put(key, acc.f1.get(key) + a.f1.get(key))
} else {
- ret.f1.put(key, a.f1.get(key))
+ acc.f1.put(key, a.f1.get(key))
}
}
}
- i += 1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[MinWithRetractAccumulator[T]].f0 = getInitValue
- accumulator.asInstanceOf[MinWithRetractAccumulator[T]].f1.clear()
+ def resetAccumulator(acc: MinWithRetractAccumulator[T]): Unit = {
+ acc.f0 = getInitValue
+ acc.f1.clear()
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo(
new MinWithRetractAccumulator[T].getClass,
getValueTypeInfo,
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
index 8ee9862..55996ac 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
@@ -18,70 +18,65 @@
package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
-import java.util.{List => JList}
+import java.lang.{Iterable => JIterable}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
/** The initial accumulator for Sum aggregate function */
-class SumAccumulator[T] extends JTuple2[T, Boolean] with Accumulator
+class SumAccumulator[T] extends JTuple2[T, Boolean]
/**
* Base class for built-in Sum aggregate function
*
* @tparam T the type for the aggregation result
*/
-abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] {
+abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T, SumAccumulator[T]] {
private val numeric = implicitly[Numeric[T]]
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): SumAccumulator[T] = {
val acc = new SumAccumulator[T]()
acc.f0 = numeric.zero //sum
acc.f1 = false
acc
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(accumulator: SumAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[SumAccumulator[T]]
- a.f0 = numeric.plus(v, a.f0)
- a.f1 = true
+ accumulator.f0 = numeric.plus(v, accumulator.f0)
+ accumulator.f1 = true
}
}
- override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[SumAccumulator[T]]
- if (a.f1) {
- a.f0
+ override def getValue(accumulator: SumAccumulator[T]): T = {
+ if (accumulator.f1) {
+ accumulator.f0
} else {
null.asInstanceOf[T]
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[SumAccumulator[T]]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[SumAccumulator[T]]
+ def merge(acc: SumAccumulator[T], its: JIterable[SumAccumulator[T]]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
if (a.f1) {
- ret.f0 = numeric.plus(ret.f0, a.f0)
- ret.f1 = true
+ acc.f0 = numeric.plus(acc.f0, a.f0)
+ acc.f1 = true
}
- i += 1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[SumAccumulator[T]].f0 = numeric.zero
- accumulator.asInstanceOf[SumAccumulator[T]].f1 = false
+ def resetAccumulator(acc: SumAccumulator[T]): Unit = {
+ acc.f0 = numeric.zero
+ acc.f1 = false
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo(
(new SumAccumulator).getClass,
getValueTypeInfo,
@@ -134,7 +129,7 @@ class DoubleSumAggFunction extends SumAggFunction[Double] {
}
/** The initial accumulator for Big Decimal Sum aggregate function */
-class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator {
+class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] {
f0 = BigDecimal.ZERO
f1 = false
}
@@ -142,49 +137,45 @@ class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulato
/**
* Built-in Big Decimal Sum aggregate function
*/
-class DecimalSumAggFunction extends AggregateFunction[BigDecimal] {
+class DecimalSumAggFunction extends AggregateFunction[BigDecimal, DecimalSumAccumulator] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): DecimalSumAccumulator = {
new DecimalSumAccumulator
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: DecimalSumAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[DecimalSumAccumulator]
- accum.f0 = accum.f0.add(v)
- accum.f1 = true
+ acc.f0 = acc.f0.add(v)
+ acc.f1 = true
}
}
- override def getValue(accumulator: Accumulator): BigDecimal = {
- if (!accumulator.asInstanceOf[DecimalSumAccumulator].f1) {
+ override def getValue(acc: DecimalSumAccumulator): BigDecimal = {
+ if (!acc.f1) {
null.asInstanceOf[BigDecimal]
} else {
- accumulator.asInstanceOf[DecimalSumAccumulator].f0
+ acc.f0
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[DecimalSumAccumulator]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[DecimalSumAccumulator]
+ def merge(acc: DecimalSumAccumulator, its: JIterable[DecimalSumAccumulator]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
if (a.f1) {
- ret.f0 = ret.f0.add(a.f0)
- ret.f1 = true
+ acc.f0 = acc.f0.add(a.f0)
+ acc.f1 = true
}
- i += 1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[DecimalSumAccumulator].f0 = BigDecimal.ZERO
- accumulator.asInstanceOf[DecimalSumAccumulator].f1 = false
+ def resetAccumulator(acc: DecimalSumAccumulator): Unit = {
+ acc.f0 = BigDecimal.ZERO
+ acc.f1 = false
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo(
(new DecimalSumAccumulator).getClass,
BasicTypeInfo.BIG_DEC_TYPE_INFO,
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
index 928be11..7f68d11 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
@@ -18,77 +18,73 @@
package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
-import java.util.{List => JList}
+import java.lang.{Iterable => JIterable}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
/** The initial accumulator for Sum with retract aggregate function */
-class SumWithRetractAccumulator[T] extends JTuple2[T, Long] with Accumulator
+class SumWithRetractAccumulator[T] extends JTuple2[T, Long]
/**
* Base class for built-in Sum with retract aggregate function
*
* @tparam T the type for the aggregation result
*/
-abstract class SumWithRetractAggFunction[T: Numeric] extends AggregateFunction[T] {
+abstract class SumWithRetractAggFunction[T: Numeric]
+ extends AggregateFunction[T, SumWithRetractAccumulator[T]] {
private val numeric = implicitly[Numeric[T]]
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): SumWithRetractAccumulator[T] = {
val acc = new SumWithRetractAccumulator[T]()
acc.f0 = numeric.zero //sum
acc.f1 = 0L //total count
acc
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: SumWithRetractAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[SumWithRetractAccumulator[T]]
- a.f0 = numeric.plus(a.f0, v)
- a.f1 += 1
+ acc.f0 = numeric.plus(acc.f0, v)
+ acc.f1 += 1
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: SumWithRetractAccumulator[T], value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[SumWithRetractAccumulator[T]]
- a.f0 = numeric.minus(a.f0, v)
- a.f1 -= 1
+ acc.f0 = numeric.minus(acc.f0, v)
+ acc.f1 -= 1
}
}
- override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[SumWithRetractAccumulator[T]]
- if (a.f1 > 0) {
- a.f0
+ override def getValue(acc: SumWithRetractAccumulator[T]): T = {
+ if (acc.f1 > 0) {
+ acc.f0
} else {
null.asInstanceOf[T]
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[SumWithRetractAccumulator[T]]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[SumWithRetractAccumulator[T]]
- ret.f0 = numeric.plus(ret.f0, a.f0)
- ret.f1 += a.f1
- i += 1
+ def merge(acc: SumWithRetractAccumulator[T],
+ its: JIterable[SumWithRetractAccumulator[T]]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
+ acc.f0 = numeric.plus(acc.f0, a.f0)
+ acc.f1 += a.f1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[SumWithRetractAccumulator[T]].f0 = numeric.zero
- accumulator.asInstanceOf[SumWithRetractAccumulator[T]].f1 = 0L
+ def resetAccumulator(acc: SumWithRetractAccumulator[T]): Unit = {
+ acc.f0 = numeric.zero
+ acc.f1 = 0L
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo(
(new SumWithRetractAccumulator).getClass,
getValueTypeInfo,
@@ -141,7 +137,7 @@ class DoubleSumWithRetractAggFunction extends SumWithRetractAggFunction[Double]
}
/** The initial accumulator for Big Decimal Sum with retract aggregate function */
-class DecimalSumWithRetractAccumulator extends JTuple2[BigDecimal, Long] with Accumulator {
+class DecimalSumWithRetractAccumulator extends JTuple2[BigDecimal, Long] {
f0 = BigDecimal.ZERO
f1 = 0L
}
@@ -149,56 +145,53 @@ class DecimalSumWithRetractAccumulator extends JTuple2[BigDecimal, Long] with Ac
/**
* Built-in Big Decimal Sum with retract aggregate function
*/
-class DecimalSumWithRetractAggFunction extends AggregateFunction[BigDecimal] {
+class DecimalSumWithRetractAggFunction
+ extends AggregateFunction[BigDecimal, DecimalSumWithRetractAccumulator] {
- override def createAccumulator(): Accumulator = {
+ override def createAccumulator(): DecimalSumWithRetractAccumulator = {
new DecimalSumWithRetractAccumulator
}
- override def accumulate(accumulator: Accumulator, value: Any): Unit = {
+ def accumulate(acc: DecimalSumWithRetractAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[DecimalSumWithRetractAccumulator]
- accum.f0 = accum.f0.add(v)
- accum.f1 += 1L
+ acc.f0 = acc.f0.add(v)
+ acc.f1 += 1L
}
}
- override def retract(accumulator: Accumulator, value: Any): Unit = {
+ def retract(acc: DecimalSumWithRetractAccumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[DecimalSumWithRetractAccumulator]
- accum.f0 = accum.f0.subtract(v)
- accum.f1 -= 1L
+ acc.f0 = acc.f0.subtract(v)
+ acc.f1 -= 1L
}
}
- override def getValue(accumulator: Accumulator): BigDecimal = {
- if (accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f1 == 0) {
+ override def getValue(acc: DecimalSumWithRetractAccumulator): BigDecimal = {
+ if (acc.f1 == 0) {
null.asInstanceOf[BigDecimal]
} else {
- accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f0
+ acc.f0
}
}
- override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = accumulators.get(0).asInstanceOf[DecimalSumWithRetractAccumulator]
- var i: Int = 1
- while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[DecimalSumWithRetractAccumulator]
- ret.f0 = ret.f0.add(a.f0)
- ret.f1 += a.f1
- i += 1
+ def merge(acc: DecimalSumWithRetractAccumulator,
+ its: JIterable[DecimalSumWithRetractAccumulator]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val a = iter.next()
+ acc.f0 = acc.f0.add(a.f0)
+ acc.f1 += a.f1
}
- ret
}
- override def resetAccumulator(accumulator: Accumulator): Unit = {
- accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f0 = BigDecimal.ZERO
- accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f1 = 0L
+ def resetAccumulator(acc: DecimalSumWithRetractAccumulator): Unit = {
+ acc.f0 = BigDecimal.ZERO
+ acc.f1 = 0L
}
- override def getAccumulatorType(): TypeInformation[_] = {
+ def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo(
(new DecimalSumWithRetractAccumulator).getClass,
BasicTypeInfo.BIG_DEC_TYPE_INFO,
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 2c503c6..a82f383 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -17,6 +17,7 @@
*/
package org.apache.flink.table.runtime.aggregate
+import java.lang.reflect.Method
import java.util
import org.apache.calcite.rel.`type`._
@@ -73,11 +74,12 @@ object AggregateUtil {
isPartitioned: Boolean,
isRowsClause: Boolean): ProcessFunction[Row, Row] = {
+ val needRetract = false
val (aggFields, aggregates) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetract)
val aggregationStateType: RowTypeInfo =
createDataSetAggregateBufferDataType(Array(), aggregates, inputType)
@@ -97,7 +99,9 @@ object AggregateUtil {
forwardMapping,
None,
None,
- outputArity
+ outputArity,
+ needRetract,
+ needMerge = false
)
if (isRowTimeType) {
@@ -147,11 +151,12 @@ object AggregateUtil {
isRowsClause: Boolean,
isRowTimeType: Boolean): ProcessFunction[Row, Row] = {
+ val needRetract = true
val (aggFields, aggregates) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = true)
+ needRetract)
val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates)
val inputRowType = FlinkTypeFactory.toInternalRowTypeInfo(inputType).asInstanceOf[RowTypeInfo]
@@ -171,7 +176,9 @@ object AggregateUtil {
forwardMapping,
None,
None,
- outputArity
+ outputArity,
+ needRetract,
+ needMerge = false
)
if (isRowTimeType) {
@@ -239,10 +246,11 @@ object AggregateUtil {
isParserCaseSensitive: Boolean)
: MapFunction[Row, Row] = {
+ val needRetract = false
val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetract)
val mapReturnType: RowTypeInfo =
createDataSetAggregateBufferDataType(
@@ -293,7 +301,9 @@ object AggregateUtil {
groupings,
None,
None,
- outputArity
+ outputArity,
+ needRetract,
+ needMerge = false
)
new DataSetWindowAggMapFunction(
@@ -339,10 +349,11 @@ object AggregateUtil {
isParserCaseSensitive: Boolean)
: RichGroupReduceFunction[Row, Row] = {
+ val needRetract = false
val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetract)
val returnType: RowTypeInfo = createDataSetAggregateBufferDataType(
groupings,
@@ -366,7 +377,9 @@ object AggregateUtil {
groupings,
Some(aggregates.indices.map(_ + groupings.length).toArray),
None,
- keysAndAggregatesArity + 1
+ keysAndAggregatesArity + 1,
+ needRetract,
+ needMerge = true
)
new DataSetSlideTimeWindowAggReduceGroupFunction(
genFunction,
@@ -447,10 +460,11 @@ object AggregateUtil {
isInputCombined: Boolean = false)
: RichGroupReduceFunction[Row, Row] = {
+ val needRetract = false
val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetract)
val aggMapping = aggregates.indices.toArray.map(_ + groupings.length)
@@ -465,7 +479,9 @@ object AggregateUtil {
groupings,
Some(aggregates.indices.map(_ + groupings.length).toArray),
None,
- outputType.getFieldCount
+ outputType.getFieldCount,
+ needRetract,
+ needMerge = true
)
val genFinalAggFunction = generator.generateAggregations(
@@ -479,7 +495,9 @@ object AggregateUtil {
groupings.indices.toArray,
Some(aggregates.indices.map(_ + groupings.length).toArray),
None,
- outputType.getFieldCount
+ outputType.getFieldCount,
+ needRetract,
+ needMerge = true
)
val keysAndAggregatesArity = groupings.length + namedAggregates.length
@@ -586,10 +604,11 @@ object AggregateUtil {
inputType: RelDataType,
groupings: Array[Int]): MapPartitionFunction[Row, Row] = {
+ val needRetract = false
val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetract)
val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
@@ -615,7 +634,9 @@ object AggregateUtil {
groupings.indices.toArray,
Some(aggregates.indices.map(_ + groupings.length).toArray),
None,
- groupings.length + aggregates.length + 2
+ groupings.length + aggregates.length + 2,
+ needRetract,
+ needMerge = true
)
new DataSetSessionWindowAggregatePreProcessor(
@@ -654,10 +675,11 @@ object AggregateUtil {
groupings: Array[Int])
: GroupCombineFunction[Row, Row] = {
+ val needRetract = false
val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetract)
val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
@@ -684,7 +706,9 @@ object AggregateUtil {
groupings.indices.toArray,
Some(aggregates.indices.map(_ + groupings.length).toArray),
None,
- groupings.length + aggregates.length + 2
+ groupings.length + aggregates.length + 2,
+ needRetract,
+ needMerge = true
)
new DataSetSessionWindowAggregatePreProcessor(
@@ -715,10 +739,11 @@ object AggregateUtil {
Option[TypeInformation[Row]],
RichGroupReduceFunction[Row, Row]) = {
+ val needRetract = false
val (aggInFields, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetract)
val (gkeyOutMapping, aggOutMapping) = getOutputMappings(
namedAggregates,
@@ -760,7 +785,9 @@ object AggregateUtil {
groupings,
None,
None,
- groupings.length + aggregates.length
+ groupings.length + aggregates.length,
+ needRetract,
+ needMerge = false
)
// compute mapping of forwarded grouping keys
@@ -784,7 +811,9 @@ object AggregateUtil {
gkeyMapping,
Some(aggregates.indices.map(_ + groupings.length).toArray),
constantFlags,
- outputType.getFieldCount
+ outputType.getFieldCount,
+ needRetract,
+ needMerge = true
)
(
@@ -805,7 +834,9 @@ object AggregateUtil {
groupings,
None,
constantFlags,
- outputType.getFieldCount
+ outputType.getFieldCount,
+ needRetract,
+ needMerge = false
)
(
@@ -874,11 +905,12 @@ object AggregateUtil {
outputType: RelDataType)
: (DataStreamAggFunction[Row, Row, Row], RowTypeInfo, RowTypeInfo) = {
+ val needRetract = false
val (aggFields, aggregates) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetract)
val aggMapping = aggregates.indices.toArray
val outputArity = aggregates.length
@@ -894,7 +926,9 @@ object AggregateUtil {
Array(), // no fields are forwarded
None,
None,
- outputArity
+ outputArity,
+ needRetract,
+ needMerge = true
)
val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType))
@@ -926,7 +960,7 @@ object AggregateUtil {
* Return true if all aggregates can be partially merged. False otherwise.
*/
private[flink] def doAllSupportPartialMerge(
- aggregateList: Array[TableAggregateFunction[_ <: Any]]): Boolean = {
+ aggregateList: Array[TableAggregateFunction[_ <: Any, _ <: Any]]): Boolean = {
aggregateList.forall(ifMethodExistInFunction("merge", _))
}
@@ -1033,11 +1067,11 @@ object AggregateUtil {
aggregateCalls: Seq[AggregateCall],
inputType: RelDataType,
needRetraction: Boolean)
- : (Array[Array[Int]], Array[TableAggregateFunction[_ <: Any]]) = {
+ : (Array[Array[Int]], Array[TableAggregateFunction[_ <: Any, _ <: Any]]) = {
// store the aggregate fields of each aggregate function, by the same order of aggregates.
val aggFieldIndexes = new Array[Array[Int]](aggregateCalls.size)
- val aggregates = new Array[TableAggregateFunction[_ <: Any]](aggregateCalls.size)
+ val aggregates = new Array[TableAggregateFunction[_ <: Any, _ <: Any]](aggregateCalls.size)
// create aggregate function instances by function type and aggregate field data type.
aggregateCalls.zipWithIndex.foreach { case (aggregateCall, index) =>
@@ -1232,12 +1266,18 @@ object AggregateUtil {
}
private def createAccumulatorType(
- aggregates: Array[TableAggregateFunction[_]]): Seq[TypeInformation[_]] = {
+ aggregates: Array[TableAggregateFunction[_, _]]): Seq[TypeInformation[_]] = {
val aggTypes: Seq[TypeInformation[_]] =
aggregates.map {
agg =>
- val accType = agg.getAccumulatorType
+ val accType = try {
+ val method: Method = agg.getClass.getMethod("getAccumulatorType")
+ method.invoke(agg).asInstanceOf[TypeInformation[_]]
+ } catch {
+ case _: NoSuchMethodException => null
+ case ite: Throwable => throw new TableException("Unexpected exception:", ite)
+ }
if (accType != null) {
accType
} else {
@@ -1259,7 +1299,7 @@ object AggregateUtil {
private def createDataSetAggregateBufferDataType(
groupings: Array[Int],
- aggregates: Array[TableAggregateFunction[_]],
+ aggregates: Array[TableAggregateFunction[_, _]],
inputType: RelDataType,
windowKeyTypes: Option[Array[TypeInformation[_]]] = None): RowTypeInfo = {
@@ -1281,7 +1321,7 @@ object AggregateUtil {
}
private[flink] def createAccumulatorRowType(
- aggregates: Array[TableAggregateFunction[_]]): RowTypeInfo = {
+ aggregates: Array[TableAggregateFunction[_, _]]): RowTypeInfo = {
val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(aggregates)
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
index bee39fa..5f48e09 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
@@ -101,3 +101,35 @@ abstract class GeneratedAggregations extends Function {
*/
def resetAccumulator(accumulators: Row)
}
+
+class SingleElementIterable[T] extends java.lang.Iterable[T] {
+
+ class SingleElementIterator extends java.util.Iterator[T] {
+
+ var element: T = _
+ var newElement: Boolean = false
+
+ override def hasNext: Boolean = newElement
+
+ override def next(): T = {
+ if (newElement) {
+ newElement = false
+ element
+ } else {
+ throw new java.util.NoSuchElementException
+ }
+ }
+
+ override def remove(): Unit = new java.lang.UnsupportedOperationException
+ }
+
+ val it = new SingleElementIterator
+
+ def setElement(element: T): Unit = it.element = element
+
+ override def iterator(): java.util.Iterator[T] = {
+ it.newElement = true
+ it
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala
index cb1137f..39b9ec3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala
@@ -17,9 +17,10 @@
*/
package org.apache.flink.table.functions.aggfunctions
+import java.lang.reflect.Method
import java.math.BigDecimal
import java.util.{ArrayList => JArrayList, List => JList}
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.junit.Assert.assertEquals
import org.junit.Test
@@ -29,14 +30,18 @@ import org.junit.Test
*
* @tparam T the type for the aggregation result
*/
-abstract class AggFunctionTestBase[T] {
+abstract class AggFunctionTestBase[T, ACC] {
def inputValueSets: Seq[Seq[_]]
def expectedResults: Seq[T]
- def aggregator: AggregateFunction[T]
+ def aggregator: AggregateFunction[T, ACC]
- def supportRetraction: Boolean = true
+ val accType = aggregator.getClass.getMethod("createAccumulator").getReturnType
+
+ def accumulateFunc: Method = aggregator.getClass.getMethod("accumulate", accType, classOf[Any])
+
+ def retractFunc: Method = null
@Test
// test aggregate and retract functions without partial merge
@@ -47,52 +52,55 @@ abstract class AggFunctionTestBase[T] {
val result = aggregator.getValue(accumulator)
validateResult[T](expected, result)
- if (supportRetraction) {
+ if (ifMethodExistInFunction("retract", aggregator)) {
retractVals(accumulator, vals)
val expectedAccum = aggregator.createAccumulator()
//The two accumulators should be exactly same
- validateResult[Accumulator](expectedAccum, accumulator)
+ validateResult[ACC](expectedAccum, accumulator)
}
}
}
@Test
- // test aggregate functions with partial merge
def testAggregateWithMerge(): Unit = {
if (ifMethodExistInFunction("merge", aggregator)) {
+ val mergeFunc =
+ aggregator.getClass.getMethod("merge", accType, classOf[java.lang.Iterable[ACC]])
// iterate over input sets
for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
//equally split the vals sequence into two sequences
val (firstVals, secondVals) = vals.splitAt(vals.length / 2)
//1. verify merge with accumulate
- val accumulators: JList[Accumulator] = new JArrayList[Accumulator]()
- accumulators.add(accumulateVals(firstVals))
+ val accumulators: JList[ACC] = new JArrayList[ACC]()
accumulators.add(accumulateVals(secondVals))
- val accumulator = aggregator.merge(accumulators)
- val result = aggregator.getValue(accumulator)
+ val acc = accumulateVals(firstVals)
+
+ mergeFunc.invoke(aggregator, acc.asInstanceOf[Object], accumulators)
+ val result = aggregator.getValue(acc)
validateResult[T](expected, result)
//2. verify merge with accumulate & retract
- if (supportRetraction) {
- retractVals(accumulator, vals)
+ if (ifMethodExistInFunction("retract", aggregator)) {
+ retractVals(acc, vals)
val expectedAccum = aggregator.createAccumulator()
//The two accumulators should be exactly same
- validateResult[Accumulator](expectedAccum, accumulator)
+ validateResult[ACC](expectedAccum, acc)
}
}
// iterate over input sets
for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
//3. test partial merge with an empty accumulator
- val accumulators: JList[Accumulator] = new JArrayList[Accumulator]()
- accumulators.add(accumulateVals(vals))
+ val accumulators: JList[ACC] = new JArrayList[ACC]()
accumulators.add(aggregator.createAccumulator())
- val accumulator = aggregator.merge(accumulators)
- val result = aggregator.getValue(accumulator)
+ val acc = accumulateVals(vals)
+
+ mergeFunc.invoke(aggregator, acc.asInstanceOf[Object], accumulators)
+ val result = aggregator.getValue(acc)
validateResult[T](expected, result)
}
}
@@ -103,13 +111,14 @@ abstract class AggFunctionTestBase[T] {
def testResetAccumulator(): Unit = {
if (ifMethodExistInFunction("resetAccumulator", aggregator)) {
+ val resetAccFunc = aggregator.getClass.getMethod("resetAccumulator", accType)
// iterate over input sets
for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
val accumulator = accumulateVals(vals)
- aggregator.resetAccumulator(accumulator)
+ resetAccFunc.invoke(aggregator, accumulator.asInstanceOf[Object])
val expectedAccum = aggregator.createAccumulator()
//The accumulator after reset should be exactly same as the new accumulator
- validateResult[Accumulator](expectedAccum, accumulator)
+ validateResult[ACC](expectedAccum, accumulator)
}
}
}
@@ -130,13 +139,18 @@ abstract class AggFunctionTestBase[T] {
}
}
- private def accumulateVals(vals: Seq[_]): Accumulator = {
+ private def accumulateVals(vals: Seq[_]): ACC = {
val accumulator = aggregator.createAccumulator()
- vals.foreach(v => aggregator.accumulate(accumulator, v))
+ vals.foreach(
+ v =>
+ accumulateFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object])
+ )
accumulator
}
- private def retractVals(accumulator:Accumulator, vals: Seq[_]) = {
- vals.foreach(v => aggregator.retract(accumulator, v))
+ private def retractVals(accumulator:ACC, vals: Seq[_]) = {
+ vals.foreach(
+ v => retractFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object])
+ )
}
}