You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2020/09/16 16:49:44 UTC

[beam] branch master updated: [BEAM-10670] Update Samza to be opt-out for SplittableDoFn.

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b981b30  [BEAM-10670] Update Samza to be opt-out for SplittableDoFn.
     new aa69ae5  Merge pull request #12617 from lukecwik/beam10670.3
b981b30 is described below

commit b981b30c9f56ec204921f4429e9c04ccda66cf99
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Sep 11 18:29:10 2020 -0700

    [BEAM-10670] Update Samza to be opt-out for SplittableDoFn.
---
 CHANGES.md                                         |   2 +-
 runners/samza/build.gradle                         |  11 +-
 .../org/apache/beam/runners/samza/SamzaRunner.java |   2 +
 .../SplittableParDoProcessKeyedElementsOp.java     | 251 +++++++++++++++++++++
 .../samza/translation/ImpulseTranslator.java       |  24 +-
 .../translation/ParDoBoundMultiTranslator.java     |   4 +-
 .../samza/translation/SamzaPipelineTranslator.java |   3 +
 .../samza/translation/SamzaTransformOverrides.java |   6 +
 .../translation/SplittableParDoTranslators.java    | 153 +++++++++++++
 .../runtime/SamzaStoreStateInternalsTest.java      |  26 +--
 .../src/main/java/org/apache/beam/sdk/io/Read.java |   8 +-
 11 files changed, 466 insertions(+), 24 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 5701546..6ebcca9 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -51,7 +51,7 @@
 
 ## Highlights
 
-* Splittable DoFn is opt-out for Java based runners (Direct, Flink, Twister2) using `--experiments=use_deprecated_read`. For all other runners, users can opt-in using `--experiments=use_sdf_read`. (Java) ([BEAM-10670](https://issues.apache.org/jira/browse/BEAM-10670))
+* Splittable DoFn is opt-out for Java based runners (Direct, Flink, Samza, Twister2) using `--experiments=use_deprecated_read`. For all other runners, users can opt-in using `--experiments=use_sdf_read`. (Java) ([BEAM-10670](https://issues.apache.org/jira/browse/BEAM-10670))
 * New highly anticipated feature X added to Python SDK ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
 * New highly anticipated feature Y added to Java SDK ([BEAM-Y](https://issues.apache.org/jira/browse/BEAM-Y)).
 
diff --git a/runners/samza/build.gradle b/runners/samza/build.gradle
index b9ee247..6ea4c59 100644
--- a/runners/samza/build.gradle
+++ b/runners/samza/build.gradle
@@ -85,8 +85,6 @@ task validatesRunner(type: Test) {
     excludeCategories 'org.apache.beam.sdk.testing.LargeKeys$Above100MB'
     excludeCategories 'org.apache.beam.sdk.testing.UsesAttemptedMetrics'
     excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
-    excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
-    excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
     excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
     excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher'
     excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle'
@@ -99,6 +97,15 @@ task validatesRunner(type: Test) {
   filter {
     // TODO(BEAM-10025)
     excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testOutputTimestampDefaultUnbounded'
+
+    // These tests fail since there is no support for side inputs in Samza's unbounded splittable DoFn integration
+    excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testWindowedSideInputWithCheckpointsUnbounded'
+    excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testSideInputUnbounded'
+    excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testWindowedSideInputUnbounded'
+    // These tests produce the output but either the pipeline doesn't shutdown or PAssert fails
+    excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testAdditionalOutputUnbounded'
+    excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testPairWithIndexBasicUnbounded'
+    excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testOutputAfterCheckpointUnbounded'
   }
 }
 
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaRunner.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaRunner.java
index 8fafb18..9bde6d3 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaRunner.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaRunner.java
@@ -23,6 +23,7 @@ import java.util.Iterator;
 import java.util.Map;
 import java.util.ServiceLoader;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.SplittableParDo;
 import org.apache.beam.runners.core.construction.renderer.PipelineDotRenderer;
 import org.apache.beam.runners.jobsubmission.PortablePipelineResult;
 import org.apache.beam.runners.samza.translation.ConfigBuilder;
@@ -106,6 +107,7 @@ public class SamzaRunner extends PipelineRunner<SamzaPipelineResult> {
 
   @Override
   public SamzaPipelineResult run(Pipeline pipeline) {
+    SplittableParDo.validateNoPrimitiveReads(pipeline);
     MetricsEnvironment.setMetricsSupported(true);
 
     if (LOG.isDebugEnabled()) {
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SplittableParDoProcessKeyedElementsOp.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SplittableParDoProcessKeyedElementsOp.java
new file mode 100644
index 0000000..c6bd8dc
--- /dev/null
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SplittableParDoProcessKeyedElementsOp.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.samza.runtime;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.concurrent.Executors;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.KeyedWorkItems;
+import org.apache.beam.runners.core.NullSideInputReader;
+import org.apache.beam.runners.core.OutputAndTimeBoundedSplittableProcessElementInvoker;
+import org.apache.beam.runners.core.OutputWindowedValue;
+import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems;
+import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessElements;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateInternalsFactory;
+import org.apache.beam.runners.core.StepContext;
+import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.TimerInternals.TimerData;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.core.construction.SplittableParDo;
+import org.apache.beam.runners.core.serialization.Base64Serializer;
+import org.apache.beam.runners.samza.SamzaPipelineOptions;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
+import org.apache.beam.sdk.transforms.join.RawUnionValue;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection.IsBounded;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
+import org.apache.samza.operators.Scheduler;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Samza operator for {@link org.apache.beam.sdk.transforms.GroupByKey}. */
+public class SplittableParDoProcessKeyedElementsOp<
+        InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+    implements Op<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, RawUnionValue, byte[]> {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(SplittableParDoProcessKeyedElementsOp.class);
+  private static final String TIMER_STATE_ID = "timer";
+
+  private final TupleTag<OutputT> mainOutputTag;
+  private final WindowingStrategy<?, BoundedWindow> windowingStrategy;
+  private final OutputManagerFactory<RawUnionValue> outputManagerFactory;
+  private final SplittableParDoViaKeyedWorkItems.ProcessElements<
+          InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+      processElements;
+  private final String transformFullName;
+  private final String transformId;
+  private final IsBounded isBounded;
+
+  private transient StateInternalsFactory<byte[]> stateInternalsFactory;
+  private transient SamzaTimerInternalsFactory<byte[]> timerInternalsFactory;
+  private transient DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> fnRunner;
+  private transient SamzaPipelineOptions pipelineOptions;
+
+  public SplittableParDoProcessKeyedElementsOp(
+      TupleTag<OutputT> mainOutputTag,
+      SplittableParDo.ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
+          processKeyedElements,
+      WindowingStrategy<?, BoundedWindow> windowingStrategy,
+      OutputManagerFactory<RawUnionValue> outputManagerFactory,
+      String transformFullName,
+      String transformId,
+      IsBounded isBounded) {
+    this.mainOutputTag = mainOutputTag;
+    this.windowingStrategy = windowingStrategy;
+    this.outputManagerFactory = outputManagerFactory;
+    this.transformFullName = transformFullName;
+    this.transformId = transformId;
+    this.isBounded = isBounded;
+
+    this.processElements = new ProcessElements<>(processKeyedElements);
+  }
+
+  @Override
+  public void open(
+      Config config,
+      Context context,
+      Scheduler<KeyedTimerData<byte[]>> timerRegistry,
+      OpEmitter<RawUnionValue> emitter) {
+    this.pipelineOptions =
+        Base64Serializer.deserializeUnchecked(
+                config.get("beamPipelineOptions"), SerializablePipelineOptions.class)
+            .get()
+            .as(SamzaPipelineOptions.class);
+
+    final SamzaStoreStateInternals.Factory<?> nonKeyedStateInternalsFactory =
+        SamzaStoreStateInternals.createStateInternalFactory(
+            transformId, null, context.getTaskContext(), pipelineOptions, null);
+
+    final DoFnRunners.OutputManager outputManager = outputManagerFactory.create(emitter);
+
+    this.stateInternalsFactory =
+        new SamzaStoreStateInternals.Factory<>(
+            transformId,
+            Collections.singletonMap(
+                SamzaStoreStateInternals.BEAM_STORE,
+                SamzaStoreStateInternals.getBeamStore(context.getTaskContext())),
+            ByteArrayCoder.of(),
+            pipelineOptions.getStoreBatchGetSize());
+
+    this.timerInternalsFactory =
+        SamzaTimerInternalsFactory.createTimerInternalFactory(
+            ByteArrayCoder.of(),
+            timerRegistry,
+            TIMER_STATE_ID,
+            nonKeyedStateInternalsFactory,
+            windowingStrategy,
+            isBounded,
+            pipelineOptions);
+
+    final KeyedInternals<byte[]> keyedInternals =
+        new KeyedInternals<>(stateInternalsFactory, timerInternalsFactory);
+
+    SplittableParDoViaKeyedWorkItems.ProcessFn<
+            InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+        processFn = processElements.newProcessFn(processElements.getFn());
+    DoFnInvokers.tryInvokeSetupFor(processFn);
+    processFn.setStateInternalsFactory(stateInternalsFactory);
+    processFn.setTimerInternalsFactory(timerInternalsFactory);
+    processFn.setProcessElementInvoker(
+        new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
+            processElements.getFn(),
+            pipelineOptions,
+            new OutputWindowedValue<OutputT>() {
+              @Override
+              public void outputWindowedValue(
+                  OutputT output,
+                  Instant timestamp,
+                  Collection<? extends BoundedWindow> windows,
+                  PaneInfo pane) {
+                outputWindowedValue(mainOutputTag, output, timestamp, windows, pane);
+              }
+
+              @Override
+              public <AdditionalOutputT> void outputWindowedValue(
+                  TupleTag<AdditionalOutputT> tag,
+                  AdditionalOutputT output,
+                  Instant timestamp,
+                  Collection<? extends BoundedWindow> windows,
+                  PaneInfo pane) {
+                outputManager.output(tag, WindowedValue.of(output, timestamp, windows, pane));
+              }
+            },
+            NullSideInputReader.empty(),
+            Executors.newSingleThreadScheduledExecutor(Executors.defaultThreadFactory()),
+            10000,
+            Duration.standardSeconds(10),
+            () -> {
+              throw new UnsupportedOperationException("BundleFinalizer unsupported in Samza");
+            }));
+
+    final StepContext stepContext =
+        new StepContext() {
+          @Override
+          public StateInternals stateInternals() {
+            return keyedInternals.stateInternals();
+          }
+
+          @Override
+          public TimerInternals timerInternals() {
+            return keyedInternals.timerInternals();
+          }
+        };
+
+    this.fnRunner =
+        DoFnRunners.simpleRunner(
+            pipelineOptions,
+            processFn,
+            NullSideInputReader.of(Collections.emptyList()),
+            outputManager,
+            mainOutputTag,
+            Collections.emptyList(),
+            stepContext,
+            null,
+            Collections.emptyMap(),
+            windowingStrategy,
+            DoFnSchemaInformation.create(),
+            Collections.emptyMap());
+  }
+
+  @Override
+  public void processElement(
+      WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> inputElement,
+      OpEmitter<RawUnionValue> emitter) {
+    fnRunner.startBundle();
+    fnRunner.processElement(inputElement);
+    fnRunner.finishBundle();
+  }
+
+  @Override
+  public void processWatermark(Instant watermark, OpEmitter<RawUnionValue> emitter) {
+    timerInternalsFactory.setInputWatermark(watermark);
+
+    fnRunner.startBundle();
+    for (KeyedTimerData<byte[]> keyedTimerData : timerInternalsFactory.removeReadyTimers()) {
+      fireTimer(keyedTimerData.getKey(), keyedTimerData.getTimerData());
+    }
+    fnRunner.finishBundle();
+
+    if (timerInternalsFactory.getOutputWatermark() == null
+        || timerInternalsFactory.getOutputWatermark().isBefore(watermark)) {
+      timerInternalsFactory.setOutputWatermark(watermark);
+      emitter.emitWatermark(timerInternalsFactory.getOutputWatermark());
+    }
+  }
+
+  @Override
+  public void processTimer(
+      KeyedTimerData<byte[]> keyedTimerData, OpEmitter<RawUnionValue> emitter) {
+    fnRunner.startBundle();
+    fireTimer(keyedTimerData.getKey(), keyedTimerData.getTimerData());
+    fnRunner.finishBundle();
+
+    timerInternalsFactory.removeProcessingTimer(keyedTimerData);
+  }
+
+  private void fireTimer(byte[] key, TimerData timer) {
+    LOG.debug("Firing timer {} for key {}", timer, key);
+    fnRunner.processElement(
+        WindowedValue.valueInGlobalWindow(
+            KeyedWorkItems.timersWorkItem(key, Collections.singletonList(timer))));
+  }
+}
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ImpulseTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ImpulseTranslator.java
index 2ccbd6c..4f316fa 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ImpulseTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ImpulseTranslator.java
@@ -20,6 +20,10 @@ package org.apache.beam.runners.samza.translation;
 import org.apache.beam.runners.core.construction.graph.PipelineNode;
 import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
 import org.apache.beam.runners.samza.runtime.OpMessage;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
 import org.apache.samza.operators.KV;
 import org.apache.samza.serializers.KVSerde;
 import org.apache.samza.serializers.NoOpSerde;
@@ -32,7 +36,25 @@ import org.apache.samza.system.descriptors.GenericSystemDescriptor;
  * {@link
  * org.apache.beam.runners.samza.translation.SamzaImpulseSystemFactory.SamzaImpulseSystemConsumer}.
  */
-public class ImpulseTranslator implements TransformTranslator {
+public class ImpulseTranslator
+    implements TransformTranslator<PTransform<PBegin, PCollection<byte[]>>> {
+
+  @Override
+  public void translate(
+      PTransform<PBegin, PCollection<byte[]>> transform, Node node, TranslationContext ctx) {
+    final PCollection<byte[]> output = ctx.getOutput(transform);
+    final String outputId = ctx.getIdForPValue(output);
+    final GenericSystemDescriptor systemDescriptor =
+        new GenericSystemDescriptor(outputId, SamzaImpulseSystemFactory.class.getName());
+
+    // The KvCoder is needed here for Samza not to crop the key.
+    final Serde<KV<?, OpMessage<byte[]>>> kvSerde = KVSerde.of(new NoOpSerde(), new NoOpSerde<>());
+    final GenericInputDescriptor<KV<?, OpMessage<byte[]>>> inputDescriptor =
+        systemDescriptor.getInputDescriptor(outputId, kvSerde);
+
+    ctx.registerInputMessageStream(output, inputDescriptor);
+  }
+
   @Override
   public void translatePortable(
       PipelineNode.PTransformNode transform,
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
index a289cf7..78b0f9d 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
@@ -331,7 +331,7 @@ class ParDoBoundMultiTranslator<InT, OutT>
     return config;
   }
 
-  private static class SideInputWatermarkFn<InT>
+  static class SideInputWatermarkFn<InT>
       implements FlatMapFunction<OpMessage<InT>, OpMessage<InT>>,
           WatermarkFunction<OpMessage<InT>> {
 
@@ -352,7 +352,7 @@ class ParDoBoundMultiTranslator<InT, OutT>
     }
   }
 
-  private static class RawUnionValueToValue<OutT> implements Op<RawUnionValue, OutT, Void> {
+  static class RawUnionValueToValue<OutT> implements Op<RawUnionValue, OutT, Void> {
     @Override
     public void processElement(WindowedValue<RawUnionValue> inputElement, OpEmitter<OutT> emitter) {
       @SuppressWarnings("unchecked")
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPipelineTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPipelineTranslator.java
index bfa2e10..c4424d6 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPipelineTranslator.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPipelineTranslator.java
@@ -184,6 +184,9 @@ public class SamzaPipelineTranslator {
           .put(PTransformTranslation.FLATTEN_TRANSFORM_URN, new FlattenPCollectionsTranslator())
           .put(SamzaPublishView.SAMZA_PUBLISH_VIEW_URN, new SamzaPublishViewTranslator())
           .put(PTransformTranslation.IMPULSE_TRANSFORM_URN, new ImpulseTranslator())
+          .put(
+              PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN,
+              new SplittableParDoTranslators.ProcessKeyedElements<>())
           .put(ExecutableStage.URN, new ParDoBoundMultiTranslator())
           .build();
     }
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTransformOverrides.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTransformOverrides.java
index b91f7b3..c9fe18c 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTransformOverrides.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTransformOverrides.java
@@ -34,6 +34,11 @@ public class SamzaTransformOverrides {
             PTransformOverride.of(
                 PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN),
                 new SamzaPublishViewTransformOverride()))
+
+        // Note that we have a direct replacement for SplittableParDo.ProcessKeyedElements
+        // for unbounded splittable DoFns and do not need to rely on
+        // SplittableParDoViaKeyedWorkItems override. Once this direct replacement supports side
+        // inputs we can remove the SplittableParDoNaiveBounded override.
         .add(
             PTransformOverride.of(
                 PTransformMatchers.splittableParDo(), new SplittableParDo.OverrideFactory()))
@@ -41,6 +46,7 @@ public class SamzaTransformOverrides {
             PTransformOverride.of(
                 PTransformMatchers.splittableProcessKeyedBounded(),
                 new SplittableParDoNaiveBounded.OverrideFactory()))
+
         // TODO: [BEAM-5362] Support @RequiresStableInput on Samza runner
         .add(
             PTransformOverride.of(
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SplittableParDoTranslators.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SplittableParDoTranslators.java
new file mode 100644
index 0000000..87f2e8f
--- /dev/null
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SplittableParDoTranslators.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.samza.translation;
+
+import static org.apache.beam.runners.samza.util.SamzaPipelineTranslatorUtils.escape;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.beam.runners.core.construction.SplittableParDo;
+import org.apache.beam.runners.samza.runtime.DoFnOp;
+import org.apache.beam.runners.samza.runtime.KvToKeyedWorkItemOp;
+import org.apache.beam.runners.samza.runtime.OpAdapter;
+import org.apache.beam.runners.samza.runtime.OpMessage;
+import org.apache.beam.runners.samza.runtime.SplittableParDoProcessKeyedElementsOp;
+import org.apache.beam.runners.samza.translation.ParDoBoundMultiTranslator.RawUnionValueToValue;
+import org.apache.beam.runners.samza.util.SamzaCoders;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
+import org.apache.beam.sdk.transforms.join.RawUnionValue;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.samza.operators.MessageStream;
+import org.apache.samza.serializers.KVSerde;
+
+/** A set of translators for {@link SplittableParDo}. */
+public class SplittableParDoTranslators {
+
+  /**
+   * Translates {@link SplittableParDo.ProcessKeyedElements} to Samza {@link
+   * SplittableParDoProcessKeyedElementsOp}.
+   */
+  static class ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
+      implements TransformTranslator<
+          SplittableParDo.ProcessKeyedElements<
+              InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>> {
+
+    @Override
+    public void translate(
+        SplittableParDo.ProcessKeyedElements<
+                InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
+            transform,
+        Node node,
+        TranslationContext ctx) {
+      final PCollection<KV<byte[], KV<InputT, RestrictionT>>> input = ctx.getInput(transform);
+
+      final ArrayList<Map.Entry<TupleTag<?>, PValue>> outputs =
+          new ArrayList<>(node.getOutputs().entrySet());
+
+      final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
+      final Map<Integer, PCollection<?>> indexToPCollectionMap = new HashMap<>();
+
+      for (int index = 0; index < outputs.size(); ++index) {
+        final Map.Entry<TupleTag<?>, PValue> taggedOutput = outputs.get(index);
+        tagToIndexMap.put(taggedOutput.getKey(), index);
+
+        if (!(taggedOutput.getValue() instanceof PCollection)) {
+          throw new IllegalArgumentException(
+              "Expected side output to be PCollection, but was: " + taggedOutput.getValue());
+        }
+        final PCollection<?> sideOutputCollection = (PCollection<?>) taggedOutput.getValue();
+        indexToPCollectionMap.put(index, sideOutputCollection);
+      }
+
+      @SuppressWarnings("unchecked")
+      final WindowingStrategy<?, BoundedWindow> windowingStrategy =
+          (WindowingStrategy<?, BoundedWindow>) input.getWindowingStrategy();
+
+      final MessageStream<OpMessage<KV<byte[], KV<InputT, RestrictionT>>>> inputStream =
+          ctx.getMessageStream(input);
+
+      final KvCoder<byte[], KV<InputT, RestrictionT>> kvInputCoder =
+          (KvCoder<byte[], KV<InputT, RestrictionT>>) input.getCoder();
+      final Coder<WindowedValue<KV<byte[], KV<InputT, RestrictionT>>>> elementCoder =
+          SamzaCoders.of(input);
+
+      final MessageStream<OpMessage<KV<byte[], KV<InputT, RestrictionT>>>> filteredInputStream =
+          inputStream.filter(msg -> msg.getType() == OpMessage.Type.ELEMENT);
+
+      final MessageStream<OpMessage<KV<byte[], KV<InputT, RestrictionT>>>> partitionedInputStream;
+      if (!needRepartition(ctx)) {
+        partitionedInputStream = filteredInputStream;
+      } else {
+        partitionedInputStream =
+            filteredInputStream
+                .partitionBy(
+                    msg -> msg.getElement().getValue().getKey(),
+                    msg -> msg.getElement(),
+                    KVSerde.of(
+                        SamzaCoders.toSerde(kvInputCoder.getKeyCoder()),
+                        SamzaCoders.toSerde(elementCoder)),
+                    "sdf-" + escape(ctx.getTransformId()))
+                .map(kv -> OpMessage.ofElement(kv.getValue()));
+      }
+
+      final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream =
+          partitionedInputStream
+              .flatMap(OpAdapter.adapt(new KvToKeyedWorkItemOp<>()))
+              .flatMap(
+                  OpAdapter.adapt(
+                      new SplittableParDoProcessKeyedElementsOp<>(
+                          transform.getMainOutputTag(),
+                          transform,
+                          windowingStrategy,
+                          new DoFnOp.MultiOutputManagerFactory(tagToIndexMap),
+                          ctx.getTransformFullName(),
+                          ctx.getTransformId(),
+                          input.isBounded())));
+
+      for (int outputIndex : tagToIndexMap.values()) {
+        @SuppressWarnings("unchecked")
+        final MessageStream<OpMessage<OutputT>> outputStream =
+            taggedOutputStream
+                .filter(
+                    message ->
+                        message.getType() != OpMessage.Type.ELEMENT
+                            || message.getElement().getValue().getUnionTag() == outputIndex)
+                .flatMap(OpAdapter.adapt(new RawUnionValueToValue()));
+
+        ctx.registerMessageStream(indexToPCollectionMap.get(outputIndex), outputStream);
+      }
+    }
+
+    private static boolean needRepartition(TranslationContext ctx) {
+      if (ctx.getPipelineOptions().getMaxSourceParallelism() == 1) {
+        // Only one task will be created, no need for repartition
+        return false;
+      }
+      return true;
+    }
+  }
+}
diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java
index 97ae705..5f88e4d 100644
--- a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java
+++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java
@@ -32,7 +32,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import org.apache.beam.runners.samza.SamzaPipelineOptions;
-import org.apache.beam.runners.samza.TestSamzaRunner;
 import org.apache.beam.runners.samza.state.SamzaMapState;
 import org.apache.beam.runners.samza.state.SamzaSetState;
 import org.apache.beam.runners.samza.translation.ConfigBuilder;
@@ -67,11 +66,15 @@ import org.apache.samza.storage.kv.KeyValueStoreMetrics;
 import org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStorageEngineFactory;
 import org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStore;
 import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Rule;
 import org.junit.Test;
 
 /** Tests for SamzaStoreStateInternals. */
 public class SamzaStoreStateInternalsTest implements Serializable {
-  public final transient TestPipeline pipeline = TestPipeline.create();
+  @Rule
+  public final transient TestPipeline pipeline =
+      TestPipeline.fromOptions(
+          PipelineOptionsFactory.fromArgs("--runner=TestSamzaRunner").create());
 
   @Test
   public void testMapStateIterator() {
@@ -125,11 +128,7 @@ public class SamzaStoreStateInternalsTest implements Serializable {
 
     PAssert.that(output).containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), KV.of("c", 12));
 
-    TestSamzaRunner.fromOptions(
-            PipelineOptionsFactory.fromArgs(
-                    "--runner=org.apache.beam.runners.samza.TestSamzaRunner")
-                .create())
-        .run(pipeline);
+    pipeline.run();
   }
 
   @Test
@@ -180,11 +179,7 @@ public class SamzaStoreStateInternalsTest implements Serializable {
 
     PAssert.that(output).containsInAnyOrder(Sets.newHashSet(97, 42, 12));
 
-    TestSamzaRunner.fromOptions(
-            PipelineOptionsFactory.fromArgs(
-                    "--runner=org.apache.beam.runners.samza.TestSamzaRunner")
-                .create())
-        .run(pipeline);
+    pipeline.run();
   }
 
   /** A storage engine to create test stores. */
@@ -286,13 +281,10 @@ public class SamzaStoreStateInternalsTest implements Serializable {
                 KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 42), KV.of("hello", 12)))
         .apply(ParDo.of(fn));
 
-    SamzaPipelineOptions options = PipelineOptionsFactory.create().as(SamzaPipelineOptions.class);
-    options.setRunner(TestSamzaRunner.class);
     Map<String, String> configs = new HashMap(ConfigBuilder.localRunConfig());
     configs.put("stores.foo.factory", TestStorageEngine.class.getName());
-    options.setConfigOverride(configs);
-
-    TestSamzaRunner.fromOptions(options).run(pipeline).waitUntilFinish();
+    pipeline.getOptions().as(SamzaPipelineOptions.class).setConfigOverride(configs);
+    pipeline.run();
 
     // The test code creates 7 underlying iterators, and 1 more is created during state.clear()
     // Verify all of them are closed
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
index aa5a959..7f0c129 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
@@ -930,7 +930,13 @@ public class Read {
    */
   private static final Set<String> SPLITTABLE_DOFN_PREFERRED_RUNNERS =
       ImmutableSet.of(
-          "DirectRunner", "FlinkRunner", "TestFlinkRunner", "Twister2Runner", "Twister2TestRunner");
+          "DirectRunner",
+          "FlinkRunner",
+          "TestFlinkRunner",
+          "SamzaRunner",
+          "TestSamzaRunner",
+          "Twister2Runner",
+          "Twister2TestRunner");
 
   private static boolean useSdf(PipelineOptions options) {
     // TODO(BEAM-10670): Make this by default true and have runners opt-out instead.