You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/03/02 20:32:42 UTC
[1/3] flink git commit: [FLINK-5768] [table] Refactor DataSet and
DataStream aggregations to use UDAGG interface.
Repository: flink
Updated Branches:
refs/heads/master 7fe0eb477 -> 438276de8
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala
index 627b25b..5ba3e34 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala
@@ -51,7 +51,7 @@ abstract class AggFunctionTestBase[T] {
// test aggregate functions with partial merge
def testAggregateWithMerge(): Unit = {
- if (ifMethodExitInFunction("merge", aggregator)) {
+ if (ifMethodExistInFunction("merge", aggregator)) {
// iterate over input sets
for ((vals, expected) <- inputValueSets.zip(expectedResults)) {
//equally split the vals sequence into two sequences
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetWindowAggregateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetWindowAggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetWindowAggregateITCase.scala
index f13f350..071f0ee 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetWindowAggregateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetWindowAggregateITCase.scala
@@ -18,6 +18,8 @@
package org.apache.flink.table.runtime.dataset
+import java.math.BigDecimal
+
import org.apache.flink.api.scala._
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.scala._
@@ -37,20 +39,22 @@ class DataSetWindowAggregateITCase(configMode: TableConfigMode)
extends TableProgramsClusterTestBase(configMode) {
val data = List(
- (1L, 1, "Hi"),
- (2L, 2, "Hallo"),
- (3L, 2, "Hello"),
- (6L, 3, "Hello"),
- (4L, 5, "Hello"),
- (16L, 4, "Hello world"),
- (8L, 3, "Hello world"))
+ (1L, 1, 1d, 1f, new BigDecimal("1"), "Hi"),
+ (2L, 2, 2d, 2f, new BigDecimal("2"), "Hallo"),
+ (3L, 2, 2d, 2f, new BigDecimal("2"), "Hello"),
+ (6L, 3, 3d, 3f, new BigDecimal("3"), "Hello"),
+ (4L, 5, 5d, 5f, new BigDecimal("5"), "Hello"),
+ (16L, 4, 4d, 4f, new BigDecimal("4"), "Hello world"),
+ (8L, 3, 3d, 3f, new BigDecimal("3"), "Hello world"))
@Test(expected = classOf[UnsupportedOperationException])
def testAllEventTimeTumblingWindowOverCount(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
- val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
+ val table = env
+ .fromCollection(data)
+ .toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string)
// Count tumbling non-grouping window on event-time are currently not supported
table
@@ -65,14 +69,20 @@ class DataSetWindowAggregateITCase(configMode: TableConfigMode)
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
- val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
+ val table = env
+ .fromCollection(data)
+ .toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string)
val windowedTable = table
.window(Tumble over 2.rows on 'long as 'w)
.groupBy('w, 'string)
- .select('string, 'int.sum)
+ .select('string, 'int.sum, 'int.count, 'int.max, 'int.min, 'int.avg,
+ 'double.sum, 'double.count, 'double.max, 'double.min, 'double.avg,
+ 'float.sum, 'float.count, 'float.max, 'float.min, 'float.avg,
+ 'bigdec.sum, 'bigdec.count, 'bigdec.max, 'bigdec.min, 'bigdec.avg)
- val expected = "Hello,7\n" + "Hello world,7\n"
+ val expected = "Hello,7,2,5,2,3,7.0,2,5.0,2.0,3.5,7.0,2,5.0,2.0,3.5,7,2,5,2,3.5\n" +
+ "Hello world,7,2,4,3,3,7.0,2,4.0,3.0,3.5,7.0,2,4.0,3.0,3.5,7,2,4,3,3.5\n"
val results = windowedTable.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@@ -82,7 +92,9 @@ class DataSetWindowAggregateITCase(configMode: TableConfigMode)
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
- val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
+ val table = env
+ .fromCollection(data)
+ .toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string)
val windowedTable = table
.window(Tumble over 5.milli on 'long as 'w)
@@ -105,7 +117,9 @@ class DataSetWindowAggregateITCase(configMode: TableConfigMode)
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
- val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
+ val table = env
+ .fromCollection(data)
+ .toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string)
val windowedTable = table
.window(Tumble over 5.milli on 'long as 'w)
@@ -125,7 +139,9 @@ class DataSetWindowAggregateITCase(configMode: TableConfigMode)
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
- val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
+ val table = env
+ .fromCollection(data)
+ .toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string)
val windowedTable = table
.window(Session withGap 7.milli on 'long as 'w)
.groupBy('string, 'w)
@@ -146,7 +162,9 @@ class DataSetWindowAggregateITCase(configMode: TableConfigMode)
// Non-grouping Session window on event-time are currently not supported
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
- val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
+ val table = env
+ .fromCollection(data)
+ .toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string)
val windowedTable =table
.window(Session withGap 7.milli on 'long as 'w)
.groupBy('w)
@@ -158,7 +176,9 @@ class DataSetWindowAggregateITCase(configMode: TableConfigMode)
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
- val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
+ val table = env
+ .fromCollection(data)
+ .toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string)
table
.window(Tumble over 5.milli on 'long as 'w)
.groupBy('w, 'string)
[3/3] flink git commit: [FLINK-5768] [table] Refactor DataSet and
DataStream aggregations to use UDAGG interface.
Posted by fh...@apache.org.
[FLINK-5768] [table] Refactor DataSet and DataStream aggregations to use UDAGG interface.
- DataStream aggregates use new WindowedStream.aggregate() operator.
This closes #3423.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/438276de
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/438276de
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/438276de
Branch: refs/heads/master
Commit: 438276de8fab4f1a8f2b62b6452c2e5b2998ce5a
Parents: 7fe0eb4
Author: shaoxuan-wang <ws...@gmail.com>
Authored: Mon Feb 27 19:09:30 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Mar 2 21:31:18 2017 +0100
----------------------------------------------------------------------
.../table/functions/AggregateFunction.scala | 13 +-
.../functions/aggfunctions/AvgAggFunction.scala | 91 ++--
.../aggfunctions/CountAggFunction.scala | 19 +-
.../functions/aggfunctions/MaxAggFunction.scala | 66 ++-
.../functions/aggfunctions/MinAggFunction.scala | 66 ++-
.../functions/aggfunctions/SumAggFunction.scala | 76 ++-
.../utils/UserDefinedFunctionUtils.scala | 4 +-
.../plan/nodes/dataset/DataSetAggregate.scala | 25 +-
.../nodes/dataset/DataSetWindowAggregate.scala | 28 +-
.../nodes/datastream/DataStreamAggregate.scala | 152 ++----
.../aggregate/AggregateAggFunction.scala | 79 ++++
.../AggregateAllTimeWindowFunction.scala | 52 ---
.../aggregate/AggregateAllWindowFunction.scala | 41 --
.../aggregate/AggregateMapFunction.scala | 22 +-
.../AggregateReduceCombineFunction.scala | 89 ++--
.../AggregateReduceGroupFunction.scala | 96 ++--
.../aggregate/AggregateTimeWindowFunction.scala | 57 ---
.../table/runtime/aggregate/AggregateUtil.scala | 464 ++++++++-----------
.../aggregate/AggregateWindowFunction.scala | 46 --
...ionWindowAggregateCombineGroupFunction.scala | 88 ++--
...sionWindowAggregateReduceGroupFunction.scala | 104 +++--
...umbleCountWindowAggReduceGroupFunction.scala | 49 +-
...mbleTimeWindowAggReduceCombineFunction.scala | 58 ++-
...TumbleTimeWindowAggReduceGroupFunction.scala | 60 ++-
.../DataSetWindowAggregateMapFunction.scala | 18 +-
...rementalAggregateAllTimeWindowFunction.scala | 24 +-
.../IncrementalAggregateAllWindowFunction.scala | 30 +-
.../IncrementalAggregateReduceFunction.scala | 63 ---
...IncrementalAggregateTimeWindowFunction.scala | 32 +-
.../IncrementalAggregateWindowFunction.scala | 40 +-
.../scala/stream/table/AggregationsITCase.scala | 10 +-
.../aggfunctions/AggFunctionTestBase.scala | 2 +-
.../dataset/DataSetWindowAggregateITCase.scala | 52 ++-
33 files changed, 1050 insertions(+), 1066 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index e15a8c4..178b439 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -19,12 +19,14 @@ package org.apache.flink.table.functions
import java.util.{List => JList}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+
/**
* Base class for User-Defined Aggregates.
*
* @tparam T the type of the aggregation result
*/
-trait AggregateFunction[T] extends UserDefinedFunction {
+abstract class AggregateFunction[T] extends UserDefinedFunction {
/**
* Create and init the Accumulator for this [[AggregateFunction]].
*
@@ -61,6 +63,15 @@ trait AggregateFunction[T] extends UserDefinedFunction {
* @return the resulting accumulator
*/
def merge(accumulators: JList[Accumulator]): Accumulator
+
+ /**
+ * Returns the [[TypeInformation]] of the accumulator.
+ * This function is optional and can be implemented if the accumulator type cannot automatically
+ * inferred from the instance returned by [[createAccumulator()]].
+ *
+ * @return The type information for the accumulator.
+ */
+ def getAccumulatorType(): TypeInformation[_] = null
}
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
index f4c0b7b..534bb03 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
@@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions
import java.math.{BigDecimal, BigInteger}
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for Integral Avg aggregate function */
+class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator {
+ f0 = 0L //sum
+ f1 = 0L //count
+}
+
/**
* Base class for built-in Integral Avg aggregate function
*
@@ -29,12 +38,6 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] {
- /** The initial accumulator for Integral Avg aggregate function */
- class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator {
- f0 = 0 //sum
- f1 = 0 //count
- }
-
override def createAccumulator(): Accumulator = {
new IntegralAvgAccumulator
}
@@ -44,7 +47,7 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] {
val v = value.asInstanceOf[Number].longValue()
val accum = accumulator.asInstanceOf[IntegralAvgAccumulator]
accum.f0 += v
- accum.f1 += 1
+ accum.f1 += 1L
}
}
@@ -69,6 +72,13 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] {
ret
}
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new IntegralAvgAccumulator().getClass,
+ BasicTypeInfo.LONG_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO)
+ }
+
/**
* Convert the intermediate result to the expected aggregation result type
*
@@ -100,6 +110,13 @@ class IntAvgAggFunction extends IntegralAvgAggFunction[Int] {
override def resultTypeConvert(value: Long): Int = value.toInt
}
+/** The initial accumulator for Big Integral Avg aggregate function */
+class BigIntegralAvgAccumulator
+ extends JTuple2[BigInteger, Long] with Accumulator {
+ f0 = BigInteger.ZERO //sum
+ f1 = 0L //count
+}
+
/**
* Base Class for Built-in Big Integral Avg aggregate function
*
@@ -107,13 +124,6 @@ class IntAvgAggFunction extends IntegralAvgAggFunction[Int] {
*/
abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] {
- /** The initial accumulator for Big Integral Avg aggregate function */
- class BigIntegralAvgAccumulator
- extends JTuple2[BigInteger, Long] with Accumulator {
- f0 = BigInteger.ZERO //sum
- f1 = 0 //count
- }
-
override def createAccumulator(): Accumulator = {
new BigIntegralAvgAccumulator
}
@@ -123,7 +133,7 @@ abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] {
val v = value.asInstanceOf[Long]
val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator]
a.f0 = a.f0.add(BigInteger.valueOf(v))
- a.f1 += 1
+ a.f1 += 1L
}
}
@@ -148,6 +158,13 @@ abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] {
ret
}
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new BigIntegralAvgAccumulator().getClass,
+ BasicTypeInfo.BIG_INT_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO)
+ }
+
/**
* Convert the intermediate result to the expected aggregation result type
*
@@ -166,6 +183,12 @@ class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] {
override def resultTypeConvert(value: BigInteger): Long = value.longValue()
}
+/** The initial accumulator for Floating Avg aggregate function */
+class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator {
+ f0 = 0 //sum
+ f1 = 0L //count
+}
+
/**
* Base class for built-in Floating Avg aggregate function
*
@@ -173,12 +196,6 @@ class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] {
*/
abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] {
- /** The initial accumulator for Floating Avg aggregate function */
- class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator {
- f0 = 0 //sum
- f1 = 0 //count
- }
-
override def createAccumulator(): Accumulator = {
new FloatingAvgAccumulator
}
@@ -188,7 +205,7 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] {
val v = value.asInstanceOf[Number].doubleValue()
val accum = accumulator.asInstanceOf[FloatingAvgAccumulator]
accum.f0 += v
- accum.f1 += 1
+ accum.f1 += 1L
}
}
@@ -213,6 +230,13 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] {
ret
}
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new FloatingAvgAccumulator().getClass,
+ BasicTypeInfo.DOUBLE_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO)
+ }
+
/**
* Convert the intermediate result to the expected aggregation result type
*
@@ -237,18 +261,18 @@ class DoubleAvgAggFunction extends FloatingAvgAggFunction[Double] {
override def resultTypeConvert(value: Double): Double = value
}
+/** The initial accumulator for Big Decimal Avg aggregate function */
+class DecimalAvgAccumulator
+ extends JTuple2[BigDecimal, Long] with Accumulator {
+ f0 = BigDecimal.ZERO //sum
+ f1 = 0L //count
+}
+
/**
* Base class for built-in Big Decimal Avg aggregate function
*/
class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] {
- /** The initial accumulator for Big Decimal Avg aggregate function */
- class DecimalAvgAccumulator
- extends JTuple2[BigDecimal, Long] with Accumulator {
- f0 = BigDecimal.ZERO //sum
- f1 = 0 //count
- }
-
override def createAccumulator(): Accumulator = {
new DecimalAvgAccumulator
}
@@ -262,7 +286,7 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] {
} else {
accum.f0 = accum.f0.add(v)
}
- accum.f1 += 1
+ accum.f1 += 1L
}
}
@@ -286,4 +310,11 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] {
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new DecimalAvgAccumulator().getClass,
+ BasicTypeInfo.BIG_DEC_TYPE_INFO,
+ BasicTypeInfo.LONG_TYPE_INFO)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
index 8b903d1..cf884ed 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
@@ -18,22 +18,25 @@
package org.apache.flink.table.functions.aggfunctions
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for count aggregate function */
+class CountAccumulator extends JTuple1[Long] with Accumulator {
+ f0 = 0L //count
+}
+
/**
* built-in count aggregate function
*/
class CountAggFunction extends AggregateFunction[Long] {
- /** The initial accumulator for count aggregate function */
- class CountAccumulator extends JTuple1[Long] with Accumulator {
- f0 = 0 //count
- }
-
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
- accumulator.asInstanceOf[CountAccumulator].f0 += 1
+ accumulator.asInstanceOf[CountAccumulator].f0 += 1L
}
}
@@ -54,4 +57,8 @@ class CountAggFunction extends AggregateFunction[Long] {
override def createAccumulator(): Accumulator = {
new CountAccumulator
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
index 20041ee..62ff88c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
@@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for Max aggregate function */
+class MaxAccumulator[T] extends JTuple2[T, Boolean] with Accumulator {
+ f0 = 0.asInstanceOf[T] //max
+ f1 = false
+}
+
/**
* Base class for built-in Max aggregate function
*
@@ -29,20 +38,14 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] {
- /** The initial accumulator for Max aggregate function */
- class MaxAccumulator extends JTuple2[T, Boolean] with Accumulator {
- f0 = 0.asInstanceOf[T] //max
- f1 = false
- }
-
override def createAccumulator(): Accumulator = {
- new MaxAccumulator
+ new MaxAccumulator[T]
}
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MaxAccumulator]
+ val a = accumulator.asInstanceOf[MaxAccumulator[T]]
if (!a.f1 || ord.compare(a.f0, v) < 0) {
a.f0 = v
a.f1 = true
@@ -51,7 +54,7 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
}
override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[MaxAccumulator]
+ val a = accumulator.asInstanceOf[MaxAccumulator[T]]
if (a.f1) {
a.f0
} else {
@@ -63,50 +66,73 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
val ret = accumulators.get(0)
var i: Int = 1
while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[MaxAccumulator]
+ val a = accumulators.get(i).asInstanceOf[MaxAccumulator[T]]
if (a.f1) {
- accumulate(ret.asInstanceOf[MaxAccumulator], a.f0)
+ accumulate(ret.asInstanceOf[MaxAccumulator[T]], a.f0)
}
i += 1
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new MaxAccumulator[T].getClass,
+ getValueTypeInfo,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO)
+ }
+
+ def getValueTypeInfo: TypeInformation[_]
}
/**
* Built-in Byte Max aggregate function
*/
-class ByteMaxAggFunction extends MaxAggFunction[Byte]
+class ByteMaxAggFunction extends MaxAggFunction[Byte] {
+ override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO
+}
/**
* Built-in Short Max aggregate function
*/
-class ShortMaxAggFunction extends MaxAggFunction[Short]
+class ShortMaxAggFunction extends MaxAggFunction[Short] {
+ override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO
+}
/**
* Built-in Int Max aggregate function
*/
-class IntMaxAggFunction extends MaxAggFunction[Int]
+class IntMaxAggFunction extends MaxAggFunction[Int] {
+ override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO
+}
/**
* Built-in Long Max aggregate function
*/
-class LongMaxAggFunction extends MaxAggFunction[Long]
+class LongMaxAggFunction extends MaxAggFunction[Long] {
+ override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO
+}
/**
* Built-in Float Max aggregate function
*/
-class FloatMaxAggFunction extends MaxAggFunction[Float]
+class FloatMaxAggFunction extends MaxAggFunction[Float] {
+ override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO
+}
/**
* Built-in Double Max aggregate function
*/
-class DoubleMaxAggFunction extends MaxAggFunction[Double]
+class DoubleMaxAggFunction extends MaxAggFunction[Double] {
+ override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO
+}
/**
* Built-in Boolean Max aggregate function
*/
-class BooleanMaxAggFunction extends MaxAggFunction[Boolean]
+class BooleanMaxAggFunction extends MaxAggFunction[Boolean] {
+ override def getValueTypeInfo = BasicTypeInfo.BOOLEAN_TYPE_INFO
+}
/**
* Built-in Big Decimal Max aggregate function
@@ -116,11 +142,13 @@ class DecimalMaxAggFunction extends MaxAggFunction[BigDecimal] {
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[MaxAccumulator]
+ val accum = accumulator.asInstanceOf[MaxAccumulator[BigDecimal]]
if (!accum.f1 || accum.f0.compareTo(v) < 0) {
accum.f0 = v
accum.f1 = true
}
}
}
+
+ override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
index 16461ae..cddb873 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
@@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for Min aggregate function */
+class MinAccumulator[T] extends JTuple2[T, Boolean] with Accumulator {
+ f0 = 0.asInstanceOf[T] //min
+ f1 = false
+}
+
/**
* Base class for built-in Min aggregate function
*
@@ -29,20 +38,14 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] {
- /** The initial accumulator for Min aggregate function */
- class MinAccumulator extends JTuple2[T, Boolean] with Accumulator {
- f0 = 0.asInstanceOf[T] //min
- f1 = false
- }
-
override def createAccumulator(): Accumulator = {
- new MinAccumulator
+ new MinAccumulator[T]
}
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[MinAccumulator]
+ val a = accumulator.asInstanceOf[MinAccumulator[T]]
if (!a.f1 || ord.compare(a.f0, v) > 0) {
a.f0 = v
a.f1 = true
@@ -51,7 +54,7 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
}
override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[MinAccumulator]
+ val a = accumulator.asInstanceOf[MinAccumulator[T]]
if (a.f1) {
a.f0
} else {
@@ -63,50 +66,73 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
val ret = accumulators.get(0)
var i: Int = 1
while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[MinAccumulator]
+ val a = accumulators.get(i).asInstanceOf[MinAccumulator[T]]
if (a.f1) {
- accumulate(ret.asInstanceOf[MinAccumulator], a.f0)
+ accumulate(ret.asInstanceOf[MinAccumulator[T]], a.f0)
}
i += 1
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ new MinAccumulator[T].getClass,
+ getValueTypeInfo,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO)
+ }
+
+ def getValueTypeInfo: TypeInformation[_]
}
/**
* Built-in Byte Min aggregate function
*/
-class ByteMinAggFunction extends MinAggFunction[Byte]
+class ByteMinAggFunction extends MinAggFunction[Byte] {
+ override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO
+}
/**
* Built-in Short Min aggregate function
*/
-class ShortMinAggFunction extends MinAggFunction[Short]
+class ShortMinAggFunction extends MinAggFunction[Short] {
+ override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO
+}
/**
* Built-in Int Min aggregate function
*/
-class IntMinAggFunction extends MinAggFunction[Int]
+class IntMinAggFunction extends MinAggFunction[Int] {
+ override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO
+}
/**
* Built-in Long Min aggregate function
*/
-class LongMinAggFunction extends MinAggFunction[Long]
+class LongMinAggFunction extends MinAggFunction[Long] {
+ override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO
+}
/**
* Built-in Float Min aggregate function
*/
-class FloatMinAggFunction extends MinAggFunction[Float]
+class FloatMinAggFunction extends MinAggFunction[Float] {
+ override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO
+}
/**
* Built-in Double Min aggregate function
*/
-class DoubleMinAggFunction extends MinAggFunction[Double]
+class DoubleMinAggFunction extends MinAggFunction[Double] {
+ override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO
+}
/**
* Built-in Boolean Min aggregate function
*/
-class BooleanMinAggFunction extends MinAggFunction[Boolean]
+class BooleanMinAggFunction extends MinAggFunction[Boolean] {
+ override def getValueTypeInfo = BasicTypeInfo.BOOLEAN_TYPE_INFO
+}
/**
* Built-in Big Decimal Min aggregate function
@@ -116,11 +142,13 @@ class DecimalMinAggFunction extends MinAggFunction[BigDecimal] {
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[BigDecimal]
- val accum = accumulator.asInstanceOf[MinAccumulator]
+ val accum = accumulator.asInstanceOf[MinAccumulator[BigDecimal]]
if (!accum.f1 || accum.f0.compareTo(v) > 0) {
accum.f0 = v
accum.f1 = true
}
}
}
+
+ override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
index b04d8c0..78fdb8e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
@@ -19,9 +19,15 @@ package org.apache.flink.table.functions.aggfunctions
import java.math.BigDecimal
import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+/** The initial accumulator for Sum aggregate function */
+class SumAccumulator[T] extends JTuple2[T, Boolean] with Accumulator
+
/**
* Base class for built-in Sum aggregate function
*
@@ -29,29 +35,26 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] {
- /** The initial accumulator for Sum aggregate function */
- class SumAccumulator extends JTuple2[T, Boolean] with Accumulator {
- f0 = numeric.zero //sum
- f1 = false
- }
-
private val numeric = implicitly[Numeric[T]]
override def createAccumulator(): Accumulator = {
- new SumAccumulator
+ val acc = new SumAccumulator[T]()
+ acc.f0 = numeric.zero //sum
+ acc.f1 = false
+ acc
}
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
if (value != null) {
val v = value.asInstanceOf[T]
- val a = accumulator.asInstanceOf[SumAccumulator]
+ val a = accumulator.asInstanceOf[SumAccumulator[T]]
a.f0 = numeric.plus(v, a.f0)
a.f1 = true
}
}
override def getValue(accumulator: Accumulator): T = {
- val a = accumulator.asInstanceOf[SumAccumulator]
+ val a = accumulator.asInstanceOf[SumAccumulator[T]]
if (a.f1) {
a.f0
} else {
@@ -60,10 +63,10 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] {
}
override def merge(accumulators: JList[Accumulator]): Accumulator = {
- val ret = createAccumulator().asInstanceOf[SumAccumulator]
+ val ret = createAccumulator().asInstanceOf[SumAccumulator[T]]
var i: Int = 0
while (i < accumulators.size()) {
- val a = accumulators.get(i).asInstanceOf[SumAccumulator]
+ val a = accumulators.get(i).asInstanceOf[SumAccumulator[T]]
if (a.f1) {
ret.f0 = numeric.plus(ret.f0, a.f0)
ret.f1 = true
@@ -72,50 +75,70 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] {
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ (new SumAccumulator).getClass,
+ getValueTypeInfo,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO)
+ }
+
+ def getValueTypeInfo: TypeInformation[_]
}
/**
* Built-in Byte Sum aggregate function
*/
-class ByteSumAggFunction extends SumAggFunction[Byte]
+class ByteSumAggFunction extends SumAggFunction[Byte] {
+ override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO
+}
/**
* Built-in Short Sum aggregate function
*/
-class ShortSumAggFunction extends SumAggFunction[Short]
+class ShortSumAggFunction extends SumAggFunction[Short] {
+ override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO
+}
/**
* Built-in Int Sum aggregate function
*/
-class IntSumAggFunction extends SumAggFunction[Int]
+class IntSumAggFunction extends SumAggFunction[Int] {
+ override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO
+}
/**
* Built-in Long Sum aggregate function
*/
-class LongSumAggFunction extends SumAggFunction[Long]
+class LongSumAggFunction extends SumAggFunction[Long] {
+ override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO
+}
/**
* Built-in Float Sum aggregate function
*/
-class FloatSumAggFunction extends SumAggFunction[Float]
+class FloatSumAggFunction extends SumAggFunction[Float] {
+ override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO
+}
/**
* Built-in Double Sum aggregate function
*/
-class DoubleSumAggFunction extends SumAggFunction[Double]
+class DoubleSumAggFunction extends SumAggFunction[Double] {
+ override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO
+}
+/** The initial accumulator for Big Decimal Sum aggregate function */
+class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator {
+ f0 = BigDecimal.ZERO
+ f1 = false
+}
/**
* Built-in Big Decimal Sum aggregate function
*/
class DecimalSumAggFunction extends AggregateFunction[BigDecimal] {
- /** The initial accumulator for Big Decimal Sum aggregate function */
- class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator {
- f0 = BigDecimal.ZERO
- f1 = false
- }
-
override def createAccumulator(): Accumulator = {
new DecimalSumAccumulator
}
@@ -150,4 +173,11 @@ class DecimalSumAggFunction extends AggregateFunction[BigDecimal] {
}
ret
}
+
+ override def getAccumulatorType(): TypeInformation[_] = {
+ new TupleTypeInfo(
+ (new DecimalSumAccumulator).getClass,
+ BasicTypeInfo.BIG_DEC_TYPE_INFO,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index aec4fbb..21d28b5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -119,9 +119,9 @@ object UserDefinedFunctionUtils {
}
/**
- * Check if a given method exits in the given function
+ * Check if a given method exists in the given function
*/
- def ifMethodExitInFunction(method: String, function: UserDefinedFunction): Boolean = {
+ def ifMethodExistInFunction(method: String, function: UserDefinedFunction): Boolean = {
val methods = function
.getClass
.getMethods
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
index 206e562..a88bcfe 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
@@ -44,9 +44,7 @@ class DataSetAggregate(
inputType: RelDataType,
grouping: Array[Int],
inGroupingSet: Boolean)
- extends SingleRel(cluster, traitSet, inputNode)
- with CommonAggregate
- with DataSetRel {
+ extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataSetRel {
override def deriveRowType(): RelDataType = rowRelDataType
@@ -63,11 +61,13 @@ class DataSetAggregate(
}
override def toString: String = {
- s"Aggregate(${ if (!grouping.isEmpty) {
- s"groupBy: (${groupingToString(inputType, grouping)}), "
- } else {
- ""
- }}select: (${aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)}))"
+ s"Aggregate(${
+ if (!grouping.isEmpty) {
+ s"groupBy: (${groupingToString(inputType, grouping)}), "
+ } else {
+ ""
+ }
+ }select: (${aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)}))"
}
override def explainTerms(pw: RelWriter): RelWriter = {
@@ -76,7 +76,7 @@ class DataSetAggregate(
.item("select", aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil))
}
- override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
+ override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
@@ -87,8 +87,6 @@ class DataSetAggregate(
override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = {
- val config = tableEnv.getConfig
-
val groupingKeys = grouping.indices.toArray
val mapFunction = AggregateUtil.createPrepareMapFunction(
@@ -107,9 +105,7 @@ class DataSetAggregate(
val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)
val prepareOpName = s"prepare select: ($aggString)"
- val mappedInput = inputDS
- .map(mapFunction)
- .name(prepareOpName)
+ val mappedInput = inputDS.map(mapFunction).name(prepareOpName)
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
@@ -127,6 +123,7 @@ class DataSetAggregate(
else {
// global aggregation
val aggOpName = s"select:($aggString)"
+
mappedInput.asInstanceOf[DataSet[Row]]
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
index 48de822..597be8c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
@@ -47,9 +47,7 @@ class DataSetWindowAggregate(
rowRelDataType: RelDataType,
inputType: RelDataType,
grouping: Array[Int])
- extends SingleRel(cluster, traitSet, inputNode)
- with CommonAggregate
- with DataSetRel {
+ extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataSetRel {
override def deriveRowType() = rowRelDataType
@@ -97,7 +95,7 @@ class DataSetWindowAggregate(
namedProperties))
}
- override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
+ override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
val rowSize = this.estimateRowSize(child.getRowType)
@@ -136,8 +134,8 @@ class DataSetWindowAggregate(
private def createEventTimeTumblingWindowDataSet(
inputDS: DataSet[Row],
isTimeWindow: Boolean,
- isParserCaseSensitive: Boolean)
- : DataSet[Row] = {
+ isParserCaseSensitive: Boolean): DataSet[Row] = {
+
val mapFunction = createDataSetWindowPrepareMapFunction(
window,
namedAggregates,
@@ -191,8 +189,7 @@ class DataSetWindowAggregate(
private[this] def createEventTimeSessionWindowDataSet(
inputDS: DataSet[Row],
- isParserCaseSensitive: Boolean)
- : DataSet[Row] = {
+ isParserCaseSensitive: Boolean): DataSet[Row] = {
val groupingKeys = grouping.indices.toArray
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
@@ -207,10 +204,7 @@ class DataSetWindowAggregate(
inputType,
isParserCaseSensitive)
- val mappedInput =
- inputDS
- .map(mapFunction)
- .name(prepareOperatorName)
+ val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName)
val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType
@@ -218,7 +212,7 @@ class DataSetWindowAggregate(
val rowTimeFieldPos = mapReturnType.getArity - 1
// do incremental aggregation
- if (doAllSupportPartialAggregation(
+ if (doAllSupportPartialMerge(
namedAggregates.map(_.getKey),
inputType,
grouping.length)) {
@@ -267,10 +261,10 @@ class DataSetWindowAggregate(
namedProperties)
mappedInput.groupBy(groupingKeys: _*)
- .sortGroup(rowTimeFieldPos, Order.ASCENDING)
- .reduceGroup(groupReduceFunction)
- .returns(rowTypeInfo)
- .name(aggregateOperatorName)
+ .sortGroup(rowTimeFieldPos, Order.ASCENDING)
+ .reduceGroup(groupReduceFunction)
+ .returns(rowTypeInfo)
+ .name(aggregateOperatorName)
}
}
// non-grouping window
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
index c21d008..50f8281 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
@@ -50,9 +50,7 @@ class DataStreamAggregate(
rowRelDataType: RelDataType,
inputType: RelDataType,
grouping: Array[Int])
- extends SingleRel(cluster, traitSet, inputNode)
- with CommonAggregate
- with DataStreamRel {
+ extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel {
override def deriveRowType(): RelDataType = rowRelDataType
@@ -91,12 +89,13 @@ class DataStreamAggregate(
super.explainTerms(pw)
.itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty)
.item("window", window)
- .item("select", aggregationToString(
- inputType,
- grouping,
- getRowType,
- namedAggregates,
- namedProperties))
+ .item(
+ "select", aggregationToString(
+ inputType,
+ grouping,
+ getRowType,
+ namedAggregates,
+ namedProperties))
}
override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = {
@@ -113,116 +112,61 @@ class DataStreamAggregate(
namedAggregates,
namedProperties)
- val prepareOpName = s"prepare select: ($aggString)"
val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " +
s"window: ($window), " +
s"select: ($aggString)"
val nonKeyedAggOpName = s"window: ($window), select: ($aggString)"
- val mapFunction = AggregateUtil.createPrepareMapFunction(
- namedAggregates,
- grouping,
- inputType)
-
- val mappedInput = inputDS.map(mapFunction).name(prepareOpName)
-
-
- // check whether all aggregates support partial aggregate
- if (AggregateUtil.doAllSupportPartialAggregation(
- namedAggregates.map(_.getKey),
- inputType,
- grouping.length)) {
- // do Incremental Aggregation
- val reduceFunction = AggregateUtil.createIncrementalAggregateReduceFunction(
- namedAggregates,
- inputType,
- getRowType,
- grouping)
- // grouped / keyed aggregation
- if (groupingKeys.length > 0) {
- val windowFunction = AggregateUtil.createWindowIncrementalAggregationFunction(
- window,
- namedAggregates,
- inputType,
- rowRelDataType,
- grouping,
- namedProperties)
-
- val keyedStream = mappedInput.keyBy(groupingKeys: _*)
- val windowedStream =
- createKeyedWindowedStream(window, keyedStream)
+ // grouped / keyed aggregation
+ if (groupingKeys.length > 0) {
+ val windowFunction = AggregateUtil.createAggregationGroupWindowFunction(
+ window,
+ groupingKeys.length,
+ namedAggregates.size,
+ rowRelDataType.getFieldCount,
+ namedProperties)
+
+ val keyedStream = inputDS.keyBy(groupingKeys: _*)
+ val windowedStream =
+ createKeyedWindowedStream(window, keyedStream)
.asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]]
- windowedStream
- .reduce(reduceFunction, windowFunction)
- .returns(rowTypeInfo)
- .name(keyedAggOpName)
- }
- // global / non-keyed aggregation
- else {
- val windowFunction = AggregateUtil.createAllWindowIncrementalAggregationFunction(
- window,
+ val (aggFunction, accumulatorRowType, aggResultRowType) =
+ AggregateUtil.createDataStreamAggregateFunction(
namedAggregates,
inputType,
rowRelDataType,
- grouping,
- namedProperties)
+ grouping)
- val windowedStream =
- createNonKeyedWindowedStream(window, mappedInput)
- .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]
-
- windowedStream
- .reduce(reduceFunction, windowFunction)
- .returns(rowTypeInfo)
- .name(nonKeyedAggOpName)
- }
+ windowedStream
+ .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo)
+ .name(keyedAggOpName)
}
+ // global / non-keyed aggregation
else {
- // do non-Incremental Aggregation
- // grouped / keyed aggregation
- if (groupingKeys.length > 0) {
-
- val windowFunction = AggregateUtil.createWindowAggregationFunction(
- window,
- namedAggregates,
- inputType,
- rowRelDataType,
- grouping,
- namedProperties)
+ val windowFunction = AggregateUtil.createAggregationAllWindowFunction(
+ window,
+ rowRelDataType.getFieldCount,
+ namedProperties)
- val keyedStream = mappedInput.keyBy(groupingKeys: _*)
- val windowedStream =
- createKeyedWindowedStream(window, keyedStream)
- .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]]
+ val windowedStream =
+ createNonKeyedWindowedStream(window, inputDS)
+ .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]
- windowedStream
- .apply(windowFunction)
- .returns(rowTypeInfo)
- .name(keyedAggOpName)
- }
- // global / non-keyed aggregation
- else {
- val windowFunction = AggregateUtil.createAllWindowAggregationFunction(
- window,
+ val (aggFunction, accumulatorRowType, aggResultRowType) =
+ AggregateUtil.createDataStreamAggregateFunction(
namedAggregates,
inputType,
rowRelDataType,
- grouping,
- namedProperties)
-
- val windowedStream =
- createNonKeyedWindowedStream(window, mappedInput)
- .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]
+ grouping)
- windowedStream
- .apply(windowFunction)
- .returns(rowTypeInfo)
+ windowedStream
+ .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo)
.name(nonKeyedAggOpName)
- }
}
}
}
+
object DataStreamAggregate {
@@ -242,8 +186,8 @@ object DataStreamAggregate {
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
- throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " +
- "currently not supported.")
+ throw new UnsupportedOperationException(
+ "Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) =>
stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide)))
@@ -258,8 +202,8 @@ object DataStreamAggregate {
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
- throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " +
- "currently not supported.")
+ throw new UnsupportedOperationException(
+ "Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSessionGroupWindow(_, gap: Expression) =>
stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap)))
@@ -284,8 +228,8 @@ object DataStreamAggregate {
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
- throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " +
- "currently not supported.")
+ throw new UnsupportedOperationException(
+ "Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) =>
stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide)))
@@ -300,8 +244,8 @@ object DataStreamAggregate {
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
- throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " +
- "currently not supported.")
+ throw new UnsupportedOperationException(
+ "Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSessionGroupWindow(_, gap) =>
stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap)))
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala
new file mode 100644
index 0000000..4d1579b
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+import java.util.{ArrayList => JArrayList, List => JList}
+import org.apache.flink.api.common.functions.{AggregateFunction => DataStreamAggFunc}
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.types.Row
+
+/**
+ * Aggregate Function used for the aggregate operator in
+ * [[org.apache.flink.streaming.api.datastream.WindowedStream]]
+ *
+ * @param aggregates the list of all [[org.apache.flink.table.functions.AggregateFunction]]
+ * used for this aggregation
+ * @param aggFields the position (in the input Row) of the input value for each aggregate
+ */
+class AggregateAggFunction(
+ private val aggregates: Array[AggregateFunction[_]],
+ private val aggFields: Array[Int])
+ extends DataStreamAggFunc[Row, Row, Row] {
+
+ val aggsWithIdx: Array[(AggregateFunction[_], Int)] = aggregates.zipWithIndex
+
+ override def createAccumulator(): Row = {
+ val accumulatorRow: Row = new Row(aggregates.length)
+ aggsWithIdx.foreach { case (agg, i) =>
+ accumulatorRow.setField(i, agg.createAccumulator())
+ }
+ accumulatorRow
+ }
+
+ override def add(value: Row, accumulatorRow: Row) = {
+
+ aggsWithIdx.foreach { case (agg, i) =>
+ val acc = accumulatorRow.getField(i).asInstanceOf[Accumulator]
+ val v = value.getField(aggFields(i))
+ agg.accumulate(acc, v)
+ }
+ }
+
+ override def getResult(accumulatorRow: Row): Row = {
+ val output = new Row(aggFields.length)
+
+ aggsWithIdx.foreach { case (agg, i) =>
+ output.setField(i, agg.getValue(accumulatorRow.getField(i).asInstanceOf[Accumulator]))
+ }
+ output
+ }
+
+ override def merge(aAccumulatorRow: Row, bAccumulatorRow: Row): Row = {
+
+ aggsWithIdx.foreach { case (agg, i) =>
+ val aAcc = aAccumulatorRow.getField(i).asInstanceOf[Accumulator]
+ val bAcc = bAccumulatorRow.getField(i).asInstanceOf[Accumulator]
+ val accumulators: JList[Accumulator] = new JArrayList[Accumulator]()
+ accumulators.add(aAcc)
+ accumulators.add(bAcc)
+ aAccumulatorRow.setField(i, agg.merge(accumulators))
+ }
+ aAccumulatorRow
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala
deleted file mode 100644
index 89f3b41..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.lang.Iterable
-
-import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.types.Row
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow
-import org.apache.flink.util.Collector
-
-class AggregateAllTimeWindowFunction(
- groupReduceFunction: RichGroupReduceFunction[Row, Row],
- windowStartPos: Option[Int],
- windowEndPos: Option[Int])
- extends AggregateAllWindowFunction[TimeWindow](groupReduceFunction) {
-
- private var collector: TimeWindowPropertyCollector = _
-
- override def open(parameters: Configuration): Unit = {
- collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos)
- super.open(parameters)
- }
-
- override def apply(window: TimeWindow, input: Iterable[Row], out: Collector[Row]): Unit = {
-
- // set collector and window
- collector.wrappedCollector = out
- collector.windowStart = window.getStart
- collector.windowEnd = window.getEnd
-
- // call wrapped reduce function with property collector
- super.apply(window, input, collector)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala
deleted file mode 100644
index 10a06da..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.lang.Iterable
-
-import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.types.Row
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction
-import org.apache.flink.streaming.api.windowing.windows.Window
-import org.apache.flink.util.Collector
-
-class AggregateAllWindowFunction[W <: Window](
- groupReduceFunction: RichGroupReduceFunction[Row, Row])
- extends RichAllWindowFunction[Row, Row, W] {
-
- override def open(parameters: Configuration): Unit = {
- groupReduceFunction.open(parameters)
- }
-
- override def apply(window: W, input: Iterable[Row], out: Collector[Row]): Unit = {
- groupReduceFunction.reduce(input, out)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
index 0033ff7..d936fbb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
@@ -22,34 +22,36 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.types.Row
import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.util.Preconditions
class AggregateMapFunction[IN, OUT](
- private val aggregates: Array[Aggregate[_]],
+ private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val groupingKeys: Array[Int],
@transient private val returnType: TypeInformation[OUT])
- extends RichMapFunction[IN, OUT]
- with ResultTypeQueryable[OUT] {
-
+ extends RichMapFunction[IN, OUT] with ResultTypeQueryable[OUT] {
+
private var output: Row = _
-
+
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(aggFields)
Preconditions.checkArgument(aggregates.length == aggFields.length)
- val partialRowLength = groupingKeys.length +
- aggregates.map(_.intermediateDataType.length).sum
+ val partialRowLength = groupingKeys.length + aggregates.length
output = new Row(partialRowLength)
}
override def map(value: IN): OUT = {
-
+
val input = value.asInstanceOf[Row]
for (i <- aggregates.indices) {
- val fieldValue = input.getField(aggFields(i))
- aggregates(i).prepare(fieldValue, output)
+ val agg = aggregates(i)
+ val accumulator = agg.createAccumulator()
+ agg.accumulate(accumulator, input.getField(aggFields(i)))
+ output.setField(groupingKeys.length + i, accumulator)
}
+
for (i <- groupingKeys.indices) {
output.setField(i, input.getField(groupingKeys(i)))
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
index 5237ecf..06ac8fb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
@@ -19,61 +19,84 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.CombineFunction
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
-import scala.collection.JavaConversions._
-
/**
- * It wraps the aggregate logic inside of
- * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and
- * [[org.apache.flink.api.java.operators.GroupCombineOperator]]
- *
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
- * Row and output Row.
- */
+ * It wraps the aggregate logic inside of
+ * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and
+ * [[org.apache.flink.api.java.operators.GroupCombineOperator]]
+ *
+ * @param aggregates The aggregate functions.
+ * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
+ * and output Row.
+ * @param aggregateMapping The index mapping between aggregate function list and aggregated
+ * value
+ * index in output Row.
+ * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
+ * Row and output Row.
+ * @param finalRowArity the arity of the final resulting row
+ */
class AggregateReduceCombineFunction(
- private val aggregates: Array[Aggregate[_ <: Any]],
+ private val aggregates: Array[AggregateFunction[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
private val groupingSetsMapping: Array[(Int, Int)],
- private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends AggregateReduceGroupFunction(
aggregates,
groupKeysMapping,
aggregateMapping,
groupingSetsMapping,
- intermediateRowArity,
- finalRowArity)
- with CombineFunction[Row, Row] {
+ finalRowArity) with CombineFunction[Row, Row] {
/**
- * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
- *
- * @param records Sub-grouped intermediate aggregate Rows iterator.
- * @return Combined intermediate aggregate Row.
- *
- */
+ * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
+ *
+ * @param records Sub-grouped intermediate aggregate Rows iterator.
+ * @return Combined intermediate aggregate Row.
+ *
+ */
override def combine(records: Iterable[Row]): Row = {
- // Initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
-
- // Merge intermediate aggregate value to buffer.
+ // merge intermediate aggregate value to buffer.
var last: Row = null
- records.foreach((record) => {
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ accumulatorList.foreach(_.clear())
+
+ val iterator = records.iterator()
+
+ var count: Int = 0
+ while (iterator.hasNext) {
+ val record = iterator.next()
+ count += 1
+ // per each aggregator, collect its accumulators to a list
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(groupKeysMapping.length + i)
+ .asInstanceOf[Accumulator])
+ }
+ // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
+ // accumulator
+ if (count > maxMergeLen) {
+ count = 0
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
last = record
- })
+ }
+
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i)))
+ }
- // Set group keys to aggregateBuffer.
+ // set group keys to aggregateBuffer.
for (i <- groupKeysMapping.indices) {
aggregateBuffer.setField(i, last.getField(i))
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
index c147629..23b5236 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
@@ -18,43 +18,48 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
-import scala.collection.JavaConversions._
-
/**
- * It wraps the aggregate logic inside of
- * [[org.apache.flink.api.java.operators.GroupReduceOperator]].
- *
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
- * Row and output Row.
- */
+ * It wraps the aggregate logic inside of
+ * [[org.apache.flink.api.java.operators.GroupReduceOperator]].
+ *
+ * @param aggregates The aggregate functions.
+ * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
+ * and output Row.
+ * @param aggregateMapping The index mapping between aggregate function list and aggregated
+ * value
+ * index in output Row.
+ * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
+ * Row and output Row.
+ * @param finalRowArity The arity of the final resulting row
+ */
class AggregateReduceGroupFunction(
- private val aggregates: Array[Aggregate[_ <: Any]],
+ private val aggregates: Array[AggregateFunction[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
private val groupingSetsMapping: Array[(Int, Int)],
- private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
protected var aggregateBuffer: Row = _
private var output: Row = _
private var intermediateGroupKeys: Option[Array[Int]] = None
+ protected val maxMergeLen = 16
+ val accumulatorList = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator]()
+ }
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupKeysMapping)
- aggregateBuffer = new Row(intermediateRowArity)
+ aggregateBuffer = new Row(aggregates.length + groupKeysMapping.length)
output = new Row(finalRowArity)
if (!groupingSetsMapping.isEmpty) {
intermediateGroupKeys = Some(groupKeysMapping.map(_._1))
@@ -62,25 +67,44 @@ class AggregateReduceGroupFunction(
}
/**
- * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
- * calculate aggregated values output by aggregate buffer, and set them into output
- * Row based on the mapping relation between intermediate aggregate data and output data.
- *
- * @param records Grouped intermediate aggregate Rows iterator.
- * @param out The collector to hand results to.
- *
- */
+ * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
+ * calculate aggregated values output by aggregate buffer, and set them into output
+ * Row based on the mapping relation between intermediate aggregate data and output data.
+ *
+ * @param records Grouped intermediate aggregate Rows iterator.
+ * @param out The collector to hand results to.
+ *
+ */
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
- // Initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
-
- // Merge intermediate aggregate value to buffer.
+ // merge intermediate aggregate value to buffer.
var last: Row = null
- records.foreach((record) => {
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ accumulatorList.foreach(_.clear())
+
+ val iterator = records.iterator()
+
+ var count: Int = 0
+ while (iterator.hasNext) {
+ val record = iterator.next()
+ count += 1
+ // per each aggregator, collect its accumulators to a list
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(groupKeysMapping.length + i)
+ .asInstanceOf[Accumulator])
+ }
+ // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
+ // accumulator
+ if (count > maxMergeLen) {
+ count = 0
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
last = record
- })
+ }
// Set group keys value to final output.
groupKeysMapping.foreach {
@@ -88,10 +112,14 @@ class AggregateReduceGroupFunction(
output.setField(after, last.getField(previous))
}
- // Evaluate final aggregate value and set to output.
+ // get final aggregate value and set to output.
aggregateMapping.foreach {
- case (after, previous) =>
- output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
+ case (after, previous) => {
+ val agg = aggregates(previous)
+ val accumulator = agg.merge(accumulatorList(previous))
+ val result = agg.getValue(accumulator)
+ output.setField(after, result)
+ }
}
// Evaluate additional values of grouping sets
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala
deleted file mode 100644
index 8f96848..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.lang.Iterable
-
-import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.api.java.tuple.Tuple
-import org.apache.flink.types.Row
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow
-import org.apache.flink.util.Collector
-
-class AggregateTimeWindowFunction(
- groupReduceFunction: RichGroupReduceFunction[Row, Row],
- windowStartPos: Option[Int],
- windowEndPos: Option[Int])
- extends AggregateWindowFunction[TimeWindow](groupReduceFunction) {
-
- private var collector: TimeWindowPropertyCollector = _
-
- override def open(parameters: Configuration): Unit = {
- collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos)
- super.open(parameters)
- }
-
- override def apply(
- key: Tuple,
- window: TimeWindow,
- input: Iterable[Row],
- out: Collector[Row]): Unit = {
-
- // set collector and window
- collector.wrappedCollector = out
- collector.windowStart = window.getStart
- collector.windowEnd = window.getEnd
-
- // call wrapped reduce function with property collector
- super.apply(key, window, input, collector)
- }
-}
[2/3] flink git commit: [FLINK-5768] [table] Refactor DataSet and
DataStream aggregations to use UDAGG interface.
Posted by fh...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index cd473ee..40468ad 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -25,8 +25,8 @@ import org.apache.calcite.sql.{SqlAggFunction, SqlKind}
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.fun._
-import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction,RichGroupCombineFunction}
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.common.functions.{InvalidTypesException, MapFunction, RichGroupCombineFunction, RichGroupReduceFunction, AggregateFunction => ApiAggregateFunction}
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeHint, TypeInformation}
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
@@ -37,6 +37,9 @@ import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
import org.apache.flink.table.api.{TableException, Types}
+import org.apache.flink.table.functions.aggfunctions._
+import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo}
import org.apache.flink.types.Row
@@ -54,15 +57,15 @@ object AggregateUtil {
* organized by the following format:
*
* {{{
- * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5
- * | |
- * v v
- * +---------+---------+--------+--------+--------+--------+
- * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 |
- * +---------+---------+--------+--------+--------+--------+
+ * avg(x) count(z)
+ * | |
+ * v v
+ * +---------+---------+-----------------+------------------+------------------+
+ * |groupKey1|groupKey2| AvgAccumulator | SumAccumulator | CountAccumulator |
+ * +---------+---------+-----------------+------------------+------------------+
* ^
* |
- * sum(y) aggOffsetInRow = 4
+ * sum(y)
* }}}
*
*/
@@ -70,15 +73,15 @@ object AggregateUtil {
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
groupings: Array[Int],
inputType: RelDataType)
- : MapFunction[Row, Row] = {
+ : MapFunction[Row, Row] = {
- val (aggFieldIndexes,aggregates) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
groupings.length)
val mapReturnType: RowTypeInfo =
- createAggregateBufferDataType(groupings, aggregates, inputType)
+ createDataSetAggregateBufferDataType(groupings, aggregates, inputType)
val mapFunction = new AggregateMapFunction[Row, Row](
aggregates,
@@ -89,7 +92,6 @@ object AggregateUtil {
mapFunction
}
-
/**
* Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates.
* The output of the function contains the grouping keys and the timestamp and the intermediate
@@ -98,17 +100,16 @@ object AggregateUtil {
* event-time, the timestamp is not aligned and used to sort.
*
* The output is stored in Row by the following format:
- *
* {{{
- * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5
- * | |
- * v v
- * +---------+---------+--------+--------+--------+--------+--------+
- * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 | rowtime|
- * +---------+---------+--------+--------+--------+--------+--------+
- * ^ ^
- * | |
- * sum(y) aggOffsetInRow = 4 rowtime to group or sort
+ * avg(x) count(z)
+ * | |
+ * v v
+ * +---------+---------+----------------+----------------+------------------+-------+
+ * |groupKey1|groupKey2| AvgAccumulator | SumAccumulator | CountAccumulator |rowtime|
+ * +---------+---------+----------------+----------------+------------------+-------+
+ * ^ ^
+ * | |
+ * sum(y) rowtime to group or sort
* }}}
*
* NOTE: this function is only used for time based window on batch tables.
@@ -119,7 +120,7 @@ object AggregateUtil {
groupings: Array[Int],
inputType: RelDataType,
isParserCaseSensitive: Boolean)
- : MapFunction[Row, Row] = {
+ : MapFunction[Row, Row] = {
val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
@@ -127,7 +128,11 @@ object AggregateUtil {
groupings.length)
val mapReturnType: RowTypeInfo =
- createAggregateBufferDataType(groupings, aggregates, inputType, Some(Array(Types.LONG)))
+ createDataSetAggregateBufferDataType(
+ groupings,
+ aggregates,
+ inputType,
+ Some(Array(Types.LONG)))
val (timeFieldPos, tumbleTimeWindowSize) = window match {
case EventTimeTumblingGroupWindow(_, time, size) =>
@@ -175,9 +180,6 @@ object AggregateUtil {
inputType,
groupings.length)._2
- val intermediateRowArity = groupings.length +
- aggregates.map(_.intermediateDataType.length).sum
-
// the mapping relation between field index of intermediate aggregate Row and output Row.
val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings)
@@ -196,30 +198,26 @@ object AggregateUtil {
case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) =>
// tumbling time window
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
- if (aggregates.forall(_.supportPartial)) {
+ if (doAllSupportPartialMerge(aggregates)) {
// for incremental aggregations
new DataSetTumbleTimeWindowAggReduceCombineFunction(
- intermediateRowArity,
asLong(size),
startPos,
endPos,
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
- intermediateRowArity + 1, // the additional field is used to store the time attribute
outputType.getFieldCount)
}
else {
// for non-incremental aggregations
new DataSetTumbleTimeWindowAggReduceGroupFunction(
- intermediateRowArity,
asLong(size),
startPos,
endPos,
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
- intermediateRowArity + 1, // the additional field is used to store the time attribute
outputType.getFieldCount)
}
case EventTimeTumblingGroupWindow(_, _, size) =>
@@ -229,7 +227,6 @@ object AggregateUtil {
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
- intermediateRowArity + 1,// the additional field is used to store the time attribute
outputType.getFieldCount)
case EventTimeSessionGroupWindow(_, _, gap) =>
@@ -238,8 +235,6 @@ object AggregateUtil {
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
- // the additional two fields are used to store window-start and window-end attributes
- intermediateRowArity + 2,
outputType.getFieldCount,
startPos,
endPos,
@@ -255,19 +250,16 @@ object AggregateUtil {
* for aggregates.
* The function returns intermediate aggregate values of all aggregate function which are
* organized by the following format:
- *
* {{{
- * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5
- * | | windowEnd(max(rowtime)
- * | | |
- * v v v
- * +---------+---------+--------+--------+--------+--------+-----------+---------+
- * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 |windowStart|windowEnd|
- * +---------+---------+--------+--------+--------+--------+-----------+---------+
- * ^ ^
- * | |
- * sum(y) aggOffsetInRow = 4 windowStart(min(rowtime))
- *
+ * avg(x) windowEnd(max(rowtime)
+ * | |
+ * v v
+ * +---------+---------+----------------+----------------+-------------+-----------+
+ * |groupKey1|groupKey2| AvgAccumulator | SumAccumulator | windowStart | windowEnd |
+ * +---------+---------+----------------+----------------+-------------+-----------+
+ * ^ ^
+ * | |
+ * sum(y) windowStart(min(rowtime))
* }}}
*
*/
@@ -276,20 +268,17 @@ object AggregateUtil {
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
groupings: Array[Int])
- : RichGroupCombineFunction[Row,Row] = {
+ : RichGroupCombineFunction[Row, Row] = {
val aggregates = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
groupings.length)._2
- val intermediateRowArity = groupings.length +
- aggregates.map(_.intermediateDataType.length).sum
-
window match {
case EventTimeSessionGroupWindow(_, _, gap) =>
val combineReturnType: RowTypeInfo =
- createAggregateBufferDataType(
+ createDataSetAggregateBufferDataType(
groupings,
aggregates,
inputType,
@@ -298,8 +287,6 @@ object AggregateUtil {
new DataSetSessionWindowAggregateCombineGroupFunction(
aggregates,
groupings,
- // the addition two fields are used to store window-start and window-end attributes
- intermediateRowArity + 2,
asLong(gap),
combineReturnType)
case _ =>
@@ -324,10 +311,10 @@ object AggregateUtil {
inGroupingSet: Boolean)
: RichGroupReduceFunction[Row, Row] = {
- val aggregates = transformToAggregateFunctions(
+ val (aggFieldIndex, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- groupings.length)._2
+ groupings.length)
val (groupingOffsetMapping, aggOffsetMapping) =
getGroupingOffsetAndAggOffsetMapping(
@@ -342,19 +329,13 @@ object AggregateUtil {
Array()
}
- val allPartialAggregate: Boolean = aggregates.forall(_.supportPartial)
-
- val intermediateRowArity = groupings.length +
- aggregates.map(_.intermediateDataType.length).sum
-
val groupReduceFunction =
- if (allPartialAggregate) {
+ if (doAllSupportPartialMerge(aggregates)) {
new AggregateReduceCombineFunction(
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
groupingSetsMapping,
- intermediateRowArity,
outputType.getFieldCount)
}
else {
@@ -363,199 +344,109 @@ object AggregateUtil {
groupingOffsetMapping,
aggOffsetMapping,
groupingSetsMapping,
- intermediateRowArity,
outputType.getFieldCount)
}
groupReduceFunction
}
/**
- * Create a [[org.apache.flink.api.common.functions.ReduceFunction]] for incremental window
- * aggregation.
- *
- */
- private[flink] def createIncrementalAggregateReduceFunction(
- namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- inputType: RelDataType,
- outputType: RelDataType,
- groupings: Array[Int])
- : IncrementalAggregateReduceFunction = {
-
- val aggregates = transformToAggregateFunctions(
- namedAggregates.map(_.getKey),inputType,groupings.length)._2
-
- val groupingOffsetMapping =
- getGroupingOffsetAndAggOffsetMapping(
- namedAggregates,
- inputType,
- outputType,
- groupings)._1
-
- val intermediateRowArity = groupings.length + aggregates.map(_.intermediateDataType.length).sum
- val reduceFunction = new IncrementalAggregateReduceFunction(
- aggregates,
- groupingOffsetMapping,
- intermediateRowArity)
- reduceFunction
- }
-
- /**
- * Create an [[AllWindowFunction]] to compute non-partitioned group window aggregates.
+ * Create an [[AllWindowFunction]] for non-partitioned window aggregates.
*/
- private[flink] def createAllWindowAggregationFunction(
+ private[flink] def createAggregationAllWindowFunction(
window: LogicalWindow,
- namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- inputType: RelDataType,
- outputType: RelDataType,
- groupings: Array[Int],
+ finalRowArity: Int,
properties: Seq[NamedWindowProperty])
: AllWindowFunction[Row, Row, DataStreamWindow] = {
- val aggFunction =
- createAggregateGroupReduceFunction(
- namedAggregates,
- inputType,
- outputType,
- groupings,
- inGroupingSet = false)
-
if (isTimeWindow(window)) {
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
- new AggregateAllTimeWindowFunction(aggFunction, startPos, endPos)
- .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]]
+ new IncrementalAggregateAllTimeWindowFunction(
+ startPos,
+ endPos,
+ finalRowArity)
+ .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]]
} else {
- new AggregateAllWindowFunction(aggFunction)
+ new IncrementalAggregateAllWindowFunction(
+ finalRowArity)
}
}
/**
- * Create a [[WindowFunction]] to compute partitioned group window aggregates.
- *
+ * Create a [[WindowFunction]] for group window aggregates.
*/
- private[flink] def createWindowAggregationFunction(
+ private[flink] def createAggregationGroupWindowFunction(
window: LogicalWindow,
- namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- inputType: RelDataType,
- outputType: RelDataType,
- groupings: Array[Int],
+ numGroupingKeys: Int,
+ numAggregates: Int,
+ finalRowArity: Int,
properties: Seq[NamedWindowProperty])
: WindowFunction[Row, Row, Tuple, DataStreamWindow] = {
- val aggFunction =
- createAggregateGroupReduceFunction(
- namedAggregates,
- inputType,
- outputType,
- groupings,
- inGroupingSet = false)
-
if (isTimeWindow(window)) {
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
- new AggregateTimeWindowFunction(aggFunction, startPos, endPos)
- .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]]
+ new IncrementalAggregateTimeWindowFunction(
+ numGroupingKeys,
+ numAggregates,
+ startPos,
+ endPos,
+ finalRowArity)
+ .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]]
} else {
- new AggregateWindowFunction(aggFunction)
+ new IncrementalAggregateWindowFunction(
+ numGroupingKeys,
+ numAggregates,
+ finalRowArity)
}
}
- /**
- * Create an [[AllWindowFunction]] to finalize incrementally pre-computed non-partitioned
- * window aggregates.
- */
- private[flink] def createAllWindowIncrementalAggregationFunction(
- window: LogicalWindow,
+ private[flink] def createDataStreamAggregateFunction(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
outputType: RelDataType,
- groupings: Array[Int],
- properties: Seq[NamedWindowProperty])
- : AllWindowFunction[Row, Row, DataStreamWindow] = {
+ groupKeysIndex: Array[Int])
+ : (ApiAggregateFunction[Row, Row, Row], RowTypeInfo, RowTypeInfo) = {
- val aggregates = transformToAggregateFunctions(
- namedAggregates.map(_.getKey),inputType,groupings.length)._2
+ val (aggFields, aggregates) =
+ transformToAggregateFunctions(namedAggregates.map(_.getKey), inputType, groupKeysIndex.length)
- val (groupingOffsetMapping, aggOffsetMapping) =
- getGroupingOffsetAndAggOffsetMapping(
- namedAggregates,
- inputType,
- outputType,
- groupings)
-
- val finalRowArity = outputType.getFieldCount
+ val aggregateMapping = getAggregateMapping(namedAggregates, outputType)
- if (isTimeWindow(window)) {
- val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
- new IncrementalAggregateAllTimeWindowFunction(
- aggregates,
- groupingOffsetMapping,
- aggOffsetMapping,
- finalRowArity,
- startPos,
- endPos)
- .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]]
- } else {
- new IncrementalAggregateAllWindowFunction(
- aggregates,
- groupingOffsetMapping,
- aggOffsetMapping,
- finalRowArity)
+ if (aggregateMapping.length != namedAggregates.length) {
+ throw new TableException(
+ "Could not find output field in input data type or aggregate functions.")
}
- }
- /**
- * Create a [[WindowFunction]] to finalize incrementally pre-computed window aggregates.
- */
- private[flink] def createWindowIncrementalAggregationFunction(
- window: LogicalWindow,
- namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- inputType: RelDataType,
- outputType: RelDataType,
- groupings: Array[Int],
- properties: Seq[NamedWindowProperty])
- : WindowFunction[Row, Row, Tuple, DataStreamWindow] = {
+ val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType))
- val aggregates = transformToAggregateFunctions(
- namedAggregates.map(_.getKey),inputType,groupings.length)._2
+ val accumulatorRowType = createAccumulatorRowType(inputType, aggregates)
+ val aggResultRowType = new RowTypeInfo(aggResultTypes: _*)
+ val aggFunction = new AggregateAggFunction(aggregates, aggFields)
- val (groupingOffsetMapping, aggOffsetMapping) =
- getGroupingOffsetAndAggOffsetMapping(
- namedAggregates,
- inputType,
- outputType,
- groupings)
-
- val finalRowArity = outputType.getFieldCount
-
- if (isTimeWindow(window)) {
- val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
- new IncrementalAggregateTimeWindowFunction(
- aggregates,
- groupingOffsetMapping,
- aggOffsetMapping,
- finalRowArity,
- startPos,
- endPos)
- .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]]
- } else {
- new IncrementalAggregateWindowFunction(
- aggregates,
- groupingOffsetMapping,
- aggOffsetMapping,
- finalRowArity)
- }
+ (aggFunction, accumulatorRowType, aggResultRowType)
}
/**
- * Return true if all aggregates can be partially computed. False otherwise.
+ * Return true if all aggregates can be partially merged. False otherwise.
*/
- private[flink] def doAllSupportPartialAggregation(
+ private[flink] def doAllSupportPartialMerge(
aggregateCalls: Seq[AggregateCall],
inputType: RelDataType,
groupKeysCount: Int): Boolean = {
- transformToAggregateFunctions(
+
+ val aggregateList = transformToAggregateFunctions(
aggregateCalls,
inputType,
- groupKeysCount)._2.forall(_.supportPartial)
+ groupKeysCount)._2
+
+ doAllSupportPartialMerge(aggregateList)
+ }
+
+ /**
+ * Return true if all aggregates can be partially merged. False otherwise.
+ */
+ private[flink] def doAllSupportPartialMerge(
+ aggregateList: Array[TableAggregateFunction[_ <: Any]]): Boolean = {
+ aggregateList.forall(ifMethodExistInFunction("merge", _))
}
/**
@@ -601,10 +492,10 @@ object AggregateUtil {
// map from field -> i$field or field -> i$field_0
val groupingFields = inputFields.map(inputFieldName => {
- val base = "i$" + inputFieldName
- var name = base
- var i = 0
- while (inputFields.contains(name)) {
+ val base = "i$" + inputFieldName
+ var name = base
+ var i = 0
+ while (inputFields.contains(name)) {
name = base + "_" + i // if i$XXX is already a field it will be suffixed by _NUMBER
i = i + 1
}
@@ -642,7 +533,7 @@ object AggregateUtil {
}
private[flink] def computeWindowStartEndPropertyPos(
- properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = {
+ properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = {
val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) {
(p, x) => p match {
@@ -663,15 +554,15 @@ object AggregateUtil {
}
private def transformToAggregateFunctions(
- aggregateCalls: Seq[AggregateCall],
- inputType: RelDataType,
- groupKeysCount: Int): (Array[Int], Array[Aggregate[_ <: Any]]) = {
+ aggregateCalls: Seq[AggregateCall],
+ inputType: RelDataType,
+ groupKeysCount: Int): (Array[Int], Array[TableAggregateFunction[_ <: Any]]) = {
// store the aggregate fields of each aggregate function, by the same order of aggregates.
val aggFieldIndexes = new Array[Int](aggregateCalls.size)
- val aggregates = new Array[Aggregate[_ <: Any]](aggregateCalls.size)
+ val aggregates = new Array[TableAggregateFunction[_ <: Any]](aggregateCalls.size)
- // set the start offset of aggregate buffer value to group keys' length,
+ // set the start offset of aggregate buffer value to group keys' length,
// as all the group keys would be moved to the start fields of intermediate
// aggregate data.
var aggOffset = groupKeysCount
@@ -696,19 +587,19 @@ object AggregateUtil {
case _: SqlSumAggFunction | _: SqlSumEmptyIsZeroAggFunction => {
aggregates(index) = sqlTypeName match {
case TINYINT =>
- new ByteSumAggregate
+ new ByteSumAggFunction
case SMALLINT =>
- new ShortSumAggregate
+ new ShortSumAggFunction
case INTEGER =>
- new IntSumAggregate
+ new IntSumAggFunction
case BIGINT =>
- new LongSumAggregate
+ new LongSumAggFunction
case FLOAT =>
- new FloatSumAggregate
+ new FloatSumAggFunction
case DOUBLE =>
- new DoubleSumAggregate
+ new DoubleSumAggFunction
case DECIMAL =>
- new DecimalSumAggregate
+ new DecimalSumAggFunction
case sqlType: SqlTypeName =>
throw new TableException("Sum aggregate does no support type:" + sqlType)
}
@@ -716,19 +607,19 @@ object AggregateUtil {
case _: SqlAvgAggFunction => {
aggregates(index) = sqlTypeName match {
case TINYINT =>
- new ByteAvgAggregate
+ new ByteAvgAggFunction
case SMALLINT =>
- new ShortAvgAggregate
+ new ShortAvgAggFunction
case INTEGER =>
- new IntAvgAggregate
+ new IntAvgAggFunction
case BIGINT =>
- new LongAvgAggregate
+ new LongAvgAggFunction
case FLOAT =>
- new FloatAvgAggregate
+ new FloatAvgAggFunction
case DOUBLE =>
- new DoubleAvgAggregate
+ new DoubleAvgAggFunction
case DECIMAL =>
- new DecimalAvgAggregate
+ new DecimalAvgAggFunction
case sqlType: SqlTypeName =>
throw new TableException("Avg aggregate does no support type:" + sqlType)
}
@@ -737,84 +628,114 @@ object AggregateUtil {
aggregates(index) = if (sqlMinMaxFunction.getKind == SqlKind.MIN) {
sqlTypeName match {
case TINYINT =>
- new ByteMinAggregate
+ new ByteMinAggFunction
case SMALLINT =>
- new ShortMinAggregate
+ new ShortMinAggFunction
case INTEGER =>
- new IntMinAggregate
+ new IntMinAggFunction
case BIGINT =>
- new LongMinAggregate
+ new LongMinAggFunction
case FLOAT =>
- new FloatMinAggregate
+ new FloatMinAggFunction
case DOUBLE =>
- new DoubleMinAggregate
+ new DoubleMinAggFunction
case DECIMAL =>
- new DecimalMinAggregate
+ new DecimalMinAggFunction
case BOOLEAN =>
- new BooleanMinAggregate
+ new BooleanMinAggFunction
case sqlType: SqlTypeName =>
throw new TableException("Min aggregate does no support type:" + sqlType)
}
} else {
sqlTypeName match {
case TINYINT =>
- new ByteMaxAggregate
+ new ByteMaxAggFunction
case SMALLINT =>
- new ShortMaxAggregate
+ new ShortMaxAggFunction
case INTEGER =>
- new IntMaxAggregate
+ new IntMaxAggFunction
case BIGINT =>
- new LongMaxAggregate
+ new LongMaxAggFunction
case FLOAT =>
- new FloatMaxAggregate
+ new FloatMaxAggFunction
case DOUBLE =>
- new DoubleMaxAggregate
+ new DoubleMaxAggFunction
case DECIMAL =>
- new DecimalMaxAggregate
+ new DecimalMaxAggFunction
case BOOLEAN =>
- new BooleanMaxAggregate
+ new BooleanMaxAggFunction
case sqlType: SqlTypeName =>
throw new TableException("Max aggregate does no support type:" + sqlType)
}
}
}
case _: SqlCountAggFunction =>
- aggregates(index) = new CountAggregate
+ aggregates(index) = new CountAggFunction
case unSupported: SqlAggFunction =>
throw new TableException("unsupported Function: " + unSupported.getName)
}
- setAggregateDataOffset(index)
- }
-
- // set the aggregate intermediate data start index in Row, and update current value.
- def setAggregateDataOffset(index: Int): Unit = {
- aggregates(index).setAggOffsetInRow(aggOffset)
- aggOffset += aggregates(index).intermediateDataType.length
}
(aggFieldIndexes, aggregates)
}
- private def createAggregateBufferDataType(
- groupings: Array[Int],
- aggregates: Array[Aggregate[_]],
- inputType: RelDataType,
- windowKeyTypes: Option[Array[TypeInformation[_]]] = None): RowTypeInfo = {
+ private def createAccumulatorType(
+ inputType: RelDataType,
+ aggregates: Array[TableAggregateFunction[_]]): Seq[TypeInformation[_]] = {
+
+ val aggTypes: Seq[TypeInformation[_]] =
+ aggregates.map {
+ agg =>
+ val accType = agg.getAccumulatorType()
+ if (accType != null) {
+ accType
+ } else {
+ val accumulator = agg.createAccumulator()
+ try {
+ TypeInformation.of(accumulator.getClass)
+ } catch {
+ case ite: InvalidTypesException =>
+ throw new TableException(
+ "Cannot infer type of accumulator. " +
+ "You can override AggregateFunction.getAccumulatorType() to specify the type.",
+ ite)
+ }
+ }
+ }
+
+ aggTypes
+ }
+
+ private def createDataSetAggregateBufferDataType(
+ groupings: Array[Int],
+ aggregates: Array[TableAggregateFunction[_]],
+ inputType: RelDataType,
+ windowKeyTypes: Option[Array[TypeInformation[_]]] = None): RowTypeInfo = {
// get the field data types of group keys.
- val groupingTypes: Seq[TypeInformation[_]] = groupings
- .map(inputType.getFieldList.get(_).getType)
- .map(FlinkTypeFactory.toTypeInfo)
+ val groupingTypes: Seq[TypeInformation[_]] =
+ groupings
+ .map(inputType.getFieldList.get(_).getType)
+ .map(FlinkTypeFactory.toTypeInfo)
// get all field data types of all intermediate aggregates
- val aggTypes: Seq[TypeInformation[_]] = aggregates.flatMap(_.intermediateDataType)
+ val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(inputType, aggregates)
// concat group key types, aggregation types, and window key types
- val allFieldTypes:Seq[TypeInformation[_]] = windowKeyTypes match {
+ val allFieldTypes: Seq[TypeInformation[_]] = windowKeyTypes match {
case None => groupingTypes ++: aggTypes
case _ => groupingTypes ++: aggTypes ++: windowKeyTypes.get
}
- new RowTypeInfo(allFieldTypes :_*)
+ new RowTypeInfo(allFieldTypes: _*)
+ }
+
+ private def createAccumulatorRowType(
+ inputType: RelDataType,
+ aggregates: Array[TableAggregateFunction[_]]): RowTypeInfo = {
+
+ val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(inputType, aggregates)
+
+ new RowTypeInfo(aggTypes: _*)
}
// Find the mapping between the index of aggregate list and aggregated value index in output Row.
@@ -826,12 +747,12 @@ object AggregateUtil {
// field index in output Row.
var aggOffsetMapping = ArrayBuffer[(Int, Int)]()
- outputType.getFieldList.zipWithIndex.foreach{
+ outputType.getFieldList.zipWithIndex.foreach {
case (outputFieldType, outputIndex) =>
namedAggregates.zipWithIndex.foreach {
case (namedAggCall, aggregateIndex) =>
if (namedAggCall.getValue.equals(outputFieldType.getName) &&
- namedAggCall.getKey.getType.equals(outputFieldType.getType)) {
+ namedAggCall.getKey.getType.equals(outputFieldType.getType)) {
aggOffsetMapping += ((outputIndex, aggregateIndex))
}
}
@@ -856,7 +777,7 @@ object AggregateUtil {
// find the field index in input data type.
case (inputFieldType, inputIndex) =>
if (outputFieldType.getName.equals(inputFieldType.getName) &&
- outputFieldType.getType.equals(inputFieldType.getType)) {
+ outputFieldType.getType.equals(inputFieldType.getType)) {
// as aggregated field in output data type would not have a matched field in
// input data, so if inputIndex is not -1, it must be a group key. Then we can
// find the field index in buffer data by the group keys index mapping between
@@ -906,6 +827,5 @@ object AggregateUtil {
case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value
case _ => throw new IllegalArgumentException()
}
-
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala
deleted file mode 100644
index 5491b1d..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.lang.Iterable
-
-import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.api.java.tuple.Tuple
-import org.apache.flink.types.Row
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction
-import org.apache.flink.streaming.api.windowing.windows.Window
-import org.apache.flink.util.Collector
-
-class AggregateWindowFunction[W <: Window](groupReduceFunction: RichGroupReduceFunction[Row, Row])
- extends RichWindowFunction[Row, Row, Tuple, W] {
-
- override def open(parameters: Configuration): Unit = {
- groupReduceFunction.open(parameters)
- }
-
- override def apply(
- key: Tuple,
- window: W,
- input: Iterable[Row],
- out: Collector[Row]): Unit = {
-
- groupReduceFunction.reduce(input, out)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala
index f1d91a3..47fa0f1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala
@@ -18,40 +18,44 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupCombineFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.types.Row
import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.util.{Collector, Preconditions}
/**
* This wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupCombineOperator]].
*
- * @param aggregates The aggregate functions.
- * @param groupingKeys The indexes of the grouping fields.
- * @param intermediateRowArity The intermediate row field count.
- * @param gap Session time window gap.
+ * @param aggregates The aggregate functions.
+ * @param groupingKeys The indexes of the grouping fields.
+ * @param gap Session time window gap.
* @param intermediateRowType Intermediate row data type.
*/
class DataSetSessionWindowAggregateCombineGroupFunction(
- aggregates: Array[Aggregate[_ <: Any]],
+ aggregates: Array[AggregateFunction[_ <: Any]],
groupingKeys: Array[Int],
- intermediateRowArity: Int,
gap: Long,
@transient intermediateRowType: TypeInformation[Row])
- extends RichGroupCombineFunction[Row,Row] with ResultTypeQueryable[Row] {
+ extends RichGroupCombineFunction[Row, Row] with ResultTypeQueryable[Row] {
private var aggregateBuffer: Row = _
- private var rowTimeFieldPos = 0
+ private var accumStartPos: Int = groupingKeys.length
+ private var rowTimeFieldPos = accumStartPos + aggregates.length
+ private val maxMergeLen = 16
+ val accumulatorList = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator]()
+ }
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupingKeys)
- aggregateBuffer = new Row(intermediateRowArity)
- rowTimeFieldPos = intermediateRowArity - 2
+ aggregateBuffer = new Row(rowTimeFieldPos + 2)
}
/**
@@ -59,7 +63,7 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
* (current'rowtime - previous\u2019rowtime > gap), and then merge data (within a unified window)
* into an aggregate buffer.
*
- * @param records Sub-grouped intermediate aggregate Rows.
+ * @param records Sub-grouped intermediate aggregate Rows.
* @return Combined intermediate aggregate Row.
*
*/
@@ -68,10 +72,15 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
var windowStart: java.lang.Long = null
var windowEnd: java.lang.Long = null
var currentRowTime: java.lang.Long = null
+ accumulatorList.foreach(_.clear())
val iterator = records.iterator()
+
+
+ var count: Int = 0
while (iterator.hasNext) {
val record = iterator.next()
+ count += 1
currentRowTime = record.getField(rowTimeFieldPos).asInstanceOf[Long]
// initial traversal or opening a new window
if (windowEnd == null || (windowEnd != null && (currentRowTime > windowEnd))) {
@@ -79,7 +88,11 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
// calculate the current window and open a new window.
if (windowEnd != null) {
// emit the current window's merged data
- doCollect(out, windowStart, windowEnd)
+ doCollect(out, accumulatorList, windowStart, windowEnd)
+
+ // clear the accumulator list for all aggregate
+ accumulatorList.foreach(_.clear())
+ count = 0
} else {
// set group keys to aggregateBuffer.
for (i <- groupingKeys.indices) {
@@ -87,36 +100,59 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
}
}
- // initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
windowStart = record.getField(rowTimeFieldPos).asInstanceOf[Long]
}
- // merge intermediate aggregate value to the buffered value.
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ // collect the accumulators for each aggregate
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator])
+ }
+
+ // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
+ // accumulator
+ if (count > maxMergeLen) {
+ count = 0
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
// the current rowtime is the last rowtime of the next calculation.
windowEnd = currentRowTime + gap
}
// emit the merged data of the current window.
- doCollect(out, windowStart, windowEnd)
+ doCollect(out, accumulatorList, windowStart, windowEnd)
}
/**
* Emit the merged data of the current window.
- * @param windowStart the window's start attribute value is the min (rowtime)
- * of all rows in the window.
- * @param windowEnd the window's end property value is max (rowtime) + gap
- * for all rows in the window.
+ *
+ * @param out the collection of the aggregate results
+ * @param accumulatorList an array (indexed by aggregate index) of the accumulator lists for
+ * each aggregate
+ * @param windowStart the window's start attribute value is the min (rowtime)
+ * of all rows in the window.
+ * @param windowEnd the window's end property value is max (rowtime) + gap
+ * for all rows in the window.
*/
def doCollect(
- out: Collector[Row],
- windowStart: Long,
- windowEnd: Long): Unit = {
+ out: Collector[Row],
+ accumulatorList: Array[JArrayList[Accumulator]],
+ windowStart: Long,
+ windowEnd: Long): Unit = {
+
+ // merge the accumulators into one accumulator
+ for (i <- aggregates.indices) {
+ aggregateBuffer.setField(accumStartPos + i, aggregates(i).merge(accumulatorList(i)))
+ }
- // intermediate Row WindowStartPos is rowtime pos .
+ // intermediate Row WindowStartPos is rowtime pos.
aggregateBuffer.setField(rowTimeFieldPos, windowStart)
- // intermediate Row WindowEndPos is rowtime pos + 1 .
+
+ // intermediate Row WindowEndPos is rowtime pos + 1.
aggregateBuffer.setField(rowTimeFieldPos + 1, windowEnd)
out.collect(aggregateBuffer)
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala
index 99d241d..1570671 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala
@@ -18,10 +18,12 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.types.Row
import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.util.{Collector, Preconditions}
/**
@@ -30,49 +32,51 @@ import org.apache.flink.util.{Collector, Preconditions}
* on batch.
*
* Note:
- *
- * This can handle two input types (depending if input is combined or not):
+ *
+ * This can handle two input types (depending if input is combined or not):
*
* 1. when partial aggregate is not supported, the input data structure of reduce is
- * |groupKey1|groupKey2|sum1|count1|sum2|count2|rowTime|
+ * |groupKey1|groupKey2|sum1|count1|sum2|count2|rowTime|
* 2. when partial aggregate is supported, the input data structure of reduce is
- * |groupKey1|groupKey2|sum1|count1|sum2|count2|windowStart|windowEnd|
+ * |groupKey1|groupKey2|sum1|count1|sum2|count2|windowStart|windowEnd|
*
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param intermediateRowArity The intermediate row field count.
- * @param finalRowArity The output row field count.
+ * @param aggregates The aggregate functions.
+ * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
+ * and output Row.
+ * @param aggregateMapping The index mapping between aggregate function list and
+ * aggregated value index in output Row.
+ * @param finalRowArity The output row field count.
* @param finalRowWindowStartPos The relative window-start field position.
- * @param finalRowWindowEndPos The relative window-end field position.
- * @param gap Session time window gap.
+ * @param finalRowWindowEndPos The relative window-end field position.
+ * @param gap Session time window gap.
*/
class DataSetSessionWindowAggregateReduceGroupFunction(
- aggregates: Array[Aggregate[_ <: Any]],
+ aggregates: Array[AggregateFunction[_ <: Any]],
groupKeysMapping: Array[(Int, Int)],
aggregateMapping: Array[(Int, Int)],
- intermediateRowArity: Int,
finalRowArity: Int,
finalRowWindowStartPos: Option[Int],
finalRowWindowEndPos: Option[Int],
- gap:Long,
+ gap: Long,
isInputCombined: Boolean)
extends RichGroupReduceFunction[Row, Row] {
private var aggregateBuffer: Row = _
- private var intermediateRowWindowStartPos = 0
- private var intermediateRowWindowEndPos = 0
private var output: Row = _
private var collector: TimeWindowPropertyCollector = _
+ private var accumStartPos: Int = groupKeysMapping.length
+ private var intermediateRowArity: Int = accumStartPos + aggregates.length + 2
+ private var intermediateRowWindowStartPos = intermediateRowArity - 2
+ private var intermediateRowWindowEndPos = intermediateRowArity - 1
+ private val maxMergeLen = 16
+ val accumulatorList = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator]()
+ }
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupKeysMapping)
aggregateBuffer = new Row(intermediateRowArity)
- intermediateRowWindowStartPos = intermediateRowArity - 2
- intermediateRowWindowEndPos = intermediateRowArity - 1
output = new Row(finalRowArity)
collector = new TimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos)
}
@@ -91,11 +95,15 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
var windowStart: java.lang.Long = null
var windowEnd: java.lang.Long = null
- var currentRowTime:java.lang.Long = null
+ var currentRowTime: java.lang.Long = null
+ accumulatorList.foreach(_.clear())
val iterator = records.iterator()
+
+ var count: Int = 0
while (iterator.hasNext) {
val record = iterator.next()
+ count += 1
currentRowTime = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long]
// initial traversal or opening a new window
if (null == windowEnd ||
@@ -104,7 +112,11 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
// calculate the current window and open a new window
if (null != windowEnd) {
// evaluate and emit the current window's result.
- doEvaluateAndCollect(out, windowStart, windowEnd)
+ doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd)
+
+ // clear the accumulator list for all aggregate
+ accumulatorList.foreach(_.clear())
+ count = 0
} else {
// set group keys value to final output.
groupKeysMapping.foreach {
@@ -112,13 +124,26 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
output.setField(after, record.getField(previous))
}
}
- // initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
+
windowStart = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long]
}
- // merge intermediate aggregate value to the buffered value.
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ // collect the accumulators for each aggregate
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator])
+ }
+
+ // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
+ // accumulator
+ if (count > maxMergeLen) {
+ count = 0
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
windowEnd = if (isInputCombined) {
// partial aggregate is supported
@@ -129,25 +154,32 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
}
}
// evaluate and emit the current window's result.
- doEvaluateAndCollect(out, windowStart, windowEnd)
+ doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd)
}
/**
* Evaluate and emit the data of the current window.
- * @param windowStart the window's start attribute value is the min (rowtime)
- * of all rows in the window.
- * @param windowEnd the window's end property value is max (rowtime) + gap
- * for all rows in the window.
+ *
+ * @param out the collection of the aggregate results
+ * @param accumulatorList an array (indexed by aggregate index) of the accumulator lists for
+ * each aggregate
+ * @param windowStart the window's start attribute value is the min (rowtime) of all rows
+ * in the window.
+ * @param windowEnd the window's end property value is max (rowtime) + gap for all rows
+ * in the window.
*/
def doEvaluateAndCollect(
- out: Collector[Row],
- windowStart: Long,
- windowEnd: Long): Unit = {
+ out: Collector[Row],
+ accumulatorList: Array[JArrayList[Accumulator]],
+ windowStart: Long,
+ windowEnd: Long): Unit = {
- // evaluate final aggregate value and set to output.
+ // merge the accumulators and then get value for the final output
aggregateMapping.foreach {
case (after, previous) =>
- output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
+ val agg = aggregates(previous)
+ val accum = agg.merge(accumulatorList(previous))
+ output.setField(after, agg.getValue(accum))
}
// adds TimeWindow properties to output then emit output
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
index 40dad17..b722330 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
@@ -18,9 +18,11 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
@@ -29,26 +31,30 @@ import org.apache.flink.util.{Collector, Preconditions}
* [[org.apache.flink.api.java.operators.GroupReduceOperator]].
* It is only used for tumbling count-window on batch.
*
- * @param windowSize Tumble count window size
- * @param aggregates The aggregate functions.
+ * @param windowSize Tumble count window size
+ * @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
- * @param intermediateRowArity The intermediate row field count
- * @param finalRowArity The output row field count
+ * @param finalRowArity The output row field count
*/
class DataSetTumbleCountWindowAggReduceGroupFunction(
private val windowSize: Long,
- private val aggregates: Array[Aggregate[_ <: Any]],
+ private val aggregates: Array[AggregateFunction[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
- private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
private var aggregateBuffer: Row = _
private var output: Row = _
+ private val accumStartPos: Int = groupKeysMapping.length
+ private val intermediateRowArity: Int = accumStartPos + aggregates.length + 1
+ private val maxMergeLen = 16
+ val accumulatorList = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator]()
+ }
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
@@ -60,30 +66,49 @@ class DataSetTumbleCountWindowAggReduceGroupFunction(
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
var count: Long = 0
+ accumulatorList.foreach(_.clear())
val iterator = records.iterator()
while (iterator.hasNext) {
val record = iterator.next()
+
if (count == 0) {
- // initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
+ // clear the accumulator list for all aggregate
+ accumulatorList.foreach(_.clear())
}
- // merge intermediate aggregate value to buffer.
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ // collect the accumulators for each aggregate
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator])
+ }
count += 1
+
+ // for every maxMergeLen accumulators, we merge them into one
+ if (count % maxMergeLen == 0) {
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
+
if (windowSize == count) {
// set group keys value to final output.
groupKeysMapping.foreach {
case (after, previous) =>
output.setField(after, record.getField(previous))
}
- // evaluate final aggregate value and set to output.
+
+ // merge the accumulators and then get value for the final output
aggregateMapping.foreach {
case (after, previous) =>
- output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
+ val agg = aggregates(previous)
+ val accumulator = agg.merge(accumulatorList(previous))
+ output.setField(after, agg.getValue(accumulator))
}
+
// emit the output
out.collect(output)
count = 0
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
index a72c9ca..d507a58 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
@@ -18,8 +18,10 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.CombineFunction
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
/**
@@ -28,68 +30,86 @@ import org.apache.flink.types.Row
* [[org.apache.flink.api.java.operators.GroupCombineOperator]].
* It is used for tumbling time-window on batch.
*
- * @param rowtimePos The rowtime field index in input row
- * @param windowSize Tumbling time window size
- * @param windowStartPos The relative window-start field position to the last field of output row
- * @param windowEndPos The relative window-end field position to the last field of output row
- * @param aggregates The aggregate functions.
+ * @param windowSize Tumbling time window size
+ * @param windowStartPos The relative window-start field position to the last field of output row
+ * @param windowEndPos The relative window-end field position to the last field of output row
+ * @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
- * @param intermediateRowArity The intermediate row field count
- * @param finalRowArity The output row field count
+ * @param finalRowArity The output row field count
*/
class DataSetTumbleTimeWindowAggReduceCombineFunction(
- rowtimePos: Int,
windowSize: Long,
windowStartPos: Option[Int],
windowEndPos: Option[Int],
- aggregates: Array[Aggregate[_ <: Any]],
+ aggregates: Array[AggregateFunction[_ <: Any]],
groupKeysMapping: Array[(Int, Int)],
aggregateMapping: Array[(Int, Int)],
- intermediateRowArity: Int,
finalRowArity: Int)
extends DataSetTumbleTimeWindowAggReduceGroupFunction(
- rowtimePos,
windowSize,
windowStartPos,
windowEndPos,
aggregates,
groupKeysMapping,
aggregateMapping,
- intermediateRowArity,
finalRowArity)
- with CombineFunction[Row, Row] {
+ with CombineFunction[Row, Row] {
/**
* For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
*
- * @param records Sub-grouped intermediate aggregate Rows iterator.
+ * @param records Sub-grouped intermediate aggregate Rows iterator.
* @return Combined intermediate aggregate Row.
*
*/
override def combine(records: Iterable[Row]): Row = {
- // initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
-
- // merge intermediate aggregate value to buffer.
var last: Row = null
+ accumulatorList.foreach(_.clear())
val iterator = records.iterator()
+
+ var count: Int = 0
while (iterator.hasNext) {
val record = iterator.next()
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ count += 1
+ // per each aggregator, collect its accumulators to a list
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(groupKeysMapping.length + i)
+ .asInstanceOf[Accumulator])
+ }
+ // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
+ // accumulator
+ if (count > maxMergeLen) {
+ count = 0
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
last = record
}
+ // per each aggregator, merge list of accumulators into one and save the result to the
+ // intermediate aggregate buffer
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i)))
+ }
+
// set group keys to aggregateBuffer.
for (i <- groupKeysMapping.indices) {
aggregateBuffer.setField(i, last.getField(i))
}
// set the rowtime attribute
+ val rowtimePos = groupKeysMapping.length + aggregates.length
+
aggregateBuffer.setField(rowtimePos, last.getField(rowtimePos))
aggregateBuffer
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
index a4c03b9..63d2aeb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
@@ -18,10 +18,11 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
+import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
@@ -30,33 +31,36 @@ import org.apache.flink.util.{Collector, Preconditions}
* [[org.apache.flink.api.java.operators.GroupReduceOperator]]. It is used for tumbling time-window
* on batch.
*
- * @param rowtimePos The rowtime field index in input row
- * @param windowSize Tumbling time window size
- * @param windowStartPos The relative window-start field position to the last field of output row
- * @param windowEndPos The relative window-end field position to the last field of output row
- * @param aggregates The aggregate functions.
+ * @param windowSize Tumbling time window size
+ * @param windowStartPos The relative window-start field position to the last field of output row
+ * @param windowEndPos The relative window-end field position to the last field of output row
+ * @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
- * @param intermediateRowArity The intermediate row field count
- * @param finalRowArity The output row field count
+ * @param finalRowArity The output row field count
*/
class DataSetTumbleTimeWindowAggReduceGroupFunction(
- rowtimePos: Int,
windowSize: Long,
windowStartPos: Option[Int],
windowEndPos: Option[Int],
- aggregates: Array[Aggregate[_ <: Any]],
+ aggregates: Array[AggregateFunction[_ <: Any]],
groupKeysMapping: Array[(Int, Int)],
aggregateMapping: Array[(Int, Int)],
- intermediateRowArity: Int,
finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
private var collector: TimeWindowPropertyCollector = _
protected var aggregateBuffer: Row = _
private var output: Row = _
+ private val accumStartPos: Int = groupKeysMapping.length
+ private val rowtimePos: Int = accumStartPos + aggregates.length
+ private val intermediateRowArity: Int = rowtimePos + 1
+ protected val maxMergeLen = 16
+ val accumulatorList = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator]()
+ }
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
@@ -68,16 +72,30 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction(
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
- // initiate intermediate aggregate value.
- aggregates.foreach(_.initiate(aggregateBuffer))
-
- // merge intermediate aggregate value to buffer.
var last: Row = null
+ accumulatorList.foreach(_.clear())
val iterator = records.iterator()
+
+ var count: Int = 0
while (iterator.hasNext) {
val record = iterator.next()
- aggregates.foreach(_.merge(record, aggregateBuffer))
+ count += 1
+ // per each aggregator, collect its accumulators to a list
+ for (i <- aggregates.indices) {
+ accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator])
+ }
+ // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
+ // accumulator
+ if (count > maxMergeLen) {
+ count = 0
+ for (i <- aggregates.indices) {
+ val agg = aggregates(i)
+ val accumulator = agg.merge(accumulatorList(i))
+ accumulatorList(i).clear()
+ accumulatorList(i).add(accumulator)
+ }
+ }
last = record
}
@@ -87,10 +105,14 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction(
output.setField(after, last.getField(previous))
}
- // evaluate final aggregate value and set to output.
+ // get final aggregate value and set to output.
aggregateMapping.foreach {
- case (after, previous) =>
- output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
+ case (after, previous) => {
+ val agg = aggregates(previous)
+ val accumulator = agg.merge(accumulatorList(previous))
+ val result = agg.getValue(accumulator)
+ output.setField(after, result)
+ }
}
// get window start timestamp
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala
index 5c3d374..68088fc 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
+import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.types.Row
import org.apache.flink.util.Preconditions
@@ -34,13 +35,13 @@ import org.apache.flink.util.Preconditions
* append an (aligned) rowtime field to the end of the output row.
*/
class DataSetWindowAggregateMapFunction(
- private val aggregates: Array[Aggregate[_]],
+ private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val groupingKeys: Array[Int],
- private val timeFieldPos: Int, // time field position in input row
+ private val timeFieldPos: Int, // time field position in input row
private val tumbleTimeWindowSize: Option[Long],
@transient private val returnType: TypeInformation[Row])
- extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] {
+ extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] {
private var output: Row = _
// rowtime index in the buffer output row
@@ -51,18 +52,22 @@ class DataSetWindowAggregateMapFunction(
Preconditions.checkNotNull(aggFields)
Preconditions.checkArgument(aggregates.length == aggFields.length)
// add one more arity to store rowtime
- val partialRowLength = groupingKeys.length +
- aggregates.map(_.intermediateDataType.length).sum + 1
+ val partialRowLength = groupingKeys.length + aggregates.length + 1
// set rowtime to the last field of the output row
rowtimeIndex = partialRowLength - 1
output = new Row(partialRowLength)
}
override def map(input: Row): Row = {
+
for (i <- aggregates.indices) {
+ val agg = aggregates(i)
val fieldValue = input.getField(aggFields(i))
- aggregates(i).prepare(fieldValue, output)
+ val accumulator = agg.createAccumulator()
+ agg.accumulate(accumulator, fieldValue)
+ output.setField(groupingKeys.length + i, accumulator)
}
+
for (i <- groupingKeys.indices) {
output.setField(i, input.getField(groupingKeys(i)))
}
@@ -103,3 +108,4 @@ class DataSetWindowAggregateMapFunction(
returnType
}
}
+
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala
index ed49dc3..51c614d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala
@@ -23,28 +23,20 @@ import org.apache.flink.types.Row
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
import org.apache.flink.util.Collector
+
/**
*
* Computes the final aggregate value from incrementally computed aggreagtes.
*
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
+ * @param windowStartPos the start position of window
+ * @param windowEndPos the end position of window
* @param finalRowArity The arity of the final output row.
*/
class IncrementalAggregateAllTimeWindowFunction(
- private val aggregates: Array[Aggregate[_ <: Any]],
- private val groupKeysMapping: Array[(Int, Int)],
- private val aggregateMapping: Array[(Int, Int)],
- private val finalRowArity: Int,
private val windowStartPos: Option[Int],
- private val windowEndPos: Option[Int])
+ private val windowEndPos: Option[Int],
+ private val finalRowArity: Int)
extends IncrementalAggregateAllWindowFunction[TimeWindow](
- aggregates,
- groupKeysMapping,
- aggregateMapping,
finalRowArity) {
private var collector: TimeWindowPropertyCollector = _
@@ -55,9 +47,9 @@ class IncrementalAggregateAllTimeWindowFunction(
}
override def apply(
- window: TimeWindow,
- records: Iterable[Row],
- out: Collector[Row]): Unit = {
+ window: TimeWindow,
+ records: Iterable[Row],
+ out: Collector[Row]): Unit = {
// set collector and window
collector.wrappedCollector = out
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala
index 3c41a62..00aba1f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala
@@ -28,25 +28,15 @@ import org.apache.flink.util.{Collector, Preconditions}
/**
* Computes the final aggregate value from incrementally computed aggreagtes.
*
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param finalRowArity The arity of the final output row.
+ * @param finalRowArity The arity of the final output row.
*/
class IncrementalAggregateAllWindowFunction[W <: Window](
- private val aggregates: Array[Aggregate[_ <: Any]],
- private val groupKeysMapping: Array[(Int, Int)],
- private val aggregateMapping: Array[(Int, Int)],
private val finalRowArity: Int)
extends RichAllWindowFunction[Row, Row, W] {
private var output: Row = _
override def open(parameters: Configuration): Unit = {
- Preconditions.checkNotNull(aggregates)
- Preconditions.checkNotNull(groupKeysMapping)
output = new Row(finalRowArity)
}
@@ -55,25 +45,15 @@ class IncrementalAggregateAllWindowFunction[W <: Window](
* Row based on the mapping relation between intermediate aggregate data and output data.
*/
override def apply(
- window: W,
- records: Iterable[Row],
- out: Collector[Row]): Unit = {
+ window: W,
+ records: Iterable[Row],
+ out: Collector[Row]): Unit = {
val iterator = records.iterator
if (iterator.hasNext) {
val record = iterator.next()
- // Set group keys value to final output.
- groupKeysMapping.foreach {
- case (after, previous) =>
- output.setField(after, record.getField(previous))
- }
- // Evaluate final aggregate value and set to output.
- aggregateMapping.foreach {
- case (after, previous) =>
- output.setField(after, aggregates(previous).evaluate(record))
- }
- out.collect(output)
+ out.collect(record)
}
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala
deleted file mode 100644
index 14b44e8..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import org.apache.flink.api.common.functions.ReduceFunction
-import org.apache.flink.types.Row
-import org.apache.flink.util.Preconditions
-
-/**
- * Incrementally computes group window aggregates.
- *
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- */
-class IncrementalAggregateReduceFunction(
- private val aggregates: Array[Aggregate[_]],
- private val groupKeysMapping: Array[(Int, Int)],
- private val intermediateRowArity: Int)
- extends ReduceFunction[Row] {
-
- Preconditions.checkNotNull(aggregates)
- Preconditions.checkNotNull(groupKeysMapping)
-
- /**
- * For Incremental intermediate aggregate Rows, merge value1 and value2
- * into aggregate buffer, return aggregate buffer.
- *
- * @param value1 The first value to combined.
- * @param value2 The second value to combined.
- * @return accumulatorRow A resulting row that combines two input values.
- *
- */
- override def reduce(value1: Row, value2: Row): Row = {
-
- // TODO: once FLINK-5105 is solved, we can avoid creating a new row for each invocation
- // and directly merge value1 and value2.
- val accumulatorRow = new Row(intermediateRowArity)
-
- // copy all fields of value1 into accumulatorRow
- (0 until intermediateRowArity)
- .foreach(i => accumulatorRow.setField(i, value1.getField(i)))
- // merge value2 to accumulatorRow
- aggregates.foreach(_.merge(value2, accumulatorRow))
-
- accumulatorRow
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala
index a6626d9..dccb4f6 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala
@@ -28,24 +28,20 @@ import org.apache.flink.util.Collector
/**
* Computes the final aggregate value from incrementally computed aggreagtes.
*
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param finalRowArity The arity of the final output row.
+ * @param windowStartPos the start position of window
+ * @param windowEndPos the end position of window
+ * @param finalRowArity The arity of the final output row
*/
class IncrementalAggregateTimeWindowFunction(
- private val aggregates: Array[Aggregate[_ <: Any]],
- private val groupKeysMapping: Array[(Int, Int)],
- private val aggregateMapping: Array[(Int, Int)],
- private val finalRowArity: Int,
+ private val numGroupingKey: Int,
+ private val numAggregates: Int,
private val windowStartPos: Option[Int],
- private val windowEndPos: Option[Int])
+ private val windowEndPos: Option[Int],
+ private val finalRowArity: Int)
extends IncrementalAggregateWindowFunction[TimeWindow](
- aggregates,
- groupKeysMapping,
- aggregateMapping, finalRowArity) {
+ numGroupingKey,
+ numAggregates,
+ finalRowArity) {
private var collector: TimeWindowPropertyCollector = _
@@ -55,10 +51,10 @@ class IncrementalAggregateTimeWindowFunction(
}
override def apply(
- key: Tuple,
- window: TimeWindow,
- records: Iterable[Row],
- out: Collector[Row]): Unit = {
+ key: Tuple,
+ window: TimeWindow,
+ records: Iterable[Row],
+ out: Collector[Row]): Unit = {
// set collector and window
collector.wrappedCollector = out
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala
index 30f7a7b..a4d4837 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala
@@ -24,30 +24,24 @@ import org.apache.flink.types.Row
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction
import org.apache.flink.streaming.api.windowing.windows.Window
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.util.Collector
/**
* Computes the final aggregate value from incrementally computed aggreagtes.
*
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
- * @param finalRowArity The arity of the final output row.
+ * @param numGroupingKey The number of grouping keys.
+ * @param numAggregates The number of aggregates.
+ * @param finalRowArity The arity of the final output row.
*/
class IncrementalAggregateWindowFunction[W <: Window](
- private val aggregates: Array[Aggregate[_ <: Any]],
- private val groupKeysMapping: Array[(Int, Int)],
- private val aggregateMapping: Array[(Int, Int)],
+ private val numGroupingKey: Int,
+ private val numAggregates: Int,
private val finalRowArity: Int)
extends RichWindowFunction[Row, Row, Tuple, W] {
private var output: Row = _
override def open(parameters: Configuration): Unit = {
- Preconditions.checkNotNull(aggregates)
- Preconditions.checkNotNull(groupKeysMapping)
output = new Row(finalRowArity)
}
@@ -56,25 +50,23 @@ class IncrementalAggregateWindowFunction[W <: Window](
* Row based on the mapping relation between intermediate aggregate data and output data.
*/
override def apply(
- key: Tuple,
- window: W,
- records: Iterable[Row],
- out: Collector[Row]): Unit = {
+ key: Tuple,
+ window: W,
+ records: Iterable[Row],
+ out: Collector[Row]): Unit = {
val iterator = records.iterator
if (iterator.hasNext) {
val record = iterator.next()
- // Set group keys value to final output.
- groupKeysMapping.foreach {
- case (after, previous) =>
- output.setField(after, record.getField(previous))
+
+ for (i <- 0 until numGroupingKey) {
+ output.setField(i, key.getField(i))
}
- // Evaluate final aggregate value and set to output.
- aggregateMapping.foreach {
- case (after, previous) =>
- output.setField(after, aggregates(previous).evaluate(record))
+ for (i <- 0 until numAggregates) {
+ output.setField(numGroupingKey + i, record.getField(i))
}
+
out.collect(output)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala
index a243db7..818cd0e 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala
@@ -133,17 +133,17 @@ class AggregationsITCase extends StreamingMultipleProgramsTestBase {
val windowedTable = table
.window(Tumble over 5.milli on 'rowtime as 'w)
.groupBy('w, 'string)
- .select('string, 'int.count, 'int.avg, 'w.start, 'w.end)
+ .select('string, 'int.count, 'int.avg, 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end)
val results = windowedTable.toDataStream[Row]
results.addSink(new StreamITCase.StringSink)
env.execute()
val expected = Seq(
- "Hello world,1,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01",
- "Hello world,1,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02",
- "Hello,2,2,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005",
- "Hi,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005")
+ "Hello world,1,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01",
+ "Hello world,1,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02",
+ "Hello,2,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005",
+ "Hi,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}