You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/03/03 13:27:33 UTC
[1/3] flink git commit: [FLINK-5927] [table] Remove old Aggregate
interface, built-in functions, and tests.
Repository: flink
Updated Branches:
refs/heads/master 050f9a416 -> c31f95cab
[FLINK-5927] [table] Remove old Aggregate interface, built-in functions, and tests.
This closes #3465.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/c31f95ca
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/c31f95ca
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/c31f95ca
Branch: refs/heads/master
Commit: c31f95cab884452dba47306c2b7fb536f047b8ae
Parents: 14fab4c
Author: shaoxuan-wang <ws...@gmail.com>
Authored: Fri Mar 3 15:05:00 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Mar 3 14:27:08 2017 +0100
----------------------------------------------------------------------
.../table/runtime/aggregate/Aggregate.scala | 96 ------
.../table/runtime/aggregate/AvgAggregate.scala | 296 -------------------
.../runtime/aggregate/CountAggregate.scala | 55 ----
.../table/runtime/aggregate/MaxAggregate.scala | 171 -----------
.../table/runtime/aggregate/MinAggregate.scala | 171 -----------
.../table/runtime/aggregate/SumAggregate.scala | 131 --------
.../runtime/aggregate/AggregateTestBase.scala | 111 -------
.../runtime/aggregate/AvgAggregateTest.scala | 154 ----------
.../runtime/aggregate/CountAggregateTest.scala | 31 --
.../runtime/aggregate/MaxAggregateTest.scala | 177 -----------
.../runtime/aggregate/MinAggregateTest.scala | 177 -----------
.../runtime/aggregate/SumAggregateTest.scala | 137 ---------
12 files changed, 1707 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/Aggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/Aggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/Aggregate.scala
deleted file mode 100644
index a614783..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/Aggregate.scala
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.types.Row
-
-/**
- * The interface for all Flink aggregate functions, which expressed in terms of initiate(),
- * prepare(), merge() and evaluate(). The aggregate functions would be executed in 2 phases:
- * -- In Map phase, use prepare() to transform aggregate field value into intermediate
- * aggregate value.
- * -- In GroupReduce phase, use merge() to merge grouped intermediate aggregate values
- * into aggregate buffer. Then use evaluate() to calculate the final aggregated value.
- * For associative decomposable aggregate functions, they support partial aggregate. To optimize
- * the performance, a Combine phase would be added between Map phase and GroupReduce phase,
- * -- In Combine phase, use merge() to merge sub-grouped intermediate aggregate values
- * into aggregate buffer.
- *
- * The intermediate aggregate value is stored inside Row, aggOffsetInRow is used as the start
- * field index in Row, so different aggregate functions could share the same Row as intermediate
- * aggregate value/aggregate buffer, as their aggregate values could be stored in distinct fields
- * of Row with no conflict. The intermediate aggregate value is required to be a sequence of JVM
- * primitives, and Flink use intermediateDataType() to get its data types in SQL side.
- *
- * @tparam T Aggregated value type.
- */
-trait Aggregate[T] extends Serializable {
-
- /**
- * Transform the aggregate field value into intermediate aggregate data.
- *
- * @param value The value to insert into the intermediate aggregate row.
- * @param intermediate The intermediate aggregate row into which the value is inserted.
- */
- def prepare(value: Any, intermediate: Row): Unit
-
- /**
- * Initiate the intermediate aggregate value in Row.
- *
- * @param intermediate The intermediate aggregate row to initiate.
- */
- def initiate(intermediate: Row): Unit
-
- /**
- * Merge intermediate aggregate data into aggregate buffer.
- *
- * @param intermediate The intermediate aggregate row to merge.
- * @param buffer The aggregate buffer into which the intermedidate is merged.
- */
- def merge(intermediate: Row, buffer: Row): Unit
-
- /**
- * Calculate the final aggregated result based on aggregate buffer.
- *
- * @param buffer The aggregate buffer from which the final aggregate is computed.
- * @return The final result of the aggregate.
- */
- def evaluate(buffer: Row): T
-
- /**
- * Intermediate aggregate value types.
- *
- * @return The types of the intermediate fields of this aggregate.
- */
- def intermediateDataType: Array[TypeInformation[_]]
-
- /**
- * Set the aggregate data offset in Row.
- *
- * @param aggOffset The offset of this aggregate in the intermediate aggregate rows.
- */
- def setAggOffsetInRow(aggOffset: Int)
-
- /**
- * Whether aggregate function support partial aggregate.
- *
- * @return True if the aggregate supports partial aggregation, False otherwise.
- */
- def supportPartial: Boolean = false
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AvgAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AvgAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AvgAggregate.scala
deleted file mode 100644
index cb94ca1..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AvgAggregate.scala
+++ /dev/null
@@ -1,296 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import com.google.common.math.LongMath
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo
-import org.apache.flink.types.Row
-import java.math.BigDecimal
-import java.math.BigInteger
-
-abstract class AvgAggregate[T] extends Aggregate[T] {
- protected var partialSumIndex: Int = _
- protected var partialCountIndex: Int = _
-
- override def supportPartial: Boolean = true
-
- override def setAggOffsetInRow(aggOffset: Int): Unit = {
- partialSumIndex = aggOffset
- partialCountIndex = aggOffset + 1
- }
-}
-
-abstract class IntegralAvgAggregate[T] extends AvgAggregate[T] {
-
- override def initiate(partial: Row): Unit = {
- partial.setField(partialSumIndex, 0L)
- partial.setField(partialCountIndex, 0L)
- }
-
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- partial.setField(partialSumIndex, 0L)
- partial.setField(partialCountIndex, 0L)
- } else {
- doPrepare(value, partial)
- }
- }
-
- override def merge(partial: Row, buffer: Row): Unit = {
- val partialSum = partial.getField(partialSumIndex).asInstanceOf[Long]
- val partialCount = partial.getField(partialCountIndex).asInstanceOf[Long]
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Long]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- buffer.setField(partialSumIndex, LongMath.checkedAdd(partialSum, bufferSum))
- buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount))
- }
-
- override def evaluate(buffer : Row): T = {
- doEvaluate(buffer).asInstanceOf[T]
- }
-
- override def intermediateDataType = Array(
- BasicTypeInfo.LONG_TYPE_INFO,
- BasicTypeInfo.LONG_TYPE_INFO)
-
- def doPrepare(value: Any, partial: Row): Unit
-
- def doEvaluate(buffer: Row): Any
-}
-
-class ByteAvgAggregate extends IntegralAvgAggregate[Byte] {
- override def doPrepare(value: Any, partial: Row): Unit = {
- val input = value.asInstanceOf[Byte]
- partial.setField(partialSumIndex, input.toLong)
- partial.setField(partialCountIndex, 1L)
- }
-
- override def doEvaluate(buffer: Row): Any = {
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Long]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- if (bufferCount == 0L) {
- null
- } else {
- (bufferSum / bufferCount).toByte
- }
- }
-}
-
-class ShortAvgAggregate extends IntegralAvgAggregate[Short] {
-
- override def doPrepare(value: Any, partial: Row): Unit = {
- val input = value.asInstanceOf[Short]
- partial.setField(partialSumIndex, input.toLong)
- partial.setField(partialCountIndex, 1L)
- }
-
- override def doEvaluate(buffer: Row): Any = {
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Long]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- if (bufferCount == 0L) {
- null
- } else {
- (bufferSum / bufferCount).toShort
- }
- }
-}
-
-class IntAvgAggregate extends IntegralAvgAggregate[Int] {
-
- override def doPrepare(value: Any, partial: Row): Unit = {
- val input = value.asInstanceOf[Int]
- partial.setField(partialSumIndex, input.toLong)
- partial.setField(partialCountIndex, 1L)
- }
-
- override def doEvaluate(buffer: Row): Any = {
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Long]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- if (bufferCount == 0L) {
- null
- } else {
- (bufferSum / bufferCount).toInt
- }
- }
-}
-
-class LongAvgAggregate extends IntegralAvgAggregate[Long] {
-
- override def intermediateDataType = Array(
- BasicTypeInfo.BIG_INT_TYPE_INFO,
- BasicTypeInfo.LONG_TYPE_INFO)
-
- override def initiate(partial: Row): Unit = {
- partial.setField(partialSumIndex, BigInteger.ZERO)
- partial.setField(partialCountIndex, 0L)
- }
-
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- partial.setField(partialSumIndex, BigInteger.ZERO)
- partial.setField(partialCountIndex, 0L)
- } else {
- doPrepare(value, partial)
- }
- }
-
- override def doPrepare(value: Any, partial: Row): Unit = {
- val input = value.asInstanceOf[Long]
- partial.setField(partialSumIndex, BigInteger.valueOf(input))
- partial.setField(partialCountIndex, 1L)
- }
-
- override def merge(partial: Row, buffer: Row): Unit = {
- val partialSum = partial.getField(partialSumIndex).asInstanceOf[BigInteger]
- val partialCount = partial.getField(partialCountIndex).asInstanceOf[Long]
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[BigInteger]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- buffer.setField(partialSumIndex, partialSum.add(bufferSum))
- buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount))
- }
-
- override def doEvaluate(buffer: Row): Any = {
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[BigInteger]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- if (bufferCount == 0L) {
- null
- } else {
- bufferSum.divide(BigInteger.valueOf(bufferCount)).longValue()
- }
- }
-}
-
-abstract class FloatingAvgAggregate[T: Numeric] extends AvgAggregate[T] {
-
- override def initiate(partial: Row): Unit = {
- partial.setField(partialSumIndex, 0D)
- partial.setField(partialCountIndex, 0L)
- }
-
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- partial.setField(partialSumIndex, 0D)
- partial.setField(partialCountIndex, 0L)
- } else {
- doPrepare(value, partial)
- }
- }
-
- override def merge(partial: Row, buffer: Row): Unit = {
- val partialSum = partial.getField(partialSumIndex).asInstanceOf[Double]
- val partialCount = partial.getField(partialCountIndex).asInstanceOf[Long]
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Double]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
-
- buffer.setField(partialSumIndex, partialSum + bufferSum)
- buffer.setField(partialCountIndex, partialCount + bufferCount)
- }
-
- override def evaluate(buffer : Row): T = {
- doEvaluate(buffer).asInstanceOf[T]
- }
-
- override def intermediateDataType = Array(
- BasicTypeInfo.DOUBLE_TYPE_INFO,
- BasicTypeInfo.LONG_TYPE_INFO)
-
- def doPrepare(value: Any, partial: Row): Unit
-
- def doEvaluate(buffer: Row): Any
-}
-
-class FloatAvgAggregate extends FloatingAvgAggregate[Float] {
-
- override def doPrepare(value: Any, partial: Row): Unit = {
- val input = value.asInstanceOf[Float]
- partial.setField(partialSumIndex, input.toDouble)
- partial.setField(partialCountIndex, 1L)
- }
-
-
- override def doEvaluate(buffer: Row): Any = {
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Double]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- if (bufferCount == 0L) {
- null
- } else {
- (bufferSum / bufferCount).toFloat
- }
- }
-}
-
-class DoubleAvgAggregate extends FloatingAvgAggregate[Double] {
-
- override def doPrepare(value: Any, partial: Row): Unit = {
- val input = value.asInstanceOf[Double]
- partial.setField(partialSumIndex, input)
- partial.setField(partialCountIndex, 1L)
- }
-
- override def doEvaluate(buffer: Row): Any = {
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[Double]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- if (bufferCount == 0L) {
- null
- } else {
- (bufferSum / bufferCount)
- }
- }
-}
-
-class DecimalAvgAggregate extends AvgAggregate[BigDecimal] {
-
- override def intermediateDataType = Array(
- BasicTypeInfo.BIG_DEC_TYPE_INFO,
- BasicTypeInfo.LONG_TYPE_INFO)
-
- override def initiate(partial: Row): Unit = {
- partial.setField(partialSumIndex, BigDecimal.ZERO)
- partial.setField(partialCountIndex, 0L)
- }
-
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- initiate(partial)
- } else {
- val input = value.asInstanceOf[BigDecimal]
- partial.setField(partialSumIndex, input)
- partial.setField(partialCountIndex, 1L)
- }
- }
-
- override def merge(partial: Row, buffer: Row): Unit = {
- val partialSum = partial.getField(partialSumIndex).asInstanceOf[BigDecimal]
- val partialCount = partial.getField(partialCountIndex).asInstanceOf[Long]
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[BigDecimal]
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- buffer.setField(partialSumIndex, partialSum.add(bufferSum))
- buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount))
- }
-
- override def evaluate(buffer: Row): BigDecimal = {
- val bufferCount = buffer.getField(partialCountIndex).asInstanceOf[Long]
- if (bufferCount != 0) {
- val bufferSum = buffer.getField(partialSumIndex).asInstanceOf[BigDecimal]
- bufferSum.divide(BigDecimal.valueOf(bufferCount))
- } else {
- null.asInstanceOf[BigDecimal]
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CountAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CountAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CountAggregate.scala
deleted file mode 100644
index ea8e1d8..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CountAggregate.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.types.Row
-
-class CountAggregate extends Aggregate[Long] {
- private var countIndex: Int = _
-
- override def initiate(intermediate: Row): Unit = {
- intermediate.setField(countIndex, 0L)
- }
-
- override def merge(intermediate: Row, buffer: Row): Unit = {
- val partialCount = intermediate.getField(countIndex).asInstanceOf[Long]
- val bufferCount = buffer.getField(countIndex).asInstanceOf[Long]
- buffer.setField(countIndex, partialCount + bufferCount)
- }
-
- override def evaluate(buffer: Row): Long = {
- buffer.getField(countIndex).asInstanceOf[Long]
- }
-
- override def prepare(value: Any, intermediate: Row): Unit = {
- if (value == null) {
- intermediate.setField(countIndex, 0L)
- } else {
- intermediate.setField(countIndex, 1L)
- }
- }
-
- override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO)
-
- override def supportPartial: Boolean = true
-
- override def setAggOffsetInRow(aggIndex: Int): Unit = {
- countIndex = aggIndex
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MaxAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MaxAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MaxAggregate.scala
deleted file mode 100644
index 34b25e0..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MaxAggregate.scala
+++ /dev/null
@@ -1,171 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import java.math.BigDecimal
-
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo
-import org.apache.flink.types.Row
-
-abstract class MaxAggregate[T](implicit ord: Ordering[T]) extends Aggregate[T] {
-
- protected var maxIndex = -1
-
- /**
- * Initiate the intermediate aggregate value in Row.
- *
- * @param intermediate The intermediate aggregate row to initiate.
- */
- override def initiate(intermediate: Row): Unit = {
- intermediate.setField(maxIndex, null)
- }
-
- /**
- * Accessed in MapFunction, prepare the input of partial aggregate.
- *
- * @param value
- * @param intermediate
- */
- override def prepare(value: Any, intermediate: Row): Unit = {
- if (value == null) {
- initiate(intermediate)
- } else {
- intermediate.setField(maxIndex, value)
- }
- }
-
- /**
- * Accessed in CombineFunction and GroupReduceFunction, merge partial
- * aggregate result into aggregate buffer.
- *
- * @param intermediate
- * @param buffer
- */
- override def merge(intermediate: Row, buffer: Row): Unit = {
- val partialValue = intermediate.getField(maxIndex).asInstanceOf[T]
- if (partialValue != null) {
- val bufferValue = buffer.getField(maxIndex).asInstanceOf[T]
- if (bufferValue != null) {
- val max : T = if (ord.compare(partialValue, bufferValue) > 0) partialValue else bufferValue
- buffer.setField(maxIndex, max)
- } else {
- buffer.setField(maxIndex, partialValue)
- }
- }
- }
-
- /**
- * Return the final aggregated result based on aggregate buffer.
- *
- * @param buffer
- * @return
- */
- override def evaluate(buffer: Row): T = {
- buffer.getField(maxIndex).asInstanceOf[T]
- }
-
- override def supportPartial: Boolean = true
-
- override def setAggOffsetInRow(aggOffset: Int): Unit = {
- maxIndex = aggOffset
- }
-}
-
-class ByteMaxAggregate extends MaxAggregate[Byte] {
-
- override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO)
-
-}
-
-class ShortMaxAggregate extends MaxAggregate[Short] {
-
- override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO)
-
-}
-
-class IntMaxAggregate extends MaxAggregate[Int] {
-
- override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO)
-
-}
-
-class LongMaxAggregate extends MaxAggregate[Long] {
-
- override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO)
-
-}
-
-class FloatMaxAggregate extends MaxAggregate[Float] {
-
- override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO)
-
-}
-
-class DoubleMaxAggregate extends MaxAggregate[Double] {
-
- override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO)
-
-}
-
-class BooleanMaxAggregate extends MaxAggregate[Boolean] {
-
- override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO)
-
-}
-
-class DecimalMaxAggregate extends Aggregate[BigDecimal] {
-
- protected var minIndex: Int = _
-
- override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
-
- override def initiate(intermediate: Row): Unit = {
- intermediate.setField(minIndex, null)
- }
-
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- initiate(partial)
- } else {
- partial.setField(minIndex, value)
- }
- }
-
- override def merge(partial: Row, buffer: Row): Unit = {
- val partialValue = partial.getField(minIndex).asInstanceOf[BigDecimal]
- if (partialValue != null) {
- val bufferValue = buffer.getField(minIndex).asInstanceOf[BigDecimal]
- if (bufferValue != null) {
- val min = if (partialValue.compareTo(bufferValue) > 0) partialValue else bufferValue
- buffer.setField(minIndex, min)
- } else {
- buffer.setField(minIndex, partialValue)
- }
- }
- }
-
- override def evaluate(buffer: Row): BigDecimal = {
- buffer.getField(minIndex).asInstanceOf[BigDecimal]
- }
-
- override def supportPartial: Boolean = true
-
- override def setAggOffsetInRow(aggOffset: Int): Unit = {
- minIndex = aggOffset
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MinAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MinAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MinAggregate.scala
deleted file mode 100644
index 88cb058..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/MinAggregate.scala
+++ /dev/null
@@ -1,171 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import java.math.BigDecimal
-
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo
-import org.apache.flink.types.Row
-
-abstract class MinAggregate[T](implicit ord: Ordering[T]) extends Aggregate[T] {
-
- protected var minIndex: Int = _
-
- /**
- * Initiate the intermediate aggregate value in Row.
- *
- * @param intermediate The intermediate aggregate row to initiate.
- */
- override def initiate(intermediate: Row): Unit = {
- intermediate.setField(minIndex, null)
- }
-
- /**
- * Accessed in MapFunction, prepare the input of partial aggregate.
- *
- * @param value
- * @param partial
- */
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- initiate(partial)
- } else {
- partial.setField(minIndex, value)
- }
- }
-
- /**
- * Accessed in CombineFunction and GroupReduceFunction, merge partial
- * aggregate result into aggregate buffer.
- *
- * @param partial
- * @param buffer
- */
- override def merge(partial: Row, buffer: Row): Unit = {
- val partialValue = partial.getField(minIndex).asInstanceOf[T]
- if (partialValue != null) {
- val bufferValue = buffer.getField(minIndex).asInstanceOf[T]
- if (bufferValue != null) {
- val min : T = if (ord.compare(partialValue, bufferValue) < 0) partialValue else bufferValue
- buffer.setField(minIndex, min)
- } else {
- buffer.setField(minIndex, partialValue)
- }
- }
- }
-
- /**
- * Return the final aggregated result based on aggregate buffer.
- *
- * @param buffer
- * @return
- */
- override def evaluate(buffer: Row): T = {
- buffer.getField(minIndex).asInstanceOf[T]
- }
-
- override def supportPartial: Boolean = true
-
- override def setAggOffsetInRow(aggOffset: Int): Unit = {
- minIndex = aggOffset
- }
-}
-
-class ByteMinAggregate extends MinAggregate[Byte] {
-
- override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO)
-
-}
-
-class ShortMinAggregate extends MinAggregate[Short] {
-
- override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO)
-
-}
-
-class IntMinAggregate extends MinAggregate[Int] {
-
- override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO)
-
-}
-
-class LongMinAggregate extends MinAggregate[Long] {
-
- override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO)
-
-}
-
-class FloatMinAggregate extends MinAggregate[Float] {
-
- override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO)
-
-}
-
-class DoubleMinAggregate extends MinAggregate[Double] {
-
- override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO)
-
-}
-
-class BooleanMinAggregate extends MinAggregate[Boolean] {
-
- override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO)
-
-}
-
-class DecimalMinAggregate extends Aggregate[BigDecimal] {
-
- protected var minIndex: Int = _
-
- override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
-
- override def initiate(intermediate: Row): Unit = {
- intermediate.setField(minIndex, null)
- }
-
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- initiate(partial)
- } else {
- partial.setField(minIndex, value)
- }
- }
-
- override def merge(partial: Row, buffer: Row): Unit = {
- val partialValue = partial.getField(minIndex).asInstanceOf[BigDecimal]
- if (partialValue != null) {
- val bufferValue = buffer.getField(minIndex).asInstanceOf[BigDecimal]
- if (bufferValue != null) {
- val min = if (partialValue.compareTo(bufferValue) < 0) partialValue else bufferValue
- buffer.setField(minIndex, min)
- } else {
- buffer.setField(minIndex, partialValue)
- }
- }
- }
-
- override def evaluate(buffer: Row): BigDecimal = {
- buffer.getField(minIndex).asInstanceOf[BigDecimal]
- }
-
- override def supportPartial: Boolean = true
-
- override def setAggOffsetInRow(aggOffset: Int): Unit = {
- minIndex = aggOffset
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/SumAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/SumAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/SumAggregate.scala
deleted file mode 100644
index cd88112..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/SumAggregate.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import java.math.BigDecimal
-
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo
-import org.apache.flink.types.Row
-
-abstract class SumAggregate[T: Numeric]
- extends Aggregate[T] {
-
- private val numeric = implicitly[Numeric[T]]
- protected var sumIndex: Int = _
-
- override def initiate(partial: Row): Unit = {
- partial.setField(sumIndex, null)
- }
-
- override def merge(partial1: Row, buffer: Row): Unit = {
- val partialValue = partial1.getField(sumIndex).asInstanceOf[T]
- if (partialValue != null) {
- val bufferValue = buffer.getField(sumIndex).asInstanceOf[T]
- if (bufferValue != null) {
- buffer.setField(sumIndex, numeric.plus(partialValue, bufferValue))
- } else {
- buffer.setField(sumIndex, partialValue)
- }
- }
- }
-
- override def evaluate(buffer: Row): T = {
- buffer.getField(sumIndex).asInstanceOf[T]
- }
-
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- initiate(partial)
- } else {
- val input = value.asInstanceOf[T]
- partial.setField(sumIndex, input)
- }
- }
-
- override def supportPartial: Boolean = true
-
- override def setAggOffsetInRow(aggOffset: Int): Unit = {
- sumIndex = aggOffset
- }
-}
-
-class ByteSumAggregate extends SumAggregate[Byte] {
- override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO)
-}
-
-class ShortSumAggregate extends SumAggregate[Short] {
- override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO)
-}
-
-class IntSumAggregate extends SumAggregate[Int] {
- override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO)
-}
-
-class LongSumAggregate extends SumAggregate[Long] {
- override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO)
-}
-
-class FloatSumAggregate extends SumAggregate[Float] {
- override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO)
-}
-
-class DoubleSumAggregate extends SumAggregate[Double] {
- override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO)
-}
-
-class DecimalSumAggregate extends Aggregate[BigDecimal] {
-
- protected var sumIndex: Int = _
-
- override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
-
- override def initiate(partial: Row): Unit = {
- partial.setField(sumIndex, null)
- }
-
- override def merge(partial1: Row, buffer: Row): Unit = {
- val partialValue = partial1.getField(sumIndex).asInstanceOf[BigDecimal]
- if (partialValue != null) {
- val bufferValue = buffer.getField(sumIndex).asInstanceOf[BigDecimal]
- if (bufferValue != null) {
- buffer.setField(sumIndex, partialValue.add(bufferValue))
- } else {
- buffer.setField(sumIndex, partialValue)
- }
- }
- }
-
- override def evaluate(buffer: Row): BigDecimal = {
- buffer.getField(sumIndex).asInstanceOf[BigDecimal]
- }
-
- override def prepare(value: Any, partial: Row): Unit = {
- if (value == null) {
- initiate(partial)
- } else {
- val input = value.asInstanceOf[BigDecimal]
- partial.setField(sumIndex, input)
- }
- }
-
- override def supportPartial: Boolean = true
-
- override def setAggOffsetInRow(aggOffset: Int): Unit = {
- sumIndex = aggOffset
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AggregateTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AggregateTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AggregateTestBase.scala
deleted file mode 100644
index 0ca101d..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AggregateTestBase.scala
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.math.BigDecimal
-import org.apache.flink.types.Row
-
-import org.junit.Test
-import org.junit.Assert.assertEquals
-
-abstract class AggregateTestBase[T] {
-
- private val offset = 2
- private val rowArity: Int = offset + aggregator.intermediateDataType.length
-
- def inputValueSets: Seq[Seq[_]]
-
- def expectedResults: Seq[T]
-
- def aggregator: Aggregate[T]
-
- private def createAggregator(): Aggregate[T] = {
- val agg = aggregator
- agg.setAggOffsetInRow(offset)
- agg
- }
-
- private def createRow(): Row = {
- new Row(rowArity)
- }
-
- @Test
- def testAggregate(): Unit = {
-
- // iterate over input sets
- for((vals, expected) <- inputValueSets.zip(expectedResults)) {
-
- // prepare mapper
- val rows: Seq[Row] = prepare(vals)
-
- val result = if (aggregator.supportPartial) {
- // test with combiner
- val (firstVals, secondVals) = rows.splitAt(rows.length / 2)
- val combined = partialAgg(firstVals) :: partialAgg(secondVals) :: Nil
- finalAgg(combined)
-
- } else {
- // test without combiner
- finalAgg(rows)
- }
-
- (expected, result) match {
- case (e: BigDecimal, r: BigDecimal) =>
- // BigDecimal.equals() value and scale but we are only interested in value.
- assert(e.compareTo(r) == 0)
- case _ =>
- assertEquals(expected, result)
- }
- }
- }
-
- private def prepare(vals: Seq[_]): Seq[Row] = {
-
- val agg = createAggregator()
-
- vals.map { v =>
- val row = createRow()
- agg.prepare(v, row)
- row
- }
- }
-
- private def partialAgg(rows: Seq[Row]): Row = {
-
- val agg = createAggregator()
- val aggBuf = createRow()
-
- agg.initiate(aggBuf)
- rows.foreach(v => agg.merge(v, aggBuf))
-
- aggBuf
- }
-
- private def finalAgg(rows: Seq[Row]): T = {
-
- val agg = createAggregator()
- val aggBuf = createRow()
-
- agg.initiate(aggBuf)
- rows.foreach(v => agg.merge(v, aggBuf))
-
- agg.evaluate(partialAgg(rows))
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AvgAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AvgAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AvgAggregateTest.scala
deleted file mode 100644
index a72d08b..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/AvgAggregateTest.scala
+++ /dev/null
@@ -1,154 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.math.BigDecimal
-
-abstract class AvgAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
-
- private val numeric: Numeric[T] = implicitly[Numeric[T]]
-
- def minVal: T
- def maxVal: T
-
- override def inputValueSets: Seq[Seq[T]] = Seq(
- Seq(
- minVal,
- minVal,
- null.asInstanceOf[T],
- minVal,
- minVal,
- null.asInstanceOf[T],
- minVal,
- minVal,
- minVal
- ),
- Seq(
- maxVal,
- maxVal,
- null.asInstanceOf[T],
- maxVal,
- maxVal,
- null.asInstanceOf[T],
- maxVal,
- maxVal,
- maxVal
- ),
- Seq(
- minVal,
- maxVal,
- null.asInstanceOf[T],
- numeric.fromInt(0),
- numeric.negate(maxVal),
- numeric.negate(minVal),
- null.asInstanceOf[T]
- ),
- Seq(
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T]
- )
- )
-
- override def expectedResults: Seq[T] = Seq(
- minVal,
- maxVal,
- numeric.fromInt(0),
- null.asInstanceOf[T]
- )
-}
-
-class ByteAvgAggregateTest extends AvgAggregateTestBase[Byte] {
-
- override def minVal = (Byte.MinValue + 1).toByte
- override def maxVal = (Byte.MaxValue - 1).toByte
-
- override def aggregator = new ByteAvgAggregate()
-}
-
-class ShortAvgAggregateTest extends AvgAggregateTestBase[Short] {
-
- override def minVal = (Short.MinValue + 1).toShort
- override def maxVal = (Short.MaxValue - 1).toShort
-
- override def aggregator = new ShortAvgAggregate()
-}
-
-class IntAvgAggregateTest extends AvgAggregateTestBase[Int] {
-
- override def minVal = Int.MinValue + 1
- override def maxVal = Int.MaxValue - 1
-
- override def aggregator = new IntAvgAggregate()
-}
-
-class LongAvgAggregateTest extends AvgAggregateTestBase[Long] {
-
- override def minVal = Long.MinValue + 1
- override def maxVal = Long.MaxValue - 1
-
- override def aggregator = new LongAvgAggregate()
-}
-
-class FloatAvgAggregateTest extends AvgAggregateTestBase[Float] {
-
- override def minVal = Float.MinValue
- override def maxVal = Float.MaxValue
-
- override def aggregator = new FloatAvgAggregate()
-}
-
-class DoubleAvgAggregateTest extends AvgAggregateTestBase[Double] {
-
- override def minVal = Float.MinValue
- override def maxVal = Float.MaxValue
-
- override def aggregator = new DoubleAvgAggregate()
-}
-
-class DecimalAvgAggregateTest extends AggregateTestBase[BigDecimal] {
-
- override def inputValueSets: Seq[Seq[_]] = Seq(
- Seq(
- new BigDecimal("987654321000000"),
- new BigDecimal("-0.000000000012345"),
- null,
- new BigDecimal("0.000000000012345"),
- new BigDecimal("-987654321000000"),
- null,
- new BigDecimal("0")
- ),
- Seq(
- null,
- null,
- null,
- null
- )
- )
-
- override def expectedResults: Seq[BigDecimal] = Seq(
- BigDecimal.ZERO,
- null
- )
-
- override def aggregator: Aggregate[BigDecimal] = new DecimalAvgAggregate()
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/CountAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/CountAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/CountAggregateTest.scala
deleted file mode 100644
index 55f73b4..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/CountAggregateTest.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-class CountAggregateTest extends AggregateTestBase[Long] {
-
- override def inputValueSets: Seq[Seq[_]] = Seq(
- Seq("a", "b", null, "c", null, "d", "e", null, "f"),
- Seq(null, null, null, null, null, null)
- )
-
- override def expectedResults: Seq[Long] = Seq(6L, 0L)
-
- override def aggregator: Aggregate[Long] = new CountAggregate()
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MaxAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MaxAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MaxAggregateTest.scala
deleted file mode 100644
index 1bf879d..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MaxAggregateTest.scala
+++ /dev/null
@@ -1,177 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.math.BigDecimal
-
-abstract class MaxAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
-
- private val numeric: Numeric[T] = implicitly[Numeric[T]]
-
- def minVal: T
- def maxVal: T
-
- override def inputValueSets: Seq[Seq[T]] = Seq(
- Seq(
- numeric.fromInt(1),
- null.asInstanceOf[T],
- maxVal,
- numeric.fromInt(-99),
- numeric.fromInt(3),
- numeric.fromInt(56),
- numeric.fromInt(0),
- minVal,
- numeric.fromInt(-20),
- numeric.fromInt(17),
- null.asInstanceOf[T]
- ),
- Seq(
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T]
- )
- )
-
- override def expectedResults: Seq[T] = Seq(
- maxVal,
- null.asInstanceOf[T]
- )
-}
-
-class ByteMaxAggregateTest extends MaxAggregateTestBase[Byte] {
-
- override def minVal = (Byte.MinValue + 1).toByte
- override def maxVal = (Byte.MaxValue - 1).toByte
-
- override def aggregator: Aggregate[Byte] = new ByteMaxAggregate()
-}
-
-class ShortMaxAggregateTest extends MaxAggregateTestBase[Short] {
-
- override def minVal = (Short.MinValue + 1).toShort
- override def maxVal = (Short.MaxValue - 1).toShort
-
- override def aggregator: Aggregate[Short] = new ShortMaxAggregate()
-}
-
-class IntMaxAggregateTest extends MaxAggregateTestBase[Int] {
-
- override def minVal = Int.MinValue + 1
- override def maxVal = Int.MaxValue - 1
-
- override def aggregator: Aggregate[Int] = new IntMaxAggregate()
-}
-
-class LongMaxAggregateTest extends MaxAggregateTestBase[Long] {
-
- override def minVal = Long.MinValue + 1
- override def maxVal = Long.MaxValue - 1
-
- override def aggregator: Aggregate[Long] = new LongMaxAggregate()
-}
-
-class FloatMaxAggregateTest extends MaxAggregateTestBase[Float] {
-
- override def minVal = Float.MinValue / 2
- override def maxVal = Float.MaxValue / 2
-
- override def aggregator: Aggregate[Float] = new FloatMaxAggregate()
-}
-
-class DoubleMaxAggregateTest extends MaxAggregateTestBase[Double] {
-
- override def minVal = Double.MinValue / 2
- override def maxVal = Double.MaxValue / 2
-
- override def aggregator: Aggregate[Double] = new DoubleMaxAggregate()
-}
-
-class BooleanMaxAggregateTest extends AggregateTestBase[Boolean] {
-
- override def inputValueSets: Seq[Seq[Boolean]] = Seq(
- Seq(
- false,
- false,
- false
- ),
- Seq(
- true,
- true,
- true
- ),
- Seq(
- true,
- false,
- null.asInstanceOf[Boolean],
- true,
- false,
- true,
- null.asInstanceOf[Boolean]
- ),
- Seq(
- null.asInstanceOf[Boolean],
- null.asInstanceOf[Boolean],
- null.asInstanceOf[Boolean]
- )
- )
-
- override def expectedResults: Seq[Boolean] = Seq(
- false,
- true,
- true,
- null.asInstanceOf[Boolean]
- )
-
- override def aggregator: Aggregate[Boolean] = new BooleanMaxAggregate()
-}
-
-class DecimalMaxAggregateTest extends AggregateTestBase[BigDecimal] {
-
- override def inputValueSets: Seq[Seq[_]] = Seq(
- Seq(
- new BigDecimal("1"),
- new BigDecimal("1000.000001"),
- new BigDecimal("-1"),
- new BigDecimal("-999.998999"),
- null,
- new BigDecimal("0"),
- new BigDecimal("-999.999"),
- null,
- new BigDecimal("999.999")
- ),
- Seq(
- null,
- null,
- null,
- null,
- null
- )
- )
-
- override def expectedResults: Seq[BigDecimal] = Seq(
- new BigDecimal("1000.000001"),
- null
- )
-
- override def aggregator: Aggregate[BigDecimal] = new DecimalMaxAggregate()
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MinAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MinAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MinAggregateTest.scala
deleted file mode 100644
index 3e2404d..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/MinAggregateTest.scala
+++ /dev/null
@@ -1,177 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.math.BigDecimal
-
-abstract class MinAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
-
- private val numeric: Numeric[T] = implicitly[Numeric[T]]
-
- def minVal: T
- def maxVal: T
-
- override def inputValueSets: Seq[Seq[T]] = Seq(
- Seq(
- numeric.fromInt(1),
- null.asInstanceOf[T],
- maxVal,
- numeric.fromInt(-99),
- numeric.fromInt(3),
- numeric.fromInt(56),
- numeric.fromInt(0),
- minVal,
- numeric.fromInt(-20),
- numeric.fromInt(17),
- null.asInstanceOf[T]
- ),
- Seq(
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T]
- )
- )
-
- override def expectedResults: Seq[T] = Seq(
- minVal,
- null.asInstanceOf[T]
- )
-}
-
-class ByteMinAggregateTest extends MinAggregateTestBase[Byte] {
-
- override def minVal = (Byte.MinValue + 1).toByte
- override def maxVal = (Byte.MaxValue - 1).toByte
-
- override def aggregator: Aggregate[Byte] = new ByteMinAggregate()
-}
-
-class ShortMinAggregateTest extends MinAggregateTestBase[Short] {
-
- override def minVal = (Short.MinValue + 1).toShort
- override def maxVal = (Short.MaxValue - 1).toShort
-
- override def aggregator: Aggregate[Short] = new ShortMinAggregate()
-}
-
-class IntMinAggregateTest extends MinAggregateTestBase[Int] {
-
- override def minVal = Int.MinValue + 1
- override def maxVal = Int.MaxValue - 1
-
- override def aggregator: Aggregate[Int] = new IntMinAggregate()
-}
-
-class LongMinAggregateTest extends MinAggregateTestBase[Long] {
-
- override def minVal = Long.MinValue + 1
- override def maxVal = Long.MaxValue - 1
-
- override def aggregator: Aggregate[Long] = new LongMinAggregate()
-}
-
-class FloatMinAggregateTest extends MinAggregateTestBase[Float] {
-
- override def minVal = Float.MinValue / 2
- override def maxVal = Float.MaxValue / 2
-
- override def aggregator: Aggregate[Float] = new FloatMinAggregate()
-}
-
-class DoubleMinAggregateTest extends MinAggregateTestBase[Double] {
-
- override def minVal = Double.MinValue / 2
- override def maxVal = Double.MaxValue / 2
-
- override def aggregator: Aggregate[Double] = new DoubleMinAggregate()
-}
-
-class BooleanMinAggregateTest extends AggregateTestBase[Boolean] {
-
- override def inputValueSets: Seq[Seq[Boolean]] = Seq(
- Seq(
- false,
- false,
- false
- ),
- Seq(
- true,
- true,
- true
- ),
- Seq(
- true,
- false,
- null.asInstanceOf[Boolean],
- true,
- false,
- true,
- null.asInstanceOf[Boolean]
- ),
- Seq(
- null.asInstanceOf[Boolean],
- null.asInstanceOf[Boolean],
- null.asInstanceOf[Boolean]
- )
- )
-
- override def expectedResults: Seq[Boolean] = Seq(
- false,
- true,
- false,
- null.asInstanceOf[Boolean]
- )
-
- override def aggregator: Aggregate[Boolean] = new BooleanMinAggregate()
-}
-
-class DecimalMinAggregateTest extends AggregateTestBase[BigDecimal] {
-
- override def inputValueSets: Seq[Seq[_]] = Seq(
- Seq(
- new BigDecimal("1"),
- new BigDecimal("1000"),
- new BigDecimal("-1"),
- new BigDecimal("-999.998999"),
- null,
- new BigDecimal("0"),
- new BigDecimal("-999.999"),
- null,
- new BigDecimal("999.999")
- ),
- Seq(
- null,
- null,
- null,
- null,
- null
- )
- )
-
- override def expectedResults: Seq[BigDecimal] = Seq(
- new BigDecimal("-999.999"),
- null
- )
-
- override def aggregator: Aggregate[BigDecimal] = new DecimalMinAggregate()
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/c31f95ca/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/SumAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/SumAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/SumAggregateTest.scala
deleted file mode 100644
index c085334..0000000
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/SumAggregateTest.scala
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.runtime.aggregate
-
-import java.math.BigDecimal
-
-abstract class SumAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
-
- private val numeric: Numeric[T] = implicitly[Numeric[T]]
-
- def maxVal: T
- private val minVal = numeric.negate(maxVal)
-
- override def inputValueSets: Seq[Seq[T]] = Seq(
- Seq(
- minVal,
- numeric.fromInt(1),
- null.asInstanceOf[T],
- numeric.fromInt(2),
- numeric.fromInt(3),
- numeric.fromInt(4),
- numeric.fromInt(5),
- numeric.fromInt(-10),
- numeric.fromInt(-20),
- numeric.fromInt(17),
- null.asInstanceOf[T],
- maxVal
- ),
- Seq(
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T],
- null.asInstanceOf[T]
- )
- )
-
- override def expectedResults: Seq[T] = Seq(
- numeric.fromInt(2),
- null.asInstanceOf[T]
- )
-}
-
-class ByteSumAggregateTest extends SumAggregateTestBase[Byte] {
-
- override def maxVal = (Byte.MaxValue / 2).toByte
-
- override def aggregator: Aggregate[Byte] = new ByteSumAggregate
-}
-
-class ShortSumAggregateTest extends SumAggregateTestBase[Short] {
-
- override def maxVal = (Short.MaxValue / 2).toShort
-
- override def aggregator: Aggregate[Short] = new ShortSumAggregate
-}
-
-class IntSumAggregateTest extends SumAggregateTestBase[Int] {
-
- override def maxVal = Int.MaxValue / 2
-
- override def aggregator: Aggregate[Int] = new IntSumAggregate
-}
-
-class LongSumAggregateTest extends SumAggregateTestBase[Long] {
-
- override def maxVal = Long.MaxValue / 2
-
- override def aggregator: Aggregate[Long] = new LongSumAggregate
-}
-
-class FloatSumAggregateTest extends SumAggregateTestBase[Float] {
-
- override def maxVal = 12345.6789f
-
- override def aggregator: Aggregate[Float] = new FloatSumAggregate
-}
-
-class DoubleSumAggregateTest extends SumAggregateTestBase[Double] {
-
- override def maxVal = 12345.6789d
-
- override def aggregator: Aggregate[Double] = new DoubleSumAggregate
-}
-
-class DecimalSumAggregateTest extends AggregateTestBase[BigDecimal] {
-
- override def inputValueSets: Seq[Seq[_]] = Seq(
- Seq(
- new BigDecimal("1"),
- new BigDecimal("2"),
- new BigDecimal("3"),
- null,
- new BigDecimal("0"),
- new BigDecimal("-1000"),
- new BigDecimal("0.000000000002"),
- new BigDecimal("1000"),
- new BigDecimal("-0.000000000001"),
- new BigDecimal("999.999"),
- null,
- new BigDecimal("4"),
- new BigDecimal("-999.999"),
- null
- ),
- Seq(
- null,
- null,
- null,
- null,
- null
- )
- )
-
- override def expectedResults: Seq[BigDecimal] = Seq(
- new BigDecimal("10.000000000001"),
- null
- )
-
- override def aggregator: Aggregate[BigDecimal] = new DecimalSumAggregate()
-}
[2/3] flink git commit: [FLINK-5955] [table] Fix aggregations with
ObjectReuse enabled by pairwise merging of accumulators.
Posted by fh...@apache.org.
[FLINK-5955] [table] Fix aggregations with ObjectReuse enabled by pairwise merging of accumulators.
This closes #3465.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/14fab4c4
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/14fab4c4
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/14fab4c4
Branch: refs/heads/master
Commit: 14fab4c412048f769209855d876221817e73ba25
Parents: 2d1721b
Author: shaoxuan-wang <ws...@gmail.com>
Authored: Fri Mar 3 13:50:29 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Mar 3 14:27:08 2017 +0100
----------------------------------------------------------------------
.../table/functions/AggregateFunction.scala | 7 ++-
.../AggregateReduceCombineFunction.scala | 51 ++++++++--------
.../AggregateReduceGroupFunction.scala | 52 ++++++++--------
...ionWindowAggregateCombineGroupFunction.scala | 58 +++++++++---------
...sionWindowAggregateReduceGroupFunction.scala | 62 +++++++++++---------
...umbleCountWindowAggReduceGroupFunction.scala | 47 ++++++++-------
...mbleTimeWindowAggReduceCombineFunction.scala | 40 ++++++-------
...TumbleTimeWindowAggReduceGroupFunction.scala | 49 +++++++++-------
8 files changed, 191 insertions(+), 175 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index 178b439..e5666ce 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -58,8 +58,11 @@ abstract class AggregateFunction[T] extends UserDefinedFunction {
/**
* Merge a list of accumulator instances into one accumulator instance.
*
- * @param accumulators the [[java.util.List]] of accumulators
- * that will be merged
+ * IMPORTANT: You may only return a new accumulator instance or the the first accumulator of the
+ * input list. If you return another instance, the result of the aggregation function might be
+ * incorrect.
+ *
+ * @param accumulators the [[java.util.List]] of accumulators that will be merged
* @return the resulting accumulator
*/
def merge(accumulators: JList[Accumulator]): Accumulator
http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
index 06ac8fb..6b95cb8 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
@@ -19,9 +19,9 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.CombineFunction
+import org.apache.flink.configuration.Configuration
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
@@ -53,6 +53,13 @@ class AggregateReduceCombineFunction(
groupingSetsMapping,
finalRowArity) with CombineFunction[Row, Row] {
+ var preAggOutput: Row = _
+
+ override def open(config: Configuration): Unit = {
+ super.open(config)
+ preAggOutput = new Row(aggregates.length + groupKeysMapping.length)
+ }
+
/**
* For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
*
@@ -62,45 +69,41 @@ class AggregateReduceCombineFunction(
*/
override def combine(records: Iterable[Row]): Row = {
- // merge intermediate aggregate value to buffer.
var last: Row = null
- accumulatorList.foreach(_.clear())
-
val iterator = records.iterator()
- var count: Int = 0
+ // reset first accumulator in merge list
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
+
while (iterator.hasNext) {
val record = iterator.next()
- count += 1
- // per each aggregator, collect its accumulators to a list
+
for (i <- aggregates.indices) {
- accumulatorList(i).add(record.getField(groupKeysMapping.length + i)
- .asInstanceOf[Accumulator])
- }
- // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
- // accumulator
- if (count > maxMergeLen) {
- count = 0
- for (i <- aggregates.indices) {
- val agg = aggregates(i)
- val accumulator = agg.merge(accumulatorList(i))
- accumulatorList(i).clear()
- accumulatorList(i).add(accumulator)
- }
+ // insert received accumulator into acc list
+ val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator]
+ accumulatorList(i).set(1, newAcc)
+ // merge acc list
+ val retAcc = aggregates(i).merge(accumulatorList(i))
+ // insert result into acc list
+ accumulatorList(i).set(0, retAcc)
}
+
last = record
}
+ // set the partial merged result to the aggregateBuffer
for (i <- aggregates.indices) {
- val agg = aggregates(i)
- aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i)))
+ preAggOutput.setField(groupKeysMapping.length + i, accumulatorList(i).get(0))
}
// set group keys to aggregateBuffer.
for (i <- groupKeysMapping.indices) {
- aggregateBuffer.setField(i, last.getField(i))
+ preAggOutput.setField(i, last.getField(i))
}
- aggregateBuffer
+ preAggOutput
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
index 23b5236..2f75cd7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
@@ -48,22 +48,27 @@ class AggregateReduceGroupFunction(
private val finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
- protected var aggregateBuffer: Row = _
private var output: Row = _
private var intermediateGroupKeys: Option[Array[Int]] = None
- protected val maxMergeLen = 16
- val accumulatorList = Array.fill(aggregates.length) {
- new JArrayList[Accumulator]()
+
+ val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator](2)
}
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupKeysMapping)
- aggregateBuffer = new Row(aggregates.length + groupKeysMapping.length)
output = new Row(finalRowArity)
if (!groupingSetsMapping.isEmpty) {
intermediateGroupKeys = Some(groupKeysMapping.map(_._1))
}
+
+ // init lists with two empty accumulators
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).add(accumulator)
+ accumulatorList(i).add(accumulator)
+ }
}
/**
@@ -77,32 +82,28 @@ class AggregateReduceGroupFunction(
*/
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
- // merge intermediate aggregate value to buffer.
var last: Row = null
- accumulatorList.foreach(_.clear())
-
val iterator = records.iterator()
- var count: Int = 0
+ // reset first accumulator in merge list
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
+
while (iterator.hasNext) {
val record = iterator.next()
- count += 1
- // per each aggregator, collect its accumulators to a list
+
for (i <- aggregates.indices) {
- accumulatorList(i).add(record.getField(groupKeysMapping.length + i)
- .asInstanceOf[Accumulator])
- }
- // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
- // accumulator
- if (count > maxMergeLen) {
- count = 0
- for (i <- aggregates.indices) {
- val agg = aggregates(i)
- val accumulator = agg.merge(accumulatorList(i))
- accumulatorList(i).clear()
- accumulatorList(i).add(accumulator)
- }
+ // insert received accumulator into acc list
+ val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator]
+ accumulatorList(i).set(1, newAcc)
+ // merge acc list
+ val retAcc = aggregates(i).merge(accumulatorList(i))
+ // insert result into acc list
+ accumulatorList(i).set(0, retAcc)
}
+
last = record
}
@@ -116,8 +117,7 @@ class AggregateReduceGroupFunction(
aggregateMapping.foreach {
case (after, previous) => {
val agg = aggregates(previous)
- val accumulator = agg.merge(accumulatorList(previous))
- val result = agg.getValue(accumulator)
+ val result = agg.getValue(accumulatorList(previous).get(0))
output.setField(after, result)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala
index 47fa0f1..88cd19f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala
@@ -45,17 +45,24 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
extends RichGroupCombineFunction[Row, Row] with ResultTypeQueryable[Row] {
private var aggregateBuffer: Row = _
- private var accumStartPos: Int = groupingKeys.length
- private var rowTimeFieldPos = accumStartPos + aggregates.length
- private val maxMergeLen = 16
- val accumulatorList = Array.fill(aggregates.length) {
- new JArrayList[Accumulator]()
+ private val accumStartPos: Int = groupingKeys.length
+ private val rowTimeFieldPos = accumStartPos + aggregates.length
+
+ val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator](2)
}
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupingKeys)
aggregateBuffer = new Row(rowTimeFieldPos + 2)
+
+ // init lists with two empty accumulators
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).add(accumulator)
+ accumulatorList(i).add(accumulator)
+ }
}
/**
@@ -72,15 +79,17 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
var windowStart: java.lang.Long = null
var windowEnd: java.lang.Long = null
var currentRowTime: java.lang.Long = null
- accumulatorList.foreach(_.clear())
- val iterator = records.iterator()
+ // reset first accumulator in merge list
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
+ val iterator = records.iterator()
- var count: Int = 0
while (iterator.hasNext) {
val record = iterator.next()
- count += 1
currentRowTime = record.getField(rowTimeFieldPos).asInstanceOf[Long]
// initial traversal or opening a new window
if (windowEnd == null || (windowEnd != null && (currentRowTime > windowEnd))) {
@@ -90,9 +99,11 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
// emit the current window's merged data
doCollect(out, accumulatorList, windowStart, windowEnd)
- // clear the accumulator list for all aggregate
- accumulatorList.foreach(_.clear())
- count = 0
+ // reset first value of accumulator list
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
} else {
// set group keys to aggregateBuffer.
for (i <- groupingKeys.indices) {
@@ -103,21 +114,14 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
windowStart = record.getField(rowTimeFieldPos).asInstanceOf[Long]
}
- // collect the accumulators for each aggregate
for (i <- aggregates.indices) {
- accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator])
- }
-
- // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
- // accumulator
- if (count > maxMergeLen) {
- count = 0
- for (i <- aggregates.indices) {
- val agg = aggregates(i)
- val accumulator = agg.merge(accumulatorList(i))
- accumulatorList(i).clear()
- accumulatorList(i).add(accumulator)
- }
+ // insert received accumulator into acc list
+ val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator]
+ accumulatorList(i).set(1, newAcc)
+ // merge acc list
+ val retAcc = aggregates(i).merge(accumulatorList(i))
+ // insert result into acc list
+ accumulatorList(i).set(0, retAcc)
}
// the current rowtime is the last rowtime of the next calculation.
@@ -146,7 +150,7 @@ class DataSetSessionWindowAggregateCombineGroupFunction(
// merge the accumulators into one accumulator
for (i <- aggregates.indices) {
- aggregateBuffer.setField(accumStartPos + i, aggregates(i).merge(accumulatorList(i)))
+ aggregateBuffer.setField(accumStartPos + i, accumulatorList(i).get(0))
}
// intermediate Row WindowStartPos is rowtime pos.
http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala
index 1570671..ebef211 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala
@@ -64,13 +64,13 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
private var aggregateBuffer: Row = _
private var output: Row = _
private var collector: TimeWindowPropertyCollector = _
- private var accumStartPos: Int = groupKeysMapping.length
- private var intermediateRowArity: Int = accumStartPos + aggregates.length + 2
- private var intermediateRowWindowStartPos = intermediateRowArity - 2
- private var intermediateRowWindowEndPos = intermediateRowArity - 1
- private val maxMergeLen = 16
- val accumulatorList = Array.fill(aggregates.length) {
- new JArrayList[Accumulator]()
+ private val accumStartPos: Int = groupKeysMapping.length
+ private val intermediateRowArity: Int = accumStartPos + aggregates.length + 2
+ private val intermediateRowWindowStartPos = intermediateRowArity - 2
+ private val intermediateRowWindowEndPos = intermediateRowArity - 1
+
+ val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator](2)
}
override def open(config: Configuration) {
@@ -79,6 +79,13 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
aggregateBuffer = new Row(intermediateRowArity)
output = new Row(finalRowArity)
collector = new TimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos)
+
+ // init lists with two empty accumulators
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).add(accumulator)
+ accumulatorList(i).add(accumulator)
+ }
}
/**
@@ -96,14 +103,17 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
var windowStart: java.lang.Long = null
var windowEnd: java.lang.Long = null
var currentRowTime: java.lang.Long = null
- accumulatorList.foreach(_.clear())
+
+ // reset first accumulator in merge list
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
val iterator = records.iterator()
- var count: Int = 0
while (iterator.hasNext) {
val record = iterator.next()
- count += 1
currentRowTime = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long]
// initial traversal or opening a new window
if (null == windowEnd ||
@@ -114,9 +124,11 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
// evaluate and emit the current window's result.
doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd)
- // clear the accumulator list for all aggregate
- accumulatorList.foreach(_.clear())
- count = 0
+ // reset first accumulator in list
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
} else {
// set group keys value to final output.
groupKeysMapping.foreach {
@@ -128,21 +140,14 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
windowStart = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long]
}
- // collect the accumulators for each aggregate
for (i <- aggregates.indices) {
- accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator])
- }
-
- // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
- // accumulator
- if (count > maxMergeLen) {
- count = 0
- for (i <- aggregates.indices) {
- val agg = aggregates(i)
- val accumulator = agg.merge(accumulatorList(i))
- accumulatorList(i).clear()
- accumulatorList(i).add(accumulator)
- }
+ // insert received accumulator into acc list
+ val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator]
+ accumulatorList(i).set(1, newAcc)
+ // merge acc list
+ val retAcc = aggregates(i).merge(accumulatorList(i))
+ // insert result into acc list
+ accumulatorList(i).set(0, retAcc)
}
windowEnd = if (isInputCombined) {
@@ -178,8 +183,7 @@ class DataSetSessionWindowAggregateReduceGroupFunction(
aggregateMapping.foreach {
case (after, previous) =>
val agg = aggregates(previous)
- val accum = agg.merge(accumulatorList(previous))
- output.setField(after, agg.getValue(accum))
+ output.setField(after, agg.getValue(accumulatorList(previous).get(0)))
}
// adds TimeWindow properties to output then emit output
http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
index b722330..85df1d8 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala
@@ -51,9 +51,9 @@ class DataSetTumbleCountWindowAggReduceGroupFunction(
private var output: Row = _
private val accumStartPos: Int = groupKeysMapping.length
private val intermediateRowArity: Int = accumStartPos + aggregates.length + 1
- private val maxMergeLen = 16
- val accumulatorList = Array.fill(aggregates.length) {
- new JArrayList[Accumulator]()
+
+ val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator](2)
}
override def open(config: Configuration) {
@@ -61,37 +61,41 @@ class DataSetTumbleCountWindowAggReduceGroupFunction(
Preconditions.checkNotNull(groupKeysMapping)
aggregateBuffer = new Row(intermediateRowArity)
output = new Row(finalRowArity)
+
+ // init lists with two empty accumulators
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).add(accumulator)
+ accumulatorList(i).add(accumulator)
+ }
}
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
var count: Long = 0
- accumulatorList.foreach(_.clear())
-
val iterator = records.iterator()
while (iterator.hasNext) {
- val record = iterator.next()
if (count == 0) {
- // clear the accumulator list for all aggregate
- accumulatorList.foreach(_.clear())
+ // reset first accumulator
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
}
- // collect the accumulators for each aggregate
- for (i <- aggregates.indices) {
- accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator])
- }
+ val record = iterator.next()
count += 1
- // for every maxMergeLen accumulators, we merge them into one
- if (count % maxMergeLen == 0) {
- for (i <- aggregates.indices) {
- val agg = aggregates(i)
- val accumulator = agg.merge(accumulatorList(i))
- accumulatorList(i).clear()
- accumulatorList(i).add(accumulator)
- }
+ for (i <- aggregates.indices) {
+ // insert received accumulator into acc list
+ val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator]
+ accumulatorList(i).set(1, newAcc)
+ // merge acc list
+ val retAcc = aggregates(i).merge(accumulatorList(i))
+ // insert result into acc list
+ accumulatorList(i).set(0, retAcc)
}
if (windowSize == count) {
@@ -105,8 +109,7 @@ class DataSetTumbleCountWindowAggReduceGroupFunction(
aggregateMapping.foreach {
case (after, previous) =>
val agg = aggregates(previous)
- val accumulator = agg.merge(accumulatorList(previous))
- output.setField(after, agg.getValue(accumulator))
+ output.setField(after, agg.getValue(accumulatorList(previous).get(0)))
}
// emit the output
http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
index d507a58..df8bed9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala
@@ -18,7 +18,6 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.CombineFunction
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
@@ -68,38 +67,33 @@ class DataSetTumbleTimeWindowAggReduceCombineFunction(
override def combine(records: Iterable[Row]): Row = {
var last: Row = null
- accumulatorList.foreach(_.clear())
-
val iterator = records.iterator()
- var count: Int = 0
+ // reset first accumulator in merge list
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
+
while (iterator.hasNext) {
val record = iterator.next()
- count += 1
- // per each aggregator, collect its accumulators to a list
+
for (i <- aggregates.indices) {
- accumulatorList(i).add(record.getField(groupKeysMapping.length + i)
- .asInstanceOf[Accumulator])
- }
- // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
- // accumulator
- if (count > maxMergeLen) {
- count = 0
- for (i <- aggregates.indices) {
- val agg = aggregates(i)
- val accumulator = agg.merge(accumulatorList(i))
- accumulatorList(i).clear()
- accumulatorList(i).add(accumulator)
- }
+ // insert received accumulator into acc list
+ val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator]
+ accumulatorList(i).set(1, newAcc)
+ // merge acc list
+ val retAcc = aggregates(i).merge(accumulatorList(i))
+ // insert result into acc list
+ accumulatorList(i).set(0, retAcc)
}
+
last = record
}
- // per each aggregator, merge list of accumulators into one and save the result to the
- // intermediate aggregate buffer
+ // set the partial merged result to the aggregateBuffer
for (i <- aggregates.indices) {
- val agg = aggregates(i)
- aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i)))
+ aggregateBuffer.setField(groupKeysMapping.length + i, accumulatorList(i).get(0))
}
// set group keys to aggregateBuffer.
http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
index 63d2aeb..7ce0bf1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala
@@ -57,9 +57,10 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction(
private val accumStartPos: Int = groupKeysMapping.length
private val rowtimePos: Int = accumStartPos + aggregates.length
private val intermediateRowArity: Int = rowtimePos + 1
- protected val maxMergeLen = 16
- val accumulatorList = Array.fill(aggregates.length) {
- new JArrayList[Accumulator]()
+
+
+ val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) {
+ new JArrayList[Accumulator](2)
}
override def open(config: Configuration) {
@@ -68,34 +69,39 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction(
aggregateBuffer = new Row(intermediateRowArity)
output = new Row(finalRowArity)
collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos)
+
+ // init lists with two empty accumulators
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).add(accumulator)
+ accumulatorList(i).add(accumulator)
+ }
}
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
var last: Row = null
- accumulatorList.foreach(_.clear())
-
val iterator = records.iterator()
- var count: Int = 0
+ // reset first accumulator in merge list
+ for (i <- aggregates.indices) {
+ val accumulator = aggregates(i).createAccumulator()
+ accumulatorList(i).set(0, accumulator)
+ }
+
while (iterator.hasNext) {
val record = iterator.next()
- count += 1
- // per each aggregator, collect its accumulators to a list
+
for (i <- aggregates.indices) {
- accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator])
- }
- // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one
- // accumulator
- if (count > maxMergeLen) {
- count = 0
- for (i <- aggregates.indices) {
- val agg = aggregates(i)
- val accumulator = agg.merge(accumulatorList(i))
- accumulatorList(i).clear()
- accumulatorList(i).add(accumulator)
- }
+ // insert received accumulator into acc list
+ val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator]
+ accumulatorList(i).set(1, newAcc)
+ // merge acc list
+ val retAcc = aggregates(i).merge(accumulatorList(i))
+ // insert result into acc list
+ accumulatorList(i).set(0, retAcc)
}
+
last = record
}
@@ -109,8 +115,7 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction(
aggregateMapping.foreach {
case (after, previous) => {
val agg = aggregates(previous)
- val accumulator = agg.merge(accumulatorList(previous))
- val result = agg.getValue(accumulator)
+ val result = agg.getValue(accumulatorList(previous).get(0))
output.setField(after, result)
}
}
[3/3] flink git commit: [hotfix] [table] Fix initialization of
accumulators for MIN and MAX aggregates.
Posted by fh...@apache.org.
[hotfix] [table] Fix initialization of accumulators for MIN and MAX aggregates.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/2d1721bb
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/2d1721bb
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/2d1721bb
Branch: refs/heads/master
Commit: 2d1721bb9b17333c3c06e3675a24d344aed3c87f
Parents: 050f9a4
Author: Fabian Hueske <fh...@apache.org>
Authored: Thu Mar 2 22:57:47 2017 +0100
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Mar 3 14:27:08 2017 +0100
----------------------------------------------------------------------
.../functions/aggfunctions/MaxAggFunction.scala | 21 +++++++++++++++-----
.../functions/aggfunctions/MinAggFunction.scala | 21 +++++++++++++++-----
2 files changed, 32 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/2d1721bb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
index 62ff88c..33cfd65 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
@@ -26,10 +26,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
/** The initial accumulator for Max aggregate function */
-class MaxAccumulator[T] extends JTuple2[T, Boolean] with Accumulator {
- f0 = 0.asInstanceOf[T] //max
- f1 = false
-}
+class MaxAccumulator[T] extends JTuple2[T, Boolean] with Accumulator
/**
* Base class for built-in Max aggregate function
@@ -39,7 +36,10 @@ class MaxAccumulator[T] extends JTuple2[T, Boolean] with Accumulator {
abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] {
override def createAccumulator(): Accumulator = {
- new MaxAccumulator[T]
+ val acc = new MaxAccumulator[T]
+ acc.f0 = getInitValue
+ acc.f1 = false
+ acc
}
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
@@ -82,6 +82,8 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
BasicTypeInfo.BOOLEAN_TYPE_INFO)
}
+ def getInitValue: T
+
def getValueTypeInfo: TypeInformation[_]
}
@@ -89,6 +91,7 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
* Built-in Byte Max aggregate function
*/
class ByteMaxAggFunction extends MaxAggFunction[Byte] {
+ override def getInitValue: Byte = 0.toByte
override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO
}
@@ -96,6 +99,7 @@ class ByteMaxAggFunction extends MaxAggFunction[Byte] {
* Built-in Short Max aggregate function
*/
class ShortMaxAggFunction extends MaxAggFunction[Short] {
+ override def getInitValue: Short = 0.toShort
override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO
}
@@ -103,6 +107,7 @@ class ShortMaxAggFunction extends MaxAggFunction[Short] {
* Built-in Int Max aggregate function
*/
class IntMaxAggFunction extends MaxAggFunction[Int] {
+ override def getInitValue: Int = 0
override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO
}
@@ -110,6 +115,7 @@ class IntMaxAggFunction extends MaxAggFunction[Int] {
* Built-in Long Max aggregate function
*/
class LongMaxAggFunction extends MaxAggFunction[Long] {
+ override def getInitValue: Long = 0L
override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO
}
@@ -117,6 +123,7 @@ class LongMaxAggFunction extends MaxAggFunction[Long] {
* Built-in Float Max aggregate function
*/
class FloatMaxAggFunction extends MaxAggFunction[Float] {
+ override def getInitValue: Float = 0.0f
override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO
}
@@ -124,6 +131,7 @@ class FloatMaxAggFunction extends MaxAggFunction[Float] {
* Built-in Double Max aggregate function
*/
class DoubleMaxAggFunction extends MaxAggFunction[Double] {
+ override def getInitValue: Double = 0.0d
override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO
}
@@ -131,6 +139,7 @@ class DoubleMaxAggFunction extends MaxAggFunction[Double] {
* Built-in Boolean Max aggregate function
*/
class BooleanMaxAggFunction extends MaxAggFunction[Boolean] {
+ override def getInitValue = false
override def getValueTypeInfo = BasicTypeInfo.BOOLEAN_TYPE_INFO
}
@@ -150,5 +159,7 @@ class DecimalMaxAggFunction extends MaxAggFunction[BigDecimal] {
}
}
+ override def getInitValue = BigDecimal.ZERO
+
override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO
}
http://git-wip-us.apache.org/repos/asf/flink/blob/2d1721bb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
index cddb873..1b2d6b0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
@@ -26,10 +26,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
/** The initial accumulator for Min aggregate function */
-class MinAccumulator[T] extends JTuple2[T, Boolean] with Accumulator {
- f0 = 0.asInstanceOf[T] //min
- f1 = false
-}
+class MinAccumulator[T] extends JTuple2[T, Boolean] with Accumulator
/**
* Base class for built-in Min aggregate function
@@ -39,7 +36,10 @@ class MinAccumulator[T] extends JTuple2[T, Boolean] with Accumulator {
abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] {
override def createAccumulator(): Accumulator = {
- new MinAccumulator[T]
+ val acc = new MinAccumulator[T]
+ acc.f0 = getInitValue
+ acc.f1 = false
+ acc
}
override def accumulate(accumulator: Accumulator, value: Any): Unit = {
@@ -82,6 +82,8 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
BasicTypeInfo.BOOLEAN_TYPE_INFO)
}
+ def getInitValue: T
+
def getValueTypeInfo: TypeInformation[_]
}
@@ -89,6 +91,7 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun
* Built-in Byte Min aggregate function
*/
class ByteMinAggFunction extends MinAggFunction[Byte] {
+ override def getInitValue: Byte = 0.toByte
override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO
}
@@ -96,6 +99,7 @@ class ByteMinAggFunction extends MinAggFunction[Byte] {
* Built-in Short Min aggregate function
*/
class ShortMinAggFunction extends MinAggFunction[Short] {
+ override def getInitValue: Short = 0.toShort
override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO
}
@@ -103,6 +107,7 @@ class ShortMinAggFunction extends MinAggFunction[Short] {
* Built-in Int Min aggregate function
*/
class IntMinAggFunction extends MinAggFunction[Int] {
+ override def getInitValue: Int = 0
override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO
}
@@ -110,6 +115,7 @@ class IntMinAggFunction extends MinAggFunction[Int] {
* Built-in Long Min aggregate function
*/
class LongMinAggFunction extends MinAggFunction[Long] {
+ override def getInitValue: Long = 0L
override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO
}
@@ -117,6 +123,7 @@ class LongMinAggFunction extends MinAggFunction[Long] {
* Built-in Float Min aggregate function
*/
class FloatMinAggFunction extends MinAggFunction[Float] {
+ override def getInitValue: Float = 0.0f
override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO
}
@@ -124,6 +131,7 @@ class FloatMinAggFunction extends MinAggFunction[Float] {
* Built-in Double Min aggregate function
*/
class DoubleMinAggFunction extends MinAggFunction[Double] {
+ override def getInitValue: Double = 0.0d
override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO
}
@@ -131,6 +139,7 @@ class DoubleMinAggFunction extends MinAggFunction[Double] {
* Built-in Boolean Min aggregate function
*/
class BooleanMinAggFunction extends MinAggFunction[Boolean] {
+ override def getInitValue: Boolean = false
override def getValueTypeInfo = BasicTypeInfo.BOOLEAN_TYPE_INFO
}
@@ -150,5 +159,7 @@ class DecimalMinAggFunction extends MinAggFunction[BigDecimal] {
}
}
+ override def getInitValue: BigDecimal = BigDecimal.ZERO
+
override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO
}