You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by xi...@apache.org on 2022/09/28 22:20:44 UTC
[beam] branch master updated: Adds support in Samza Runner to run DoFn.processElement in parallel inside Samza tasks (#23313)
This is an automated email from the ASF dual-hosted git repository.
xinyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 406da4ae224 Adds support in Samza Runner to run DoFn.processElement in parallel inside Samza tasks (#23313)
406da4ae224 is described below
commit 406da4ae2242b3c6a124a1c161f222403e1e67f1
Author: Xinyu Liu <xi...@gmail.com>
AuthorDate: Wed Sep 28 15:20:36 2022 -0700
Adds support in Samza Runner to run DoFn.processElement in parallel inside Samza tasks (#23313)
---
.../beam/runners/samza/SamzaPipelineOptions.java | 31 ++++
.../samza/SamzaPipelineOptionsValidator.java | 7 +-
.../runners/samza/runtime/AsyncDoFnRunner.java | 118 ++++++++++++++
.../apache/beam/runners/samza/runtime/DoFnOp.java | 51 ++++--
.../runners/samza/runtime/FutureCollector.java | 7 +
.../beam/runners/samza/runtime/OpAdapter.java | 83 +++++-----
.../beam/runners/samza/runtime/OpEmitter.java | 6 +
.../runners/samza/runtime/SamzaDoFnRunners.java | 13 +-
.../runners/samza/translation/ConfigBuilder.java | 19 +++
.../beam/runners/samza/util/FutureUtils.java | 17 ++
.../runners/samza/runtime/AsyncDoFnRunnerTest.java | 171 +++++++++++++++++++++
11 files changed, 465 insertions(+), 58 deletions(-)
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptions.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptions.java
index c0af1fab0d9..814b14f98b8 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptions.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptions.java
@@ -20,9 +20,14 @@ package org.apache.beam.runners.samza;
import com.fasterxml.jackson.annotation.JsonIgnore;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.DefaultValueFactory;
import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.Hidden;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.samza.config.ConfigLoaderFactory;
import org.apache.samza.config.loaders.PropertiesConfigLoaderFactory;
import org.apache.samza.metrics.MetricsReporter;
@@ -129,4 +134,30 @@ public interface SamzaPipelineOptions extends PipelineOptions {
long getMaxBundleTimeMs();
void setMaxBundleTimeMs(long maxBundleTimeMs);
+
+ @Description(
+ "The number of threads to run DoFn.processElements in parallel within a bundle. Used only in non-portable mode.")
+ @Default.Integer(1)
+ int getNumThreadsForProcessElement();
+
+ void setNumThreadsForProcessElement(int numThreads);
+
+ @JsonIgnore
+ @Description(
+ "The ExecutorService instance to run DoFN.processElements in parallel within a bundle. Used only in non-portable mode.")
+ @Default.InstanceFactory(ProcessElementExecutorServiceFactory.class)
+ @Hidden
+ ExecutorService getExecutorServiceForProcessElement();
+
+ void setExecutorServiceForProcessElement(ExecutorService executorService);
+
+ class ProcessElementExecutorServiceFactory implements DefaultValueFactory<ExecutorService> {
+
+ @Override
+ public ExecutorService create(PipelineOptions options) {
+ return Executors.newFixedThreadPool(
+ options.as(SamzaPipelineOptions.class).getNumThreadsForProcessElement(),
+ new ThreadFactoryBuilder().setNameFormat("Process Element Thread-%d").build());
+ }
+ }
}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptionsValidator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptionsValidator.java
index 7702b6bb41f..5beb9fe0e56 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptionsValidator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptionsValidator.java
@@ -44,14 +44,11 @@ public class SamzaPipelineOptionsValidator {
: pipelineOptions.getConfigOverride();
final JobConfig jobConfig = new JobConfig(new MapConfig(configs));
- // TODO: once Samza supports a better thread pool modle, e.g. thread
- // per-task/key-range, this can be supported.
+ // Validate that the threadPoolSize is not override in the code
checkArgument(
jobConfig.getThreadPoolSize() <= 1,
JOB_CONTAINER_THREAD_POOL_SIZE
- + " cannot be configured to"
- + " greater than 1 for max bundle size: "
- + pipelineOptions.getMaxBundleSize());
+ + " config should be replaced with SamzaPipelineOptions.numThreadsForProcessElement");
}
}
}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunner.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunner.java
new file mode 100644
index 00000000000..7120696aa4f
--- /dev/null
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunner.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.samza.runtime;
+
+import java.util.Collection;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.stream.Collectors;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.samza.SamzaPipelineOptions;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This {@link DoFnRunner} adds the capability of executing the {@link
+ * org.apache.beam.sdk.transforms.DoFn.ProcessElement} in the thread pool, and returns the future to
+ * the collector for the underlying async execution.
+ */
+public class AsyncDoFnRunner<InT, OutT> implements DoFnRunner<InT, OutT> {
+ private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFnRunner.class);
+
+ private final DoFnRunner<InT, OutT> underlying;
+ private final ExecutorService executor;
+ private final OpEmitter<OutT> emitter;
+ private final FutureCollector<OutT> futureCollector;
+
+ public static <InT, OutT> AsyncDoFnRunner<InT, OutT> create(
+ DoFnRunner<InT, OutT> runner,
+ OpEmitter<OutT> emitter,
+ FutureCollector<OutT> futureCollector,
+ SamzaPipelineOptions options) {
+
+ LOG.info("Run DoFn with " + AsyncDoFnRunner.class.getName());
+ return new AsyncDoFnRunner<>(runner, emitter, futureCollector, options);
+ }
+
+ private AsyncDoFnRunner(
+ DoFnRunner<InT, OutT> runner,
+ OpEmitter<OutT> emitter,
+ FutureCollector<OutT> futureCollector,
+ SamzaPipelineOptions options) {
+ this.underlying = runner;
+ this.executor = options.getExecutorServiceForProcessElement();
+ this.emitter = emitter;
+ this.futureCollector = futureCollector;
+ }
+
+ @Override
+ public void startBundle() {
+ underlying.startBundle();
+ }
+
+ @Override
+ public void processElement(WindowedValue<InT> elem) {
+ final CompletableFuture<Void> future =
+ CompletableFuture.runAsync(
+ () -> {
+ underlying.processElement(elem);
+ },
+ executor);
+
+ final CompletableFuture<Collection<WindowedValue<OutT>>> outputFutures =
+ future.thenApply(
+ x ->
+ emitter.collectOutput().stream()
+ .map(OpMessage::getElement)
+ .collect(Collectors.toList()));
+
+ futureCollector.addAll(outputFutures);
+ }
+
+ @Override
+ public <KeyT> void onTimer(
+ String timerId,
+ String timerFamilyId,
+ KeyT key,
+ BoundedWindow window,
+ Instant timestamp,
+ Instant outputTimestamp,
+ TimeDomain timeDomain) {
+ underlying.onTimer(timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
+ }
+
+ @Override
+ public void finishBundle() {
+ underlying.finishBundle();
+ }
+
+ @Override
+ public <KeyT> void onWindowExpiration(BoundedWindow window, Instant timestamp, KeyT key) {
+ underlying.onWindowExpiration(window, timestamp, key);
+ }
+
+ @Override
+ public DoFn<InT, OutT> getFn() {
+ return underlying.getFn();
+ }
+}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
index 714693cc687..735ec62cd32 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
@@ -21,12 +21,12 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
import java.util.ArrayList;
import java.util.Collection;
-import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
@@ -259,7 +259,9 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
sideOutputTags,
outputCoders,
doFnSchemaInformation,
- (Map<String, PCollectionView<?>>) sideInputMapping);
+ (Map<String, PCollectionView<?>>) sideInputMapping,
+ emitter,
+ outputFutureCollector);
}
this.pushbackFnRunner =
@@ -479,30 +481,50 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
}
static class FutureCollectorImpl<OutT> implements FutureCollector<OutT> {
- private final List<CompletionStage<WindowedValue<OutT>>> outputFutures;
private final AtomicBoolean collectorSealed;
+ private CompletionStage<Collection<WindowedValue<OutT>>> outputFuture;
FutureCollectorImpl() {
- /*
- * Choosing synchronized list here since the concurrency is low as the message dispatch thread is single threaded.
- * We need this guard against scenarios when watermark/finish bundle trigger outputs.
- */
- outputFutures = Collections.synchronizedList(new ArrayList<>());
+ outputFuture = CompletableFuture.completedFuture(new ArrayList<>());
collectorSealed = new AtomicBoolean(true);
}
@Override
public void add(CompletionStage<WindowedValue<OutT>> element) {
+ checkState(
+ !collectorSealed.get(),
+ "Cannot add element to an unprepared collector. Make sure prepare() is invoked before adding elements.");
+
+ // We need synchronize guard against scenarios when watermark/finish bundle trigger outputs.
+ synchronized (this) {
+ outputFuture =
+ outputFuture.thenCombine(
+ element,
+ (collection, event) -> {
+ collection.add(event);
+ return collection;
+ });
+ }
+ }
+
+ @Override
+ public void addAll(CompletionStage<Collection<WindowedValue<OutT>>> elements) {
checkState(
!collectorSealed.get(),
"Cannot add elements to an unprepared collector. Make sure prepare() is invoked before adding elements.");
- outputFutures.add(element);
+
+ synchronized (this) {
+ outputFuture = FutureUtils.combineFutures(outputFuture, elements);
+ }
}
@Override
public void discard() {
collectorSealed.compareAndSet(false, true);
- outputFutures.clear();
+
+ synchronized (this) {
+ outputFuture = CompletableFuture.completedFuture(new ArrayList<>());
+ }
}
@Override
@@ -513,10 +535,11 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
*/
collectorSealed.compareAndSet(false, true);
- CompletionStage<Collection<WindowedValue<OutT>>> sealedOutputFuture =
- FutureUtils.flattenFutures(outputFutures);
- outputFutures.clear();
- return sealedOutputFuture;
+ synchronized (this) {
+ final CompletionStage<Collection<WindowedValue<OutT>>> sealedOutputFuture = outputFuture;
+ outputFuture = CompletableFuture.completedFuture(new ArrayList<>());
+ return sealedOutputFuture;
+ }
}
@Override
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/FutureCollector.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/FutureCollector.java
index acb2ebaa713..c606b756935 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/FutureCollector.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/FutureCollector.java
@@ -36,6 +36,13 @@ public interface FutureCollector<OutT> {
*/
void add(CompletionStage<WindowedValue<OutT>> element);
+ /**
+ * Outputs a collection of elements to the collector.
+ *
+ * @param elements to add to the collector
+ */
+ void addAll(CompletionStage<Collection<WindowedValue<OutT>>> elements);
+
/**
* Discards the elements within the collector. Once the elements have been discarded, callers need
* to prepare the collector again before invoking {@link #add(CompletionStage)}.
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpAdapter.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpAdapter.java
index aa090d9b772..eabcd87f5f3 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpAdapter.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpAdapter.java
@@ -21,13 +21,16 @@ import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
+import java.util.Queue;
import java.util.ServiceLoader;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.runners.samza.SamzaPipelineExceptionContext;
import org.apache.beam.runners.samza.translation.TranslationContext;
+import org.apache.beam.runners.samza.util.FutureUtils;
import org.apache.beam.runners.samza.util.SamzaPipelineExceptionListener;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
@@ -57,9 +60,6 @@ public class OpAdapter<InT, OutT, K>
private final Op<InT, OutT, K> op;
private final String transformFullName;
- private transient List<OpMessage<OutT>> outputList;
- private transient CompletionStage<Collection<OpMessage<OutT>>> outputFuture;
- private transient Instant outputWatermark;
private transient OpEmitter<OutT> emitter;
private transient Config config;
private transient Context context;
@@ -77,8 +77,7 @@ public class OpAdapter<InT, OutT, K>
@Override
public final void init(Context context) {
- this.outputList = new ArrayList<>();
- this.emitter = new OpEmitterImpl();
+ this.emitter = new OpEmitterImpl<>();
this.config = context.getJobContext().getConfig();
this.context = context;
this.exceptionListeners =
@@ -97,8 +96,6 @@ public class OpAdapter<InT, OutT, K>
@Override
public synchronized CompletionStage<Collection<OpMessage<OutT>>> apply(OpMessage<InT> message) {
- assert outputList.isEmpty();
-
try {
switch (message.getType()) {
case ELEMENT:
@@ -121,27 +118,13 @@ public class OpAdapter<InT, OutT, K>
}
CompletionStage<Collection<OpMessage<OutT>>> resultFuture =
- CompletableFuture.completedFuture(new ArrayList<>(outputList));
-
- if (outputFuture != null) {
- resultFuture =
- resultFuture.thenCombine(
- outputFuture,
- (res1, res2) -> {
- res1.addAll(res2);
- return res1;
- });
- }
+ CompletableFuture.completedFuture(emitter.collectOutput());
- outputList.clear();
- outputFuture = null;
- return resultFuture;
+ return FutureUtils.combineFutures(resultFuture, emitter.collectFuture());
}
@Override
public synchronized Collection<OpMessage<OutT>> processWatermark(long time) {
- assert outputList.isEmpty();
-
try {
op.processWatermark(new Instant(time), emitter);
} catch (Exception e) {
@@ -150,21 +133,17 @@ public class OpAdapter<InT, OutT, K>
throw UserCodeException.wrap(e);
}
- final List<OpMessage<OutT>> results = new ArrayList<>(outputList);
- outputList.clear();
- return results;
+ return emitter.collectOutput();
}
@Override
public synchronized Long getOutputWatermark() {
- return outputWatermark != null ? outputWatermark.getMillis() : null;
+ return emitter.collectWatermark();
}
@Override
public synchronized Collection<OpMessage<OutT>> onCallback(
KeyedTimerData<K> keyedTimerData, long time) {
- assert outputList.isEmpty();
-
try {
op.processTimer(keyedTimerData, emitter);
} catch (Exception e) {
@@ -172,9 +151,7 @@ public class OpAdapter<InT, OutT, K>
throw UserCodeException.wrap(e);
}
- final List<OpMessage<OutT>> results = new ArrayList<>(outputList);
- outputList.clear();
- return results;
+ return emitter.collectOutput();
}
@Override
@@ -182,17 +159,27 @@ public class OpAdapter<InT, OutT, K>
op.close();
}
- private class OpEmitterImpl implements OpEmitter<OutT> {
+ private static class OpEmitterImpl<OutT> implements OpEmitter<OutT> {
+ private final Queue<OpMessage<OutT>> outputQueue;
+ private CompletionStage<Collection<OpMessage<OutT>>> outputFuture;
+ private Instant outputWatermark;
+
+ private OpEmitterImpl() {
+ outputQueue = new ConcurrentLinkedQueue<>();
+ }
+
@Override
public void emitElement(WindowedValue<OutT> element) {
- outputList.add(OpMessage.ofElement(element));
+ outputQueue.add(OpMessage.ofElement(element));
}
@Override
public void emitFuture(CompletionStage<Collection<WindowedValue<OutT>>> resultFuture) {
- outputFuture =
+ final CompletionStage<Collection<OpMessage<OutT>>> resultFutureWrapped =
resultFuture.thenApply(
res -> res.stream().map(OpMessage::ofElement).collect(Collectors.toList()));
+
+ outputFuture = FutureUtils.combineFutures(outputFuture, resultFutureWrapped);
}
@Override
@@ -202,7 +189,31 @@ public class OpAdapter<InT, OutT, K>
@Override
public <T> void emitView(String id, WindowedValue<Iterable<T>> elements) {
- outputList.add(OpMessage.ofSideInput(id, elements));
+ outputQueue.add(OpMessage.ofSideInput(id, elements));
+ }
+
+ @Override
+ public Collection<OpMessage<OutT>> collectOutput() {
+ final List<OpMessage<OutT>> outputList = new ArrayList<>();
+ OpMessage<OutT> output;
+ while ((output = outputQueue.poll()) != null) {
+ outputList.add(output);
+ }
+ return outputList;
+ }
+
+ @Override
+ public CompletionStage<Collection<OpMessage<OutT>>> collectFuture() {
+ final CompletionStage<Collection<OpMessage<OutT>>> future = outputFuture;
+ outputFuture = null;
+ return future;
+ }
+
+ @Override
+ public Long collectWatermark() {
+ final Instant watermark = outputWatermark;
+ outputWatermark = null;
+ return watermark == null ? null : watermark.getMillis();
}
}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpEmitter.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpEmitter.java
index 951f5df6e46..cefbf0f8a2b 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpEmitter.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpEmitter.java
@@ -32,4 +32,10 @@ public interface OpEmitter<OutT> {
void emitWatermark(Instant watermark);
<T> void emitView(String id, WindowedValue<Iterable<T>> elements);
+
+ Collection<OpMessage<OutT>> collectOutput();
+
+ CompletionStage<Collection<OpMessage<OutT>>> collectFuture();
+
+ Long collectWatermark();
}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
index dc9f3301105..12872b82d8f 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
@@ -90,7 +90,9 @@ public class SamzaDoFnRunners {
List<TupleTag<?>> sideOutputTags,
Map<TupleTag<?>, Coder<?>> outputCoders,
DoFnSchemaInformation doFnSchemaInformation,
- Map<String, PCollectionView<?>> sideInputMapping) {
+ Map<String, PCollectionView<?>> sideInputMapping,
+ OpEmitter emitter,
+ FutureCollector futureCollector) {
final KeyedInternals keyedInternals;
final TimerInternals timerInternals;
final StateInternals stateInternals;
@@ -133,6 +135,7 @@ public class SamzaDoFnRunners {
underlyingRunner, executionContext.getMetricsContainer(), transformFullName)
: underlyingRunner;
+ final DoFnRunner<InT, FnOutT> doFnRunnerWithStates;
if (keyedInternals != null) {
final DoFnRunner<InT, FnOutT> statefulDoFnRunner =
DoFnRunners.defaultStatefulDoFnRunner(
@@ -144,10 +147,14 @@ public class SamzaDoFnRunners {
new StatefulDoFnRunner.TimeInternalsCleanupTimer(timerInternals, windowingStrategy),
createStateCleaner(doFn, windowingStrategy, keyedInternals.stateInternals()));
- return new DoFnRunnerWithKeyedInternals<>(statefulDoFnRunner, keyedInternals);
+ doFnRunnerWithStates = new DoFnRunnerWithKeyedInternals<>(statefulDoFnRunner, keyedInternals);
} else {
- return doFnRunnerWithMetrics;
+ doFnRunnerWithStates = doFnRunnerWithMetrics;
}
+
+ return pipelineOptions.getNumThreadsForProcessElement() > 1
+ ? AsyncDoFnRunner.create(doFnRunnerWithStates, emitter, futureCollector, pipelineOptions)
+ : doFnRunnerWithStates;
}
/** Creates a {@link StepContext} that allows accessing state and timer internals. */
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ConfigBuilder.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ConfigBuilder.java
index 6ed4003cdcf..50650ece96c 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ConfigBuilder.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ConfigBuilder.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.samza.translation;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.samza.config.JobConfig.JOB_CONTAINER_THREAD_POOL_SIZE;
import static org.apache.samza.config.JobConfig.JOB_ID;
import static org.apache.samza.config.JobConfig.JOB_NAME;
import static org.apache.samza.config.TaskConfig.COMMIT_MS;
@@ -95,7 +96,25 @@ public class ConfigBuilder {
config.put(ApplicationConfig.APP_ID, options.getJobInstance());
config.put(JOB_NAME, options.getJobName());
config.put(JOB_ID, options.getJobInstance());
+
+ // bundle-related configs
config.put(MAX_CONCURRENCY, String.valueOf(options.getMaxBundleSize()));
+ if (options.getMaxBundleSize() > 1) {
+ final String threadPoolSizeStr = config.remove(JOB_CONTAINER_THREAD_POOL_SIZE);
+ final int threadPoolSize =
+ threadPoolSizeStr == null ? 0 : Integer.parseInt(threadPoolSizeStr);
+
+ if (threadPoolSize > 1 && options.getNumThreadsForProcessElement() <= 1) {
+ // In case the user sets the thread pool through samza config instead options,
+ // set the bundle thread pool size based on container thread pool config
+ LOG.info(
+ "Set NumThreadsForProcessElement based on "
+ + JOB_CONTAINER_THREAD_POOL_SIZE
+ + " to "
+ + threadPoolSize);
+ options.setNumThreadsForProcessElement(threadPoolSize);
+ }
+ }
// remove config overrides before serialization (LISAMZA-15259)
options.setConfigOverride(new HashMap<>());
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/util/FutureUtils.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/util/FutureUtils.java
index 09ad77bac57..a6ca0351255 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/util/FutureUtils.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/util/FutureUtils.java
@@ -47,4 +47,21 @@ public final class FutureUtils {
return result;
});
}
+
+ public static <T> CompletionStage<Collection<T>> combineFutures(
+ CompletionStage<Collection<T>> future1, CompletionStage<Collection<T>> future2) {
+
+ if (future1 == null) {
+ return future2;
+ } else if (future2 == null) {
+ return future1;
+ } else {
+ return future1.thenCombine(
+ future2,
+ (c1, c2) -> {
+ c1.addAll(c2);
+ return c1;
+ });
+ }
+ }
}
diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunnerTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunnerTest.java
new file mode 100644
index 00000000000..d62a28b374f
--- /dev/null
+++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunnerTest.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.samza.runtime;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Filter;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.junit.Rule;
+import org.junit.Test;
+
+@SuppressWarnings({
+ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
+ // TODO(https://github.com/apache/beam/issues/21230): Remove when new version of
+ // errorprone is released (2.11.0)
+ "unused"
+})
+/**
+ * Tests for {@link AsyncDoFnRunner}.
+ *
+ * <p>Note due to the bug in SAMZA-2761, end-of-stream can cause shutdown while there are still
+ * messages in process in asynchronous mode. As a temporary solution, we add more bundles to process
+ * in the test inputs.
+ */
+public class AsyncDoFnRunnerTest implements Serializable {
+
+ @Rule
+ public final transient TestPipeline pipeline =
+ TestPipeline.fromOptions(
+ PipelineOptionsFactory.fromArgs(
+ "--runner=TestSamzaRunner",
+ "--maxBundleSize=5",
+ "--numThreadsForProcessElement=5")
+ .create());
+
+ @Test
+ public void testSimplePipeline() {
+ List<Integer> input = new ArrayList<>();
+ for (int i = 1; i < 20; i++) {
+ input.add(i);
+ }
+ PCollection<Integer> square =
+ pipeline
+ .apply(Create.of(input))
+ .apply(Filter.by(x -> x <= 5))
+ .apply(MapElements.into(TypeDescriptors.integers()).via(x -> x * x));
+
+ PAssert.that(square).containsInAnyOrder(Arrays.asList(1, 4, 9, 16, 25));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testPipelineWithState() {
+ final List<KV<String, String>> input =
+ new ArrayList<>(
+ Arrays.asList(
+ KV.of("apple", "red"),
+ KV.of("banana", "yellow"),
+ KV.of("apple", "yellow"),
+ KV.of("grape", "purple"),
+ KV.of("banana", "yellow")));
+ final Map<String, Integer> expectedCount = ImmutableMap.of("apple", 2, "banana", 2, "grape", 1);
+
+ // TODO: remove after SAMZA-2761 fix
+ for (int i = 0; i < 20; i++) {
+ input.add(KV.of("*", "*"));
+ }
+
+ final DoFn<KV<String, String>, KV<String, Integer>> fn =
+ new DoFn<KV<String, String>, KV<String, Integer>>() {
+
+ @StateId("cc")
+ private final StateSpec<CombiningState<Integer, int[], Integer>> countState =
+ StateSpecs.combiningFromInputInternal(VarIntCoder.of(), Sum.ofIntegers());
+
+ @ProcessElement
+ public void processElement(
+ ProcessContext c, @StateId("cc") CombiningState<Integer, int[], Integer> countState) {
+
+ if (c.element().getKey().equals("*")) {
+ return;
+ }
+
+ // Need explicit synchronization here
+ synchronized (this) {
+ countState.add(1);
+ }
+
+ String key = c.element().getKey();
+ int n = countState.read();
+ if (n >= expectedCount.get(key)) {
+ c.output(KV.of(key, n));
+ }
+ }
+ };
+
+ PCollection<KV<String, Integer>> counts = pipeline.apply(Create.of(input)).apply(ParDo.of(fn));
+
+ PAssert.that(counts)
+ .containsInAnyOrder(
+ expectedCount.entrySet().stream()
+ .map(entry -> KV.of(entry.getKey(), entry.getValue()))
+ .collect(Collectors.toList()));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testPipelineWithAggregation() {
+ final List<KV<String, Long>> input =
+ new ArrayList<>(
+ Arrays.asList(
+ KV.of("apple", 2L),
+ KV.of("banana", 5L),
+ KV.of("apple", 8L),
+ KV.of("grape", 10L),
+ KV.of("banana", 5L)));
+
+ // TODO: remove after SAMZA-2761 fix
+ for (int i = 0; i < 20; i++) {
+ input.add(KV.of("*", 0L));
+ }
+
+ PCollection<KV<String, Long>> sums =
+ pipeline
+ .apply(Create.of(input))
+ .apply(Filter.by(x -> !x.getKey().equals("*")))
+ .apply(Sum.longsPerKey());
+
+ PAssert.that(sums)
+ .containsInAnyOrder(
+ Arrays.asList(KV.of("apple", 10L), KV.of("banana", 10L), KV.of("grape", 10L)));
+
+ pipeline.run();
+ }
+}