You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ec...@apache.org on 2019/06/26 15:22:41 UTC

[beam] 03/07: Implement reduce part of CombineGlobally translation with windowing

This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit bba08b4201a1dc53ae117d9bd495b117923e189d
Author: Etienne Chauchot <ec...@apache.org>
AuthorDate: Thu Jun 13 11:23:52 2019 +0200

    Implement reduce part of CombineGlobally translation with windowing
---
 .../batch/AggregatorCombinerGlobally.java          | 165 +++++++++++++++++----
 .../batch/CombineGloballyTranslatorBatch.java      |  19 +--
 2 files changed, 144 insertions(+), 40 deletions(-)

diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java
index 2f8293b..0d13218 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java
@@ -18,60 +18,173 @@
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
 import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.Row;
 import org.apache.spark.sql.expressions.Aggregator;
+import org.joda.time.Instant;
+import scala.Tuple2;
 
-/** An {@link Aggregator} for the Spark Batch Runner. */
-class AggregatorCombinerGlobally<InputT, AccumT, OutputT>
-    extends Aggregator<InputT, AccumT, OutputT> {
+/** An {@link Aggregator} for the Spark Batch Runner. It does not use ReduceFnRunner
+ * for windowMerging, because reduceFnRunner is based on state which requires a keyed collection.
+ * The accumulator is a {@code Iterable<WindowedValue<AccumT>> because an {@code InputT} can be in multiple windows. So, when accumulating {@code InputT} values, we create one accumulator per input window.
+ * */
+
+class AggregatorCombinerGlobally<InputT, AccumT, OutputT, W extends BoundedWindow>
+    extends Aggregator<WindowedValue<InputT>, Iterable<WindowedValue<AccumT>>, WindowedValue<OutputT>> {
 
   private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn;
+  private WindowingStrategy<InputT, W> windowingStrategy;
+  private TimestampCombiner timestampCombiner;
 
-  public AggregatorCombinerGlobally(Combine.CombineFn<InputT, AccumT, OutputT> combineFn) {
+  public AggregatorCombinerGlobally(Combine.CombineFn<InputT, AccumT, OutputT> combineFn, WindowingStrategy<?, ?> windowingStrategy) {
     this.combineFn = combineFn;
+    this.windowingStrategy = (WindowingStrategy<InputT, W>) windowingStrategy;
+    this.timestampCombiner = windowingStrategy.getTimestampCombiner();
   }
 
-  @Override
-  public AccumT zero() {
-    return combineFn.createAccumulator();
+  @Override public Iterable<WindowedValue<AccumT>> zero() {
+    return new ArrayList<>();
   }
 
-  @Override
-  public AccumT reduce(AccumT accumulator, InputT input) {
-    // because of generic type InputT, spark cannot infer an input type.
-    // it would pass Integer as input if we had a Aggregator<Integer, ..., ...>
-    // without the type inference it stores input in a GenericRowWithSchema
-    Row row = (Row) input;
-    InputT t = RowHelpers.extractObjectFromRow(row);
-    return combineFn.addInput(accumulator, t);
+  @Override public Iterable<WindowedValue<AccumT>> reduce(Iterable<WindowedValue<AccumT>> accumulators,
+      WindowedValue<InputT> input) {
+
+    //concatenate accumulators windows and input windows and merge the windows
+    Collection<W> inputWindows = (Collection<W>)input.getWindows();
+    Set<W> windows = collectAccumulatorsWindows(accumulators);
+    windows.addAll(inputWindows);
+    Map<W, W> windowToMergeResult = null;
+    try {
+      windowToMergeResult = mergeWindows(windowingStrategy, windows);
+    } catch (Exception e) {
+      throw new RuntimeException("Unable to merge accumulators windows and input windows", e);
+    }
+
+    // iterate through the input windows and for each, create an accumulator with the merged window
+    // associated to it and call addInput with the accumulator.
+    // Maintain a map of the accumulators for use as output
+    Map<W, Tuple2<AccumT, Instant>> mapState = new HashMap<>();
+    for (W inputWindow:inputWindows) {
+      W mergedWindow = windowToMergeResult.get(inputWindow);
+      mergedWindow = mergedWindow == null ? inputWindow : mergedWindow;
+      Tuple2<AccumT, Instant> accumAndInstant = mapState.get(mergedWindow);
+      // if there is no accumulator associated with this window yet, create one
+      if (accumAndInstant == null) {
+        AccumT accum = combineFn.addInput(combineFn.createAccumulator(), input.getValue());
+        Instant windowTimestamp =
+            timestampCombiner.assign(
+                mergedWindow, windowingStrategy.getWindowFn().getOutputTime(input.getTimestamp(), mergedWindow));
+        accumAndInstant = new Tuple2<>(accum, windowTimestamp);
+        mapState.put(mergedWindow, accumAndInstant);
+      } else {
+        AccumT updatedAccum =
+            combineFn.addInput(accumAndInstant._1, input.getValue());
+        Instant updatedTimestamp = timestampCombiner.combine(accumAndInstant._2, timestampCombiner
+            .assign(mergedWindow,
+                windowingStrategy.getWindowFn().getOutputTime(input.getTimestamp(), mergedWindow)));
+        accumAndInstant = new Tuple2<>(updatedAccum, updatedTimestamp);
+      }
+    }
+    // output the accumulators map
+    List<WindowedValue<AccumT>> result = new ArrayList<>();
+    for (Map.Entry<W, Tuple2<AccumT, Instant>> entry : mapState.entrySet()) {
+      AccumT accumulator = entry.getValue()._1;
+      Instant windowTimestamp = entry.getValue()._2;
+      W window = entry.getKey();
+      result.add(WindowedValue.of(accumulator, windowTimestamp, window, PaneInfo.NO_FIRING));
+    }
+    return result;
   }
 
-  @Override
-  public AccumT merge(AccumT accumulator1, AccumT accumulator2) {
+  @Override public Iterable<WindowedValue<AccumT>> merge(
+      Iterable<WindowedValue<AccumT>> accumulators1,
+      Iterable<WindowedValue<AccumT>> accumulators2) {
+    // TODO
+    /*
     ArrayList<AccumT> accumulators = new ArrayList<>();
     accumulators.add(accumulator1);
     accumulators.add(accumulator2);
     return combineFn.mergeAccumulators(accumulators);
+*/
+    return null;
   }
 
-  @Override
-  public OutputT finish(AccumT reduction) {
-    return combineFn.extractOutput(reduction);
+  @Override public WindowedValue<OutputT> finish(Iterable<WindowedValue<AccumT>> reduction) {
+    // TODO
+    //    return combineFn.extractOutput(reduction);
+    return null;
   }
 
-  @Override
-  public Encoder<AccumT> bufferEncoder() {
+  @Override public Encoder<Iterable<WindowedValue<AccumT>>> bufferEncoder() {
     // TODO replace with accumulatorCoder if possible
     return EncoderHelpers.genericEncoder();
   }
 
-  @Override
-  public Encoder<OutputT> outputEncoder() {
+  @Override public Encoder<WindowedValue<OutputT>> outputEncoder() {
     // TODO replace with outputCoder if possible
     return EncoderHelpers.genericEncoder();
   }
+
+  private Set<W> collectAccumulatorsWindows(Iterable<WindowedValue<AccumT>> accumulators) {
+    Set<W> windows = new HashSet<>();
+    for (WindowedValue<?> accumulator : accumulators) {
+      // an accumulator has only one window associated to it.
+      W accumulatorWindow = (W) accumulator.getWindows().iterator().next();
+      windows.add(accumulatorWindow);
+    } return windows;
+  }
+
+  private Map<W, W> mergeWindows(WindowingStrategy<InputT, W> windowingStrategy, Set<W> windows)
+      throws Exception {
+    WindowFn<InputT, W> windowFn = windowingStrategy.getWindowFn();
+
+    if (windowingStrategy.getWindowFn().isNonMerging()) {
+      // Return an empty map, indicating that every window is not merged.
+      return Collections.emptyMap();
+    }
+
+    Map<W, W> windowToMergeResult = new HashMap<>();
+    windowFn.mergeWindows(new MergeContextImpl(windowFn, windows, windowToMergeResult));
+    return windowToMergeResult;
+  }
+
+
+  private class MergeContextImpl extends WindowFn<InputT, W>.MergeContext {
+
+    private Set<W> windows;
+    private Map<W, W> windowToMergeResult;
+
+    MergeContextImpl(WindowFn<InputT, W> windowFn, Set<W> windows, Map<W, W> windowToMergeResult) {
+      windowFn.super();
+      this.windows = windows;
+      this.windowToMergeResult = windowToMergeResult;
+    }
+
+    @Override
+    public Collection<W> windows() {
+      return windows;
+    }
+
+    @Override
+    public void merge(Collection<W> toBeMerged, W mergeResult) throws Exception {
+      for (W w : toBeMerged) {
+        windowToMergeResult.put(w, mergeResult);
+      }
+    }
+  }
+
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
index 53651cf..f18572b 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
@@ -26,6 +26,7 @@ import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 
@@ -50,25 +51,15 @@ class CombineGloballyTranslatorBatch<InputT, AccumT, OutputT>
     @SuppressWarnings("unchecked")
     final Combine.CombineFn<InputT, AccumT, OutputT> combineFn =
         (Combine.CombineFn<InputT, AccumT, OutputT>) combineTransform.getFn();
-
+    WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
     Dataset<WindowedValue<InputT>> inputDataset = context.getDataset(input);
 
-    //TODO merge windows instead of doing unwindow/window to comply with beam model
-    Dataset<InputT> unWindowedDataset =
-        inputDataset.map(WindowingHelpers.unwindowMapFunction(), EncoderHelpers.genericEncoder());
-
     Dataset<Row> combinedRowDataset =
-        unWindowedDataset.agg(new AggregatorCombinerGlobally<>(combineFn).toColumn());
-
-    Dataset<OutputT> combinedDataset =
-        combinedRowDataset.map(
-            RowHelpers.extractObjectFromRowMapFunction(), EncoderHelpers.genericEncoder());
+        inputDataset.agg(new AggregatorCombinerGlobally<>(combineFn, windowingStrategy).toColumn());
 
-    // Window the result into global window.
     Dataset<WindowedValue<OutputT>> outputDataset =
-        combinedDataset.map(
-            WindowingHelpers.windowMapFunction(), EncoderHelpers.windowedValueEncoder());
-
+        combinedRowDataset.map(
+            RowHelpers.extractObjectFromRowMapFunction(), EncoderHelpers.windowedValueEncoder());
     context.putDataset(output, outputDataset);
   }
 }