You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/03/02 20:32:44 UTC
[3/3] flink git commit: [FLINK-5768] [table] Refactor DataSet and
DataStream aggregations to use UDAGG interface.
[FLINK-5768] [table] Refactor DataSet and DataStream aggregations to use UDAGG interface.
- DataStream aggregates use new WindowedStream.aggregate() operator.
This closes #3423.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/438276de
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/438276de
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/438276de
Branch: refs/heads/master
Commit: 438276de8fab4f1a8f2b62b6452c2e5b2998ce5a
Parents: 7fe0eb4
Author: shaoxuan-wang <ws...@gmail.com>
Authored: Mon Feb 27 19:09:30 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Mar 2 21:31:18 2017 +0100
----------------------------------------------------------------------
.../table/functions/AggregateFunction.scala | 13 +-
.../functions/aggfunctions/AvgAggFunction.scala | 91 ++--
.../aggfunctions/CountAggFunction.scala | 19 +-
.../functions/aggfunctions/MaxAggFunction.scala | 66 ++-
.../functions/aggfunctions/MinAggFunction.scala | 66 ++-
.../functions/aggfunctions/SumAggFunction.scala | 76 ++-
.../utils/UserDefinedFunctionUtils.scala | 4 +-
.../plan/nodes/dataset/DataSetAggregate.scala | 25 +-
.../nodes/dataset/DataSetWindowAggregate.scala | 28 +-
.../nodes/datastream/DataStreamAggregate.scala | 152 ++----
.../aggregate/AggregateAggFunction.scala | 79 ++++
.../AggregateAllTimeWindowFunction.scala | 52 ---
.../aggregate/AggregateAllWindowFunction.scala | 41 --
.../aggregate/AggregateMapFunction.scala | 22 +-
.../AggregateReduceCombineFunction.scala | 89 ++--
.../AggregateReduceGroupFunction.scala | 96 ++--
.../aggregate/AggregateTimeWindowFunction.scala | 57 ---
.../table/runtime/aggregate/AggregateUtil.scala | 464 ++++++++-----------
.../aggregate/AggregateWindowFunction.scala | 46 --
...ionWindowAggregateCombineGroupFunction.scala | 88 ++--
...sionWindowAggregateReduceGroupFunction.scala | 104 +++--
...umbleCountWindowAggReduceGroupFunction.scala | 49 +-
...mbleTimeWindowAggReduceCombineFunction.scala | 58 ++-
...TumbleTimeWindowAggReduceGroupFunction.scala | 60 ++-
.../DataSetWindowAggregateMapFunction.scala | 18 +-
...rementalAggregateAllTimeWindowFunction.scala | 24 +-
.../IncrementalAggregateAllWindowFunction.scala | 30 +-
.../IncrementalAggregateReduceFunction.scala | 63 ---
...IncrementalAggregateTimeWindowFunction.scala | 32 +-
.../IncrementalAggregateWindowFunction.scala | 40 +-
.../scala/stream/table/AggregationsITCase.scala | 10 +-
.../aggfunctions/AggFunctionTestBase.scala | 2 +-
.../dataset/DataSetWindowAggregateITCase.scala | 52 ++-
33 files changed, 1050 insertions(+), 1066 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index e15a8c4..178b439 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -19,12 +19,14 @@ package org.apache.flink.table.functions
import java.util.{List => JList}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+
/**
* Base class for User-Defined Aggregates.
*
* @tparam T the type of the aggregation result
*/
-trait AggregateFunction[T] extends UserDefinedFunction {
+abstract class AggregateFunction[T] extends UserDefinedFunction {
/**
* Create and init the Accumulator for this [[AggregateFunction]].
*
@@ -61,6 +63,15 @@ trait AggregateFunction[T] extends UserDefinedFunction {
* @return the resulting accumulator
*/
def merge(accumulators: JList[Accumulator]): Accumulator
+
+ /**
+ * Returns the [[TypeInformation]] of the accumulator.
+ * This function is optional and can be implemented if the accumulator type cannot automatically
+ * inferred from the instance returned by [[createAccumulator()]].
+ *
+ * @return The type information for the accumulator.
+ */
+ def getAccumulatorType(): TypeInformation[_] = null
}
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
index f4c0b7b..534bb03 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
@@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions
import java.math.{BigDecimal, BigInteger}
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for Integral Avg aggregate function */
+class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator {
+ f0 = 0L //sum
+ f1 = 0L //count
+}
+
/**
* Base class for built-in Integral Avg aggregate function
*
@@ -29,12 +38,6 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] {
- /** The initial accumulator for Integral Avg aggregate function */
- class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator {
- f0 = 0 //sum
- f1 = 0 //count
- }
-
override def createAccumulator(): Accumulator = {
new IntegralAvgAccumulator
}
@@ -44,7 +47,7 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] {
val v = value.asInstanceOf[Number].longValue()
val accum = accumulator.asInstanceOf[IntegralAvgAccumulator]
accum.f0 += v
- accum.f1 += 1
+ accum.f1 += 1L
}
}
@@ -69,6 +72,13 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] {
ret
}
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new IntegralAvgAccumulator().getClass,
+ BasicTypeInfo.LONG_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO)
+ }
+
/**
* Convert the intermediate result to the expected aggregation result type
*
@@ -100,6 +110,13 @@ class IntAvgAggFunction extends IntegralAvgAggFunction[Int] {
override def resultTypeConvert(value: Long): Int = value.toInt
}
+/** The initial accumulator for Big Integral Avg aggregate function */
+class BigIntegralAvgAccumulator
+ extends JTuple2[BigInteger, Long] with Accumulator {
+ f0 = BigInteger.ZERO //sum
+ f1 = 0L //count
+}
+
/**
* Base Class for Built-in Big Integral Avg aggregate function
*
@@ -107,13 +124,6 @@ class IntAvgAggFunction extends IntegralAvgAggFunction[Int] {
*/
abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] {
- /** The initial accumulator for Big Integral Avg aggregate function */
- class BigIntegralAvgAccumulator
- extends JTuple2[BigInteger, Long] with Accumulator {
- f0 = BigInteger.ZERO //sum
- f1 = 0 //count
- }
-
override def createAccumulator(): Accumulator = {
new BigIntegralAvgAccumulator
}
@@ -123,7 +133,7 @@ abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] {
val v = value.asInstanceOf[Long]
val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator]
a.f0 = a.f0.add(BigInteger.valueOf(v))
- a.f1 += 1
+ a.f1 += 1L
}
}
@@ -148,6 +158,13 @@ abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] {
ret
}
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new BigIntegralAvgAccumulator().getClass,
+ BasicTypeInfo.BIG_INT_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO)
+ }
+
/**
* Convert the intermediate result to the expected aggregation result type
*
@@ -166,6 +183,12 @@ class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] {
override def resultTypeConvert(value: BigInteger): Long = value.longValue()
}
+/** The initial accumulator for Floating Avg aggregate function */
+class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator {
+ f0 = 0 //sum
+ f1 = 0L //count
+}
+
/**
* Base class for built-in Floating Avg aggregate function
*
@@ -173,12 +196,6 @@ class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] {
*/
abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] {
- /** The initial accumulator for Floating Avg aggregate function */
- class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator {
- f0 = 0 //sum
- f1 = 0 //count
- }
-
override def createAccumulator(): Accumulator = {
new FloatingAvgAccumulator
}
@@ -188,7 +205,7 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] {
val v = value.asInstanceOf[Number].doubleValue()
val accum = accumulator.asInstanceOf[FloatingAvgAccumulator]
accum.f0 += v
- accum.f1 += 1
+ accum.f1 += 1L
}
}
@@ -213,6 +230,13 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] {
ret
}
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new FloatingAvgAccumulator().getClass,
+ BasicTypeInfo.DOUBLE_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO)
+ }
+
/**
* Convert the intermediate result to the expected aggregation result type
*
@@ -237,18 +261,18 @@ class DoubleAvgAggFunction extends FloatingAvgAggFunction[Double] {
override def resultTypeConvert(value: Double): Double = value
}
+/** The initial accumulator for Big Decimal Avg aggregate function */
+class DecimalAvgAccumulator
+ extends JTuple2[BigDecimal, Long] with Accumulator {
+ f0 = BigDecimal.ZERO //sum
+ f1 = 0L //count
+}
+
/**
* Base class for built-in Big Decimal Avg aggregate function
*/
class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] {
- /** The initial accumulator for Big Decimal Avg aggregate function */
- class DecimalAvgAccumulator
- extends JTuple2[BigDecimal, Long] with Accumulator {
- f0 = BigDecimal.ZERO //sum
- f1 = 0 //count
- }
-
override def createAccumulator(): Accumulator = {
new DecimalAvgAccumulator
}
@@ -262,7 +286,7 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] {
} else {
accum.f0 = accum.f0.add(v)
}
- accum.f1 += 1
+ accum.f1 += 1L
}
}
@@ -286,4 +310,11 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] {
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new DecimalAvgAccumulator().getClass,
+ BasicTypeInfo.BIG_DEC_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
index 8b903d1..cf884ed 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
@@ -18,22 +18,25 @@
package org.apache.flink.table.functions.aggfunctions
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for count aggregate function */
+class CountAccumulator extends JTuple1[Long] with Accumulator {
+ f0 = 0L //count
+}
+
/**
* built-in count aggregate function
*/
class CountAggFunction extends AggregateFunction[Long] {
- /** The initial accumulator for count aggregate function */
- class CountAccumulator extends JTuple1[Long] with Accumulator {
- f0 = 0 //count
- }
-
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
- accumulator.asInstanceOf[CountAccumulator].f0 += 1
+ accumulator.asInstanceOf[CountAccumulator].f0 += 1L
}
}
@@ -54,4 +57,8 @@ class CountAggFunction extends AggregateFunction[Long] {
override def createAccumulator(): Accumulator = {
new CountAccumulator
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
index 20041ee..62ff88c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
@@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for Max aggregate function */
+class MaxAccumulator[T] extends JTuple2[T, Boolean] with Accumulator {
+ f0 = 0.asInstanceOf[T] //max
+ f1 = false
+}
+
/**
* Base class for built-in Max aggregate function
*
@@ -29,20 +38,14 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] {
- /** The initial accumulator for Max aggregate function */
- class MaxAccumulator extends JTuple2[T, Boolean] with Accumulator {
- f0 = 0.asInstanceOf[T] //max
- f1 = false
- }
-
override def createAccumulator(): Accumulator = {
- new MaxAccumulator
+ new MaxAccumulator[T]
}
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MaxAccumulator]
+ val a = accumulator.asInstanceOf[MaxAccumulator[T]]
if (!a.f1 || ord.compare(a.f0, v) < 0) {
a.f0 = v
a.f1 = true
@@ -51,7 +54,7 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
}
override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[MaxAccumulator]
+ val a = accumulator.asInstanceOf[MaxAccumulator[T]]
if (a.f1) {
a.f0
} else {
@@ -63,50 +66,73 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
val ret = accumulators.get(0)
var i: Int = 1
while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[MaxAccumulator]
+ val a = accumulators.get(i).asInstanceOf[MaxAccumulator[T]]
if (a.f1) {
- accumulate(ret.asInstanceOf[MaxAccumulator], a.f0)
+ accumulate(ret.asInstanceOf[MaxAccumulator[T]], a.f0)
}
i += 1
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new MaxAccumulator[T].getClass,
+ getValueTypeInfo,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO)
+ }
+
+ def getValueTypeInfo: TypeInformation[_]
}
/**
* Built-in Byte Max aggregate function
*/
-class ByteMaxAggFunction extends MaxAggFunction[Byte]
+class ByteMaxAggFunction extends MaxAggFunction[Byte] {
+ override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO
+}
/**
* Built-in Short Max aggregate function
*/
-class ShortMaxAggFunction extends MaxAggFunction[Short]
+class ShortMaxAggFunction extends MaxAggFunction[Short] {
+ override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO
+}
/**
* Built-in Int Max aggregate function
*/
-class IntMaxAggFunction extends MaxAggFunction[Int]
+class IntMaxAggFunction extends MaxAggFunction[Int] {
+ override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO
+}
/**
* Built-in Long Max aggregate function
*/
-class LongMaxAggFunction extends MaxAggFunction[Long]
+class LongMaxAggFunction extends MaxAggFunction[Long] {
+ override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO
+}
/**
* Built-in Float Max aggregate function
*/
-class FloatMaxAggFunction extends MaxAggFunction[Float]
+class FloatMaxAggFunction extends MaxAggFunction[Float] {
+ override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO
+}
/**
* Built-in Double Max aggregate function
*/
-class DoubleMaxAggFunction extends MaxAggFunction[Double]
+class DoubleMaxAggFunction extends MaxAggFunction[Double] {
+ override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO
+}
/**
* Built-in Boolean Max aggregate function
*/
-class BooleanMaxAggFunction extends MaxAggFunction[Boolean]
+class BooleanMaxAggFunction extends MaxAggFunction[Boolean] {
+ override def getValueTypeInfo = BasicTypeInfo.BOOLEAN_TYPE_INFO
+}
/**
* Built-in Big Decimal Max aggregate function
@@ -116,11 +142,13 @@ class DecimalMaxAggFunction extends MaxAggFunction[BigDecimal] {
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[MaxAccumulator]
+ val accum = accumulator.asInstanceOf[MaxAccumulator[BigDecimal]]
if (!accum.f1 || accum.f0.compareTo(v) < 0) {
accum.f0 = v
accum.f1 = true
}
}
}
+
+ override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
index 16461ae..cddb873 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
@@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for Min aggregate function */
+class MinAccumulator[T] extends JTuple2[T, Boolean] with Accumulator {
+ f0 = 0.asInstanceOf[T] //min
+ f1 = false
+}
+
/**
* Base class for built-in Min aggregate function
*
@@ -29,20 +38,14 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] {
- /** The initial accumulator for Min aggregate function */
- class MinAccumulator extends JTuple2[T, Boolean] with Accumulator {
- f0 = 0.asInstanceOf[T] //min
- f1 = false
- }
-
override def createAccumulator(): Accumulator = {
- new MinAccumulator
+ new MinAccumulator[T]
}
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MinAccumulator]
+ val a = accumulator.asInstanceOf[MinAccumulator[T]]
if (!a.f1 || ord.compare(a.f0, v) > 0) {
a.f0 = v
a.f1 = true
@@ -51,7 +54,7 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
}
override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[MinAccumulator]
+ val a = accumulator.asInstanceOf[MinAccumulator[T]]
if (a.f1) {
a.f0
} else {
@@ -63,50 +66,73 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
val ret = accumulators.get(0)
var i: Int = 1
while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[MinAccumulator]
+ val a = accumulators.get(i).asInstanceOf[MinAccumulator[T]]
if (a.f1) {
- accumulate(ret.asInstanceOf[MinAccumulator], a.f0)
+ accumulate(ret.asInstanceOf[MinAccumulator[T]], a.f0)
}
i += 1
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new MinAccumulator[T].getClass,
+ getValueTypeInfo,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO)
+ }
+
+ def getValueTypeInfo: TypeInformation[_]
}
/**
* Built-in Byte Min aggregate function
*/
-class ByteMinAggFunction extends MinAggFunction[Byte]
+class ByteMinAggFunction extends MinAggFunction[Byte] {
+ override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO
+}
/**
* Built-in Short Min aggregate function
*/
-class ShortMinAggFunction extends MinAggFunction[Short]
+class ShortMinAggFunction extends MinAggFunction[Short] {
+ override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO
+}
/**
* Built-in Int Min aggregate function
*/
-class IntMinAggFunction extends MinAggFunction[Int]
+class IntMinAggFunction extends MinAggFunction[Int] {
+ override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO
+}
/**
* Built-in Long Min aggregate function
*/
-class LongMinAggFunction extends MinAggFunction[Long]
+class LongMinAggFunction extends MinAggFunction[Long] {
+ override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO
+}
/**
* Built-in Float Min aggregate function
*/
-class FloatMinAggFunction extends MinAggFunction[Float]
+class FloatMinAggFunction extends MinAggFunction[Float] {
+ override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO
+}
/**
* Built-in Double Min aggregate function
*/
-class DoubleMinAggFunction extends MinAggFunction[Double]
+class DoubleMinAggFunction extends MinAggFunction[Double] {
+ override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO
+}
/**
* Built-in Boolean Min aggregate function
*/
-class BooleanMinAggFunction extends MinAggFunction[Boolean]
+class BooleanMinAggFunction extends MinAggFunction[Boolean] {
+ override def getValueTypeInfo = BasicTypeInfo.BOOLEAN_TYPE_INFO
+}
/**
* Built-in Big Decimal Min aggregate function
@@ -116,11 +142,13 @@ class DecimalMinAggFunction extends MinAggFunction[BigDecimal] {
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[MinAccumulator]
+ val accum = accumulator.asInstanceOf[MinAccumulator[BigDecimal]]
if (!accum.f1 || accum.f0.compareTo(v) > 0) {
accum.f0 = v
accum.f1 = true
}
}
}
+
+ override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
index b04d8c0..78fdb8e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
@@ -19,9 +19,15 @@ package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for Sum aggregate function */
+class SumAccumulator[T] extends JTuple2[T, Boolean] with Accumulator
+
/**
* Base class for built-in Sum aggregate function
*
@@ -29,29 +35,26 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] {
- /** The initial accumulator for Sum aggregate function */
- class SumAccumulator extends JTuple2[T, Boolean] with Accumulator {
- f0 = numeric.zero //sum
- f1 = false
- }
-
private val numeric = implicitly[Numeric[T]]
override def createAccumulator(): Accumulator = {
- new SumAccumulator
+ val acc = new SumAccumulator[T]()
+ acc.f0 = numeric.zero //sum
+ acc.f1 = false
+ acc
}
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[SumAccumulator]
+ val a = accumulator.asInstanceOf[SumAccumulator[T]]
a.f0 = numeric.plus(v, a.f0)
a.f1 = true
}
}
override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[SumAccumulator]
+ val a = accumulator.asInstanceOf[SumAccumulator[T]]
if (a.f1) {
a.f0
} else {
@@ -60,10 +63,10 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] {
}
override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = createAccumulator().asInstanceOf[SumAccumulator]
+ val ret = createAccumulator().asInstanceOf[SumAccumulator[T]]
var i: Int = 0
while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[SumAccumulator]
+ val a = accumulators.get(i).asInstanceOf[SumAccumulator[T]]
if (a.f1) {
ret.f0 = numeric.plus(ret.f0, a.f0)
ret.f1 = true
@@ -72,50 +75,70 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] {
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ (new SumAccumulator).getClass,
+ getValueTypeInfo,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO)
+ }
+
+ def getValueTypeInfo: TypeInformation[_]
}
/**
* Built-in Byte Sum aggregate function
*/
-class ByteSumAggFunction extends SumAggFunction[Byte]
+class ByteSumAggFunction extends SumAggFunction[Byte] {
+ override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO
+}
/**
* Built-in Short Sum aggregate function
*/
-class ShortSumAggFunction extends SumAggFunction[Short]
+class ShortSumAggFunction extends SumAggFunction[Short] {
+ override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO
+}
/**
* Built-in Int Sum aggregate function
*/
-class IntSumAggFunction extends SumAggFunction[Int]
+class IntSumAggFunction extends SumAggFunction[Int] {
+ override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO
+}
/**
* Built-in Long Sum aggregate function
*/
-class LongSumAggFunction extends SumAggFunction[Long]
+class LongSumAggFunction extends SumAggFunction[Long] {
+ override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO
+}
/**
* Built-in Float Sum aggregate function
*/
-class FloatSumAggFunction extends SumAggFunction[Float]
+class FloatSumAggFunction extends SumAggFunction[Float] {
+ override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO
+}
/**
* Built-in Double Sum aggregate function
*/
-class DoubleSumAggFunction extends SumAggFunction[Double]
+class DoubleSumAggFunction extends SumAggFunction[Double] {
+ override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO
+}
+/** The initial accumulator for Big Decimal Sum aggregate function */
+class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator {
+ f0 = BigDecimal.ZERO
+ f1 = false
+}
/**
* Built-in Big Decimal Sum aggregate function
*/
class DecimalSumAggFunction extends AggregateFunction[BigDecimal] {
- /** The initial accumulator for Big Decimal Sum aggregate function */
- class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator {
- f0 = BigDecimal.ZERO
- f1 = false
- }
-
override def createAccumulator(): Accumulator = {
new DecimalSumAccumulator
}
@@ -150,4 +173,11 @@ class DecimalSumAggFunction extends AggregateFunction[BigDecimal] {
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ (new DecimalSumAccumulator).getClass,
+ BasicTypeInfo.BIG_DEC_TYPE_INFO,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index aec4fbb..21d28b5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -119,9 +119,9 @@ object UserDefinedFunctionUtils {
}
/**
- * Check if a given method exits in the given function
+ * Check if a given method exists in the given function
*/
- def ifMethodExitInFunction(method: String, function: UserDefinedFunction): Boolean = {
+ def ifMethodExistInFunction(method: String, function: UserDefinedFunction): Boolean = {
val methods = function
.getClass
.getMethods
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
index 206e562..a88bcfe 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
@@ -44,9 +44,7 @@ class DataSetAggregate(
inputType: RelDataType,
grouping: Array[Int],
inGroupingSet: Boolean)
- extends SingleRel(cluster, traitSet, inputNode)
- with CommonAggregate
- with DataSetRel {
+ extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataSetRel {
override def deriveRowType(): RelDataType = rowRelDataType
@@ -63,11 +61,13 @@ class DataSetAggregate(
}
override def toString: String = {
- s"Aggregate(${ if (!grouping.isEmpty) {
- s"groupBy: (${groupingToString(inputType, grouping)}), "
- } else {
- ""
- }}select: (${aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)}))"
+ s"Aggregate(${
+ if (!grouping.isEmpty) {
+ s"groupBy: (${groupingToString(inputType, grouping)}), "
+ } else {
+ ""
+ }
+ }select: (${aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)}))"
}
override def explainTerms(pw: RelWriter): RelWriter = {
@@ -76,7 +76,7 @@ class DataSetAggregate(
.item("select", aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil))
}
- override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
+ override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
@@ -87,8 +87,6 @@ class DataSetAggregate(
override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = {
- val config = tableEnv.getConfig
-
val groupingKeys = grouping.indices.toArray
val mapFunction = AggregateUtil.createPrepareMapFunction(
@@ -107,9 +105,7 @@ class DataSetAggregate(
val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)
val prepareOpName = s"prepare select: ($aggString)"
- val mappedInput = inputDS
- .map(mapFunction)
- .name(prepareOpName)
+ val mappedInput = inputDS.map(mapFunction).name(prepareOpName)
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
@@ -127,6 +123,7 @@ class DataSetAggregate(
else {
// global aggregation
val aggOpName = s"select:($aggString)"
+
mappedInput.asInstanceOf[DataSet[Row]]
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
index 48de822..597be8c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
@@ -47,9 +47,7 @@ class DataSetWindowAggregate(
rowRelDataType: RelDataType,
inputType: RelDataType,
grouping: Array[Int])
- extends SingleRel(cluster, traitSet, inputNode)
- with CommonAggregate
- with DataSetRel {
+ extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataSetRel {
override def deriveRowType() = rowRelDataType
@@ -97,7 +95,7 @@ class DataSetWindowAggregate(
namedProperties))
}
- override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
+ override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
val rowSize = this.estimateRowSize(child.getRowType)
@@ -136,8 +134,8 @@ class DataSetWindowAggregate(
private def createEventTimeTumblingWindowDataSet(
inputDS: DataSet[Row],
isTimeWindow: Boolean,
- isParserCaseSensitive: Boolean)
- : DataSet[Row] = {
+ isParserCaseSensitive: Boolean): DataSet[Row] = {
+
val mapFunction = createDataSetWindowPrepareMapFunction(
window,
namedAggregates,
@@ -191,8 +189,7 @@ class DataSetWindowAggregate(
private[this] def createEventTimeSessionWindowDataSet(
inputDS: DataSet[Row],
- isParserCaseSensitive: Boolean)
- : DataSet[Row] = {
+ isParserCaseSensitive: Boolean): DataSet[Row] = {
val groupingKeys = grouping.indices.toArray
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
@@ -207,10 +204,7 @@ class DataSetWindowAggregate(
inputType,
isParserCaseSensitive)
- val mappedInput =
- inputDS
- .map(mapFunction)
- .name(prepareOperatorName)
+ val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName)
val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType
@@ -218,7 +212,7 @@ class DataSetWindowAggregate(
val rowTimeFieldPos = mapReturnType.getArity - 1
// do incremental aggregation
- if (doAllSupportPartialAggregation(
+ if (doAllSupportPartialMerge(
namedAggregates.map(_.getKey),
inputType,
grouping.length)) {
@@ -267,10 +261,10 @@ class DataSetWindowAggregate(
namedProperties)
mappedInput.groupBy(groupingKeys: _*)
- .sortGroup(rowTimeFieldPos, Order.ASCENDING)
- .reduceGroup(groupReduceFunction)
- .returns(rowTypeInfo)
- .name(aggregateOperatorName)
+ .sortGroup(rowTimeFieldPos, Order.ASCENDING)
+ .reduceGroup(groupReduceFunction)
+ .returns(rowTypeInfo)
+ .name(aggregateOperatorName)
}
}
// non-grouping window
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
index c21d008..50f8281 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
@@ -50,9 +50,7 @@ class DataStreamAggregate(
rowRelDataType: RelDataType,
inputType: RelDataType,
grouping: Array[Int])
- extends SingleRel(cluster, traitSet, inputNode)
- with CommonAggregate
- with DataStreamRel {
+ extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel {
override def deriveRowType(): RelDataType = rowRelDataType
@@ -91,12 +89,13 @@ class DataStreamAggregate(
super.explainTerms(pw)
.itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty)
.item("window", window)
- .item("select", aggregationToString(
- inputType,
- grouping,
- getRowType,
- namedAggregates,
- namedProperties))
+ .item(
+ "select", aggregationToString(
+ inputType,
+ grouping,
+ getRowType,
+ namedAggregates,
+ namedProperties))
}
override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = {
@@ -113,116 +112,61 @@ class DataStreamAggregate(
namedAggregates,
namedProperties)
- val prepareOpName = s"prepare select: ($aggString)"
val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " +
s"window: ($window), " +
s"select: ($aggString)"
val nonKeyedAggOpName = s"window: ($window), select: ($aggString)"
- val mapFunction = AggregateUtil.createPrepareMapFunction(
- namedAggregates,
- grouping,
- inputType)
-
- val mappedInput = inputDS.map(mapFunction).name(prepareOpName)
-
-
- // check whether all aggregates support partial aggregate
- if (AggregateUtil.doAllSupportPartialAggregation(
- namedAggregates.map(_.getKey),
- inputType,
- grouping.length)) {
- // do Incremental Aggregation
- val reduceFunction = AggregateUtil.createIncrementalAggregateReduceFunction(
- namedAggregates,
- inputType,
- getRowType,
- grouping)
- // grouped / keyed aggregation
- if (groupingKeys.length > 0) {
- val windowFunction = AggregateUtil.createWindowIncrementalAggregationFunction(
- window,
- namedAggregates,
- inputType,
- rowRelDataType,
- grouping,
- namedProperties)
-
- val keyedStream = mappedInput.keyBy(groupingKeys: _*)
- val windowedStream =
- createKeyedWindowedStream(window, keyedStream)
+ // grouped / keyed aggregation
+ if (groupingKeys.length > 0) {
+ val windowFunction = AggregateUtil.createAggregationGroupWindowFunction(
+ window,
+ groupingKeys.length,
+ namedAggregates.size,
+ rowRelDataType.getFieldCount,
+ namedProperties)
+
+ val keyedStream = inputDS.keyBy(groupingKeys: _*)
+ val windowedStream =
+ createKeyedWindowedStream(window, keyedStream)
.asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]]
- windowedStream
- .reduce(reduceFunction, windowFunction)
- .returns(rowTypeInfo)
- .name(keyedAggOpName)
- }
- // global / non-keyed aggregation
- else {
- val windowFunction = AggregateUtil.createAllWindowIncrementalAggregationFunction(
- window,
+ val (aggFunction, accumulatorRowType, aggResultRowType) =
+ AggregateUtil.createDataStreamAggregateFunction(
namedAggregates,
inputType,
rowRelDataType,
- grouping,
- namedProperties)
+ grouping)
- val windowedStream =
- createNonKeyedWindowedStream(window, mappedInput)
- .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]
-
- windowedStream
- .reduce(reduceFunction, windowFunction)
- .returns(rowTypeInfo)
- .name(nonKeyedAggOpName)
- }
+ windowedStream
+ .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo)
+ .name(keyedAggOpName)
}
+ // global / non-keyed aggregation
else {
- // do non-Incremental Aggregation
- // grouped / keyed aggregation
- if (groupingKeys.length > 0) {
-
- val windowFunction = AggregateUtil.createWindowAggregationFunction(
- window,
- namedAggregates,
- inputType,
- rowRelDataType,
- grouping,
- namedProperties)
+ val windowFunction = AggregateUtil.createAggregationAllWindowFunction(
+ window,
+ rowRelDataType.getFieldCount,
+ namedProperties)
- val keyedStream = mappedInput.keyBy(groupingKeys: _*)
- val windowedStream =
- createKeyedWindowedStream(window, keyedStream)
- .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]]
+ val windowedStream =
+ createNonKeyedWindowedStream(window, inputDS)
+ .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]
- windowedStream
- .apply(windowFunction)
- .returns(rowTypeInfo)
- .name(keyedAggOpName)
- }
- // global / non-keyed aggregation
- else {
- val windowFunction = AggregateUtil.createAllWindowAggregationFunction(
- window,
+ val (aggFunction, accumulatorRowType, aggResultRowType) =
+ AggregateUtil.createDataStreamAggregateFunction(
namedAggregates,
inputType,
rowRelDataType,
- grouping,
- namedProperties)
-
- val windowedStream =
- createNonKeyedWindowedStream(window, mappedInput)
- .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]
+ grouping)
- windowedStream
- .apply(windowFunction)
- .returns(rowTypeInfo)
+ windowedStream
+ .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo)
.name(nonKeyedAggOpName)
- }
}
}
}
+
object DataStreamAggregate {
@@ -242,8 +186,8 @@ object DataStreamAggregate {
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
- throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " +
- "currently not supported.")
+ throw new UnsupportedOperationException(
+ "Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) =>
stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide)))
@@ -258,8 +202,8 @@ object DataStreamAggregate {
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
- throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " +
- "currently not supported.")
+ throw new UnsupportedOperationException(
+ "Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSessionGroupWindow(_, gap: Expression) =>
stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap)))
@@ -284,8 +228,8 @@ object DataStreamAggregate {
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
- throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " +
- "currently not supported.")
+ throw new UnsupportedOperationException(
+ "Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) =>
stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide)))
@@ -300,8 +244,8 @@ object DataStreamAggregate {
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
- throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " +
- "currently not supported.")
+ throw new UnsupportedOperationException(
+ "Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSessionGroupWindow(_, gap) =>
stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap)))
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala
new file mode 100644
index 0000000..4d1579b
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+import java.util.{ArrayList => JArrayList, List => JList}
+import org.apache.flink.api.common.functions.{AggregateFunction => DataStreamAggFunc}
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.types.Row
+
+/**
+ * Aggregate Function used for the aggregate operator in
+ * [[org.apache.flink.streaming.api.datastream.WindowedStream]]
+ *
+ * @param aggregates the list of all [[org.apache.flink.table.functions.AggregateFunction]]
+ * used for this aggregation
+ * @param aggFields the position (in the input Row) of the input value for each aggregate
+ */
+class AggregateAggFunction(
+ private val aggregates: Array[AggregateFunction[_]],
+ private val aggFields: Array[Int])
+ extends DataStreamAggFunc[Row, Row, Row] {
+
+ val aggsWithIdx: Array[(AggregateFunction[_], Int)] = aggregates.zipWithIndex
+
+ override def createAccumulator(): Row = {
+ val accumulatorRow: Row = new Row(aggregates.length)
+ aggsWithIdx.foreach { case (agg, i) =>
+ accumulatorRow.setField(i, agg.createAccumulator())
+ }
+ accumulatorRow
+ }
+
+ override def add(value: Row, accumulatorRow: Row) = {
+
+ aggsWithIdx.foreach { case (agg, i) =>
+ val acc = accumulatorRow.getField(i).asInstanceOf[Accumulator]
+ val v = value.getField(aggFields(i))
+ agg.accumulate(acc, v)
+ }
+ }
+
+ override def getResult(accumulatorRow: Row): Row = {
+ val output = new Row(aggFields.length)
+
+ aggsWithIdx.foreach { case (agg, i) =>
+ output.setField(i, agg.getValue(accumulatorRow.getField(i).asInstanceOf[Accumulator]))
+ }
+ output
+ }
+
+ override def merge(aAccumulatorRow: Row, bAccumulatorRow: Row): Row = {
+
+ aggsWithIdx.foreach { case (agg, i) =>
+ val aAcc = aAccumulatorRow.getField(i).asInstanceOf[Accumulator]
+ val bAcc = bAccumulatorRow.getField(i).asInstanceOf[Accumulator]
+ val accumulators: JList[Accumulator] = new JArrayList[Accumulator]()
+ accumulators.add(aAcc)
+ accumulators.add(bAcc)
+ aAccumulatorRow.setField(i, agg.merge(accumulators))
+ }
+ aAccumulatorRow
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala
deleted file mode 100644
index 89f3b41..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.lang.Iterable
-
-import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.types.Row
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow
-import org.apache.flink.util.Collector
-
-class AggregateAllTimeWindowFunction(
- groupReduceFunction: RichGroupReduceFunction[Row, Row],
- windowStartPos: Option[Int],
- windowEndPos: Option[Int])
- extends AggregateAllWindowFunction[TimeWindow](groupReduceFunction) {
-
- private var collector: TimeWindowPropertyCollector = _
-
- override def open(parameters: Configuration): Unit = {
- collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos)
- super.open(parameters)
- }
-
- override def apply(window: TimeWindow, input: Iterable[Row], out: Collector[Row]): Unit = {
-
- // set collector and window
- collector.wrappedCollector = out
- collector.windowStart = window.getStart
- collector.windowEnd = window.getEnd
-
- // call wrapped reduce function with property collector
- super.apply(window, input, collector)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala
deleted file mode 100644
index 10a06da..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.lang.Iterable
-
-import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.types.Row
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction
-import org.apache.flink.streaming.api.windowing.windows.Window
-import org.apache.flink.util.Collector
-
-class AggregateAllWindowFunction[W <: Window](
- groupReduceFunction: RichGroupReduceFunction[Row, Row])
- extends RichAllWindowFunction[Row, Row, W] {
-
- override def open(parameters: Configuration): Unit = {
- groupReduceFunction.open(parameters)
- }
-
- override def apply(window: W, input: Iterable[Row], out: Collector[Row]): Unit = {
- groupReduceFunction.reduce(input, out)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
index 0033ff7..d936fbb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
@@ -22,34 +22,36 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.types.Row
import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.util.Preconditions
class AggregateMapFunction[IN, OUT](
- private val aggregates: Array[Aggregate[_]],
+ private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val groupingKeys: Array[Int],
@transient private val returnType: TypeInformation[OUT])
- extends RichMapFunction[IN, OUT]
- with ResultTypeQueryable[OUT] {
-
+ extends RichMapFunction[IN, OUT] with ResultTypeQueryable[OUT] {
+
private var output: Row = _
-
+
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(aggFields)
Preconditions.checkArgument(aggregates.length == aggFields.length)
- val partialRowLength = groupingKeys.length +
- aggregates.map(_.intermediateDataType.length).sum
+ val partialRowLength = groupingKeys.length + aggregates.length
output = new Row(partialRowLength)
}
override def map(value: IN): OUT = {
-
+
val input = value.asInstanceOf[Row]
for (i <- aggregates.indices) {
- val fieldValue = input.getField(aggFields(i))
- aggregates(i).prepare(fieldValue, output)
+ val agg = aggregates(i)
+ val accumulator = agg.createAccumulator()
+ agg.accumulate(accumulator, input.getField(aggFields(i)))
+ output.setField(groupingKeys.length + i, accumulator)
}
+
for (i <- groupingKeys.indices) {
output.setField(i, input.getField(groupingKeys(i)))
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
index 5237ecf..06ac8fb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
@@ -19,61 +19,84 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.CombineFunction
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
-import scala.collection.JavaConversions._
-
/**
- * It wraps the aggregate logic inside of
- * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and
- * [[org.apache.flink.api.java.operators.GroupCombineOperator]]
- *
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
- * Row and output Row.
- */
+ * It wraps the aggregate logic inside of
+ * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and
+ * [[org.apache.flink.api.java.operators.GroupCombineOperator]]
+ *
+ * @param aggregates The aggregate functions.
+ * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
+ * and output Row.
+ * @param aggregateMapping The index mapping between aggregate function list and aggregated
+ * value
+ * index in output Row.
+ * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
+ * Row and output Row.
+ * @param finalRowArity the arity of the final resulting row
+ */
class AggregateReduceCombineFunction(
- private val aggregates: Array[Aggregate[_ <: Any]],
+ private val aggregates: Array[AggregateFunction[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
private val groupingSetsMapping: Array[(Int, Int)],
- private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends AggregateReduceGroupFunction(
aggregates,
groupKeysMapping,
aggregateMapping,
groupingSetsMapping,
- intermediateRowArity,
- finalRowArity)
- with CombineFunction[Row, Row] {
+ finalRowArity) with CombineFunction[Row, Row] {
/**
- * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
- *
- * @param records Sub-grouped intermediate aggregate Rows iterator.
- * @return Combined intermediate aggregate Row.
- *
- */
+ * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
+ *
+ * @param records Sub-grouped intermediate aggregate Rows iterator.
+ * @return Combined intermediate aggregate Row.
+ *
+ */
override def combine(records: Iterable[Row]): Row = {
- // Initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
-
- // Merge intermediate aggregate value to buffer.
+ // merge intermediate aggregate value to buffer.
var last: Row = null
- records.foreach((record) => {
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ accumulatorList.foreach(_.clear())
+
+ val iterator = records.iterator()
+
+ var count: Int = 0
+ while (iterator.hasNext) {
+ val record = iterator.next()
+ count += 1
+ // per each aggregator, collect its accumulators to a list
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(groupKeysMapping.length + i)
+ .asInstanceOf[Accumulator])
+ }
+ // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
+ // accumulator
+ if (count > maxMergeLen) {
+ count = 0
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
last = record
- })
+ }
+
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i)))
+ }
- // Set group keys to aggregateBuffer.
+ // set group keys to aggregateBuffer.
for (i <- groupKeysMapping.indices) {
aggregateBuffer.setField(i, last.getField(i))
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
index c147629..23b5236 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
@@ -18,43 +18,48 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
-import scala.collection.JavaConversions._
-
/**
- * It wraps the aggregate logic inside of
- * [[org.apache.flink.api.java.operators.GroupReduceOperator]].
- *
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
- * Row and output Row.
- */
+ * It wraps the aggregate logic inside of
+ * [[org.apache.flink.api.java.operators.GroupReduceOperator]].
+ *
+ * @param aggregates The aggregate functions.
+ * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
+ * and output Row.
+ * @param aggregateMapping The index mapping between aggregate function list and aggregated
+ * value
+ * index in output Row.
+ * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
+ * Row and output Row.
+ * @param finalRowArity The arity of the final resulting row
+ */
class AggregateReduceGroupFunction(
- private val aggregates: Array[Aggregate[_ <: Any]],
+ private val aggregates: Array[AggregateFunction[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
private val groupingSetsMapping: Array[(Int, Int)],
- private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
protected var aggregateBuffer: Row = _
private var output: Row = _
private var intermediateGroupKeys: Option[Array[Int]] = None
+ protected val maxMergeLen = 16
+ val accumulatorList = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator]()
+ }
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupKeysMapping)
- aggregateBuffer = new Row(intermediateRowArity)
+ aggregateBuffer = new Row(aggregates.length + groupKeysMapping.length)
output = new Row(finalRowArity)
if (!groupingSetsMapping.isEmpty) {
intermediateGroupKeys = Some(groupKeysMapping.map(_._1))
@@ -62,25 +67,44 @@ class AggregateReduceGroupFunction(
}
/**
- * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
- * calculate aggregated values output by aggregate buffer, and set them into output
- * Row based on the mapping relation between intermediate aggregate data and output data.
- *
- * @param records Grouped intermediate aggregate Rows iterator.
- * @param out The collector to hand results to.
- *
- */
+ * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
+ * calculate aggregated values output by aggregate buffer, and set them into output
+ * Row based on the mapping relation between intermediate aggregate data and output data.
+ *
+ * @param records Grouped intermediate aggregate Rows iterator.
+ * @param out The collector to hand results to.
+ *
+ */
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
- // Initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
-
- // Merge intermediate aggregate value to buffer.
+ // merge intermediate aggregate value to buffer.
var last: Row = null
- records.foreach((record) => {
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ accumulatorList.foreach(_.clear())
+
+ val iterator = records.iterator()
+
+ var count: Int = 0
+ while (iterator.hasNext) {
+ val record = iterator.next()
+ count += 1
+ // per each aggregator, collect its accumulators to a list
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(groupKeysMapping.length + i)
+ .asInstanceOf[Accumulator])
+ }
+ // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
+ // accumulator
+ if (count > maxMergeLen) {
+ count = 0
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
last = record
- })
+ }
// Set group keys value to final output.
groupKeysMapping.foreach {
@@ -88,10 +112,14 @@ class AggregateReduceGroupFunction(
output.setField(after, last.getField(previous))
}
- // Evaluate final aggregate value and set to output.
+ // get final aggregate value and set to output.
aggregateMapping.foreach {
- case (after, previous) =>
- output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
+ case (after, previous) => {
+ val agg = aggregates(previous)
+ val accumulator = agg.merge(accumulatorList(previous))
+ val result = agg.getValue(accumulator)
+ output.setField(after, result)
+ }
}
// Evaluate additional values of grouping sets
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala
deleted file mode 100644
index 8f96848..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.lang.Iterable
-
-import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.api.java.tuple.Tuple
-import org.apache.flink.types.Row
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow
-import org.apache.flink.util.Collector
-
-class AggregateTimeWindowFunction(
- groupReduceFunction: RichGroupReduceFunction[Row, Row],
- windowStartPos: Option[Int],
- windowEndPos: Option[Int])
- extends AggregateWindowFunction[TimeWindow](groupReduceFunction) {
-
- private var collector: TimeWindowPropertyCollector = _
-
- override def open(parameters: Configuration): Unit = {
- collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos)
- super.open(parameters)
- }
-
- override def apply(
- key: Tuple,
- window: TimeWindow,
- input: Iterable[Row],
- out: Collector[Row]): Unit = {
-
- // set collector and window
- collector.wrappedCollector = out
- collector.windowStart = window.getStart
- collector.windowEnd = window.getEnd
-
- // call wrapped reduce function with property collector
- super.apply(key, window, input, collector)
- }
-}