You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by xi...@apache.org on 2022/09/28 22:20:44 UTC

[beam] branch master updated: Adds support in Samza Runner to run DoFn.processElement in parallel inside Samza tasks (#23313)

This is an automated email from the ASF dual-hosted git repository.

xinyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 406da4ae224 Adds support in Samza Runner to run DoFn.processElement in parallel inside Samza tasks (#23313)
406da4ae224 is described below

commit 406da4ae2242b3c6a124a1c161f222403e1e67f1
Author: Xinyu Liu <xi...@gmail.com>
AuthorDate: Wed Sep 28 15:20:36 2022 -0700

    Adds support in Samza Runner to run DoFn.processElement in parallel inside Samza tasks (#23313)
---
 .../beam/runners/samza/SamzaPipelineOptions.java   |  31 ++++
 .../samza/SamzaPipelineOptionsValidator.java       |   7 +-
 .../runners/samza/runtime/AsyncDoFnRunner.java     | 118 ++++++++++++++
 .../apache/beam/runners/samza/runtime/DoFnOp.java  |  51 ++++--
 .../runners/samza/runtime/FutureCollector.java     |   7 +
 .../beam/runners/samza/runtime/OpAdapter.java      |  83 +++++-----
 .../beam/runners/samza/runtime/OpEmitter.java      |   6 +
 .../runners/samza/runtime/SamzaDoFnRunners.java    |  13 +-
 .../runners/samza/translation/ConfigBuilder.java   |  19 +++
 .../beam/runners/samza/util/FutureUtils.java       |  17 ++
 .../runners/samza/runtime/AsyncDoFnRunnerTest.java | 171 +++++++++++++++++++++
 11 files changed, 465 insertions(+), 58 deletions(-)

diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptions.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptions.java
index c0af1fab0d9..814b14f98b8 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptions.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptions.java
@@ -20,9 +20,14 @@ package org.apache.beam.runners.samza;
 import com.fasterxml.jackson.annotation.JsonIgnore;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.DefaultValueFactory;
 import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.Hidden;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
 import org.apache.samza.config.ConfigLoaderFactory;
 import org.apache.samza.config.loaders.PropertiesConfigLoaderFactory;
 import org.apache.samza.metrics.MetricsReporter;
@@ -129,4 +134,30 @@ public interface SamzaPipelineOptions extends PipelineOptions {
   long getMaxBundleTimeMs();
 
   void setMaxBundleTimeMs(long maxBundleTimeMs);
+
+  @Description(
+      "The number of threads to run DoFn.processElements in parallel within a bundle. Used only in non-portable mode.")
+  @Default.Integer(1)
+  int getNumThreadsForProcessElement();
+
+  void setNumThreadsForProcessElement(int numThreads);
+
+  @JsonIgnore
+  @Description(
+      "The ExecutorService instance to run DoFN.processElements in parallel within a bundle. Used only in non-portable mode.")
+  @Default.InstanceFactory(ProcessElementExecutorServiceFactory.class)
+  @Hidden
+  ExecutorService getExecutorServiceForProcessElement();
+
+  void setExecutorServiceForProcessElement(ExecutorService executorService);
+
+  class ProcessElementExecutorServiceFactory implements DefaultValueFactory<ExecutorService> {
+
+    @Override
+    public ExecutorService create(PipelineOptions options) {
+      return Executors.newFixedThreadPool(
+          options.as(SamzaPipelineOptions.class).getNumThreadsForProcessElement(),
+          new ThreadFactoryBuilder().setNameFormat("Process Element Thread-%d").build());
+    }
+  }
 }
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptionsValidator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptionsValidator.java
index 7702b6bb41f..5beb9fe0e56 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptionsValidator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineOptionsValidator.java
@@ -44,14 +44,11 @@ public class SamzaPipelineOptionsValidator {
               : pipelineOptions.getConfigOverride();
       final JobConfig jobConfig = new JobConfig(new MapConfig(configs));
 
-      // TODO: once Samza supports a better thread pool modle, e.g. thread
-      // per-task/key-range, this can be supported.
+      // Validate that the threadPoolSize is not override in the code
       checkArgument(
           jobConfig.getThreadPoolSize() <= 1,
           JOB_CONTAINER_THREAD_POOL_SIZE
-              + " cannot be configured to"
-              + " greater than 1 for max bundle size: "
-              + pipelineOptions.getMaxBundleSize());
+              + " config should be replaced with SamzaPipelineOptions.numThreadsForProcessElement");
     }
   }
 }
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunner.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunner.java
new file mode 100644
index 00000000000..7120696aa4f
--- /dev/null
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunner.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.samza.runtime;
+
+import java.util.Collection;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.stream.Collectors;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.samza.SamzaPipelineOptions;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This {@link DoFnRunner} adds the capability of executing the {@link
+ * org.apache.beam.sdk.transforms.DoFn.ProcessElement} in the thread pool, and returns the future to
+ * the collector for the underlying async execution.
+ */
+public class AsyncDoFnRunner<InT, OutT> implements DoFnRunner<InT, OutT> {
+  private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFnRunner.class);
+
+  private final DoFnRunner<InT, OutT> underlying;
+  private final ExecutorService executor;
+  private final OpEmitter<OutT> emitter;
+  private final FutureCollector<OutT> futureCollector;
+
+  public static <InT, OutT> AsyncDoFnRunner<InT, OutT> create(
+      DoFnRunner<InT, OutT> runner,
+      OpEmitter<OutT> emitter,
+      FutureCollector<OutT> futureCollector,
+      SamzaPipelineOptions options) {
+
+    LOG.info("Run DoFn with " + AsyncDoFnRunner.class.getName());
+    return new AsyncDoFnRunner<>(runner, emitter, futureCollector, options);
+  }
+
+  private AsyncDoFnRunner(
+      DoFnRunner<InT, OutT> runner,
+      OpEmitter<OutT> emitter,
+      FutureCollector<OutT> futureCollector,
+      SamzaPipelineOptions options) {
+    this.underlying = runner;
+    this.executor = options.getExecutorServiceForProcessElement();
+    this.emitter = emitter;
+    this.futureCollector = futureCollector;
+  }
+
+  @Override
+  public void startBundle() {
+    underlying.startBundle();
+  }
+
+  @Override
+  public void processElement(WindowedValue<InT> elem) {
+    final CompletableFuture<Void> future =
+        CompletableFuture.runAsync(
+            () -> {
+              underlying.processElement(elem);
+            },
+            executor);
+
+    final CompletableFuture<Collection<WindowedValue<OutT>>> outputFutures =
+        future.thenApply(
+            x ->
+                emitter.collectOutput().stream()
+                    .map(OpMessage::getElement)
+                    .collect(Collectors.toList()));
+
+    futureCollector.addAll(outputFutures);
+  }
+
+  @Override
+  public <KeyT> void onTimer(
+      String timerId,
+      String timerFamilyId,
+      KeyT key,
+      BoundedWindow window,
+      Instant timestamp,
+      Instant outputTimestamp,
+      TimeDomain timeDomain) {
+    underlying.onTimer(timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
+  }
+
+  @Override
+  public void finishBundle() {
+    underlying.finishBundle();
+  }
+
+  @Override
+  public <KeyT> void onWindowExpiration(BoundedWindow window, Instant timestamp, KeyT key) {
+    underlying.onWindowExpiration(window, timestamp, key);
+  }
+
+  @Override
+  public DoFn<InT, OutT> getFn() {
+    return underlying.getFn();
+  }
+}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
index 714693cc687..735ec62cd32 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
@@ -21,12 +21,12 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
 
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.ServiceLoader;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Function;
@@ -259,7 +259,9 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
               sideOutputTags,
               outputCoders,
               doFnSchemaInformation,
-              (Map<String, PCollectionView<?>>) sideInputMapping);
+              (Map<String, PCollectionView<?>>) sideInputMapping,
+              emitter,
+              outputFutureCollector);
     }
 
     this.pushbackFnRunner =
@@ -479,30 +481,50 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
   }
 
   static class FutureCollectorImpl<OutT> implements FutureCollector<OutT> {
-    private final List<CompletionStage<WindowedValue<OutT>>> outputFutures;
     private final AtomicBoolean collectorSealed;
+    private CompletionStage<Collection<WindowedValue<OutT>>> outputFuture;
 
     FutureCollectorImpl() {
-      /*
-       * Choosing synchronized list here since the concurrency is low as the message dispatch thread is single threaded.
-       * We need this guard against scenarios when watermark/finish bundle trigger outputs.
-       */
-      outputFutures = Collections.synchronizedList(new ArrayList<>());
+      outputFuture = CompletableFuture.completedFuture(new ArrayList<>());
       collectorSealed = new AtomicBoolean(true);
     }
 
     @Override
     public void add(CompletionStage<WindowedValue<OutT>> element) {
+      checkState(
+          !collectorSealed.get(),
+          "Cannot add element to an unprepared collector. Make sure prepare() is invoked before adding elements.");
+
+      // We need synchronize guard against scenarios when watermark/finish bundle trigger outputs.
+      synchronized (this) {
+        outputFuture =
+            outputFuture.thenCombine(
+                element,
+                (collection, event) -> {
+                  collection.add(event);
+                  return collection;
+                });
+      }
+    }
+
+    @Override
+    public void addAll(CompletionStage<Collection<WindowedValue<OutT>>> elements) {
       checkState(
           !collectorSealed.get(),
           "Cannot add elements to an unprepared collector. Make sure prepare() is invoked before adding elements.");
-      outputFutures.add(element);
+
+      synchronized (this) {
+        outputFuture = FutureUtils.combineFutures(outputFuture, elements);
+      }
     }
 
     @Override
     public void discard() {
       collectorSealed.compareAndSet(false, true);
-      outputFutures.clear();
+
+      synchronized (this) {
+        outputFuture = CompletableFuture.completedFuture(new ArrayList<>());
+      }
     }
 
     @Override
@@ -513,10 +535,11 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
        */
       collectorSealed.compareAndSet(false, true);
 
-      CompletionStage<Collection<WindowedValue<OutT>>> sealedOutputFuture =
-          FutureUtils.flattenFutures(outputFutures);
-      outputFutures.clear();
-      return sealedOutputFuture;
+      synchronized (this) {
+        final CompletionStage<Collection<WindowedValue<OutT>>> sealedOutputFuture = outputFuture;
+        outputFuture = CompletableFuture.completedFuture(new ArrayList<>());
+        return sealedOutputFuture;
+      }
     }
 
     @Override
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/FutureCollector.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/FutureCollector.java
index acb2ebaa713..c606b756935 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/FutureCollector.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/FutureCollector.java
@@ -36,6 +36,13 @@ public interface FutureCollector<OutT> {
    */
   void add(CompletionStage<WindowedValue<OutT>> element);
 
+  /**
+   * Outputs a collection of elements to the collector.
+   *
+   * @param elements to add to the collector
+   */
+  void addAll(CompletionStage<Collection<WindowedValue<OutT>>> elements);
+
   /**
    * Discards the elements within the collector. Once the elements have been discarded, callers need
    * to prepare the collector again before invoking {@link #add(CompletionStage)}.
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpAdapter.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpAdapter.java
index aa090d9b772..eabcd87f5f3 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpAdapter.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpAdapter.java
@@ -21,13 +21,16 @@ import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
+import java.util.Queue;
 import java.util.ServiceLoader;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.stream.Collectors;
 import java.util.stream.StreamSupport;
 import org.apache.beam.runners.samza.SamzaPipelineExceptionContext;
 import org.apache.beam.runners.samza.translation.TranslationContext;
+import org.apache.beam.runners.samza.util.FutureUtils;
 import org.apache.beam.runners.samza.util.SamzaPipelineExceptionListener;
 import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.util.WindowedValue;
@@ -57,9 +60,6 @@ public class OpAdapter<InT, OutT, K>
 
   private final Op<InT, OutT, K> op;
   private final String transformFullName;
-  private transient List<OpMessage<OutT>> outputList;
-  private transient CompletionStage<Collection<OpMessage<OutT>>> outputFuture;
-  private transient Instant outputWatermark;
   private transient OpEmitter<OutT> emitter;
   private transient Config config;
   private transient Context context;
@@ -77,8 +77,7 @@ public class OpAdapter<InT, OutT, K>
 
   @Override
   public final void init(Context context) {
-    this.outputList = new ArrayList<>();
-    this.emitter = new OpEmitterImpl();
+    this.emitter = new OpEmitterImpl<>();
     this.config = context.getJobContext().getConfig();
     this.context = context;
     this.exceptionListeners =
@@ -97,8 +96,6 @@ public class OpAdapter<InT, OutT, K>
 
   @Override
   public synchronized CompletionStage<Collection<OpMessage<OutT>>> apply(OpMessage<InT> message) {
-    assert outputList.isEmpty();
-
     try {
       switch (message.getType()) {
         case ELEMENT:
@@ -121,27 +118,13 @@ public class OpAdapter<InT, OutT, K>
     }
 
     CompletionStage<Collection<OpMessage<OutT>>> resultFuture =
-        CompletableFuture.completedFuture(new ArrayList<>(outputList));
-
-    if (outputFuture != null) {
-      resultFuture =
-          resultFuture.thenCombine(
-              outputFuture,
-              (res1, res2) -> {
-                res1.addAll(res2);
-                return res1;
-              });
-    }
+        CompletableFuture.completedFuture(emitter.collectOutput());
 
-    outputList.clear();
-    outputFuture = null;
-    return resultFuture;
+    return FutureUtils.combineFutures(resultFuture, emitter.collectFuture());
   }
 
   @Override
   public synchronized Collection<OpMessage<OutT>> processWatermark(long time) {
-    assert outputList.isEmpty();
-
     try {
       op.processWatermark(new Instant(time), emitter);
     } catch (Exception e) {
@@ -150,21 +133,17 @@ public class OpAdapter<InT, OutT, K>
       throw UserCodeException.wrap(e);
     }
 
-    final List<OpMessage<OutT>> results = new ArrayList<>(outputList);
-    outputList.clear();
-    return results;
+    return emitter.collectOutput();
   }
 
   @Override
   public synchronized Long getOutputWatermark() {
-    return outputWatermark != null ? outputWatermark.getMillis() : null;
+    return emitter.collectWatermark();
   }
 
   @Override
   public synchronized Collection<OpMessage<OutT>> onCallback(
       KeyedTimerData<K> keyedTimerData, long time) {
-    assert outputList.isEmpty();
-
     try {
       op.processTimer(keyedTimerData, emitter);
     } catch (Exception e) {
@@ -172,9 +151,7 @@ public class OpAdapter<InT, OutT, K>
       throw UserCodeException.wrap(e);
     }
 
-    final List<OpMessage<OutT>> results = new ArrayList<>(outputList);
-    outputList.clear();
-    return results;
+    return emitter.collectOutput();
   }
 
   @Override
@@ -182,17 +159,27 @@ public class OpAdapter<InT, OutT, K>
     op.close();
   }
 
-  private class OpEmitterImpl implements OpEmitter<OutT> {
+  private static class OpEmitterImpl<OutT> implements OpEmitter<OutT> {
+    private final Queue<OpMessage<OutT>> outputQueue;
+    private CompletionStage<Collection<OpMessage<OutT>>> outputFuture;
+    private Instant outputWatermark;
+
+    private OpEmitterImpl() {
+      outputQueue = new ConcurrentLinkedQueue<>();
+    }
+
     @Override
     public void emitElement(WindowedValue<OutT> element) {
-      outputList.add(OpMessage.ofElement(element));
+      outputQueue.add(OpMessage.ofElement(element));
     }
 
     @Override
     public void emitFuture(CompletionStage<Collection<WindowedValue<OutT>>> resultFuture) {
-      outputFuture =
+      final CompletionStage<Collection<OpMessage<OutT>>> resultFutureWrapped =
           resultFuture.thenApply(
               res -> res.stream().map(OpMessage::ofElement).collect(Collectors.toList()));
+
+      outputFuture = FutureUtils.combineFutures(outputFuture, resultFutureWrapped);
     }
 
     @Override
@@ -202,7 +189,31 @@ public class OpAdapter<InT, OutT, K>
 
     @Override
     public <T> void emitView(String id, WindowedValue<Iterable<T>> elements) {
-      outputList.add(OpMessage.ofSideInput(id, elements));
+      outputQueue.add(OpMessage.ofSideInput(id, elements));
+    }
+
+    @Override
+    public Collection<OpMessage<OutT>> collectOutput() {
+      final List<OpMessage<OutT>> outputList = new ArrayList<>();
+      OpMessage<OutT> output;
+      while ((output = outputQueue.poll()) != null) {
+        outputList.add(output);
+      }
+      return outputList;
+    }
+
+    @Override
+    public CompletionStage<Collection<OpMessage<OutT>>> collectFuture() {
+      final CompletionStage<Collection<OpMessage<OutT>>> future = outputFuture;
+      outputFuture = null;
+      return future;
+    }
+
+    @Override
+    public Long collectWatermark() {
+      final Instant watermark = outputWatermark;
+      outputWatermark = null;
+      return watermark == null ? null : watermark.getMillis();
     }
   }
 
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpEmitter.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpEmitter.java
index 951f5df6e46..cefbf0f8a2b 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpEmitter.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/OpEmitter.java
@@ -32,4 +32,10 @@ public interface OpEmitter<OutT> {
   void emitWatermark(Instant watermark);
 
   <T> void emitView(String id, WindowedValue<Iterable<T>> elements);
+
+  Collection<OpMessage<OutT>> collectOutput();
+
+  CompletionStage<Collection<OpMessage<OutT>>> collectFuture();
+
+  Long collectWatermark();
 }
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
index dc9f3301105..12872b82d8f 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
@@ -90,7 +90,9 @@ public class SamzaDoFnRunners {
       List<TupleTag<?>> sideOutputTags,
       Map<TupleTag<?>, Coder<?>> outputCoders,
       DoFnSchemaInformation doFnSchemaInformation,
-      Map<String, PCollectionView<?>> sideInputMapping) {
+      Map<String, PCollectionView<?>> sideInputMapping,
+      OpEmitter emitter,
+      FutureCollector futureCollector) {
     final KeyedInternals keyedInternals;
     final TimerInternals timerInternals;
     final StateInternals stateInternals;
@@ -133,6 +135,7 @@ public class SamzaDoFnRunners {
                 underlyingRunner, executionContext.getMetricsContainer(), transformFullName)
             : underlyingRunner;
 
+    final DoFnRunner<InT, FnOutT> doFnRunnerWithStates;
     if (keyedInternals != null) {
       final DoFnRunner<InT, FnOutT> statefulDoFnRunner =
           DoFnRunners.defaultStatefulDoFnRunner(
@@ -144,10 +147,14 @@ public class SamzaDoFnRunners {
               new StatefulDoFnRunner.TimeInternalsCleanupTimer(timerInternals, windowingStrategy),
               createStateCleaner(doFn, windowingStrategy, keyedInternals.stateInternals()));
 
-      return new DoFnRunnerWithKeyedInternals<>(statefulDoFnRunner, keyedInternals);
+      doFnRunnerWithStates = new DoFnRunnerWithKeyedInternals<>(statefulDoFnRunner, keyedInternals);
     } else {
-      return doFnRunnerWithMetrics;
+      doFnRunnerWithStates = doFnRunnerWithMetrics;
     }
+
+    return pipelineOptions.getNumThreadsForProcessElement() > 1
+        ? AsyncDoFnRunner.create(doFnRunnerWithStates, emitter, futureCollector, pipelineOptions)
+        : doFnRunnerWithStates;
   }
 
   /** Creates a {@link StepContext} that allows accessing state and timer internals. */
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ConfigBuilder.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ConfigBuilder.java
index 6ed4003cdcf..50650ece96c 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ConfigBuilder.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ConfigBuilder.java
@@ -18,6 +18,7 @@
 package org.apache.beam.runners.samza.translation;
 
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.samza.config.JobConfig.JOB_CONTAINER_THREAD_POOL_SIZE;
 import static org.apache.samza.config.JobConfig.JOB_ID;
 import static org.apache.samza.config.JobConfig.JOB_NAME;
 import static org.apache.samza.config.TaskConfig.COMMIT_MS;
@@ -95,7 +96,25 @@ public class ConfigBuilder {
       config.put(ApplicationConfig.APP_ID, options.getJobInstance());
       config.put(JOB_NAME, options.getJobName());
       config.put(JOB_ID, options.getJobInstance());
+
+      // bundle-related configs
       config.put(MAX_CONCURRENCY, String.valueOf(options.getMaxBundleSize()));
+      if (options.getMaxBundleSize() > 1) {
+        final String threadPoolSizeStr = config.remove(JOB_CONTAINER_THREAD_POOL_SIZE);
+        final int threadPoolSize =
+            threadPoolSizeStr == null ? 0 : Integer.parseInt(threadPoolSizeStr);
+
+        if (threadPoolSize > 1 && options.getNumThreadsForProcessElement() <= 1) {
+          // In case the user sets the thread pool through samza config instead options,
+          // set the bundle thread pool size based on container thread pool config
+          LOG.info(
+              "Set NumThreadsForProcessElement based on "
+                  + JOB_CONTAINER_THREAD_POOL_SIZE
+                  + " to "
+                  + threadPoolSize);
+          options.setNumThreadsForProcessElement(threadPoolSize);
+        }
+      }
 
       // remove config overrides before serialization (LISAMZA-15259)
       options.setConfigOverride(new HashMap<>());
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/util/FutureUtils.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/util/FutureUtils.java
index 09ad77bac57..a6ca0351255 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/util/FutureUtils.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/util/FutureUtils.java
@@ -47,4 +47,21 @@ public final class FutureUtils {
               return result;
             });
   }
+
+  public static <T> CompletionStage<Collection<T>> combineFutures(
+      CompletionStage<Collection<T>> future1, CompletionStage<Collection<T>> future2) {
+
+    if (future1 == null) {
+      return future2;
+    } else if (future2 == null) {
+      return future1;
+    } else {
+      return future1.thenCombine(
+          future2,
+          (c1, c2) -> {
+            c1.addAll(c2);
+            return c1;
+          });
+    }
+  }
 }
diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunnerTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunnerTest.java
new file mode 100644
index 00000000000..d62a28b374f
--- /dev/null
+++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/AsyncDoFnRunnerTest.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.samza.runtime;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Filter;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.junit.Rule;
+import org.junit.Test;
+
+@SuppressWarnings({
+  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
+  // TODO(https://github.com/apache/beam/issues/21230): Remove when new version of
+  // errorprone is released (2.11.0)
+  "unused"
+})
+/**
+ * Tests for {@link AsyncDoFnRunner}.
+ *
+ * <p>Note due to the bug in SAMZA-2761, end-of-stream can cause shutdown while there are still
+ * messages in process in asynchronous mode. As a temporary solution, we add more bundles to process
+ * in the test inputs.
+ */
+public class AsyncDoFnRunnerTest implements Serializable {
+
+  @Rule
+  public final transient TestPipeline pipeline =
+      TestPipeline.fromOptions(
+          PipelineOptionsFactory.fromArgs(
+                  "--runner=TestSamzaRunner",
+                  "--maxBundleSize=5",
+                  "--numThreadsForProcessElement=5")
+              .create());
+
+  @Test
+  public void testSimplePipeline() {
+    List<Integer> input = new ArrayList<>();
+    for (int i = 1; i < 20; i++) {
+      input.add(i);
+    }
+    PCollection<Integer> square =
+        pipeline
+            .apply(Create.of(input))
+            .apply(Filter.by(x -> x <= 5))
+            .apply(MapElements.into(TypeDescriptors.integers()).via(x -> x * x));
+
+    PAssert.that(square).containsInAnyOrder(Arrays.asList(1, 4, 9, 16, 25));
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testPipelineWithState() {
+    final List<KV<String, String>> input =
+        new ArrayList<>(
+            Arrays.asList(
+                KV.of("apple", "red"),
+                KV.of("banana", "yellow"),
+                KV.of("apple", "yellow"),
+                KV.of("grape", "purple"),
+                KV.of("banana", "yellow")));
+    final Map<String, Integer> expectedCount = ImmutableMap.of("apple", 2, "banana", 2, "grape", 1);
+
+    // TODO: remove after SAMZA-2761 fix
+    for (int i = 0; i < 20; i++) {
+      input.add(KV.of("*", "*"));
+    }
+
+    final DoFn<KV<String, String>, KV<String, Integer>> fn =
+        new DoFn<KV<String, String>, KV<String, Integer>>() {
+
+          @StateId("cc")
+          private final StateSpec<CombiningState<Integer, int[], Integer>> countState =
+              StateSpecs.combiningFromInputInternal(VarIntCoder.of(), Sum.ofIntegers());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId("cc") CombiningState<Integer, int[], Integer> countState) {
+
+            if (c.element().getKey().equals("*")) {
+              return;
+            }
+
+            // Need explicit synchronization here
+            synchronized (this) {
+              countState.add(1);
+            }
+
+            String key = c.element().getKey();
+            int n = countState.read();
+            if (n >= expectedCount.get(key)) {
+              c.output(KV.of(key, n));
+            }
+          }
+        };
+
+    PCollection<KV<String, Integer>> counts = pipeline.apply(Create.of(input)).apply(ParDo.of(fn));
+
+    PAssert.that(counts)
+        .containsInAnyOrder(
+            expectedCount.entrySet().stream()
+                .map(entry -> KV.of(entry.getKey(), entry.getValue()))
+                .collect(Collectors.toList()));
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testPipelineWithAggregation() {
+    final List<KV<String, Long>> input =
+        new ArrayList<>(
+            Arrays.asList(
+                KV.of("apple", 2L),
+                KV.of("banana", 5L),
+                KV.of("apple", 8L),
+                KV.of("grape", 10L),
+                KV.of("banana", 5L)));
+
+    // TODO: remove after SAMZA-2761 fix
+    for (int i = 0; i < 20; i++) {
+      input.add(KV.of("*", 0L));
+    }
+
+    PCollection<KV<String, Long>> sums =
+        pipeline
+            .apply(Create.of(input))
+            .apply(Filter.by(x -> !x.getKey().equals("*")))
+            .apply(Sum.longsPerKey());
+
+    PAssert.that(sums)
+        .containsInAnyOrder(
+            Arrays.asList(KV.of("apple", 10L), KV.of("banana", 10L), KV.of("grape", 10L)));
+
+    pipeline.run();
+  }
+}