You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by xi...@apache.org on 2022/09/09 21:19:01 UTC

[beam] branch master updated: Consolidate Samza TranslationContext and PortableTranslationContext (#23072)

This is an automated email from the ASF dual-hosted git repository.

xinyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new d17914862da Consolidate Samza TranslationContext and PortableTranslationContext (#23072)
d17914862da is described below

commit d17914862dadf94376f02b01e056e517b1224236
Author: Bharath Kumarasubramanian <bh...@apache.org>
AuthorDate: Fri Sep 9 14:18:54 2022 -0700

    Consolidate Samza TranslationContext and PortableTranslationContext (#23072)
---
 .../samza/translation/GroupByKeyTranslator.java    |  2 +-
 .../translation/ParDoBoundMultiTranslator.java     |  2 +-
 .../translation/PortableTranslationContext.java    | 51 ++++------------------
 .../samza/translation/ReshuffleTranslator.java     |  2 +-
 .../SamzaPortablePipelineTranslator.java           |  8 +---
 .../samza/translation/TranslationContext.java      | 15 +++++--
 6 files changed, 26 insertions(+), 54 deletions(-)

diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java
index f12675863ce..2982d35b806 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java
@@ -155,7 +155,7 @@ class GroupByKeyTranslator<K, InputT, OutputT>
       WindowedValue.WindowedValueCoder<KV<K, InputT>> windowedInputCoder,
       TupleTag<KV<K, OutputT>> outputTag,
       PortableTranslationContext ctx) {
-    final boolean needRepartition = ctx.getSamzaPipelineOptions().getMaxSourceParallelism() > 1;
+    final boolean needRepartition = ctx.getPipelineOptions().getMaxSourceParallelism() > 1;
     final Coder<BoundedWindow> windowCoder = windowingStrategy.getWindowFn().windowCoder();
     final KvCoder<K, InputT> kvInputCoder = (KvCoder<K, InputT>) windowedInputCoder.getValueCoder();
     final Coder<WindowedValue<KV<K, InputT>>> elementCoder =
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
index b64d5871f80..e032a0d8345 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
@@ -510,7 +510,7 @@ class ParDoBoundMultiTranslator<InT, OutT>
             coder.withValueCoder(IterableCoder.of(coder.getValueCoder())),
             ctx.getTransformId(),
             getSideInputUniqueId(sideInputId),
-            ctx.getSamzaPipelineOptions());
+            ctx.getPipelineOptions());
 
     return broadcastSideInput;
   }
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java
index 1f8fecc12da..776ee80878d 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java
@@ -18,26 +18,20 @@
 package org.apache.beam.runners.samza.translation;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.stream.Collectors;
 import org.apache.beam.runners.core.construction.graph.PipelineNode;
 import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
 import org.apache.beam.runners.samza.SamzaPipelineOptions;
 import org.apache.beam.runners.samza.runtime.OpMessage;
-import org.apache.beam.runners.samza.util.HashIdGenerator;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.samza.application.descriptors.StreamApplicationDescriptor;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
-import org.apache.samza.operators.OutputStream;
 import org.apache.samza.system.descriptors.InputDescriptor;
-import org.apache.samza.system.descriptors.OutputDescriptor;
-import org.apache.samza.table.Table;
-import org.apache.samza.table.descriptors.TableDescriptor;
 
 /**
  * Helper that keeps the mapping from BEAM PCollection id to Samza {@link MessageStream}. It also
@@ -48,26 +42,16 @@ import org.apache.samza.table.descriptors.TableDescriptor;
   "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
   "nullness" // TODO(https://github.com/apache/beam/issues/20497)
 })
-public class PortableTranslationContext {
+public class PortableTranslationContext extends TranslationContext {
   private final Map<String, MessageStream<?>> messageStreams = new HashMap<>();
-  private final StreamApplicationDescriptor appDescriptor;
   private final JobInfo jobInfo;
-  private final SamzaPipelineOptions options;
-  private final Set<String> registeredInputStreams = new HashSet<>();
-  private final Map<String, Table> registeredTables = new HashMap<>();
-  private final HashIdGenerator idGenerator = new HashIdGenerator();
 
   private PipelineNode.PTransformNode currentTransform;
 
   public PortableTranslationContext(
       StreamApplicationDescriptor appDescriptor, SamzaPipelineOptions options, JobInfo jobInfo) {
+    super(appDescriptor, Collections.emptyMap(), options);
     this.jobInfo = jobInfo;
-    this.appDescriptor = appDescriptor;
-    this.options = options;
-  }
-
-  public SamzaPipelineOptions getSamzaPipelineOptions() {
-    return this.options;
   }
 
   public <T> List<MessageStream<OpMessage<T>>> getAllInputMessageStreams(
@@ -106,45 +90,28 @@ public class PortableTranslationContext {
     messageStreams.put(id, stream);
   }
 
-  /** Get output stream by output descriptor. */
-  public <OutT> OutputStream<OutT> getOutputStream(OutputDescriptor<OutT, ?> outputDescriptor) {
-    return appDescriptor.getOutputStream(outputDescriptor);
-  }
-
   /** Register an input stream with certain config id. */
   public <T> void registerInputMessageStream(
       String id, InputDescriptor<KV<?, OpMessage<T>>, ?> inputDescriptor) {
-    // we want to register it with the Samza graph only once per i/o stream
-    final String streamId = inputDescriptor.getStreamId();
-    if (registeredInputStreams.contains(streamId)) {
-      return;
-    }
-    final MessageStream<OpMessage<T>> stream =
-        appDescriptor.getInputStream(inputDescriptor).map(org.apache.samza.operators.KV::getValue);
-
-    registerMessageStream(id, stream);
-    registeredInputStreams.add(streamId);
+    registerInputMessageStreams(id, Collections.singletonList(inputDescriptor));
   }
 
-  @SuppressWarnings("unchecked")
-  public <K, V> Table<KV<K, V>> getTable(TableDescriptor<K, V, ?> tableDesc) {
-    return registeredTables.computeIfAbsent(
-        tableDesc.getTableId(), id -> appDescriptor.getTable(tableDesc));
+  public <T> void registerInputMessageStreams(
+      String id, List<? extends InputDescriptor<KV<?, OpMessage<T>>, ?>> inputDescriptors) {
+    registerInputMessageStreams(id, inputDescriptors, this::registerMessageStream);
   }
 
   public void setCurrentTransform(PipelineNode.PTransformNode currentTransform) {
     this.currentTransform = currentTransform;
   }
 
+  @Override
   public void clearCurrentTransform() {
     this.currentTransform = null;
   }
 
+  @Override
   public String getTransformFullName() {
     return currentTransform.getTransform().getUniqueName();
   }
-
-  public String getTransformId() {
-    return idGenerator.getId(currentTransform.getTransform().getUniqueName());
-  }
 }
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ReshuffleTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ReshuffleTranslator.java
index 62bc2224f35..e82020238ce 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ReshuffleTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ReshuffleTranslator.java
@@ -84,7 +84,7 @@ public class ReshuffleTranslator<K, InT, OutT>
             ((KvCoder<K, InT>) windowedInputCoder.getValueCoder()).getKeyCoder(),
             windowedInputCoder,
             "rshfl-" + ctx.getTransformId(),
-            ctx.getSamzaPipelineOptions().getMaxSourceParallelism() > 1);
+            ctx.getPipelineOptions().getMaxSourceParallelism() > 1);
 
     ctx.registerMessageStream(outputId, outputStream);
   }
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
index e5bc2cd8f91..055dc4f2d72 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
@@ -59,9 +59,7 @@ public class SamzaPortablePipelineTranslator {
   private SamzaPortablePipelineTranslator() {}
 
   public static void translate(RunnerApi.Pipeline pipeline, PortableTranslationContext ctx) {
-    QueryablePipeline queryablePipeline =
-        QueryablePipeline.forTransforms(
-            pipeline.getRootTransformIdsList(), pipeline.getComponents());
+    QueryablePipeline queryablePipeline = QueryablePipeline.forPipeline(pipeline);
 
     for (PipelineNode.PTransformNode transform :
         queryablePipeline.getTopologicallyOrderedTransforms()) {
@@ -78,9 +76,7 @@ public class SamzaPortablePipelineTranslator {
 
   public static void createConfig(
       RunnerApi.Pipeline pipeline, ConfigBuilder configBuilder, SamzaPipelineOptions options) {
-    QueryablePipeline queryablePipeline =
-        QueryablePipeline.forTransforms(
-            pipeline.getRootTransformIdsList(), pipeline.getComponents());
+    QueryablePipeline queryablePipeline = QueryablePipeline.forPipeline(pipeline);
     for (PipelineNode.PTransformNode transform :
         queryablePipeline.getTopologicallyOrderedTransforms()) {
       TransformTranslator<?> translator =
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/TranslationContext.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/TranslationContext.java
index f9193ddfaaf..84885945049 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/TranslationContext.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/TranslationContext.java
@@ -24,6 +24,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import org.apache.beam.runners.core.construction.TransformInputs;
 import org.apache.beam.runners.samza.SamzaPipelineOptions;
@@ -109,6 +110,13 @@ public class TranslationContext {
    */
   public <OutT> void registerInputMessageStreams(
       PValue pvalue, List<? extends InputDescriptor<KV<?, OpMessage<OutT>>, ?>> inputDescriptors) {
+    registerInputMessageStreams(pvalue, inputDescriptors, this::registerMessageStream);
+  }
+
+  protected <KeyT, OutT> void registerInputMessageStreams(
+      KeyT key,
+      List<? extends InputDescriptor<KV<?, OpMessage<OutT>>, ?>> inputDescriptors,
+      BiConsumer<KeyT, MessageStream<OpMessage<OutT>>> registerFunction) {
     final Set<MessageStream<OpMessage<OutT>>> streamsToMerge = new HashSet<>();
     for (InputDescriptor<KV<?, OpMessage<OutT>>, ?> inputDescriptor : inputDescriptors) {
       final String streamId = inputDescriptor.getStreamId();
@@ -119,7 +127,7 @@ public class TranslationContext {
         LOG.info(
             String.format(
                 "Stream id %s has already been mapped to %s stream. Mapping %s to the same message stream.",
-                streamId, messageStream, pvalue));
+                streamId, messageStream, key));
         streamsToMerge.add(messageStream);
       } else {
         final MessageStream<OpMessage<OutT>> typedStream =
@@ -128,7 +136,8 @@ public class TranslationContext {
         streamsToMerge.add(typedStream);
       }
     }
-    registerMessageStream(pvalue, MessageStream.mergeAll(streamsToMerge));
+
+    registerFunction.accept(key, MessageStream.mergeAll(streamsToMerge));
   }
 
   public <OutT> void registerMessageStream(PValue pvalue, MessageStream<OpMessage<OutT>> stream) {
@@ -237,7 +246,7 @@ public class TranslationContext {
   }
 
   public String getTransformId() {
-    return idGenerator.getId(currentTransform.getFullName());
+    return idGenerator.getId(getTransformFullName());
   }
 
   /** The dummy stream created will only be used in Beam tests. */