You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ec...@apache.org on 2019/10/24 10:08:53 UTC

[beam] 29/37: Apply new Encoders to AggregatorCombiner

This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 29f7e93c954cc26425a052c0f1c19ec6e6c9fe66
Author: Etienne Chauchot <ec...@apache.org>
AuthorDate: Fri Sep 27 11:55:20 2019 +0200

    Apply new Encoders to AggregatorCombiner
---
 .../translation/batch/AggregatorCombiner.java      | 22 +++++++++++++++++-----
 .../batch/CombinePerKeyTranslatorBatch.java        | 20 ++++++++++++++++----
 2 files changed, 33 insertions(+), 9 deletions(-)

diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java
index 0e3229e..d14569a 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java
@@ -27,6 +27,8 @@ import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
@@ -52,13 +54,25 @@ class AggregatorCombiner<K, InputT, AccumT, OutputT, W extends BoundedWindow>
   private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn;
   private WindowingStrategy<InputT, W> windowingStrategy;
   private TimestampCombiner timestampCombiner;
+  private IterableCoder<WindowedValue<AccumT>> accumulatorCoder;
+  private IterableCoder<WindowedValue<OutputT>> outputCoder;
 
   public AggregatorCombiner(
       Combine.CombineFn<InputT, AccumT, OutputT> combineFn,
-      WindowingStrategy<?, ?> windowingStrategy) {
+      WindowingStrategy<?, ?> windowingStrategy,
+      Coder<AccumT> accumulatorCoder,
+      Coder<OutputT> outputCoder) {
     this.combineFn = combineFn;
     this.windowingStrategy = (WindowingStrategy<InputT, W>) windowingStrategy;
     this.timestampCombiner = windowingStrategy.getTimestampCombiner();
+    this.accumulatorCoder =
+        IterableCoder.of(
+            WindowedValue.FullWindowedValueCoder.of(
+                accumulatorCoder, windowingStrategy.getWindowFn().windowCoder()));
+    this.outputCoder =
+        IterableCoder.of(
+            WindowedValue.FullWindowedValueCoder.of(
+                outputCoder, windowingStrategy.getWindowFn().windowCoder()));
   }
 
   @Override
@@ -142,14 +156,12 @@ class AggregatorCombiner<K, InputT, AccumT, OutputT, W extends BoundedWindow>
 
   @Override
   public Encoder<Iterable<WindowedValue<AccumT>>> bufferEncoder() {
-    // TODO replace with accumulatorCoder if possible
-    return EncoderHelpers.genericEncoder();
+    return EncoderHelpers.fromBeamCoder(accumulatorCoder);
   }
 
   @Override
   public Encoder<Iterable<WindowedValue<OutputT>>> outputEncoder() {
-    // TODO replace with outputCoder if possible
-    return EncoderHelpers.genericEncoder();
+    return EncoderHelpers.fromBeamCoder(outputCoder);
   }
 
   private Set<W> collectAccumulatorsWindows(Iterable<WindowedValue<AccumT>> accumulators) {
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
index 33b037a..be238b5 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
@@ -23,6 +23,7 @@ import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTr
 import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.transforms.Combine;
@@ -58,20 +59,31 @@ class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT>
 
     Dataset<WindowedValue<KV<K, InputT>>> inputDataset = context.getDataset(input);
 
-    Coder<K> keyCoder = (Coder<K>) input.getCoder().getCoderArguments().get(0);
-    Coder<OutputT> outputTCoder = (Coder<OutputT>) output.getCoder().getCoderArguments().get(1);
+    KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
+    Coder<K> keyCoder = inputCoder.getKeyCoder();
+    KvCoder<K, OutputT> outputKVCoder = (KvCoder<K, OutputT>) output.getCoder();
+    Coder<OutputT> outputCoder = outputKVCoder.getValueCoder();
 
     KeyValueGroupedDataset<K, WindowedValue<KV<K, InputT>>> groupedDataset =
         inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
 
+    Coder<AccumT> accumulatorCoder = null;
+    try {
+      accumulatorCoder =
+          combineFn.getAccumulatorCoder(
+              input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
+    } catch (CannotProvideCoderException e) {
+      throw new RuntimeException(e);
+    }
+
     Dataset<Tuple2<K, Iterable<WindowedValue<OutputT>>>> combinedDataset =
         groupedDataset.agg(
             new AggregatorCombiner<K, InputT, AccumT, OutputT, BoundedWindow>(
-                    combineFn, windowingStrategy)
+                    combineFn, windowingStrategy, accumulatorCoder, outputCoder)
                 .toColumn());
 
     // expand the list into separate elements and put the key back into the elements
-    Coder<KV<K, OutputT>> kvCoder = KvCoder.of(keyCoder, outputTCoder);
+    Coder<KV<K, OutputT>> kvCoder = KvCoder.of(keyCoder, outputCoder);
     WindowedValue.WindowedValueCoder<KV<K, OutputT>> wvCoder =
         WindowedValue.FullWindowedValueCoder.of(
             kvCoder, input.getWindowingStrategy().getWindowFn().windowCoder());