You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ec...@apache.org on 2022/09/22 10:30:15 UTC

[beam] branch master updated: Improved pipeline translation in SparkStructuredStreamingRunner (#22446)

This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 762edd7f3a6 Improved pipeline translation in SparkStructuredStreamingRunner (#22446)
762edd7f3a6 is described below

commit 762edd7f3a64f076dbee156fa48b8a7e5e6a512f
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Thu Sep 22 12:30:03 2022 +0200

    Improved pipeline translation in SparkStructuredStreamingRunner (#22446)
    
    * Closes #22445: Improved pipeline translation in SparkStructuredStreamingRunner (also closes #22382)
---
 .../translation/helpers/EncoderFactory.java        |  12 +-
 .../translation/utils/ScalaInterop.java}           |  27 +-
 .../spark/structuredstreaming/Constants.java       |  25 -
 .../SparkStructuredStreamingRunner.java            |  51 +-
 .../io/BoundedDatasetFactory.java                  | 324 +++++++++++
 .../streaming => io}/package-info.java             |   4 +-
 .../metrics/WithMetricsSupport.java                |   2 +-
 .../translation/AbstractTranslationContext.java    | 235 --------
 .../translation/PipelineTranslator.java            |  57 +-
 .../translation/TransformTranslator.java           | 198 ++++++-
 .../translation/TranslationContext.java            | 124 ++++-
 .../translation/batch/AggregatorCombiner.java      | 270 ----------
 .../translation/batch/Aggregators.java             | 591 +++++++++++++++++++++
 .../batch/CombineGloballyTranslatorBatch.java      | 121 +++++
 .../batch/CombinePerKeyTranslatorBatch.java        | 181 ++++---
 .../CreatePCollectionViewTranslatorBatch.java      |  24 +-
 .../translation/batch/DatasetSourceBatch.java      | 240 ---------
 .../translation/batch/DoFnFunction.java            | 164 ------
 .../batch/DoFnMapPartitionsFactory.java            | 224 ++++++++
 .../translation/batch/FlattenTranslatorBatch.java  |  60 +--
 .../translation/batch/GroupByKeyHelpers.java       | 106 ++++
 .../batch/GroupByKeyTranslatorBatch.java           | 298 +++++++++--
 .../translation/batch/ImpulseTranslatorBatch.java  |  26 +-
 .../translation/batch/ParDoTranslatorBatch.java    | 315 ++++++-----
 .../translation/batch/PipelineTranslatorBatch.java |  28 +-
 .../translation/batch/ProcessContext.java          | 138 -----
 .../batch/ReadSourceTranslatorBatch.java           |  76 +--
 .../batch/ReshuffleTranslatorBatch.java            |  30 --
 .../batch/WindowAssignTranslatorBatch.java         |  90 +++-
 .../translation/helpers/CoderHelpers.java          |  10 +-
 .../translation/helpers/EncoderFactory.java        |  71 ++-
 .../translation/helpers/EncoderHelpers.java        | 546 ++++++++++++++++++-
 .../translation/helpers/MultiOutputCoder.java      |  84 ---
 .../translation/helpers/RowHelpers.java            |  75 ---
 .../translation/helpers/SchemaHelpers.java         |  39 --
 .../translation/helpers/WindowingHelpers.java      |  82 ---
 .../streaming/DatasetSourceStreaming.java          |  25 -
 .../streaming/PipelineTranslatorStreaming.java     |  93 ----
 .../streaming/ReadSourceTranslatorStreaming.java   |  87 ---
 .../translation/utils/ScalaInterop.java            | 114 ++++
 .../aggregators/metrics/sink/InMemoryMetrics.java  |   2 +-
 .../translation/batch/AggregatorsTest.java         | 370 +++++++++++++
 .../{CombineTest.java => CombineGloballyTest.java} | 129 ++---
 .../{CombineTest.java => CombinePerKeyTest.java}   |  92 ++--
 .../translation/batch/ComplexSourceTest.java       |  15 +-
 .../translation/batch/FlattenTest.java             |  12 +-
 .../translation/batch/GroupByKeyTest.java          | 152 ++++--
 .../translation/batch/ParDoTest.java               |  54 +-
 .../translation/batch/SimpleSourceTest.java        |  12 +-
 .../translation/batch/WindowAssignTest.java        |  12 +-
 .../translation/helpers/EncoderHelpersTest.java    | 210 +++++++-
 .../runners/spark/SparkCommonPipelineOptions.java  |   6 +
 .../beam/runners/spark/SparkPipelineOptions.java   |   6 -
 53 files changed, 3978 insertions(+), 2361 deletions(-)

diff --git a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
index 54b400f08d0..2b86ec839c9 100644
--- a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
+++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
@@ -17,26 +17,24 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+
 import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
 import org.apache.spark.sql.catalyst.expressions.Expression;
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
 import org.apache.spark.sql.types.DataType;
-import scala.collection.immutable.List;
-import scala.collection.immutable.Nil$;
-import scala.collection.mutable.WrappedArray;
 import scala.reflect.ClassTag$;
 
 public class EncoderFactory {
 
   static <T> Encoder<T> create(
       Expression serializer, Expression deserializer, Class<? super T> clazz) {
-    // TODO Isolate usage of Scala APIs in utility https://github.com/apache/beam/issues/22382
-    List<Expression> serializers = Nil$.MODULE$.$colon$colon(serializer);
     return new ExpressionEncoder<>(
         SchemaHelpers.binarySchema(),
         false,
-        serializers,
+        listOf(serializer),
         deserializer,
         ClassTag$.MODULE$.apply(clazz));
   }
@@ -46,6 +44,6 @@ public class EncoderFactory {
    * input arg is {@code null}.
    */
   static Expression invokeIfNotNull(Class<?> cls, String fun, DataType type, Expression... args) {
-    return new StaticInvoke(cls, type, fun, new WrappedArray.ofRef<>(args), true, true);
+    return new StaticInvoke(cls, type, fun, seqOf(args), true, true);
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
similarity index 61%
rename from runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java
rename to runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
index 2406c0f49ab..c5bc71af602 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java
+++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
@@ -15,17 +15,26 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+package org.apache.beam.runners.spark.structuredstreaming.translation.utils;
 
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.KV;
-import org.apache.spark.api.java.function.MapFunction;
+import scala.collection.Seq;
+import scala.collection.immutable.List;
+import scala.collection.immutable.Nil$;
+import scala.collection.mutable.WrappedArray;
 
-/** Helper functions for working with {@link org.apache.beam.sdk.values.KV}. */
-public final class KVHelpers {
+/** Utilities for easier interoperability with the Spark Scala API. */
+public class ScalaInterop {
+  private ScalaInterop() {}
 
-  /** A Spark {@link MapFunction} for extracting the key out of a {@link KV} for GBK for example. */
-  public static <K, V> MapFunction<WindowedValue<KV<K, V>>, K> extractKey() {
-    return wv -> wv.getValue().getKey();
+  public static <T> Seq<T> seqOf(T... t) {
+    return new WrappedArray.ofRef<>(t);
+  }
+
+  public static <T> Seq<T> listOf(T t) {
+    return emptyList().$colon$colon(t);
+  }
+
+  public static <T> List<T> emptyList() {
+    return (List<T>) Nil$.MODULE$;
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java
deleted file mode 100644
index 08c187ce6c6..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming;
-
-public class Constants {
-
-  public static final String BEAM_SOURCE_OPTION = "beam-source";
-  public static final String DEFAULT_PARALLELISM = "default-parallelism";
-  public static final String PIPELINE_OPTIONS = "pipeline-options";
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
index b1de9e941e4..68f54ac93bf 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
@@ -29,10 +29,9 @@ import org.apache.beam.runners.spark.structuredstreaming.metrics.AggregatorMetri
 import org.apache.beam.runners.spark.structuredstreaming.metrics.CompositeSource;
 import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
 import org.apache.beam.runners.spark.structuredstreaming.metrics.SparkBeamMetricSource;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
 import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
 import org.apache.beam.runners.spark.structuredstreaming.translation.batch.PipelineTranslatorBatch;
-import org.apache.beam.runners.spark.structuredstreaming.translation.streaming.PipelineTranslatorStreaming;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.PipelineRunner;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
@@ -41,6 +40,7 @@ import org.apache.beam.sdk.options.ExperimentalOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
 import org.apache.spark.SparkEnv$;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.metrics.MetricsSystem;
@@ -48,24 +48,34 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * SparkStructuredStreamingRunner is based on spark structured streaming framework and is no more
- * based on RDD/DStream API. See
- * https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html It is still
- * experimental, its coverage of the Beam model is partial. The SparkStructuredStreamingRunner
- * translate operations defined on a pipeline to a representation executable by Spark, and then
- * submitting the job to Spark to be executed. If we wanted to run a Beam pipeline with the default
- * options of a single threaded spark instance in local mode, we would do the following:
+ * A Spark runner build on top of Spark's SQL Engine (<a
+ * href="https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html">Structured
+ * Streaming framework</a>).
  *
- * <p>{@code Pipeline p = [logic for pipeline creation] SparkStructuredStreamingPipelineResult
- * result = (SparkStructuredStreamingPipelineResult) p.run(); }
+ * <p><b>This runner is experimental, its coverage of the Beam model is still partial. Due to
+ * limitations of the Structured Streaming framework (e.g. lack of support for multiple stateful
+ * operators), streaming mode is not yet supported by this runner. </b>
+ *
+ * <p>The runner translates transforms defined on a Beam pipeline to Spark `Dataset` transformations
+ * (leveraging the high level Dataset API) and then submits these to Spark to be executed.
+ *
+ * <p>To run a Beam pipeline with the default options using Spark's local mode, we would do the
+ * following:
+ *
+ * <pre>{@code
+ * Pipeline p = [logic for pipeline creation]
+ * PipelineResult result = p.run();
+ * }</pre>
  *
  * <p>To create a pipeline runner to run against a different spark cluster, with a custom master url
  * we would do the following:
  *
- * <p>{@code Pipeline p = [logic for pipeline creation] SparkStructuredStreamingPipelineOptions
- * options = SparkPipelineOptionsFactory.create(); options.setSparkMaster("spark://host:port");
- * SparkStructuredStreamingPipelineResult result = (SparkStructuredStreamingPipelineResult) p.run();
- * }
+ * <pre>{@code
+ * Pipeline p = [logic for pipeline creation]
+ * SparkCommonPipelineOptions options = p.getOptions.as(SparkCommonPipelineOptions.class);
+ * options.setSparkMaster("spark://host:port");
+ * PipelineResult result = p.run();
+ * }</pre>
  */
 @SuppressWarnings({
   "nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -135,7 +145,7 @@ public final class SparkStructuredStreamingRunner
     AggregatorsAccumulator.clear();
     MetricsAccumulator.clear();
 
-    final AbstractTranslationContext translationContext = translatePipeline(pipeline);
+    final TranslationContext translationContext = translatePipeline(pipeline);
 
     final ExecutorService executorService = Executors.newSingleThreadExecutor();
     final Future<?> submissionFuture =
@@ -169,8 +179,10 @@ public final class SparkStructuredStreamingRunner
     return result;
   }
 
-  private AbstractTranslationContext translatePipeline(Pipeline pipeline) {
+  private TranslationContext translatePipeline(Pipeline pipeline) {
     PipelineTranslator.detectTranslationMode(pipeline, options);
+    Preconditions.checkArgument(
+        !options.isStreaming(), "%s does not support streaming pipelines.", getClass().getName());
 
     // Default to using the primitive versions of Read.Bounded and Read.Unbounded for non-portable
     // execution.
@@ -182,10 +194,7 @@ public final class SparkStructuredStreamingRunner
 
     PipelineTranslator.replaceTransforms(pipeline, options);
     prepareFilesToStage(options);
-    PipelineTranslator pipelineTranslator =
-        options.isStreaming()
-            ? new PipelineTranslatorStreaming(options)
-            : new PipelineTranslatorBatch(options);
+    PipelineTranslator pipelineTranslator = new PipelineTranslatorBatch(options);
 
     final JavaSparkContext jsc =
         JavaSparkContext.fromSparkContext(
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java
new file mode 100644
index 00000000000..83dc98f3c10
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.io;
+
+import static java.util.stream.Collectors.toList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static scala.collection.JavaConverters.asScalaIterator;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntSupplier;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.BoundedSource.BoundedReader;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.spark.InterruptibleIterator;
+import org.apache.spark.Partition;
+import org.apache.spark.SparkContext;
+import org.apache.spark.TaskContext;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.connector.catalog.SupportsRead;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.catalog.TableCapability;
+import org.apache.spark.sql.connector.read.Batch;
+import org.apache.spark.sql.connector.read.InputPartition;
+import org.apache.spark.sql.connector.read.PartitionReader;
+import org.apache.spark.sql.connector.read.PartitionReaderFactory;
+import org.apache.spark.sql.connector.read.Scan;
+import org.apache.spark.sql.connector.read.ScanBuilder;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+import scala.Option;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+
+public class BoundedDatasetFactory {
+  private BoundedDatasetFactory() {}
+
+  /**
+   * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}.
+   *
+   * <p>Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization.
+   * This makes this approach at the time being significantly less performant than creating a
+   * dataset from an RDD.
+   */
+  public static <T> Dataset<WindowedValue<T>> createDatasetFromRows(
+      SparkSession session,
+      BoundedSource<T> source,
+      SerializablePipelineOptions options,
+      Encoder<WindowedValue<T>> encoder) {
+    Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism());
+    BeamTable<T> table = new BeamTable<>(source, params);
+    LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty());
+    return Dataset.ofRows(session, logicalPlan).as(encoder);
+  }
+
+  /**
+   * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}.
+   *
+   * <p>This is currently the most efficient approach as it avoid any serialization overhead.
+   */
+  public static <T> Dataset<WindowedValue<T>> createDatasetFromRDD(
+      SparkSession session,
+      BoundedSource<T> source,
+      SerializablePipelineOptions options,
+      Encoder<WindowedValue<T>> encoder) {
+    Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism());
+    RDD<WindowedValue<T>> rdd = new BoundedRDD<>(session.sparkContext(), source, params);
+    return session.createDataset(rdd, encoder);
+  }
+
+  /** An {@link RDD} for a bounded Beam source. */
+  private static class BoundedRDD<T> extends RDD<WindowedValue<T>> {
+    final BoundedSource<T> source;
+    final Params<T> params;
+
+    public BoundedRDD(SparkContext sc, BoundedSource<T> source, Params<T> params) {
+      super(sc, emptyList(), ClassTag.apply(WindowedValue.class));
+      this.source = source;
+      this.params = params;
+    }
+
+    @Override
+    public Iterator<WindowedValue<T>> compute(Partition split, TaskContext context) {
+      return new InterruptibleIterator<>(
+          context,
+          asScalaIterator(new SourcePartitionIterator<>((SourcePartition<T>) split, params)));
+    }
+
+    @Override
+    public Partition[] getPartitions() {
+      return SourcePartition.partitionsOf(source, params).toArray(new Partition[0]);
+    }
+  }
+
+  /** A Spark {@link Table} for a bounded Beam source supporting batch reads only. */
+  private static class BeamTable<T> implements Table, SupportsRead {
+    final BoundedSource<T> source;
+    final Params<T> params;
+
+    BeamTable(BoundedSource<T> source, Params<T> params) {
+      this.source = source;
+      this.params = params;
+    }
+
+    public Encoder<WindowedValue<T>> getEncoder() {
+      return params.encoder;
+    }
+
+    @Override
+    public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) {
+      return () ->
+          new Scan() {
+            @Override
+            public StructType readSchema() {
+              return params.encoder.schema();
+            }
+
+            @Override
+            public Batch toBatch() {
+              return new BeamBatch<>(source, params);
+            }
+          };
+    }
+
+    @Override
+    public String name() {
+      return "BeamSource<" + source.getClass().getName() + ">";
+    }
+
+    @Override
+    public StructType schema() {
+      return params.encoder.schema();
+    }
+
+    @Override
+    public Set<TableCapability> capabilities() {
+      return ImmutableSet.of(TableCapability.BATCH_READ);
+    }
+
+    private static class BeamBatch<T> implements Batch, Serializable {
+      final BoundedSource<T> source;
+      final Params<T> params;
+
+      private BeamBatch(BoundedSource<T> source, Params<T> params) {
+        this.source = source;
+        this.params = params;
+      }
+
+      @Override
+      public InputPartition[] planInputPartitions() {
+        return SourcePartition.partitionsOf(source, params).toArray(new InputPartition[0]);
+      }
+
+      @Override
+      public PartitionReaderFactory createReaderFactory() {
+        return p -> new BeamPartitionReader<>(((SourcePartition<T>) p), params);
+      }
+    }
+
+    private static class BeamPartitionReader<T> implements PartitionReader<InternalRow> {
+      final SourcePartitionIterator<T> iterator;
+      final Serializer<WindowedValue<T>> serializer;
+      transient @Nullable InternalRow next;
+
+      BeamPartitionReader(SourcePartition<T> partition, Params<T> params) {
+        iterator = new SourcePartitionIterator<>(partition, params);
+        serializer = ((ExpressionEncoder<WindowedValue<T>>) params.encoder).createSerializer();
+      }
+
+      @Override
+      public boolean next() throws IOException {
+        if (iterator.hasNext()) {
+          next = serializer.apply(iterator.next());
+          return true;
+        }
+        return false;
+      }
+
+      @Override
+      public InternalRow get() {
+        if (next == null) {
+          throw new IllegalStateException("Next not available");
+        }
+        return next;
+      }
+
+      @Override
+      public void close() throws IOException {
+        next = null;
+        iterator.close();
+      }
+    }
+  }
+
+  /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */
+  private static class SourcePartition<T> implements Partition, InputPartition {
+    final BoundedSource<T> source;
+    final int index;
+
+    SourcePartition(BoundedSource<T> source, IntSupplier idxSupplier) {
+      this.source = source;
+      this.index = idxSupplier.getAsInt();
+    }
+
+    static <T> List<SourcePartition<T>> partitionsOf(BoundedSource<T> source, Params<T> params) {
+      try {
+        PipelineOptions options = params.options.get();
+        long desiredSize = source.getEstimatedSizeBytes(options) / params.numPartitions;
+        List<BoundedSource<T>> split = (List<BoundedSource<T>>) source.split(desiredSize, options);
+        IntSupplier idxSupplier = new AtomicInteger(0)::getAndIncrement;
+        return split.stream().map(s -> new SourcePartition<>(s, idxSupplier)).collect(toList());
+      } catch (Exception e) {
+        throw new RuntimeException(
+            "Error splitting BoundedSource " + source.getClass().getCanonicalName(), e);
+      }
+    }
+
+    @Override
+    public int index() {
+      return index;
+    }
+
+    @Override
+    public int hashCode() {
+      return index;
+    }
+  }
+
+  /** A partition iterator on a partitioned Beam {@link BoundedSource}. */
+  private static class SourcePartitionIterator<T> extends AbstractIterator<WindowedValue<T>>
+      implements Closeable {
+    BoundedReader<T> reader;
+    boolean started = false;
+
+    public SourcePartitionIterator(SourcePartition<T> partition, Params<T> params) {
+      try {
+        reader = partition.source.createReader(params.options.get());
+      } catch (IOException e) {
+        throw new RuntimeException("Failed to create reader from a BoundedSource.", e);
+      }
+    }
+
+    @Override
+    @SuppressWarnings("nullness") // ok, reader not used any longer
+    public void close() throws IOException {
+      if (reader != null) {
+        endOfData();
+        try {
+          reader.close();
+        } finally {
+          reader = null;
+        }
+      }
+    }
+
+    @Override
+    protected WindowedValue<T> computeNext() {
+      try {
+        if (started ? reader.advance() : start()) {
+          return timestampedValueInGlobalWindow(reader.getCurrent(), reader.getCurrentTimestamp());
+        } else {
+          close();
+          return endOfData();
+        }
+      } catch (IOException e) {
+        throw new RuntimeException("Failed to start or advance reader.", e);
+      }
+    }
+
+    private boolean start() throws IOException {
+      started = true;
+      return reader.start();
+    }
+  }
+
+  /** Shared parameters. */
+  private static class Params<T> implements Serializable {
+    final Encoder<WindowedValue<T>> encoder;
+    final SerializablePipelineOptions options;
+    final int numPartitions;
+
+    Params(
+        Encoder<WindowedValue<T>> encoder, SerializablePipelineOptions options, int numPartitions) {
+      checkArgument(numPartitions > 0, "Number of partitions must be greater than zero.");
+      this.encoder = encoder;
+      this.options = options;
+      this.numPartitions = numPartitions;
+    }
+  }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java
similarity index 83%
rename from runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java
rename to runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java
index 67f3613e056..23de70c705b 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java
@@ -16,5 +16,5 @@
  * limitations under the License.
  */
 
-/** Internal utilities to translate Beam pipelines to Spark streaming. */
-package org.apache.beam.runners.spark.structuredstreaming.translation.streaming;
+/** Spark-specific transforms for I/O. */
+package org.apache.beam.runners.spark.structuredstreaming.io;
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
index d48a229996f..c9233a128c1 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
@@ -36,7 +36,6 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Ordering
  * <p>{@link MetricRegistry} is not an interface, so this is not a by-the-book decorator. That said,
  * it delegates all metric related getters to the "decorated" instance.
  */
-@SuppressWarnings({"rawtypes"}) // required by interface
 public class WithMetricsSupport extends MetricRegistry {
 
   private final MetricRegistry internalMetricRegistry;
@@ -70,6 +69,7 @@ public class WithMetricsSupport extends MetricRegistry {
   }
 
   @Override
+  @SuppressWarnings({"rawtypes"}) // required by interface
   public SortedMap<String, Gauge> getGauges(final MetricFilter filter) {
     ImmutableSortedMap.Builder<String, Gauge> builder =
         new ImmutableSortedMap.Builder<>(Ordering.from(String.CASE_INSENSITIVE_ORDER));
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java
deleted file mode 100644
index aed287ba6d5..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java
+++ /dev/null
@@ -1,235 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation;
-
-import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
-import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
-import org.apache.beam.runners.core.construction.TransformInputs;
-import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.PValue;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.spark.api.java.function.ForeachFunction;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.ForeachWriter;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.streaming.DataStreamWriter;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * Base class that gives a context for {@link PTransform} translation: keeping track of the
- * datasets, the {@link SparkSession}, the current transform being translated.
- */
-@SuppressWarnings({
-  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class AbstractTranslationContext {
-
-  private static final Logger LOG = LoggerFactory.getLogger(AbstractTranslationContext.class);
-
-  /** All the datasets of the DAG. */
-  private final Map<PValue, Dataset<?>> datasets;
-  /** datasets that are not used as input to other datasets (leaves of the DAG). */
-  private final Set<Dataset<?>> leaves;
-
-  private final SerializablePipelineOptions serializablePipelineOptions;
-
-  @SuppressFBWarnings("URF_UNREAD_FIELD") // make spotbugs happy
-  private AppliedPTransform<?, ?, ?> currentTransform;
-
-  @SuppressFBWarnings("URF_UNREAD_FIELD") // make spotbugs happy
-  private final SparkSession sparkSession;
-
-  private final Map<PCollectionView<?>, Dataset<?>> broadcastDataSets;
-
-  public AbstractTranslationContext(SparkStructuredStreamingPipelineOptions options) {
-    this.sparkSession = SparkSessionFactory.getOrCreateSession(options);
-    this.serializablePipelineOptions = new SerializablePipelineOptions(options);
-    this.datasets = new HashMap<>();
-    this.leaves = new HashSet<>();
-    this.broadcastDataSets = new HashMap<>();
-  }
-
-  public SparkSession getSparkSession() {
-    return sparkSession;
-  }
-
-  public SerializablePipelineOptions getSerializableOptions() {
-    return serializablePipelineOptions;
-  }
-
-  // --------------------------------------------------------------------------------------------
-  //  Transforms methods
-  // --------------------------------------------------------------------------------------------
-  public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) {
-    this.currentTransform = currentTransform;
-  }
-
-  public AppliedPTransform<?, ?, ?> getCurrentTransform() {
-    return currentTransform;
-  }
-
-  // --------------------------------------------------------------------------------------------
-  //  Datasets methods
-  // --------------------------------------------------------------------------------------------
-  @SuppressWarnings("unchecked")
-  public <T> Dataset<T> emptyDataset() {
-    return (Dataset<T>) sparkSession.emptyDataset(EncoderHelpers.fromBeamCoder(VoidCoder.of()));
-  }
-
-  @SuppressWarnings("unchecked")
-  public <T> Dataset<WindowedValue<T>> getDataset(PValue value) {
-    Dataset<?> dataset = datasets.get(value);
-    // assume that the Dataset is used as an input if retrieved here. So it is not a leaf anymore
-    leaves.remove(dataset);
-    return (Dataset<WindowedValue<T>>) dataset;
-  }
-
-  /**
-   * TODO: All these 3 methods (putDataset*) are temporary and they are used only for generics type
-   * checking. We should unify them in the future.
-   */
-  public void putDatasetWildcard(PValue value, Dataset<WindowedValue<?>> dataset) {
-    if (!datasets.containsKey(value)) {
-      datasets.put(value, dataset);
-      leaves.add(dataset);
-    }
-  }
-
-  public <T> void putDataset(PValue value, Dataset<WindowedValue<T>> dataset) {
-    if (!datasets.containsKey(value)) {
-      datasets.put(value, dataset);
-      leaves.add(dataset);
-    }
-  }
-
-  public <ViewT, ElemT> void setSideInputDataset(
-      PCollectionView<ViewT> value, Dataset<WindowedValue<ElemT>> set) {
-    if (!broadcastDataSets.containsKey(value)) {
-      broadcastDataSets.put(value, set);
-    }
-  }
-
-  @SuppressWarnings("unchecked")
-  public <T> Dataset<T> getSideInputDataSet(PCollectionView<?> value) {
-    return (Dataset<T>) broadcastDataSets.get(value);
-  }
-
-  // --------------------------------------------------------------------------------------------
-  //  PCollections methods
-  // --------------------------------------------------------------------------------------------
-  public PValue getInput() {
-    return Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform));
-  }
-
-  public Map<TupleTag<?>, PCollection<?>> getInputs() {
-    return currentTransform.getInputs();
-  }
-
-  public PValue getOutput() {
-    return Iterables.getOnlyElement(currentTransform.getOutputs().values());
-  }
-
-  public Map<TupleTag<?>, PCollection<?>> getOutputs() {
-    return currentTransform.getOutputs();
-  }
-
-  @SuppressWarnings("unchecked")
-  public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
-    return currentTransform.getOutputs().entrySet().stream()
-        .filter(e -> e.getValue() instanceof PCollection)
-        .collect(Collectors.toMap(Map.Entry::getKey, e -> ((PCollection) e.getValue()).getCoder()));
-  }
-
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline methods
-  // --------------------------------------------------------------------------------------------
-
-  /** Starts the pipeline. */
-  public void startPipeline() {
-    SparkStructuredStreamingPipelineOptions options =
-        serializablePipelineOptions.get().as(SparkStructuredStreamingPipelineOptions.class);
-    int datasetIndex = 0;
-    for (Dataset<?> dataset : leaves) {
-      if (options.isStreaming()) {
-        // TODO: deal with Beam Discarding, Accumulating and Accumulating & Retracting	outputmodes
-        // with DatastreamWriter.outputMode
-        DataStreamWriter<?> dataStreamWriter = dataset.writeStream();
-        // spark sets a default checkpoint dir if not set.
-        if (options.getCheckpointDir() != null) {
-          dataStreamWriter =
-              dataStreamWriter.option("checkpointLocation", options.getCheckpointDir());
-        }
-        launchStreaming(dataStreamWriter.foreach(new NoOpForeachWriter<>()));
-      } else {
-        if (options.getTestMode()) {
-          LOG.debug("**** dataset {} catalyst execution plans ****", ++datasetIndex);
-          dataset.explain(true);
-        }
-        // apply a dummy fn just to apply foreach action that will trigger the pipeline run in
-        // spark
-        dataset.foreach((ForeachFunction) t -> {});
-      }
-    }
-  }
-
-  public abstract void launchStreaming(DataStreamWriter<?> dataStreamWriter);
-
-  public static void printDatasetContent(Dataset<WindowedValue> dataset) {
-    // cannot use dataset.show because dataset schema is binary so it will print binary
-    // code.
-    List<WindowedValue> windowedValues = dataset.collectAsList();
-    for (WindowedValue windowedValue : windowedValues) {
-      LOG.debug("**** dataset content {} ****", windowedValue.toString());
-    }
-  }
-
-  private static class NoOpForeachWriter<T> extends ForeachWriter<T> {
-
-    @Override
-    public boolean open(long partitionId, long epochId) {
-      return false;
-    }
-
-    @Override
-    public void process(T value) {
-      // do nothing
-    }
-
-    @Override
-    public void close(Throwable errorOrNull) {
-      // do nothing
-    }
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
index 0f851d9588d..0fa48fc3d3e 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
@@ -17,24 +17,25 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import java.io.IOException;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.PipelineTranslatorBatch;
-import org.apache.beam.runners.spark.structuredstreaming.translation.streaming.PipelineTranslatorStreaming;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.options.StreamingOptions;
+import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
  * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
  * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation. If we have a streaming job, it is instantiated as a {@link
- * PipelineTranslatorStreaming}. If we have a batch job, it is instantiated as a {@link
- * PipelineTranslatorBatch}.
+ * preparation.
  */
 @SuppressWarnings({
   "nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -42,7 +43,7 @@ import org.slf4j.LoggerFactory;
 public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
   private int depth = 0;
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected AbstractTranslationContext translationContext;
+  protected TranslationContext translationContext;
 
   // --------------------------------------------------------------------------------------------
   //  Pipeline preparation methods
@@ -123,22 +124,25 @@ public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaul
   }
 
   /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract TransformTranslator<?> getTransformTranslator(TransformHierarchy.Node node);
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
+          @Nullable TransformT transform);
 
   /** Apply the given TransformTranslator to the given node. */
-  private <T extends PTransform<?, ?>> void applyTransformTranslator(
-      TransformHierarchy.Node node, TransformTranslator<?> transformTranslator) {
+  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      void applyTransformTranslator(
+          TransformHierarchy.Node node,
+          TransformT transform,
+          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
     // create the applied PTransform on the translationContext
-    translationContext.setCurrentTransform(node.toAppliedPTransform(getPipeline()));
-
-    // avoid type capture
-    @SuppressWarnings("unchecked")
-    T typedTransform = (T) node.getTransform();
-    @SuppressWarnings("unchecked")
-    TransformTranslator<T> typedTransformTranslator = (TransformTranslator<T>) transformTranslator;
-
-    // apply the transformTranslator
-    typedTransformTranslator.translateTransform(typedTransform, translationContext);
+    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+    try {
+      transformTranslator.translate(transform, appliedTransform, translationContext);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
   }
 
   // --------------------------------------------------------------------------------------------
@@ -164,10 +168,12 @@ public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaul
     LOG.debug("{} enterCompositeTransform- {}", genSpaces(depth), node.getFullName());
     depth++;
 
-    TransformTranslator<?> transformTranslator = getTransformTranslator(node);
+    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
+        getTransformTranslator(transform);
 
     if (transformTranslator != null) {
-      applyTransformTranslator(node, transformTranslator);
+      applyTransformTranslator(node, transform, transformTranslator);
       LOG.debug("{} translated- {}", genSpaces(depth), node.getFullName());
       return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
     } else {
@@ -187,16 +193,19 @@ public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaul
 
     // get the transformation corresponding to the node we are
     // currently visiting and translate it into its Spark alternative.
-    TransformTranslator<?> transformTranslator = getTransformTranslator(node);
+    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
+        getTransformTranslator(transform);
+
     if (transformTranslator == null) {
       String transformUrn = PTransformTranslation.urnForTransform(node.getTransform());
       throw new UnsupportedOperationException(
           "The transform " + transformUrn + " is currently not supported.");
     }
-    applyTransformTranslator(node, transformTranslator);
+    applyTransformTranslator(node, transform, transformTranslator);
   }
 
-  public AbstractTranslationContext getTranslationContext() {
+  public TranslationContext getTranslationContext() {
     return translationContext;
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
index 61580aed219..d991a0d9148 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
@@ -17,15 +17,197 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
-import java.io.Serializable;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.getOnlyElement;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.core.construction.TransformInputs;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.SparkSession;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
+import scala.Tuple2;
+import scala.reflect.ClassTag;
+
+/**
+ * Supports translation between a Beam transform, and Spark's operations on Datasets.
+ *
+ * <p>WARNING: Do not make this class serializable! It could easily hide situations where
+ * unnecessary references leak into Spark closures.
+ */
+public abstract class TransformTranslator<
+    InT extends PInput, OutT extends POutput, TransformT extends PTransform<? extends InT, OutT>> {
+
+  protected abstract void translate(TransformT transform, Context cxt) throws IOException;
+
+  public final void translate(
+      TransformT transform,
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform,
+      TranslationContext cxt)
+      throws IOException {
+    translate(transform, new Context(appliedTransform, cxt));
+  }
+
+  protected class Context {
+    private final AppliedPTransform<InT, OutT, PTransform<InT, OutT>> transform;
+    private final TranslationContext cxt;
+    private @MonotonicNonNull InT pIn = null;
+    private @MonotonicNonNull OutT pOut = null;
+
+    protected Context(
+        AppliedPTransform<InT, OutT, PTransform<InT, OutT>> transform, TranslationContext cxt) {
+      this.transform = transform;
+      this.cxt = cxt;
+    }
+
+    public InT getInput() {
+      if (pIn == null) {
+        pIn = (InT) getOnlyElement(TransformInputs.nonAdditionalInputs(transform));
+      }
+      return pIn;
+    }
+
+    public Map<TupleTag<?>, PCollection<?>> getInputs() {
+      return transform.getInputs();
+    }
+
+    public OutT getOutput() {
+      if (pOut == null) {
+        pOut = (OutT) getOnlyElement(transform.getOutputs().values());
+      }
+      return pOut;
+    }
+
+    public <T> PCollection<T> getOutput(TupleTag<T> tag) {
+      PCollection<T> pc = (PCollection<T>) transform.getOutputs().get(tag);
+      if (pc == null) {
+        throw new IllegalStateException("No output for tag " + tag);
+      }
+      return pc;
+    }
+
+    public Map<TupleTag<?>, PCollection<?>> getOutputs() {
+      return transform.getOutputs();
+    }
+
+    public AppliedPTransform<InT, OutT, PTransform<InT, OutT>> getCurrentTransform() {
+      return transform;
+    }
+
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return cxt.getDataset(pCollection);
+    }
+
+    public <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      cxt.putDataset(pCollection, dataset);
+    }
+
+    public SerializablePipelineOptions getSerializableOptions() {
+      return cxt.getSerializableOptions();
+    }
+
+    public PipelineOptions getOptions() {
+      return cxt.getSerializableOptions().get();
+    }
+
+    // FIXME Types don't guarantee anything!
+    public <ViewT, ElemT> void setSideInputDataset(
+        PCollectionView<ViewT> value, Dataset<WindowedValue<ElemT>> set) {
+      cxt.setSideInputDataset(value, set);
+    }
+
+    public <T> Dataset<T> getSideInputDataset(PCollectionView<?> sideInput) {
+      return cxt.getSideInputDataSet(sideInput);
+    }
+
+    public <T> Dataset<WindowedValue<T>> createDataset(
+        List<WindowedValue<T>> data, Encoder<WindowedValue<T>> enc) {
+      return data.isEmpty()
+          ? cxt.getSparkSession().emptyDataset(enc)
+          : cxt.getSparkSession().createDataset(data, enc);
+    }
+
+    public <T> Broadcast<T> broadcast(T value) {
+      return cxt.getSparkSession().sparkContext().broadcast(value, (ClassTag) ClassTag.AnyRef());
+    }
+
+    public SparkSession getSparkSession() {
+      return cxt.getSparkSession();
+    }
+
+    public <T> Encoder<T> encoderOf(Coder<T> coder) {
+      return coder instanceof KvCoder ? kvEncoderOf((KvCoder) coder) : getOrCreateEncoder(coder);
+    }
+
+    public <K, V> Encoder<KV<K, V>> kvEncoderOf(KvCoder<K, V> coder) {
+      return cxt.encoderOf(coder, c -> kvEncoder(keyEncoderOf(coder), valueEncoderOf(coder)));
+    }
+
+    public <K, V> Encoder<K> keyEncoderOf(KvCoder<K, V> coder) {
+      return getOrCreateEncoder(coder.getKeyCoder());
+    }
+
+    public <K, V> Encoder<V> valueEncoderOf(KvCoder<K, V> coder) {
+      return getOrCreateEncoder(coder.getValueCoder());
+    }
+
+    public <T> Encoder<WindowedValue<T>> windowedEncoder(Coder<T> coder) {
+      return windowedValueEncoder(encoderOf(coder), windowEncoder());
+    }
+
+    public <T> Encoder<WindowedValue<T>> windowedEncoder(Encoder<T> enc) {
+      return windowedValueEncoder(enc, windowEncoder());
+    }
+
+    public <T1, T2> Encoder<Tuple2<T1, T2>> tupleEncoder(Encoder<T1> e1, Encoder<T2> e2) {
+      return Encoders.tuple(e1, e2);
+    }
+
+    public <T, W extends BoundedWindow> Encoder<WindowedValue<T>> windowedEncoder(
+        Coder<T> coder, Coder<W> windowCoder) {
+      return windowedValueEncoder(encoderOf(coder), getOrCreateWindowCoder(windowCoder));
+    }
+
+    public Encoder<BoundedWindow> windowEncoder() {
+      checkState(
+          !transform.getInputs().isEmpty(), "Transform has no inputs, cannot get windowCoder!");
+      Coder<BoundedWindow> coder =
+          ((PCollection) getInput()).getWindowingStrategy().getWindowFn().windowCoder();
+      return cxt.encoderOf(coder, c -> encoderFor(c));
+    }
+
+    private <W extends BoundedWindow> Encoder<BoundedWindow> getOrCreateWindowCoder(
+        Coder<W> coder) {
+      return cxt.encoderOf((Coder<BoundedWindow>) coder, c -> encoderFor(c));
+    }
 
-/** Supports translation between a Beam transform, and Spark's operations on Datasets. */
-@SuppressWarnings({
-  "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
-})
-public interface TransformTranslator<TransformT extends PTransform> extends Serializable {
+    private <T> Encoder<T> getOrCreateEncoder(Coder<T> coder) {
+      return cxt.encoderOf(coder, c -> encoderFor(c));
+    }
+  }
 
-  /** Base class for translators of {@link PTransform}. */
-  void translateTransform(TransformT transform, AbstractTranslationContext context);
+  protected <T> Coder<BoundedWindow> windowCoder(PCollection<T> pc) {
+    return (Coder) pc.getWindowingStrategy().getWindowFn().windowCoder();
+  }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
index 12cb2d2fef0..617aa67c5fe 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
@@ -17,27 +17,125 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
-import java.util.concurrent.TimeoutException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.spark.sql.streaming.DataStreamWriter;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.util.Preconditions;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.spark.api.java.function.ForeachFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
- * Subclass of {@link
- * org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext} that
- * address spark breaking changes.
+ * Base class that gives a context for {@link PTransform} translation: keeping track of the
+ * datasets, the {@link SparkSession}, the current transform being translated.
  */
-public class TranslationContext extends AbstractTranslationContext {
+public class TranslationContext {
+
+  private static final Logger LOG = LoggerFactory.getLogger(TranslationContext.class);
+
+  /** All the datasets of the DAG. */
+  private final Map<PValue, Dataset<?>> datasets;
+  /** datasets that are not used as input to other datasets (leaves of the DAG). */
+  private final Set<Dataset<?>> leaves;
+
+  private final SerializablePipelineOptions serializablePipelineOptions;
+
+  private final SparkSession sparkSession;
+
+  private final Map<PCollectionView<?>, Dataset<?>> broadcastDataSets;
+
+  private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
 
   public TranslationContext(SparkStructuredStreamingPipelineOptions options) {
-    super(options);
+    this.sparkSession = SparkSessionFactory.getOrCreateSession(options);
+    this.serializablePipelineOptions = new SerializablePipelineOptions(options);
+    this.datasets = new HashMap<>();
+    this.leaves = new HashSet<>();
+    this.broadcastDataSets = new HashMap<>();
+    this.encoders = new HashMap<>();
+  }
+
+  public SparkSession getSparkSession() {
+    return sparkSession;
+  }
+
+  public SerializablePipelineOptions getSerializableOptions() {
+    return serializablePipelineOptions;
+  }
+
+  public <T> Encoder<T> encoderOf(Coder<T> coder, Function<Coder<T>, Encoder<T>> loadFn) {
+    return (Encoder<T>) encoders.computeIfAbsent(coder, (Function) loadFn);
+  }
+
+  @SuppressWarnings("unchecked") // can't be avoided
+  public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+    Dataset<?> dataset = Preconditions.checkStateNotNull(datasets.get(pCollection));
+    // assume that the Dataset is used as an input if retrieved here. So it is not a leaf anymore
+    leaves.remove(dataset);
+    return (Dataset<WindowedValue<T>>) dataset;
+  }
+
+  public <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+    if (!datasets.containsKey(pCollection)) {
+      datasets.put(pCollection, dataset);
+      leaves.add(dataset);
+    }
+  }
+
+  public <ViewT, ElemT> void setSideInputDataset(
+      PCollectionView<ViewT> value, Dataset<WindowedValue<ElemT>> set) {
+    if (!broadcastDataSets.containsKey(value)) {
+      broadcastDataSets.put(value, set);
+    }
+  }
+
+  @SuppressWarnings("unchecked") // can't be avoided
+  public <T> Dataset<T> getSideInputDataSet(PCollectionView<?> value) {
+    return (Dataset<T>) Preconditions.checkStateNotNull(broadcastDataSets.get(value));
+  }
+
+  /**
+   * Starts the batch pipeline, streaming is not supported.
+   *
+   * @see org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner
+   */
+  public void startPipeline() {
+    encoders.clear();
+
+    SparkStructuredStreamingPipelineOptions options =
+        serializablePipelineOptions.get().as(SparkStructuredStreamingPipelineOptions.class);
+    int datasetIndex = 0;
+    for (Dataset<?> dataset : leaves) {
+      if (options.getTestMode()) {
+        LOG.debug("**** dataset {} catalyst execution plans ****", ++datasetIndex);
+        dataset.explain(true);
+      }
+      // force evaluation using a dummy foreach action
+      dataset.foreach((ForeachFunction) t -> {});
+    }
   }
 
-  @Override
-  public void launchStreaming(DataStreamWriter<?> dataStreamWriter) {
-    try {
-      dataStreamWriter.start();
-    } catch (TimeoutException e) {
-      throw new RuntimeException("A timeout occurred when running the streaming pipeline", e);
+  public static <T> void printDatasetContent(Dataset<WindowedValue<T>> dataset) {
+    // cannot use dataset.show because dataset schema is binary so it will print binary
+    // code.
+    List<WindowedValue<T>> windowedValues = dataset.collectAsList();
+    for (WindowedValue<?> windowedValue : windowedValues) {
+      System.out.println(windowedValue);
     }
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java
deleted file mode 100644
index d0f46ea807c..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java
+++ /dev/null
@@ -1,270 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderException;
-import org.apache.beam.sdk.coders.IterableCoder;
-import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.PaneInfo;
-import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
-import org.apache.beam.sdk.util.CoderUtils;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.expressions.Aggregator;
-import org.joda.time.Instant;
-import scala.Tuple2;
-
-/** An {@link Aggregator} for the Spark Batch Runner.
- * The accumulator is a {@code Iterable<WindowedValue<AccumT>> because an {@code InputT} can be in multiple windows. So, when accumulating {@code InputT} values, we create one accumulator per input window.
- * */
-class AggregatorCombiner<K, InputT, AccumT, OutputT, W extends BoundedWindow>
-    extends Aggregator<
-        WindowedValue<KV<K, InputT>>,
-        Iterable<WindowedValue<AccumT>>,
-        Iterable<WindowedValue<OutputT>>> {
-
-  private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn;
-  private WindowingStrategy<InputT, W> windowingStrategy;
-  private TimestampCombiner timestampCombiner;
-  private Coder<AccumT> accumulatorCoder;
-  private IterableCoder<WindowedValue<AccumT>> bufferEncoder;
-  private IterableCoder<WindowedValue<OutputT>> outputCoder;
-
-  public AggregatorCombiner(
-      Combine.CombineFn<InputT, AccumT, OutputT> combineFn,
-      WindowingStrategy<?, ?> windowingStrategy,
-      Coder<AccumT> accumulatorCoder,
-      Coder<OutputT> outputCoder) {
-    this.combineFn = combineFn;
-    this.windowingStrategy = (WindowingStrategy<InputT, W>) windowingStrategy;
-    this.timestampCombiner = windowingStrategy.getTimestampCombiner();
-    this.accumulatorCoder = accumulatorCoder;
-    this.bufferEncoder =
-        IterableCoder.of(
-            WindowedValue.FullWindowedValueCoder.of(
-                accumulatorCoder, windowingStrategy.getWindowFn().windowCoder()));
-    this.outputCoder =
-        IterableCoder.of(
-            WindowedValue.FullWindowedValueCoder.of(
-                outputCoder, windowingStrategy.getWindowFn().windowCoder()));
-  }
-
-  @Override
-  public Iterable<WindowedValue<AccumT>> zero() {
-    return new ArrayList<>();
-  }
-
-  private Iterable<WindowedValue<AccumT>> createAccumulator(WindowedValue<KV<K, InputT>> inputWv) {
-    // need to create an accumulator because combineFn can modify its input accumulator.
-    AccumT accumulator = combineFn.createAccumulator();
-    AccumT accumT = combineFn.addInput(accumulator, inputWv.getValue().getValue());
-    return Lists.newArrayList(
-        WindowedValue.of(accumT, inputWv.getTimestamp(), inputWv.getWindows(), inputWv.getPane()));
-  }
-
-  @Override
-  public Iterable<WindowedValue<AccumT>> reduce(
-      Iterable<WindowedValue<AccumT>> accumulators, WindowedValue<KV<K, InputT>> inputWv) {
-    return merge(accumulators, createAccumulator(inputWv));
-  }
-
-  @Override
-  public Iterable<WindowedValue<AccumT>> merge(
-      Iterable<WindowedValue<AccumT>> accumulators1,
-      Iterable<WindowedValue<AccumT>> accumulators2) {
-
-    // merge the windows of all the accumulators
-    Iterable<WindowedValue<AccumT>> accumulators = Iterables.concat(accumulators1, accumulators2);
-    Set<W> accumulatorsWindows = collectAccumulatorsWindows(accumulators);
-    Map<W, W> windowToMergeResult;
-    try {
-      windowToMergeResult = mergeWindows(windowingStrategy, accumulatorsWindows);
-    } catch (Exception e) {
-      throw new RuntimeException("Unable to merge accumulators windows", e);
-    }
-
-    // group accumulators by their merged window
-    Map<W, List<Tuple2<AccumT, Instant>>> mergedWindowToAccumulators = new HashMap<>();
-    for (WindowedValue<AccumT> accumulatorWv : accumulators) {
-      // Encode a version of the accumulator if it is in multiple windows. The combineFn is able to
-      // mutate the accumulator instance and this could lead to incorrect results if the same
-      // instance is merged across multiple windows so we decode a new instance as needed. This
-      // prevents issues during merging of accumulators.
-      byte[] encodedAccumT = null;
-      if (accumulatorWv.getWindows().size() > 1) {
-        try {
-          encodedAccumT = CoderUtils.encodeToByteArray(accumulatorCoder, accumulatorWv.getValue());
-        } catch (CoderException e) {
-          throw new RuntimeException(
-              String.format(
-                  "Unable to encode accumulator %s with coder %s.",
-                  accumulatorWv.getValue(), accumulatorCoder),
-              e);
-        }
-      }
-      for (BoundedWindow accumulatorWindow : accumulatorWv.getWindows()) {
-        W mergedWindowForAccumulator = windowToMergeResult.get(accumulatorWindow);
-        mergedWindowForAccumulator =
-            (mergedWindowForAccumulator == null)
-                ? (W) accumulatorWindow
-                : mergedWindowForAccumulator;
-
-        // Decode a copy of the accumulator when necessary.
-        AccumT accumT;
-        if (encodedAccumT != null) {
-          try {
-            accumT = CoderUtils.decodeFromByteArray(accumulatorCoder, encodedAccumT);
-          } catch (CoderException e) {
-            throw new RuntimeException(
-                String.format(
-                    "Unable to encode accumulator %s with coder %s.",
-                    accumulatorWv.getValue(), accumulatorCoder),
-                e);
-          }
-        } else {
-          accumT = accumulatorWv.getValue();
-        }
-
-        // we need only the timestamp and the AccumT, we create a tuple
-        Tuple2<AccumT, Instant> accumAndInstant =
-            new Tuple2<>(
-                accumT,
-                timestampCombiner.assign(mergedWindowForAccumulator, accumulatorWv.getTimestamp()));
-        if (mergedWindowToAccumulators.get(mergedWindowForAccumulator) == null) {
-          mergedWindowToAccumulators.put(
-              mergedWindowForAccumulator, Lists.newArrayList(accumAndInstant));
-        } else {
-          mergedWindowToAccumulators.get(mergedWindowForAccumulator).add(accumAndInstant);
-        }
-      }
-    }
-    // merge the accumulators for each mergedWindow
-    List<WindowedValue<AccumT>> result = new ArrayList<>();
-    for (Map.Entry<W, List<Tuple2<AccumT, Instant>>> entry :
-        mergedWindowToAccumulators.entrySet()) {
-      W mergedWindow = entry.getKey();
-      List<Tuple2<AccumT, Instant>> accumsAndInstantsForMergedWindow = entry.getValue();
-
-      // we need to create the first accumulator because combineFn.mergerAccumulators can modify the
-      // first accumulator
-      AccumT first = combineFn.createAccumulator();
-      Iterable<AccumT> accumulatorsToMerge =
-          Iterables.concat(
-              Collections.singleton(first),
-              accumsAndInstantsForMergedWindow.stream()
-                  .map(x -> x._1())
-                  .collect(Collectors.toList()));
-      result.add(
-          WindowedValue.of(
-              combineFn.mergeAccumulators(accumulatorsToMerge),
-              timestampCombiner.combine(
-                  accumsAndInstantsForMergedWindow.stream()
-                      .map(x -> x._2())
-                      .collect(Collectors.toList())),
-              mergedWindow,
-              PaneInfo.NO_FIRING));
-    }
-    return result;
-  }
-
-  @Override
-  public Iterable<WindowedValue<OutputT>> finish(Iterable<WindowedValue<AccumT>> reduction) {
-    List<WindowedValue<OutputT>> result = new ArrayList<>();
-    for (WindowedValue<AccumT> windowedValue : reduction) {
-      result.add(windowedValue.withValue(combineFn.extractOutput(windowedValue.getValue())));
-    }
-    return result;
-  }
-
-  @Override
-  public Encoder<Iterable<WindowedValue<AccumT>>> bufferEncoder() {
-    return EncoderHelpers.fromBeamCoder(bufferEncoder);
-  }
-
-  @Override
-  public Encoder<Iterable<WindowedValue<OutputT>>> outputEncoder() {
-    return EncoderHelpers.fromBeamCoder(outputCoder);
-  }
-
-  private Set<W> collectAccumulatorsWindows(Iterable<WindowedValue<AccumT>> accumulators) {
-    Set<W> windows = new HashSet<>();
-    for (WindowedValue<?> accumulator : accumulators) {
-      for (BoundedWindow untypedWindow : accumulator.getWindows()) {
-        @SuppressWarnings("unchecked")
-        W window = (W) untypedWindow;
-        windows.add(window);
-      }
-    }
-    return windows;
-  }
-
-  private Map<W, W> mergeWindows(WindowingStrategy<InputT, W> windowingStrategy, Set<W> windows)
-      throws Exception {
-    WindowFn<InputT, W> windowFn = windowingStrategy.getWindowFn();
-
-    if (!windowingStrategy.needsMerge()) {
-      // Return an empty map, indicating that every window is not merged.
-      return Collections.emptyMap();
-    }
-
-    Map<W, W> windowToMergeResult = new HashMap<>();
-    windowFn.mergeWindows(new MergeContextImpl(windowFn, windows, windowToMergeResult));
-    return windowToMergeResult;
-  }
-
-  private class MergeContextImpl extends WindowFn<InputT, W>.MergeContext {
-
-    private Set<W> windows;
-    private Map<W, W> windowToMergeResult;
-
-    MergeContextImpl(WindowFn<InputT, W> windowFn, Set<W> windows, Map<W, W> windowToMergeResult) {
-      windowFn.super();
-      this.windows = windows;
-      this.windowToMergeResult = windowToMergeResult;
-    }
-
-    @Override
-    public Collection<W> windows() {
-      return windows;
-    }
-
-    @Override
-    public void merge(Collection<W> toBeMerged, W mergeResult) throws Exception {
-      for (W w : toBeMerged) {
-        windowToMergeResult.put(w, mergeResult);
-      }
-    }
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java
new file mode 100644
index 00000000000..45026f9d8bd
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java
@@ -0,0 +1,591 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mutablePairEncoder;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators.peekingIterator;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.BiFunction;
+import java.util.function.BinaryOperator;
+import java.util.function.Function;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.PeekingIterator;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.util.MutablePair;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.PolyNull;
+import org.joda.time.Instant;
+
+public class Aggregators {
+
+  /**
+   * Creates simple value {@link Aggregator} that is not window aware.
+   *
+   * @param <ValT> {@link CombineFn} input type
+   * @param <AccT> {@link CombineFn} accumulator type
+   * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+   * @param <InT> {@link Aggregator} input type
+   */
+  public static <ValT, AccT, ResT, InT> Aggregator<InT, ?, ResT> value(
+      CombineFn<ValT, AccT, ResT> fn,
+      Fun1<InT, ValT> valueFn,
+      Encoder<AccT> accEnc,
+      Encoder<ResT> outEnc) {
+    return new ValueAggregator<>(fn, valueFn, accEnc, outEnc);
+  }
+
+  /**
+   * Creates windowed Spark {@link Aggregator} depending on the provided Beam {@link WindowFn}s.
+   *
+   * <p>Specialised implementations are provided for:
+   * <li>{@link Sessions}
+   * <li>Non merging window functions
+   * <li>Merging window functions
+   *
+   * @param <ValT> {@link CombineFn} input type
+   * @param <AccT> {@link CombineFn} accumulator type
+   * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+   * @param <InT> {@link Aggregator} input type
+   */
+  public static <ValT, AccT, ResT, InT>
+      Aggregator<WindowedValue<InT>, ?, Collection<WindowedValue<ResT>>> windowedValue(
+          CombineFn<ValT, AccT, ResT> fn,
+          Fun1<WindowedValue<InT>, ValT> valueFn,
+          WindowingStrategy<?, ?> windowing,
+          Encoder<BoundedWindow> windowEnc,
+          Encoder<AccT> accEnc,
+          Encoder<WindowedValue<ResT>> outEnc) {
+    if (!windowing.needsMerge()) {
+      return new NonMergingWindowedAggregator<>(fn, valueFn, windowing, windowEnc, accEnc, outEnc);
+    } else if (windowing.getWindowFn().getClass().equals(Sessions.class)) {
+      return new SessionsAggregator<>(fn, valueFn, windowing, (Encoder) windowEnc, accEnc, outEnc);
+    }
+    return new MergingWindowedAggregator<>(fn, valueFn, windowing, windowEnc, accEnc, outEnc);
+  }
+
+  /**
+   * Simple value {@link Aggregator} that is not window aware.
+   *
+   * @param <ValT> {@link CombineFn} input type
+   * @param <AccT> {@link CombineFn} accumulator type
+   * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+   * @param <InT> {@link Aggregator} input type
+   */
+  private static class ValueAggregator<ValT, AccT, ResT, InT>
+      extends CombineFnAggregator<ValT, AccT, ResT, InT, AccT, ResT> {
+
+    public ValueAggregator(
+        CombineFn<ValT, AccT, ResT> fn,
+        Fun1<InT, ValT> valueFn,
+        Encoder<AccT> accEnc,
+        Encoder<ResT> outEnc) {
+      super(fn, valueFn, accEnc, outEnc);
+    }
+
+    @Override
+    public AccT zero() {
+      return emptyAcc();
+    }
+
+    @Override
+    public AccT reduce(AccT buff, InT in) {
+      return addToAcc(buff, value(in));
+    }
+
+    @Override
+    public AccT merge(AccT b1, AccT b2) {
+      return mergeAccs(b1, b2);
+    }
+
+    @Override
+    public ResT finish(AccT buff) {
+      return extract(buff);
+    }
+  }
+
+  /**
+   * Specialized windowed Spark {@link Aggregator} for Beam {@link WindowFn}s of type {@link
+   * Sessions}. The aggregator uses a {@link TreeMap} as buffer to maintain ordering of the {@link
+   * IntervalWindow}s and merge these more efficiently.
+   *
+   * <p>For efficiency, this aggregator re-implements {@link
+   * Sessions#mergeWindows(WindowFn.MergeContext)} to leverage the already sorted buffer.
+   *
+   * @param <ValT> {@link CombineFn} input type
+   * @param <AccT> {@link CombineFn} accumulator type
+   * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+   * @param <InT> {@link Aggregator} input type
+   */
+  private static class SessionsAggregator<ValT, AccT, ResT, InT>
+      extends WindowedAggregator<
+          ValT,
+          AccT,
+          ResT,
+          InT,
+          IntervalWindow,
+          TreeMap<IntervalWindow, MutablePair<Instant, AccT>>> {
+
+    SessionsAggregator(
+        CombineFn<ValT, AccT, ResT> combineFn,
+        Fun1<WindowedValue<InT>, ValT> valueFn,
+        WindowingStrategy<?, ?> windowing,
+        Encoder<IntervalWindow> windowEnc,
+        Encoder<AccT> accEnc,
+        Encoder<WindowedValue<ResT>> outEnc) {
+      super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, (Class) TreeMap.class);
+      checkArgument(windowing.getWindowFn().getClass().equals(Sessions.class));
+    }
+
+    @Override
+    public final TreeMap<IntervalWindow, MutablePair<Instant, AccT>> zero() {
+      return new TreeMap<>();
+    }
+
+    @Override
+    @SuppressWarnings("keyfor")
+    public TreeMap<IntervalWindow, MutablePair<Instant, AccT>> reduce(
+        TreeMap<IntervalWindow, MutablePair<Instant, AccT>> buff, WindowedValue<InT> input) {
+      for (IntervalWindow window : (Collection<IntervalWindow>) input.getWindows()) {
+        @MonotonicNonNull MutablePair<Instant, AccT> acc = null;
+        @MonotonicNonNull IntervalWindow first = null, last = null;
+        // start with window before or equal to new window (if exists)
+        @Nullable Entry<IntervalWindow, MutablePair<Instant, AccT>> lower = buff.floorEntry(window);
+        if (lower != null && window.intersects(lower.getKey())) {
+          // if intersecting, init accumulator and extend window to span both
+          acc = lower.getValue();
+          window = window.span(lower.getKey());
+          first = last = lower.getKey();
+        }
+        // merge following windows in order if they intersect, then stop
+        for (Entry<IntervalWindow, MutablePair<Instant, AccT>> entry :
+            buff.tailMap(window, false).entrySet()) {
+          MutablePair<Instant, AccT> entryAcc = entry.getValue();
+          IntervalWindow entryWindow = entry.getKey();
+          if (window.intersects(entryWindow)) {
+            // extend window and merge accumulators
+            window = window.span(entryWindow);
+            acc = acc == null ? entryAcc : mergeAccs(window, acc, entryAcc);
+            if (first == null) {
+              // there was no previous (lower) window intersecting the input window
+              first = last = entryWindow;
+            } else {
+              last = entryWindow;
+            }
+          } else {
+            break; // stop, later windows won't intersect either
+          }
+        }
+        if (first != null && last != null) {
+          // remove entire subset from from first to last after it got merged into acc
+          buff.navigableKeySet().subSet(first, true, last, true).clear();
+        }
+        // add input and get accumulator for new (potentially merged) window
+        buff.put(window, addToAcc(window, acc, value(input), input.getTimestamp()));
+      }
+      return buff;
+    }
+
+    @Override
+    public TreeMap<IntervalWindow, MutablePair<Instant, AccT>> merge(
+        TreeMap<IntervalWindow, MutablePair<Instant, AccT>> b1,
+        TreeMap<IntervalWindow, MutablePair<Instant, AccT>> b2) {
+      if (b1.isEmpty()) {
+        return b2;
+      } else if (b2.isEmpty()) {
+        return b1;
+      }
+      // Init new tree map to merge both buffers
+      TreeMap<IntervalWindow, MutablePair<Instant, AccT>> res = zero();
+      PeekingIterator<Entry<IntervalWindow, MutablePair<Instant, AccT>>> it1 =
+          peekingIterator(b1.entrySet().iterator());
+      PeekingIterator<Entry<IntervalWindow, MutablePair<Instant, AccT>>> it2 =
+          peekingIterator(b2.entrySet().iterator());
+
+      @Nullable MutablePair<Instant, AccT> acc = null;
+      @Nullable IntervalWindow window = null;
+      while (it1.hasNext() || it2.hasNext()) {
+        // pick iterator with the smallest window ahead and forward it
+        Entry<IntervalWindow, MutablePair<Instant, AccT>> nextMin =
+            (it1.hasNext() && it2.hasNext())
+                ? it1.peek().getKey().compareTo(it2.peek().getKey()) <= 0 ? it1.next() : it2.next()
+                : it1.hasNext() ? it1.next() : it2.next();
+        if (window != null && window.intersects(nextMin.getKey())) {
+          // extend window and merge accumulators if intersecting
+          window = window.span(nextMin.getKey());
+          acc = mergeAccs(window, acc, nextMin.getValue());
+        } else {
+          // store window / accumulator if necessary and continue with next minimum
+          if (window != null && acc != null) {
+            res.put(window, acc);
+          }
+          acc = nextMin.getValue();
+          window = nextMin.getKey();
+        }
+      }
+      if (window != null && acc != null) {
+        res.put(window, acc);
+      }
+      return res;
+    }
+  }
+
+  /**
+   * Merging windowed Spark {@link Aggregator} using a Map of {@link BoundedWindow}s as aggregation
+   * buffer. When reducing new input, a windowed accumulator is created for each new window of the
+   * input that doesn't overlap with existing windows. Otherwise, if the window is known or
+   * overlaps, the window is extended accordingly and accumulators are merged.
+   *
+   * @param <ValT> {@link CombineFn} input type
+   * @param <AccT> {@link CombineFn} accumulator type
+   * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+   * @param <InT> {@link Aggregator} input type
+   */
+  private static class MergingWindowedAggregator<ValT, AccT, ResT, InT>
+      extends NonMergingWindowedAggregator<ValT, AccT, ResT, InT> {
+
+    private final WindowFn<ValT, BoundedWindow> windowFn;
+
+    public MergingWindowedAggregator(
+        CombineFn<ValT, AccT, ResT> combineFn,
+        Fun1<WindowedValue<InT>, ValT> valueFn,
+        WindowingStrategy<?, ?> windowing,
+        Encoder<BoundedWindow> windowEnc,
+        Encoder<AccT> accEnc,
+        Encoder<WindowedValue<ResT>> outEnc) {
+      super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc);
+      windowFn = (WindowFn<ValT, BoundedWindow>) windowing.getWindowFn();
+    }
+
+    @Override
+    protected Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(
+        Map<BoundedWindow, MutablePair<Instant, AccT>> buff,
+        Collection<BoundedWindow> windows,
+        ValT value,
+        Instant timestamp) {
+      if (buff.isEmpty()) {
+        // no windows yet to be merged, use the non-merging behavior of super
+        return super.reduce(buff, windows, value, timestamp);
+      }
+      // Merge multiple windows into one target window using the reducer function if the window
+      // already exists. Otherwise, the input value is added to the accumulator. Merged windows are
+      // removed from the accumulator map.
+      Function<BoundedWindow, ReduceFn<AccT>> accFn =
+          target ->
+              (acc, w) -> {
+                MutablePair<Instant, AccT> accW = buff.remove(w);
+                return (accW != null)
+                    ? mergeAccs(w, acc, accW)
+                    : addToAcc(w, acc, value, timestamp);
+              };
+      Set<BoundedWindow> unmerged = mergeWindows(buff, ImmutableSet.copyOf(windows), accFn);
+      if (!unmerged.isEmpty()) {
+        // remaining windows don't have to be merged
+        return super.reduce(buff, unmerged, value, timestamp);
+      }
+      return buff;
+    }
+
+    @Override
+    public Map<BoundedWindow, MutablePair<Instant, AccT>> merge(
+        Map<BoundedWindow, MutablePair<Instant, AccT>> b1,
+        Map<BoundedWindow, MutablePair<Instant, AccT>> b2) {
+      // Merge multiple windows into one target window using the reducer function. Merged windows
+      // are removed from both accumulator maps
+      Function<BoundedWindow, ReduceFn<AccT>> reduceFn =
+          target -> (acc, w) -> mergeAccs(w, mergeAccs(w, acc, b1.remove(w)), b2.remove(w));
+
+      Set<BoundedWindow> unmerged = b2.keySet();
+      unmerged = mergeWindows(b1, unmerged, reduceFn);
+      if (!unmerged.isEmpty()) {
+        // keep only unmerged windows in 2nd accumulator map, continue using "non-merging" merge
+        b2.keySet().retainAll(unmerged);
+        return super.merge(b1, b2);
+      }
+      return b1;
+    }
+
+    /** Reduce function to merge multiple windowed accumulator values into one target window. */
+    private interface ReduceFn<AccT>
+        extends BiFunction<MutablePair<Instant, AccT>, BoundedWindow, MutablePair<Instant, AccT>> {}
+
+    /**
+     * Attempt to merge windows of accumulator map with additional windows using the reducer
+     * function. The reducer function must support {@code null} as zero value.
+     *
+     * @return The subset of additional windows that don't require a merge.
+     */
+    private Set<BoundedWindow> mergeWindows(
+        Map<BoundedWindow, MutablePair<Instant, AccT>> buff,
+        Set<BoundedWindow> newWindows,
+        Function<BoundedWindow, ReduceFn<AccT>> reduceFn) {
+      try {
+        Set<BoundedWindow> newUnmerged = new HashSet<>(newWindows);
+        windowFn.mergeWindows(
+            windowFn.new MergeContext() {
+              @Override
+              public Collection<BoundedWindow> windows() {
+                return Sets.union(buff.keySet(), newWindows);
+              }
+
+              @Override
+              public void merge(Collection<BoundedWindow> merges, BoundedWindow target) {
+                buff.put(
+                    target, merges.stream().reduce(null, reduceFn.apply(target), combiner(target)));
+                newUnmerged.removeAll(merges);
+              }
+            });
+        return newUnmerged;
+      } catch (Exception e) {
+        throw new RuntimeException("Unable to merge accumulators windows", e);
+      }
+    }
+  }
+
+  /**
+   * Non-merging windowed Spark {@link Aggregator} using a Map of {@link BoundedWindow}s as
+   * aggregation buffer. When reducing new input, a windowed accumulator is created for each new
+   * window of the input. Otherwise, if the window is known, the accumulators are merged.
+   *
+   * @param <ValT> {@link CombineFn} input type
+   * @param <AccT> {@link CombineFn} accumulator type
+   * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+   * @param <InT> {@link Aggregator} input type
+   */
+  private static class NonMergingWindowedAggregator<ValT, AccT, ResT, InT>
+      extends WindowedAggregator<
+          ValT, AccT, ResT, InT, BoundedWindow, Map<BoundedWindow, MutablePair<Instant, AccT>>> {
+
+    public NonMergingWindowedAggregator(
+        CombineFn<ValT, AccT, ResT> combineFn,
+        Fun1<WindowedValue<InT>, ValT> valueFn,
+        WindowingStrategy<?, ?> windowing,
+        Encoder<BoundedWindow> windowEnc,
+        Encoder<AccT> accEnc,
+        Encoder<WindowedValue<ResT>> outEnc) {
+      super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, (Class) Map.class);
+    }
+
+    @Override
+    public Map<BoundedWindow, MutablePair<Instant, AccT>> zero() {
+      return new HashMap<>();
+    }
+
+    @Override
+    public final Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(
+        Map<BoundedWindow, MutablePair<Instant, AccT>> buff, WindowedValue<InT> input) {
+      Collection<BoundedWindow> windows = (Collection<BoundedWindow>) input.getWindows();
+      return reduce(buff, windows, value(input), input.getTimestamp());
+    }
+
+    protected Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(
+        Map<BoundedWindow, MutablePair<Instant, AccT>> buff,
+        Collection<BoundedWindow> windows,
+        ValT value,
+        Instant timestamp) {
+      // for each window add the value to the accumulator
+      for (BoundedWindow window : windows) {
+        buff.compute(window, (w, acc) -> addToAcc(w, acc, value, timestamp));
+      }
+      return buff;
+    }
+
+    @Override
+    public Map<BoundedWindow, MutablePair<Instant, AccT>> merge(
+        Map<BoundedWindow, MutablePair<Instant, AccT>> b1,
+        Map<BoundedWindow, MutablePair<Instant, AccT>> b2) {
+      if (b1.isEmpty()) {
+        return b2;
+      } else if (b2.isEmpty()) {
+        return b1;
+      }
+      if (b2.size() > b1.size()) {
+        return merge(b2, b1);
+      }
+      // merge entries of (smaller) 2nd agg buffer map into first by merging the accumulators
+      b2.forEach((w, acc) -> b1.merge(w, acc, combiner(w)));
+      return b1;
+    }
+  }
+
+  /**
+   * Abstract base of a Spark {@link Aggregator} on {@link WindowedValue}s using a Map of {@link W}
+   * as aggregation buffer.
+   *
+   * @param <ValT> {@link CombineFn} input type
+   * @param <AccT> {@link CombineFn} accumulator type
+   * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+   * @param <InT> {@link Aggregator} input type
+   * @param <W> bounded window type
+   * @param <MapT> aggregation buffer {@link W}
+   */
+  private abstract static class WindowedAggregator<
+          ValT,
+          AccT,
+          ResT,
+          InT,
+          W extends @NonNull BoundedWindow,
+          MapT extends Map<W, @NonNull MutablePair<Instant, AccT>>>
+      extends CombineFnAggregator<
+          ValT, AccT, ResT, WindowedValue<InT>, MapT, Collection<WindowedValue<ResT>>> {
+    private final TimestampCombiner tsCombiner;
+
+    public WindowedAggregator(
+        CombineFn<ValT, AccT, ResT> combineFn,
+        Fun1<WindowedValue<InT>, ValT> valueFn,
+        WindowingStrategy<?, ?> windowing,
+        Encoder<W> windowEnc,
+        Encoder<AccT> accEnc,
+        Encoder<WindowedValue<ResT>> outEnc,
+        Class<MapT> clazz) {
+      super(
+          combineFn,
+          valueFn,
+          mapEncoder(windowEnc, mutablePairEncoder(encoderOf(Instant.class), accEnc), clazz),
+          collectionEncoder(outEnc));
+      tsCombiner = windowing.getTimestampCombiner();
+    }
+
+    protected final Instant resolveTimestamp(BoundedWindow w, Instant t1, Instant t2) {
+      return tsCombiner.merge(w, t1, t2);
+    }
+
+    /** Init accumulator with initial input value and timestamp. */
+    protected final MutablePair<Instant, AccT> initAcc(ValT value, Instant timestamp) {
+      return new MutablePair<>(timestamp, addToAcc(emptyAcc(), value));
+    }
+
+    /** Merge timestamped accumulators. */
+    protected final <T extends MutablePair<Instant, AccT>> @PolyNull T mergeAccs(
+        W window, @PolyNull T a1, @PolyNull T a2) {
+      if (a1 == null || a2 == null) {
+        return a1 == null ? a2 : a1;
+      }
+      return (T) a1.update(resolveTimestamp(window, a1._1, a2._1), mergeAccs(a1._2, a2._2));
+    }
+
+    @SuppressWarnings("nullness") // may return null
+    protected BinaryOperator<MutablePair<Instant, AccT>> combiner(W target) {
+      return (a1, a2) -> mergeAccs(target, a1, a2);
+    }
+
+    /** Add an input value to a nullable accumulator. */
+    protected final MutablePair<Instant, AccT> addToAcc(
+        W window, @Nullable MutablePair<Instant, AccT> acc, ValT val, Instant ts) {
+      if (acc == null) {
+        return initAcc(val, ts);
+      }
+      return acc.update(resolveTimestamp(window, acc._1, ts), addToAcc(acc._2, val));
+    }
+
+    @Override
+    @SuppressWarnings("nullness") // entries are non null
+    public final Collection<WindowedValue<ResT>> finish(MapT buffer) {
+      return Collections2.transform(buffer.entrySet(), this::windowedValue);
+    }
+
+    private WindowedValue<ResT> windowedValue(Entry<W, MutablePair<Instant, AccT>> e) {
+      return WindowedValue.of(extract(e.getValue()._2), e.getValue()._1, e.getKey(), NO_FIRING);
+    }
+  }
+
+  /**
+   * Abstract base of Spark {@link Aggregator}s using a Beam {@link CombineFn}.
+   *
+   * @param <ValT> {@link CombineFn} input type
+   * @param <AccT> {@link CombineFn} accumulator type
+   * @param <ResT> {@link CombineFn} result type
+   * @param <InT> {@link Aggregator} input type
+   * @param <BuffT> {@link Aggregator} buffer type
+   * @param <OutT> {@link Aggregator} output type
+   */
+  private abstract static class CombineFnAggregator<ValT, AccT, ResT, InT, BuffT, OutT>
+      extends Aggregator<InT, BuffT, OutT> {
+    private final CombineFn<ValT, AccT, ResT> fn;
+    private final Fun1<InT, ValT> valueFn;
+    private final Encoder<BuffT> bufferEnc;
+    private final Encoder<OutT> outputEnc;
+
+    public CombineFnAggregator(
+        CombineFn<ValT, AccT, ResT> fn,
+        Fun1<InT, ValT> valueFn,
+        Encoder<BuffT> bufferEnc,
+        Encoder<OutT> outputEnc) {
+      this.fn = fn;
+      this.valueFn = valueFn;
+      this.bufferEnc = bufferEnc;
+      this.outputEnc = outputEnc;
+    }
+
+    protected final ValT value(InT in) {
+      return valueFn.apply(in);
+    }
+
+    protected final AccT emptyAcc() {
+      return fn.createAccumulator();
+    }
+
+    protected final AccT mergeAccs(AccT a1, AccT a2) {
+      return fn.mergeAccumulators(ImmutableList.of(a1, a2));
+    }
+
+    protected final AccT addToAcc(AccT acc, ValT val) {
+      return fn.addInput(acc, val);
+    }
+
+    protected final ResT extract(AccT acc) {
+      return fn.extractOutput(acc);
+    }
+
+    @Override
+    public Encoder<BuffT> bufferEncoder() {
+      return bufferEnc;
+    }
+
+    @Override
+    public Encoder<OutT> outputEncoder() {
+      return outputEnc;
+    }
+  }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
new file mode 100644
index 00000000000..5bc017134e9
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static scala.collection.Iterator.single;
+
+import java.util.Collection;
+import java.util.Map;
+import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
+import scala.collection.Iterator;
+
+/**
+ * Translator for {@link Combine.Globally} using a Spark {@link Aggregator}.
+ *
+ * <p>To minimize the amount of data shuffled, this first reduces the data per partition using
+ * {@link Aggregator#reduce}, gathers the partial results (using {@code coalesce(1)}) and finally
+ * merges these using {@link Aggregator#merge}.
+ *
+ * <p>TODOs:
+ * <li>any missing features?
+ */
+class CombineGloballyTranslatorBatch<InT, AccT, OutT>
+    extends TransformTranslator<PCollection<InT>, PCollection<OutT>, Combine.Globally<InT, OutT>> {
+
+  @Override
+  public void translate(Combine.Globally<InT, OutT> transform, Context cxt) {
+    WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+    CombineFn<InT, AccT, OutT> combineFn = (CombineFn<InT, AccT, OutT>) transform.getFn();
+
+    Coder<InT> inputCoder = cxt.getInput().getCoder();
+    Coder<OutT> outputCoder = cxt.getOutput().getCoder();
+    Coder<AccT> accumCoder = accumulatorCoder(combineFn, inputCoder, cxt);
+
+    Encoder<OutT> outEnc = cxt.encoderOf(outputCoder);
+    Encoder<AccT> accEnc = cxt.encoderOf(accumCoder);
+    Encoder<WindowedValue<OutT>> wvOutEnc = cxt.windowedEncoder(outEnc);
+
+    Dataset<WindowedValue<InT>> dataset = cxt.getDataset(cxt.getInput());
+
+    final Dataset<WindowedValue<OutT>> result;
+    if (GroupByKeyHelpers.eligibleForGlobalGroupBy(windowing, true)) {
+      Aggregator<InT, ?, OutT> agg = Aggregators.value(combineFn, v -> v, accEnc, outEnc);
+
+      // Drop window and restore afterwards, produces single global aggregation result
+      result = aggregate(dataset, agg, value(), windowedValue(), wvOutEnc);
+    } else {
+      Aggregator<WindowedValue<InT>, ?, Collection<WindowedValue<OutT>>> agg =
+          Aggregators.windowedValue(
+              combineFn, value(), windowing, cxt.windowEncoder(), accEnc, wvOutEnc);
+
+      // Produces aggregation result per window
+      result =
+          aggregate(dataset, agg, v -> v, fun1(out -> ScalaInterop.scalaIterator(out)), wvOutEnc);
+    }
+    cxt.putDataset(cxt.getOutput(), result);
+  }
+
+  /**
+   * Aggregate dataset globally without using key.
+   *
+   * <p>There is no global, typed version of {@link Dataset#agg(Map)} on datasets. This reduces all
+   * partitions first, and then merges them to receive the final result.
+   */
+  private static <InT, OutT, AggInT, BuffT, AggOutT> Dataset<WindowedValue<OutT>> aggregate(
+      Dataset<WindowedValue<InT>> ds,
+      Aggregator<AggInT, BuffT, AggOutT> agg,
+      Fun1<WindowedValue<InT>, AggInT> valueFn,
+      Fun1<AggOutT, Iterator<WindowedValue<OutT>>> finishFn,
+      Encoder<WindowedValue<OutT>> enc) {
+    // reduce partition using aggregator
+    Fun1<Iterator<WindowedValue<InT>>, Iterator<BuffT>> reduce =
+        fun1(it -> single(it.map(valueFn).foldLeft(agg.zero(), agg::reduce)));
+    // merge reduced partitions using aggregator
+    Fun1<Iterator<BuffT>, Iterator<WindowedValue<OutT>>> merge =
+        fun1(it -> finishFn.apply(agg.finish(it.hasNext() ? it.reduce(agg::merge) : agg.zero())));
+
+    return ds.mapPartitions(reduce, agg.bufferEncoder()).coalesce(1).mapPartitions(merge, enc);
+  }
+
+  private Coder<AccT> accumulatorCoder(
+      CombineFn<InT, AccT, OutT> fn, Coder<InT> valueCoder, Context cxt) {
+    try {
+      return fn.getAccumulatorCoder(cxt.getInput().getPipeline().getCoderRegistry(), valueCoder);
+    } catch (CannotProvideCoderException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private static <T> Fun1<T, Iterator<WindowedValue<T>>> windowedValue() {
+    return v -> single(WindowedValue.valueInGlobalWindow(v));
+  }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
index 2b0cf8be995..f990cd114f9 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
@@ -17,98 +17,141 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import java.util.ArrayList;
-import java.util.List;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+
+import java.util.Collection;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
 import scala.Tuple2;
+import scala.collection.TraversableOnce;
 
-@SuppressWarnings({
-  "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
-})
-class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT>
-    implements TransformTranslator<
-        PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
+/**
+ * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with a Spark {@link
+ * Aggregator}.
+ *
+ * <ul>
+ *   <li>When using the default global window, window information is dropped and restored after the
+ *       aggregation.
+ *   <li>For non-merging windows, windows are exploded and moved into a composite key for better
+ *       distribution. After the aggregation, windowed values are restored from the composite key.
+ *   <li>All other cases use an aggregator on windowed values that is optimized for the current
+ *       windowing strategy.
+ * </ul>
+ *
+ * TODOs:
+ * <li>combine with context (CombineFnWithContext)?
+ * <li>combine with sideInputs?
+ * <li>other there other missing features?
+ */
+class CombinePerKeyTranslatorBatch<K, InT, AccT, OutT>
+    extends TransformTranslator<
+        PCollection<KV<K, InT>>, PCollection<KV<K, OutT>>, Combine.PerKey<K, InT, OutT>> {
 
   @Override
-  public void translateTransform(
-      PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> transform,
-      AbstractTranslationContext context) {
+  public void translate(Combine.PerKey<K, InT, OutT> transform, Context cxt) {
+    WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+    CombineFn<InT, AccT, OutT> combineFn = (CombineFn<InT, AccT, OutT>) transform.getFn();
+
+    KvCoder<K, InT> inputCoder = (KvCoder<K, InT>) cxt.getInput().getCoder();
+    KvCoder<K, OutT> outputCoder = (KvCoder<K, OutT>) cxt.getOutput().getCoder();
+
+    Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder);
+    Encoder<KV<K, InT>> inputEnc = cxt.encoderOf(inputCoder);
+    Encoder<WindowedValue<KV<K, OutT>>> wvOutputEnc = cxt.windowedEncoder(outputCoder);
+    Encoder<AccT> accumEnc = accumEncoder(combineFn, inputCoder.getValueCoder(), cxt);
+
+    final Dataset<WindowedValue<KV<K, OutT>>> result;
+
+    boolean globalGroupBy = eligibleForGlobalGroupBy(windowing, true);
+    boolean groupByWindow = eligibleForGroupByWindow(windowing, true);
 
-    Combine.PerKey combineTransform = (Combine.PerKey) transform;
-    @SuppressWarnings("unchecked")
-    final PCollection<KV<K, InputT>> input = (PCollection<KV<K, InputT>>) context.getInput();
-    @SuppressWarnings("unchecked")
-    final PCollection<KV<K, OutputT>> output = (PCollection<KV<K, OutputT>>) context.getOutput();
-    @SuppressWarnings("unchecked")
-    final Combine.CombineFn<InputT, AccumT, OutputT> combineFn =
-        (Combine.CombineFn<InputT, AccumT, OutputT>) combineTransform.getFn();
-    WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
+    if (globalGroupBy || groupByWindow) {
+      Aggregator<KV<K, InT>, ?, OutT> valueAgg =
+          Aggregators.value(combineFn, KV::getValue, accumEnc, cxt.valueEncoderOf(outputCoder));
 
-    Dataset<WindowedValue<KV<K, InputT>>> inputDataset = context.getDataset(input);
+      if (globalGroupBy) {
+        // Drop window and group by key globally to run the aggregation (combineFn), afterwards the
+        // global window is restored
+        result =
+            cxt.getDataset(cxt.getInput())
+                .groupByKey(valueKey(), keyEnc)
+                .mapValues(value(), inputEnc)
+                .agg(valueAgg.toColumn())
+                .map(globalKV(), wvOutputEnc);
+      } else {
+        Encoder<Tuple2<BoundedWindow, K>> windowedKeyEnc =
+            cxt.tupleEncoder(cxt.windowEncoder(), keyEnc);
 
-    KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
-    Coder<K> keyCoder = inputCoder.getKeyCoder();
-    KvCoder<K, OutputT> outputKVCoder = (KvCoder<K, OutputT>) output.getCoder();
-    Coder<OutputT> outputCoder = outputKVCoder.getValueCoder();
+        // Group by window and key to run the aggregation (combineFn)
+        result =
+            cxt.getDataset(cxt.getInput())
+                .flatMap(explodeWindowedKey(value()), cxt.tupleEncoder(windowedKeyEnc, inputEnc))
+                .groupByKey(fun1(Tuple2::_1), windowedKeyEnc)
+                .mapValues(fun1(Tuple2::_2), inputEnc)
+                .agg(valueAgg.toColumn())
+                .map(windowedKV(), wvOutputEnc);
+      }
+    } else {
+      // Optimized aggregator for non-merging and session window functions, all others depend on
+      // windowFn.mergeWindows
+      Aggregator<WindowedValue<KV<K, InT>>, ?, Collection<WindowedValue<OutT>>> aggregator =
+          Aggregators.windowedValue(
+              combineFn,
+              valueValue(),
+              windowing,
+              cxt.windowEncoder(),
+              accumEnc,
+              cxt.windowedEncoder(outputCoder.getValueCoder()));
+      result =
+          cxt.getDataset(cxt.getInput())
+              .groupByKey(valueKey(), keyEnc)
+              .agg(aggregator.toColumn())
+              .flatMap(explodeWindows(), wvOutputEnc);
+    }
+
+    cxt.putDataset(cxt.getOutput(), result);
+  }
+
+  private static <K, V>
+      Fun1<Tuple2<K, Collection<WindowedValue<V>>>, TraversableOnce<WindowedValue<KV<K, V>>>>
+          explodeWindows() {
+    return t ->
+        ScalaInterop.scalaIterator(t._2).map(wv -> wv.withValue(KV.of(t._1, wv.getValue())));
+  }
 
-    KeyValueGroupedDataset<K, WindowedValue<KV<K, InputT>>> groupedDataset =
-        inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
+  private static <K, V> Fun1<Tuple2<K, V>, WindowedValue<KV<K, V>>> globalKV() {
+    return t -> WindowedValue.valueInGlobalWindow(KV.of(t._1, t._2));
+  }
 
-    Coder<AccumT> accumulatorCoder = null;
+  private Encoder<AccT> accumEncoder(
+      CombineFn<InT, AccT, OutT> fn, Coder<InT> valueCoder, Context cxt) {
     try {
-      accumulatorCoder =
-          combineFn.getAccumulatorCoder(
-              input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
+      CoderRegistry registry = cxt.getInput().getPipeline().getCoderRegistry();
+      return cxt.encoderOf(fn.getAccumulatorCoder(registry, valueCoder));
     } catch (CannotProvideCoderException e) {
       throw new RuntimeException(e);
     }
-
-    Dataset<Tuple2<K, Iterable<WindowedValue<OutputT>>>> combinedDataset =
-        groupedDataset.agg(
-            new AggregatorCombiner<K, InputT, AccumT, OutputT, BoundedWindow>(
-                    combineFn, windowingStrategy, accumulatorCoder, outputCoder)
-                .toColumn());
-
-    // expand the list into separate elements and put the key back into the elements
-    WindowedValue.WindowedValueCoder<KV<K, OutputT>> wvCoder =
-        WindowedValue.FullWindowedValueCoder.of(
-            outputKVCoder, input.getWindowingStrategy().getWindowFn().windowCoder());
-    Dataset<WindowedValue<KV<K, OutputT>>> outputDataset =
-        combinedDataset.flatMap(
-            (FlatMapFunction<
-                    Tuple2<K, Iterable<WindowedValue<OutputT>>>, WindowedValue<KV<K, OutputT>>>)
-                tuple2 -> {
-                  K key = tuple2._1();
-                  Iterable<WindowedValue<OutputT>> windowedValues = tuple2._2();
-                  List<WindowedValue<KV<K, OutputT>>> result = new ArrayList<>();
-                  for (WindowedValue<OutputT> windowedValue : windowedValues) {
-                    KV<K, OutputT> kv = KV.of(key, windowedValue.getValue());
-                    result.add(
-                        WindowedValue.of(
-                            kv,
-                            windowedValue.getTimestamp(),
-                            windowedValue.getWindows(),
-                            windowedValue.getPane()));
-                  }
-                  return result.iterator();
-                },
-            EncoderHelpers.fromBeamCoder(wvCoder));
-    context.putDataset(output, outputDataset);
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java
index ae1eeced328..27151292023 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java
@@ -19,39 +19,25 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
 import java.io.IOException;
 import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.spark.sql.Dataset;
 
 class CreatePCollectionViewTranslatorBatch<ElemT, ViewT>
-    implements TransformTranslator<PTransform<PCollection<ElemT>, PCollection<ElemT>>> {
+    extends TransformTranslator<
+        PCollection<ElemT>, PCollection<ElemT>, View.CreatePCollectionView<ElemT, ViewT>> {
 
   @Override
-  public void translateTransform(
-      PTransform<PCollection<ElemT>, PCollection<ElemT>> transform,
-      AbstractTranslationContext context) {
+  public void translate(View.CreatePCollectionView<ElemT, ViewT> transform, Context context) {
 
     Dataset<WindowedValue<ElemT>> inputDataSet = context.getDataset(context.getInput());
 
-    @SuppressWarnings("unchecked")
-    AppliedPTransform<
-            PCollection<ElemT>,
-            PCollection<ElemT>,
-            PTransform<PCollection<ElemT>, PCollection<ElemT>>>
-        application =
-            (AppliedPTransform<
-                    PCollection<ElemT>,
-                    PCollection<ElemT>,
-                    PTransform<PCollection<ElemT>, PCollection<ElemT>>>)
-                context.getCurrentTransform();
     PCollectionView<ViewT> input;
     try {
-      input = CreatePCollectionViewTranslation.getView(application);
+      input = CreatePCollectionViewTranslation.getView(context.getCurrentTransform());
     } catch (IOException e) {
       throw new RuntimeException(e);
     }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
deleted file mode 100644
index 46bde96c30c..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
+++ /dev/null
@@ -1,240 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS;
-import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
-
-import java.io.IOException;
-import java.io.Serializable;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
-import org.apache.beam.runners.core.serialization.Base64Serializer;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SchemaHelpers;
-import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.connector.catalog.SupportsRead;
-import org.apache.spark.sql.connector.catalog.Table;
-import org.apache.spark.sql.connector.catalog.TableCapability;
-import org.apache.spark.sql.connector.catalog.TableProvider;
-import org.apache.spark.sql.connector.expressions.Transform;
-import org.apache.spark.sql.connector.read.Batch;
-import org.apache.spark.sql.connector.read.InputPartition;
-import org.apache.spark.sql.connector.read.PartitionReader;
-import org.apache.spark.sql.connector.read.PartitionReaderFactory;
-import org.apache.spark.sql.connector.read.Scan;
-import org.apache.spark.sql.connector.read.ScanBuilder;
-import org.apache.spark.sql.types.StructType;
-import org.apache.spark.sql.util.CaseInsensitiveStringMap;
-
-/**
- * Spark DataSourceV2 API was removed in Spark3. This is a Beam source wrapper using the new spark 3
- * source API.
- */
-public class DatasetSourceBatch implements TableProvider {
-
-  private static final StructType BINARY_SCHEMA = SchemaHelpers.binarySchema();
-
-  public DatasetSourceBatch() {}
-
-  @Override
-  public StructType inferSchema(CaseInsensitiveStringMap options) {
-    return BINARY_SCHEMA;
-  }
-
-  @Override
-  public boolean supportsExternalMetadata() {
-    return true;
-  }
-
-  @Override
-  public Table getTable(
-      StructType schema, Transform[] partitioning, Map<String, String> properties) {
-    return new DatasetSourceBatchTable();
-  }
-
-  private static class DatasetSourceBatchTable implements SupportsRead {
-
-    @Override
-    public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
-      return new ScanBuilder() {
-
-        @Override
-        public Scan build() {
-          return new Scan() { // scan for Batch reading
-
-            @Override
-            public StructType readSchema() {
-              return BINARY_SCHEMA;
-            }
-
-            @Override
-            public Batch toBatch() {
-              return new BeamBatch<>(options);
-            }
-          };
-        }
-      };
-    }
-
-    @Override
-    public String name() {
-      return "BeamSource";
-    }
-
-    @Override
-    public StructType schema() {
-      return BINARY_SCHEMA;
-    }
-
-    @Override
-    public Set<TableCapability> capabilities() {
-      final ImmutableSet<TableCapability> capabilities =
-          ImmutableSet.of(TableCapability.BATCH_READ);
-      return capabilities;
-    }
-
-    private static class BeamBatch<T> implements Batch, Serializable {
-
-      private final int numPartitions;
-      private final BoundedSource<T> source;
-      private final SerializablePipelineOptions serializablePipelineOptions;
-
-      private BeamBatch(CaseInsensitiveStringMap options) {
-        if (Strings.isNullOrEmpty(options.get(BEAM_SOURCE_OPTION))) {
-          throw new RuntimeException("Beam source was not set in DataSource options");
-        }
-        this.source =
-            Base64Serializer.deserializeUnchecked(
-                options.get(BEAM_SOURCE_OPTION), BoundedSource.class);
-
-        if (Strings.isNullOrEmpty(DEFAULT_PARALLELISM)) {
-          throw new RuntimeException("Spark default parallelism was not set in DataSource options");
-        }
-        this.numPartitions = Integer.parseInt(options.get(DEFAULT_PARALLELISM));
-        checkArgument(numPartitions > 0, "Number of partitions must be greater than zero.");
-
-        if (Strings.isNullOrEmpty(options.get(PIPELINE_OPTIONS))) {
-          throw new RuntimeException("Beam pipelineOptions were not set in DataSource options");
-        }
-        this.serializablePipelineOptions =
-            new SerializablePipelineOptions(options.get(PIPELINE_OPTIONS));
-      }
-
-      @Override
-      public InputPartition[] planInputPartitions() {
-        PipelineOptions options = serializablePipelineOptions.get();
-        long desiredSizeBytes;
-
-        try {
-          desiredSizeBytes = source.getEstimatedSizeBytes(options) / numPartitions;
-          List<? extends BoundedSource<T>> splits = source.split(desiredSizeBytes, options);
-          InputPartition[] result = new InputPartition[splits.size()];
-          int i = 0;
-          for (BoundedSource<T> split : splits) {
-            result[i++] = new BeamInputPartition<>(split);
-          }
-          return result;
-        } catch (Exception e) {
-          throw new RuntimeException(
-              "Error in splitting BoundedSource " + source.getClass().getCanonicalName(), e);
-        }
-      }
-
-      @Override
-      public PartitionReaderFactory createReaderFactory() {
-        return new PartitionReaderFactory() {
-
-          @Override
-          public PartitionReader<InternalRow> createReader(InputPartition partition) {
-            return new BeamPartitionReader<T>(
-                ((BeamInputPartition<T>) partition).getSource(), serializablePipelineOptions);
-          }
-        };
-      }
-
-      private static class BeamInputPartition<T> implements InputPartition {
-
-        private final BoundedSource<T> source;
-
-        private BeamInputPartition(BoundedSource<T> source) {
-          this.source = source;
-        }
-
-        public BoundedSource<T> getSource() {
-          return source;
-        }
-      }
-
-      private static class BeamPartitionReader<T> implements PartitionReader<InternalRow> {
-
-        private final BoundedSource<T> source;
-        private final BoundedSource.BoundedReader<T> reader;
-        private boolean started;
-        private boolean closed;
-
-        BeamPartitionReader(
-            BoundedSource<T> source, SerializablePipelineOptions serializablePipelineOptions) {
-          this.started = false;
-          this.closed = false;
-          this.source = source;
-          // reader is not serializable so lazy initialize it
-          try {
-            reader =
-                source.createReader(serializablePipelineOptions.get().as(PipelineOptions.class));
-          } catch (IOException e) {
-            throw new RuntimeException("Error creating BoundedReader ", e);
-          }
-        }
-
-        @Override
-        public boolean next() throws IOException {
-          if (!started) {
-            started = true;
-            return reader.start();
-          } else {
-            return !closed && reader.advance();
-          }
-        }
-
-        @Override
-        public InternalRow get() {
-          WindowedValue<T> windowedValue =
-              WindowedValue.timestampedValueInGlobalWindow(
-                  reader.getCurrent(), reader.getCurrentTimestamp());
-          return RowHelpers.storeWindowedValueInRow(windowedValue, source.getOutputCoder());
-        }
-
-        @Override
-        public void close() throws IOException {
-          closed = true;
-          reader.close();
-        }
-      }
-    }
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
deleted file mode 100644
index 42a809fdd97..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
+++ /dev/null
@@ -1,164 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import java.util.Collections;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.DoFnRunners;
-import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
-import org.apache.beam.runners.spark.structuredstreaming.translation.utils.CachedSideInputReader;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.LinkedListMultimap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
-import org.apache.spark.api.java.function.MapPartitionsFunction;
-import scala.Tuple2;
-
-/**
- * Encapsulates a {@link DoFn} inside a Spark {@link
- * org.apache.spark.api.java.function.MapPartitionsFunction}.
- *
- * <p>We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index and must tag
- * all outputs with the output number. Afterwards a filter will filter out those elements that are
- * not to be in a specific output.
- */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public class DoFnFunction<InputT, OutputT>
-    implements MapPartitionsFunction<WindowedValue<InputT>, Tuple2<TupleTag<?>, WindowedValue<?>>> {
-
-  private final MetricsContainerStepMapAccumulator metricsAccum;
-  private final String stepName;
-  private final DoFn<InputT, OutputT> doFn;
-  private transient boolean wasSetupCalled;
-  private final WindowingStrategy<?, ?> windowingStrategy;
-  private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
-  private final SerializablePipelineOptions serializableOptions;
-  private final List<TupleTag<?>> additionalOutputTags;
-  private final TupleTag<OutputT> mainOutputTag;
-  private final Coder<InputT> inputCoder;
-  private final Map<TupleTag<?>, Coder<?>> outputCoderMap;
-  private final SideInputBroadcast broadcastStateData;
-  private DoFnSchemaInformation doFnSchemaInformation;
-  private Map<String, PCollectionView<?>> sideInputMapping;
-
-  public DoFnFunction(
-      MetricsContainerStepMapAccumulator metricsAccum,
-      String stepName,
-      DoFn<InputT, OutputT> doFn,
-      WindowingStrategy<?, ?> windowingStrategy,
-      Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
-      SerializablePipelineOptions serializableOptions,
-      List<TupleTag<?>> additionalOutputTags,
-      TupleTag<OutputT> mainOutputTag,
-      Coder<InputT> inputCoder,
-      Map<TupleTag<?>, Coder<?>> outputCoderMap,
-      SideInputBroadcast broadcastStateData,
-      DoFnSchemaInformation doFnSchemaInformation,
-      Map<String, PCollectionView<?>> sideInputMapping) {
-    this.metricsAccum = metricsAccum;
-    this.stepName = stepName;
-    this.doFn = doFn;
-    this.windowingStrategy = windowingStrategy;
-    this.sideInputs = sideInputs;
-    this.serializableOptions = serializableOptions;
-    this.additionalOutputTags = additionalOutputTags;
-    this.mainOutputTag = mainOutputTag;
-    this.inputCoder = inputCoder;
-    this.outputCoderMap = outputCoderMap;
-    this.broadcastStateData = broadcastStateData;
-    this.doFnSchemaInformation = doFnSchemaInformation;
-    this.sideInputMapping = sideInputMapping;
-  }
-
-  @Override
-  public Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> call(Iterator<WindowedValue<InputT>> iter)
-      throws Exception {
-    if (!wasSetupCalled && iter.hasNext()) {
-      DoFnInvokers.tryInvokeSetupFor(doFn, serializableOptions.get());
-      wasSetupCalled = true;
-    }
-
-    DoFnOutputManager outputManager = new DoFnOutputManager();
-
-    DoFnRunner<InputT, OutputT> doFnRunner =
-        DoFnRunners.simpleRunner(
-            serializableOptions.get(),
-            doFn,
-            CachedSideInputReader.of(new SparkSideInputReader(sideInputs, broadcastStateData)),
-            outputManager,
-            mainOutputTag,
-            additionalOutputTags,
-            new NoOpStepContext(),
-            inputCoder,
-            outputCoderMap,
-            windowingStrategy,
-            doFnSchemaInformation,
-            sideInputMapping);
-
-    DoFnRunnerWithMetrics<InputT, OutputT> doFnRunnerWithMetrics =
-        new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum);
-
-    return new ProcessContext<>(
-            doFn, doFnRunnerWithMetrics, outputManager, Collections.emptyIterator())
-        .processPartition(iter)
-        .iterator();
-  }
-
-  private class DoFnOutputManager
-      implements ProcessContext.ProcessOutputManager<Tuple2<TupleTag<?>, WindowedValue<?>>> {
-
-    private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create();
-
-    @Override
-    public void clear() {
-      outputs.clear();
-    }
-
-    @Override
-    public Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> iterator() {
-      Iterator<Map.Entry<TupleTag<?>, WindowedValue<?>>> entryIter = outputs.entries().iterator();
-      return Iterators.transform(entryIter, this.entryToTupleFn());
-    }
-
-    private <K, V> Function<Map.Entry<K, V>, Tuple2<K, V>> entryToTupleFn() {
-      return en -> new Tuple2<>(en.getKey(), en.getValue());
-    }
-
-    @Override
-    public synchronized <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
-      outputs.put(tag, output);
-    }
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java
new file mode 100644
index 00000000000..c02e07319af
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toCollection;
+import static java.util.stream.Collectors.toMap;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists.newArrayListWithCapacity;
+
+import java.io.Serializable;
+import java.util.ArrayDeque;
+import java.util.Collection;
+import java.util.Deque;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.DoFnRunners.OutputManager;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext;
+import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.CachedSideInputReader;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun2;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator;
+import org.apache.spark.api.java.function.MapPartitionsFunction;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.collection.Iterator;
+
+/**
+ * Encapsulates a {@link DoFn} inside a Spark {@link
+ * org.apache.spark.api.java.function.MapPartitionsFunction}.
+ */
+class DoFnMapPartitionsFactory<InT, OutT> implements Serializable {
+  private final String stepName;
+
+  private final DoFn<InT, OutT> doFn;
+  private final DoFnSchemaInformation doFnSchema;
+  private final SerializablePipelineOptions options;
+
+  private final Coder<InT> coder;
+  private final WindowingStrategy<?, ?> windowingStrategy;
+  private final TupleTag<OutT> mainOutput;
+  private final List<TupleTag<?>> additionalOutputs;
+  private final Map<TupleTag<?>, Coder<?>> outputCoders;
+
+  private final Map<String, PCollectionView<?>> sideInputs;
+  private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputWindows;
+  private final SideInputBroadcast broadcastStateData;
+
+  DoFnMapPartitionsFactory(
+      String stepName,
+      DoFn<InT, OutT> doFn,
+      DoFnSchemaInformation doFnSchema,
+      SerializablePipelineOptions options,
+      PCollection<InT> input,
+      TupleTag<OutT> mainOutput,
+      Map<TupleTag<?>, PCollection<?>> outputs,
+      Map<String, PCollectionView<?>> sideInputs,
+      SideInputBroadcast broadcastStateData) {
+    this.stepName = stepName;
+    this.doFn = doFn;
+    this.doFnSchema = doFnSchema;
+    this.options = options;
+    this.coder = input.getCoder();
+    this.windowingStrategy = input.getWindowingStrategy();
+    this.mainOutput = mainOutput;
+    this.additionalOutputs = additionalOutputs(outputs, mainOutput);
+    this.outputCoders = outputCoders(outputs);
+    this.sideInputs = sideInputs;
+    this.sideInputWindows = sideInputWindows(sideInputs.values());
+    this.broadcastStateData = broadcastStateData;
+  }
+
+  /** Create the {@link MapPartitionsFunction} using the provided output function. */
+  <OutputT extends @NonNull Object> Fun1<Iterator<WindowedValue<InT>>, Iterator<OutputT>> create(
+      Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn) {
+    return it ->
+        it.hasNext()
+            ? scalaIterator(new DoFnPartitionIt<>(outputFn, it))
+            : (Iterator<OutputT>) Iterator.empty();
+  }
+
+  // FIXME Add support for TimerInternals.TimerData
+  /**
+   * Partition iterator that lazily processes each element from the (input) iterator on demand
+   * producing zero, one or more output elements as output (via an internal buffer).
+   *
+   * <p>When initializing the iterator for a partition {@code setup} followed by {@code startBundle}
+   * is called.
+   */
+  private class DoFnPartitionIt<FnInT extends InT, OutputT> extends AbstractIterator<OutputT> {
+    private final Deque<OutputT> buffer;
+    private final DoFnRunner<InT, OutT> doFnRunner;
+    private final Iterator<WindowedValue<FnInT>> partitionIt;
+
+    private boolean isBundleFinished;
+
+    DoFnPartitionIt(
+        Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn,
+        Iterator<WindowedValue<FnInT>> partitionIt) {
+      this.buffer = new ArrayDeque<>();
+      this.doFnRunner = metricsRunner(simpleRunner(outputFn, buffer));
+      this.partitionIt = partitionIt;
+      // Before starting to iterate over the partition, invoke setup and then startBundle
+      DoFnInvokers.tryInvokeSetupFor(doFn, options.get());
+      try {
+        doFnRunner.startBundle();
+      } catch (RuntimeException re) {
+        DoFnInvokers.invokerFor(doFn).invokeTeardown();
+        throw re;
+      }
+    }
+
+    @Override
+    protected OutputT computeNext() {
+      try {
+        while (true) {
+          if (!buffer.isEmpty()) {
+            return buffer.remove();
+          }
+          if (partitionIt.hasNext()) {
+            // grab the next element and process it.
+            doFnRunner.processElement((WindowedValue<InT>) partitionIt.next());
+          } else {
+            if (!isBundleFinished) {
+              isBundleFinished = true;
+              doFnRunner.finishBundle();
+              continue; // finishBundle can produce more output
+            }
+            DoFnInvokers.invokerFor(doFn).invokeTeardown();
+            return endOfData();
+          }
+        }
+      } catch (RuntimeException re) {
+        DoFnInvokers.invokerFor(doFn).invokeTeardown();
+        throw re;
+      }
+    }
+  }
+
+  private <OutputT> DoFnRunner<InT, OutT> simpleRunner(
+      Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn, Deque<OutputT> buffer) {
+    OutputManager outputManager =
+        new OutputManager() {
+          @Override
+          public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
+            buffer.add(outputFn.apply(tag, output));
+          }
+        };
+    SideInputReader sideInputReader =
+        CachedSideInputReader.of(new SparkSideInputReader(sideInputWindows, broadcastStateData));
+    return DoFnRunners.simpleRunner(
+        options.get(),
+        doFn,
+        sideInputReader,
+        outputManager,
+        mainOutput,
+        additionalOutputs,
+        new NoOpStepContext(),
+        coder,
+        outputCoders,
+        windowingStrategy,
+        doFnSchema,
+        sideInputs);
+  }
+
+  private DoFnRunner<InT, OutT> metricsRunner(DoFnRunner<InT, OutT> runner) {
+    return new DoFnRunnerWithMetrics<>(stepName, runner, MetricsAccumulator.getInstance());
+  }
+
+  private static Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputWindows(
+      Collection<PCollectionView<?>> views) {
+    return views.stream().collect(toMap(identity(), DoFnMapPartitionsFactory::windowingStrategy));
+  }
+
+  private static WindowingStrategy<?, ?> windowingStrategy(PCollectionView<?> view) {
+    PCollection<?> pc = view.getPCollection();
+    if (pc == null) {
+      throw new IllegalStateException("PCollection not available for " + view);
+    }
+    return pc.getWindowingStrategy();
+  }
+
+  private static List<TupleTag<?>> additionalOutputs(
+      Map<TupleTag<?>, PCollection<?>> outputs, TupleTag<?> mainOutput) {
+    return outputs.keySet().stream()
+        .filter(t -> !t.equals(mainOutput))
+        .collect(toCollection(() -> newArrayListWithCapacity(outputs.size() - 1)));
+  }
+
+  private static Map<TupleTag<?>, Coder<?>> outputCoders(Map<TupleTag<?>, PCollection<?>> outputs) {
+    return outputs.entrySet().stream()
+        .collect(toMap(Map.Entry::getKey, e -> e.getValue().getCoder()));
+  }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java
index db361f7753e..eb071304998 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java
@@ -17,49 +17,47 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
 
 import java.util.Collection;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import java.util.Iterator;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
-import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
 
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
 class FlattenTranslatorBatch<T>
-    implements TransformTranslator<PTransform<PCollectionList<T>, PCollection<T>>> {
+    extends TransformTranslator<PCollectionList<T>, PCollection<T>, Flatten.PCollections<T>> {
 
   @Override
-  public void translateTransform(
-      PTransform<PCollectionList<T>, PCollection<T>> transform,
-      AbstractTranslationContext context) {
-    Collection<PCollection<?>> pcollectionList = context.getInputs().values();
-    Dataset<WindowedValue<T>> result = null;
-    if (pcollectionList.isEmpty()) {
-      result = context.emptyDataset();
-    } else {
-      for (PValue pValue : pcollectionList) {
-        checkArgument(
-            pValue instanceof PCollection,
-            "Got non-PCollection input to flatten: %s of type %s",
-            pValue,
-            pValue.getClass().getSimpleName());
-        @SuppressWarnings("unchecked")
-        PCollection<T> pCollection = (PCollection<T>) pValue;
-        Dataset<WindowedValue<T>> current = context.getDataset(pCollection);
-        if (result == null) {
-          result = current;
-        } else {
-          result = result.union(current);
-        }
+  public void translate(Flatten.PCollections<T> transform, Context cxt) {
+    Collection<PCollection<?>> pCollections = cxt.getInputs().values();
+    Coder<T> outputCoder = cxt.getOutput().getCoder();
+    Encoder<WindowedValue<T>> outputEnc =
+        cxt.windowedEncoder(outputCoder, windowCoder(cxt.getOutput()));
+
+    Dataset<WindowedValue<T>> result;
+    Iterator<PCollection<T>> pcIt = (Iterator) pCollections.iterator();
+    if (pcIt.hasNext()) {
+      result = getDataset(pcIt.next(), outputCoder, outputEnc, cxt);
+      while (pcIt.hasNext()) {
+        result = result.union(getDataset(pcIt.next(), outputCoder, outputEnc, cxt));
       }
+    } else {
+      result = cxt.createDataset(ImmutableList.of(), outputEnc);
     }
-    context.putDataset(context.getOutput(), result);
+    cxt.putDataset(cxt.getOutput(), result);
+  }
+
+  private Dataset<WindowedValue<T>> getDataset(
+      PCollection<T> pc, Coder<T> coder, Encoder<WindowedValue<T>> enc, Context cxt) {
+    Dataset<WindowedValue<T>> current = cxt.getDataset(pc);
+    // if coders don't match, map using identity function to replace encoder
+    return pc.getCoder().equals(coder) ? current : current.map(fun1(v -> v), enc);
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java
new file mode 100644
index 00000000000..28ab07114c6
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static org.apache.beam.sdk.transforms.windowing.TimestampCombiner.END_OF_WINDOW;
+
+import java.util.Collection;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import scala.Tuple2;
+import scala.collection.TraversableOnce;
+
+/**
+ * Package private helpers to support translating grouping transforms using `groupByKey` such as
+ * {@link GroupByKeyTranslatorBatch} or {@link CombinePerKeyTranslatorBatch}.
+ */
+class GroupByKeyHelpers {
+
+  private GroupByKeyHelpers() {}
+
+  /**
+   * Checks if it's possible to use an optimized `groupByKey` that also moves the window into the
+   * key.
+   *
+   * @param windowing The windowing strategy
+   * @param endOfWindowOnly Flag if to limit this optimization to {@link
+   *     TimestampCombiner#END_OF_WINDOW}.
+   */
+  static boolean eligibleForGroupByWindow(
+      WindowingStrategy<?, ?> windowing, boolean endOfWindowOnly) {
+    return !windowing.needsMerge()
+        && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW)
+        && windowing.getWindowFn().windowCoder().consistentWithEquals();
+  }
+
+  /**
+   * Checks if it's possible to use an optimized `groupByKey` for the global window.
+   *
+   * @param windowing The windowing strategy
+   * @param endOfWindowOnly Flag if to limit this optimization to {@link
+   *     TimestampCombiner#END_OF_WINDOW}.
+   */
+  static boolean eligibleForGlobalGroupBy(
+      WindowingStrategy<?, ?> windowing, boolean endOfWindowOnly) {
+    return windowing.getWindowFn() instanceof GlobalWindows
+        && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW);
+  }
+
+  /**
+   * Explodes a windowed {@link KV} assigned to potentially multiple {@link BoundedWindow}s to a
+   * traversable of composite keys {@code (BoundedWindow, Key)} and value.
+   */
+  static <K, V, T>
+      Fun1<WindowedValue<KV<K, V>>, TraversableOnce<Tuple2<Tuple2<BoundedWindow, K>, T>>>
+          explodeWindowedKey(Fun1<WindowedValue<KV<K, V>>, T> valueFn) {
+    return v -> {
+      T value = valueFn.apply(v);
+      K key = v.getValue().getKey();
+      Collection<BoundedWindow> windows = (Collection<BoundedWindow>) v.getWindows();
+      return ScalaInterop.scalaIterator(windows).map(w -> tuple(tuple(w, key), value));
+    };
+  }
+
+  static <K, V> Fun1<Tuple2<Tuple2<BoundedWindow, K>, V>, WindowedValue<KV<K, V>>> windowedKV() {
+    return t -> windowedKV(t._1, t._2);
+  }
+
+  static <K, V> WindowedValue<KV<K, V>> windowedKV(Tuple2<BoundedWindow, K> key, V value) {
+    return WindowedValue.of(KV.of(key._2, value), key._1.maxTimestamp(), key._1, NO_FIRING);
+  }
+
+  static <V> Fun1<WindowedValue<V>, V> value() {
+    return v -> v.getValue();
+  }
+
+  static <K, V> Fun1<WindowedValue<KV<K, V>>, V> valueValue() {
+    return v -> v.getValue().getValue();
+  }
+
+  static <K, V> Fun1<WindowedValue<KV<K, V>>, K> valueKey() {
+    return v -> v.getValue().getKey();
+  }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
index 6391ba4600c..61306cb993c 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
@@ -17,74 +17,274 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
+import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.collect_list;
+import static org.apache.spark.sql.functions.explode;
+import static org.apache.spark.sql.functions.max;
+import static org.apache.spark.sql.functions.min;
+import static org.apache.spark.sql.functions.struct;
+
 import java.io.Serializable;
 import org.apache.beam.runners.core.InMemoryStateInternals;
-import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.ReduceFnRunner;
 import org.apache.beam.runners.core.StateInternalsFactory;
 import org.apache.beam.runners.core.SystemReduceFn;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
 import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.sql.Column;
 import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.catalyst.expressions.CreateArray;
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.collection.Seq;
+import scala.collection.immutable.List;
 
+/**
+ * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation
+ * function {@code collect_list} when applicable.
+ *
+ * <p>Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the
+ * latter case the entire group (iterator) has to be loaded into memory as well. Either way there's
+ * a risk of OOM errors. When disabling {@link #useCollectList}, a more memory sensitive iterable is
+ * used that can be traversed just once. Attempting to traverse the iterable again will throw.
+ *
+ * <ul>
+ *   <li>When using the default global window, window information is dropped and restored after the
+ *       aggregation.
+ *   <li>For non-merging windows, windows are exploded and moved into a composite key for better
+ *       distribution. Though, to keep the amount of shuffled data low, this is only done if values
+ *       are assigned to a single window or if there are only few keys and distributing data is
+ *       important. After the aggregation, windowed values are restored from the composite key.
+ *   <li>All other cases are implemented using the SDK {@link ReduceFnRunner}.
+ * </ul>
+ */
 class GroupByKeyTranslatorBatch<K, V>
-    implements TransformTranslator<
-        PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>> {
+    extends TransformTranslator<
+        PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>> {
+
+  /** Literal of binary encoded Pane info. */
+  private static final Expression PANE_NO_FIRING = lit(toByteArray(NO_FIRING, PaneInfoCoder.of()));
+
+  /** Defaults for value in single global window. */
+  private static final List<Expression> GLOBAL_WINDOW_DETAILS =
+      windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY}));
+
+  private boolean useCollectList = true;
+
+  public GroupByKeyTranslatorBatch() {}
+
+  public GroupByKeyTranslatorBatch(boolean useCollectList) {
+    this.useCollectList = useCollectList;
+  }
 
   @Override
-  public void translateTransform(
-      PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> transform,
-      AbstractTranslationContext context) {
-
-    @SuppressWarnings("unchecked")
-    final PCollection<KV<K, V>> inputPCollection = (PCollection<KV<K, V>>) context.getInput();
-    Dataset<WindowedValue<KV<K, V>>> input = context.getDataset(inputPCollection);
-    WindowingStrategy<?, ?> windowingStrategy = inputPCollection.getWindowingStrategy();
-    KvCoder<K, V> kvCoder = (KvCoder<K, V>) inputPCollection.getCoder();
-    Coder<V> valueCoder = kvCoder.getValueCoder();
-
-    // group by key only
-    Coder<K> keyCoder = kvCoder.getKeyCoder();
-    KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupByKeyOnly =
-        input.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
-
-    // group also by windows
-    WindowedValue.FullWindowedValueCoder<KV<K, Iterable<V>>> outputCoder =
-        WindowedValue.FullWindowedValueCoder.of(
-            KvCoder.of(keyCoder, IterableCoder.of(valueCoder)),
-            windowingStrategy.getWindowFn().windowCoder());
-    Dataset<WindowedValue<KV<K, Iterable<V>>>> output =
-        groupByKeyOnly.flatMapGroups(
-            new GroupAlsoByWindowViaOutputBufferFn<>(
-                windowingStrategy,
-                new InMemoryStateInternalsFactory<>(),
-                SystemReduceFn.buffering(valueCoder),
-                context.getSerializableOptions()),
-            EncoderHelpers.fromBeamCoder(outputCoder));
-
-    context.putDataset(context.getOutput(), output);
+  public void translate(GroupByKey<K, V> transform, Context cxt) {
+    WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+    TimestampCombiner tsCombiner = windowing.getTimestampCombiner();
+
+    Dataset<WindowedValue<KV<K, V>>> input = cxt.getDataset(cxt.getInput());
+
+    KvCoder<K, V> inputCoder = (KvCoder<K, V>) cxt.getInput().getCoder();
+    KvCoder<K, Iterable<V>> outputCoder = (KvCoder<K, Iterable<V>>) cxt.getOutput().getCoder();
+
+    Encoder<V> valueEnc = cxt.valueEncoderOf(inputCoder);
+    Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder);
+
+    // In batch we can ignore triggering and allowed lateness parameters
+    final Dataset<WindowedValue<KV<K, Iterable<V>>>> result;
+
+    if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) {
+      // Collects all values per key in memory. This might be problematic if there's few keys only
+      // or some highly skewed distribution.
+      result =
+          input
+              .groupBy(col("value.key").as("key"))
+              .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner))
+              .select(
+                  inGlobalWindow(
+                      keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))),
+                      windowTimestamp(tsCombiner)));
+
+    } else if (eligibleForGlobalGroupBy(windowing, true)) {
+      // Produces an iterable that can be traversed exactly once. However, on the plus side, data is
+      // not collected in memory until serialized or done by the user.
+      result =
+          cxt.getDataset(cxt.getInput())
+              .groupByKey(valueKey(), keyEnc)
+              .mapValues(valueValue(), cxt.valueEncoderOf(inputCoder))
+              .mapGroups(fun2((k, it) -> KV.of(k, iterableOnce(it))), cxt.kvEncoderOf(outputCoder))
+              .map(fun1(WindowedValue::valueInGlobalWindow), cxt.windowedEncoder(outputCoder));
+
+    } else if (useCollectList
+        && eligibleForGroupByWindow(windowing, false)
+        && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) {
+      // Using the window as part of the key should help to better distribute the data. However, if
+      // values are assigned to multiple windows, more data would be shuffled around. If there's few
+      // keys only, this is still valuable.
+      // Collects all values per key & window in memory.
+      result =
+          input
+              .select(explode(col("windows")).as("window"), col("value"), col("timestamp"))
+              .groupBy(col("value.key"), col("window"))
+              .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner))
+              .select(
+                  inSingleWindow(
+                      keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))),
+                      col("window").as(cxt.windowEncoder()),
+                      windowTimestamp(tsCombiner)));
+
+    } else if (eligibleForGroupByWindow(windowing, true)
+        && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) {
+      // Using the window as part of the key should help to better distribute the data. However, if
+      // values are assigned to multiple windows, more data would be shuffled around. If there's few
+      // keys only, this is still valuable.
+      // Produces an iterable that can be traversed exactly once. However, on the plus side, data is
+      // not collected in memory until serialized or done by the user.
+      Encoder<Tuple2<BoundedWindow, K>> windowedKeyEnc =
+          cxt.tupleEncoder(cxt.windowEncoder(), keyEnc);
+      result =
+          cxt.getDataset(cxt.getInput())
+              .flatMap(explodeWindowedKey(valueValue()), cxt.tupleEncoder(windowedKeyEnc, valueEnc))
+              .groupByKey(fun1(Tuple2::_1), windowedKeyEnc)
+              .mapValues(fun1(Tuple2::_2), valueEnc)
+              .mapGroups(
+                  fun2((wKey, it) -> windowedKV(wKey, iterableOnce(it))),
+                  cxt.windowedEncoder(outputCoder));
+
+    } else {
+      // Collects all values per key in memory. This might be problematic if there's few keys only
+      // or some highly skewed distribution.
+
+      // FIXME Revisit this case, implementation is far from ideal:
+      // - iterator traversed at least twice, forcing materialization in memory
+
+      // group by key, then by windows
+      result =
+          input
+              .groupByKey(valueKey(), keyEnc)
+              .flatMapGroups(
+                  new GroupAlsoByWindowViaOutputBufferFn<>(
+                      windowing,
+                      (SerStateInternalsFactory) key -> InMemoryStateInternals.forKey(key),
+                      SystemReduceFn.buffering(inputCoder.getValueCoder()),
+                      cxt.getSerializableOptions()),
+                  cxt.windowedEncoder(outputCoder));
+    }
+
+    cxt.putDataset(cxt.getOutput(), result);
+  }
+
+  /** Serializable In-memory state internals factory. */
+  private interface SerStateInternalsFactory<K> extends StateInternalsFactory<K>, Serializable {}
+
+  private Encoder<Iterable<V>> iterableEnc(Encoder<V> enc) {
+    // safe to use list encoder with collect list
+    return (Encoder) collectionEncoder(enc);
+  }
+
+  private static Column[] timestampAggregator(TimestampCombiner tsCombiner) {
+    if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) {
+      return new Column[0]; // no aggregation needed
+    }
+    Column agg =
+        tsCombiner.equals(TimestampCombiner.EARLIEST)
+            ? min(col("timestamp"))
+            : max(col("timestamp"));
+    return new Column[] {agg.as("timestamp")};
+  }
+
+  private static Expression windowTimestamp(TimestampCombiner tsCombiner) {
+    if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) {
+      // null will be set to END_OF_WINDOW by the respective deserializer
+      return litNull(DataTypes.LongType);
+    }
+    return col("timestamp").expr();
   }
 
   /**
-   * In-memory state internals factory.
-   *
-   * @param <K> State key type.
+   * Java {@link Iterable} from Scala {@link Iterator} that can be iterated just once so that we
+   * don't have to load all data into memory.
    */
-  static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, Serializable {
-    @Override
-    public StateInternals stateInternalsForKey(K key) {
-      return InMemoryStateInternals.forKey(key);
-    }
+  private static <T extends @NonNull Object> Iterable<T> iterableOnce(Iterator<T> it) {
+    return () -> {
+      checkState(!it.isEmpty(), "Iterator on values can only be consumed once!");
+      return javaIterator(it);
+    };
+  }
+
+  private <T> TypedColumn<?, KV<K, T>> keyValue(TypedColumn<?, K> key, TypedColumn<?, T> value) {
+    return struct(key.as("key"), value.as("value")).as(kvEncoder(key.encoder(), value.encoder()));
+  }
+
+  private static <InT, T> TypedColumn<InT, WindowedValue<T>> inGlobalWindow(
+      TypedColumn<?, T> value, Expression ts) {
+    List<Expression> fields = concat(timestampedValue(value, ts), GLOBAL_WINDOW_DETAILS);
+    Encoder<WindowedValue<T>> enc =
+        windowedValueEncoder(value.encoder(), encoderOf(GlobalWindow.class));
+    return (TypedColumn<InT, WindowedValue<T>>) new Column(new CreateNamedStruct(fields)).as(enc);
+  }
+
+  public static <InT, T> TypedColumn<InT, WindowedValue<T>> inSingleWindow(
+      TypedColumn<?, T> value, TypedColumn<?, ? extends BoundedWindow> window, Expression ts) {
+    Expression windows = new CreateArray(listOf(window.expr()));
+    Seq<Expression> fields = concat(timestampedValue(value, ts), windowDetails(windows));
+    Encoder<WindowedValue<T>> enc = windowedValueEncoder(value.encoder(), window.encoder());
+    return (TypedColumn<InT, WindowedValue<T>>) new Column(new CreateNamedStruct(fields)).as(enc);
+  }
+
+  private static List<Expression> timestampedValue(Column value, Expression ts) {
+    return seqOf(lit("value"), value.expr(), lit("timestamp"), ts).toList();
+  }
+
+  private static List<Expression> windowDetails(Expression windows) {
+    return seqOf(lit("windows"), windows, lit("pane"), PANE_NO_FIRING).toList();
+  }
+
+  private static <T extends @NonNull Object> Expression lit(T t) {
+    return Literal$.MODULE$.apply(t);
+  }
+
+  @SuppressWarnings("nullness") // NULL literal
+  private static Expression litNull(DataType dataType) {
+    return new Literal(null, dataType);
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java
index 65f496c772b..cf0d2e7ab09 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java
@@ -17,33 +17,27 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import java.util.Collections;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
+
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.Impulse;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.spark.sql.Dataset;
 
 public class ImpulseTranslatorBatch
-    implements TransformTranslator<PTransform<PBegin, PCollection<byte[]>>> {
+    extends TransformTranslator<PBegin, PCollection<byte[]>, Impulse> {
 
   @Override
-  public void translateTransform(
-      PTransform<PBegin, PCollection<byte[]>> transform, AbstractTranslationContext context) {
-    Coder<WindowedValue<byte[]>> windowedValueCoder =
-        WindowedValue.FullWindowedValueCoder.of(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE);
+  public void translate(Impulse transform, Context cxt) {
     Dataset<WindowedValue<byte[]>> dataset =
-        context
-            .getSparkSession()
-            .createDataset(
-                Collections.singletonList(WindowedValue.valueInGlobalWindow(new byte[0])),
-                EncoderHelpers.fromBeamCoder(windowedValueCoder));
-    context.putDataset(context.getOutput(), dataset);
+        cxt.createDataset(
+            ImmutableList.of(WindowedValue.valueInGlobalWindow(EMPTY_BYTE_ARRAY)),
+            cxt.windowedEncoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE));
+    cxt.putDataset(cxt.getOutput(), dataset);
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index 52c2d5ae642..131b285be13 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -17,64 +17,81 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
+import static java.util.stream.Collectors.toList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.storage.StorageLevel.MEMORY_ONLY;
 
 import java.io.IOException;
+import java.util.AbstractMap.SimpleImmutableEntry;
 import java.util.ArrayList;
-import java.util.HashMap;
+import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
-import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
-import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOutputCoder;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.FilterFunction;
-import org.apache.spark.api.java.function.MapFunction;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
+import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.storage.StorageLevel;
+import scala.Function1;
 import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
 
 /**
- * TODO: Add support for state and timers.
+ * Translator for {@link ParDo.MultiOutput} based on {@link DoFnRunners#simpleRunner}.
  *
- * @param <InputT>
- * @param <OutputT>
+ * <p>Each tag is encoded as individual column with a respective schema & encoder each.
+ *
+ * <p>TODO:
+ * <li>Add support for state and timers.
+ * <li>Add support for SplittableDoFn
  */
-@SuppressWarnings({
-  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
 class ParDoTranslatorBatch<InputT, OutputT>
-    implements TransformTranslator<PTransform<PCollection<InputT>, PCollectionTuple>> {
+    extends TransformTranslator<
+        PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>> {
+
+  private static final ClassTag<WindowedValue<Object>> WINDOWED_VALUE_CTAG =
+      ClassTag.apply(WindowedValue.class);
+
+  private static final ClassTag<Tuple2<Integer, WindowedValue<Object>>> TUPLE2_CTAG =
+      ClassTag.apply(Tuple2.class);
 
   @Override
-  public void translateTransform(
-      PTransform<PCollection<InputT>, PCollectionTuple> transform,
-      AbstractTranslationContext context) {
-    String stepName = context.getCurrentTransform().getFullName();
+  public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
+      throws IOException {
+    String stepName = cxt.getCurrentTransform().getFullName();
+
+    SparkCommonPipelineOptions opts = cxt.getOptions().as(SparkCommonPipelineOptions.class);
+    StorageLevel storageLevel = StorageLevel.fromString(opts.getStorageLevel());
 
     // Check for not supported advanced features
     // TODO: add support of Splittable DoFn
-    DoFn<InputT, OutputT> doFn = getDoFn(context);
+    DoFn<InputT, OutputT> doFn = transform.getFn();
     checkState(
         !DoFnSignatures.isSplittable(doFn),
         "Not expected to directly translate splittable DoFn, should have been overridden: %s",
@@ -86,98 +103,124 @@ class ParDoTranslatorBatch<InputT, OutputT>
 
     checkState(
         !DoFnSignatures.requiresTimeSortedInput(doFn),
-        "@RequiresTimeSortedInput is not " + "supported for the moment");
-
-    DoFnSchemaInformation doFnSchemaInformation =
-        ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
-
-    // Init main variables
-    PValue input = context.getInput();
-    Dataset<WindowedValue<InputT>> inputDataSet = context.getDataset(input);
-    Map<TupleTag<?>, PCollection<?>> outputs = context.getOutputs();
-    TupleTag<?> mainOutputTag = getTupleTag(context);
-    List<TupleTag<?>> outputTags = new ArrayList<>(outputs.keySet());
-    WindowingStrategy<?, ?> windowingStrategy =
-        ((PCollection<InputT>) input).getWindowingStrategy();
-    Coder<InputT> inputCoder = ((PCollection<InputT>) input).getCoder();
-    Coder<? extends BoundedWindow> windowCoder = windowingStrategy.getWindowFn().windowCoder();
-
-    // construct a map from side input to WindowingStrategy so that
-    // the DoFn runner can map main-input windows to side input windows
-    List<PCollectionView<?>> sideInputs = getSideInputs(context);
-    Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputStrategies = new HashMap<>();
-    for (PCollectionView<?> sideInput : sideInputs) {
-      sideInputStrategies.put(sideInput, sideInput.getPCollection().getWindowingStrategy());
-    }
-
-    SideInputBroadcast broadcastStateData = createBroadcastSideInputs(sideInputs, context);
+        "@RequiresTimeSortedInput is not supported for the moment");
 
-    Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
-    MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
+    TupleTag<OutputT> mainOutputTag = transform.getMainOutputTag();
 
-    List<TupleTag<?>> additionalOutputTags = new ArrayList<>();
-    for (TupleTag<?> tag : outputTags) {
-      if (!tag.equals(mainOutputTag)) {
-        additionalOutputTags.add(tag);
-      }
-    }
+    DoFnSchemaInformation doFnSchema =
+        ParDoTranslation.getSchemaInformation(cxt.getCurrentTransform());
 
-    Map<String, PCollectionView<?>> sideInputMapping =
-        ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
-    @SuppressWarnings("unchecked")
-    DoFnFunction<InputT, OutputT> doFnWrapper =
-        new DoFnFunction(
-            metricsAccum,
+    PCollection<InputT> input = (PCollection<InputT>) cxt.getInput();
+    DoFnMapPartitionsFactory<InputT, OutputT> factory =
+        new DoFnMapPartitionsFactory<>(
             stepName,
             doFn,
-            windowingStrategy,
-            sideInputStrategies,
-            context.getSerializableOptions(),
-            additionalOutputTags,
+            doFnSchema,
+            cxt.getSerializableOptions(),
+            input,
             mainOutputTag,
-            inputCoder,
-            outputCoderMap,
-            broadcastStateData,
-            doFnSchemaInformation,
-            sideInputMapping);
-
-    MultiOutputCoder multipleOutputCoder =
-        MultiOutputCoder.of(SerializableCoder.of(TupleTag.class), outputCoderMap, windowCoder);
-    Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs =
-        inputDataSet.mapPartitions(doFnWrapper, EncoderHelpers.fromBeamCoder(multipleOutputCoder));
-    if (outputs.entrySet().size() > 1) {
-      allOutputs.persist();
-      for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) {
-        pruneOutputFilteredByTag(context, allOutputs, output, windowCoder);
+            cxt.getOutputs(),
+            transform.getSideInputs(),
+            createBroadcastSideInputs(transform.getSideInputs().values(), cxt));
+
+    Dataset<WindowedValue<InputT>> inputDs = cxt.getDataset(input);
+    if (cxt.getOutputs().size() > 1) {
+      // In case of multiple outputs / tags, map each tag to a column by index.
+      // At the end split the result into multiple datasets selecting one column each.
+      Map<TupleTag<?>, Integer> tags = ImmutableMap.copyOf(zipwithIndex(cxt.getOutputs().keySet()));
+
+      List<Encoder<WindowedValue<Object>>> encoders =
+          createEncoders(cxt.getOutputs(), (Iterable<TupleTag<?>>) tags.keySet(), cxt);
+
+      Function1<Iterator<WindowedValue<InputT>>, Iterator<Tuple2<Integer, WindowedValue<Object>>>>
+          doFnMapper = factory.create((tag, v) -> tuple(tags.get(tag), (WindowedValue<Object>) v));
+
+      // FIXME What's the strategy to unpersist Datasets / RDDs?
+
+      // If using storage level MEMORY_ONLY, it's best to persist the dataset as RDD to avoid any
+      // serialization / use of encoders. Persisting a Dataset, even if using a "deserialized"
+      // storage level, involves converting the data to the internal representation (InternalRow)
+      // by use of an encoder.
+      // For any other storage level, persist as Dataset, so we can select columns by TupleTag
+      // individually without restoring the entire row.
+      if (MEMORY_ONLY().equals(storageLevel)) {
+
+        RDD<Tuple2<Integer, WindowedValue<Object>>> allTagsRDD =
+            inputDs.rdd().mapPartitions(doFnMapper, false, TUPLE2_CTAG);
+        allTagsRDD.persist();
+
+        // divide into separate output datasets per tag
+        for (Entry<TupleTag<?>, Integer> e : tags.entrySet()) {
+          TupleTag<Object> key = (TupleTag<Object>) e.getKey();
+          Integer id = e.getValue();
+
+          RDD<WindowedValue<Object>> rddByTag =
+              allTagsRDD
+                  .filter(fun1(t -> t._1.equals(id)))
+                  .map(fun1(Tuple2::_2), WINDOWED_VALUE_CTAG);
+
+          cxt.putDataset(
+              cxt.getOutput(key), cxt.getSparkSession().createDataset(rddByTag, encoders.get(id)));
+        }
+      } else {
+        // Persist as wide rows with one column per TupleTag to support different schemas
+        Dataset<Tuple2<Integer, WindowedValue<Object>>> allTagsDS =
+            inputDs.mapPartitions(doFnMapper, oneOfEncoder(encoders));
+        allTagsDS.persist(storageLevel);
+
+        // divide into separate output datasets per tag
+        for (Entry<TupleTag<?>, Integer> e : tags.entrySet()) {
+          TupleTag<Object> key = (TupleTag<Object>) e.getKey();
+          Integer id = e.getValue();
+
+          // Resolve specific column matching the tuple tag (by id)
+          TypedColumn<Tuple2<Integer, WindowedValue<Object>>, WindowedValue<Object>> col =
+              (TypedColumn) col(id.toString()).as(encoders.get(id));
+
+          cxt.putDataset(cxt.getOutput(key), allTagsDS.filter(col.isNotNull()).select(col));
+        }
       }
     } else {
-      Coder<OutputT> outputCoder = ((PCollection<OutputT>) outputs.get(mainOutputTag)).getCoder();
-      Coder<WindowedValue<?>> windowedValueCoder =
-          (Coder<WindowedValue<?>>) (Coder<?>) WindowedValue.getFullCoder(outputCoder, windowCoder);
-      Dataset<WindowedValue<?>> outputDataset =
-          allOutputs.map(
-              (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, WindowedValue<?>>)
-                  value -> value._2,
-              EncoderHelpers.fromBeamCoder(windowedValueCoder));
-      context.putDatasetWildcard(outputs.entrySet().iterator().next().getValue(), outputDataset);
+      PCollection<OutputT> output = cxt.getOutput(mainOutputTag);
+      Dataset<WindowedValue<OutputT>> mainDS =
+          inputDs.mapPartitions(
+              factory.create((tag, value) -> (WindowedValue<OutputT>) value),
+              cxt.windowedEncoder(output.getCoder()));
+
+      cxt.putDataset(output, mainDS);
+    }
+  }
+
+  private List<Encoder<WindowedValue<Object>>> createEncoders(
+      Map<TupleTag<?>, PCollection<?>> outputs, Iterable<TupleTag<?>> columns, Context ctx) {
+    return Streams.stream(columns)
+        .map(tag -> ctx.windowedEncoder(getCoder(outputs.get(tag), tag)))
+        .collect(toList());
+  }
+
+  private Coder<Object> getCoder(@Nullable PCollection<?> pc, TupleTag<?> tag) {
+    if (pc == null) {
+      throw new NullPointerException("No PCollection for tag " + tag);
     }
+    return (Coder<Object>) pc.getCoder();
   }
 
-  private static SideInputBroadcast createBroadcastSideInputs(
-      List<PCollectionView<?>> sideInputs, AbstractTranslationContext context) {
-    JavaSparkContext jsc =
-        JavaSparkContext.fromSparkContext(context.getSparkSession().sparkContext());
+  // FIXME Better ways?
+  private SideInputBroadcast createBroadcastSideInputs(
+      Collection<PCollectionView<?>> sideInputs, Context context) {
 
     SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
     for (PCollectionView<?> sideInput : sideInputs) {
+      PCollection<?> pc = sideInput.getPCollection();
+      if (pc == null) {
+        throw new NullPointerException("PCollection for SideInput is null");
+      }
       Coder<? extends BoundedWindow> windowCoder =
-          sideInput.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
-
+          pc.getWindowingStrategy().getWindowFn().windowCoder();
       Coder<WindowedValue<?>> windowedValueCoder =
           (Coder<WindowedValue<?>>)
-              (Coder<?>)
-                  WindowedValue.getFullCoder(sideInput.getPCollection().getCoder(), windowCoder);
-      Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataSet(sideInput);
+              (Coder<?>) WindowedValue.getFullCoder(pc.getCoder(), windowCoder);
+      Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataset(sideInput);
       List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
       List<byte[]> codedValues = new ArrayList<>();
       for (WindowedValue<?> v : valuesList) {
@@ -185,73 +228,17 @@ class ParDoTranslatorBatch<InputT, OutputT>
       }
 
       sideInputBroadcast.add(
-          sideInput.getTagInternal().getId(), jsc.broadcast(codedValues), windowedValueCoder);
+          sideInput.getTagInternal().getId(), context.broadcast(codedValues), windowedValueCoder);
     }
     return sideInputBroadcast;
   }
 
-  private List<PCollectionView<?>> getSideInputs(AbstractTranslationContext context) {
-    List<PCollectionView<?>> sideInputs;
-    try {
-      sideInputs = ParDoTranslation.getSideInputs(context.getCurrentTransform());
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-    return sideInputs;
-  }
-
-  private TupleTag<?> getTupleTag(AbstractTranslationContext context) {
-    TupleTag<?> mainOutputTag;
-    try {
-      mainOutputTag = ParDoTranslation.getMainOutputTag(context.getCurrentTransform());
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-    return mainOutputTag;
-  }
-
-  @SuppressWarnings("unchecked")
-  private DoFn<InputT, OutputT> getDoFn(AbstractTranslationContext context) {
-    DoFn<InputT, OutputT> doFn;
-    try {
-      doFn = (DoFn<InputT, OutputT>) ParDoTranslation.getDoFn(context.getCurrentTransform());
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-    return doFn;
-  }
-
-  private void pruneOutputFilteredByTag(
-      AbstractTranslationContext context,
-      Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs,
-      Map.Entry<TupleTag<?>, PCollection<?>> output,
-      Coder<? extends BoundedWindow> windowCoder) {
-    Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> filteredDataset =
-        allOutputs.filter(new DoFnFilterFunction(output.getKey()));
-    Coder<WindowedValue<?>> windowedValueCoder =
-        (Coder<WindowedValue<?>>)
-            (Coder<?>)
-                WindowedValue.getFullCoder(
-                    ((PCollection<OutputT>) output.getValue()).getCoder(), windowCoder);
-    Dataset<WindowedValue<?>> outputDataset =
-        filteredDataset.map(
-            (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, WindowedValue<?>>)
-                value -> value._2,
-            EncoderHelpers.fromBeamCoder(windowedValueCoder));
-    context.putDatasetWildcard(output.getValue(), outputDataset);
-  }
-
-  static class DoFnFilterFunction implements FilterFunction<Tuple2<TupleTag<?>, WindowedValue<?>>> {
-
-    private final TupleTag<?> key;
-
-    DoFnFilterFunction(TupleTag<?> key) {
-      this.key = key;
-    }
-
-    @Override
-    public boolean call(Tuple2<TupleTag<?>, WindowedValue<?>> value) {
-      return value._1.equals(key);
+  private static <T> Collection<Entry<T, Integer>> zipwithIndex(Collection<T> col) {
+    ArrayList<Entry<T, Integer>> zipped = new ArrayList<>(col.size());
+    int i = 0;
+    for (T t : col) {
+      zipped.add(new SimpleImmutableEntry<>(t, i++));
     }
+    return zipped;
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java
index 5789db6cd30..62d79632cfa 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java
@@ -25,7 +25,6 @@ import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTra
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
 import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
@@ -34,6 +33,8 @@ import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
@@ -41,10 +42,6 @@ import org.checkerframework.checker.nullness.qual.Nullable;
  * only the components specific to batch: registry of batch {@link TransformTranslator} and registry
  * lookup code.
  */
-@SuppressWarnings({
-  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
 public class PipelineTranslatorBatch extends PipelineTranslator {
 
   // --------------------------------------------------------------------------------------------
@@ -65,23 +62,24 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
 
   static {
     TRANSFORM_TRANSLATORS.put(Impulse.class, new ImpulseTranslatorBatch());
-    TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch());
-    TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch());
+    TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch<>());
+    TRANSFORM_TRANSLATORS.put(Combine.Globally.class, new CombineGloballyTranslatorBatch<>());
+    TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch<>());
 
     // TODO: Do we need to have a dedicated translator for {@code Reshuffle} if it's deprecated?
     // TRANSFORM_TRANSLATORS.put(Reshuffle.class, new ReshuffleTranslatorBatch());
 
-    TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch());
+    TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch<>());
 
-    TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch());
+    TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch<>());
 
-    TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch());
+    TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch<>());
 
     TRANSFORM_TRANSLATORS.put(
-        SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch());
+        SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>());
 
     TRANSFORM_TRANSLATORS.put(
-        View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch());
+        View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch<>());
   }
 
   public PipelineTranslatorBatch(SparkStructuredStreamingPipelineOptions options) {
@@ -90,8 +88,10 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
 
   /** Returns a translator for the given node, if it is possible, otherwise null. */
   @Override
-  protected TransformTranslator<?> getTransformTranslator(TransformHierarchy.Node node) {
-    @Nullable PTransform<?, ?> transform = node.getTransform();
+  @Nullable
+  protected <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
+          @Nullable TransformT transform) {
     // Root of the graph is null
     if (transform == null) {
       return null;
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java
deleted file mode 100644
index db64bfd19f3..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import java.util.ArrayList;
-import java.util.Iterator;
-import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.DoFnRunners.OutputManager;
-import org.apache.beam.runners.core.TimerInternals;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator;
-
-/** Spark runner process context processes Spark partitions using Beam's {@link DoFnRunner}. */
-class ProcessContext<FnInputT, FnOutputT, OutputT> {
-
-  private final DoFn<FnInputT, FnOutputT> doFn;
-  private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
-  private final ProcessOutputManager<OutputT> outputManager;
-  private final Iterator<TimerInternals.TimerData> timerDataIterator;
-
-  ProcessContext(
-      DoFn<FnInputT, FnOutputT> doFn,
-      DoFnRunner<FnInputT, FnOutputT> doFnRunner,
-      ProcessOutputManager<OutputT> outputManager,
-      Iterator<TimerInternals.TimerData> timerDataIterator) {
-
-    this.doFn = doFn;
-    this.doFnRunner = doFnRunner;
-    this.outputManager = outputManager;
-    this.timerDataIterator = timerDataIterator;
-  }
-
-  Iterable<OutputT> processPartition(Iterator<WindowedValue<FnInputT>> partition) {
-
-    // skip if partition is empty.
-    if (!partition.hasNext()) {
-      return new ArrayList<>();
-    }
-
-    // process the partition; finishBundle() is called from within the output iterator.
-    return this.getOutputIterable(partition, doFnRunner);
-  }
-
-  private void clearOutput() {
-    outputManager.clear();
-  }
-
-  private Iterator<OutputT> getOutputIterator() {
-    return outputManager.iterator();
-  }
-
-  private Iterable<OutputT> getOutputIterable(
-      final Iterator<WindowedValue<FnInputT>> iter,
-      final DoFnRunner<FnInputT, FnOutputT> doFnRunner) {
-    return () -> new ProcCtxtIterator(iter, doFnRunner);
-  }
-
-  interface ProcessOutputManager<T> extends OutputManager, Iterable<T> {
-    void clear();
-  }
-
-  private class ProcCtxtIterator extends AbstractIterator<OutputT> {
-
-    private final Iterator<WindowedValue<FnInputT>> inputIterator;
-    private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
-    private Iterator<OutputT> outputIterator;
-    private boolean isBundleStarted;
-    private boolean isBundleFinished;
-
-    ProcCtxtIterator(
-        Iterator<WindowedValue<FnInputT>> iterator, DoFnRunner<FnInputT, FnOutputT> doFnRunner) {
-      this.inputIterator = iterator;
-      this.doFnRunner = doFnRunner;
-      this.outputIterator = getOutputIterator();
-    }
-
-    @Override
-    protected OutputT computeNext() {
-      try {
-        // Process each element from the (input) iterator, which produces, zero, one or more
-        // output elements (of type V) in the output iterator. Note that the output
-        // collection (and iterator) is reset between each call to processElement, so the
-        // collection only holds the output values for each call to processElement, rather
-        // than for the whole partition (which would use too much memory).
-        if (!isBundleStarted) {
-          isBundleStarted = true;
-          // call startBundle() before beginning to process the partition.
-          doFnRunner.startBundle();
-        }
-
-        while (true) {
-          if (outputIterator.hasNext()) {
-            return outputIterator.next();
-          }
-
-          clearOutput();
-          if (inputIterator.hasNext()) {
-            // grab the next element and process it.
-            doFnRunner.processElement(inputIterator.next());
-            outputIterator = getOutputIterator();
-          } else if (timerDataIterator.hasNext()) {
-            outputIterator = getOutputIterator();
-          } else {
-            // no more input to consume, but finishBundle can produce more output
-            if (!isBundleFinished) {
-              isBundleFinished = true;
-              doFnRunner.finishBundle();
-              outputIterator = getOutputIterator();
-              continue; // try to consume outputIterator from start of loop
-            }
-            DoFnInvokers.invokerFor(doFn).invokeTeardown();
-            return endOfData();
-          }
-        }
-      } catch (final RuntimeException re) {
-        DoFnInvokers.invokerFor(doFn).invokeTeardown();
-        throw re;
-      }
-    }
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
index ebeb8a96eda..30b599c7e5e 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
@@ -17,72 +17,38 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS;
-
 import java.io.IOException;
-import org.apache.beam.runners.core.construction.ReadTranslation;
-import org.apache.beam.runners.core.serialization.Base64Serializer;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.core.construction.SplittableParDo;
+import org.apache.beam.runners.spark.structuredstreaming.io.BoundedDatasetFactory;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
 import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
+import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.SparkSession;
 
+/**
+ * Translator for a {@link SplittableParDo.PrimitiveBoundedRead} that creates a Dataset via an RDD
+ * to avoid an additional serialization roundtrip.
+ */
 class ReadSourceTranslatorBatch<T>
-    implements TransformTranslator<PTransform<PBegin, PCollection<T>>> {
+    extends TransformTranslator<PBegin, PCollection<T>, SplittableParDo.PrimitiveBoundedRead<T>> {
 
-  private static final String sourceProviderClass = DatasetSourceBatch.class.getCanonicalName();
-
-  @SuppressWarnings("unchecked")
   @Override
-  public void translateTransform(
-      PTransform<PBegin, PCollection<T>> transform, AbstractTranslationContext context) {
-    AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>> rootTransform =
-        (AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>>)
-            context.getCurrentTransform();
-
-    BoundedSource<T> source;
-    try {
-      source = ReadTranslation.boundedSourceFromTransform(rootTransform);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-    SparkSession sparkSession = context.getSparkSession();
-
-    String serializedSource = Base64Serializer.serializeUnchecked(source);
-    Dataset<Row> rowDataset =
-        sparkSession
-            .read()
-            .format(sourceProviderClass)
-            .option(BEAM_SOURCE_OPTION, serializedSource)
-            .option(
-                DEFAULT_PARALLELISM,
-                String.valueOf(context.getSparkSession().sparkContext().defaultParallelism()))
-            .option(PIPELINE_OPTIONS, context.getSerializableOptions().toString())
-            .load();
-
-    // extract windowedValue from Row
-    WindowedValue.FullWindowedValueCoder<T> windowedValueCoder =
-        WindowedValue.FullWindowedValueCoder.of(
-            source.getOutputCoder(), GlobalWindow.Coder.INSTANCE);
-
-    Dataset<WindowedValue<T>> dataset =
-        rowDataset.map(
-            RowHelpers.extractWindowedValueFromRowMapFunction(windowedValueCoder),
-            EncoderHelpers.fromBeamCoder(windowedValueCoder));
-
-    PCollection<T> output = (PCollection<T>) context.getOutput();
-    context.putDataset(output, dataset);
+  public void translate(SplittableParDo.PrimitiveBoundedRead<T> transform, Context cxt)
+      throws IOException {
+    SparkSession session = cxt.getSparkSession();
+    BoundedSource<T> source = transform.getSource();
+    SerializablePipelineOptions options = cxt.getSerializableOptions();
+
+    Encoder<WindowedValue<T>> encoder =
+        cxt.windowedEncoder(source.getOutputCoder(), GlobalWindow.Coder.INSTANCE);
+
+    cxt.putDataset(
+        cxt.getOutput(),
+        BoundedDatasetFactory.createDatasetFromRDD(session, source, options, encoder));
   }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java
deleted file mode 100644
index a88d5454667..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
-import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.sdk.transforms.Reshuffle;
-
-/** TODO: Should be removed if {@link Reshuffle} won't be translated. */
-class ReshuffleTranslatorBatch<K, InputT> implements TransformTranslator<Reshuffle<K, InputT>> {
-
-  @Override
-  public void translateTransform(
-      Reshuffle<K, InputT> transform, AbstractTranslationContext context) {}
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java
index 875a983b401..3b993a3ce19 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java
@@ -17,45 +17,83 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import java.util.Collection;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.WindowingHelpers;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.sql.Dataset;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.joda.time.Instant;
 
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
 class WindowAssignTranslatorBatch<T>
-    implements TransformTranslator<PTransform<PCollection<T>, PCollection<T>>> {
+    extends TransformTranslator<PCollection<T>, PCollection<T>, Window.Assign<T>> {
 
   @Override
-  public void translateTransform(
-      PTransform<PCollection<T>, PCollection<T>> transform, AbstractTranslationContext context) {
-
-    Window.Assign<T> assignTransform = (Window.Assign<T>) transform;
-    @SuppressWarnings("unchecked")
-    final PCollection<T> input = (PCollection<T>) context.getInput();
-    @SuppressWarnings("unchecked")
-    final PCollection<T> output = (PCollection<T>) context.getOutput();
-
-    Dataset<WindowedValue<T>> inputDataset = context.getDataset(input);
-    if (WindowingHelpers.skipAssignWindows(assignTransform, context)) {
-      context.putDataset(output, inputDataset);
+  public void translate(Window.Assign<T> transform, Context cxt) {
+    WindowFn<T, ?> windowFn = transform.getWindowFn();
+    PCollection<T> input = cxt.getInput();
+    Dataset<WindowedValue<T>> inputDataset = cxt.getDataset(input);
+
+    if (windowFn == null || skipAssignWindows(windowFn, input)) {
+      cxt.putDataset(cxt.getOutput(), inputDataset);
     } else {
-      WindowFn<T, ?> windowFn = assignTransform.getWindowFn();
-      WindowedValue.FullWindowedValueCoder<T> windowedValueCoder =
-          WindowedValue.FullWindowedValueCoder.of(input.getCoder(), windowFn.windowCoder());
       Dataset<WindowedValue<T>> outputDataset =
           inputDataset.map(
-              WindowingHelpers.assignWindowsMapFunction(windowFn),
-              EncoderHelpers.fromBeamCoder(windowedValueCoder));
-      context.putDataset(output, outputDataset);
+              assignWindows(windowFn),
+              cxt.windowedEncoder(input.getCoder(), windowFn.windowCoder()));
+
+      cxt.putDataset(cxt.getOutput(), outputDataset);
     }
   }
+
+  /**
+   * Checks if the window transformation should be applied or skipped.
+   *
+   * <p>Avoid running assign windows if both source and destination are global window or if the user
+   * has not specified the WindowFn (meaning they are just messing with triggering or allowed
+   * lateness).
+   */
+  private boolean skipAssignWindows(WindowFn<T, ?> newFn, PCollection<T> input) {
+    WindowFn<?, ?> currentFn = input.getWindowingStrategy().getWindowFn();
+    return currentFn instanceof GlobalWindows && newFn instanceof GlobalWindows;
+  }
+
+  private static <T, W extends @NonNull BoundedWindow>
+      MapFunction<WindowedValue<T>, WindowedValue<T>> assignWindows(WindowFn<T, W> windowFn) {
+    return value -> {
+      final BoundedWindow window = getOnlyWindow(value);
+      final T element = value.getValue();
+      final Instant timestamp = value.getTimestamp();
+      Collection<W> windows =
+          windowFn.assignWindows(
+              windowFn.new AssignContext() {
+
+                @Override
+                public T element() {
+                  return element;
+                }
+
+                @Override
+                public @NonNull Instant timestamp() {
+                  return timestamp;
+                }
+
+                @Override
+                public @NonNull BoundedWindow window() {
+                  return window;
+                }
+              });
+      return WindowedValue.of(element, timestamp, windows, value.getPane());
+    };
+  }
+
+  private static <T> BoundedWindow getOnlyWindow(WindowedValue<T> wv) {
+    return Iterables.getOnlyElement((Iterable<BoundedWindow>) wv.getWindows());
+  }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
index fe3f39ef51e..f8c63bc34f1 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
@@ -17,10 +17,9 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.util.CoderUtils;
 
 /** Serialization utility class. */
 public final class CoderHelpers {
@@ -35,13 +34,11 @@ public final class CoderHelpers {
    * @return Byte array representing serialized object.
    */
   public static <T> byte[] toByteArray(T value, Coder<T> coder) {
-    ByteArrayOutputStream baos = new ByteArrayOutputStream();
     try {
-      coder.encode(value, baos);
+      return CoderUtils.encodeToByteArray(coder, value);
     } catch (IOException e) {
       throw new IllegalStateException("Error encoding value: " + value, e);
     }
-    return baos.toByteArray();
   }
 
   /**
@@ -53,9 +50,8 @@ public final class CoderHelpers {
    * @return Deserialized object.
    */
   public static <T> T fromByteArray(byte[] serialized, Coder<T> coder) {
-    ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
     try {
-      return coder.decode(bais);
+      return CoderUtils.decodeFromByteArray(coder, serialized);
     } catch (IOException e) {
       throw new IllegalStateException("Error decoding bytes for coder: " + coder, e);
     }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
index c7d69c0b8ad..e70cc7253f8 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
@@ -17,13 +17,17 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+
 import java.lang.reflect.Constructor;
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
 import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke;
+import org.apache.spark.sql.catalyst.expressions.objects.NewInstance;
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
 import org.apache.spark.sql.types.DataType;
-import scala.collection.immutable.Nil$;
-import scala.collection.mutable.WrappedArray;
+import scala.Option;
 import scala.reflect.ClassTag;
 
 public class EncoderFactory {
@@ -31,6 +35,12 @@ public class EncoderFactory {
   private static final Constructor<StaticInvoke> STATIC_INVOKE_CONSTRUCTOR =
       (Constructor<StaticInvoke>) StaticInvoke.class.getConstructors()[0];
 
+  private static final Constructor<Invoke> INVOKE_CONSTRUCTOR =
+      (Constructor<Invoke>) Invoke.class.getConstructors()[0];
+
+  private static final Constructor<NewInstance> NEW_INSTANCE_CONSTRUCTOR =
+      (Constructor<NewInstance>) NewInstance.class.getConstructors()[0];
+
   static <T> ExpressionEncoder<T> create(
       Expression serializer, Expression deserializer, Class<? super T> clazz) {
     return new ExpressionEncoder<>(serializer, deserializer, ClassTag.apply(clazz));
@@ -39,21 +49,68 @@ public class EncoderFactory {
   /**
    * Invoke method {@code fun} on Class {@code cls}, immediately propagating {@code null} if any
    * input arg is {@code null}.
-   *
-   * <p>To address breaking interfaces between various version of Spark 3 these are created
-   * reflectively. This is fine as it's just needed once to create the query plan.
    */
   static Expression invokeIfNotNull(Class<?> cls, String fun, DataType type, Expression... args) {
+    return invoke(cls, fun, type, true, args);
+  }
+
+  /** Invoke method {@code fun} on Class {@code cls}. */
+  static Expression invoke(Class<?> cls, String fun, DataType type, Expression... args) {
+    return invoke(cls, fun, type, false, args);
+  }
+
+  private static Expression invoke(
+      Class<?> cls, String fun, DataType type, boolean propagateNull, Expression... args) {
     try {
+      // To address breaking interfaces between various version of Spark 3,  expressions are
+      // created reflectively. This is fine as it's just needed once to create the query plan.
       switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) {
         case 6:
           // Spark 3.1.x
           return STATIC_INVOKE_CONSTRUCTOR.newInstance(
-              cls, type, fun, new WrappedArray.ofRef<>(args), true, true);
+              cls, type, fun, seqOf(args), propagateNull, true);
         case 8:
           // Spark 3.2.x, 3.3.x
           return STATIC_INVOKE_CONSTRUCTOR.newInstance(
-              cls, type, fun, new WrappedArray.ofRef<>(args), Nil$.MODULE$, true, true, true);
+              cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true);
+        default:
+          throw new RuntimeException("Unsupported version of Spark");
+      }
+    } catch (IllegalArgumentException | ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  /** Invoke method {@code fun} on {@code obj} with provided {@code args}. */
+  static Expression invoke(
+      Expression obj, String fun, DataType type, boolean nullable, Expression... args) {
+    try {
+      // To address breaking interfaces between various version of Spark 3,  expressions are
+      // created reflectively. This is fine as it's just needed once to create the query plan.
+      switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) {
+        case 6:
+          return INVOKE_CONSTRUCTOR.newInstance(obj, fun, type, seqOf(args), false, nullable);
+        case 8:
+          return INVOKE_CONSTRUCTOR.newInstance(
+              obj, fun, type, seqOf(args), emptyList(), false, nullable, true);
+        default:
+          throw new RuntimeException("Unsupported version of Spark");
+      }
+    } catch (IllegalArgumentException | ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  static Expression newInstance(Class<?> cls, DataType type, Expression... args) {
+    try {
+      // To address breaking interfaces between various version of Spark 3,  expressions are
+      // created reflectively. This is fine as it's just needed once to create the query plan.
+      switch (NEW_INSTANCE_CONSTRUCTOR.getParameterCount()) {
+        case 5:
+          return NEW_INSTANCE_CONSTRUCTOR.newInstance(cls, seqOf(args), true, type, Option.empty());
+        case 6:
+          return NEW_INSTANCE_CONSTRUCTOR.newInstance(
+              cls, seqOf(args), emptyList(), true, type, Option.empty());
         default:
           throw new RuntimeException("Unsupported version of Spark");
       }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index 68738cf0308..f89f6bdb9d3 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -17,44 +17,488 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invoke;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invokeIfNotNull;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.newInstance;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.match;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
 import static org.apache.spark.sql.types.DataTypes.BinaryType;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.LongType;
 
+import java.math.BigDecimal;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
 import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.catalyst.SerializerBuildHelper;
+import org.apache.spark.sql.catalyst.SerializerBuildHelper.MapElementInformation;
 import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
 import org.apache.spark.sql.catalyst.expressions.BoundReference;
+import org.apache.spark.sql.catalyst.expressions.Coalesce;
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
+import org.apache.spark.sql.catalyst.expressions.EqualTo;
 import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.GetStructField;
+import org.apache.spark.sql.catalyst.expressions.If;
+import org.apache.spark.sql.catalyst.expressions.IsNotNull;
+import org.apache.spark.sql.catalyst.expressions.IsNull;
 import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.catalyst.expressions.MapKeys;
+import org.apache.spark.sql.catalyst.expressions.MapValues;
+import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.types.ArrayType;
 import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.MapType;
 import org.apache.spark.sql.types.ObjectType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.MutablePair;
 import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+import scala.Option;
+import scala.Some;
+import scala.Tuple2;
+import scala.collection.IndexedSeq;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
 
+/** {@link Encoders} utility class. */
 public class EncoderHelpers {
   private static final DataType OBJECT_TYPE = new ObjectType(Object.class);
+  private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class);
+  private static final DataType WINDOWED_VALUE = new ObjectType(WindowedValue.class);
+  private static final DataType KV_TYPE = new ObjectType(KV.class);
+  private static final DataType MUTABLE_PAIR_TYPE = new ObjectType(MutablePair.class);
+
+  // Collections / maps of these types can be (de)serialized without (de)serializing each member
+  private static final Set<Class<?>> PRIMITIV_TYPES =
+      ImmutableSet.of(
+          Boolean.class,
+          Byte.class,
+          Short.class,
+          Integer.class,
+          Long.class,
+          Float.class,
+          Double.class);
+
+  // Default encoders by class
+  private static final Map<Class<?>, Encoder<?>> DEFAULT_ENCODERS = new HashMap<>();
+
+  // Factory for default encoders by class
+  private static final Function<Class<?>, @Nullable Encoder<?>> ENCODER_FACTORY =
+      cls -> {
+        if (cls.equals(PaneInfo.class)) {
+          return paneInfoEncoder();
+        } else if (cls.equals(GlobalWindow.class)) {
+          return binaryEncoder(GlobalWindow.Coder.INSTANCE, false);
+        } else if (cls.equals(IntervalWindow.class)) {
+          return binaryEncoder(IntervalWindowCoder.of(), false);
+        } else if (cls.equals(Instant.class)) {
+          return instantEncoder();
+        } else if (cls.equals(String.class)) {
+          return Encoders.STRING();
+        } else if (cls.equals(Boolean.class)) {
+          return Encoders.BOOLEAN();
+        } else if (cls.equals(Integer.class)) {
+          return Encoders.INT();
+        } else if (cls.equals(Long.class)) {
+          return Encoders.LONG();
+        } else if (cls.equals(Float.class)) {
+          return Encoders.FLOAT();
+        } else if (cls.equals(Double.class)) {
+          return Encoders.DOUBLE();
+        } else if (cls.equals(BigDecimal.class)) {
+          return Encoders.DECIMAL();
+        } else if (cls.equals(byte[].class)) {
+          return Encoders.BINARY();
+        } else if (cls.equals(Byte.class)) {
+          return Encoders.BYTE();
+        } else if (cls.equals(Short.class)) {
+          return Encoders.SHORT();
+        }
+        return null;
+      };
+
+  private static <T> @Nullable Encoder<T> getOrCreateDefaultEncoder(Class<? super T> cls) {
+    return (Encoder<T>) DEFAULT_ENCODERS.computeIfAbsent(cls, ENCODER_FACTORY);
+  }
+
+  /** Gets or creates a default {@link Encoder} for {@link T}. */
+  public static <T> Encoder<T> encoderOf(Class<? super T> cls) {
+    Encoder<T> enc = getOrCreateDefaultEncoder(cls);
+    if (enc == null) {
+      throw new IllegalArgumentException("No default coder available for class " + cls);
+    }
+    return enc;
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType}
+   * delegating to a Beam {@link Coder} underneath.
+   *
+   * <p>Note: For common types, if available, default Spark {@link Encoder}s are used instead.
+   *
+   * @param coder Beam {@link Coder}
+   */
+  public static <T> Encoder<T> encoderFor(Coder<T> coder) {
+    Encoder<T> enc = getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType());
+    return enc != null ? enc : binaryEncoder(coder, true);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link T} of {@link StructType} with fields {@code value},
+   * {@code timestamp}, {@code windows} and {@code pane}.
+   *
+   * @param value {@link Encoder} to encode field `{@code value}`.
+   * @param window {@link Encoder} to encode individual windows in field `{@code windows}`
+   */
+  public static <T, W extends BoundedWindow> Encoder<WindowedValue<T>> windowedValueEncoder(
+      Encoder<T> value, Encoder<W> window) {
+    Encoder<Instant> timestamp = encoderOf(Instant.class);
+    Encoder<PaneInfo> pane = encoderOf(PaneInfo.class);
+    Encoder<Collection<W>> windows = collectionEncoder(window);
+    Expression serializer =
+        serializeWindowedValue(rootRef(WINDOWED_VALUE, true), value, timestamp, windows, pane);
+    Expression deserializer =
+        deserializeWindowedValue(rootCol(serializer.dataType()), value, timestamp, windows, pane);
+    return EncoderFactory.create(serializer, deserializer, WindowedValue.class);
+  }
+
+  /**
+   * Creates a one-of Spark {@link Encoder} of {@link StructType} where each alternative is
+   * represented as colum / field named by its index with a separate {@link Encoder} each.
+   *
+   * <p>Externally this is represented as tuple {@code (index, data)} where an index corresponds to
+   * an {@link Encoder} in the provided list.
+   *
+   * @param encoders {@link Encoder}s for each alternative.
+   */
+  public static <T> Encoder<Tuple2<Integer, T>> oneOfEncoder(List<Encoder<T>> encoders) {
+    Expression serializer = serializeOneOf(rootRef(TUPLE2_TYPE, true), encoders);
+    Expression deserializer = deserializeOneOf(rootCol(serializer.dataType()), encoders);
+    return EncoderFactory.create(serializer, deserializer, Tuple2.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link KV} of {@link StructType} with fields {@code key}
+   * and {@code value}.
+   *
+   * @param key {@link Encoder} to encode field `{@code key}`.
+   * @param value {@link Encoder} to encode field `{@code value}`
+   */
+  public static <K, V> Encoder<KV<K, V>> kvEncoder(Encoder<K> key, Encoder<V> value) {
+    Expression serializer = serializeKV(rootRef(KV_TYPE, true), key, value);
+    Expression deserializer = deserializeKV(rootCol(serializer.dataType()), key, value);
+    return EncoderFactory.create(serializer, deserializer, KV.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s with nullable
+   * elements.
+   *
+   * @param enc {@link Encoder} to encode collection elements
+   */
+  public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> enc) {
+    return collectionEncoder(enc, true);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s.
+   *
+   * @param enc {@link Encoder} to encode collection elements
+   * @param nullable Allow nullable collection elements
+   */
+  public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> enc, boolean nullable) {
+    DataType type = new ObjectType(Collection.class);
+    Expression serializer = serializeSeq(rootRef(type, true), enc, nullable);
+    Expression deserializer = deserializeSeq(rootCol(serializer.dataType()), enc, nullable, true);
+    return EncoderFactory.create(serializer, deserializer, Collection.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} of {@link MapType} that deserializes to {@link MapT}.
+   *
+   * @param key {@link Encoder} to encode keys
+   * @param value {@link Encoder} to encode values
+   * @param cls Specific class to use, supported are {@link HashMap} and {@link TreeMap}
+   */
+  public static <MapT extends Map<K, V>, K, V> Encoder<MapT> mapEncoder(
+      Encoder<K> key, Encoder<V> value, Class<MapT> cls) {
+    Expression serializer = mapSerializer(rootRef(new ObjectType(cls), true), key, value);
+    Expression deserializer = mapDeserializer(rootCol(serializer.dataType()), key, value, cls);
+    return EncoderFactory.create(serializer, deserializer, cls);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for Spark's {@link MutablePair} of {@link StructType} with
+   * fields `{@code _1}` and `{@code _2}`.
+   *
+   * <p>This is intended to be used in places such as aggregators.
+   *
+   * @param enc1 {@link Encoder} to encode `{@code _1}`
+   * @param enc2 {@link Encoder} to encode `{@code _2}`
+   */
+  public static <T1, T2> Encoder<MutablePair<T1, T2>> mutablePairEncoder(
+      Encoder<T1> enc1, Encoder<T2> enc2) {
+    Expression serializer = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, true), enc1, enc2);
+    Expression deserializer = deserializeMutablePair(rootCol(serializer.dataType()), enc1, enc2);
+    return EncoderFactory.create(serializer, deserializer, MutablePair.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link PaneInfo} of {@link DataTypes#BinaryType
+   * BinaryType}.
+   */
+  private static Encoder<PaneInfo> paneInfoEncoder() {
+    DataType type = new ObjectType(PaneInfo.class);
+    return EncoderFactory.create(
+        invokeIfNotNull(Utils.class, "paneInfoToBytes", BinaryType, rootRef(type, false)),
+        invokeIfNotNull(Utils.class, "paneInfoFromBytes", type, rootCol(BinaryType)),
+        PaneInfo.class);
+  }
 
   /**
-   * Wrap a Beam coder into a Spark Encoder using Catalyst Expression Encoders (which uses java code
-   * generation).
+   * Creates a Spark {@link Encoder} for Joda {@link Instant} of {@link DataTypes#LongType
+   * LongType}.
    */
-  public static <T> Encoder<T> fromBeamCoder(Coder<T> coder) {
-    Class<? super T> clazz = coder.getEncodedTypeDescriptor().getRawType();
-    // Class T could be private, therefore use OBJECT_TYPE to not risk an IllegalAccessError
+  private static Encoder<Instant> instantEncoder() {
+    DataType type = new ObjectType(Instant.class);
+    Expression instant = rootRef(type, true);
+    Expression millis = rootCol(LongType);
     return EncoderFactory.create(
-        beamSerializer(rootRef(OBJECT_TYPE, true), coder),
-        beamDeserializer(rootCol(BinaryType), coder),
-        clazz);
+        nullSafe(instant, invoke(instant, "getMillis", LongType, false)),
+        nullSafe(millis, invoke(Instant.class, "ofEpochMilli", type, millis)),
+        Instant.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType}
+   * delegating to a Beam {@link Coder} underneath.
+   *
+   * @param coder Beam {@link Coder}
+   * @param nullable If to allow nullable items
+   */
+  private static <T> Encoder<T> binaryEncoder(Coder<T> coder, boolean nullable) {
+    Literal litCoder = lit(coder, Coder.class);
+    // T could be private, use OBJECT_TYPE for code generation to not risk an IllegalAccessError
+    return EncoderFactory.create(
+        invokeIfNotNull(
+            CoderHelpers.class,
+            "toByteArray",
+            BinaryType,
+            rootRef(OBJECT_TYPE, nullable),
+            litCoder),
+        invokeIfNotNull(
+            CoderHelpers.class, "fromByteArray", OBJECT_TYPE, rootCol(BinaryType), litCoder),
+        coder.getEncodedTypeDescriptor().getRawType());
+  }
+
+  private static <T, W extends BoundedWindow> Expression serializeWindowedValue(
+      Expression in,
+      Encoder<T> valueEnc,
+      Encoder<Instant> timestampEnc,
+      Encoder<Collection<W>> windowsEnc,
+      Encoder<PaneInfo> paneEnc) {
+    return serializerObject(
+        in,
+        tuple("value", serializeField(in, valueEnc, "getValue")),
+        tuple("timestamp", serializeField(in, timestampEnc, "getTimestamp")),
+        tuple("windows", serializeField(in, windowsEnc, "getWindows")),
+        tuple("pane", serializeField(in, paneEnc, "getPane")));
+  }
+
+  private static Expression serializerObject(Expression in, Tuple2<String, Expression>... fields) {
+    return SerializerBuildHelper.createSerializerForObject(in, seqOf(fields));
+  }
+
+  private static <T, W extends BoundedWindow> Expression deserializeWindowedValue(
+      Expression in,
+      Encoder<T> valueEnc,
+      Encoder<Instant> timestampEnc,
+      Encoder<Collection<W>> windowsEnc,
+      Encoder<PaneInfo> paneEnc) {
+    Expression value = deserializeField(in, valueEnc, 0, "value");
+    Expression windows = deserializeField(in, windowsEnc, 2, "windows");
+    Expression timestamp = deserializeField(in, timestampEnc, 1, "timestamp");
+    Expression pane = deserializeField(in, paneEnc, 3, "pane");
+    // set timestamp to end of window (maxTimestamp) if null
+    timestamp =
+        ifNotNull(timestamp, invoke(Utils.class, "maxTimestamp", timestamp.dataType(), windows));
+    Expression[] fields = new Expression[] {value, timestamp, windows, pane};
+
+    return nullSafe(pane, invoke(WindowedValue.class, "of", WINDOWED_VALUE, fields));
+  }
+
+  private static <K, V> Expression serializeMutablePair(
+      Expression in, Encoder<K> enc1, Encoder<V> enc2) {
+    return serializerObject(
+        in,
+        tuple("_1", serializeField(in, enc1, "_1")),
+        tuple("_2", serializeField(in, enc2, "_2")));
   }
 
-  /** Catalyst Expression that serializes elements using Beam {@link Coder}. */
-  private static <T> Expression beamSerializer(Expression obj, Coder<T> coder) {
-    Expression[] args = {obj, lit(coder, Coder.class)};
-    return EncoderFactory.invokeIfNotNull(CoderHelpers.class, "toByteArray", BinaryType, args);
+  private static <K, V> Expression deserializeMutablePair(
+      Expression in, Encoder<K> enc1, Encoder<V> enc2) {
+    Expression field1 = deserializeField(in, enc1, 0, "_1");
+    Expression field2 = deserializeField(in, enc2, 1, "_2");
+    return invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, field1, field2);
   }
 
-  /** Catalyst Expression that deserializes elements using Beam {@link Coder}. */
-  private static <T> Expression beamDeserializer(Expression bytes, Coder<T> coder) {
-    Expression[] args = {bytes, lit(coder, Coder.class)};
-    return EncoderFactory.invokeIfNotNull(CoderHelpers.class, "fromByteArray", OBJECT_TYPE, args);
+  private static <K, V> Expression serializeKV(
+      Expression in, Encoder<K> keyEnc, Encoder<V> valueEnc) {
+    return serializerObject(
+        in,
+        tuple("key", serializeField(in, keyEnc, "getKey")),
+        tuple("value", serializeField(in, valueEnc, "getValue")));
+  }
+
+  private static <K, V> Expression deserializeKV(
+      Expression in, Encoder<K> keyEnc, Encoder<V> valueEnc) {
+    Expression key = deserializeField(in, keyEnc, 0, "key");
+    Expression value = deserializeField(in, valueEnc, 1, "value");
+    return invoke(KV.class, "of", KV_TYPE, key, value);
+  }
+
+  public static <T> Expression serializeOneOf(Expression in, List<Encoder<T>> encoders) {
+    Expression type = invoke(in, "_1", IntegerType, false);
+    Expression[] args = new Expression[encoders.size() * 2];
+    for (int i = 0; i < encoders.size(); i++) {
+      args[i * 2] = lit(String.valueOf(i));
+      args[i * 2 + 1] = serializeOneOfField(in, type, encoders.get(i), i);
+    }
+    return new CreateNamedStruct(seqOf(args));
+  }
+
+  public static <T> Expression deserializeOneOf(Expression in, List<Encoder<T>> encoders) {
+    Expression[] args = new Expression[encoders.size()];
+    for (int i = 0; i < encoders.size(); i++) {
+      args[i] = deserializeOneOfField(in, encoders.get(i), i);
+    }
+    return new Coalesce(seqOf(args));
+  }
+
+  private static <T> Expression serializeOneOfField(
+      Expression in, Expression type, Encoder<T> enc, int typeIdx) {
+    Expression litNull = lit(null, serializedType(enc));
+    Expression value = invoke(in, "_2", deserializedType(enc), false);
+    return new If(new EqualTo(type, lit(typeIdx)), serialize(value, enc), litNull);
+  }
+
+  private static <T> Expression deserializeOneOfField(Expression in, Encoder<T> enc, int idx) {
+    GetStructField field = new GetStructField(in, idx, Option.empty());
+    Expression litNull = lit(null, TUPLE2_TYPE);
+    Expression newTuple = newInstance(Tuple2.class, TUPLE2_TYPE, lit(idx), deserialize(field, enc));
+    return new If(new IsNull(field), litNull, newTuple);
+  }
+
+  private static <T> Expression serializeField(Expression in, Encoder<T> enc, String getterName) {
+    Expression ref = serializer(enc).collect(match(BoundReference.class)).head();
+    return serialize(invoke(in, getterName, ref.dataType(), ref.nullable()), enc);
+  }
+
+  private static <T> Expression deserializeField(
+      Expression in, Encoder<T> enc, int idx, String name) {
+    return deserialize(new GetStructField(in, idx, new Some<>(name)), enc);
+  }
+
+  // Note: Currently this doesn't support nullable primitive values
+  private static <K, V> Expression mapSerializer(Expression map, Encoder<K> key, Encoder<V> value) {
+    DataType keyType = deserializedType(key);
+    DataType valueType = deserializedType(value);
+    return SerializerBuildHelper.createSerializerForMap(
+        map,
+        new MapElementInformation(keyType, false, e -> serialize(e, key)),
+        new MapElementInformation(valueType, false, e -> serialize(e, value)));
+  }
+
+  private static <MapT extends Map<K, V>, K, V> Expression mapDeserializer(
+      Expression in, Encoder<K> key, Encoder<V> value, Class<MapT> cls) {
+    Preconditions.checkArgument(cls.isAssignableFrom(HashMap.class) || cls.equals(TreeMap.class));
+    Expression keys = deserializeSeq(new MapKeys(in), key, false, false);
+    Expression values = deserializeSeq(new MapValues(in), value, false, false);
+    String fn = cls.equals(TreeMap.class) ? "toTreeMap" : "toMap";
+    return invoke(
+        Utils.class, fn, new ObjectType(cls), keys, values, mapItemType(key), mapItemType(value));
+  }
+
+  // serialized type for primitive types (avoid boxing!), otherwise the deserialized type
+  private static Literal mapItemType(Encoder<?> enc) {
+    return lit(isPrimitiveEnc(enc) ? serializedType(enc) : deserializedType(enc), DataType.class);
+  }
+
+  private static <T> Expression serializeSeq(Expression in, Encoder<T> enc, boolean nullable) {
+    if (isPrimitiveEnc(enc)) {
+      Expression array = invoke(in, "toArray", new ObjectType(Object[].class), false);
+      return SerializerBuildHelper.createSerializerForGenericArray(
+          array, serializedType(enc), nullable);
+    }
+    Expression seq = invoke(Utils.class, "toSeq", new ObjectType(Seq.class), in);
+    return MapObjects$.MODULE$.apply(
+        exp -> serialize(exp, enc), seq, deserializedType(enc), nullable, Option.empty());
+  }
+
+  private static <T> Expression deserializeSeq(
+      Expression in, Encoder<T> enc, boolean nullable, boolean asJava) {
+    DataType type = serializedType(enc); // input type is the serializer result type
+    if (isPrimitiveEnc(enc)) {
+      ObjectType listType = new ObjectType(List.class);
+      return asJava ? invoke(Utils.class, "toList", listType, in, lit(type, DataType.class)) : in;
+    }
+    Option<Class<?>> optCls = asJava ? Option.apply(List.class) : Option.empty();
+    return MapObjects$.MODULE$.apply(exp -> deserialize(exp, enc), in, type, nullable, optCls);
+  }
+
+  private static <T> boolean isPrimitiveEnc(Encoder<T> enc) {
+    return PRIMITIV_TYPES.contains(enc.clsTag().runtimeClass());
+  }
+
+  private static <T> Expression serialize(Expression input, Encoder<T> enc) {
+    return serializer(enc).transformUp(replace(BoundReference.class, input));
+  }
+
+  private static <T> Expression deserialize(Expression input, Encoder<T> enc) {
+    return deserializer(enc).transformUp(replace(GetColumnByOrdinal.class, input));
+  }
+
+  private static <T> Expression serializer(Encoder<T> enc) {
+    return ((ExpressionEncoder<T>) enc).objSerializer();
+  }
+
+  private static <T> Expression deserializer(Encoder<T> enc) {
+    return ((ExpressionEncoder<T>) enc).objDeserializer();
+  }
+
+  private static <T> DataType serializedType(Encoder<T> enc) {
+    return ((ExpressionEncoder<T>) enc).objSerializer().dataType();
+  }
+
+  private static <T> DataType deserializedType(Encoder<T> enc) {
+    return ((ExpressionEncoder<T>) enc).objDeserializer().dataType();
   }
 
   private static Expression rootRef(DataType dt, boolean nullable) {
@@ -65,7 +509,77 @@ public class EncoderHelpers {
     return new GetColumnByOrdinal(0, dt);
   }
 
+  private static Expression nullSafe(Expression in, Expression out) {
+    return new If(new IsNull(in), lit(null, out.dataType()), out);
+  }
+
+  private static Expression ifNotNull(Expression expr, Expression otherwise) {
+    return new If(new IsNotNull(expr), expr, otherwise);
+  }
+
+  private static <T extends @NonNull Object> Expression lit(T t) {
+    return Literal$.MODULE$.apply(t);
+  }
+
+  @SuppressWarnings("nullness") // literal NULL is allowed
+  private static <T> Expression lit(@Nullable T t, DataType dataType) {
+    return new Literal(t, dataType);
+  }
+
   private static <T extends @NonNull Object> Literal lit(T obj, Class<? extends T> cls) {
     return Literal.fromObject(obj, new ObjectType(cls));
   }
+
+  /** Encoder / expression utils that are called from generated code. */
+  public static class Utils {
+
+    public static PaneInfo paneInfoFromBytes(byte[] bytes) {
+      return CoderHelpers.fromByteArray(bytes, PaneInfoCoder.of());
+    }
+
+    public static byte[] paneInfoToBytes(PaneInfo pane) {
+      return CoderHelpers.toByteArray(pane, PaneInfoCoder.of());
+    }
+
+    /** The end of the only window (max timestamp). */
+    public static Instant maxTimestamp(Iterable<BoundedWindow> windows) {
+      return Iterables.getOnlyElement(windows).maxTimestamp();
+    }
+
+    public static List<Object> toList(ArrayData arrayData, DataType type) {
+      return JavaConverters.seqAsJavaList(arrayData.toSeq(type));
+    }
+
+    public static Seq<Object> toSeq(ArrayData arrayData) {
+      return arrayData.toSeq(OBJECT_TYPE);
+    }
+
+    public static Seq<Object> toSeq(Collection<Object> col) {
+      if (col instanceof List) {
+        return JavaConverters.asScalaBuffer((List<Object>) col);
+      }
+      return JavaConverters.collectionAsScalaIterable(col).toSeq();
+    }
+
+    public static TreeMap<Object, Object> toTreeMap(
+        ArrayData keys, ArrayData values, DataType keyType, DataType valueType) {
+      return toMap(new TreeMap<>(), keys, values, keyType, valueType);
+    }
+
+    public static HashMap<Object, Object> toMap(
+        ArrayData keys, ArrayData values, DataType keyType, DataType valueType) {
+      HashMap<Object, Object> map = Maps.newHashMapWithExpectedSize(keys.numElements());
+      return toMap(map, keys, values, keyType, valueType);
+    }
+
+    private static <MapT extends Map<Object, Object>> MapT toMap(
+        MapT map, ArrayData keys, ArrayData values, DataType keyType, DataType valueType) {
+      IndexedSeq<Object> keysSeq = keys.toSeq(keyType);
+      IndexedSeq<Object> valuesSeq = values.toSeq(valueType);
+      for (int i = 0; i < keysSeq.size(); i++) {
+        map.put(keysSeq.apply(i), valuesSeq.apply(i));
+      }
+      return map;
+    }
+  }
 }
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java
deleted file mode 100644
index f77fcea6796..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.Map;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderException;
-import org.apache.beam.sdk.coders.CustomCoder;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.TupleTag;
-import scala.Tuple2;
-
-/**
- * Coder to serialize and deserialize {@code}Tuple2<TupleTag<T>, WindowedValue<T>{/@code} to be used
- * in spark encoders while applying {@link org.apache.beam.sdk.transforms.DoFn}.
- *
- * @param <T> type of the elements in the collection
- */
-@SuppressWarnings({
-  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public class MultiOutputCoder<T> extends CustomCoder<Tuple2<TupleTag<T>, WindowedValue<T>>> {
-  Coder<TupleTag> tupleTagCoder;
-  Map<TupleTag<?>, Coder<?>> coderMap;
-  Coder<? extends BoundedWindow> windowCoder;
-
-  public static MultiOutputCoder of(
-      Coder<TupleTag> tupleTagCoder,
-      Map<TupleTag<?>, Coder<?>> coderMap,
-      Coder<? extends BoundedWindow> windowCoder) {
-    return new MultiOutputCoder(tupleTagCoder, coderMap, windowCoder);
-  }
-
-  private MultiOutputCoder(
-      Coder<TupleTag> tupleTagCoder,
-      Map<TupleTag<?>, Coder<?>> coderMap,
-      Coder<? extends BoundedWindow> windowCoder) {
-    this.tupleTagCoder = tupleTagCoder;
-    this.coderMap = coderMap;
-    this.windowCoder = windowCoder;
-  }
-
-  @Override
-  public void encode(Tuple2<TupleTag<T>, WindowedValue<T>> tuple2, OutputStream outStream)
-      throws IOException {
-    TupleTag<T> tupleTag = tuple2._1();
-    tupleTagCoder.encode(tupleTag, outStream);
-    Coder<T> valueCoder = (Coder<T>) coderMap.get(tupleTag);
-    WindowedValue.FullWindowedValueCoder<T> wvCoder =
-        WindowedValue.FullWindowedValueCoder.of(valueCoder, windowCoder);
-    wvCoder.encode(tuple2._2(), outStream);
-  }
-
-  @Override
-  public Tuple2<TupleTag<T>, WindowedValue<T>> decode(InputStream inStream)
-      throws CoderException, IOException {
-    TupleTag<T> tupleTag = (TupleTag<T>) tupleTagCoder.decode(inStream);
-    Coder<T> valueCoder = (Coder<T>) coderMap.get(tupleTag);
-    WindowedValue.FullWindowedValueCoder<T> wvCoder =
-        WindowedValue.FullWindowedValueCoder.of(valueCoder, windowCoder);
-    WindowedValue<T> wv = wvCoder.decode(inStream);
-    return Tuple2.apply(tupleTag, wv);
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java
deleted file mode 100644
index 9b5d5da2b2c..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-
-import static scala.collection.JavaConversions.asScalaBuffer;
-
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.spark.api.java.function.MapFunction;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.catalyst.InternalRow;
-
-/** Helper functions for working with {@link Row}. */
-public final class RowHelpers {
-
-  /**
-   * A Spark {@link MapFunction} for extracting a {@link WindowedValue} from a Row in which the
-   * {@link WindowedValue} was serialized to bytes using its {@link
-   * WindowedValue.WindowedValueCoder}.
-   *
-   * @param <T> The type of the object.
-   * @return A {@link MapFunction} that accepts a {@link Row} and returns its {@link WindowedValue}.
-   */
-  public static <T> MapFunction<Row, WindowedValue<T>> extractWindowedValueFromRowMapFunction(
-      WindowedValue.WindowedValueCoder<T> windowedValueCoder) {
-    return (MapFunction<Row, WindowedValue<T>>)
-        value -> {
-          // there is only one value put in each Row by the InputPartitionReader
-          byte[] bytes = (byte[]) value.get(0);
-          return windowedValueCoder.decode(new ByteArrayInputStream(bytes));
-        };
-  }
-
-  /**
-   * Serialize a windowedValue to bytes using windowedValueCoder {@link
-   * WindowedValue.FullWindowedValueCoder} and stores it an InternalRow.
-   */
-  public static <T> InternalRow storeWindowedValueInRow(
-      WindowedValue<T> windowedValue, Coder<T> coder) {
-    List<Object> list = new ArrayList<>();
-    // serialize the windowedValue to bytes array to comply with dataset binary schema
-    WindowedValue.FullWindowedValueCoder<T> windowedValueCoder =
-        WindowedValue.FullWindowedValueCoder.of(coder, GlobalWindow.Coder.INSTANCE);
-    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
-    try {
-      windowedValueCoder.encode(windowedValue, byteArrayOutputStream);
-      byte[] bytes = byteArrayOutputStream.toByteArray();
-      list.add(bytes);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-    return InternalRow.apply(asScalaBuffer(list).toList());
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java
deleted file mode 100644
index 71dca5264dd..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.types.Metadata;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-
-/** A {@link SchemaHelpers} for the Spark Batch Runner. */
-public class SchemaHelpers {
-  private static final StructType BINARY_SCHEMA =
-      new StructType(
-          new StructField[] {
-            StructField.apply("binaryStructField", DataTypes.BinaryType, true, Metadata.empty())
-          });
-
-  public static StructType binarySchema() {
-    // we use a binary schema for now because:
-    // using a empty schema raises a indexOutOfBoundsException
-    // using a NullType schema stores null in the elements
-    return BINARY_SCHEMA;
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java
deleted file mode 100644
index 5085eb9f796..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-
-import java.util.Collection;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
-import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.spark.api.java.function.MapFunction;
-import org.joda.time.Instant;
-
-/** Helper functions for working with windows. */
-public final class WindowingHelpers {
-
-  /**
-   * Checks if the window transformation should be applied or skipped.
-   *
-   * <p>Avoid running assign windows if both source and destination are global window or if the user
-   * has not specified the WindowFn (meaning they are just messing with triggering or allowed
-   * lateness).
-   */
-  @SuppressWarnings("unchecked")
-  public static <T, W extends BoundedWindow> boolean skipAssignWindows(
-      Window.Assign<T> transform, AbstractTranslationContext context) {
-    WindowFn<? super T, W> windowFnToApply = (WindowFn<? super T, W>) transform.getWindowFn();
-    PCollection<T> input = (PCollection<T>) context.getInput();
-    WindowFn<?, ?> windowFnOfInput = input.getWindowingStrategy().getWindowFn();
-    return windowFnToApply == null
-        || (windowFnOfInput instanceof GlobalWindows && windowFnToApply instanceof GlobalWindows);
-  }
-
-  public static <T, W extends BoundedWindow>
-      MapFunction<WindowedValue<T>, WindowedValue<T>> assignWindowsMapFunction(
-          WindowFn<T, W> windowFn) {
-    return (MapFunction<WindowedValue<T>, WindowedValue<T>>)
-        windowedValue -> {
-          final BoundedWindow boundedWindow = Iterables.getOnlyElement(windowedValue.getWindows());
-          final T element = windowedValue.getValue();
-          final Instant timestamp = windowedValue.getTimestamp();
-          Collection<W> windows =
-              windowFn.assignWindows(
-                  windowFn.new AssignContext() {
-
-                    @Override
-                    public T element() {
-                      return element;
-                    }
-
-                    @Override
-                    public Instant timestamp() {
-                      return timestamp;
-                    }
-
-                    @Override
-                    public BoundedWindow window() {
-                      return boundedWindow;
-                    }
-                  });
-          return WindowedValue.of(element, timestamp, windows, windowedValue.getPane());
-        };
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java
deleted file mode 100644
index 5eb60f68cb3..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.streaming;
-
-/**
- * Spark structured streaming framework does not support more than one aggregation in streaming mode
- * because of watermark implementation. As a consequence, this runner, does not support streaming
- * mode yet see https://github.com/apache/beam/issues/20241
- */
-class DatasetSourceStreaming {}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java
deleted file mode 100644
index 73d99efa463..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.streaming;
-
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.beam.runners.core.construction.SplittableParDo;
-import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.runners.TransformHierarchy;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.checkerframework.checker.nullness.qual.Nullable;
-
-/**
- * {@link PipelineTranslator} for executing a {@link Pipeline} in Spark in streaming mode. This
- * contains only the components specific to streaming: registry of streaming {@link
- * TransformTranslator} and registry lookup code.
- */
-@SuppressWarnings({
-  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public class PipelineTranslatorStreaming extends PipelineTranslator {
-  // --------------------------------------------------------------------------------------------
-  //  Transform Translator Registry
-  // --------------------------------------------------------------------------------------------
-
-  @SuppressWarnings("rawtypes")
-  private static final Map<Class<? extends PTransform>, TransformTranslator> TRANSFORM_TRANSLATORS =
-      new HashMap<>();
-
-  // TODO the ability to have more than one TransformTranslator per URN
-  // that could be dynamically chosen by a predicated that evaluates based on PCollection
-  // obtainable though node.getInputs.getValue()
-  // See
-  // https://github.com/seznam/euphoria/blob/master/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/SparkFlowTranslator.java#L83
-  // And
-  // https://github.com/seznam/euphoria/blob/master/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/SparkFlowTranslator.java#L106
-
-  static {
-    //    TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch());
-    //    TRANSFORM_TRANSLATORS.put(Combine.Globally.class, new CombineGloballyTranslatorBatch());
-    //    TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch());
-
-    // TODO: Do we need to have a dedicated translator for {@code Reshuffle} if it's deprecated?
-    // TRANSFORM_TRANSLATORS.put(Reshuffle.class, new ReshuffleTranslatorBatch());
-
-    //    TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch());
-    //
-    //    TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch());
-    //
-    //    TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch());
-
-    TRANSFORM_TRANSLATORS.put(
-        SplittableParDo.PrimitiveUnboundedRead.class, new ReadSourceTranslatorStreaming());
-
-    //    TRANSFORM_TRANSLATORS
-    //        .put(View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch());
-  }
-
-  public PipelineTranslatorStreaming(SparkStructuredStreamingPipelineOptions options) {
-    translationContext = new TranslationContext(options);
-  }
-
-  /** Returns a translator for the given node, if it is possible, otherwise null. */
-  @Override
-  protected TransformTranslator<?> getTransformTranslator(TransformHierarchy.Node node) {
-    @Nullable PTransform<?, ?> transform = node.getTransform();
-    // Root of the graph is null
-    if (transform == null) {
-      return null;
-    }
-    return TRANSFORM_TRANSLATORS.get(transform.getClass());
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java
deleted file mode 100644
index 8abc8771a4e..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.streaming;
-
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS;
-
-import java.io.IOException;
-import org.apache.beam.runners.core.construction.ReadTranslation;
-import org.apache.beam.runners.core.serialization.Base64Serializer;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
-import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
-import org.apache.beam.sdk.io.UnboundedSource;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PBegin;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-
-class ReadSourceTranslatorStreaming<T>
-    implements TransformTranslator<PTransform<PBegin, PCollection<T>>> {
-
-  private static final String sourceProviderClass = DatasetSourceStreaming.class.getCanonicalName();
-
-  @SuppressWarnings("unchecked")
-  @Override
-  public void translateTransform(
-      PTransform<PBegin, PCollection<T>> transform, AbstractTranslationContext context) {
-    AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>> rootTransform =
-        (AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>>)
-            context.getCurrentTransform();
-
-    UnboundedSource<T, UnboundedSource.CheckpointMark> source;
-    try {
-      source = ReadTranslation.unboundedSourceFromTransform(rootTransform);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-    SparkSession sparkSession = context.getSparkSession();
-
-    String serializedSource = Base64Serializer.serializeUnchecked(source);
-    Dataset<Row> rowDataset =
-        sparkSession
-            .readStream()
-            .format(sourceProviderClass)
-            .option(BEAM_SOURCE_OPTION, serializedSource)
-            .option(
-                DEFAULT_PARALLELISM,
-                String.valueOf(context.getSparkSession().sparkContext().defaultParallelism()))
-            .option(PIPELINE_OPTIONS, context.getSerializableOptions().toString())
-            .load();
-
-    // extract windowedValue from Row
-    WindowedValue.FullWindowedValueCoder<T> windowedValueCoder =
-        WindowedValue.FullWindowedValueCoder.of(
-            source.getOutputCoder(), GlobalWindow.Coder.INSTANCE);
-    Dataset<WindowedValue<T>> dataset =
-        rowDataset.map(
-            RowHelpers.extractWindowedValueFromRowMapFunction(windowedValueCoder),
-            EncoderHelpers.fromBeamCoder(windowedValueCoder));
-
-    PCollection<T> output = (PCollection<T>) context.getOutput();
-    context.putDataset(output, dataset);
-  }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
new file mode 100644
index 00000000000..1908cbf2bba
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.utils;
+
+import java.io.Serializable;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.Function1;
+import scala.Function2;
+import scala.PartialFunction;
+import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+import scala.collection.immutable.List;
+import scala.collection.immutable.Nil$;
+import scala.collection.mutable.WrappedArray;
+
+/** Utilities for easier interoperability with the Spark Scala API. */
+public class ScalaInterop {
+  private ScalaInterop() {}
+
+  public static <T> Seq<T> seqOf(T... t) {
+    return new WrappedArray.ofRef<>(t);
+  }
+
+  public static <T> List<T> concat(List<T> a, List<T> b) {
+    return b.$colon$colon$colon(a);
+  }
+
+  public static <T> Seq<T> listOf(T t) {
+    return emptyList().$colon$colon(t);
+  }
+
+  public static <T> List<T> emptyList() {
+    return (List<T>) Nil$.MODULE$;
+  }
+
+  /** Scala {@link Iterator} of Java {@link Iterable}. */
+  public static <T extends @NonNull Object> Iterator<T> scalaIterator(Iterable<T> iterable) {
+    return scalaIterator(iterable.iterator());
+  }
+
+  /** Scala {@link Iterator} of Java {@link java.util.Iterator}. */
+  public static <T extends @NonNull Object> Iterator<T> scalaIterator(java.util.Iterator<T> it) {
+    return JavaConverters.asScalaIterator(it);
+  }
+
+  /** Java {@link java.util.Iterator} of Scala {@link Iterator}. */
+  public static <T extends @NonNull Object> java.util.Iterator<T> javaIterator(Iterator<T> it) {
+    return JavaConverters.asJavaIterator(it);
+  }
+
+  public static <T1, T2> Tuple2<T1, T2> tuple(T1 t1, T2 t2) {
+    return new Tuple2<>(t1, t2);
+  }
+
+  public static <T extends @NonNull Object, V> PartialFunction<T, T> replace(
+      Class<V> clazz, T replace) {
+    return new PartialFunction<T, T>() {
+
+      @Override
+      public boolean isDefinedAt(T x) {
+        return clazz.isAssignableFrom(x.getClass());
+      }
+
+      @Override
+      public T apply(T x) {
+        return replace;
+      }
+    };
+  }
+
+  public static <T extends @NonNull Object, V> PartialFunction<T, V> match(Class<V> clazz) {
+    return new PartialFunction<T, V>() {
+
+      @Override
+      public boolean isDefinedAt(T x) {
+        return clazz.isAssignableFrom(x.getClass());
+      }
+
+      @Override
+      public V apply(T x) {
+        return (V) x;
+      }
+    };
+  }
+
+  public static <T, V> Fun1<T, V> fun1(Fun1<T, V> fun) {
+    return fun;
+  }
+
+  public static <T1, T2, V> Fun2<T1, T2, V> fun2(Fun2<T1, T2, V> fun) {
+    return fun;
+  }
+
+  public interface Fun1<T, V> extends Function1<T, V>, Serializable {}
+
+  public interface Fun2<T1, T2, V> extends Function2<T1, T2, V>, Serializable {}
+}
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
index f994f7712b3..7f2eaa10e80 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
@@ -49,7 +49,7 @@ public class InMemoryMetrics implements Sink {
     internalMetricRegistry = metricRegistry;
   }
 
-  @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"})
+  @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"}) // because of getGauges
   public static <T> T valueOf(final String name) {
     // this might fail in case we have multiple aggregators with the same suffix after
     // the last dot, but it should be good enough for tests.
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java
new file mode 100644
index 00000000000..b5b07db0e38
--- /dev/null
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+import static org.joda.time.Duration.standardMinutes;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.util.MutablePair;
+import org.hamcrest.Matcher;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.runner.RunWith;
+
+@RunWith(Enclosed.class)
+public class AggregatorsTest {
+
+  // just something easy readable
+  private static final Instant NOW = Instant.parse("2000-01-01T00:00Z");
+
+  /** Tests for NonMergingWindowedAggregator in {@link Aggregators}. */
+  public static class NonMergingWindowedAggregatorTest {
+
+    private SlidingWindows sliding =
+        SlidingWindows.of(standardMinutes(15)).every(standardMinutes(5));
+
+    private Aggregator<
+            WindowedValue<Integer>,
+            Map<IntervalWindow, MutablePair<Instant, Integer>>,
+            Collection<WindowedValue<Integer>>>
+        agg = windowedAgg(sliding);
+
+    @Test
+    public void testReduce() {
+      Map<IntervalWindow, MutablePair<Instant, Integer>> acc;
+
+      acc = agg.reduce(agg.zero(), windowedValue(1, at(10)));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(intervalWindow(0, 15), pair(at(10), 1)),
+              KV.of(intervalWindow(5, 20), pair(at(10), 1)),
+              KV.of(intervalWindow(10, 25), pair(at(10), 1))));
+
+      acc = agg.reduce(acc, windowedValue(2, at(16)));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(intervalWindow(0, 15), pair(at(10), 1)),
+              KV.of(intervalWindow(5, 20), pair(at(16), 3)),
+              KV.of(intervalWindow(10, 25), pair(at(16), 3)),
+              KV.of(intervalWindow(15, 30), pair(at(16), 2))));
+    }
+
+    @Test
+    public void testMerge() {
+      Map<IntervalWindow, MutablePair<Instant, Integer>> acc;
+
+      assertThat(agg.merge(agg.zero(), agg.zero()), equalTo(agg.zero()));
+
+      acc = mapOf(KV.of(intervalWindow(0, 15), pair(at(0), 1)));
+
+      assertThat(agg.merge(acc, agg.zero()), equalTo(acc));
+      assertThat(agg.merge(agg.zero(), acc), equalTo(acc));
+
+      acc = agg.merge(acc, acc);
+      assertThat(acc, equalsToMap(KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1))));
+
+      acc = agg.merge(acc, mapOf(KV.of(intervalWindow(5, 20), pair(at(5), 3))));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)),
+              KV.of(intervalWindow(5, 20), pair(at(5), 3))));
+
+      acc = agg.merge(mapOf(KV.of(intervalWindow(10, 25), pair(at(10), 4))), acc);
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)),
+              KV.of(intervalWindow(5, 20), pair(at(5), 3)),
+              KV.of(intervalWindow(10, 25), pair(at(10), 4))));
+    }
+
+    private WindowedValue<Integer> windowedValue(Integer value, Instant ts) {
+      return WindowedValue.of(value, ts, sliding.assignWindows(ts), PaneInfo.NO_FIRING);
+    }
+  }
+
+  /**
+   * Shared implementation of tests for SessionsAggregator and MergingWindowedAggregator in {@link
+   * Aggregators}.
+   */
+  public abstract static class AbstractSessionsTest<
+      AccT extends Map<IntervalWindow, MutablePair<Instant, Integer>>> {
+
+    static final Duration SESSIONS_GAP = standardMinutes(15);
+
+    final Aggregator<WindowedValue<Integer>, AccT, Collection<WindowedValue<Integer>>> agg;
+
+    AbstractSessionsTest(WindowFn<?, ?> windowFn) {
+      agg = windowedAgg(windowFn);
+    }
+
+    abstract AccT accOf(KV<IntervalWindow, MutablePair<Instant, Integer>>... entries);
+
+    @Test
+    public void testReduce() {
+      AccT acc;
+
+      acc = agg.reduce(agg.zero(), sessionValue(10, at(0)));
+      assertThat(acc, equalsToMap(KV.of(sessionWindow(0), pair(at(0), 10))));
+
+      // 2nd session after 1st
+      acc = agg.reduce(acc, sessionValue(7, at(20)));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(sessionWindow(0), pair(at(0), 10)), KV.of(sessionWindow(20), pair(at(20), 7))));
+
+      // merge into 2nd session
+      acc = agg.reduce(acc, sessionValue(6, at(18)));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(sessionWindow(0), pair(at(0), 10)),
+              KV.of(sessionWindow(18, 35), pair(at(20), 7 + 6))));
+
+      // merge into 2nd session
+      acc = agg.reduce(acc, sessionValue(5, at(21)));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(sessionWindow(0), pair(at(0), 10)),
+              KV.of(sessionWindow(18, 36), pair(at(21), 7 + 6 + 5))));
+
+      // 3rd session after 2nd
+      acc = agg.reduce(acc, sessionValue(2, NOW.plus(standardMinutes(45))));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(sessionWindow(0), pair(at(0), 10)),
+              KV.of(sessionWindow(18, 36), pair(at(21), 7 + 6 + 5)),
+              KV.of(sessionWindow(45), pair(at(45), 2))));
+
+      // merge with 1st and 2nd
+      acc = agg.reduce(acc, sessionValue(1, at(10)));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(sessionWindow(0, 36), pair(at(21), 10 + 7 + 6 + 5 + 1)),
+              KV.of(sessionWindow(45), pair(at(45), 2))));
+    }
+
+    @Test
+    public void testMerge() {
+      AccT acc;
+
+      assertThat(agg.merge(agg.zero(), agg.zero()), equalTo(agg.zero()));
+
+      acc = accOf(KV.of(sessionWindow(0), pair(at(0), 1)));
+
+      assertThat(agg.merge(acc, agg.zero()), equalTo(acc));
+      assertThat(agg.merge(agg.zero(), acc), equalTo(acc));
+
+      acc = agg.merge(acc, acc);
+      assertThat(acc, equalsToMap(KV.of(sessionWindow(0), pair(at(0), 1 + 1))));
+
+      acc = agg.merge(acc, accOf(KV.of(sessionWindow(20), pair(at(20), 2))));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(sessionWindow(0), pair(at(0), 1 + 1)),
+              KV.of(sessionWindow(20), pair(at(20), 2))));
+
+      acc = agg.merge(accOf(KV.of(sessionWindow(40), pair(at(40), 3))), acc);
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(sessionWindow(0), pair(at(0), 1 + 1)),
+              KV.of(sessionWindow(20), pair(at(20), 2)),
+              KV.of(sessionWindow(40), pair(at(40), 3))));
+
+      acc = agg.merge(acc, accOf(KV.of(sessionWindow(10), pair(at(10), 4))));
+      assertThat(
+          acc,
+          equalsToMap(
+              KV.of(sessionWindow(0, 35), pair(at(20), 1 + 1 + 2 + 4)),
+              KV.of(sessionWindow(40), pair(at(40), 3))));
+
+      acc = agg.merge(accOf(KV.of(sessionWindow(5, 45), pair(at(30), 5))), acc);
+      assertThat(
+          acc, equalsToMap(KV.of(sessionWindow(0, 55), pair(at(40), 1 + 1 + 2 + 4 + 3 + 5))));
+    }
+
+    private WindowedValue<Integer> sessionValue(Integer value, Instant ts) {
+      return WindowedValue.of(value, ts, new IntervalWindow(ts, SESSIONS_GAP), PaneInfo.NO_FIRING);
+    }
+
+    private IntervalWindow sessionWindow(int fromMinutes) {
+      return new IntervalWindow(at(fromMinutes), SESSIONS_GAP);
+    }
+
+    private static IntervalWindow sessionWindow(int fromMinutes, int toMinutes) {
+      return intervalWindow(fromMinutes, toMinutes);
+    }
+  }
+
+  /** Tests for specialized SessionsAggregator in {@link Aggregators}. */
+  public static class SessionsAggregatorTest
+      extends AbstractSessionsTest<TreeMap<IntervalWindow, MutablePair<Instant, Integer>>> {
+
+    public SessionsAggregatorTest() {
+      super(Sessions.withGapDuration(SESSIONS_GAP));
+    }
+
+    @Override
+    TreeMap<IntervalWindow, MutablePair<Instant, Integer>> accOf(
+        KV<IntervalWindow, MutablePair<Instant, Integer>>... entries) {
+      return new TreeMap<>(mapOf(entries));
+    }
+  }
+
+  /** Tests for MergingWindowedAggregator in {@link Aggregators}. */
+  public static class MergingWindowedAggregatorTest
+      extends AbstractSessionsTest<Map<IntervalWindow, MutablePair<Instant, Integer>>> {
+
+    public MergingWindowedAggregatorTest() {
+      super(new CustomSessions<>());
+    }
+
+    @Override
+    Map<IntervalWindow, MutablePair<Instant, Integer>> accOf(
+        KV<IntervalWindow, MutablePair<Instant, Integer>>... entries) {
+      return mapOf(entries);
+    }
+
+    /** Wrapper around {@link Sessions} to test the MergingWindowedAggregator. */
+    private static class CustomSessions<T> extends WindowFn<T, IntervalWindow> {
+      private final Sessions sessions = Sessions.withGapDuration(SESSIONS_GAP);
+
+      @Override
+      public Collection<IntervalWindow> assignWindows(WindowFn<T, IntervalWindow>.AssignContext c) {
+        return sessions.assignWindows((WindowFn.AssignContext) c);
+      }
+
+      @Override
+      public void mergeWindows(WindowFn<T, IntervalWindow>.MergeContext c) throws Exception {
+        sessions.mergeWindows((WindowFn<Object, IntervalWindow>.MergeContext) c);
+      }
+
+      @Override
+      public boolean isCompatible(WindowFn<?, ?> other) {
+        return sessions.isCompatible(other);
+      }
+
+      @Override
+      public Coder<IntervalWindow> windowCoder() {
+        return sessions.windowCoder();
+      }
+
+      @Override
+      public WindowMappingFn<IntervalWindow> getDefaultWindowMappingFn() {
+        return sessions.getDefaultWindowMappingFn();
+      }
+    }
+  }
+
+  private static IntervalWindow intervalWindow(int fromMinutes, int toMinutes) {
+    return new IntervalWindow(at(fromMinutes), at(toMinutes));
+  }
+
+  private static Instant at(int minutes) {
+    return NOW.plus(standardMinutes(minutes));
+  }
+
+  private static Matcher<Map<IntervalWindow, MutablePair<Instant, Integer>>> equalsToMap(
+      KV<IntervalWindow, MutablePair<Instant, Integer>>... entries) {
+    return equalTo(mapOf(entries));
+  }
+
+  private static Map<IntervalWindow, MutablePair<Instant, Integer>> mapOf(
+      KV<IntervalWindow, MutablePair<Instant, Integer>>... entries) {
+    return Arrays.asList(entries).stream().collect(Collectors.toMap(KV::getKey, KV::getValue));
+  }
+
+  private static MutablePair<Instant, Integer> pair(Instant ts, int value) {
+    return new MutablePair<>(ts, value);
+  }
+
+  private static <AccT>
+      Aggregator<WindowedValue<Integer>, AccT, Collection<WindowedValue<Integer>>> windowedAgg(
+          WindowFn<?, ?> windowFn) {
+    Encoder<Integer> intEnc = EncoderHelpers.encoderOf(Integer.class);
+    Encoder<BoundedWindow> windowEnc = encoderFor((Coder) IntervalWindow.getCoder());
+    Encoder<WindowedValue<Integer>> outputEnc = windowedValueEncoder(intEnc, windowEnc);
+
+    WindowingStrategy<?, ?> windowing =
+        WindowingStrategy.of(windowFn).withTimestampCombiner(TimestampCombiner.LATEST);
+
+    Aggregator<WindowedValue<Integer>, ?, Collection<WindowedValue<Integer>>> agg =
+        Aggregators.windowedValue(
+            new SimpleSum(), WindowedValue::getValue, windowing, windowEnc, intEnc, outputEnc);
+    return (Aggregator) agg;
+  }
+
+  private static class SimpleSum extends Combine.CombineFn<Integer, Integer, Integer> {
+
+    @Override
+    public Integer createAccumulator() {
+      return 0;
+    }
+
+    @Override
+    public Integer addInput(Integer acc, Integer input) {
+      return acc + input;
+    }
+
+    @Override
+    public Integer mergeAccumulators(Iterable<Integer> accs) {
+      return Streams.stream(accs.iterator()).reduce((a, b) -> a + b).orElseGet(() -> 0);
+    }
+
+    @Override
+    public Integer extractOutput(Integer acc) {
+      return acc;
+    }
+  }
+}
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java
similarity index 56%
copy from runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java
copy to runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java
index 52e60a3db54..dca8b664bd3 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java
@@ -18,43 +18,44 @@
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
 import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.BinaryCombineFn;
+import org.apache.beam.sdk.transforms.CombineFnBase;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.SerializableBiFunction;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
 import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
-import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
-/** Test class for beam to spark {@link org.apache.beam.sdk.transforms.Combine} translation. */
+/**
+ * Test class for beam to spark {@link Combine#globally(CombineFnBase.GlobalCombineFn)} translation.
+ */
 @RunWith(JUnit4.class)
-public class CombineTest implements Serializable {
-  private static Pipeline pipeline;
+public class CombineGloballyTest implements Serializable {
+  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
 
-  @BeforeClass
-  public static void beforeClass() {
+  private static PipelineOptions testOptions() {
     SparkStructuredStreamingPipelineOptions options =
         PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
     options.setRunner(SparkStructuredStreamingRunner.class);
     options.setTestMode(true);
-    pipeline = Pipeline.create(options);
+    return options;
   }
 
   @Test
@@ -79,65 +80,59 @@ public class CombineTest implements Serializable {
                     TimestampedValue.of(5, new Instant(11)),
                     TimestampedValue.of(6, new Instant(12))))
             .apply(Window.into(FixedWindows.of(Duration.millis(10))))
-            .apply(Combine.globally(Sum.ofIntegers()).withoutDefaults());
+            .apply(Sum.integersGlobally().withoutDefaults());
     PAssert.that(input).containsInAnyOrder(7, 14);
+    pipeline.run();
   }
 
   @Test
-  public void testCombinePerKey() {
-    List<KV<Integer, Integer>> elems = new ArrayList<>();
-    elems.add(KV.of(1, 1));
-    elems.add(KV.of(1, 3));
-    elems.add(KV.of(1, 5));
-    elems.add(KV.of(2, 2));
-    elems.add(KV.of(2, 4));
-    elems.add(KV.of(2, 6));
-
-    PCollection<KV<Integer, Integer>> input =
-        pipeline.apply(Create.of(elems)).apply(Sum.integersPerKey());
-    PAssert.that(input).containsInAnyOrder(KV.of(1, 9), KV.of(2, 12));
+  public void testCombineGloballyWithSlidingWindows() {
+    PCollection<Integer> input =
+        pipeline
+            .apply(
+                Create.timestamped(
+                    TimestampedValue.of(1, new Instant(1)),
+                    TimestampedValue.of(3, new Instant(2)),
+                    TimestampedValue.of(5, new Instant(3)),
+                    TimestampedValue.of(2, new Instant(1)),
+                    TimestampedValue.of(4, new Instant(2)),
+                    TimestampedValue.of(6, new Instant(3))))
+            .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1))))
+            .apply(Sum.integersGlobally().withoutDefaults());
+    PAssert.that(input)
+        .containsInAnyOrder(1 + 2, 1 + 2 + 3 + 4, 1 + 3 + 5 + 2 + 4 + 6, 3 + 4 + 5 + 6, 5 + 6);
     pipeline.run();
   }
 
   @Test
-  public void testCombinePerKeyPreservesWindowing() {
-    PCollection<KV<Integer, Integer>> input =
+  public void testCombineGloballyWithMergingWindows() {
+    PCollection<Integer> input =
         pipeline
             .apply(
                 Create.timestamped(
-                    TimestampedValue.of(KV.of(1, 1), new Instant(1)),
-                    TimestampedValue.of(KV.of(1, 3), new Instant(2)),
-                    TimestampedValue.of(KV.of(1, 5), new Instant(11)),
-                    TimestampedValue.of(KV.of(2, 2), new Instant(3)),
-                    TimestampedValue.of(KV.of(2, 4), new Instant(11)),
-                    TimestampedValue.of(KV.of(2, 6), new Instant(12))))
-            .apply(Window.into(FixedWindows.of(Duration.millis(10))))
-            .apply(Sum.integersPerKey());
-    PAssert.that(input).containsInAnyOrder(KV.of(1, 4), KV.of(1, 5), KV.of(2, 2), KV.of(2, 10));
+                    TimestampedValue.of(2, new Instant(5)),
+                    TimestampedValue.of(4, new Instant(11)),
+                    TimestampedValue.of(6, new Instant(12))))
+            .apply(Window.into(Sessions.withGapDuration(Duration.millis(5))))
+            .apply(Sum.integersGlobally().withoutDefaults());
+
+    PAssert.that(input).containsInAnyOrder(2 /*window [5-10)*/, 10 /*window [11-17)*/);
     pipeline.run();
   }
 
   @Test
-  public void testCombinePerKeyWithSlidingWindows() {
-    PCollection<KV<Integer, Integer>> input =
+  public void testCountGloballyWithSlidingWindows() {
+    PCollection<String> input =
         pipeline
             .apply(
                 Create.timestamped(
-                    TimestampedValue.of(KV.of(1, 1), new Instant(1)),
-                    TimestampedValue.of(KV.of(1, 3), new Instant(2)),
-                    TimestampedValue.of(KV.of(1, 5), new Instant(3)),
-                    TimestampedValue.of(KV.of(1, 2), new Instant(1)),
-                    TimestampedValue.of(KV.of(1, 4), new Instant(2)),
-                    TimestampedValue.of(KV.of(1, 6), new Instant(3))))
-            .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1))))
-            .apply(Sum.integersPerKey());
-    PAssert.that(input)
-        .containsInAnyOrder(
-            KV.of(1, 1 + 2),
-            KV.of(1, 1 + 2 + 3 + 4),
-            KV.of(1, 1 + 3 + 5 + 2 + 4 + 6),
-            KV.of(1, 3 + 4 + 5 + 6),
-            KV.of(1, 5 + 6));
+                    TimestampedValue.of("a", new Instant(1)),
+                    TimestampedValue.of("a", new Instant(2)),
+                    TimestampedValue.of("a", new Instant(2))))
+            .apply(Window.into(SlidingWindows.of(Duration.millis(2)).every(Duration.millis(1))));
+    PCollection<Long> output =
+        input.apply(Combine.globally(Count.<String>combineFn()).withoutDefaults());
+    PAssert.that(output).containsInAnyOrder(1L, 3L, 2L);
     pipeline.run();
   }
 
@@ -152,35 +147,9 @@ public class CombineTest implements Serializable {
                     TimestampedValue.of(5, new Instant(3))))
             .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1))))
             .apply(
-                Combine.globally(
-                        Combine.BinaryCombineFn.of(
-                            (SerializableBiFunction<Integer, Integer, Integer>)
-                                (integer1, integer2) -> integer1 > integer2 ? integer1 : integer2))
+                Combine.globally(BinaryCombineFn.<Integer>of((i1, i2) -> i1 > i2 ? i1 : i2))
                     .withoutDefaults());
     PAssert.that(input).containsInAnyOrder(1, 3, 5, 5, 5);
     pipeline.run();
   }
-
-  @Test
-  public void testCountPerElementWithSlidingWindows() {
-    PCollection<String> input =
-        pipeline
-            .apply(
-                Create.timestamped(
-                    TimestampedValue.of("a", new Instant(1)),
-                    TimestampedValue.of("a", new Instant(2)),
-                    TimestampedValue.of("b", new Instant(3)),
-                    TimestampedValue.of("b", new Instant(4))))
-            .apply(Window.into(SlidingWindows.of(Duration.millis(2)).every(Duration.millis(1))));
-    PCollection<KV<String, Long>> output = input.apply(Count.perElement());
-    PAssert.that(output)
-        .containsInAnyOrder(
-            KV.of("a", 1L),
-            KV.of("a", 2L),
-            KV.of("a", 1L),
-            KV.of("b", 1L),
-            KV.of("b", 2L),
-            KV.of("b", 1L));
-    pipeline.run();
-  }
 }
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java
similarity index 71%
rename from runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java
rename to runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java
index 52e60a3db54..c8b25b3355d 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java
@@ -22,65 +22,44 @@ import java.util.ArrayList;
 import java.util.List;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.CombineFnBase;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.SerializableBiFunction;
+import org.apache.beam.sdk.transforms.Distinct;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
 import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TimestampedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
-import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
-/** Test class for beam to spark {@link org.apache.beam.sdk.transforms.Combine} translation. */
+/**
+ * Test class for beam to spark {@link
+ * org.apache.beam.sdk.transforms.Combine#perKey(CombineFnBase.GlobalCombineFn)} translation.
+ */
 @RunWith(JUnit4.class)
-public class CombineTest implements Serializable {
-  private static Pipeline pipeline;
+public class CombinePerKeyTest implements Serializable {
+  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
 
-  @BeforeClass
-  public static void beforeClass() {
+  private static PipelineOptions testOptions() {
     SparkStructuredStreamingPipelineOptions options =
         PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
     options.setRunner(SparkStructuredStreamingRunner.class);
     options.setTestMode(true);
-    pipeline = Pipeline.create(options);
-  }
-
-  @Test
-  public void testCombineGlobally() {
-    PCollection<Integer> input =
-        pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(Sum.integersGlobally());
-    PAssert.that(input).containsInAnyOrder(55);
-    // uses combine per key
-    pipeline.run();
-  }
-
-  @Test
-  public void testCombineGloballyPreservesWindowing() {
-    PCollection<Integer> input =
-        pipeline
-            .apply(
-                Create.timestamped(
-                    TimestampedValue.of(1, new Instant(1)),
-                    TimestampedValue.of(2, new Instant(2)),
-                    TimestampedValue.of(3, new Instant(11)),
-                    TimestampedValue.of(4, new Instant(3)),
-                    TimestampedValue.of(5, new Instant(11)),
-                    TimestampedValue.of(6, new Instant(12))))
-            .apply(Window.into(FixedWindows.of(Duration.millis(10))))
-            .apply(Combine.globally(Sum.ofIntegers()).withoutDefaults());
-    PAssert.that(input).containsInAnyOrder(7, 14);
+    return options;
   }
 
   @Test
@@ -99,6 +78,17 @@ public class CombineTest implements Serializable {
     pipeline.run();
   }
 
+  @Test
+  public void testDistinctViaCombinePerKey() {
+    List<Integer> elems = Lists.newArrayList(1, 2, 3, 3, 4, 4, 4, 4, 5, 5);
+
+    // Distinct is implemented in terms of CombinePerKey
+    PCollection<Integer> result = pipeline.apply(Create.of(elems)).apply(Distinct.create());
+
+    PAssert.that(result).containsInAnyOrder(1, 2, 3, 4, 5);
+    pipeline.run();
+  }
+
   @Test
   public void testCombinePerKeyPreservesWindowing() {
     PCollection<KV<Integer, Integer>> input =
@@ -142,22 +132,26 @@ public class CombineTest implements Serializable {
   }
 
   @Test
-  public void testBinaryCombineWithSlidingWindows() {
-    PCollection<Integer> input =
+  public void testCombineByKeyWithMergingWindows() {
+    PCollection<KV<Integer, Integer>> input =
         pipeline
             .apply(
                 Create.timestamped(
-                    TimestampedValue.of(1, new Instant(1)),
-                    TimestampedValue.of(3, new Instant(2)),
-                    TimestampedValue.of(5, new Instant(3))))
-            .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1))))
-            .apply(
-                Combine.globally(
-                        Combine.BinaryCombineFn.of(
-                            (SerializableBiFunction<Integer, Integer, Integer>)
-                                (integer1, integer2) -> integer1 > integer2 ? integer1 : integer2))
-                    .withoutDefaults());
-    PAssert.that(input).containsInAnyOrder(1, 3, 5, 5, 5);
+                    TimestampedValue.of(KV.of(1, 1), new Instant(5)),
+                    TimestampedValue.of(KV.of(1, 3), new Instant(7)),
+                    TimestampedValue.of(KV.of(1, 5), new Instant(11)),
+                    TimestampedValue.of(KV.of(2, 2), new Instant(5)),
+                    TimestampedValue.of(KV.of(2, 4), new Instant(11)),
+                    TimestampedValue.of(KV.of(2, 6), new Instant(12))))
+            .apply(Window.into(Sessions.withGapDuration(Duration.millis(5))))
+            .apply(Sum.integersPerKey());
+
+    PAssert.that(input)
+        .containsInAnyOrder(
+            KV.of(1, 9), // window [5-16)
+            KV.of(2, 2), // window [5-10)
+            KV.of(2, 10) // window [11-17)
+            );
     pipeline.run();
   }
 
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java
index 0175d03f875..582a31a05a6 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java
@@ -27,13 +27,15 @@ import java.util.ArrayList;
 import java.util.List;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.values.PCollection;
 import org.junit.BeforeClass;
 import org.junit.ClassRule;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 import org.junit.runner.RunWith;
@@ -46,15 +48,18 @@ public class ComplexSourceTest implements Serializable {
   private static File file;
   private static List<String> lines = createLines(30);
 
-  private static Pipeline pipeline;
+  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
 
-  @BeforeClass
-  public static void beforeClass() throws IOException {
+  private static PipelineOptions testOptions() {
     SparkStructuredStreamingPipelineOptions options =
         PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
     options.setRunner(SparkStructuredStreamingRunner.class);
     options.setTestMode(true);
-    pipeline = Pipeline.create(options);
+    return options;
+  }
+
+  @BeforeClass
+  public static void beforeClass() throws IOException {
     file = createFile(lines);
   }
 
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java
index e126d06e685..50b443da9ae 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java
@@ -20,14 +20,15 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 import java.io.Serializable;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
-import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -35,15 +36,14 @@ import org.junit.runners.JUnit4;
 /** Test class for beam to spark flatten translation. */
 @RunWith(JUnit4.class)
 public class FlattenTest implements Serializable {
-  private static Pipeline pipeline;
+  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
 
-  @BeforeClass
-  public static void beforeClass() {
+  private static PipelineOptions testOptions() {
     SparkStructuredStreamingPipelineOptions options =
         PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
     options.setRunner(SparkStructuredStreamingRunner.class);
     options.setTestMode(true);
-    pipeline = Pipeline.create(options);
+    return options;
   }
 
   @Test
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java
index 07850232853..1a84466b319 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java
@@ -17,30 +17,41 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
+import static java.util.Arrays.stream;
+import static java.util.stream.Collectors.groupingBy;
+import static java.util.stream.Collectors.mapping;
+import static java.util.stream.Collectors.toList;
 import static org.apache.beam.sdk.testing.SerializableMatchers.containsInAnyOrder;
 import static org.hamcrest.MatcherAssert.assertThat;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.SerializableMatcher;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TimestampedValue;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
-import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -48,15 +59,14 @@ import org.junit.runners.JUnit4;
 /** Test class for beam to spark {@link ParDo} translation. */
 @RunWith(JUnit4.class)
 public class GroupByKeyTest implements Serializable {
-  private static Pipeline pipeline;
+  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
 
-  @BeforeClass
-  public static void beforeClass() {
+  private static PipelineOptions testOptions() {
     SparkStructuredStreamingPipelineOptions options =
         PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
     options.setRunner(SparkStructuredStreamingRunner.class);
     options.setTestMode(true);
-    pipeline = Pipeline.create(options);
+    return options;
   }
 
   @Test
@@ -64,54 +74,89 @@ public class GroupByKeyTest implements Serializable {
     pipeline
         .apply(
             Create.timestamped(
-                TimestampedValue.of(KV.of(1, 1), new Instant(1)),
-                TimestampedValue.of(KV.of(1, 3), new Instant(2)),
-                TimestampedValue.of(KV.of(1, 5), new Instant(11)),
-                TimestampedValue.of(KV.of(2, 2), new Instant(3)),
-                TimestampedValue.of(KV.of(2, 4), new Instant(11)),
-                TimestampedValue.of(KV.of(2, 6), new Instant(12))))
+                shuffleRandomly(
+                    TimestampedValue.of(KV.of(1, 1), new Instant(1)),
+                    TimestampedValue.of(KV.of(1, 3), new Instant(2)),
+                    TimestampedValue.of(KV.of(1, 5), new Instant(11)),
+                    TimestampedValue.of(KV.of(2, 2), new Instant(3)),
+                    TimestampedValue.of(KV.of(2, 4), new Instant(11)),
+                    TimestampedValue.of(KV.of(2, 6), new Instant(12)))))
         .apply(Window.into(FixedWindows.of(Duration.millis(10))))
         .apply(GroupByKey.create())
-        // do manual assertion for windows because Passert do not support multiple kv with same key
-        // (because multiple windows)
+        // Passert do not support multiple kv with same key (because multiple windows)
         .apply(
             ParDo.of(
-                new DoFn<KV<Integer, Iterable<Integer>>, KV<Integer, Iterable<Integer>>>() {
+                new AssertContains<>(
+                    KV.of(1, containsInAnyOrder(1, 3)), // window [0-10)
+                    KV.of(1, containsInAnyOrder(5)), // window [10-20)
+                    KV.of(2, containsInAnyOrder(4, 6)), // window [10-20)
+                    KV.of(2, containsInAnyOrder(2)) // window [0-10)
+                    )));
+    pipeline.run();
+  }
 
-                  @ProcessElement
-                  public void processElement(ProcessContext context) {
-                    KV<Integer, Iterable<Integer>> element = context.element();
-                    if (element.getKey() == 1) {
-                      if (Iterables.size(element.getValue()) == 2) {
-                        assertThat(element.getValue(), containsInAnyOrder(1, 3)); // window [0-10)
-                      } else {
-                        assertThat(element.getValue(), containsInAnyOrder(5)); // window [10-20)
-                      }
-                    } else { // key == 2
-                      if (Iterables.size(element.getValue()) == 2) {
-                        assertThat(element.getValue(), containsInAnyOrder(4, 6)); // window [10-20)
-                      } else {
-                        assertThat(element.getValue(), containsInAnyOrder(2)); // window [0-10)
-                      }
-                    }
-                    context.output(element);
-                  }
-                }));
+  @Test
+  public void testGroupByKeyExplodesMultipleWindows() {
+    pipeline
+        .apply(
+            Create.timestamped(
+                shuffleRandomly(
+                    TimestampedValue.of(KV.of(1, 1), new Instant(5)),
+                    TimestampedValue.of(KV.of(1, 3), new Instant(7)),
+                    TimestampedValue.of(KV.of(1, 5), new Instant(11)),
+                    TimestampedValue.of(KV.of(2, 2), new Instant(5)),
+                    TimestampedValue.of(KV.of(2, 4), new Instant(11)),
+                    TimestampedValue.of(KV.of(2, 6), new Instant(12)))))
+        .apply(Window.into(SlidingWindows.of(Duration.millis(10)).every(Duration.millis(5))))
+        .apply(GroupByKey.create())
+        // Passert do not support multiple kv with same key (because multiple windows)
+        .apply(
+            ParDo.of(
+                new AssertContains<>(
+                    KV.of(1, containsInAnyOrder(1, 3)), // window [0-10)
+                    KV.of(1, containsInAnyOrder(1, 3, 5)), // window [5-15)
+                    KV.of(1, containsInAnyOrder(5)), // window [10-20)
+                    KV.of(2, containsInAnyOrder(2)), // window [0-10)
+                    KV.of(2, containsInAnyOrder(2, 4, 6)), // window [5-15)
+                    KV.of(2, containsInAnyOrder(4, 6)) // window [10-20)
+                    )));
+    pipeline.run();
+  }
+
+  @Test
+  public void testGroupByKeyWithMergingWindows() {
+    pipeline
+        .apply(
+            Create.timestamped(
+                shuffleRandomly(
+                    TimestampedValue.of(KV.of(1, 1), new Instant(5)),
+                    TimestampedValue.of(KV.of(1, 3), new Instant(7)),
+                    TimestampedValue.of(KV.of(1, 5), new Instant(11)),
+                    TimestampedValue.of(KV.of(2, 2), new Instant(5)),
+                    TimestampedValue.of(KV.of(2, 4), new Instant(11)),
+                    TimestampedValue.of(KV.of(2, 6), new Instant(12)))))
+        .apply(Window.into(Sessions.withGapDuration(Duration.millis(5))))
+        .apply(GroupByKey.create())
+        // Passert do not support multiple kv with same key (because multiple windows)
+        .apply(
+            ParDo.of(
+                new AssertContains<>(
+                    KV.of(1, containsInAnyOrder(1, 3, 5)), // window [5-16)
+                    KV.of(2, containsInAnyOrder(2)), // window [5-10)
+                    KV.of(2, containsInAnyOrder(4, 6)) // window [11-17)
+                    )));
     pipeline.run();
   }
 
   @Test
   public void testGroupByKey() {
-    List<KV<Integer, Integer>> elems = new ArrayList<>();
-    elems.add(KV.of(1, 1));
-    elems.add(KV.of(1, 3));
-    elems.add(KV.of(1, 5));
-    elems.add(KV.of(2, 2));
-    elems.add(KV.of(2, 4));
-    elems.add(KV.of(2, 6));
+    List<KV<Integer, Integer>> elems =
+        shuffleRandomly(
+            KV.of(1, 1), KV.of(1, 3), KV.of(1, 5), KV.of(2, 2), KV.of(2, 4), KV.of(2, 6));
 
     PCollection<KV<Integer, Iterable<Integer>>> input =
         pipeline.apply(Create.of(elems)).apply(GroupByKey.create());
+
     PAssert.thatMap(input)
         .satisfies(
             results -> {
@@ -121,4 +166,27 @@ public class GroupByKeyTest implements Serializable {
             });
     pipeline.run();
   }
+
+  static class AssertContains<K, V> extends DoFn<KV<K, Iterable<V>>, Void> {
+    private final Map<K, List<SerializableMatcher<Iterable<? extends V>>>> byKey;
+
+    public AssertContains(KV<K, SerializableMatcher<Iterable<? extends V>>>... matchers) {
+      byKey = stream(matchers).collect(groupingBy(KV::getKey, mapping(KV::getValue, toList())));
+    }
+
+    @ProcessElement
+    public void processElement(@Element KV<K, Iterable<V>> elem) {
+      assertThat("Unexpected key: " + elem.getKey(), byKey.containsKey(elem.getKey()));
+      List<V> values = ImmutableList.copyOf(elem.getValue());
+      assertThat(
+          "Unexpected values " + values + " for key " + elem.getKey(),
+          byKey.get(elem.getKey()).stream().anyMatch(m -> m.matches(values)));
+    }
+  }
+
+  private <T> List<T> shuffleRandomly(T... elems) {
+    ArrayList<T> list = Lists.newArrayList(elems);
+    Collections.shuffle(list);
+    return list;
+  }
 }
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
index 16d9a8b7fa8..f319173ed2b 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
@@ -20,19 +20,24 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 import java.io.Serializable;
 import java.util.List;
 import java.util.Map;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
-import org.junit.BeforeClass;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -40,15 +45,14 @@ import org.junit.runners.JUnit4;
 /** Test class for beam to spark {@link ParDo} translation. */
 @RunWith(JUnit4.class)
 public class ParDoTest implements Serializable {
-  private static Pipeline pipeline;
+  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
 
-  @BeforeClass
-  public static void beforeClass() {
+  private static PipelineOptions testOptions() {
     SparkStructuredStreamingPipelineOptions options =
         PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
     options.setRunner(SparkStructuredStreamingRunner.class);
     options.setTestMode(true);
-    pipeline = Pipeline.create(options);
+    return options;
   }
 
   @Test
@@ -59,6 +63,44 @@ public class ParDoTest implements Serializable {
     pipeline.run();
   }
 
+  @Test
+  public void testPardoWithOutputTagsCachedRDD() {
+    pardoWithOutputTags("MEMORY_ONLY");
+  }
+
+  @Test
+  public void testPardoWithOutputTagsCachedDataset() {
+    pardoWithOutputTags("MEMORY_AND_DISK");
+  }
+
+  private void pardoWithOutputTags(String storageLevel) {
+    pipeline.getOptions().as(SparkCommonPipelineOptions.class).setStorageLevel(storageLevel);
+
+    TupleTag<Integer> even = new TupleTag<Integer>() {};
+    TupleTag<String> unevenAsString = new TupleTag<String>() {};
+
+    DoFn<Integer, Integer> doFn =
+        new DoFn<Integer, Integer>() {
+          @ProcessElement
+          public void processElement(@Element Integer i, MultiOutputReceiver out) {
+            if (i % 2 == 0) {
+              out.get(even).output(i);
+            } else {
+              out.get(unevenAsString).output(i.toString());
+            }
+          }
+        };
+
+    PCollectionTuple outputs =
+        pipeline
+            .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+            .apply(ParDo.of(doFn).withOutputTags(even, TupleTagList.of(unevenAsString)));
+
+    PAssert.that(outputs.get(even)).containsInAnyOrder(2, 4, 6, 8, 10);
+    PAssert.that(outputs.get(unevenAsString)).containsInAnyOrder("1", "3", "5", "7", "9");
+    pipeline.run();
+  }
+
   @Test
   public void testTwoPardoInRow() {
     PCollection<Integer> input =
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java
index 70cdca630b9..0f16b644222 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java
@@ -20,12 +20,13 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 import java.io.Serializable;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.values.PCollection;
-import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -33,15 +34,14 @@ import org.junit.runners.JUnit4;
 /** Test class for beam to spark source translation. */
 @RunWith(JUnit4.class)
 public class SimpleSourceTest implements Serializable {
-  private static Pipeline pipeline;
+  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
 
-  @BeforeClass
-  public static void beforeClass() {
+  private static PipelineOptions testOptions() {
     SparkStructuredStreamingPipelineOptions options =
         PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
     options.setRunner(SparkStructuredStreamingRunner.class);
     options.setTestMode(true);
-    pipeline = Pipeline.create(options);
+    return options;
   }
 
   @Test
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java
index b8b41010a24..28efe754ddf 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java
@@ -20,9 +20,10 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 import java.io.Serializable;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
@@ -31,7 +32,7 @@ import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
-import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -39,15 +40,14 @@ import org.junit.runners.JUnit4;
 /** Test class for beam to spark window assign translation. */
 @RunWith(JUnit4.class)
 public class WindowAssignTest implements Serializable {
-  private static Pipeline pipeline;
+  @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
 
-  @BeforeClass
-  public static void beforeClass() {
+  private static PipelineOptions testOptions() {
     SparkStructuredStreamingPipelineOptions options =
         PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
     options.setRunner(SparkStructuredStreamingRunner.class);
     options.setTestMode(true);
-    pipeline = Pipeline.create(options);
+    return options;
   }
 
   @Test
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
index c8a8fba8d28..ab6e3083c54 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
@@ -18,32 +18,95 @@
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
 import static java.util.Arrays.asList;
-import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.fromBeamCoder;
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toList;
+import static java.util.stream.Collectors.toMap;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates.notNull;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.StringType;
+import static org.apache.spark.sql.types.DataTypes.createStructField;
+import static org.apache.spark.sql.types.DataTypes.createStructType;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertEquals;
+import static org.hamcrest.Matchers.instanceOf;
 
-import java.util.Arrays;
+import java.math.BigDecimal;
+import java.math.MathContext;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
+import java.util.TreeMap;
+import java.util.function.Function;
 import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
+import org.apache.beam.sdk.coders.BigDecimalCoder;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
+import org.apache.beam.sdk.coders.BigEndianLongCoder;
+import org.apache.beam.sdk.coders.BigEndianShortCoder;
+import org.apache.beam.sdk.coders.BooleanCoder;
+import org.apache.beam.sdk.coders.ByteCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.DelegateCoder;
+import org.apache.beam.sdk.coders.DoubleCoder;
+import org.apache.beam.sdk.coders.FloatCoder;
+import org.apache.beam.sdk.coders.InstantCoder;
+import org.apache.beam.sdk.coders.ListCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.joda.time.Instant;
 import org.junit.ClassRule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import scala.Tuple2;
 
 /** Test of the wrapping of Beam Coders as Spark ExpressionEncoders. */
 @RunWith(JUnit4.class)
 public class EncoderHelpersTest {
+  @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule("local[1]");
 
-  @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule();
+  private static final Encoder<GlobalWindow> windowEnc =
+      EncoderHelpers.encoderOf(GlobalWindow.class);
+
+  private static final Map<Coder<?>, List<?>> BASIC_CASES =
+      ImmutableMap.<Coder<?>, List<?>>builder()
+          .put(BooleanCoder.of(), asList(true, false, null))
+          .put(ByteCoder.of(), asList((byte) 1, null))
+          .put(BigEndianShortCoder.of(), asList((short) 1, null))
+          .put(BigEndianIntegerCoder.of(), asList(1, 2, 3, null))
+          .put(VarIntCoder.of(), asList(1, 2, 3, null))
+          .put(BigEndianLongCoder.of(), asList(1L, 2L, 3L, null))
+          .put(VarLongCoder.of(), asList(1L, 2L, 3L, null))
+          .put(FloatCoder.of(), asList((float) 1.0, (float) 2.0, null))
+          .put(DoubleCoder.of(), asList(1.0, 2.0, null))
+          .put(StringUtf8Coder.of(), asList("1", "2", null))
+          .put(BigDecimalCoder.of(), asList(bigDecimalOf(1L), bigDecimalOf(2L), null))
+          .put(InstantCoder.of(), asList(Instant.ofEpochMilli(1), null))
+          .build();
 
   private <T> Dataset<T> createDataset(List<T> data, Encoder<T> encoder) {
     Dataset<T> ds = sessionRule.getSession().createDataset(data, encoder);
@@ -52,10 +115,14 @@ public class EncoderHelpersTest {
   }
 
   @Test
-  public void beamCoderToSparkEncoderTest() {
-    List<Integer> data = Arrays.asList(1, 2, 3);
-    Dataset<Integer> dataset = createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
-    assertEquals(data, dataset.collectAsList());
+  public void testBeamEncoderMappings() {
+    BASIC_CASES.forEach(
+        (coder, data) -> {
+          Encoder<?> encoder = encoderFor(coder);
+          serializeAndDeserialize(data.get(0), (Encoder) encoder);
+          Dataset<?> dataset = createDataset(data, (Encoder) encoder);
+          assertThat(dataset.collect(), equalTo(data.toArray()));
+        });
   }
 
   @Test
@@ -63,10 +130,135 @@ public class EncoderHelpersTest {
     // Verify concrete types are not used in coder generation.
     // In case of private types this would cause an IllegalAccessError.
     List<PrivateString> data = asList(new PrivateString("1"), new PrivateString("2"));
-    Dataset<PrivateString> dataset = createDataset(data, fromBeamCoder(PrivateString.CODER));
+    Dataset<PrivateString> dataset = createDataset(data, encoderFor(PrivateString.CODER));
+    assertThat(dataset.collect(), equalTo(data.toArray()));
+  }
+
+  @Test
+  public void testBeamWindowedValueEncoderMappings() {
+    BASIC_CASES.forEach(
+        (coder, data) -> {
+          List<WindowedValue<?>> windowed =
+              Lists.transform(data, WindowedValue::valueInGlobalWindow);
+
+          Encoder<?> encoder = windowedValueEncoder(encoderFor(coder), windowEnc);
+          serializeAndDeserialize(windowed.get(0), (Encoder) encoder);
+
+          Dataset<?> dataset = createDataset(windowed, (Encoder) encoder);
+          assertThat(dataset.collect(), equalTo(windowed.toArray()));
+        });
+  }
+
+  @Test
+  public void testCollectionEncoder() {
+    BASIC_CASES.forEach(
+        (coder, data) -> {
+          Encoder<? extends Collection<?>> encoder = collectionEncoder(encoderFor(coder), true);
+          Collection<?> collection = Collections.unmodifiableCollection(data);
+
+          Dataset<Collection<?>> dataset = createDataset(asList(collection), (Encoder) encoder);
+          assertThat(dataset.head(), equalTo(data));
+        });
+  }
+
+  private void testMapEncoder(Class<?> cls, Function<Map<?, ?>, Map<?, ?>> decorator) {
+    BASIC_CASES.forEach(
+        (coder, data) -> {
+          Encoder<?> enc = encoderFor(coder);
+          Encoder<Map<?, ?>> mapEncoder = mapEncoder(enc, enc, (Class) cls);
+          Map<?, ?> map =
+              decorator.apply(
+                  data.stream().filter(notNull()).collect(toMap(identity(), identity())));
+
+          Dataset<Map<?, ?>> dataset = createDataset(asList(map), mapEncoder);
+          Map<?, ?> head = dataset.head();
+          assertThat(head, equalTo(map));
+          assertThat(head, instanceOf(cls));
+        });
+  }
+
+  @Test
+  public void testMapEncoder() {
+    testMapEncoder(Map.class, identity());
+  }
+
+  @Test
+  public void testHashMapEncoder() {
+    testMapEncoder(HashMap.class, identity());
+  }
+
+  @Test
+  public void testTreeMapEncoder() {
+    testMapEncoder(TreeMap.class, TreeMap::new);
+  }
+
+  @Test
+  public void testBeamBinaryEncoder() {
+    List<List<String>> data = asList(asList("a1", "a2", "a3"), asList("b1", "b2"), asList("c1"));
+
+    Encoder<List<String>> encoder = encoderFor(ListCoder.of(StringUtf8Coder.of()));
+    serializeAndDeserialize(data.get(0), encoder);
+
+    Dataset<List<String>> dataset = createDataset(data, encoder);
     assertThat(dataset.collect(), equalTo(data.toArray()));
   }
 
+  @Test
+  public void testEncoderForKVCoder() {
+    List<KV<Integer, String>> data =
+        asList(KV.of(1, "value1"), KV.of(null, "value2"), KV.of(3, null));
+
+    Encoder<KV<Integer, String>> encoder =
+        kvEncoder(encoderFor(VarIntCoder.of()), encoderFor(StringUtf8Coder.of()));
+    serializeAndDeserialize(data.get(0), encoder);
+
+    Dataset<KV<Integer, String>> dataset = createDataset(data, encoder);
+
+    StructType kvSchema =
+        createStructType(
+            new StructField[] {
+              createStructField("key", IntegerType, true),
+              createStructField("value", StringType, true)
+            });
+
+    assertThat(dataset.schema(), equalTo(kvSchema));
+    assertThat(dataset.collectAsList(), equalTo(data));
+  }
+
+  @Test
+  public void testOneOffEncoder() {
+    List<Coder<?>> coders = ImmutableList.copyOf(BASIC_CASES.keySet());
+    List<Encoder<?>> encoders = coders.stream().map(EncoderHelpers::encoderFor).collect(toList());
+
+    // build oneOf tuples of type index and corresponding value
+    List<Tuple2<Integer, ?>> data =
+        BASIC_CASES.entrySet().stream()
+            .map(e -> tuple(coders.indexOf(e.getKey()), (Object) e.getValue().get(0)))
+            .collect(toList());
+
+    // dataset is a sparse dataset with only one column set per row
+    Dataset<Tuple2<Integer, ?>> dataset = createDataset(data, oneOfEncoder((List) encoders));
+    assertThat(dataset.collectAsList(), equalTo(data));
+  }
+
+  // fix scale/precision to system default to compare using equals
+  private static BigDecimal bigDecimalOf(long l) {
+    DecimalType type = DecimalType.SYSTEM_DEFAULT();
+    return new BigDecimal(l, new MathContext(type.precision())).setScale(type.scale());
+  }
+
+  // test and explicit serialization roundtrip
+  private static <T> void serializeAndDeserialize(T data, Encoder<T> enc) {
+    ExpressionEncoder<T> bound = (ExpressionEncoder<T>) enc;
+    bound =
+        bound.resolveAndBind(bound.resolveAndBind$default$1(), bound.resolveAndBind$default$2());
+
+    InternalRow row = bound.createSerializer().apply(data);
+    T deserialized = bound.createDeserializer().apply(row);
+
+    assertThat(deserialized, equalTo(data));
+  }
+
   private static class PrivateString {
     private static final Coder<PrivateString> CODER =
         DelegateCoder.of(
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java
index 95e7865e6f9..461cfc80917 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java
@@ -50,6 +50,12 @@ public interface SparkCommonPipelineOptions
 
   void setCheckpointDir(String checkpointDir);
 
+  @Description("Batch default storage level")
+  @Default.String("MEMORY_ONLY")
+  String getStorageLevel();
+
+  void setStorageLevel(String storageLevel);
+
   @Description("Enable/disable sending aggregator values to Spark's metric sinks")
   @Default.Boolean(true)
   Boolean getEnableSparkMetricSinks();
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
index ecf933336c9..9a5229f21ae 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
@@ -34,12 +34,6 @@ public interface SparkPipelineOptions extends SparkCommonPipelineOptions {
 
   void setBatchIntervalMillis(Long batchInterval);
 
-  @Description("Batch default storage level")
-  @Default.String("MEMORY_ONLY")
-  String getStorageLevel();
-
-  void setStorageLevel(String storageLevel);
-
   @Description("Minimum time to spend on read, for each micro-batch.")
   @Default.Long(200)
   Long getMinReadTimeMillis();