You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2017/01/25 19:56:08 UTC
[3/4] flink git commit: [FLINK-5630] [streaming api] Followups to the
AggregateFunction
[FLINK-5630] [streaming api] Followups to the AggregateFunction
- Add a RichAggregateFunction
- Document generic type parameters
- Allowing different input/output types for the cases where an additional window apply function is specified
- Adding the aggregate() methods to the Scala API
- Adding the window translation tests
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1542260d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1542260d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1542260d
Branch: refs/heads/master
Commit: 1542260d52238e87de4fa040e6079465777e8263
Parents: 6f5c7d8
Author: Stephan Ewen <se...@apache.org>
Authored: Tue Jan 24 20:08:28 2017 +0100
Committer: Stephan Ewen <se...@apache.org>
Committed: Wed Jan 25 19:56:23 2017 +0100
----------------------------------------------------------------------
.../api/common/functions/AggregateFunction.java | 66 ++++-
.../common/functions/RichAggregateFunction.java | 53 ++++
.../api/datastream/AllWindowedStream.java | 223 +++++++++++++++-
.../api/datastream/WindowedStream.java | 52 +++-
.../AggregateApplyAllWindowFunction.java | 21 +-
.../windowing/AggregateApplyWindowFunction.java | 20 +-
.../windowing/AllWindowTranslationTest.java | 216 ++++++++++++++++
.../windowing/WindowTranslationTest.java | 231 ++++++++++++++++-
.../streaming/api/scala/AllWindowedStream.scala | 98 ++++++-
.../streaming/api/scala/WindowedStream.scala | 86 ++++++-
.../api/scala/AllWindowTranslationTest.scala | 226 ++++++++++++++++-
.../api/scala/WindowTranslationTest.scala | 254 ++++++++++++++++++-
12 files changed, 1500 insertions(+), 46 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-core/src/main/java/org/apache/flink/api/common/functions/AggregateFunction.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/AggregateFunction.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/AggregateFunction.java
index 507be63..3c79396 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/functions/AggregateFunction.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/AggregateFunction.java
@@ -18,14 +18,37 @@
package org.apache.flink.api.common.functions;
+import org.apache.flink.annotation.PublicEvolving;
+
import java.io.Serializable;
/**
+ * The {@code AggregateFunction} is a flexible aggregation function, characterized by the
+ * following features:
+ *
+ * <ul>
+ * <li>The aggregates may use different types for input values, intermediate aggregates,
+ * and result type, to support a wide range of aggregation types.</li>
+ *
+ * <li>Support for distributive aggregations: Different intermediate aggregates can be
+ * merged together, to allow for pre-aggregation/final-aggregation optimizations.</li>
+ * </ul>
+ *
+ * <p>The {@code AggregateFunction}'s intermediate aggregate (in-progress aggregation state)
+ * is called the <i>accumulator</i>. Values are added to the accumulator, and final aggregates are
+ * obtained by finalizing the accumulator state. This supports aggregation functions where the
+ * intermediate state needs to be different than the aggregated values and the final result type,
+ * such as for example <i>average</i> (which typically keeps a count and sum).
+ * Merging intermediate aggregates (partial aggregates) means merging the accumulators.
+ *
+ * <p>The AggregationFunction itself is stateless. To allow a single AggregationFunction
+ * instance to maintain multiple aggregates (such as one aggregate per key), the
+ * AggregationFunction creates a new accumulator whenever a new aggregation is started.
*
* <p>Aggregation functions must be {@link Serializable} because they are sent around
* between distributed processes during distributed execution.
*
- * <p>An example how to use this interface is below:
+ * <h1>Example: Average and Weighted Average</h1>
*
* <pre>{@code
* // the accumulator, which holds the state of the in-flight aggregate
@@ -81,14 +104,55 @@ import java.io.Serializable;
* }
* }
* }</pre>
+ *
+ * @param <IN> The type of the values that are aggregated (input values)
+ * @param <ACC> The type of the accumulator (intermediate aggregate state).
+ * @param <OUT> The type of the aggregated result
*/
+@PublicEvolving
public interface AggregateFunction<IN, ACC, OUT> extends Function, Serializable {
+ /**
+ * Creates a new accumulator, starting a new aggregate.
+ *
+ * <p>The new accumulator is typically meaningless unless a value is added
+ * via {@link #add(Object, Object)}.
+ *
+ * <p>The accumulator is the state of a running aggregation. When a program has multiple
+ * aggregates in progress (such as per key and window), the state (per key and window)
+ * is the size of the accumulator.
+ *
+ * @return A new accumulator, corresponding to an empty aggregate.
+ */
ACC createAccumulator();
+ /**
+ * Adds the given value to the given accumulator.
+ *
+ * @param value The value to add
+ * @param accumulator The accumulator to add the value to
+ */
void add(IN value, ACC accumulator);
+ /**
+ * Gets the result of the aggregation from the accumulator.
+ *
+ * @param accumulator The accumulator of the aggregation
+ * @return The final aggregation result.
+ */
OUT getResult(ACC accumulator);
+ /**
+ * Merges two accumulators, returning an accumulator with the merged state.
+ *
+ * <p>This function may reuse any of the given accumulators as the target for the merge
+ * and return that. The assumption is that the given accumulators will not be used any
+ * more after having been passed to this function.
+ *
+ * @param a An accumulator to merge
+ * @param b Another accumulator to merge
+ *
+ * @return The accumulator with the merged state
+ */
ACC merge(ACC a, ACC b);
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-core/src/main/java/org/apache/flink/api/common/functions/RichAggregateFunction.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/RichAggregateFunction.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/RichAggregateFunction.java
new file mode 100644
index 0000000..caf2557
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/RichAggregateFunction.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.functions;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * Rich variant of the {@link AggregateFunction}. As a {@link RichFunction}, it gives access to the
+ * {@link RuntimeContext} and provides setup and teardown methods:
+ * {@link RichFunction#open(org.apache.flink.configuration.Configuration)} and
+ * {@link RichFunction#close()}.
+ *
+ * @see AggregateFunction
+ *
+ * @param <IN> The type of the values that are aggregated (input values)
+ * @param <ACC> The type of the accumulator (intermediate aggregate state).
+ * @param <OUT> The type of the aggregated result
+ */
+@PublicEvolving
+public abstract class RichAggregateFunction<IN, ACC, OUT>
+ extends AbstractRichFunction
+ implements AggregateFunction<IN, ACC, OUT> {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public abstract ACC createAccumulator();
+
+ @Override
+ public abstract void add(IN value, ACC accumulator);
+
+ @Override
+ public abstract OUT getResult(ACC accumulator);
+
+ @Override
+ public abstract ACC merge(ACC a, ACC b);
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AllWindowedStream.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AllWindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AllWindowedStream.java
index 5de1774..c3c7424 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AllWindowedStream.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AllWindowedStream.java
@@ -20,9 +20,11 @@ package org.apache.flink.streaming.api.datastream;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.annotation.Public;
+import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FoldFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
@@ -35,6 +37,7 @@ import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction;
import org.apache.flink.streaming.api.functions.aggregation.ComparableAggregator;
import org.apache.flink.streaming.api.functions.aggregation.SumAggregator;
+import org.apache.flink.streaming.api.functions.windowing.AggregateApplyAllWindowFunction;
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.functions.windowing.PassThroughAllWindowFunction;
import org.apache.flink.streaming.api.functions.windowing.FoldApplyAllWindowFunction;
@@ -54,6 +57,7 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* A {@code AllWindowedStream} represents a data stream where the stream of
@@ -180,7 +184,7 @@ public class AllWindowedStream<T, W extends Window> {
function = input.getExecutionEnvironment().clean(function);
String callLocation = Utils.getCallLocationName();
- String udfName = "WindowedStream." + callLocation;
+ String udfName = "AllWindowedStream." + callLocation;
return reduce(function, new PassThroughAllWindowFunction<W, T>());
}
@@ -198,7 +202,10 @@ public class AllWindowedStream<T, W extends Window> {
* @return The data stream that is the result of applying the window function to the window.
*/
@PublicEvolving
- public <R> SingleOutputStreamOperator<R> reduce(ReduceFunction<T> reduceFunction, AllWindowFunction<T, R, W> function) {
+ public <R> SingleOutputStreamOperator<R> reduce(
+ ReduceFunction<T> reduceFunction,
+ AllWindowFunction<T, R, W> function) {
+
TypeInformation<T> inType = input.getType();
TypeInformation<R> resultType = TypeExtractor.getUnaryOperatorReturnType(
function, AllWindowFunction.class, true, true, inType, null, false);
@@ -230,7 +237,7 @@ public class AllWindowedStream<T, W extends Window> {
reduceFunction = input.getExecutionEnvironment().clean(reduceFunction);
String callLocation = Utils.getCallLocationName();
- String udfName = "WindowedStream." + callLocation;
+ String udfName = "AllWindowedStream." + callLocation;
String opName;
KeySelector<T, Byte> keySel = input.getKeySelector();
@@ -279,6 +286,202 @@ public class AllWindowedStream<T, W extends Window> {
return input.transform(opName, resultType, operator).forceNonParallel();
}
+ // ------------------------------------------------------------------------
+ // AggregateFunction
+ // ------------------------------------------------------------------------
+
+ /**
+ * Applies the given {@code AggregateFunction} to each window. The AggregateFunction
+ * aggregates all elements of a window into a single result element. The stream of these
+ * result elements (one per window) is interpreted as a regular non-windowed stream.
+ *
+ * @param function The aggregation function.
+ * @return The data stream that is the result of applying the fold function to the window.
+ *
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <R> The type of the elements in the resulting stream, equal to the
+ * AggregateFunction's result type
+ */
+ public <ACC, R> SingleOutputStreamOperator<R> aggregate(AggregateFunction<T, ACC, R> function) {
+ checkNotNull(function, "function");
+
+ if (function instanceof RichFunction) {
+ throw new UnsupportedOperationException("This aggregation function cannot be a RichFunction.");
+ }
+
+ TypeInformation<ACC> accumulatorType = TypeExtractor.getAggregateFunctionAccumulatorType(
+ function, input.getType(), null, false);
+
+ TypeInformation<R> resultType = TypeExtractor.getAggregateFunctionReturnType(
+ function, input.getType(), null, false);
+
+ return aggregate(function, accumulatorType, resultType);
+ }
+
+ /**
+ * Applies the given {@code AggregateFunction} to each window. The AggregateFunction
+ * aggregates all elements of a window into a single result element. The stream of these
+ * result elements (one per window) is interpreted as a regular non-windowed stream.
+ *
+ * @param function The aggregation function.
+ * @return The data stream that is the result of applying the aggregation function to the window.
+ *
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <R> The type of the elements in the resulting stream, equal to the
+ * AggregateFunction's result type
+ */
+ public <ACC, R> SingleOutputStreamOperator<R> aggregate(
+ AggregateFunction<T, ACC, R> function,
+ TypeInformation<ACC> accumulatorType,
+ TypeInformation<R> resultType) {
+
+ checkNotNull(function, "function");
+ checkNotNull(accumulatorType, "accumulatorType");
+ checkNotNull(resultType, "resultType");
+
+ if (function instanceof RichFunction) {
+ throw new UnsupportedOperationException("This aggregation function cannot be a RichFunction.");
+ }
+
+ return aggregate(function, new PassThroughAllWindowFunction<W, R>(),
+ accumulatorType, resultType, resultType);
+ }
+
+ /**
+ * Applies the given window function to each window. The window function is called for each
+ * evaluation of the window for each key individually. The output of the window function is
+ * interpreted as a regular non-windowed stream.
+ *
+ * <p>Arriving data is incrementally aggregated using the given aggregate function. This means
+ * that the window function typically has only a single value to process when called.
+ *
+ * @param aggFunction The aggregate function that is used for incremental aggregation.
+ * @param windowFunction The window function.
+ *
+ * @return The data stream that is the result of applying the window function to the window.
+ *
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <V> The type of AggregateFunction's result, and the WindowFunction's input
+ * @param <R> The type of the elements in the resulting stream, equal to the
+ * WindowFunction's result type
+ */
+ public <ACC, V, R> SingleOutputStreamOperator<R> aggregate(
+ AggregateFunction<T, ACC, V> aggFunction,
+ AllWindowFunction<V, R, W> windowFunction) {
+
+ checkNotNull(aggFunction, "aggFunction");
+ checkNotNull(windowFunction, "windowFunction");
+
+ TypeInformation<ACC> accumulatorType = TypeExtractor.getAggregateFunctionAccumulatorType(
+ aggFunction, input.getType(), null, false);
+
+ TypeInformation<V> aggResultType = TypeExtractor.getAggregateFunctionReturnType(
+ aggFunction, input.getType(), null, false);
+
+ TypeInformation<R> resultType = TypeExtractor.getUnaryOperatorReturnType(
+ windowFunction, AllWindowFunction.class, true, true, aggResultType, null, false);
+
+ return aggregate(aggFunction, windowFunction, accumulatorType, aggResultType, resultType);
+ }
+
+ /**
+ * Applies the given window function to each window. The window function is called for each
+ * evaluation of the window for each key individually. The output of the window function is
+ * interpreted as a regular non-windowed stream.
+ *
+ * <p>Arriving data is incrementally aggregated using the given aggregate function. This means
+ * that the window function typically has only a single value to process when called.
+ *
+ * @param aggregateFunction The aggregation function that is used for incremental aggregation.
+ * @param windowFunction The window function.
+ * @param accumulatorType Type information for the internal accumulator type of the aggregation function
+ * @param resultType Type information for the result type of the window function
+ *
+ * @return The data stream that is the result of applying the window function to the window.
+ *
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <V> The type of AggregateFunction's result, and the WindowFunction's input
+ * @param <R> The type of the elements in the resulting stream, equal to the
+ * WindowFunction's result type
+ */
+ public <ACC, V, R> SingleOutputStreamOperator<R> aggregate(
+ AggregateFunction<T, ACC, V> aggregateFunction,
+ AllWindowFunction<V, R, W> windowFunction,
+ TypeInformation<ACC> accumulatorType,
+ TypeInformation<V> aggregateResultType,
+ TypeInformation<R> resultType) {
+
+ checkNotNull(aggregateFunction, "aggregateFunction");
+ checkNotNull(windowFunction, "windowFunction");
+ checkNotNull(accumulatorType, "accumulatorType");
+ checkNotNull(aggregateResultType, "aggregateResultType");
+ checkNotNull(resultType, "resultType");
+
+ if (aggregateFunction instanceof RichFunction) {
+ throw new UnsupportedOperationException("This aggregate function cannot be a RichFunction.");
+ }
+
+ //clean the closures
+ windowFunction = input.getExecutionEnvironment().clean(windowFunction);
+ aggregateFunction = input.getExecutionEnvironment().clean(aggregateFunction);
+
+ final String callLocation = Utils.getCallLocationName();
+ final String udfName = "AllWindowedStream." + callLocation;
+
+ final String opName;
+ final KeySelector<T, Byte> keySel = input.getKeySelector();
+
+ OneInputStreamOperator<T, R> operator;
+
+ if (evictor != null) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ TypeSerializer<StreamRecord<T>> streamRecordSerializer =
+ (TypeSerializer<StreamRecord<T>>) new StreamElementSerializer(
+ input.getType().createSerializer(getExecutionEnvironment().getConfig()));
+
+ ListStateDescriptor<StreamRecord<T>> stateDesc =
+ new ListStateDescriptor<>("window-contents", streamRecordSerializer);
+
+ opName = "TriggerWindow(" + windowAssigner + ", " + stateDesc + ", " + trigger + ", " + evictor + ", " + udfName + ")";
+
+ operator =
+ new EvictingWindowOperator<>(windowAssigner,
+ windowAssigner.getWindowSerializer(getExecutionEnvironment().getConfig()),
+ keySel,
+ input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()),
+ stateDesc,
+ new InternalIterableAllWindowFunction<>(
+ new AggregateApplyAllWindowFunction<>(aggregateFunction, windowFunction)),
+ trigger,
+ evictor,
+ allowedLateness);
+
+ } else {
+ AggregatingStateDescriptor<T, ACC, V> stateDesc = new AggregatingStateDescriptor<>(
+ "window-contents",
+ aggregateFunction,
+ accumulatorType.createSerializer(getExecutionEnvironment().getConfig()));
+
+ opName = "TriggerWindow(" + windowAssigner + ", " + stateDesc + ", " + trigger + ", " + udfName + ")";
+
+ operator = new WindowOperator<>(
+ windowAssigner,
+ windowAssigner.getWindowSerializer(getExecutionEnvironment().getConfig()),
+ keySel,
+ input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()),
+ stateDesc,
+ new InternalSingleValueAllWindowFunction<>(windowFunction),
+ trigger,
+ allowedLateness);
+ }
+
+ return input.transform(opName, resultType, operator).forceNonParallel();
+ }
+
+ // ------------------------------------------------------------------------
+ // FoldFunction
+ // ------------------------------------------------------------------------
+
/**
* Applies the given fold function to each window. The window function is called for each
* evaluation of the window for each key individually. The output of the reduce function is
@@ -374,7 +577,7 @@ public class AllWindowedStream<T, W extends Window> {
foldFunction = input.getExecutionEnvironment().clean(foldFunction);
String callLocation = Utils.getCallLocationName();
- String udfName = "WindowedStream." + callLocation;
+ String udfName = "AllWindowedStream." + callLocation;
String opName;
KeySelector<T, Byte> keySel = input.getKeySelector();
@@ -422,6 +625,10 @@ public class AllWindowedStream<T, W extends Window> {
return input.transform(opName, resultType, operator).forceNonParallel();
}
+ // ------------------------------------------------------------------------
+ // Apply (Window Function)
+ // ------------------------------------------------------------------------
+
/**
* Applies the given window function to each window. The window function is called for each
* evaluation of the window for each key individually. The output of the window function is
@@ -460,7 +667,7 @@ public class AllWindowedStream<T, W extends Window> {
function = input.getExecutionEnvironment().clean(function);
String callLocation = Utils.getCallLocationName();
- String udfName = "WindowedStream." + callLocation;
+ String udfName = "AllWindowedStream." + callLocation;
String opName;
KeySelector<T, Byte> keySel = input.getKeySelector();
@@ -557,7 +764,7 @@ public class AllWindowedStream<T, W extends Window> {
reduceFunction = input.getExecutionEnvironment().clean(reduceFunction);
String callLocation = Utils.getCallLocationName();
- String udfName = "WindowedStream." + callLocation;
+ String udfName = "AllWindowedStream." + callLocation;
String opName;
KeySelector<T, Byte> keySel = input.getKeySelector();
@@ -660,7 +867,7 @@ public class AllWindowedStream<T, W extends Window> {
foldFunction = input.getExecutionEnvironment().clean(foldFunction);
String callLocation = Utils.getCallLocationName();
- String udfName = "WindowedStream." + callLocation;
+ String udfName = "AllWindowedStream." + callLocation;
String opName;
KeySelector<T, Byte> keySel = input.getKeySelector();
@@ -709,7 +916,7 @@ public class AllWindowedStream<T, W extends Window> {
}
// ------------------------------------------------------------------------
- // Aggregations on the keyed windows
+ // Aggregations on the all windows
// ------------------------------------------------------------------------
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
index c74bad7..3fbdda8 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
@@ -515,12 +515,16 @@ public class WindowedStream<T, K, W extends Window> {
// ------------------------------------------------------------------------
/**
- * Applies the given fold function to each window. The window function is called for each
- * evaluation of the window for each key individually. The output of the reduce function is
- * interpreted as a regular non-windowed stream.
+ * Applies the given aggregation function to each window. The aggregation function is called for
+ * each element, aggregating values incrementally and keeping the state to one accumulator
+ * per key and window.
*
- * @param function The fold function.
+ * @param function The aggregation function.
* @return The data stream that is the result of applying the fold function to the window.
+ *
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <R> The type of the elements in the resulting stream, equal to the
+ * AggregateFunction's result type
*/
public <ACC, R> SingleOutputStreamOperator<R> aggregate(AggregateFunction<T, ACC, R> function) {
checkNotNull(function, "function");
@@ -545,6 +549,10 @@ public class WindowedStream<T, K, W extends Window> {
*
* @param function The aggregation function.
* @return The data stream that is the result of applying the aggregation function to the window.
+ *
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <R> The type of the elements in the resulting stream, equal to the
+ * AggregateFunction's result type
*/
public <ACC, R> SingleOutputStreamOperator<R> aggregate(
AggregateFunction<T, ACC, R> function,
@@ -559,7 +567,8 @@ public class WindowedStream<T, K, W extends Window> {
throw new UnsupportedOperationException("This aggregation function cannot be a RichFunction.");
}
- return aggregate(function, new PassThroughWindowFunction<K, W, R>(), accumulatorType, resultType);
+ return aggregate(function, new PassThroughWindowFunction<K, W, R>(),
+ accumulatorType, resultType, resultType);
}
/**
@@ -574,10 +583,15 @@ public class WindowedStream<T, K, W extends Window> {
* @param windowFunction The window function.
*
* @return The data stream that is the result of applying the window function to the window.
+ *
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <V> The type of AggregateFunction's result, and the WindowFunction's input
+ * @param <R> The type of the elements in the resulting stream, equal to the
+ * WindowFunction's result type
*/
- public <ACC, R> SingleOutputStreamOperator<R> aggregate(
- AggregateFunction<T, ACC, R> aggFunction,
- WindowFunction<R, R, K, W> windowFunction) {
+ public <ACC, V, R> SingleOutputStreamOperator<R> aggregate(
+ AggregateFunction<T, ACC, V> aggFunction,
+ WindowFunction<V, R, K, W> windowFunction) {
checkNotNull(aggFunction, "aggFunction");
checkNotNull(windowFunction, "windowFunction");
@@ -585,10 +599,13 @@ public class WindowedStream<T, K, W extends Window> {
TypeInformation<ACC> accumulatorType = TypeExtractor.getAggregateFunctionAccumulatorType(
aggFunction, input.getType(), null, false);
- TypeInformation<R> resultType = TypeExtractor.getAggregateFunctionReturnType(
+ TypeInformation<V> aggResultType = TypeExtractor.getAggregateFunctionReturnType(
aggFunction, input.getType(), null, false);
- return aggregate(aggFunction, windowFunction, accumulatorType, resultType);
+ TypeInformation<R> resultType = TypeExtractor.getUnaryOperatorReturnType(
+ windowFunction, WindowFunction.class, true, true, aggResultType, null, false);
+
+ return aggregate(aggFunction, windowFunction, accumulatorType, aggResultType, resultType);
}
/**
@@ -605,16 +622,23 @@ public class WindowedStream<T, K, W extends Window> {
* @param resultType Type information for the result type of the window function
*
* @return The data stream that is the result of applying the window function to the window.
+ *
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <V> The type of AggregateFunction's result, and the WindowFunction's input
+ * @param <R> The type of the elements in the resulting stream, equal to the
+ * WindowFunction's result type
*/
- public <ACC, R> SingleOutputStreamOperator<R> aggregate(
- AggregateFunction<T, ACC, R> aggregateFunction,
- WindowFunction<R, R, K, W> windowFunction,
+ public <ACC, V, R> SingleOutputStreamOperator<R> aggregate(
+ AggregateFunction<T, ACC, V> aggregateFunction,
+ WindowFunction<V, R, K, W> windowFunction,
TypeInformation<ACC> accumulatorType,
+ TypeInformation<V> aggregateResultType,
TypeInformation<R> resultType) {
checkNotNull(aggregateFunction, "aggregateFunction");
checkNotNull(windowFunction, "windowFunction");
checkNotNull(accumulatorType, "accumulatorType");
+ checkNotNull(aggregateResultType, "aggregateResultType");
checkNotNull(resultType, "resultType");
if (aggregateFunction instanceof RichFunction) {
@@ -654,7 +678,7 @@ public class WindowedStream<T, K, W extends Window> {
allowedLateness);
} else {
- AggregatingStateDescriptor<T, ACC, R> stateDesc = new AggregatingStateDescriptor<>("window-contents",
+ AggregatingStateDescriptor<T, ACC, V> stateDesc = new AggregatingStateDescriptor<>("window-contents",
aggregateFunction, accumulatorType.createSerializer(getExecutionEnvironment().getConfig()));
opName = "TriggerWindow(" + windowAssigner + ", " + stateDesc + ", " + trigger + ", " + udfName + ")";
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyAllWindowFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyAllWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyAllWindowFunction.java
index 1b9fa88..929e336 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyAllWindowFunction.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyAllWindowFunction.java
@@ -26,18 +26,29 @@ import org.apache.flink.util.Collector;
import java.util.Collections;
+/**
+ * A {@link AllWindowFunction} that composes an {@link AggregateFunction} and {@link AllWindowFunction}.
+ * Upon invocation, this first applies {@code AggregateFunction} to the input, and then
+ * finally the {@code AllWindowFunction} to the single result element.
+ *
+ * @param <W> The window type
+ * @param <T> The type of the input to the AggregateFunction
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <V> The type of the AggregateFunction's result, and the input to the WindowFunction
+ * @param <R> The result type of the WindowFunction
+ */
@Internal
-public class AggregateApplyAllWindowFunction<W extends Window, T, ACC, R>
- extends WrappingFunction<AllWindowFunction<R, R, W>>
+public class AggregateApplyAllWindowFunction<W extends Window, T, ACC, V, R>
+ extends WrappingFunction<AllWindowFunction<V, R, W>>
implements AllWindowFunction<T, R, W> {
private static final long serialVersionUID = 1L;
- private final AggregateFunction<T, ACC, R> aggFunction;
+ private final AggregateFunction<T, ACC, V> aggFunction;
public AggregateApplyAllWindowFunction(
- AggregateFunction<T, ACC, R> aggFunction,
- AllWindowFunction<R, R, W> windowFunction) {
+ AggregateFunction<T, ACC, V> aggFunction,
+ AllWindowFunction<V, R, W> windowFunction) {
super(windowFunction);
this.aggFunction = aggFunction;
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyWindowFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyWindowFunction.java
index 5200bc2..73e1f0f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyWindowFunction.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/AggregateApplyWindowFunction.java
@@ -25,16 +25,28 @@ import org.apache.flink.util.Collector;
import java.util.Collections;
+/**
+ * A {@link WindowFunction} that composes an {@link AggregateFunction} and {@link WindowFunction}.
+ * Upon invocation, this first applies {@code AggregateFunction} to the input, and then
+ * finally the {@code WindowFunction} to the single result element.
+ *
+ * @param <K> The key type
+ * @param <W> The window type
+ * @param <T> The type of the input to the AggregateFunction
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <V> The type of the AggregateFunction's result, and the input to the WindowFunction
+ * @param <R> The result type of the WindowFunction
+ */
@Internal
-public class AggregateApplyWindowFunction<K, W extends Window, T, ACC, R>
- extends WrappingFunction<WindowFunction<R, R, K, W>>
+public class AggregateApplyWindowFunction<K, W extends Window, T, ACC, V, R>
+ extends WrappingFunction<WindowFunction<V, R, K, W>>
implements WindowFunction<T, R, K, W> {
private static final long serialVersionUID = 1L;
- private final AggregateFunction<T, ACC, R> aggFunction;
+ private final AggregateFunction<T, ACC, V> aggFunction;
- public AggregateApplyWindowFunction(AggregateFunction<T, ACC, R> aggFunction, WindowFunction<R, R, K, W> windowFunction) {
+ public AggregateApplyWindowFunction(AggregateFunction<T, ACC, V> aggFunction, WindowFunction<V, R, K, W> windowFunction) {
super(windowFunction);
this.aggFunction = aggFunction;
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
index 3d4de5d..b6c1618 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
@@ -18,10 +18,13 @@
package org.apache.flink.streaming.runtime.operators.windowing;
import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FoldFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichAggregateFunction;
import org.apache.flink.api.common.functions.RichFoldFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
+import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
@@ -74,6 +77,10 @@ import static org.junit.Assert.fail;
@SuppressWarnings("serial")
public class AllWindowTranslationTest {
+ // ------------------------------------------------------------------------
+ // rich function tests
+ // ------------------------------------------------------------------------
+
/**
* .reduce() does not support RichReduceFunction, since the reduce function is used internally
* in a {@code ReducingState}.
@@ -101,6 +108,24 @@ public class AllWindowTranslationTest {
}
/**
+ * .aggregate() does not support RichAggregateFunction, since the AggregateFunction is used internally
+ * in an {@code AggregatingState}.
+ */
+ @Test(expected = UnsupportedOperationException.class)
+ public void testAggregateWithRichFunctionFails() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ source
+ .windowAll(SlidingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
+ .aggregate(new DummyRichAggregationFunction<Tuple2<String, Integer>>());
+
+ fail("exception was not thrown");
+ }
+
+ /**
* .fold() does not support RichFoldFunction, since the fold function is used internally
* in a {@code FoldingState}.
*/
@@ -126,6 +151,9 @@ public class AllWindowTranslationTest {
fail("exception was not thrown");
}
+ // ------------------------------------------------------------------------
+ // Merging Windows Support
+ // ------------------------------------------------------------------------
@Test
public void testSessionWithFoldFails() throws Exception {
@@ -206,6 +234,10 @@ public class AllWindowTranslationTest {
fail("The trigger call should fail.");
}
+ // ------------------------------------------------------------------------
+ // reduce() translation tests
+ // ------------------------------------------------------------------------
+
@Test
@SuppressWarnings("rawtypes")
public void testReduceEventTime() throws Exception {
@@ -392,6 +424,126 @@ public class AllWindowTranslationTest {
processElementAndEnsureOutput(operator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
}
+ // ------------------------------------------------------------------------
+ // aggregate() translation tests
+ // ------------------------------------------------------------------------
+
+ @Test
+ public void testAggregateEventTime() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime);
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+
+ DataStream<Tuple2<String, Integer>> window1 = source
+ .windowAll(SlidingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
+ .aggregate(new DummyAggregationFunction());
+
+ OneInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>> transform =
+ (OneInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>>) window1.getTransformation();
+
+ OneInputStreamOperator<Tuple2<String, Integer>, Tuple2<String, Integer>> operator = transform.getOperator();
+
+ Assert.assertTrue(operator instanceof WindowOperator);
+ WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?> winOperator =
+ (WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?>) operator;
+
+ Assert.assertTrue(winOperator.getTrigger() instanceof EventTimeTrigger);
+ Assert.assertTrue(winOperator.getWindowAssigner() instanceof SlidingEventTimeWindows);
+ Assert.assertTrue(winOperator.getStateDescriptor() instanceof AggregatingStateDescriptor);
+
+ processElementAndEnsureOutput(
+ winOperator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
+ }
+
+ @Test
+ public void testAggregateProcessingTime() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+
+ DataStream<Tuple2<String, Integer>> window1 = source
+ .windowAll(SlidingProcessingTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
+ .aggregate(new DummyAggregationFunction());
+
+ OneInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>> transform =
+ (OneInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>>) window1.getTransformation();
+
+ OneInputStreamOperator<Tuple2<String, Integer>, Tuple2<String, Integer>> operator = transform.getOperator();
+
+ Assert.assertTrue(operator instanceof WindowOperator);
+ WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?> winOperator =
+ (WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?>) operator;
+
+ Assert.assertTrue(winOperator.getTrigger() instanceof ProcessingTimeTrigger);
+ Assert.assertTrue(winOperator.getWindowAssigner() instanceof SlidingProcessingTimeWindows);
+ Assert.assertTrue(winOperator.getStateDescriptor() instanceof AggregatingStateDescriptor);
+
+ processElementAndEnsureOutput(
+ winOperator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
+ }
+
+ @Test
+ public void testAggregateWithWindowFunctionEventTime() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime);
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+
+ DataStream<Tuple3<String, String, Integer>> window = source
+ .windowAll(TumblingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS)))
+ .aggregate(new DummyAggregationFunction(), new TestAllWindowFunction());
+
+ OneInputTransformation<Tuple2<String, Integer>, Tuple3<String, String, Integer>> transform =
+ (OneInputTransformation<Tuple2<String, Integer>, Tuple3<String, String, Integer>>) window.getTransformation();
+
+ OneInputStreamOperator<Tuple2<String, Integer>, Tuple3<String, String, Integer>> operator = transform.getOperator();
+
+ Assert.assertTrue(operator instanceof WindowOperator);
+ WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?> winOperator =
+ (WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?>) operator;
+
+ Assert.assertTrue(winOperator.getTrigger() instanceof EventTimeTrigger);
+ Assert.assertTrue(winOperator.getWindowAssigner() instanceof TumblingEventTimeWindows);
+ Assert.assertTrue(winOperator.getStateDescriptor() instanceof AggregatingStateDescriptor);
+
+ processElementAndEnsureOutput(
+ operator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
+ }
+
+ @Test
+ public void testAggregateWithWindowFunctionProcessingTime() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+
+ DataStream<Tuple3<String, String, Integer>> window = source
+ .windowAll(TumblingProcessingTimeWindows.of(Time.of(1, TimeUnit.SECONDS)))
+ .aggregate(new DummyAggregationFunction(), new TestAllWindowFunction());
+
+ OneInputTransformation<Tuple2<String, Integer>, Tuple3<String, String, Integer>> transform =
+ (OneInputTransformation<Tuple2<String, Integer>, Tuple3<String, String, Integer>>) window.getTransformation();
+
+ OneInputStreamOperator<Tuple2<String, Integer>, Tuple3<String, String, Integer>> operator = transform.getOperator();
+
+ Assert.assertTrue(operator instanceof WindowOperator);
+ WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?> winOperator =
+ (WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?>) operator;
+
+ Assert.assertTrue(winOperator.getTrigger() instanceof ProcessingTimeTrigger);
+ Assert.assertTrue(winOperator.getWindowAssigner() instanceof TumblingProcessingTimeWindows);
+ Assert.assertTrue(winOperator.getStateDescriptor() instanceof AggregatingStateDescriptor);
+
+ processElementAndEnsureOutput(
+ operator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
+ }
+
+ // ------------------------------------------------------------------------
+ // fold() translation tests
+ // ------------------------------------------------------------------------
+
@Test
@SuppressWarnings("rawtypes")
public void testFoldEventTime() throws Exception {
@@ -548,6 +700,9 @@ public class AllWindowTranslationTest {
processElementAndEnsureOutput(winOperator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
}
+ // ------------------------------------------------------------------------
+ // apply() translation tests
+ // ------------------------------------------------------------------------
@Test
@SuppressWarnings("rawtypes")
@@ -854,4 +1009,65 @@ public class AllWindowTranslationTest {
return accumulator;
}
}
+
+ private static class DummyAggregationFunction
+ implements AggregateFunction<Tuple2<String, Integer>, Tuple2<String, Integer>, Tuple2<String, Integer>> {
+
+ @Override
+ public Tuple2<String, Integer> createAccumulator() {
+ return new Tuple2<>("", 0);
+ }
+
+ @Override
+ public void add(Tuple2<String, Integer> value, Tuple2<String, Integer> accumulator) {
+ accumulator.f0 = value.f0;
+ accumulator.f1 = value.f1;
+ }
+
+ @Override
+ public Tuple2<String, Integer> getResult(Tuple2<String, Integer> accumulator) {
+ return accumulator;
+ }
+
+ @Override
+ public Tuple2<String, Integer> merge(Tuple2<String, Integer> a, Tuple2<String, Integer> b) {
+ return a;
+ }
+ }
+
+ private static class DummyRichAggregationFunction<T> extends RichAggregateFunction<T, T, T> {
+
+ @Override
+ public T createAccumulator() {
+ return null;
+ }
+
+ @Override
+ public void add(T value, T accumulator) {}
+
+ @Override
+ public T getResult(T accumulator) {
+ return accumulator;
+ }
+
+ @Override
+ public T merge(T a, T b) {
+ return a;
+ }
+ }
+
+ private static class TestAllWindowFunction
+ implements AllWindowFunction<Tuple2<String, Integer>, Tuple3<String, String, Integer>, TimeWindow> {
+
+ @Override
+ public void apply(
+ TimeWindow window,
+ Iterable<Tuple2<String, Integer>> values,
+ Collector<Tuple3<String, String, Integer>> out) throws Exception {
+
+ for (Tuple2<String, Integer> in : values) {
+ out.collect(new Tuple3<>(in.f0, in.f0, in.f1));
+ }
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
index 492d275..f72a2f1 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
@@ -18,10 +18,13 @@
package org.apache.flink.streaming.runtime.operators.windowing;
import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FoldFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichAggregateFunction;
import org.apache.flink.api.common.functions.RichFoldFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
+import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
@@ -74,6 +77,10 @@ import static org.junit.Assert.fail;
@SuppressWarnings("serial")
public class WindowTranslationTest {
+ // ------------------------------------------------------------------------
+ // Rich Pre-Aggregation Functions
+ // ------------------------------------------------------------------------
+
/**
* .reduce() does not support RichReduceFunction, since the reduce function is used internally
* in a {@code ReducingState}.
@@ -89,7 +96,6 @@ public class WindowTranslationTest {
.keyBy(0)
.window(SlidingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
.reduce(new RichReduceFunction<Tuple2<String, Integer>>() {
- private static final long serialVersionUID = -6448847205314995812L;
@Override
public Tuple2<String, Integer> reduce(Tuple2<String, Integer> value1,
@@ -102,6 +108,25 @@ public class WindowTranslationTest {
}
/**
+ * .aggregate() does not support RichAggregateFunction, since the AggregationFunction is used internally
+ * in a {@code AggregatingState}.
+ */
+ @Test(expected = UnsupportedOperationException.class)
+ public void testAgrgegateWithRichFunctionFails() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ source
+ .keyBy(0)
+ .window(SlidingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
+ .aggregate(new DummyRichAggregationFunction<Tuple2<String,Integer>>());
+
+ fail("exception was not thrown");
+ }
+
+ /**
* .fold() does not support RichFoldFunction, since the fold function is used internally
* in a {@code FoldingState}.
*/
@@ -116,7 +141,6 @@ public class WindowTranslationTest {
.keyBy(0)
.window(SlidingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
.fold(new Tuple2<>("", 0), new RichFoldFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() {
- private static final long serialVersionUID = -6448847205314995812L;
@Override
public Tuple2<String, Integer> fold(Tuple2<String, Integer> value1,
@@ -128,6 +152,9 @@ public class WindowTranslationTest {
fail("exception was not thrown");
}
+ // ------------------------------------------------------------------------
+ // Merging Windows Support
+ // ------------------------------------------------------------------------
@Test
public void testSessionWithFoldFails() throws Exception {
@@ -137,7 +164,6 @@ public class WindowTranslationTest {
WindowedStream<String, String, TimeWindow> windowedStream = env.fromElements("Hello", "Ciao")
.keyBy(new KeySelector<String, String>() {
- private static final long serialVersionUID = -3298887124448443076L;
@Override
public String getKey(String value) throws Exception {
@@ -224,6 +250,11 @@ public class WindowTranslationTest {
fail("The trigger call should fail.");
}
+
+ // ------------------------------------------------------------------------
+ // Reduce Translation Tests
+ // ------------------------------------------------------------------------
+
@Test
@SuppressWarnings("rawtypes")
public void testReduceEventTime() throws Exception {
@@ -416,6 +447,132 @@ public class WindowTranslationTest {
processElementAndEnsureOutput(operator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
}
+ // ------------------------------------------------------------------------
+ // Aggregate Translation Tests
+ // ------------------------------------------------------------------------
+
+ @Test
+ public void testAggregateEventTime() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime);
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+
+ DataStream<Tuple2<String, Integer>> window1 = source
+ .keyBy(new TupleKeySelector())
+ .window(SlidingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
+ .aggregate(new DummyAggregationFunction());
+
+ OneInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>> transform =
+ (OneInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>>) window1.getTransformation();
+
+ OneInputStreamOperator<Tuple2<String, Integer>, Tuple2<String, Integer>> operator = transform.getOperator();
+
+ Assert.assertTrue(operator instanceof WindowOperator);
+ WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?> winOperator =
+ (WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?>) operator;
+
+ Assert.assertTrue(winOperator.getTrigger() instanceof EventTimeTrigger);
+ Assert.assertTrue(winOperator.getWindowAssigner() instanceof SlidingEventTimeWindows);
+ Assert.assertTrue(winOperator.getStateDescriptor() instanceof AggregatingStateDescriptor);
+
+ processElementAndEnsureOutput(
+ winOperator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
+ }
+
+ @Test
+ public void testAggregateProcessingTime() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+
+ DataStream<Tuple2<String, Integer>> window1 = source
+ .keyBy(new TupleKeySelector())
+ .window(SlidingProcessingTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
+ .aggregate(new DummyAggregationFunction());
+
+ OneInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>> transform =
+ (OneInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>>) window1.getTransformation();
+
+ OneInputStreamOperator<Tuple2<String, Integer>, Tuple2<String, Integer>> operator = transform.getOperator();
+
+ Assert.assertTrue(operator instanceof WindowOperator);
+ WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?> winOperator =
+ (WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?>) operator;
+
+ Assert.assertTrue(winOperator.getTrigger() instanceof ProcessingTimeTrigger);
+ Assert.assertTrue(winOperator.getWindowAssigner() instanceof SlidingProcessingTimeWindows);
+ Assert.assertTrue(winOperator.getStateDescriptor() instanceof AggregatingStateDescriptor);
+
+ processElementAndEnsureOutput(
+ winOperator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
+ }
+
+ @Test
+ public void testAggregateWithWindowFunctionEventTime() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime);
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+
+ DummyReducer reducer = new DummyReducer();
+
+ DataStream<Tuple3<String, String, Integer>> window = source
+ .keyBy(new TupleKeySelector())
+ .window(TumblingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS)))
+ .aggregate(new DummyAggregationFunction(), new TestWindowFunction());
+
+ OneInputTransformation<Tuple2<String, Integer>, Tuple3<String, String, Integer>> transform =
+ (OneInputTransformation<Tuple2<String, Integer>, Tuple3<String, String, Integer>>) window.getTransformation();
+
+ OneInputStreamOperator<Tuple2<String, Integer>, Tuple3<String, String, Integer>> operator = transform.getOperator();
+
+ Assert.assertTrue(operator instanceof WindowOperator);
+ WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?> winOperator =
+ (WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?>) operator;
+
+ Assert.assertTrue(winOperator.getTrigger() instanceof EventTimeTrigger);
+ Assert.assertTrue(winOperator.getWindowAssigner() instanceof TumblingEventTimeWindows);
+ Assert.assertTrue(winOperator.getStateDescriptor() instanceof AggregatingStateDescriptor);
+
+ processElementAndEnsureOutput(
+ operator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
+ }
+
+ @Test
+ public void testAggregateWithWindowFunctionProcessingTime() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ DataStream<Tuple2<String, Integer>> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2));
+
+ DataStream<Tuple3<String, String, Integer>> window = source
+ .keyBy(new TupleKeySelector())
+ .window(TumblingProcessingTimeWindows.of(Time.of(1, TimeUnit.SECONDS)))
+ .aggregate(new DummyAggregationFunction(), new TestWindowFunction());
+
+ OneInputTransformation<Tuple2<String, Integer>, Tuple3<String, String, Integer>> transform =
+ (OneInputTransformation<Tuple2<String, Integer>, Tuple3<String, String, Integer>>) window.getTransformation();
+
+ OneInputStreamOperator<Tuple2<String, Integer>, Tuple3<String, String, Integer>> operator = transform.getOperator();
+
+ Assert.assertTrue(operator instanceof WindowOperator);
+ WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?> winOperator =
+ (WindowOperator<String, Tuple2<String, Integer>, ?, ?, ?>) operator;
+
+ Assert.assertTrue(winOperator.getTrigger() instanceof ProcessingTimeTrigger);
+ Assert.assertTrue(winOperator.getWindowAssigner() instanceof TumblingProcessingTimeWindows);
+ Assert.assertTrue(winOperator.getStateDescriptor() instanceof AggregatingStateDescriptor);
+
+ processElementAndEnsureOutput(
+ operator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
+ }
+
+ // ------------------------------------------------------------------------
+ // Fold Translation Tests
+ // ------------------------------------------------------------------------
+
@Test
@SuppressWarnings("rawtypes")
public void testFoldEventTime() throws Exception {
@@ -577,6 +734,10 @@ public class WindowTranslationTest {
processElementAndEnsureOutput(winOperator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1));
}
+ // ------------------------------------------------------------------------
+ // Apply Translation Tests
+ // ------------------------------------------------------------------------
+
@Test
@SuppressWarnings("rawtypes")
public void testApplyEventTime() throws Exception {
@@ -873,7 +1034,6 @@ public class WindowTranslationTest {
// ------------------------------------------------------------------------
public static class DummyReducer implements ReduceFunction<Tuple2<String, Integer>> {
- private static final long serialVersionUID = 1L;
@Override
public Tuple2<String, Integer> reduce(Tuple2<String, Integer> value1, Tuple2<String, Integer> value2) throws Exception {
@@ -890,13 +1050,72 @@ public class WindowTranslationTest {
}
}
+ private static class DummyAggregationFunction
+ implements AggregateFunction<Tuple2<String, Integer>, Tuple2<String, Integer>, Tuple2<String, Integer>> {
+
+ @Override
+ public Tuple2<String, Integer> createAccumulator() {
+ return new Tuple2<>("", 0);
+ }
+
+ @Override
+ public void add(Tuple2<String, Integer> value, Tuple2<String, Integer> accumulator) {
+ accumulator.f0 = value.f0;
+ accumulator.f1 = value.f1;
+ }
+
+ @Override
+ public Tuple2<String, Integer> getResult(Tuple2<String, Integer> accumulator) {
+ return accumulator;
+ }
+
+ @Override
+ public Tuple2<String, Integer> merge(Tuple2<String, Integer> a, Tuple2<String, Integer> b) {
+ return a;
+ }
+ }
+
+ private static class DummyRichAggregationFunction<T> extends RichAggregateFunction<T, T, T> {
+
+ @Override
+ public T createAccumulator() {
+ return null;
+ }
+
+ @Override
+ public void add(T value, T accumulator) {}
+
+ @Override
+ public T getResult(T accumulator) {
+ return accumulator;
+ }
+
+ @Override
+ public T merge(T a, T b) {
+ return a;
+ }
+ }
+
+ private static class TestWindowFunction
+ implements WindowFunction<Tuple2<String, Integer>, Tuple3<String, String, Integer>, String, TimeWindow> {
+
+ @Override
+ public void apply(String key,
+ TimeWindow window,
+ Iterable<Tuple2<String, Integer>> values,
+ Collector<Tuple3<String, String, Integer>> out) throws Exception {
+
+ for (Tuple2<String, Integer> in : values) {
+ out.collect(new Tuple3<>(in.f0, in.f0, in.f1));
+ }
+ }
+ }
+
private static class TupleKeySelector implements KeySelector<Tuple2<String, Integer>, String> {
- private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple2<String, Integer> value) throws Exception {
return value.f0;
}
}
-
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala
index 324689a..7f52252 100644
--- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala
@@ -19,7 +19,7 @@
package org.apache.flink.streaming.api.scala
import org.apache.flink.annotation.{PublicEvolving, Public}
-import org.apache.flink.api.common.functions.{FoldFunction, ReduceFunction}
+import org.apache.flink.api.common.functions.{AggregateFunction, FoldFunction, ReduceFunction}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.streaming.api.datastream.{AllWindowedStream => JavaAllWStream}
import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction.AggregationType
@@ -32,6 +32,8 @@ import org.apache.flink.streaming.api.windowing.triggers.Trigger
import org.apache.flink.streaming.api.windowing.windows.Window
import org.apache.flink.util.Collector
+import org.apache.flink.util.Preconditions.checkNotNull
+
/**
* A [[AllWindowedStream]] represents a data stream where the stream of
* elements is split into windows based on a
@@ -92,9 +94,11 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) {
}
// ------------------------------------------------------------------------
- // Operations on the keyed windows
+ // Operations on the windows
// ------------------------------------------------------------------------
+ // ---------------------------- reduce() ------------------------------------
+
/**
* Applies a reduce function to the window. The window function is called for each evaluation
* of the window for each key individually. The output of the reduce function is interpreted
@@ -195,6 +199,94 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) {
asScalaStream(javaStream.reduce(reducer, applyFunction, returnType))
}
+ // --------------------------- aggregate() ----------------------------------
+
+ /**
+ * Applies the given aggregation function to each window. The aggregation function
+ * is called for each element, aggregating values incrementally and keeping the state to
+ * one accumulator per window.
+ *
+ * @param aggregateFunction The aggregation function.
+ * @return The data stream that is the result of applying the fold function to the window.
+ */
+ def aggregate[ACC: TypeInformation, R: TypeInformation]
+ (aggregateFunction: AggregateFunction[T, ACC, R]): DataStream[R] = {
+
+ checkNotNull(aggregateFunction, "AggregationFunction must not be null")
+
+ val accumulatorType: TypeInformation[ACC] = implicitly[TypeInformation[ACC]]
+ val resultType: TypeInformation[R] = implicitly[TypeInformation[R]]
+
+ asScalaStream(javaStream.aggregate(
+ clean(aggregateFunction), accumulatorType, resultType))
+ }
+
+ /**
+ * Applies the given window function to each window. The window function is called for each
+ * evaluation of the window for each key individually. The output of the window function is
+ * interpreted as a regular non-windowed stream.
+ *
+ * Arriving data is pre-aggregated using the given aggregation function.
+ *
+ * @param preAggregator The aggregation function that is used for pre-aggregation
+ * @param windowFunction The window function.
+ * @return The data stream that is the result of applying the window function to the window.
+ */
+ def aggregate[ACC: TypeInformation, V: TypeInformation, R: TypeInformation]
+ (preAggregator: AggregateFunction[T, ACC, V],
+ windowFunction: AllWindowFunction[V, R, W]): DataStream[R] = {
+
+ checkNotNull(preAggregator, "AggregationFunction must not be null")
+ checkNotNull(windowFunction, "Window function must not be null")
+
+ val cleanedPreAggregator = clean(preAggregator)
+ val cleanedWindowFunction = clean(windowFunction)
+
+ val applyFunction = new ScalaAllWindowFunctionWrapper[V, R, W](cleanedWindowFunction)
+
+ val accumulatorType: TypeInformation[ACC] = implicitly[TypeInformation[ACC]]
+ val aggregationResultType: TypeInformation[V] = implicitly[TypeInformation[V]]
+ val resultType: TypeInformation[R] = implicitly[TypeInformation[R]]
+
+ asScalaStream(javaStream.aggregate(
+ cleanedPreAggregator, applyFunction,
+ accumulatorType, aggregationResultType, resultType))
+ }
+
+ /**
+ * Applies the given window function to each window. The window function is called for each
+ * evaluation of the window. The output of the window function is
+ * interpreted as a regular non-windowed stream.
+ *
+ * Arriving data is pre-aggregated using the given aggregation function.
+ *
+ * @param preAggregator The aggregation function that is used for pre-aggregation
+ * @param windowFunction The window function.
+ * @return The data stream that is the result of applying the window function to the window.
+ */
+ def aggregate[ACC: TypeInformation, V: TypeInformation, R: TypeInformation]
+ (preAggregator: AggregateFunction[T, ACC, V],
+ windowFunction: (W, Iterable[V], Collector[R]) => Unit): DataStream[R] = {
+
+ checkNotNull(preAggregator, "AggregationFunction must not be null")
+ checkNotNull(windowFunction, "Window function must not be null")
+
+ val cleanPreAggregator = clean(preAggregator)
+ val cleanWindowFunction = clean(windowFunction)
+
+ val applyFunction = new ScalaAllWindowFunction[V, R, W](cleanWindowFunction)
+
+ val accumulatorType: TypeInformation[ACC] = implicitly[TypeInformation[ACC]]
+ val aggregationResultType: TypeInformation[V] = implicitly[TypeInformation[V]]
+ val resultType: TypeInformation[R] = implicitly[TypeInformation[R]]
+
+ asScalaStream(javaStream.aggregate(
+ cleanPreAggregator, applyFunction,
+ accumulatorType, aggregationResultType, resultType))
+ }
+
+ // ----------------------------- fold() -------------------------------------
+
/**
* Applies the given fold function to each window. The window function is called for each
* evaluation of the window for each key individually. The output of the reduce function is
@@ -298,6 +390,8 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) {
asScalaStream(javaStream.fold(initialValue, folder, applyFunction, accType, returnType))
}
+ // ---------------------------- apply() -------------------------------------
+
/**
* Applies the given window function to each window. The window function is called for each
* evaluation of the window for each key individually. The output of the window function is
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
index db187ea..ab27820 100644
--- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
@@ -19,7 +19,7 @@
package org.apache.flink.streaming.api.scala
import org.apache.flink.annotation.{PublicEvolving, Public}
-import org.apache.flink.api.common.functions.{FoldFunction, ReduceFunction}
+import org.apache.flink.api.common.functions.{AggregateFunction, FoldFunction, ReduceFunction}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.streaming.api.datastream.{WindowedStream => JavaWStream}
import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction.AggregationType
@@ -98,6 +98,8 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) {
// Operations on the keyed windows
// ------------------------------------------------------------------------
+ // --------------------------- reduce() -----------------------------------
+
/**
* Applies a reduce function to the window. The window function is called for each evaluation
* of the window for each key individually. The output of the reduce function is interpreted
@@ -196,6 +198,86 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) {
asScalaStream(javaStream.reduce(reducer, applyFunction, implicitly[TypeInformation[R]]))
}
+ // -------------------------- aggregate() ---------------------------------
+
+ /**
+ * Applies the given aggregation function to each window and key. The aggregation function
+ * is called for each element, aggregating values incrementally and keeping the state to
+ * one accumulator per key and window.
+ *
+ * @param aggregateFunction The aggregation function.
+ * @return The data stream that is the result of applying the fold function to the window.
+ */
+ def aggregate[ACC: TypeInformation, R: TypeInformation]
+ (aggregateFunction: AggregateFunction[T, ACC, R]): DataStream[R] = {
+
+ val accumulatorType: TypeInformation[ACC] = implicitly[TypeInformation[ACC]]
+ val resultType: TypeInformation[R] = implicitly[TypeInformation[R]]
+
+ asScalaStream(javaStream.aggregate(
+ clean(aggregateFunction), accumulatorType, resultType))
+ }
+
+ /**
+ * Applies the given window function to each window. The window function is called for each
+ * evaluation of the window for each key individually. The output of the window function is
+ * interpreted as a regular non-windowed stream.
+ *
+ * Arriving data is pre-aggregated using the given aggregation function.
+ *
+ * @param preAggregator The aggregation function that is used for pre-aggregation
+ * @param windowFunction The window function.
+ * @return The data stream that is the result of applying the window function to the window.
+ */
+ def aggregate[ACC: TypeInformation, V: TypeInformation, R: TypeInformation]
+ (preAggregator: AggregateFunction[T, ACC, V],
+ windowFunction: WindowFunction[V, R, K, W]): DataStream[R] = {
+
+ val cleanedPreAggregator = clean(preAggregator)
+ val cleanedWindowFunction = clean(windowFunction)
+
+ val applyFunction = new ScalaWindowFunctionWrapper[V, R, K, W](cleanedWindowFunction)
+
+ val accumulatorType: TypeInformation[ACC] = implicitly[TypeInformation[ACC]]
+ val aggregationResultType: TypeInformation[V] = implicitly[TypeInformation[V]]
+ val resultType: TypeInformation[R] = implicitly[TypeInformation[R]]
+
+ asScalaStream(javaStream.aggregate(
+ cleanedPreAggregator, applyFunction,
+ accumulatorType, aggregationResultType, resultType))
+ }
+
+ /**
+ * Applies the given window function to each window. The window function is called for each
+ * evaluation of the window for each key individually. The output of the window function is
+ * interpreted as a regular non-windowed stream.
+ *
+ * Arriving data is pre-aggregated using the given aggregation function.
+ *
+ * @param preAggregator The aggregation function that is used for pre-aggregation
+ * @param windowFunction The window function.
+ * @return The data stream that is the result of applying the window function to the window.
+ */
+ def aggregate[ACC: TypeInformation, V: TypeInformation, R: TypeInformation]
+ (preAggregator: AggregateFunction[T, ACC, V],
+ windowFunction: (K, W, Iterable[V], Collector[R]) => Unit): DataStream[R] = {
+
+ val cleanedPreAggregator = clean(preAggregator)
+ val cleanedWindowFunction = clean(windowFunction)
+
+ val applyFunction = new ScalaWindowFunction[V, R, K, W](cleanedWindowFunction)
+
+ val accumulatorType: TypeInformation[ACC] = implicitly[TypeInformation[ACC]]
+ val aggregationResultType: TypeInformation[V] = implicitly[TypeInformation[V]]
+ val resultType: TypeInformation[R] = implicitly[TypeInformation[R]]
+
+ asScalaStream(javaStream.aggregate(
+ cleanedPreAggregator, applyFunction,
+ accumulatorType, aggregationResultType, resultType))
+ }
+
+ // ---------------------------- fold() ------------------------------------
+
/**
* Applies the given fold function to each window. The window function is called for each
* evaluation of the window for each key individually. The output of the reduce function is
@@ -297,6 +379,8 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) {
asScalaStream(javaStream.fold(initialValue, folder, applyFunction, accType, resultType))
}
+ // ---------------------------- apply() -------------------------------------
+
/**
* Applies the given window function to each window. The window function is called for each
* evaluation of the window for each key individually. The output of the window function is
http://git-wip-us.apache.org/repos/asf/flink/blob/1542260d/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/AllWindowTranslationTest.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/AllWindowTranslationTest.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/AllWindowTranslationTest.scala
index c738955..7e067a0 100644
--- a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/AllWindowTranslationTest.scala
+++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/AllWindowTranslationTest.scala
@@ -20,13 +20,13 @@ package org.apache.flink.streaming.api.scala
import org.apache.flink.api.common.ExecutionConfig
-import org.apache.flink.api.common.functions.{FoldFunction, RichFoldFunction, RichReduceFunction}
-import org.apache.flink.api.common.state.{FoldingStateDescriptor, ListStateDescriptor, ReducingStateDescriptor}
+import org.apache.flink.api.common.functions._
+import org.apache.flink.api.common.state.{AggregatingStateDescriptor, FoldingStateDescriptor, ListStateDescriptor, ReducingStateDescriptor}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
-import org.apache.flink.streaming.api.scala.function.AllWindowFunction
+import org.apache.flink.streaming.api.scala.function.{WindowFunction, AllWindowFunction}
import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.streaming.api.windowing.assigners._
import org.apache.flink.streaming.api.windowing.evictors.CountEvictor
@@ -49,6 +49,10 @@ import org.junit.Test
*/
class AllWindowTranslationTest {
+ // ------------------------------------------------------------------------
+ // rich function tests
+ // ------------------------------------------------------------------------
+
/**
* .reduce() does not support [[RichReduceFunction]], since the reduce function is used
* internally in a [[org.apache.flink.api.common.state.ReducingState]].
@@ -70,6 +74,24 @@ class AllWindowTranslationTest {
}
/**
+ * .aggregate() does not support [[RichAggregateFunction]], since the reduce function is used
+ * internally in a [[org.apache.flink.api.common.state.AggregatingState]].
+ */
+ @Test(expected = classOf[UnsupportedOperationException])
+ def testAggregateWithRichFunctionFails() {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val source = env.fromElements(("hello", 1), ("hello", 2))
+
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime)
+
+ source
+ .windowAll(SlidingEventTimeWindows.of(Time.seconds(1), Time.milliseconds(100)))
+ .aggregate(new DummyRichAggregator())
+
+ fail("exception was not thrown")
+ }
+
+ /**
* .fold() does not support [[RichFoldFunction]], since the reduce function is used internally
* in a [[org.apache.flink.api.common.state.FoldingState]].
*/
@@ -89,6 +111,10 @@ class AllWindowTranslationTest {
fail("exception was not thrown")
}
+ // ------------------------------------------------------------------------
+ // merging window precondition
+ // ------------------------------------------------------------------------
+
@Test
def testSessionWithFoldFails() {
// verify that fold does not work with merging windows
@@ -148,6 +174,10 @@ class AllWindowTranslationTest {
fail("The trigger call should fail.")
}
+ // ------------------------------------------------------------------------
+ // reduce() translation tests
+ // ------------------------------------------------------------------------
+
@Test
def testReduceEventTime() {
val env = StreamExecutionEnvironment.getExecutionEnvironment
@@ -399,6 +429,182 @@ class AllWindowTranslationTest {
("hello", 1))
}
+ // ------------------------------------------------------------------------
+ // aggregate() translation tests
+ // ------------------------------------------------------------------------
+
+ @Test
+ def testAggregateEventTime() {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime)
+
+ val source = env.fromElements(("hello", 1), ("hello", 2))
+
+ val window1 = source
+ .windowAll(SlidingEventTimeWindows.of(Time.seconds(1), Time.milliseconds(100)))
+ .aggregate(new DummyAggregator())
+
+ val transform = window1
+ .javaStream
+ .getTransformation
+ .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+ val operator = transform.getOperator
+ assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+ val winOperator = operator
+ .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+ assertTrue(winOperator.getTrigger.isInstanceOf[EventTimeTrigger])
+ assertTrue(winOperator.getWindowAssigner.isInstanceOf[SlidingEventTimeWindows])
+ assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+ processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+ winOperator,
+ winOperator.getKeySelector,
+ BasicTypeInfo.STRING_TYPE_INFO,
+ ("hello", 1))
+ }
+
+ @Test
+ def testAggregateProcessingTime() {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime)
+
+ val source = env.fromElements(("hello", 1), ("hello", 2))
+
+ val window1 = source
+ .windowAll(SlidingProcessingTimeWindows.of(Time.seconds(1), Time.milliseconds(100)))
+ .aggregate(new DummyAggregator())
+
+ val transform = window1
+ .javaStream
+ .getTransformation
+ .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+ val operator = transform.getOperator
+ assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+ val winOperator = operator
+ .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+ assertTrue(winOperator.getTrigger.isInstanceOf[ProcessingTimeTrigger])
+ assertTrue(winOperator.getWindowAssigner.isInstanceOf[SlidingProcessingTimeWindows])
+ assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+ processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+ winOperator,
+ winOperator.getKeySelector,
+ BasicTypeInfo.STRING_TYPE_INFO,
+ ("hello", 1))
+ }
+
+ @Test
+ def testAggregateWithWindowFunctionEventTime() {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime)
+
+ val source = env.fromElements(("hello", 1), ("hello", 2))
+
+ val window1 = source
+ .windowAll(TumblingEventTimeWindows.of(Time.seconds(1)))
+ .aggregate(new DummyAggregator(), new TestAllWindowFunction())
+
+ val transform = window1
+ .javaStream
+ .getTransformation
+ .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+ val operator = transform.getOperator
+ assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+ val winOperator = operator
+ .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+ assertTrue(winOperator.getTrigger.isInstanceOf[EventTimeTrigger])
+ assertTrue(winOperator.getWindowAssigner.isInstanceOf[TumblingEventTimeWindows])
+ assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+ processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+ winOperator,
+ winOperator.getKeySelector,
+ BasicTypeInfo.STRING_TYPE_INFO,
+ ("hello", 1))
+ }
+
+ @Test
+ def testAggregateWithWindowFunctionProcessingTime() {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime)
+
+ val source = env.fromElements(("hello", 1), ("hello", 2))
+
+ val window1 = source
+ .windowAll(TumblingProcessingTimeWindows.of(Time.seconds(1)))
+ .aggregate(new DummyAggregator(), new TestAllWindowFunction())
+
+ val transform = window1
+ .javaStream
+ .getTransformation
+ .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+ val operator = transform.getOperator
+ assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+ val winOperator = operator
+ .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+ assertTrue(winOperator.getTrigger.isInstanceOf[ProcessingTimeTrigger])
+ assertTrue(winOperator.getWindowAssigner.isInstanceOf[TumblingProcessingTimeWindows])
+ assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+ processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+ winOperator,
+ winOperator.getKeySelector,
+ BasicTypeInfo.STRING_TYPE_INFO,
+ ("hello", 1))
+ }
+
+ @Test
+ def testAggregateWithWindowFunctionEventTimeWithScalaFunction() {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime)
+
+ val source = env.fromElements(("hello", 1), ("hello", 2))
+
+ val window1 = source
+ .windowAll(TumblingEventTimeWindows.of(Time.seconds(1)))
+ .aggregate(
+ new DummyAggregator(),
+ { (_, in: Iterable[(String, Int)], out: Collector[(String, Int)]) => {
+ in foreach { x => out.collect(x)}
+ } })
+
+ val transform = window1
+ .javaStream
+ .getTransformation
+ .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]]
+
+ val operator = transform.getOperator
+ assertTrue(operator.isInstanceOf[WindowOperator[_, _, _, _, _ <: Window]])
+
+ val winOperator = operator
+ .asInstanceOf[WindowOperator[String, (String, Int), _, (String, Int), _ <: Window]]
+
+ assertTrue(winOperator.getTrigger.isInstanceOf[EventTimeTrigger])
+ assertTrue(winOperator.getWindowAssigner.isInstanceOf[TumblingEventTimeWindows])
+ assertTrue(winOperator.getStateDescriptor.isInstanceOf[AggregatingStateDescriptor[_, _, _]])
+
+ processElementAndEnsureOutput[String, (String, Int), (String, Int)](
+ winOperator,
+ winOperator.getKeySelector,
+ BasicTypeInfo.STRING_TYPE_INFO,
+ ("hello", 1))
+ }
+
+ // ------------------------------------------------------------------------
+ // fold() translation tests
+ // ------------------------------------------------------------------------
@Test
def testFoldEventTime() {
@@ -662,6 +868,9 @@ class AllWindowTranslationTest {
("hello", 1))
}
+ // ------------------------------------------------------------------------
+ // apply() translation tests
+ // ------------------------------------------------------------------------
@Test
def testApplyEventTime() {
@@ -1033,3 +1242,14 @@ class AllWindowTranslationTest {
testHarness.close()
}
}
+
+class TestAllWindowFunction extends AllWindowFunction[(String, Int), (String, Int), TimeWindow] {
+
+ override def apply(
+ window: TimeWindow,
+ input: Iterable[(String, Int)],
+ out: Collector[(String, Int)]): Unit = {
+
+ input.foreach(out.collect)
+ }
+}