You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by mb...@apache.org on 2014/11/09 20:08:46 UTC
[2/4] incubator-flink git commit: [streaming] Aggregation rework to
support field expression based aggregations for Pojo data streams
[streaming] Aggregation rework to support field expression based aggregations for Pojo data streams
Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/7ae58042
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/7ae58042
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/7ae58042
Branch: refs/heads/master
Commit: 7ae58042787bf307304d59182ba1f0825a480917
Parents: c9f3846
Author: Gyula Fora <gy...@apache.org>
Authored: Wed Nov 5 16:23:22 2014 +0100
Committer: mbalassi <ba...@gmail.com>
Committed: Sun Nov 9 13:16:46 2014 +0100
----------------------------------------------------------------------
.../api/datastream/BatchedDataStream.java | 118 +++++++--
.../streaming/api/datastream/DataStream.java | 118 ++++++++-
.../aggregation/AggregationFunction.java | 17 +-
.../ComparableAggregationFunction.java | 83 -------
.../aggregation/ComparableAggregator.java | 243 +++++++++++++++++++
.../api/function/aggregation/Comparator.java | 104 ++++++++
.../aggregation/MaxAggregationFunction.java | 34 ---
.../aggregation/MaxByAggregationFunction.java | 39 ---
.../aggregation/MinAggregationFunction.java | 34 ---
.../aggregation/MinByAggregationFunction.java | 70 ------
.../aggregation/SumAggregationFunction.java | 159 ------------
.../api/function/aggregation/SumAggregator.java | 173 +++++++++++++
.../api/function/aggregation/SumFunction.java | 102 ++++++++
.../streaming/api/AggregationFunctionTest.java | 80 +++---
14 files changed, 877 insertions(+), 497 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
index 86cb90b..c8a49c6 100755
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
@@ -24,11 +24,9 @@ import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.streaming.api.function.aggregation.AggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MaxAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MaxByAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MinAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MinByAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.SumAggregationFunction;
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType;
+import org.apache.flink.streaming.api.function.aggregation.ComparableAggregator;
+import org.apache.flink.streaming.api.function.aggregation.SumAggregator;
import org.apache.flink.streaming.api.invokable.StreamInvokable;
import org.apache.flink.streaming.api.invokable.operator.BatchGroupReduceInvokable;
import org.apache.flink.streaming.api.invokable.operator.BatchReduceInvokable;
@@ -151,11 +149,10 @@ public class BatchedDataStream<OUT> {
* The position in the data point to sum
* @return The transformed DataStream.
*/
- @SuppressWarnings("unchecked")
public SingleOutputStreamOperator<OUT, ?> sum(int positionToSum) {
dataStream.checkFieldRange(positionToSum);
- return aggregate((AggregationFunction<OUT>) SumAggregationFunction.getSumFunction(
- positionToSum, dataStream.getClassAtPos(positionToSum), dataStream.getOutputType()));
+ return aggregate((AggregationFunction<OUT>) SumAggregator.getSumFunction(positionToSum,
+ dataStream.getClassAtPos(positionToSum), dataStream.getOutputType()));
}
/**
@@ -168,6 +165,23 @@ public class BatchedDataStream<OUT> {
}
/**
+ * Applies an aggregation that that gives the sum of the pojo data stream at
+ * the given field expression. A field expression is either the name of a
+ * public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> sum(String field) {
+ return aggregate((AggregationFunction<OUT>) SumAggregator.getSumFunction(field,
+ getOutputType()));
+ }
+
+ /**
* Applies an aggregation that that gives the minimum of every sliding
* batch/window of the data stream at the given position.
*
@@ -177,7 +191,8 @@ public class BatchedDataStream<OUT> {
*/
public SingleOutputStreamOperator<OUT, ?> min(int positionToMin) {
dataStream.checkFieldRange(positionToMin);
- return aggregate(new MinAggregationFunction<OUT>(positionToMin, dataStream.getOutputType()));
+ return aggregate(ComparableAggregator.getAggregator(positionToMin, getOutputType(),
+ AggregationType.MIN));
}
/**
@@ -209,8 +224,8 @@ public class BatchedDataStream<OUT> {
*/
public SingleOutputStreamOperator<OUT, ?> minBy(int positionToMinBy, boolean first) {
dataStream.checkFieldRange(positionToMinBy);
- return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first,
- dataStream.getOutputType()));
+ return aggregate(ComparableAggregator.getAggregator(positionToMinBy, getOutputType(),
+ AggregationType.MINBY, first));
}
/**
@@ -232,7 +247,8 @@ public class BatchedDataStream<OUT> {
*/
public SingleOutputStreamOperator<OUT, ?> max(int positionToMax) {
dataStream.checkFieldRange(positionToMax);
- return aggregate(new MaxAggregationFunction<OUT>(positionToMax, dataStream.getOutputType()));
+ return aggregate(ComparableAggregator.getAggregator(positionToMax, getOutputType(),
+ AggregationType.MAX));
}
/**
@@ -263,8 +279,8 @@ public class BatchedDataStream<OUT> {
*/
public SingleOutputStreamOperator<OUT, ?> maxBy(int positionToMaxBy, boolean first) {
dataStream.checkFieldRange(positionToMaxBy);
- return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first,
- dataStream.getOutputType()));
+ return aggregate(ComparableAggregator.getAggregator(positionToMaxBy, getOutputType(),
+ AggregationType.MAXBY, first));
}
/**
@@ -277,6 +293,80 @@ public class BatchedDataStream<OUT> {
}
/**
+ * Applies an aggregation that that gives the minimum of the pojo data
+ * stream at the given field expression. A field expression is either the
+ * name of a public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> min(String field) {
+ return aggregate(ComparableAggregator.getAggregator(field, getOutputType(),
+ AggregationType.MIN, false));
+ }
+
+ /**
+ * Applies an aggregation that that gives the maximum of the pojo data
+ * stream at the given field expression. A field expression is either the
+ * name of a public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> max(String field) {
+ return aggregate(ComparableAggregator.getAggregator(field, getOutputType(),
+ AggregationType.MAX, false));
+ }
+
+ /**
+ * Applies an aggregation that that gives the minimum element of the pojo
+ * data stream by the given field expression. A field expression is either
+ * the name of a public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @param first
+ * If True then in case of field equality the first object will
+ * be returned
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> minBy(String field, boolean first) {
+ return aggregate(ComparableAggregator.getAggregator(field, getOutputType(),
+ AggregationType.MINBY, first));
+ }
+
+ /**
+ * Applies an aggregation that that gives the maximum element of the pojo
+ * data stream by the given field expression. A field expression is either
+ * the name of a public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @param first
+ * If True then in case of field equality the first object will
+ * be returned
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> maxBy(String field, boolean first) {
+ return aggregate(ComparableAggregator.getAggregator(field, getOutputType(),
+ AggregationType.MAXBY, first));
+ }
+
+ /**
* Gets the output type.
*
* @return The output type.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
index b50c42d..991b6d7 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
@@ -42,11 +42,9 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.streaming.api.JobGraphBuilder;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.function.aggregation.AggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MaxAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MaxByAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MinAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MinByAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.SumAggregationFunction;
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType;
+import org.apache.flink.streaming.api.function.aggregation.ComparableAggregator;
+import org.apache.flink.streaming.api.function.aggregation.SumAggregator;
import org.apache.flink.streaming.api.function.sink.PrintSinkFunction;
import org.apache.flink.streaming.api.function.sink.SinkFunction;
import org.apache.flink.streaming.api.function.sink.WriteFormatAsCsv;
@@ -763,11 +761,27 @@ public class DataStream<OUT> {
* The position in the data point to sum
* @return The transformed DataStream.
*/
- @SuppressWarnings("unchecked")
public SingleOutputStreamOperator<OUT, ?> sum(int positionToSum) {
checkFieldRange(positionToSum);
- return aggregate((AggregationFunction<OUT>) SumAggregationFunction.getSumFunction(
- positionToSum, getClassAtPos(positionToSum), getOutputType()));
+ return aggregate((AggregationFunction<OUT>) SumAggregator.getSumFunction(positionToSum,
+ getClassAtPos(positionToSum), getOutputType()));
+ }
+
+ /**
+ * Applies an aggregation that that gives the sum of the pojo data stream at
+ * the given field expression. A field expression is either the name of a
+ * public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> sum(String field) {
+ return aggregate((AggregationFunction<OUT>) SumAggregator.getSumFunction(field,
+ getOutputType()));
}
/**
@@ -789,7 +803,82 @@ public class DataStream<OUT> {
*/
public SingleOutputStreamOperator<OUT, ?> min(int positionToMin) {
checkFieldRange(positionToMin);
- return aggregate(new MinAggregationFunction<OUT>(positionToMin, getOutputType()));
+ return aggregate(ComparableAggregator.getAggregator(positionToMin, getOutputType(),
+ AggregationType.MIN));
+ }
+
+ /**
+ * Applies an aggregation that that gives the minimum of the pojo data
+ * stream at the given field expression. A field expression is either the
+ * name of a public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> min(String field) {
+ return aggregate(ComparableAggregator.getAggregator(field, getOutputType(),
+ AggregationType.MIN, false));
+ }
+
+ /**
+ * Applies an aggregation that that gives the maximum of the pojo data
+ * stream at the given field expression. A field expression is either the
+ * name of a public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> max(String field) {
+ return aggregate(ComparableAggregator.getAggregator(field, getOutputType(),
+ AggregationType.MAX, false));
+ }
+
+ /**
+ * Applies an aggregation that that gives the minimum element of the pojo
+ * data stream by the given field expression. A field expression is either
+ * the name of a public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @param first
+ * If True then in case of field equality the first object will
+ * be returned
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> minBy(String field, boolean first) {
+ return aggregate(ComparableAggregator.getAggregator(field, getOutputType(),
+ AggregationType.MINBY, first));
+ }
+
+ /**
+ * Applies an aggregation that that gives the maximum element of the pojo
+ * data stream by the given field expression. A field expression is either
+ * the name of a public field or a getter method with parentheses of the
+ * {@link DataStream}S underlying type. A dot can be used to drill down into
+ * objects, as in {@code "field1.getInnerField2()" }.
+ *
+ * @param field
+ * The field expression based on which the aggregation will be
+ * applied.
+ * @param first
+ * If True then in case of field equality the first object will
+ * be returned
+ * @return The transformed DataStream.
+ */
+ public SingleOutputStreamOperator<OUT, ?> maxBy(String field, boolean first) {
+ return aggregate(ComparableAggregator.getAggregator(field, getOutputType(),
+ AggregationType.MAXBY, first));
}
/**
@@ -821,7 +910,8 @@ public class DataStream<OUT> {
*/
public SingleOutputStreamOperator<OUT, ?> minBy(int positionToMinBy, boolean first) {
checkFieldRange(positionToMinBy);
- return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first, getOutputType()));
+ return aggregate(ComparableAggregator.getAggregator(positionToMinBy, getOutputType(),
+ AggregationType.MINBY, first));
}
/**
@@ -843,7 +933,8 @@ public class DataStream<OUT> {
*/
public SingleOutputStreamOperator<OUT, ?> max(int positionToMax) {
checkFieldRange(positionToMax);
- return aggregate(new MaxAggregationFunction<OUT>(positionToMax, getOutputType()));
+ return aggregate(ComparableAggregator.getAggregator(positionToMax, getOutputType(),
+ AggregationType.MAX));
}
/**
@@ -875,7 +966,8 @@ public class DataStream<OUT> {
*/
public SingleOutputStreamOperator<OUT, ?> maxBy(int positionToMaxBy, boolean first) {
checkFieldRange(positionToMaxBy);
- return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first, getOutputType()));
+ return aggregate(ComparableAggregator.getAggregator(positionToMaxBy, getOutputType(),
+ AggregationType.MAXBY, first));
}
/**
@@ -888,7 +980,7 @@ public class DataStream<OUT> {
}
/**
- * Applies an aggregation that gives the count of the data point.
+ * Applies an aggregation that gives the count of the values.
*
* @return The transformed DataStream.
*/
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
index 825b4db..d95c37e 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
@@ -18,23 +18,18 @@
package org.apache.flink.streaming.api.function.aggregation;
import org.apache.flink.api.common.functions.ReduceFunction;
-import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
-import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple;
public abstract class AggregationFunction<T> implements ReduceFunction<T> {
private static final long serialVersionUID = 1L;
- public int position;
- protected Tuple returnTuple;
- protected boolean isTuple;
- protected boolean isArray;
+ int position;
- public AggregationFunction(int pos, TypeInformation<?> type) {
+ public AggregationFunction(int pos) {
this.position = pos;
- this.isTuple = type.isTupleType();
- this.isArray = type instanceof BasicArrayTypeInfo || type instanceof PrimitiveArrayTypeInfo;
+ }
+
+ public static enum AggregationType {
+ SUM, MIN, MAX, MINBY, MAXBY,
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java
deleted file mode 100644
index 383c39c..0000000
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.function.aggregation;
-
-import java.lang.reflect.Array;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple;
-
-public abstract class ComparableAggregationFunction<T> extends AggregationFunction<T> {
-
- private static final long serialVersionUID = 1L;
-
- public ComparableAggregationFunction(int positionToAggregate, TypeInformation<?> type) {
- super(positionToAggregate, type);
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public T reduce(T value1, T value2) throws Exception {
- if (isTuple) {
- Tuple t1 = (Tuple) value1;
- Tuple t2 = (Tuple) value2;
-
- compare(t1, t2);
-
- return (T) returnTuple;
- } else if (isArray) {
- return compareArray(value1, value2);
- } else if (value1 instanceof Comparable) {
- if (isExtremal((Comparable<Object>) value1, value2)) {
- return value1;
- } else {
- return value2;
- }
- } else {
- throw new RuntimeException("The values " + value1 + " and " + value2
- + " cannot be compared.");
- }
- }
-
- @SuppressWarnings("unchecked")
- public T compareArray(T array1, T array2) {
- Object v1 = Array.get(array1, position);
- Object v2 = Array.get(array2, position);
- if (isExtremal((Comparable<Object>) v1, v2)) {
- Array.set(array2, position, v1);
- } else {
- Array.set(array2, position, v2);
- }
-
- return array2;
- }
-
- public <R> void compare(Tuple tuple1, Tuple tuple2) throws InstantiationException,
- IllegalAccessException {
-
- Comparable<R> o1 = tuple1.getField(position);
- R o2 = tuple2.getField(position);
-
- if (isExtremal(o1, o2)) {
- tuple2.setField(o1, position);
- }
- returnTuple = tuple2;
- }
-
- public abstract <R> boolean isExtremal(Comparable<R> o1, R o2);
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
new file mode 100644
index 0000000..6e2a400
--- /dev/null
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.function.aggregation;
+
+import java.lang.reflect.Array;
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.CompositeType;
+import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.typeutils.PojoTypeInfo;
+import org.apache.flink.api.java.typeutils.runtime.PojoComparator;
+
+public abstract class ComparableAggregator<T> extends AggregationFunction<T> {
+
+ private static final long serialVersionUID = 1L;
+
+ Comparator comparator;
+ boolean byAggregate;
+ boolean first;
+
+ private ComparableAggregator(int pos, AggregationType aggregationType, boolean first) {
+ super(pos);
+ this.comparator = Comparator.getForAggregation(aggregationType);
+ this.byAggregate = (aggregationType == AggregationType.MAXBY)
+ || (aggregationType == AggregationType.MINBY);
+ this.first = first;
+ }
+
+ public static <R> AggregationFunction<R> getAggregator(int positionToAggregate,
+ TypeInformation<R> typeInfo, AggregationType aggregationType) {
+ return getAggregator(positionToAggregate, typeInfo, aggregationType, false);
+ }
+
+ public static <R> AggregationFunction<R> getAggregator(int positionToAggregate,
+ TypeInformation<R> typeInfo, AggregationType aggregationType, boolean first) {
+
+ if (typeInfo.isTupleType()) {
+ return new TupleComparableAggregator<R>(positionToAggregate, aggregationType, first);
+ } else if (typeInfo instanceof BasicArrayTypeInfo
+ || typeInfo instanceof PrimitiveArrayTypeInfo) {
+ return new ArrayComparableAggregator<R>(positionToAggregate, aggregationType, first);
+ } else {
+ return new SimpleComparableAggregator<R>(aggregationType);
+ }
+ }
+
+ public static <R> AggregationFunction<R> getAggregator(String field,
+ TypeInformation<R> typeInfo, AggregationType aggregationType, boolean first) {
+
+ return new PojoComparableAggregator<R>(field, typeInfo, aggregationType, first);
+ }
+
+ private static class TupleComparableAggregator<T> extends ComparableAggregator<T> {
+
+ private static final long serialVersionUID = 1L;
+
+ public TupleComparableAggregator(int pos, AggregationType aggregationType, boolean first) {
+ super(pos, aggregationType, first);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public T reduce(T value1, T value2) throws Exception {
+ Tuple tuple1 = (Tuple) value1;
+ Tuple tuple2 = (Tuple) value2;
+
+ Comparable<Object> o1 = tuple1.getField(position);
+ Object o2 = tuple2.getField(position);
+
+ int c = comparator.isExtremal(o1, o2);
+
+ if (byAggregate) {
+ if (c == 1) {
+ return (T) tuple1;
+ }
+ if (first) {
+ if (c == 0) {
+ return (T) tuple1;
+ }
+ }
+
+ return (T) tuple2;
+
+ } else {
+ if (c == 1) {
+ tuple2.setField(o1, position);
+ }
+ return (T) tuple2;
+ }
+
+ }
+ }
+
+ private static class ArrayComparableAggregator<T> extends ComparableAggregator<T> {
+
+ private static final long serialVersionUID = 1L;
+
+ public ArrayComparableAggregator(int pos, AggregationType aggregationType, boolean first) {
+ super(pos, aggregationType, first);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public T reduce(T array1, T array2) throws Exception {
+
+ Object v1 = Array.get(array1, position);
+ Object v2 = Array.get(array2, position);
+
+ int c = comparator.isExtremal((Comparable<Object>) v1, v2);
+
+ if (byAggregate) {
+ if (c == 1) {
+ return array1;
+ }
+ if (first) {
+ if (c == 0) {
+ return array1;
+ }
+ }
+
+ return array2;
+ } else {
+ if (c == 1) {
+ Array.set(array2, position, v1);
+ } else {
+ Array.set(array2, position, v2);
+ }
+
+ return array2;
+ }
+ }
+
+ }
+
+ private static class SimpleComparableAggregator<T> extends ComparableAggregator<T> {
+
+ private static final long serialVersionUID = 1L;
+
+ public SimpleComparableAggregator(AggregationType aggregationType) {
+ super(0, aggregationType, false);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public T reduce(T value1, T value2) throws Exception {
+
+ if (comparator.isExtremal((Comparable<Object>) value1, value2) == 1) {
+ return value1;
+ } else {
+ return value2;
+ }
+ }
+
+ }
+
+ private static class PojoComparableAggregator<T> extends ComparableAggregator<T> {
+
+ private static final long serialVersionUID = 1L;
+ PojoComparator<T> pojoComparator;
+
+ public PojoComparableAggregator(String field, TypeInformation<?> typeInfo,
+ AggregationType aggregationType, boolean first) {
+ super(0, aggregationType, first);
+ if (!(typeInfo instanceof CompositeType<?>)) {
+ throw new IllegalArgumentException(
+ "Key expressions are only supported on POJO types and Tuples. "
+ + "A type is considered a POJO if all its fields are public, or have both getters and setters defined");
+ }
+
+ @SuppressWarnings("unchecked")
+ CompositeType<T> cType = (CompositeType<T>) typeInfo;
+
+ List<FlatFieldDescriptor> fieldDescriptors = new ArrayList<FlatFieldDescriptor>();
+ cType.getKey(field, 0, fieldDescriptors);
+
+ int logicalKeyPosition = fieldDescriptors.get(0).getPosition();
+
+ if (cType instanceof PojoTypeInfo) {
+ pojoComparator = (PojoComparator<T>) cType.createComparator(
+ new int[] { logicalKeyPosition }, new boolean[] { false }, 0);
+ } else {
+ throw new IllegalArgumentException(
+ "Key expressions are only supported on POJO types. "
+ + "A type is considered a POJO if all its fields are public, or have both getters and setters defined");
+ }
+ }
+
+ @Override
+ public T reduce(T value1, T value2) throws Exception {
+
+ Field[] keyFields = pojoComparator.getKeyFields();
+ Object field1 = pojoComparator.accessField(keyFields[0], value1);
+ Object field2 = pojoComparator.accessField(keyFields[0], value2);
+
+ @SuppressWarnings("unchecked")
+ int c = comparator.isExtremal((Comparable<Object>) field1, field2);
+
+ if (byAggregate) {
+ if (c == 1) {
+ return value1;
+ }
+ if (first) {
+ if (c == 0) {
+ return value1;
+ }
+ }
+
+ return value2;
+ } else {
+ if (c == 1) {
+ keyFields[0].set(value2, field1);
+ } else {
+ keyFields[0].set(value2, field2);
+ }
+
+ return value2;
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/Comparator.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/Comparator.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/Comparator.java
new file mode 100644
index 0000000..f56774b
--- /dev/null
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/Comparator.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.function.aggregation;
+
+import java.io.Serializable;
+
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType;
+
+public abstract class Comparator implements Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ public abstract <R> int isExtremal(Comparable<R> o1, R o2);
+
+ public static Comparator getForAggregation(AggregationType type) {
+ switch (type) {
+ case MAX:
+ return new MaxComparator();
+ case MIN:
+ return new MinComparator();
+ case MINBY:
+ return new MinByComparator();
+ case MAXBY:
+ return new MaxByComparator();
+ default:
+ throw new IllegalArgumentException("Unsupported aggregation type.");
+ }
+ }
+
+ private static class MaxComparator extends Comparator {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public <R> int isExtremal(Comparable<R> o1, R o2) {
+ return o1.compareTo(o2) > 0 ? 1 : 0;
+ }
+
+ }
+
+ private static class MaxByComparator extends Comparator {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public <R> int isExtremal(Comparable<R> o1, R o2) {
+ int c = o1.compareTo(o2);
+ if (c > 0) {
+ return 1;
+ }
+ if (c == 0) {
+ return 0;
+ } else {
+ return -1;
+ }
+ }
+
+ }
+
+ private static class MinByComparator extends Comparator {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public <R> int isExtremal(Comparable<R> o1, R o2) {
+ int c = o1.compareTo(o2);
+ if (c < 0) {
+ return 1;
+ }
+ if (c == 0) {
+ return 0;
+ } else {
+ return -1;
+ }
+ }
+
+ }
+
+ private static class MinComparator extends Comparator {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public <R> int isExtremal(Comparable<R> o1, R o2) {
+ return o1.compareTo(o2) < 0 ? 1 : 0;
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java
deleted file mode 100644
index d013162..0000000
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.function.aggregation;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-
-public class MaxAggregationFunction<T> extends ComparableAggregationFunction<T> {
-
- private static final long serialVersionUID = 1L;
-
- public MaxAggregationFunction(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @Override
- public <R> boolean isExtremal(Comparable<R> o1, R o2) {
- return o1.compareTo(o2) > 0;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java
deleted file mode 100644
index 4679028..0000000
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.function.aggregation;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-
-public class MaxByAggregationFunction<T> extends MinByAggregationFunction<T> {
-
- private static final long serialVersionUID = 1L;
-
- public MaxByAggregationFunction(int pos, boolean first, TypeInformation<?> type) {
- super(pos, first, type);
- }
-
- @Override
- public <R> boolean isExtremal(Comparable<R> o1, R o2) {
- if (first) {
- return o1.compareTo(o2) >= 0;
- } else {
- return o1.compareTo(o2) > 0;
- }
-
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java
deleted file mode 100644
index 83c20c7..0000000
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.function.aggregation;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-
-public class MinAggregationFunction<T> extends ComparableAggregationFunction<T> {
-
- private static final long serialVersionUID = 1L;
-
- public MinAggregationFunction(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @Override
- public <R> boolean isExtremal(Comparable<R> o1, R o2) {
- return o1.compareTo(o2) < 0;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java
deleted file mode 100644
index 31d6b37..0000000
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.function.aggregation;
-
-import java.lang.reflect.Array;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple;
-
-public class MinByAggregationFunction<T> extends ComparableAggregationFunction<T> {
-
- private static final long serialVersionUID = 1L;
- protected boolean first;
-
- public MinByAggregationFunction(int pos, boolean first, TypeInformation<?> type) {
- super(pos, type);
- this.first = first;
- }
-
- @Override
- public <R> void compare(Tuple tuple1, Tuple tuple2) throws InstantiationException,
- IllegalAccessException {
-
- Comparable<R> o1 = tuple1.getField(position);
- R o2 = tuple2.getField(position);
-
- if (isExtremal(o1, o2)) {
- returnTuple = tuple1;
- } else {
- returnTuple = tuple2;
- }
- }
-
- @Override
- @SuppressWarnings("unchecked")
- public T compareArray(T array1, T array2) {
- Object v1 = Array.get(array1, position);
- Object v2 = Array.get(array2, position);
- if (isExtremal((Comparable<Object>) v1, v2)) {
- return array1;
- } else {
- return array2;
- }
- }
-
- @Override
- public <R> boolean isExtremal(Comparable<R> o1, R o2) {
- if (first) {
- return o1.compareTo(o2) <= 0;
- } else {
- return o1.compareTo(o2) < 0;
- }
-
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java
deleted file mode 100644
index cd50072..0000000
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java
+++ /dev/null
@@ -1,159 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.function.aggregation;
-
-import java.lang.reflect.Array;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple;
-
-public abstract class SumAggregationFunction<T> extends AggregationFunction<T> {
-
- private static final long serialVersionUID = 1L;
-
- public SumAggregationFunction(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public T reduce(T value1, T value2) throws Exception {
- if (isTuple) {
- Tuple tuple1 = (Tuple) value1;
- Tuple tuple2 = (Tuple) value2;
-
- returnTuple = tuple2;
- returnTuple.setField(add(tuple1.getField(position), tuple2.getField(position)),
- position);
-
- return (T) returnTuple;
- } else if (isArray) {
- Object v1 = Array.get(value1, position);
- Object v2 = Array.get(value2, position);
- Array.set(value2, position, add(v1, v2));
- return value2;
- } else {
- return (T) add(value1, value2);
- }
- }
-
- protected abstract Object add(Object value1, Object value2);
-
- @SuppressWarnings("rawtypes")
- public static <T> SumAggregationFunction getSumFunction(int pos, Class<T> classAtPos,
- TypeInformation<?> typeInfo) {
-
- if (classAtPos == Integer.class) {
- return new IntSum<T>(pos, typeInfo);
- } else if (classAtPos == Long.class) {
- return new LongSum<T>(pos, typeInfo);
- } else if (classAtPos == Short.class) {
- return new ShortSum<T>(pos, typeInfo);
- } else if (classAtPos == Double.class) {
- return new DoubleSum<T>(pos, typeInfo);
- } else if (classAtPos == Float.class) {
- return new FloatSum<T>(pos, typeInfo);
- } else if (classAtPos == Byte.class) {
- return new ByteSum<T>(pos, typeInfo);
- } else {
- throw new RuntimeException("DataStream cannot be summed because the class "
- + classAtPos.getSimpleName() + " does not support the + operator.");
- }
-
- }
-
- private static class IntSum<T> extends SumAggregationFunction<T> {
- private static final long serialVersionUID = 1L;
-
- public IntSum(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @Override
- protected Object add(Object value1, Object value2) {
- return (Integer) value1 + (Integer) value2;
- }
- }
-
- private static class LongSum<T> extends SumAggregationFunction<T> {
- private static final long serialVersionUID = 1L;
-
- public LongSum(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @Override
- protected Object add(Object value1, Object value2) {
- return (Long) value1 + (Long) value2;
- }
- }
-
- private static class DoubleSum<T> extends SumAggregationFunction<T> {
-
- private static final long serialVersionUID = 1L;
-
- public DoubleSum(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @Override
- protected Object add(Object value1, Object value2) {
- return (Double) value1 + (Double) value2;
- }
- }
-
- private static class ShortSum<T> extends SumAggregationFunction<T> {
- private static final long serialVersionUID = 1L;
-
- public ShortSum(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @Override
- protected Object add(Object value1, Object value2) {
- return (Short) value1 + (Short) value2;
- }
- }
-
- private static class FloatSum<T> extends SumAggregationFunction<T> {
- private static final long serialVersionUID = 1L;
-
- public FloatSum(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @Override
- protected Object add(Object value1, Object value2) {
- return (Float) value1 + (Float) value2;
- }
- }
-
- private static class ByteSum<T> extends SumAggregationFunction<T> {
- private static final long serialVersionUID = 1L;
-
- public ByteSum(int pos, TypeInformation<?> type) {
- super(pos, type);
- }
-
- @Override
- protected Object add(Object value1, Object value2) {
- return (Byte) value1 + (Byte) value2;
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregator.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregator.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregator.java
new file mode 100644
index 0000000..384b4f6
--- /dev/null
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregator.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.function.aggregation;
+
+import java.lang.reflect.Array;
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.CompositeType;
+import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.typeutils.PojoTypeInfo;
+import org.apache.flink.api.java.typeutils.runtime.PojoComparator;
+
+public abstract class SumAggregator {
+
+ public static <T> ReduceFunction<T> getSumFunction(int pos, Class<?> clazz,
+ TypeInformation<T> typeInfo) {
+
+ if (typeInfo.isTupleType()) {
+ return new TupleSumAggregator<T>(pos, SumFunction.getForClass(clazz));
+ } else if (typeInfo instanceof BasicArrayTypeInfo
+ || typeInfo instanceof PrimitiveArrayTypeInfo) {
+ return new ArraySumAggregator<T>(pos, SumFunction.getForClass(clazz));
+ } else {
+ return new SimpleSumAggregator<T>(SumFunction.getForClass(clazz));
+ }
+
+ }
+
+ public static <T> ReduceFunction<T> getSumFunction(String field, TypeInformation<T> typeInfo) {
+
+ return new PojoSumAggregator<T>(field, typeInfo);
+ }
+
+ private static class TupleSumAggregator<T> extends AggregationFunction<T> {
+
+ private static final long serialVersionUID = 1L;
+
+ SumFunction adder;
+
+ public TupleSumAggregator(int pos, SumFunction adder) {
+ super(pos);
+ this.adder = adder;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public T reduce(T value1, T value2) throws Exception {
+
+ Tuple tuple1 = (Tuple) value1;
+ Tuple tuple2 = (Tuple) value2;
+
+ tuple2.setField(adder.add(tuple1.getField(position), tuple2.getField(position)),
+ position);
+
+ return (T) tuple2;
+ }
+
+ }
+
+ private static class ArraySumAggregator<T> extends AggregationFunction<T> {
+
+ private static final long serialVersionUID = 1L;
+
+ SumFunction adder;
+
+ public ArraySumAggregator(int pos, SumFunction adder) {
+ super(pos);
+ this.adder = adder;
+ }
+
+ @Override
+ public T reduce(T value1, T value2) throws Exception {
+
+ Object v1 = Array.get(value1, position);
+ Object v2 = Array.get(value2, position);
+ Array.set(value2, position, adder.add(v1, v2));
+ return value2;
+ }
+
+ }
+
+ private static class SimpleSumAggregator<T> extends AggregationFunction<T> {
+
+ private static final long serialVersionUID = 1L;
+
+ SumFunction adder;
+
+ public SimpleSumAggregator(SumFunction adder) {
+ super(0);
+ this.adder = adder;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public T reduce(T value1, T value2) throws Exception {
+
+ return (T) adder.add(value1, value2);
+ }
+
+ }
+
+ private static class PojoSumAggregator<T> extends AggregationFunction<T> {
+
+ private static final long serialVersionUID = 1L;
+ SumFunction adder;
+ PojoComparator<T> comparator;
+
+ public PojoSumAggregator(String field, TypeInformation<?> type) {
+ super(0);
+ if (!(type instanceof CompositeType<?>)) {
+ throw new IllegalArgumentException(
+ "Key expressions are only supported on POJO types and Tuples. "
+ + "A type is considered a POJO if all its fields are public, or have both getters and setters defined");
+ }
+
+ @SuppressWarnings("unchecked")
+ CompositeType<T> cType = (CompositeType<T>) type;
+
+ List<FlatFieldDescriptor> fieldDescriptors = new ArrayList<FlatFieldDescriptor>();
+ cType.getKey(field, 0, fieldDescriptors);
+
+ int logicalKeyPosition = fieldDescriptors.get(0).getPosition();
+ Class<?> keyClass = fieldDescriptors.get(0).getType().getTypeClass();
+
+ adder = SumFunction.getForClass(keyClass);
+
+ if (cType instanceof PojoTypeInfo) {
+ comparator = (PojoComparator<T>) cType.createComparator(
+ new int[] { logicalKeyPosition }, new boolean[] { false }, 0);
+ } else {
+ throw new IllegalArgumentException(
+ "Key expressions are only supported on POJO types. "
+ + "A type is considered a POJO if all its fields are public, or have both getters and setters defined");
+ }
+ }
+
+ @Override
+ public T reduce(T value1, T value2) throws Exception {
+
+ Field[] keyFields = comparator.getKeyFields();
+ Object field1 = comparator.accessField(keyFields[0], value1);
+ Object field2 = comparator.accessField(keyFields[0], value2);
+
+ keyFields[0].set(value2, adder.add(field1, field2));
+
+ return value2;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
new file mode 100644
index 0000000..1ac236d
--- /dev/null
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.function.aggregation;
+
+import java.io.Serializable;
+
+public abstract class SumFunction implements Serializable{
+
+ private static final long serialVersionUID = 1L;
+
+ public abstract Object add(Object o1, Object o2);
+
+ public static SumFunction getForClass(Class<?> clazz) {
+
+ if (clazz == Integer.class) {
+ return new IntSum();
+ } else if (clazz == Long.class) {
+ return new LongSum();
+ } else if (clazz == Short.class) {
+ return new ShortSum();
+ } else if (clazz == Double.class) {
+ return new DoubleSum();
+ } else if (clazz == Float.class) {
+ return new FloatSum();
+ } else if (clazz == Byte.class) {
+ return new ByteSum();
+ } else {
+ throw new RuntimeException("DataStream cannot be summed because the class "
+ + clazz.getSimpleName() + " does not support the + operator.");
+ }
+ }
+
+ private static class IntSum extends SumFunction {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Object add(Object value1, Object value2) {
+ return (Integer) value1 + (Integer) value2;
+ }
+ }
+
+ private static class LongSum extends SumFunction {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Object add(Object value1, Object value2) {
+ return (Long) value1 + (Long) value2;
+ }
+ }
+
+ private static class DoubleSum extends SumFunction {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Object add(Object value1, Object value2) {
+ return (Double) value1 + (Double) value2;
+ }
+ }
+
+ private static class ShortSum extends SumFunction {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Object add(Object value1, Object value2) {
+ return (Short) value1 + (Short) value2;
+ }
+ }
+
+ private static class FloatSum extends SumFunction {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Object add(Object value1, Object value2) {
+ return (Float) value1 + (Float) value2;
+ }
+ }
+
+ private static class ByteSum extends SumFunction {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Object add(Object value1, Object value2) {
+ return (Byte) value1 + (Byte) value2;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7ae58042/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java b/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
index 95c6f71..0fbf72a 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
@@ -23,15 +23,14 @@ import static org.junit.Assert.fail;
import java.util.ArrayList;
import java.util.List;
+import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.function.aggregation.MaxAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MaxByAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MinAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.MinByAggregationFunction;
-import org.apache.flink.streaming.api.function.aggregation.SumAggregationFunction;
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType;
+import org.apache.flink.streaming.api.function.aggregation.ComparableAggregator;
+import org.apache.flink.streaming.api.function.aggregation.SumAggregator;
import org.apache.flink.streaming.api.invokable.operator.GroupedReduceInvokable;
import org.apache.flink.streaming.api.invokable.operator.StreamReduceInvokable;
import org.apache.flink.streaming.util.MockInvokable;
@@ -87,22 +86,22 @@ public class AggregationFunctionTest {
expectedGroupMaxList.add(new Tuple2<Integer, Integer>(i % 3, i));
}
- TypeInformation<?> type1 = TypeExtractor.getForObject(new Tuple2<Integer, Integer>(0, 0));
- TypeInformation<?> type2 = TypeExtractor.getForObject(2);
-
- @SuppressWarnings("unchecked")
- SumAggregationFunction<Tuple2<Integer, Integer>> sumFunction = SumAggregationFunction
- .getSumFunction(1, Integer.class, type1);
- @SuppressWarnings("unchecked")
- SumAggregationFunction<Integer> sumFunction0 = SumAggregationFunction.getSumFunction(0,
- Integer.class, type2);
- MinAggregationFunction<Tuple2<Integer, Integer>> minFunction = new MinAggregationFunction<Tuple2<Integer, Integer>>(
- 1, type1);
- MinAggregationFunction<Integer> minFunction0 = new MinAggregationFunction<Integer>(0, type2);
- MaxAggregationFunction<Tuple2<Integer, Integer>> maxFunction = new MaxAggregationFunction<Tuple2<Integer, Integer>>(
- 1, type1);
- MaxAggregationFunction<Integer> maxFunction0 = new MaxAggregationFunction<Integer>(0, type2);
-
+ TypeInformation<Tuple2<Integer, Integer>> type1 = TypeExtractor
+ .getForObject(new Tuple2<Integer, Integer>(0, 0));
+ TypeInformation<Integer> type2 = TypeExtractor.getForObject(2);
+
+ ReduceFunction<Tuple2<Integer, Integer>> sumFunction = SumAggregator.getSumFunction(1,
+ Integer.class, type1);
+ ReduceFunction<Integer> sumFunction0 = SumAggregator
+ .getSumFunction(0, Integer.class, type2);
+ ReduceFunction<Tuple2<Integer, Integer>> minFunction = ComparableAggregator
+ .getAggregator(1, type1, AggregationType.MIN);
+ ReduceFunction<Integer> minFunction0 = ComparableAggregator.getAggregator(0,
+ type2, AggregationType.MIN);
+ ReduceFunction<Tuple2<Integer, Integer>> maxFunction = ComparableAggregator
+ .getAggregator(1, type1, AggregationType.MAX);
+ ReduceFunction<Integer> maxFunction0 = ComparableAggregator.getAggregator(0,
+ type2, AggregationType.MAX);
List<Tuple2<Integer, Integer>> sumList = MockInvokable.createAndExecute(
new StreamReduceInvokable<Tuple2<Integer, Integer>>(sumFunction), getInputList());
@@ -157,15 +156,15 @@ public class AggregationFunctionTest {
// Nothing to do here
}
- MaxByAggregationFunction<Tuple2<Integer, Integer>> maxByFunctionFirst = new MaxByAggregationFunction<Tuple2<Integer, Integer>>(
- 0, true, type1);
- MaxByAggregationFunction<Tuple2<Integer, Integer>> maxByFunctionLast = new MaxByAggregationFunction<Tuple2<Integer, Integer>>(
- 0, false, type1);
+ ReduceFunction<Tuple2<Integer, Integer>> maxByFunctionFirst = ComparableAggregator
+ .getAggregator(0, type1, AggregationType.MAXBY, true);
+ ReduceFunction<Tuple2<Integer, Integer>> maxByFunctionLast = ComparableAggregator
+ .getAggregator(0, type1, AggregationType.MAXBY, false);
- MinByAggregationFunction<Tuple2<Integer, Integer>> minByFunctionFirst = new MinByAggregationFunction<Tuple2<Integer, Integer>>(
- 0, true, type1);
- MinByAggregationFunction<Tuple2<Integer, Integer>> minByFunctionLast = new MinByAggregationFunction<Tuple2<Integer, Integer>>(
- 0, false, type1);
+ ReduceFunction<Tuple2<Integer, Integer>> minByFunctionFirst = ComparableAggregator
+ .getAggregator(0, type1, AggregationType.MINBY, true);
+ ReduceFunction<Tuple2<Integer, Integer>> minByFunctionLast = ComparableAggregator
+ .getAggregator(0, type1, AggregationType.MINBY, false);
List<Tuple2<Integer, Integer>> maxByFirstExpected = new ArrayList<Tuple2<Integer, Integer>>();
maxByFirstExpected.add(new Tuple2<Integer, Integer>(0, 0));
@@ -228,17 +227,18 @@ public class AggregationFunctionTest {
@Test
public void minMaxByTest() {
- TypeInformation<?> type1 = TypeExtractor.getForObject(new Tuple2<Integer, Integer>(0, 0));
-
- MaxByAggregationFunction<Tuple2<Integer, Integer>> maxByFunctionFirst = new MaxByAggregationFunction<Tuple2<Integer, Integer>>(
- 0, true, type1);
- MaxByAggregationFunction<Tuple2<Integer, Integer>> maxByFunctionLast = new MaxByAggregationFunction<Tuple2<Integer, Integer>>(
- 0, false, type1);
-
- MinByAggregationFunction<Tuple2<Integer, Integer>> minByFunctionFirst = new MinByAggregationFunction<Tuple2<Integer, Integer>>(
- 0, true, type1);
- MinByAggregationFunction<Tuple2<Integer, Integer>> minByFunctionLast = new MinByAggregationFunction<Tuple2<Integer, Integer>>(
- 0, false, type1);
+ TypeInformation<Tuple2<Integer, Integer>> type1 = TypeExtractor
+ .getForObject(new Tuple2<Integer, Integer>(0, 0));
+
+ ReduceFunction<Tuple2<Integer, Integer>> maxByFunctionFirst = ComparableAggregator
+ .getAggregator(0, type1, AggregationType.MAXBY, true);
+ ReduceFunction<Tuple2<Integer, Integer>> maxByFunctionLast = ComparableAggregator
+ .getAggregator(0, type1, AggregationType.MAXBY, false);
+
+ ReduceFunction<Tuple2<Integer, Integer>> minByFunctionFirst = ComparableAggregator
+ .getAggregator(0, type1, AggregationType.MINBY, true);
+ ReduceFunction<Tuple2<Integer, Integer>> minByFunctionLast = ComparableAggregator
+ .getAggregator(0, type1, AggregationType.MINBY, false);
List<Tuple2<Integer, Integer>> maxByFirstExpected = new ArrayList<Tuple2<Integer, Integer>>();
maxByFirstExpected.add(new Tuple2<Integer, Integer>(0, 0));