You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by xi...@apache.org on 2022/09/09 21:19:01 UTC
[beam] branch master updated: Consolidate Samza TranslationContext and PortableTranslationContext (#23072)
This is an automated email from the ASF dual-hosted git repository.
xinyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new d17914862da Consolidate Samza TranslationContext and PortableTranslationContext (#23072)
d17914862da is described below
commit d17914862dadf94376f02b01e056e517b1224236
Author: Bharath Kumarasubramanian <bh...@apache.org>
AuthorDate: Fri Sep 9 14:18:54 2022 -0700
Consolidate Samza TranslationContext and PortableTranslationContext (#23072)
---
.../samza/translation/GroupByKeyTranslator.java | 2 +-
.../translation/ParDoBoundMultiTranslator.java | 2 +-
.../translation/PortableTranslationContext.java | 51 ++++------------------
.../samza/translation/ReshuffleTranslator.java | 2 +-
.../SamzaPortablePipelineTranslator.java | 8 +---
.../samza/translation/TranslationContext.java | 15 +++++--
6 files changed, 26 insertions(+), 54 deletions(-)
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java
index f12675863ce..2982d35b806 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java
@@ -155,7 +155,7 @@ class GroupByKeyTranslator<K, InputT, OutputT>
WindowedValue.WindowedValueCoder<KV<K, InputT>> windowedInputCoder,
TupleTag<KV<K, OutputT>> outputTag,
PortableTranslationContext ctx) {
- final boolean needRepartition = ctx.getSamzaPipelineOptions().getMaxSourceParallelism() > 1;
+ final boolean needRepartition = ctx.getPipelineOptions().getMaxSourceParallelism() > 1;
final Coder<BoundedWindow> windowCoder = windowingStrategy.getWindowFn().windowCoder();
final KvCoder<K, InputT> kvInputCoder = (KvCoder<K, InputT>) windowedInputCoder.getValueCoder();
final Coder<WindowedValue<KV<K, InputT>>> elementCoder =
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
index b64d5871f80..e032a0d8345 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
@@ -510,7 +510,7 @@ class ParDoBoundMultiTranslator<InT, OutT>
coder.withValueCoder(IterableCoder.of(coder.getValueCoder())),
ctx.getTransformId(),
getSideInputUniqueId(sideInputId),
- ctx.getSamzaPipelineOptions());
+ ctx.getPipelineOptions());
return broadcastSideInput;
}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java
index 1f8fecc12da..776ee80878d 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java
@@ -18,26 +18,20 @@
package org.apache.beam.runners.samza.translation;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.List;
import java.util.Map;
-import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.samza.SamzaPipelineOptions;
import org.apache.beam.runners.samza.runtime.OpMessage;
-import org.apache.beam.runners.samza.util.HashIdGenerator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.samza.application.descriptors.StreamApplicationDescriptor;
import org.apache.samza.operators.KV;
import org.apache.samza.operators.MessageStream;
-import org.apache.samza.operators.OutputStream;
import org.apache.samza.system.descriptors.InputDescriptor;
-import org.apache.samza.system.descriptors.OutputDescriptor;
-import org.apache.samza.table.Table;
-import org.apache.samza.table.descriptors.TableDescriptor;
/**
* Helper that keeps the mapping from BEAM PCollection id to Samza {@link MessageStream}. It also
@@ -48,26 +42,16 @@ import org.apache.samza.table.descriptors.TableDescriptor;
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
-public class PortableTranslationContext {
+public class PortableTranslationContext extends TranslationContext {
private final Map<String, MessageStream<?>> messageStreams = new HashMap<>();
- private final StreamApplicationDescriptor appDescriptor;
private final JobInfo jobInfo;
- private final SamzaPipelineOptions options;
- private final Set<String> registeredInputStreams = new HashSet<>();
- private final Map<String, Table> registeredTables = new HashMap<>();
- private final HashIdGenerator idGenerator = new HashIdGenerator();
private PipelineNode.PTransformNode currentTransform;
public PortableTranslationContext(
StreamApplicationDescriptor appDescriptor, SamzaPipelineOptions options, JobInfo jobInfo) {
+ super(appDescriptor, Collections.emptyMap(), options);
this.jobInfo = jobInfo;
- this.appDescriptor = appDescriptor;
- this.options = options;
- }
-
- public SamzaPipelineOptions getSamzaPipelineOptions() {
- return this.options;
}
public <T> List<MessageStream<OpMessage<T>>> getAllInputMessageStreams(
@@ -106,45 +90,28 @@ public class PortableTranslationContext {
messageStreams.put(id, stream);
}
- /** Get output stream by output descriptor. */
- public <OutT> OutputStream<OutT> getOutputStream(OutputDescriptor<OutT, ?> outputDescriptor) {
- return appDescriptor.getOutputStream(outputDescriptor);
- }
-
/** Register an input stream with certain config id. */
public <T> void registerInputMessageStream(
String id, InputDescriptor<KV<?, OpMessage<T>>, ?> inputDescriptor) {
- // we want to register it with the Samza graph only once per i/o stream
- final String streamId = inputDescriptor.getStreamId();
- if (registeredInputStreams.contains(streamId)) {
- return;
- }
- final MessageStream<OpMessage<T>> stream =
- appDescriptor.getInputStream(inputDescriptor).map(org.apache.samza.operators.KV::getValue);
-
- registerMessageStream(id, stream);
- registeredInputStreams.add(streamId);
+ registerInputMessageStreams(id, Collections.singletonList(inputDescriptor));
}
- @SuppressWarnings("unchecked")
- public <K, V> Table<KV<K, V>> getTable(TableDescriptor<K, V, ?> tableDesc) {
- return registeredTables.computeIfAbsent(
- tableDesc.getTableId(), id -> appDescriptor.getTable(tableDesc));
+ public <T> void registerInputMessageStreams(
+ String id, List<? extends InputDescriptor<KV<?, OpMessage<T>>, ?>> inputDescriptors) {
+ registerInputMessageStreams(id, inputDescriptors, this::registerMessageStream);
}
public void setCurrentTransform(PipelineNode.PTransformNode currentTransform) {
this.currentTransform = currentTransform;
}
+ @Override
public void clearCurrentTransform() {
this.currentTransform = null;
}
+ @Override
public String getTransformFullName() {
return currentTransform.getTransform().getUniqueName();
}
-
- public String getTransformId() {
- return idGenerator.getId(currentTransform.getTransform().getUniqueName());
- }
}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ReshuffleTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ReshuffleTranslator.java
index 62bc2224f35..e82020238ce 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ReshuffleTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ReshuffleTranslator.java
@@ -84,7 +84,7 @@ public class ReshuffleTranslator<K, InT, OutT>
((KvCoder<K, InT>) windowedInputCoder.getValueCoder()).getKeyCoder(),
windowedInputCoder,
"rshfl-" + ctx.getTransformId(),
- ctx.getSamzaPipelineOptions().getMaxSourceParallelism() > 1);
+ ctx.getPipelineOptions().getMaxSourceParallelism() > 1);
ctx.registerMessageStream(outputId, outputStream);
}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
index e5bc2cd8f91..055dc4f2d72 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
@@ -59,9 +59,7 @@ public class SamzaPortablePipelineTranslator {
private SamzaPortablePipelineTranslator() {}
public static void translate(RunnerApi.Pipeline pipeline, PortableTranslationContext ctx) {
- QueryablePipeline queryablePipeline =
- QueryablePipeline.forTransforms(
- pipeline.getRootTransformIdsList(), pipeline.getComponents());
+ QueryablePipeline queryablePipeline = QueryablePipeline.forPipeline(pipeline);
for (PipelineNode.PTransformNode transform :
queryablePipeline.getTopologicallyOrderedTransforms()) {
@@ -78,9 +76,7 @@ public class SamzaPortablePipelineTranslator {
public static void createConfig(
RunnerApi.Pipeline pipeline, ConfigBuilder configBuilder, SamzaPipelineOptions options) {
- QueryablePipeline queryablePipeline =
- QueryablePipeline.forTransforms(
- pipeline.getRootTransformIdsList(), pipeline.getComponents());
+ QueryablePipeline queryablePipeline = QueryablePipeline.forPipeline(pipeline);
for (PipelineNode.PTransformNode transform :
queryablePipeline.getTopologicallyOrderedTransforms()) {
TransformTranslator<?> translator =
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/TranslationContext.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/TranslationContext.java
index f9193ddfaaf..84885945049 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/TranslationContext.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/TranslationContext.java
@@ -24,6 +24,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
+import java.util.function.BiConsumer;
import java.util.function.Consumer;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.samza.SamzaPipelineOptions;
@@ -109,6 +110,13 @@ public class TranslationContext {
*/
public <OutT> void registerInputMessageStreams(
PValue pvalue, List<? extends InputDescriptor<KV<?, OpMessage<OutT>>, ?>> inputDescriptors) {
+ registerInputMessageStreams(pvalue, inputDescriptors, this::registerMessageStream);
+ }
+
+ protected <KeyT, OutT> void registerInputMessageStreams(
+ KeyT key,
+ List<? extends InputDescriptor<KV<?, OpMessage<OutT>>, ?>> inputDescriptors,
+ BiConsumer<KeyT, MessageStream<OpMessage<OutT>>> registerFunction) {
final Set<MessageStream<OpMessage<OutT>>> streamsToMerge = new HashSet<>();
for (InputDescriptor<KV<?, OpMessage<OutT>>, ?> inputDescriptor : inputDescriptors) {
final String streamId = inputDescriptor.getStreamId();
@@ -119,7 +127,7 @@ public class TranslationContext {
LOG.info(
String.format(
"Stream id %s has already been mapped to %s stream. Mapping %s to the same message stream.",
- streamId, messageStream, pvalue));
+ streamId, messageStream, key));
streamsToMerge.add(messageStream);
} else {
final MessageStream<OpMessage<OutT>> typedStream =
@@ -128,7 +136,8 @@ public class TranslationContext {
streamsToMerge.add(typedStream);
}
}
- registerMessageStream(pvalue, MessageStream.mergeAll(streamsToMerge));
+
+ registerFunction.accept(key, MessageStream.mergeAll(streamsToMerge));
}
public <OutT> void registerMessageStream(PValue pvalue, MessageStream<OpMessage<OutT>> stream) {
@@ -237,7 +246,7 @@ public class TranslationContext {
}
public String getTransformId() {
- return idGenerator.getId(currentTransform.getFullName());
+ return idGenerator.getId(getTransformFullName());
}
/** The dummy stream created will only be used in Beam tests. */