You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by mm...@apache.org on 2022/07/22 08:11:55 UTC

[beam] branch master updated: Fixes #22156: Fix Spark3 runner to compile against Spark 3.2/3.3 and add version tests to verify compatibility going forward (#22157)

This is an automated email from the ASF dual-hosted git repository.

mmack pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 72127f93f45 Fixes #22156: Fix Spark3 runner to compile against Spark 3.2/3.3 and add version tests to verify compatibility going forward (#22157)
72127f93f45 is described below

commit 72127f93f45229cdb62117635563cd8e1709b94b
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Fri Jul 22 10:11:48 2022 +0200

    Fixes #22156: Fix Spark3 runner to compile against Spark 3.2/3.3 and add version tests to verify compatibility going forward (#22157)
---
 .../job_PreCommit_Java_Spark3_Versions.groovy      |  34 +--
 build.gradle.kts                                   |   1 +
 .../translation/helpers/EncoderFactory.java        |  43 ++--
 runners/spark/3/build.gradle                       |  34 +++
 .../translation/batch/DatasetSourceBatch.java      |   2 +-
 .../translation/helpers/EncoderFactory.java        |  59 +++--
 runners/spark/spark_runner.gradle                  |  48 ++--
 .../runners/spark/metrics/AggregatorMetric.java    |  47 +++-
 .../beam/runners/spark/metrics/BeamMetricSet.java  |  59 +++++
 .../runners/spark/metrics/SparkBeamMetric.java     |  93 +++++---
 .../runners/spark/metrics/WithMetricsSupport.java  | 123 ++--------
 .../beam/runners/spark/metrics/sink/CsvSink.java   |  59 ++++-
 .../runners/spark/metrics/sink/GraphiteSink.java   |  66 +++++-
 .../SparkStructuredStreamingPipelineOptions.java   |   7 +
 .../SparkStructuredStreamingPipelineResult.java    |  39 +--
 .../SparkStructuredStreamingRunner.java            |   9 +-
 .../metrics/AggregatorMetric.java                  |  47 +++-
 .../structuredstreaming/metrics/BeamMetricSet.java |  60 +++++
 .../metrics/SparkBeamMetric.java                   |  86 ++++---
 .../metrics/WithMetricsSupport.java                | 127 ++--------
 .../metrics/sink/CodahaleCsvSink.java              |  58 ++++-
 .../metrics/sink/CodahaleGraphiteSink.java         |  61 ++++-
 .../translation/AbstractTranslationContext.java    |  25 +-
 .../translation/SparkSessionFactory.java           |  71 ++++++
 .../translation/helpers/EncoderHelpers.java        | 261 +++------------------
 .../aggregators/metrics/sink/InMemoryMetrics.java  |  33 +--
 .../metrics/sink/SparkMetricsSinkTest.java         |   6 +-
 .../runners/spark/metrics/SparkBeamMetricTest.java |  18 +-
 .../structuredstreaming/SparkSessionRule.java      |  21 +-
 .../aggregators/metrics/sink/InMemoryMetrics.java  |  36 ++-
 .../metrics/sink/SparkMetricsSinkTest.java         |  45 ++--
 ...eamMetricTest.java => SparkBeamMetricTest.java} |  22 +-
 .../translation/helpers/EncoderHelpersTest.java    |  59 ++++-
 .../spark/src/test/resources/metrics.properties    |  68 ------
 .../site/content/en/documentation/runners/spark.md |   6 +-
 35 files changed, 990 insertions(+), 843 deletions(-)

diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java b/.test-infra/jenkins/job_PreCommit_Java_Spark3_Versions.groovy
similarity index 56%
copy from runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java
copy to .test-infra/jenkins/job_PreCommit_Java_Spark3_Versions.groovy
index bc585d8a31e..f13c4c0a1e2 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java
+++ b/.test-infra/jenkins/job_PreCommit_Java_Spark3_Versions.groovy
@@ -15,21 +15,23 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.spark.structuredstreaming;
 
-import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
-import org.apache.beam.sdk.options.Default;
-import org.apache.beam.sdk.options.PipelineOptions;
+import PrecommitJobBuilder
 
-/**
- * Spark runner {@link PipelineOptions} handles Spark execution-related configurations, such as the
- * master address, and other user-related knobs.
- */
-public interface SparkStructuredStreamingPipelineOptions extends SparkCommonPipelineOptions {
-
-  /** Set to true to run the job in test mode. */
-  @Default.Boolean(false)
-  boolean getTestMode();
-
-  void setTestMode(boolean testMode);
-}
+PrecommitJobBuilder builder = new PrecommitJobBuilder(
+    scope: this,
+    nameBase: 'Java_Spark3_Versions',
+    gradleTask: ':runners:spark:3:sparkVersionsTest',
+    gradleSwitches: [
+      '-PdisableSpotlessCheck=true'
+    ], // spotless checked in separate pre-commit
+    triggerPathPatterns: [
+      '^runners/spark/.*$',
+    ],
+    timeoutMins: 120,
+    )
+builder.build {
+  publishers {
+    archiveJunit('**/build/test-results/**/*.xml')
+  }
+}
\ No newline at end of file
diff --git a/build.gradle.kts b/build.gradle.kts
index 7ea18895f77..fe92d64e32b 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -220,6 +220,7 @@ tasks.register("javaHadoopVersionsTest") {
   dependsOn(":sdks:java:io:parquet:hadoopVersionsTest")
   dependsOn(":sdks:java:extensions:sorter:hadoopVersionsTest")
   dependsOn(":runners:spark:2:hadoopVersionsTest")
+  dependsOn(":runners:spark:3:hadoopVersionsTest")
 }
 
 tasks.register("sqlPostCommit") {
diff --git a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
index 325d15075b6..54b400f08d0 100644
--- a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
+++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
@@ -17,38 +17,35 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
-import static org.apache.spark.sql.types.DataTypes.BinaryType;
-
-import java.util.Collections;
-import java.util.List;
-import org.apache.beam.sdk.coders.Coder;
 import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
-import org.apache.spark.sql.catalyst.expressions.BoundReference;
-import org.apache.spark.sql.catalyst.expressions.Cast;
 import org.apache.spark.sql.catalyst.expressions.Expression;
-import org.apache.spark.sql.types.ObjectType;
-import scala.collection.JavaConversions;
-import scala.reflect.ClassTag;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.types.DataType;
+import scala.collection.immutable.List;
+import scala.collection.immutable.Nil$;
+import scala.collection.mutable.WrappedArray;
 import scala.reflect.ClassTag$;
 
 public class EncoderFactory {
 
-  public static <T> Encoder<T> fromBeamCoder(Coder<T> coder) {
-    Class<? super T> clazz = coder.getEncodedTypeDescriptor().getRawType();
-    ClassTag<T> classTag = ClassTag$.MODULE$.apply(clazz);
-    List<Expression> serializers =
-        Collections.singletonList(
-            new EncoderHelpers.EncodeUsingBeamCoder<>(
-                new BoundReference(0, new ObjectType(clazz), true), coder));
-
+  static <T> Encoder<T> create(
+      Expression serializer, Expression deserializer, Class<? super T> clazz) {
+    // TODO Isolate usage of Scala APIs in utility https://github.com/apache/beam/issues/22382
+    List<Expression> serializers = Nil$.MODULE$.$colon$colon(serializer);
     return new ExpressionEncoder<>(
         SchemaHelpers.binarySchema(),
         false,
-        JavaConversions.collectionAsScalaIterable(serializers).toSeq(),
-        new EncoderHelpers.DecodeUsingBeamCoder<>(
-            new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType), classTag, coder),
-        classTag);
+        serializers,
+        deserializer,
+        ClassTag$.MODULE$.apply(clazz));
+  }
+
+  /**
+   * Invoke method {@code fun} on Class {@code cls}, immediately propagating {@code null} if any
+   * input arg is {@code null}.
+   */
+  static Expression invokeIfNotNull(Class<?> cls, String fun, DataType type, Expression... args) {
+    return new StaticInvoke(cls, type, fun, new WrappedArray.ofRef<>(args), true, true);
   }
 }
diff --git a/runners/spark/3/build.gradle b/runners/spark/3/build.gradle
index 1641dc5e01f..2bb5d67eef5 100644
--- a/runners/spark/3/build.gradle
+++ b/runners/spark/3/build.gradle
@@ -28,3 +28,37 @@ project.ext {
 
 // Load the main build script which contains all build logic.
 apply from: "$basePath/spark_runner.gradle"
+
+
+def sparkVersions = [
+    "330": "3.3.0",
+    "321": "3.2.1"
+]
+
+sparkVersions.each { kv ->
+  configurations.create("sparkVersion$kv.key")
+  configurations."sparkVersion$kv.key" {
+    resolutionStrategy {
+      spark.components.each { component -> force "$component:$kv.value" }
+    }
+  }
+
+  dependencies {
+    spark.components.each { component -> "sparkVersion$kv.key" "$component:$kv.value" }
+  }
+
+  tasks.register("sparkVersion${kv.key}Test", Test) {
+    group = "Verification"
+    description = "Verifies code compatibility with Spark $kv.value"
+    classpath = configurations."sparkVersion$kv.key" + sourceSets.test.runtimeClasspath
+    systemProperties test.systemProperties
+
+    include "**/*.class"
+    maxParallelForks 4
+  }
+}
+
+tasks.register("sparkVersionsTest") {
+  group = "Verification"
+  dependsOn sparkVersions.collect{k,v -> "sparkVersion${k}Test"}
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
index f2fd80005fa..46bde96c30c 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
@@ -34,8 +34,8 @@ import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.Sch
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
-import org.apache.parquet.Strings;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.SupportsRead;
 import org.apache.spark.sql.connector.catalog.Table;
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
index 39a71507453..c7d69c0b8ad 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
@@ -17,33 +17,48 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
-import static org.apache.spark.sql.types.DataTypes.BinaryType;
-
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import java.lang.reflect.Constructor;
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
-import org.apache.spark.sql.catalyst.expressions.BoundReference;
-import org.apache.spark.sql.catalyst.expressions.Cast;
 import org.apache.spark.sql.catalyst.expressions.Expression;
-import org.apache.spark.sql.types.ObjectType;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.types.DataType;
+import scala.collection.immutable.Nil$;
+import scala.collection.mutable.WrappedArray;
 import scala.reflect.ClassTag;
-import scala.reflect.ClassTag$;
 
 public class EncoderFactory {
+  // default constructor to reflectively create static invoke expressions
+  private static final Constructor<StaticInvoke> STATIC_INVOKE_CONSTRUCTOR =
+      (Constructor<StaticInvoke>) StaticInvoke.class.getConstructors()[0];
+
+  static <T> ExpressionEncoder<T> create(
+      Expression serializer, Expression deserializer, Class<? super T> clazz) {
+    return new ExpressionEncoder<>(serializer, deserializer, ClassTag.apply(clazz));
+  }
 
-  public static <T> Encoder<T> fromBeamCoder(Coder<T> coder) {
-    Class<? super T> clazz = coder.getEncodedTypeDescriptor().getRawType();
-    ClassTag<T> classTag = ClassTag$.MODULE$.apply(clazz);
-    Expression serializer =
-        new EncoderHelpers.EncodeUsingBeamCoder<>(
-            new BoundReference(0, new ObjectType(clazz), true), coder);
-    Expression deserializer =
-        new EncoderHelpers.DecodeUsingBeamCoder<>(
-            new Cast(
-                new GetColumnByOrdinal(0, BinaryType), BinaryType, scala.Option.<String>empty()),
-            classTag,
-            coder);
-    return new ExpressionEncoder<>(serializer, deserializer, classTag);
+  /**
+   * Invoke method {@code fun} on Class {@code cls}, immediately propagating {@code null} if any
+   * input arg is {@code null}.
+   *
+   * <p>To address breaking interfaces between various version of Spark 3 these are created
+   * reflectively. This is fine as it's just needed once to create the query plan.
+   */
+  static Expression invokeIfNotNull(Class<?> cls, String fun, DataType type, Expression... args) {
+    try {
+      switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) {
+        case 6:
+          // Spark 3.1.x
+          return STATIC_INVOKE_CONSTRUCTOR.newInstance(
+              cls, type, fun, new WrappedArray.ofRef<>(args), true, true);
+        case 8:
+          // Spark 3.2.x, 3.3.x
+          return STATIC_INVOKE_CONSTRUCTOR.newInstance(
+              cls, type, fun, new WrappedArray.ofRef<>(args), Nil$.MODULE$, true, true, true);
+        default:
+          throw new RuntimeException("Unsupported version of Spark");
+      }
+    } catch (IllegalArgumentException | ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
   }
 }
diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle
index 2fac13b0bf7..78b4cdb4521 100644
--- a/runners/spark/spark_runner.gradle
+++ b/runners/spark/spark_runner.gradle
@@ -17,7 +17,6 @@
  */
 
 import groovy.json.JsonOutput
-import java.util.stream.Collectors
 
 apply plugin: 'org.apache.beam.module'
 applyJavaNature(
@@ -123,6 +122,19 @@ test {
   if(project.hasProperty("rerun-tests")) { 	outputs.upToDateWhen {false} }
 }
 
+class SparkComponents {
+  List<String> components
+}
+
+extensions.create('spark', SparkComponents)
+spark.components = [
+    "org.apache.spark:spark-core_$spark_scala_version",
+    "org.apache.spark:spark-network-common_$spark_scala_version",
+    "org.apache.spark:spark-sql_$spark_scala_version",
+    "org.apache.spark:spark-streaming_$spark_scala_version",
+    "org.apache.spark:spark-catalyst_$spark_scala_version"
+]
+
 dependencies {
   implementation project(path: ":model:pipeline", configuration: "shadow")
   implementation project(path: ":sdks:java:core", configuration: "shadow")
@@ -140,13 +152,11 @@ dependencies {
   implementation project(":sdks:java:fn-execution")
   implementation library.java.vendored_grpc_1_43_2
   implementation library.java.vendored_guava_26_0_jre
-  implementation "com.codahale.metrics:metrics-core:3.0.1"
-  provided "org.apache.spark:spark-core_$spark_scala_version:$spark_version"
-  provided "org.apache.spark:spark-network-common_$spark_scala_version:$spark_version"
+  implementation "io.dropwizard.metrics:metrics-core:3.1.5" // version used by Spark 2.4
+  spark.components.each { component ->
+    provided "$component:$spark_version"
+  }
   permitUnusedDeclared "org.apache.spark:spark-network-common_$spark_scala_version:$spark_version"
-  provided "org.apache.spark:spark-sql_$spark_scala_version:$spark_version"
-  provided "org.apache.spark:spark-streaming_$spark_scala_version:$spark_version"
-  provided "org.apache.spark:spark-catalyst_$spark_scala_version:$spark_version"
   if (project.property("spark_scala_version").equals("2.11")) {
     compileOnly "org.scala-lang:scala-library:2.11.12"
     runtimeOnly library.java.jackson_module_scala_2_11
@@ -154,19 +164,15 @@ dependencies {
     compileOnly "org.scala-lang:scala-library:2.12.15"
     runtimeOnly library.java.jackson_module_scala_2_12
   }
-  if (project.property("spark_version").equals("3.1.2")) {
-    compileOnly "org.apache.parquet:parquet-common:1.10.1"
-  }
   // Force paranamer 2.8 to avoid issues when using Scala 2.12
   runtimeOnly "com.thoughtworks.paranamer:paranamer:2.8"
   provided library.java.hadoop_common
   provided library.java.commons_io
   provided library.java.hamcrest
   provided "com.esotericsoftware:kryo-shaded:4.0.2"
-  testImplementation "org.apache.spark:spark-core_$spark_scala_version:$spark_version"
-  testImplementation "org.apache.spark:spark-network-common_$spark_scala_version:$spark_version"
-  testImplementation "org.apache.spark:spark-sql_$spark_scala_version:$spark_version"
-  testImplementation "org.apache.spark:spark-streaming_$spark_scala_version:$spark_version"
+  spark.components.each { component ->
+    testImplementation "$component:$spark_version"
+  }
   testImplementation project(":sdks:java:io:kafka")
   testImplementation project(path: ":sdks:java:core", configuration: "shadowTest")
   // SparkStateInternalsTest extends abstract StateInternalsTest
@@ -194,21 +200,15 @@ dependencies {
 def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing'
 def tempLocation = project.findProperty('tempLocation') ?: 'gs://temp-storage-for-end-to-end-tests'
 
-configurations.testRuntimeClasspath {
-  // Testing the Spark runner causes a StackOverflowError if slf4j-jdk14 is on the classpath
+configurations.all {
+  // Prevent StackOverflowError if slf4j-jdk14 is on the classpath
   exclude group: "org.slf4j", module: "slf4j-jdk14"
+  // Avoid any transitive usage of the old codahale group to make dependency resolution deterministic
+  exclude group: "com.codahale.metrics", module: "metrics-core"
 }
 
-configurations.validatesRunner {
-  // Testing the Spark runner causes a StackOverflowError if slf4j-jdk14 is on the classpath
-  exclude group: "org.slf4j", module: "slf4j-jdk14"
-}
-
-
 hadoopVersions.each { kv ->
   configurations."hadoopVersion$kv.key" {
-    // Testing the Spark runner causes a StackOverflowError if slf4j-jdk14 is on the classpath
-    exclude group: "org.slf4j", module: "slf4j-jdk14"
     resolutionStrategy {
       force "org.apache.hadoop:hadoop-common:$kv.value"
     }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/AggregatorMetric.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/AggregatorMetric.java
index 18a3785ae3d..41db37c92af 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/AggregatorMetric.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/AggregatorMetric.java
@@ -17,23 +17,58 @@
  */
 package org.apache.beam.runners.spark.metrics;
 
+import com.codahale.metrics.Gauge;
 import com.codahale.metrics.Metric;
+import com.codahale.metrics.MetricFilter;
+import java.util.HashMap;
+import java.util.Map;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
-/** An adapter between the {@link NamedAggregators} and Codahale's {@link Metric} interface. */
-public class AggregatorMetric implements Metric {
+/** An adapter between the {@link NamedAggregators} and the Dropwizard {@link Metric} interface. */
+public class AggregatorMetric extends BeamMetricSet {
+
+  private static final Logger LOG = LoggerFactory.getLogger(AggregatorMetric.class);
 
   private final NamedAggregators namedAggregators;
 
-  private AggregatorMetric(final NamedAggregators namedAggregators) {
+  private AggregatorMetric(NamedAggregators namedAggregators) {
     this.namedAggregators = namedAggregators;
   }
 
-  public static AggregatorMetric of(final NamedAggregators namedAggregators) {
+  public static AggregatorMetric of(NamedAggregators namedAggregators) {
     return new AggregatorMetric(namedAggregators);
   }
 
-  NamedAggregators getNamedAggregators() {
-    return namedAggregators;
+  @Override
+  public Map<String, Gauge<Double>> getValue(String prefix, MetricFilter filter) {
+    Map<String, Gauge<Double>> metrics = new HashMap<>();
+    for (Map.Entry<String, ?> entry : namedAggregators.renderAll().entrySet()) {
+      String name = prefix + "." + entry.getKey();
+      Object rawValue = entry.getValue();
+      if (rawValue != null) {
+        try {
+          Gauge<Double> gauge = staticGauge(rawValue);
+          if (filter.matches(name, gauge)) {
+            metrics.put(name, gauge);
+          }
+        } catch (NumberFormatException e) {
+          LOG.warn(
+              "Metric `{}` of type {} can't be reported, conversion to double failed.",
+              name,
+              rawValue.getClass().getSimpleName(),
+              e);
+        }
+      }
+    }
+    return metrics;
+  }
+
+  // Metric type is assumed to be compatible with Double
+  protected Gauge<Double> staticGauge(Object rawValue) throws NumberFormatException {
+    return rawValue instanceof Number
+        ? super.staticGauge((Number) rawValue)
+        : super.staticGauge(Double.parseDouble(rawValue.toString()));
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/BeamMetricSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/BeamMetricSet.java
new file mode 100644
index 00000000000..2e2970fdc7f
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/BeamMetricSet.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.metrics;
+
+import com.codahale.metrics.Gauge;
+import com.codahale.metrics.MetricFilter;
+import java.util.Map;
+
+/**
+ * {@link BeamMetricSet} is a {@link Gauge} that returns a map of multiple metrics which get
+ * flattened in {@link WithMetricsSupport#getGauges()} for usage in {@link
+ * org.apache.spark.metrics.sink.Sink Spark metric sinks}.
+ *
+ * <p>Note: Recent versions of Dropwizard {@link com.codahale.metrics.MetricRegistry MetricRegistry}
+ * do not allow registering arbitrary implementations of {@link com.codahale.metrics.Metric Metrics}
+ * and require usage of {@link Gauge} here.
+ */
+// TODO: turn into MetricRegistry https://github.com/apache/beam/issues/22384
+abstract class BeamMetricSet implements Gauge<Map<String, Gauge<Double>>> {
+
+  @Override
+  public final Map<String, Gauge<Double>> getValue() {
+    return getValue("", MetricFilter.ALL);
+  }
+
+  protected abstract Map<String, Gauge<Double>> getValue(String prefix, MetricFilter filter);
+
+  protected Gauge<Double> staticGauge(Number number) {
+    return new ConstantGauge(number.doubleValue());
+  }
+
+  private static class ConstantGauge implements Gauge<Double> {
+    private final double value;
+
+    ConstantGauge(double value) {
+      this.value = value;
+    }
+
+    @Override
+    public Double getValue() {
+      return value;
+    }
+  }
+}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java
index 298db0fc68a..1eb8349a513 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java
@@ -17,13 +17,17 @@
  */
 package org.apache.beam.runners.spark.metrics;
 
-import static java.util.stream.Collectors.toList;
 import static org.apache.beam.runners.core.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates.not;
 
+import com.codahale.metrics.Gauge;
 import com.codahale.metrics.Metric;
-import java.util.ArrayList;
+import com.codahale.metrics.MetricFilter;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.metrics.DistributionResult;
 import org.apache.beam.sdk.metrics.GaugeResult;
@@ -33,61 +37,72 @@ import org.apache.beam.sdk.metrics.MetricQueryResults;
 import org.apache.beam.sdk.metrics.MetricResult;
 import org.apache.beam.sdk.metrics.MetricResults;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
 
 /**
- * An adapter between the {@link MetricsContainerStepMap} and Codahale's {@link Metric} interface.
+ * An adapter between the {@link MetricsContainerStepMap} and the Dropwizard {@link Metric}
+ * interface.
  */
-public class SparkBeamMetric implements Metric {
+class SparkBeamMetric extends BeamMetricSet {
+
   private static final String ILLEGAL_CHARACTERS = "[^A-Za-z0-9-]";
 
-  static Map<String, ?> renderAll(MetricResults metricResults) {
-    Map<String, Object> metrics = new HashMap<>();
-    MetricQueryResults metricQueryResults = metricResults.allMetrics();
-    for (MetricResult<Long> metricResult : metricQueryResults.getCounters()) {
-      metrics.put(renderName(metricResult), metricResult.getAttempted());
+  @Override
+  public Map<String, Gauge<Double>> getValue(String prefix, MetricFilter filter) {
+    MetricResults metricResults =
+        asAttemptedOnlyMetricResults(MetricsAccumulator.getInstance().value());
+    Map<String, Gauge<Double>> metrics = new HashMap<>();
+    MetricQueryResults allMetrics = metricResults.allMetrics();
+    for (MetricResult<Long> metricResult : allMetrics.getCounters()) {
+      putFiltered(metrics, filter, renderName(prefix, metricResult), metricResult.getAttempted());
     }
-    for (MetricResult<DistributionResult> metricResult : metricQueryResults.getDistributions()) {
+    for (MetricResult<DistributionResult> metricResult : allMetrics.getDistributions()) {
       DistributionResult result = metricResult.getAttempted();
-      metrics.put(renderName(metricResult) + ".count", result.getCount());
-      metrics.put(renderName(metricResult) + ".sum", result.getSum());
-      metrics.put(renderName(metricResult) + ".min", result.getMin());
-      metrics.put(renderName(metricResult) + ".max", result.getMax());
-      metrics.put(renderName(metricResult) + ".mean", result.getMean());
+      String baseName = renderName(prefix, metricResult);
+      putFiltered(metrics, filter, baseName + ".count", result.getCount());
+      putFiltered(metrics, filter, baseName + ".sum", result.getSum());
+      putFiltered(metrics, filter, baseName + ".min", result.getMin());
+      putFiltered(metrics, filter, baseName + ".max", result.getMax());
+      putFiltered(metrics, filter, baseName + ".mean", result.getMean());
     }
-    for (MetricResult<GaugeResult> metricResult : metricQueryResults.getGauges()) {
-      metrics.put(renderName(metricResult), metricResult.getAttempted().getValue());
+    for (MetricResult<GaugeResult> metricResult : allMetrics.getGauges()) {
+      putFiltered(
+          metrics,
+          filter,
+          renderName(prefix, metricResult),
+          metricResult.getAttempted().getValue());
     }
     return metrics;
   }
 
-  Map<String, ?> renderAll() {
-    MetricResults metricResults =
-        asAttemptedOnlyMetricResults(MetricsAccumulator.getInstance().value());
-    return renderAll(metricResults);
-  }
-
   @VisibleForTesting
-  static String renderName(MetricResult<?> metricResult) {
+  @SuppressWarnings("nullness") // ok to have nullable elements on stream
+  static String renderName(String prefix, MetricResult<?> metricResult) {
     MetricKey key = metricResult.getKey();
     MetricName name = key.metricName();
     String step = key.stepName();
+    return Streams.concat(
+            Stream.of(prefix),
+            Stream.of(stripSuffix(normalizePart(step))),
+            Stream.of(name.getNamespace(), name.getName()).map(SparkBeamMetric::normalizePart))
+        .filter(not(Strings::isNullOrEmpty))
+        .collect(Collectors.joining("."));
+  }
 
-    ArrayList<String> pieces = new ArrayList<>();
-
-    if (step != null) {
-      step = step.replaceAll(ILLEGAL_CHARACTERS, "_");
-      if (step.endsWith("_")) {
-        step = step.substring(0, step.length() - 1);
-      }
-      pieces.add(step);
-    }
+  private static @Nullable String normalizePart(@Nullable String str) {
+    return str != null ? str.replaceAll(ILLEGAL_CHARACTERS, "_") : null;
+  }
 
-    pieces.addAll(
-        ImmutableList.of(name.getNamespace(), name.getName()).stream()
-            .map(str -> str.replaceAll(ILLEGAL_CHARACTERS, "_"))
-            .collect(toList()));
+  private static @Nullable String stripSuffix(@Nullable String str) {
+    return str != null && str.endsWith("_") ? str.substring(0, str.length() - 1) : str;
+  }
 
-    return String.join(".", pieces);
+  private void putFiltered(
+      Map<String, Gauge<Double>> metrics, MetricFilter filter, String name, Number value) {
+    Gauge<Double> metric = staticGauge(value);
+    if (filter.matches(name, metric)) {
+      metrics.put(name, metric);
+    }
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/WithMetricsSupport.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/WithMetricsSupport.java
index 1d551f0b12e..a0fc714100a 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/WithMetricsSupport.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/WithMetricsSupport.java
@@ -21,24 +21,13 @@ import com.codahale.metrics.Counter;
 import com.codahale.metrics.Gauge;
 import com.codahale.metrics.Histogram;
 import com.codahale.metrics.Meter;
-import com.codahale.metrics.Metric;
 import com.codahale.metrics.MetricFilter;
 import com.codahale.metrics.MetricRegistry;
 import com.codahale.metrics.Timer;
-import java.util.HashMap;
 import java.util.Map;
 import java.util.SortedMap;
-import org.apache.beam.runners.spark.aggregators.NamedAggregators;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicate;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIterable;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSortedMap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Ordering;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 /**
  * A {@link MetricRegistry} decorator-like that supports {@link AggregatorMetric} and {@link
@@ -47,15 +36,9 @@ import org.slf4j.LoggerFactory;
  * <p>{@link MetricRegistry} is not an interface, so this is not a by-the-book decorator. That said,
  * it delegates all metric related getters to the "decorated" instance.
  */
-@SuppressWarnings({
-  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
-  "keyfor",
-  "nullness"
-}) // TODO(https://github.com/apache/beam/issues/20497)
+@SuppressWarnings({"rawtypes"}) // required by interface
 public class WithMetricsSupport extends MetricRegistry {
 
-  private static final Logger LOG = LoggerFactory.getLogger(WithMetricsSupport.class);
-
   private final MetricRegistry internalMetricRegistry;
 
   private WithMetricsSupport(final MetricRegistry internalMetricRegistry) {
@@ -88,95 +71,21 @@ public class WithMetricsSupport extends MetricRegistry {
 
   @Override
   public SortedMap<String, Gauge> getGauges(final MetricFilter filter) {
-    return new ImmutableSortedMap.Builder<String, Gauge>(
-            Ordering.from(String.CASE_INSENSITIVE_ORDER))
-        .putAll(internalMetricRegistry.getGauges(filter))
-        .putAll(extractGauges(internalMetricRegistry, filter))
-        .build();
-  }
-
-  private Map<String, Gauge> extractGauges(
-      final MetricRegistry metricRegistry, final MetricFilter filter) {
-    Map<String, Gauge> gauges = new HashMap<>();
-
-    // find the AggregatorMetric metrics from within all currently registered metrics
-    final Optional<Map<String, Gauge>> aggregatorMetrics =
-        FluentIterable.from(metricRegistry.getMetrics().entrySet())
-            .firstMatch(isAggregatorMetric())
-            .transform(aggregatorMetricToGauges());
-
-    // find the SparkBeamMetric metrics from within all currently registered metrics
-    final Optional<Map<String, Gauge>> beamMetrics =
-        FluentIterable.from(metricRegistry.getMetrics().entrySet())
-            .firstMatch(isSparkBeamMetric())
-            .transform(beamMetricToGauges());
-
-    if (aggregatorMetrics.isPresent()) {
-      gauges.putAll(Maps.filterEntries(aggregatorMetrics.get(), matches(filter)));
-    }
-
-    if (beamMetrics.isPresent()) {
-      gauges.putAll(Maps.filterEntries(beamMetrics.get(), matches(filter)));
-    }
-
-    return gauges;
-  }
-
-  private Function<Map.Entry<String, Metric>, Map<String, Gauge>> aggregatorMetricToGauges() {
-    return entry -> {
-      final NamedAggregators agg = ((AggregatorMetric) entry.getValue()).getNamedAggregators();
-      final String parentName = entry.getKey();
-      final Map<String, Gauge> gaugeMap = Maps.transformEntries(agg.renderAll(), toGauge());
-      final Map<String, Gauge> fullNameGaugeMap = Maps.newLinkedHashMap();
-      for (Map.Entry<String, Gauge> gaugeEntry : gaugeMap.entrySet()) {
-        fullNameGaugeMap.put(parentName + "." + gaugeEntry.getKey(), gaugeEntry.getValue());
+    ImmutableSortedMap.Builder<String, Gauge> builder =
+        new ImmutableSortedMap.Builder<>(Ordering.from(String.CASE_INSENSITIVE_ORDER));
+
+    Map<String, Gauge> gauges =
+        internalMetricRegistry.getGauges(
+            (n, m) -> filter.matches(n, m) || m instanceof BeamMetricSet);
+
+    for (Map.Entry<String, Gauge> entry : gauges.entrySet()) {
+      Gauge gauge = entry.getValue();
+      if (gauge instanceof BeamMetricSet) {
+        builder.putAll(((BeamMetricSet) gauge).getValue(entry.getKey(), filter));
+      } else {
+        builder.put(entry.getKey(), gauge);
       }
-      return Maps.filterValues(fullNameGaugeMap, Predicates.notNull());
-    };
-  }
-
-  private Function<Map.Entry<String, Metric>, Map<String, Gauge>> beamMetricToGauges() {
-    return entry -> {
-      final Map<String, ?> metrics = ((SparkBeamMetric) entry.getValue()).renderAll();
-      final String parentName = entry.getKey();
-      final Map<String, Gauge> gaugeMap = Maps.transformEntries(metrics, toGauge());
-      final Map<String, Gauge> fullNameGaugeMap = Maps.newLinkedHashMap();
-      for (Map.Entry<String, Gauge> gaugeEntry : gaugeMap.entrySet()) {
-        fullNameGaugeMap.put(parentName + "." + gaugeEntry.getKey(), gaugeEntry.getValue());
-      }
-      return Maps.filterValues(fullNameGaugeMap, Predicates.notNull());
-    };
-  }
-
-  private Maps.EntryTransformer<String, Object, Gauge> toGauge() {
-    return (name, rawValue) ->
-        () -> {
-          // at the moment the metric's type is assumed to be
-          // compatible with Double. While far from perfect, it seems reasonable at
-          // this point in time
-          try {
-            return Double.parseDouble(rawValue.toString());
-          } catch (final Exception e) {
-            LOG.warn(
-                "Failed reporting metric with name [{}], of type [{}], since it could not be"
-                    + " converted to double",
-                name,
-                rawValue.getClass().getSimpleName(),
-                e);
-            return null;
-          }
-        };
-  }
-
-  private Predicate<Map.Entry<String, Gauge>> matches(final MetricFilter filter) {
-    return entry -> filter.matches(entry.getKey(), entry.getValue());
-  }
-
-  private Predicate<Map.Entry<String, Metric>> isAggregatorMetric() {
-    return metricEntry -> (metricEntry.getValue() instanceof AggregatorMetric);
-  }
-
-  private Predicate<Map.Entry<String, Metric>> isSparkBeamMetric() {
-    return metricEntry -> (metricEntry.getValue() instanceof SparkBeamMetric);
+    }
+    return builder.build();
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/sink/CsvSink.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/sink/CsvSink.java
index d87cbd2fd34..d880cd3cf9e 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/sink/CsvSink.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/sink/CsvSink.java
@@ -18,22 +18,69 @@
 package org.apache.beam.runners.spark.metrics.sink;
 
 import com.codahale.metrics.MetricRegistry;
-import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 import java.util.Properties;
 import org.apache.beam.runners.spark.metrics.AggregatorMetric;
 import org.apache.beam.runners.spark.metrics.WithMetricsSupport;
+import org.apache.spark.SecurityManager;
 import org.apache.spark.metrics.sink.Sink;
 
 /**
- * A Spark {@link Sink} that is tailored to report {@link AggregatorMetric} metrics to a CSV file.
+ * A {@link Sink} for <a href="https://spark.apache.org/docs/latest/monitoring.html#metrics">Spark's
+ * metric system</a> that is tailored to report {@link AggregatorMetric}s to a CSV file.
+ *
+ * <p>The sink is configured using Spark configuration parameters, for example:
+ *
+ * <pre>{@code
+ * "spark.metrics.conf.*.sink.csv.class"="org.apache.beam.runners.spark.metrics.sink.CsvSink"
+ * "spark.metrics.conf.*.sink.csv.directory"="<output_directory>"
+ * "spark.metrics.conf.*.sink.csv.period"=10
+ * "spark.metrics.conf.*.sink.csv.unit"=seconds
+ * }</pre>
  */
-// Intentionally overriding parent name because inheritors should replace the parent.
-@SuppressFBWarnings("NM_SAME_SIMPLE_NAME_AS_SUPERCLASS")
-public class CsvSink extends org.apache.spark.metrics.sink.CsvSink {
+public class CsvSink implements Sink {
+
+  // Initialized reflectively as done by Spark's MetricsSystem
+  private final org.apache.spark.metrics.sink.CsvSink delegate;
+
+  /** Constructor for Spark 3.1.x and earlier. */
   public CsvSink(
       final Properties properties,
       final MetricRegistry metricRegistry,
       final org.apache.spark.SecurityManager securityMgr) {
-    super(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr);
+    try {
+      delegate =
+          org.apache.spark.metrics.sink.CsvSink.class
+              .getConstructor(Properties.class, MetricRegistry.class, SecurityManager.class)
+              .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr);
+    } catch (ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  /** Constructor for Spark 3.2.x and later. */
+  public CsvSink(final Properties properties, final MetricRegistry metricRegistry) {
+    try {
+      delegate =
+          org.apache.spark.metrics.sink.CsvSink.class
+              .getConstructor(Properties.class, MetricRegistry.class)
+              .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry));
+    } catch (ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  @Override
+  public void start() {
+    delegate.start();
+  }
+
+  @Override
+  public void stop() {
+    delegate.stop();
+  }
+
+  @Override
+  public void report() {
+    delegate.report();
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/sink/GraphiteSink.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/sink/GraphiteSink.java
index eca1b2ba4b8..0b21554069d 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/sink/GraphiteSink.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/sink/GraphiteSink.java
@@ -18,20 +18,72 @@
 package org.apache.beam.runners.spark.metrics.sink;
 
 import com.codahale.metrics.MetricRegistry;
-import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 import java.util.Properties;
 import org.apache.beam.runners.spark.metrics.AggregatorMetric;
 import org.apache.beam.runners.spark.metrics.WithMetricsSupport;
+import org.apache.spark.SecurityManager;
 import org.apache.spark.metrics.sink.Sink;
 
-/** A Spark {@link Sink} that is tailored to report {@link AggregatorMetric} metrics to Graphite. */
-// Intentionally overriding parent name because inheritors should replace the parent.
-@SuppressFBWarnings("NM_SAME_SIMPLE_NAME_AS_SUPERCLASS")
-public class GraphiteSink extends org.apache.spark.metrics.sink.GraphiteSink {
+/**
+ * A {@link Sink} for <a href="https://spark.apache.org/docs/latest/monitoring.html#metrics">Spark's
+ * metric system</a> that is tailored to report {@link AggregatorMetric}s to Graphite.
+ *
+ * <p>The sink is configured using Spark configuration parameters, for example:
+ *
+ * <pre>{@code
+ * "spark.metrics.conf.*.sink.graphite.class"="org.apache.beam.runners.spark.metrics.sink.GraphiteSink"
+ * "spark.metrics.conf.*.sink.graphite.host"="<graphite_hostname>"
+ * "spark.metrics.conf.*.sink.graphite.port"=<graphite_listening_port>
+ * "spark.metrics.conf.*.sink.graphite.period"=10
+ * "spark.metrics.conf.*.sink.graphite.unit"=seconds
+ * "spark.metrics.conf.*.sink.graphite.prefix"="<optional_prefix>"
+ * "spark.metrics.conf.*.sink.graphite.regex"="<optional_regex_to_send_matching_metrics>"
+ * }</pre>
+ */
+public class GraphiteSink implements Sink {
+
+  // Initialized reflectively as done by Spark's MetricsSystem
+  private final org.apache.spark.metrics.sink.GraphiteSink delegate;
+
+  /** Constructor for Spark 3.1.x and earlier. */
   public GraphiteSink(
       final Properties properties,
       final MetricRegistry metricRegistry,
-      final org.apache.spark.SecurityManager securityMgr) {
-    super(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr);
+      final SecurityManager securityMgr) {
+    try {
+      delegate =
+          org.apache.spark.metrics.sink.GraphiteSink.class
+              .getConstructor(Properties.class, MetricRegistry.class, SecurityManager.class)
+              .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr);
+    } catch (ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  /** Constructor for Spark 3.2.x and later. */
+  public GraphiteSink(final Properties properties, final MetricRegistry metricRegistry) {
+    try {
+      delegate =
+          org.apache.spark.metrics.sink.GraphiteSink.class
+              .getConstructor(Properties.class, MetricRegistry.class)
+              .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry));
+    } catch (ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  @Override
+  public void start() {
+    delegate.start();
+  }
+
+  @Override
+  public void stop() {
+    delegate.stop();
+  }
+
+  @Override
+  public void report() {
+    delegate.report();
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java
index bc585d8a31e..3371a403b2c 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.spark.structuredstreaming;
 
 import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
 import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.Description;
 import org.apache.beam.sdk.options.PipelineOptions;
 
 /**
@@ -32,4 +33,10 @@ public interface SparkStructuredStreamingPipelineOptions extends SparkCommonPipe
   boolean getTestMode();
 
   void setTestMode(boolean testMode);
+
+  @Description("Enable if the runner should use the currently active Spark session.")
+  @Default.Boolean(false)
+  boolean getUseActiveSparkSession();
+
+  void setUseActiveSparkSession(boolean value);
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java
index 663c87af4c6..1392ae8f0c7 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java
@@ -20,7 +20,6 @@ package org.apache.beam.runners.spark.structuredstreaming;
 import static org.apache.beam.runners.core.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults;
 
 import java.io.IOException;
-import java.util.Objects;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
@@ -31,8 +30,6 @@ import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.metrics.MetricResults;
 import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.spark.SparkException;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.streaming.StreamingQuery;
 import org.joda.time.Duration;
 
 /** Represents a Spark pipeline execution result. */
@@ -43,19 +40,16 @@ import org.joda.time.Duration;
 public class SparkStructuredStreamingPipelineResult implements PipelineResult {
 
   final Future pipelineExecution;
-  final SparkSession sparkSession;
-  PipelineResult.State state;
+  final Runnable onTerminalState;
 
-  boolean isStreaming;
+  PipelineResult.State state;
 
   SparkStructuredStreamingPipelineResult(
-      final Future<?> pipelineExecution, final SparkSession sparkSession) {
+      final Future<?> pipelineExecution, final Runnable onTerminalState) {
     this.pipelineExecution = pipelineExecution;
-    this.sparkSession = sparkSession;
+    this.onTerminalState = onTerminalState;
     // pipelineExecution is expected to have started executing eagerly.
     this.state = State.RUNNING;
-    // TODO: Implement results on a streaming pipeline. Currently does not stream.
-    this.isStreaming = false;
   }
 
   private static RuntimeException runtimeExceptionFrom(final Throwable e) {
@@ -79,29 +73,10 @@ public class SparkStructuredStreamingPipelineResult implements PipelineResult {
     return runtimeExceptionFrom(e);
   }
 
-  protected void stop() {
-    try {
-      // TODO: await any outstanding queries on the session if this is streaming.
-      if (isStreaming) {
-        for (StreamingQuery query : sparkSession.streams().active()) {
-          query.stop();
-        }
-      }
-    } catch (Exception e) {
-      throw beamExceptionFrom(e);
-    } finally {
-      sparkSession.stop();
-      if (Objects.equals(state, State.RUNNING)) {
-        this.state = State.STOPPED;
-      }
-    }
-  }
-
   private State awaitTermination(Duration duration)
       throws TimeoutException, ExecutionException, InterruptedException {
     pipelineExecution.get(duration.getMillis(), TimeUnit.MILLISECONDS);
     // Throws an exception if the job is not finished successfully in the given time.
-    // TODO: all streaming functionality
     return PipelineResult.State.DONE;
   }
 
@@ -149,7 +124,11 @@ public class SparkStructuredStreamingPipelineResult implements PipelineResult {
     State oldState = this.state;
     this.state = newState;
     if (!oldState.isTerminal() && newState.isTerminal()) {
-      stop();
+      try {
+        onTerminalState.run();
+      } catch (Exception e) {
+        throw beamExceptionFrom(e);
+      }
     }
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
index d66e0c77186..b1de9e941e4 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
@@ -146,10 +146,12 @@ public final class SparkStructuredStreamingRunner
             });
     executorService.shutdown();
 
-    // TODO: Streaming.
+    Runnable onTerminalState =
+        options.getUseActiveSparkSession()
+            ? () -> {}
+            : () -> translationContext.getSparkSession().stop();
     SparkStructuredStreamingPipelineResult result =
-        new SparkStructuredStreamingPipelineResult(
-            submissionFuture, translationContext.getSparkSession());
+        new SparkStructuredStreamingPipelineResult(submissionFuture, onTerminalState);
 
     if (options.getEnableSparkMetricSinks()) {
       registerMetricsSource(options.getAppName());
@@ -162,7 +164,6 @@ public final class SparkStructuredStreamingRunner
 
     if (options.getTestMode()) {
       result.waitUntilFinish();
-      result.stop();
     }
 
     return result;
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/AggregatorMetric.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/AggregatorMetric.java
index 55590a66fef..74bea7f5255 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/AggregatorMetric.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/AggregatorMetric.java
@@ -17,23 +17,58 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.metrics;
 
+import com.codahale.metrics.Gauge;
 import com.codahale.metrics.Metric;
+import com.codahale.metrics.MetricFilter;
+import java.util.HashMap;
+import java.util.Map;
 import org.apache.beam.runners.spark.structuredstreaming.aggregators.NamedAggregators;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
-/** An adapter between the {@link NamedAggregators} and Codahale's {@link Metric} interface. */
-public class AggregatorMetric implements Metric {
+/** An adapter between the {@link NamedAggregators} and the Dropwizard {@link Metric} interface. */
+public class AggregatorMetric extends BeamMetricSet {
+
+  private static final Logger LOG = LoggerFactory.getLogger(AggregatorMetric.class);
 
   private final NamedAggregators namedAggregators;
 
-  private AggregatorMetric(final NamedAggregators namedAggregators) {
+  private AggregatorMetric(NamedAggregators namedAggregators) {
     this.namedAggregators = namedAggregators;
   }
 
-  public static AggregatorMetric of(final NamedAggregators namedAggregators) {
+  public static AggregatorMetric of(NamedAggregators namedAggregators) {
     return new AggregatorMetric(namedAggregators);
   }
 
-  NamedAggregators getNamedAggregators() {
-    return namedAggregators;
+  @Override
+  public Map<String, Gauge<Double>> getValue(String prefix, MetricFilter filter) {
+    Map<String, Gauge<Double>> metrics = new HashMap<>();
+    for (Map.Entry<String, ?> entry : namedAggregators.renderAll().entrySet()) {
+      String name = prefix + "." + entry.getKey();
+      Object rawValue = entry.getValue();
+      if (rawValue != null) {
+        try {
+          Gauge<Double> gauge = staticGauge(rawValue);
+          if (filter.matches(name, gauge)) {
+            metrics.put(name, gauge);
+          }
+        } catch (NumberFormatException e) {
+          LOG.warn(
+              "Metric `{}` of type {} can't be reported, conversion to double failed.",
+              name,
+              rawValue.getClass().getSimpleName(),
+              e);
+        }
+      }
+    }
+    return metrics;
+  }
+
+  // Metric type is assumed to be compatible with Double
+  protected Gauge<Double> staticGauge(Object rawValue) throws NumberFormatException {
+    return rawValue instanceof Number
+        ? super.staticGauge((Number) rawValue)
+        : super.staticGauge(Double.parseDouble(rawValue.toString()));
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricSet.java
new file mode 100644
index 00000000000..7095036f28a
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricSet.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.metrics;
+
+import com.codahale.metrics.Gauge;
+import com.codahale.metrics.MetricFilter;
+import java.util.Map;
+import org.apache.beam.runners.spark.metrics.WithMetricsSupport;
+
+/**
+ * {@link BeamMetricSet} is a {@link Gauge} that returns a map of multiple metrics which get
+ * flattened in {@link WithMetricsSupport#getGauges()} for usage in {@link
+ * org.apache.spark.metrics.sink.Sink Spark metric sinks}.
+ *
+ * <p>Note: Recent versions of Dropwizard {@link com.codahale.metrics.MetricRegistry MetricRegistry}
+ * do not allow registering arbitrary implementations of {@link com.codahale.metrics.Metric Metrics}
+ * and require usage of {@link Gauge} here.
+ */
+// TODO: turn into MetricRegistry https://github.com/apache/beam/issues/22384
+abstract class BeamMetricSet implements Gauge<Map<String, Gauge<Double>>> {
+
+  @Override
+  public final Map<String, Gauge<Double>> getValue() {
+    return getValue("", MetricFilter.ALL);
+  }
+
+  protected abstract Map<String, Gauge<Double>> getValue(String prefix, MetricFilter filter);
+
+  protected Gauge<Double> staticGauge(Number number) {
+    return new ConstantGauge(number.doubleValue());
+  }
+
+  private static class ConstantGauge implements Gauge<Double> {
+    private final double value;
+
+    ConstantGauge(double value) {
+      this.value = value;
+    }
+
+    @Override
+    public Double getValue() {
+      return value;
+    }
+  }
+}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java
index de146c60f97..0cecae4a25b 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java
@@ -17,14 +17,17 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.metrics;
 
-import static java.util.stream.Collectors.toList;
 import static org.apache.beam.runners.core.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates.not;
 
+import com.codahale.metrics.Gauge;
 import com.codahale.metrics.Metric;
-import java.util.ArrayList;
+import com.codahale.metrics.MetricFilter;
 import java.util.HashMap;
 import java.util.Map;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import javax.annotation.Nullable;
 import org.apache.beam.sdk.metrics.DistributionResult;
 import org.apache.beam.sdk.metrics.GaugeResult;
 import org.apache.beam.sdk.metrics.MetricKey;
@@ -33,57 +36,72 @@ import org.apache.beam.sdk.metrics.MetricQueryResults;
 import org.apache.beam.sdk.metrics.MetricResult;
 import org.apache.beam.sdk.metrics.MetricResults;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
 
 /**
- * An adapter between the {@link MetricsContainerStepMap} and Codahale's {@link Metric} interface.
+ * An adapter between the {@link SparkMetricsContainerStepMap} and the Dropwizard {@link Metric}
+ * interface.
  */
-class SparkBeamMetric implements Metric {
+class SparkBeamMetric extends BeamMetricSet {
+
   private static final String ILLEGAL_CHARACTERS = "[^A-Za-z0-9-]";
 
-  Map<String, ?> renderAll() {
-    Map<String, Object> metrics = new HashMap<>();
+  @Override
+  public Map<String, Gauge<Double>> getValue(String prefix, MetricFilter filter) {
     MetricResults metricResults =
         asAttemptedOnlyMetricResults(MetricsAccumulator.getInstance().value());
-    MetricQueryResults metricQueryResults = metricResults.allMetrics();
-    for (MetricResult<Long> metricResult : metricQueryResults.getCounters()) {
-      metrics.put(renderName(metricResult), metricResult.getAttempted());
+    Map<String, Gauge<Double>> metrics = new HashMap<>();
+    MetricQueryResults allMetrics = metricResults.allMetrics();
+    for (MetricResult<Long> metricResult : allMetrics.getCounters()) {
+      putFiltered(metrics, filter, renderName(prefix, metricResult), metricResult.getAttempted());
     }
-    for (MetricResult<DistributionResult> metricResult : metricQueryResults.getDistributions()) {
+    for (MetricResult<DistributionResult> metricResult : allMetrics.getDistributions()) {
       DistributionResult result = metricResult.getAttempted();
-      metrics.put(renderName(metricResult) + ".count", result.getCount());
-      metrics.put(renderName(metricResult) + ".sum", result.getSum());
-      metrics.put(renderName(metricResult) + ".min", result.getMin());
-      metrics.put(renderName(metricResult) + ".max", result.getMax());
-      metrics.put(renderName(metricResult) + ".mean", result.getMean());
+      String baseName = renderName(prefix, metricResult);
+      putFiltered(metrics, filter, baseName + ".count", result.getCount());
+      putFiltered(metrics, filter, baseName + ".sum", result.getSum());
+      putFiltered(metrics, filter, baseName + ".min", result.getMin());
+      putFiltered(metrics, filter, baseName + ".max", result.getMax());
+      putFiltered(metrics, filter, baseName + ".mean", result.getMean());
     }
-    for (MetricResult<GaugeResult> metricResult : metricQueryResults.getGauges()) {
-      metrics.put(renderName(metricResult), metricResult.getAttempted().getValue());
+    for (MetricResult<GaugeResult> metricResult : allMetrics.getGauges()) {
+      putFiltered(
+          metrics,
+          filter,
+          renderName(prefix, metricResult),
+          metricResult.getAttempted().getValue());
     }
     return metrics;
   }
 
   @VisibleForTesting
-  String renderName(MetricResult<?> metricResult) {
+  @SuppressWarnings("nullness") // ok to have nullable elements on stream
+  static String renderName(String prefix, MetricResult<?> metricResult) {
     MetricKey key = metricResult.getKey();
     MetricName name = key.metricName();
     String step = key.stepName();
+    return Streams.concat(
+            Stream.of(prefix), // prefix is not cleaned, should it be?
+            Stream.of(stripSuffix(normalizePart(step))),
+            Stream.of(name.getNamespace(), name.getName()).map(SparkBeamMetric::normalizePart))
+        .filter(not(Strings::isNullOrEmpty))
+        .collect(Collectors.joining("."));
+  }
 
-    ArrayList<String> pieces = new ArrayList<>();
-
-    if (step != null) {
-      step = step.replaceAll(ILLEGAL_CHARACTERS, "_");
-      if (step.endsWith("_")) {
-        step = step.substring(0, step.length() - 1);
-      }
-      pieces.add(step);
-    }
+  private static @Nullable String normalizePart(@Nullable String str) {
+    return str != null ? str.replaceAll(ILLEGAL_CHARACTERS, "_") : null;
+  }
 
-    pieces.addAll(
-        ImmutableList.of(name.getNamespace(), name.getName()).stream()
-            .map(str -> str.replaceAll(ILLEGAL_CHARACTERS, "_"))
-            .collect(toList()));
+  private static @Nullable String stripSuffix(@Nullable String str) {
+    return str != null && str.endsWith("_") ? str.substring(0, str.length() - 1) : str;
+  }
 
-    return String.join(".", pieces);
+  private void putFiltered(
+      Map<String, Gauge<Double>> metrics, MetricFilter filter, String name, Number value) {
+    Gauge<Double> metric = staticGauge(value);
+    if (filter.matches(name, metric)) {
+      metrics.put(name, metric);
+    }
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
index c1c7b293ff7..d48a229996f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
@@ -21,41 +21,24 @@ import com.codahale.metrics.Counter;
 import com.codahale.metrics.Gauge;
 import com.codahale.metrics.Histogram;
 import com.codahale.metrics.Meter;
-import com.codahale.metrics.Metric;
 import com.codahale.metrics.MetricFilter;
 import com.codahale.metrics.MetricRegistry;
 import com.codahale.metrics.Timer;
-import java.util.HashMap;
 import java.util.Map;
 import java.util.SortedMap;
-import org.apache.beam.runners.spark.structuredstreaming.aggregators.NamedAggregators;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicate;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIterable;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSortedMap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Ordering;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 /**
- * A {@link MetricRegistry} decorator-like that supports {@link AggregatorMetric} and {@link
- * SparkBeamMetric} as {@link Gauge Gauges}.
+ * A {@link MetricRegistry} decorator-like that supports {@link BeamMetricSet}s as {@link Gauge
+ * Gauges}.
  *
  * <p>{@link MetricRegistry} is not an interface, so this is not a by-the-book decorator. That said,
  * it delegates all metric related getters to the "decorated" instance.
  */
-@SuppressWarnings({
-  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
-  "keyfor",
-  "nullness"
-}) // TODO(https://github.com/apache/beam/issues/20497)
+@SuppressWarnings({"rawtypes"}) // required by interface
 public class WithMetricsSupport extends MetricRegistry {
 
-  private static final Logger LOG = LoggerFactory.getLogger(WithMetricsSupport.class);
-
   private final MetricRegistry internalMetricRegistry;
 
   private WithMetricsSupport(final MetricRegistry internalMetricRegistry) {
@@ -88,95 +71,21 @@ public class WithMetricsSupport extends MetricRegistry {
 
   @Override
   public SortedMap<String, Gauge> getGauges(final MetricFilter filter) {
-    return new ImmutableSortedMap.Builder<String, Gauge>(
-            Ordering.from(String.CASE_INSENSITIVE_ORDER))
-        .putAll(internalMetricRegistry.getGauges(filter))
-        .putAll(extractGauges(internalMetricRegistry, filter))
-        .build();
-  }
-
-  private Map<String, Gauge> extractGauges(
-      final MetricRegistry metricRegistry, final MetricFilter filter) {
-    Map<String, Gauge> gauges = new HashMap<>();
-
-    // find the AggregatorMetric metrics from within all currently registered metrics
-    final Optional<Map<String, Gauge>> aggregatorMetrics =
-        FluentIterable.from(metricRegistry.getMetrics().entrySet())
-            .firstMatch(isAggregatorMetric())
-            .transform(aggregatorMetricToGauges());
-
-    // find the SparkBeamMetric metrics from within all currently registered metrics
-    final Optional<Map<String, Gauge>> beamMetrics =
-        FluentIterable.from(metricRegistry.getMetrics().entrySet())
-            .firstMatch(isSparkBeamMetric())
-            .transform(beamMetricToGauges());
-
-    if (aggregatorMetrics.isPresent()) {
-      gauges.putAll(Maps.filterEntries(aggregatorMetrics.get(), matches(filter)));
-    }
-
-    if (beamMetrics.isPresent()) {
-      gauges.putAll(Maps.filterEntries(beamMetrics.get(), matches(filter)));
-    }
-
-    return gauges;
-  }
-
-  private Function<Map.Entry<String, Metric>, Map<String, Gauge>> aggregatorMetricToGauges() {
-    return entry -> {
-      final NamedAggregators agg = ((AggregatorMetric) entry.getValue()).getNamedAggregators();
-      final String parentName = entry.getKey();
-      final Map<String, Gauge> gaugeMap = Maps.transformEntries(agg.renderAll(), toGauge());
-      final Map<String, Gauge> fullNameGaugeMap = Maps.newLinkedHashMap();
-      for (Map.Entry<String, Gauge> gaugeEntry : gaugeMap.entrySet()) {
-        fullNameGaugeMap.put(parentName + "." + gaugeEntry.getKey(), gaugeEntry.getValue());
+    ImmutableSortedMap.Builder<String, Gauge> builder =
+        new ImmutableSortedMap.Builder<>(Ordering.from(String.CASE_INSENSITIVE_ORDER));
+
+    Map<String, Gauge> gauges =
+        internalMetricRegistry.getGauges(
+            (n, m) -> filter.matches(n, m) || m instanceof BeamMetricSet);
+
+    for (Map.Entry<String, Gauge> entry : gauges.entrySet()) {
+      Gauge gauge = entry.getValue();
+      if (gauge instanceof BeamMetricSet) {
+        builder.putAll(((BeamMetricSet) gauge).getValue(entry.getKey(), filter));
+      } else {
+        builder.put(entry.getKey(), gauge);
       }
-      return Maps.filterValues(fullNameGaugeMap, Predicates.notNull());
-    };
-  }
-
-  private Function<Map.Entry<String, Metric>, Map<String, Gauge>> beamMetricToGauges() {
-    return entry -> {
-      final Map<String, ?> metrics = ((SparkBeamMetric) entry.getValue()).renderAll();
-      final String parentName = entry.getKey();
-      final Map<String, Gauge> gaugeMap = Maps.transformEntries(metrics, toGauge());
-      final Map<String, Gauge> fullNameGaugeMap = Maps.newLinkedHashMap();
-      for (Map.Entry<String, Gauge> gaugeEntry : gaugeMap.entrySet()) {
-        fullNameGaugeMap.put(parentName + "." + gaugeEntry.getKey(), gaugeEntry.getValue());
-      }
-      return Maps.filterValues(fullNameGaugeMap, Predicates.notNull());
-    };
-  }
-
-  private Maps.EntryTransformer<String, Object, Gauge> toGauge() {
-    return (name, rawValue) ->
-        () -> {
-          // at the moment the metric's type is assumed to be
-          // compatible with Double. While far from perfect, it seems reasonable at
-          // this point in time
-          try {
-            return Double.parseDouble(rawValue.toString());
-          } catch (final Exception e) {
-            LOG.warn(
-                "Failed reporting metric with name [{}], of type [{}], since it could not be"
-                    + " converted to double",
-                name,
-                rawValue.getClass().getSimpleName(),
-                e);
-            return null;
-          }
-        };
-  }
-
-  private Predicate<Map.Entry<String, Gauge>> matches(final MetricFilter filter) {
-    return entry -> filter.matches(entry.getKey(), entry.getValue());
-  }
-
-  private Predicate<Map.Entry<String, Metric>> isAggregatorMetric() {
-    return metricEntry -> (metricEntry.getValue() instanceof AggregatorMetric);
-  }
-
-  private Predicate<Map.Entry<String, Metric>> isSparkBeamMetric() {
-    return metricEntry -> (metricEntry.getValue() instanceof SparkBeamMetric);
+    }
+    return builder.build();
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleCsvSink.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleCsvSink.java
index 7c1f209619f..c8f9139a2eb 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleCsvSink.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleCsvSink.java
@@ -21,16 +21,66 @@ import com.codahale.metrics.MetricRegistry;
 import java.util.Properties;
 import org.apache.beam.runners.spark.structuredstreaming.metrics.AggregatorMetric;
 import org.apache.beam.runners.spark.structuredstreaming.metrics.WithMetricsSupport;
+import org.apache.spark.SecurityManager;
 import org.apache.spark.metrics.sink.Sink;
 
 /**
- * A Spark {@link Sink} that is tailored to report {@link AggregatorMetric} metrics to a CSV file.
+ * A {@link Sink} for <a href="https://spark.apache.org/docs/latest/monitoring.html#metrics">Spark's
+ * metric system</a> that is tailored to report {@link AggregatorMetric}s to a CSV file.
+ *
+ * <p>The sink is configured using Spark configuration parameters, for example:
+ *
+ * <pre>{@code
+ * "spark.metrics.conf.*.sink.csv.class"="org.apache.beam.runners.spark.structuredstreaming.metrics.sink.CodahaleCsvSink"
+ * "spark.metrics.conf.*.sink.csv.directory"="<output_directory>"
+ * "spark.metrics.conf.*.sink.csv.period"=10
+ * "spark.metrics.conf.*.sink.csv.unit"=seconds
+ * }</pre>
  */
-public class CodahaleCsvSink extends org.apache.spark.metrics.sink.CsvSink {
+public class CodahaleCsvSink implements Sink {
+
+  // Initialized reflectively as done by Spark's MetricsSystem
+  private final org.apache.spark.metrics.sink.CsvSink delegate;
+
+  /** Constructor for Spark 3.1.x and earlier. */
   public CodahaleCsvSink(
       final Properties properties,
       final MetricRegistry metricRegistry,
-      final org.apache.spark.SecurityManager securityMgr) {
-    super(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr);
+      final SecurityManager securityMgr) {
+    try {
+      delegate =
+          org.apache.spark.metrics.sink.CsvSink.class
+              .getConstructor(Properties.class, MetricRegistry.class, SecurityManager.class)
+              .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr);
+    } catch (ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  /** Constructor for Spark 3.2.x and later. */
+  public CodahaleCsvSink(final Properties properties, final MetricRegistry metricRegistry) {
+    try {
+      delegate =
+          org.apache.spark.metrics.sink.CsvSink.class
+              .getConstructor(Properties.class, MetricRegistry.class)
+              .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry));
+    } catch (ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  @Override
+  public void start() {
+    delegate.start();
+  }
+
+  @Override
+  public void stop() {
+    delegate.stop();
+  }
+
+  @Override
+  public void report() {
+    delegate.report();
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleGraphiteSink.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleGraphiteSink.java
index 1dc4644615b..5640c965740 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleGraphiteSink.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleGraphiteSink.java
@@ -21,14 +21,69 @@ import com.codahale.metrics.MetricRegistry;
 import java.util.Properties;
 import org.apache.beam.runners.spark.structuredstreaming.metrics.AggregatorMetric;
 import org.apache.beam.runners.spark.structuredstreaming.metrics.WithMetricsSupport;
+import org.apache.spark.SecurityManager;
 import org.apache.spark.metrics.sink.Sink;
 
-/** A Spark {@link Sink} that is tailored to report {@link AggregatorMetric} metrics to Graphite. */
-public class CodahaleGraphiteSink extends org.apache.spark.metrics.sink.GraphiteSink {
+/**
+ * A {@link Sink} for <a href="https://spark.apache.org/docs/latest/monitoring.html#metrics">Spark's
+ * metric system</a> that is tailored to report {@link AggregatorMetric}s to Graphite.
+ *
+ * <p>The sink is configured using Spark configuration parameters, for example:
+ *
+ * <pre>{@code
+ * "spark.metrics.conf.*.sink.graphite.class"="org.apache.beam.runners.spark.structuredstreaming.metrics.sink.CodahaleGraphiteSink"
+ * "spark.metrics.conf.*.sink.graphite.host"="<graphite_hostname>"
+ * "spark.metrics.conf.*.sink.graphite.port"=<graphite_listening_port>
+ * "spark.metrics.conf.*.sink.graphite.period"=10
+ * "spark.metrics.conf.*.sink.graphite.unit"=seconds
+ * "spark.metrics.conf.*.sink.graphite.prefix"="<optional_prefix>"
+ * "spark.metrics.conf.*.sink.graphite.regex"="<optional_regex_to_send_matching_metrics>"
+ * }</pre>
+ */
+public class CodahaleGraphiteSink implements Sink {
+
+  // Initialized reflectively as done by Spark's MetricsSystem
+  private final org.apache.spark.metrics.sink.GraphiteSink delegate;
+
+  /** Constructor for Spark 3.1.x and earlier. */
   public CodahaleGraphiteSink(
       final Properties properties,
       final MetricRegistry metricRegistry,
       final org.apache.spark.SecurityManager securityMgr) {
-    super(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr);
+    try {
+      delegate =
+          org.apache.spark.metrics.sink.GraphiteSink.class
+              .getConstructor(Properties.class, MetricRegistry.class, SecurityManager.class)
+              .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr);
+    } catch (ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  /** Constructor for Spark 3.2.x and later. */
+  public CodahaleGraphiteSink(final Properties properties, final MetricRegistry metricRegistry) {
+    try {
+      delegate =
+          org.apache.spark.metrics.sink.GraphiteSink.class
+              .getConstructor(Properties.class, MetricRegistry.class)
+              .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry));
+    } catch (ReflectiveOperationException ex) {
+      throw new RuntimeException(ex);
+    }
+  }
+
+  @Override
+  public void start() {
+    delegate.start();
+  }
+
+  @Override
+  public void stop() {
+    delegate.stop();
+  }
+
+  @Override
+  public void report() {
+    delegate.report();
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java
index 766065fd7d7..aed287ba6d5 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java
@@ -38,7 +38,6 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.function.ForeachFunction;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.ForeachWriter;
@@ -75,29 +74,7 @@ public abstract class AbstractTranslationContext {
   private final Map<PCollectionView<?>, Dataset<?>> broadcastDataSets;
 
   public AbstractTranslationContext(SparkStructuredStreamingPipelineOptions options) {
-    SparkConf sparkConf = new SparkConf();
-    sparkConf.setMaster(options.getSparkMaster());
-    sparkConf.setAppName(options.getAppName());
-    if (options.getFilesToStage() != null && !options.getFilesToStage().isEmpty()) {
-      sparkConf.setJars(options.getFilesToStage().toArray(new String[0]));
-    }
-
-    // By default, Spark defines 200 as a number of sql partitions. This seems too much for local
-    // mode, so try to align with value of "sparkMaster" option in this case.
-    // We should not overwrite this value (or any user-defined spark configuration value) if the
-    // user has already configured it.
-    String sparkMaster = options.getSparkMaster();
-    if (sparkMaster != null
-        && sparkMaster.startsWith("local[")
-        && System.getProperty("spark.sql.shuffle.partitions") == null) {
-      int numPartitions =
-          Integer.parseInt(sparkMaster.substring("local[".length(), sparkMaster.length() - 1));
-      if (numPartitions > 0) {
-        sparkConf.set("spark.sql.shuffle.partitions", String.valueOf(numPartitions));
-      }
-    }
-
-    this.sparkSession = SparkSession.builder().config(sparkConf).getOrCreate();
+    this.sparkSession = SparkSessionFactory.getOrCreateSession(options);
     this.serializablePipelineOptions = new SerializablePipelineOptions(options);
     this.datasets = new HashMap<>();
     this.leaves = new HashSet<>();
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkSessionFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkSessionFactory.java
new file mode 100644
index 00000000000..d8430f5f130
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkSessionFactory.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation;
+
+import java.util.List;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
+import org.apache.spark.SparkConf;
+import org.apache.spark.sql.SparkSession;
+
+public class SparkSessionFactory {
+
+  /**
+   * Gets active {@link SparkSession} or creates one using {@link
+   * SparkStructuredStreamingPipelineOptions}.
+   */
+  public static SparkSession getOrCreateSession(SparkStructuredStreamingPipelineOptions options) {
+    if (options.getUseActiveSparkSession()) {
+      return SparkSession.active();
+    }
+    return sessionBuilder(options.getSparkMaster(), options.getAppName(), options.getFilesToStage())
+        .getOrCreate();
+  }
+
+  /** Creates Spark session builder with some optimizations for local mode, e.g. in tests. */
+  public static SparkSession.Builder sessionBuilder(String master) {
+    return sessionBuilder(master, null, null);
+  }
+
+  private static SparkSession.Builder sessionBuilder(
+      String master, @Nullable String appName, @Nullable List<String> jars) {
+    SparkConf sparkConf = new SparkConf();
+    sparkConf.setMaster(master);
+    if (appName != null) {
+      sparkConf.setAppName(appName);
+    }
+    if (jars != null && !jars.isEmpty()) {
+      sparkConf.setJars(jars.toArray(new String[0]));
+    }
+
+    // By default, Spark defines 200 as a number of sql partitions. This seems too much for local
+    // mode, so try to align with value of "sparkMaster" option in this case.
+    // We should not overwrite this value (or any user-defined spark configuration value) if the
+    // user has already configured it.
+    if (master != null
+        && master.startsWith("local[")
+        && System.getProperty("spark.sql.shuffle.partitions") == null) {
+      int numPartitions =
+          Integer.parseInt(master.substring("local[".length(), master.length() - 1));
+      if (numPartitions > 0) {
+        sparkConf.set("spark.sql.shuffle.partitions", String.valueOf(numPartitions));
+      }
+    }
+    return SparkSession.builder().config(sparkConf);
+  }
+}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index 7b2f109d736..68738cf0308 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -19,256 +19,53 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
 import static org.apache.spark.sql.types.DataTypes.BinaryType;
 
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Objects;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.expressions.BoundReference;
 import org.apache.spark.sql.catalyst.expressions.Expression;
-import org.apache.spark.sql.catalyst.expressions.NonSQLExpression;
-import org.apache.spark.sql.catalyst.expressions.UnaryExpression;
-import org.apache.spark.sql.catalyst.expressions.codegen.Block;
-import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator;
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext;
-import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode;
+import org.apache.spark.sql.catalyst.expressions.Literal;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.ObjectType;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import scala.StringContext;
-import scala.collection.JavaConversions;
-import scala.reflect.ClassTag;
+import org.checkerframework.checker.nullness.qual.NonNull;
 
-/** {@link Encoders} utility class. */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
 public class EncoderHelpers {
+  private static final DataType OBJECT_TYPE = new ObjectType(Object.class);
+
   /**
    * Wrap a Beam coder into a Spark Encoder using Catalyst Expression Encoders (which uses java code
    * generation).
    */
   public static <T> Encoder<T> fromBeamCoder(Coder<T> coder) {
-    return EncoderFactory.fromBeamCoder(coder);
+    Class<? super T> clazz = coder.getEncodedTypeDescriptor().getRawType();
+    // Class T could be private, therefore use OBJECT_TYPE to not risk an IllegalAccessError
+    return EncoderFactory.create(
+        beamSerializer(rootRef(OBJECT_TYPE, true), coder),
+        beamDeserializer(rootCol(BinaryType), coder),
+        clazz);
   }
 
-  /**
-   * Catalyst Expression that serializes elements using Beam {@link Coder}.
-   *
-   * @param <T>: Type of elements ot be serialized.
-   */
-  public static class EncodeUsingBeamCoder<T> extends UnaryExpression
-      implements NonSQLExpression, Serializable {
-
-    private final Expression child;
-    private final Coder<T> coder;
-
-    public EncodeUsingBeamCoder(Expression child, Coder<T> coder) {
-      this.child = child;
-      this.coder = coder;
-    }
-
-    @Override
-    public Expression child() {
-      return child;
-    }
-
-    @Override
-    public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
-      String accessCode = ctx.addReferenceObj("coder", coder, coder.getClass().getName());
-      ExprCode input = child.genCode(ctx);
-      String javaType = CodeGenerator.javaType(dataType());
-
-      List<String> parts = new ArrayList<>();
-      List<Object> args = new ArrayList<>();
-      /*
-        CODE GENERATED
-        final ${javaType} ${ev.value} = org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.EncodeUsingBeamCoder.encode(${input.isNull()}, ${input.value}, ${coder});
-      */
-      parts.add("final ");
-      args.add(javaType);
-      parts.add(" ");
-      args.add(ev.value());
-      parts.add(
-          " = org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.EncodeUsingBeamCoder.encode(");
-      args.add(input.isNull());
-      parts.add(", ");
-      args.add(input.value());
-      parts.add(", ");
-      args.add(accessCode);
-      parts.add(");");
-
-      StringContext sc =
-          new StringContext(JavaConversions.collectionAsScalaIterable(parts).toSeq());
-      Block code =
-          new Block.BlockHelper(sc).code(JavaConversions.collectionAsScalaIterable(args).toSeq());
-
-      return ev.copy(input.code().$plus(code), input.isNull(), ev.value());
-    }
-
-    @Override
-    public DataType dataType() {
-      return BinaryType;
-    }
-
-    @Override
-    public Object productElement(int n) {
-      switch (n) {
-        case 0:
-          return child;
-        case 1:
-          return coder;
-        default:
-          throw new ArrayIndexOutOfBoundsException("productElement out of bounds");
-      }
-    }
-
-    @Override
-    public int productArity() {
-      return 2;
-    }
-
-    @Override
-    public boolean canEqual(Object that) {
-      return (that instanceof EncodeUsingBeamCoder);
-    }
-
-    @Override
-    public boolean equals(@Nullable Object o) {
-      if (this == o) {
-        return true;
-      }
-      if (o == null || getClass() != o.getClass()) {
-        return false;
-      }
-      EncodeUsingBeamCoder<?> that = (EncodeUsingBeamCoder<?>) o;
-      return child.equals(that.child) && coder.equals(that.coder);
-    }
-
-    @Override
-    public int hashCode() {
-      return Objects.hash(super.hashCode(), child, coder);
-    }
-
-    /**
-     * Convert value to byte array (invoked by generated code in {@link #doGenCode(CodegenContext,
-     * ExprCode)}).
-     */
-    public static <T> byte[] encode(boolean isNull, @Nullable T value, Coder<T> coder) {
-      return isNull ? null : CoderHelpers.toByteArray(value, coder);
-    }
+  /** Catalyst Expression that serializes elements using Beam {@link Coder}. */
+  private static <T> Expression beamSerializer(Expression obj, Coder<T> coder) {
+    Expression[] args = {obj, lit(coder, Coder.class)};
+    return EncoderFactory.invokeIfNotNull(CoderHelpers.class, "toByteArray", BinaryType, args);
   }
 
-  /**
-   * Catalyst Expression that deserializes elements using Beam {@link Coder}.
-   *
-   * @param <T>: Type of elements ot be serialized.
-   */
-  public static class DecodeUsingBeamCoder<T> extends UnaryExpression
-      implements NonSQLExpression, Serializable {
-
-    private final Expression child;
-    private final ClassTag<T> classTag;
-    private final Coder<T> coder;
-
-    public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, Coder<T> coder) {
-      this.child = child;
-      this.classTag = classTag;
-      this.coder = coder;
-    }
-
-    @Override
-    public Expression child() {
-      return child;
-    }
-
-    @Override
-    public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
-      String accessCode = ctx.addReferenceObj("coder", coder, coder.getClass().getName());
-      ExprCode input = child.genCode(ctx);
-      String javaType = CodeGenerator.javaType(dataType());
-
-      List<String> parts = new ArrayList<>();
-      List<Object> args = new ArrayList<>();
-      /*
-        CODE GENERATED:
-        final ${javaType} ${ev.value} = (${javaType}) org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.DecodeUsingBeamCoder.decode(${input.value}, ${coder});
-      */
-      parts.add("final ");
-      args.add(javaType);
-      parts.add(" ");
-      args.add(ev.value());
-      parts.add(" = (");
-      args.add(javaType);
-      parts.add(
-          ") org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.DecodeUsingBeamCoder.decode(");
-      args.add(input.isNull());
-      parts.add(", ");
-      args.add(input.value());
-      parts.add(", ");
-      args.add(accessCode);
-      parts.add(");");
-
-      StringContext sc =
-          new StringContext(JavaConversions.collectionAsScalaIterable(parts).toSeq());
-      Block code =
-          new Block.BlockHelper(sc).code(JavaConversions.collectionAsScalaIterable(args).toSeq());
-      return ev.copy(input.code().$plus(code), input.isNull(), ev.value());
-    }
-
-    @Override
-    public DataType dataType() {
-      return new ObjectType(classTag.runtimeClass());
-    }
-
-    @Override
-    public Object productElement(int n) {
-      switch (n) {
-        case 0:
-          return child;
-        case 1:
-          return classTag;
-        case 2:
-          return coder;
-        default:
-          throw new ArrayIndexOutOfBoundsException("productElement out of bounds");
-      }
-    }
-
-    @Override
-    public int productArity() {
-      return 3;
-    }
-
-    @Override
-    public boolean canEqual(Object that) {
-      return (that instanceof DecodeUsingBeamCoder);
-    }
+  /** Catalyst Expression that deserializes elements using Beam {@link Coder}. */
+  private static <T> Expression beamDeserializer(Expression bytes, Coder<T> coder) {
+    Expression[] args = {bytes, lit(coder, Coder.class)};
+    return EncoderFactory.invokeIfNotNull(CoderHelpers.class, "fromByteArray", OBJECT_TYPE, args);
+  }
 
-    @Override
-    public boolean equals(@Nullable Object o) {
-      if (this == o) {
-        return true;
-      }
-      if (o == null || getClass() != o.getClass()) {
-        return false;
-      }
-      DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o;
-      return child.equals(that.child) && classTag.equals(that.classTag) && coder.equals(that.coder);
-    }
+  private static Expression rootRef(DataType dt, boolean nullable) {
+    return new BoundReference(0, dt, nullable);
+  }
 
-    @Override
-    public int hashCode() {
-      return Objects.hash(super.hashCode(), child, classTag, coder);
-    }
+  private static Expression rootCol(DataType dt) {
+    return new GetColumnByOrdinal(0, dt);
+  }
 
-    /**
-     * Convert value from byte array (invoked by generated code in {@link #doGenCode(CodegenContext,
-     * ExprCode)}).
-     */
-    public static <T> T decode(boolean isNull, byte @Nullable [] serialized, Coder<T> coder) {
-      return isNull ? null : CoderHelpers.fromByteArray(serialized, coder);
-    }
+  private static <T extends @NonNull Object> Literal lit(T obj, Class<? extends T> cls) {
+    return Literal.fromObject(obj, new ObjectType(cls));
   }
 }
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/InMemoryMetrics.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/InMemoryMetrics.java
index a4b3e5425c2..b69275b2e39 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/InMemoryMetrics.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/InMemoryMetrics.java
@@ -17,11 +17,13 @@
  */
 package org.apache.beam.runners.spark.aggregators.metrics.sink;
 
+import com.codahale.metrics.Gauge;
 import com.codahale.metrics.MetricFilter;
 import com.codahale.metrics.MetricRegistry;
+import java.util.Collection;
 import java.util.Properties;
 import org.apache.beam.runners.spark.metrics.WithMetricsSupport;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.spark.metrics.sink.Sink;
 
 /** An in-memory {@link Sink} implementation for tests. */
@@ -30,6 +32,7 @@ public class InMemoryMetrics implements Sink {
   private static WithMetricsSupport extendedMetricsRegistry;
   private static MetricRegistry internalMetricRegistry;
 
+  // Constructor for Spark 3.1
   @SuppressWarnings("UnusedParameters")
   public InMemoryMetrics(
       final Properties properties,
@@ -39,26 +42,24 @@ public class InMemoryMetrics implements Sink {
     internalMetricRegistry = metricRegistry;
   }
 
-  @SuppressWarnings("TypeParameterUnusedInFormals")
-  public static <T> T valueOf(final String name) {
-    final T retVal;
+  // Constructor for Spark >= 3.2
+  @SuppressWarnings("UnusedParameters")
+  public InMemoryMetrics(final Properties properties, final MetricRegistry metricRegistry) {
+    extendedMetricsRegistry = WithMetricsSupport.forRegistry(metricRegistry);
+    internalMetricRegistry = metricRegistry;
+  }
 
+  @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"})
+  public static <T> T valueOf(final String name) {
     // this might fail in case we have multiple aggregators with the same suffix after
     // the last dot, but it should be good enough for tests.
-    if (extendedMetricsRegistry != null
-        && extendedMetricsRegistry.getGauges().keySet().stream()
-            .anyMatch(Predicates.containsPattern(name + "$")::apply)) {
-      String key =
-          extendedMetricsRegistry.getGauges().keySet().stream()
-              .filter(Predicates.containsPattern(name + "$")::apply)
-              .findFirst()
-              .get();
-      retVal = (T) extendedMetricsRegistry.getGauges().get(key).getValue();
+    if (extendedMetricsRegistry != null) {
+      Collection<Gauge> matches =
+          extendedMetricsRegistry.getGauges((n, m) -> n.endsWith(name)).values();
+      return matches.isEmpty() ? null : (T) Iterables.getOnlyElement(matches).getValue();
     } else {
-      retVal = null;
+      return null;
     }
-
-    return retVal;
   }
 
   @SuppressWarnings("WeakerAccess")
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/SparkMetricsSinkTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/SparkMetricsSinkTest.java
index 0d067b53eb2..edf164b26b8 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/SparkMetricsSinkTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/SparkMetricsSinkTest.java
@@ -34,6 +34,7 @@ import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
@@ -51,7 +52,10 @@ import org.junit.rules.ExternalResource;
  * streaming modes.
  */
 public class SparkMetricsSinkTest {
-  @ClassRule public static SparkContextRule contextRule = new SparkContextRule();
+  @ClassRule
+  public static SparkContextRule contextRule =
+      new SparkContextRule(
+          KV.of("spark.metrics.conf.*.sink.memory.class", InMemoryMetrics.class.getName()));
 
   @Rule public ExternalResource inMemoryMetricsSink = new InMemoryMetricsSinkRule();
 
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricTest.java
index 1851e1db130..df96a657831 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricTest.java
@@ -27,6 +27,7 @@ import org.junit.Test;
 
 /** Test SparkBeamMetric. */
 public class SparkBeamMetricTest {
+
   @Test
   public void testRenderName() {
     MetricResult<Object> metricResult =
@@ -35,10 +36,25 @@ public class SparkBeamMetricTest {
                 "myStep.one.two(three)", MetricName.named("myNameSpace//", "myName()")),
             123,
             456);
-    String renderedName = SparkBeamMetric.renderName(metricResult);
+    String renderedName = SparkBeamMetric.renderName("", metricResult);
     assertThat(
         "Metric name was not rendered correctly",
         renderedName,
         equalTo("myStep_one_two_three.myNameSpace__.myName__"));
   }
+
+  @Test
+  public void testRenderNameWithPrefix() {
+    MetricResult<Object> metricResult =
+        MetricResult.create(
+            MetricKey.create(
+                "myStep.one.two(three)", MetricName.named("myNameSpace//", "myName()")),
+            123,
+            456);
+    String renderedName = SparkBeamMetric.renderName("prefix", metricResult);
+    assertThat(
+        "Metric name was not rendered correctly",
+        renderedName,
+        equalTo("prefix.myStep_one_two_three.myNameSpace__.myName__"));
+  }
 }
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java
index f68df83ac07..33eef26dddd 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java
@@ -23,6 +23,9 @@ import java.io.Serializable;
 import java.util.Arrays;
 import java.util.Map;
 import javax.annotation.Nullable;
+import org.apache.beam.runners.spark.structuredstreaming.translation.SparkSessionFactory;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.values.KV;
 import org.apache.spark.sql.SparkSession;
 import org.junit.rules.ExternalResource;
@@ -34,13 +37,12 @@ public class SparkSessionRule extends ExternalResource implements Serializable {
   private transient @Nullable SparkSession session = null;
 
   public SparkSessionRule(String sparkMaster, Map<String, String> sparkConfig) {
-    builder = SparkSession.builder();
+    builder = SparkSessionFactory.sessionBuilder(sparkMaster);
     sparkConfig.forEach(builder::config);
-    builder.master(sparkMaster);
   }
 
   public SparkSessionRule(KV<String, String>... sparkConfig) {
-    this("local", sparkConfig);
+    this("local[2]", sparkConfig);
   }
 
   public SparkSessionRule(String sparkMaster, KV<String, String>... sparkConfig) {
@@ -54,6 +56,19 @@ public class SparkSessionRule extends ExternalResource implements Serializable {
     return session;
   }
 
+  public PipelineOptions createPipelineOptions() {
+    return configure(TestPipeline.testingPipelineOptions());
+  }
+
+  public PipelineOptions configure(PipelineOptions options) {
+    SparkStructuredStreamingPipelineOptions opts =
+        options.as(SparkStructuredStreamingPipelineOptions.class);
+    opts.setUseActiveSparkSession(true);
+    opts.setRunner(SparkStructuredStreamingRunner.class);
+    opts.setTestMode(true);
+    return opts;
+  }
+
   @Override
   public Statement apply(Statement base, Description description) {
     builder.appName(description.getDisplayName());
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
index 8649e91c761..f994f7712b3 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
@@ -17,22 +17,22 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.aggregators.metrics.sink;
 
+import com.codahale.metrics.Gauge;
 import com.codahale.metrics.MetricFilter;
 import com.codahale.metrics.MetricRegistry;
+import java.util.Collection;
 import java.util.Properties;
 import org.apache.beam.runners.spark.structuredstreaming.metrics.WithMetricsSupport;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.spark.metrics.sink.Sink;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
 
 /** An in-memory {@link Sink} implementation for tests. */
-@RunWith(JUnit4.class)
 public class InMemoryMetrics implements Sink {
 
   private static WithMetricsSupport extendedMetricsRegistry;
   private static MetricRegistry internalMetricRegistry;
 
+  // Constructor for Spark 3.1
   @SuppressWarnings("UnusedParameters")
   public InMemoryMetrics(
       final Properties properties,
@@ -42,26 +42,24 @@ public class InMemoryMetrics implements Sink {
     internalMetricRegistry = metricRegistry;
   }
 
-  @SuppressWarnings("TypeParameterUnusedInFormals")
-  public static <T> T valueOf(final String name) {
-    final T retVal;
+  // Constructor for Spark >= 3.2
+  @SuppressWarnings("UnusedParameters")
+  public InMemoryMetrics(final Properties properties, final MetricRegistry metricRegistry) {
+    extendedMetricsRegistry = WithMetricsSupport.forRegistry(metricRegistry);
+    internalMetricRegistry = metricRegistry;
+  }
 
+  @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"})
+  public static <T> T valueOf(final String name) {
     // this might fail in case we have multiple aggregators with the same suffix after
     // the last dot, but it should be good enough for tests.
-    if (extendedMetricsRegistry != null
-        && extendedMetricsRegistry.getGauges().keySet().stream()
-            .anyMatch(Predicates.containsPattern(name + "$")::apply)) {
-      String key =
-          extendedMetricsRegistry.getGauges().keySet().stream()
-              .filter(Predicates.containsPattern(name + "$")::apply)
-              .findFirst()
-              .get();
-      retVal = (T) extendedMetricsRegistry.getGauges().get(key).getValue();
+    if (extendedMetricsRegistry != null) {
+      Collection<Gauge> matches =
+          extendedMetricsRegistry.getGauges((n, m) -> n.endsWith(name)).values();
+      return matches.isEmpty() ? null : (T) Iterables.getOnlyElement(matches).getValue();
     } else {
-      retVal = null;
+      return null;
     }
-
-    return retVal;
   }
 
   @SuppressWarnings("WeakerAccess")
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/SparkMetricsSinkTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/SparkMetricsSinkTest.java
index 40b5036fe94..2f02656dc37 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/SparkMetricsSinkTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/SparkMetricsSinkTest.java
@@ -21,51 +21,39 @@ import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.nullValue;
 
-import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
+import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
 import org.apache.beam.runners.spark.structuredstreaming.examples.WordCount;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
-import org.junit.BeforeClass;
-import org.junit.Ignore;
+import org.junit.ClassRule;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExternalResource;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
 
-/**
- * TODO: add testInStreamingMode() once streaming support will be implemented.
- *
- * <p>A test that verifies Beam metrics are reported to Spark's metrics sink in both batch and
- * streaming modes.
- */
-@Ignore("Has been failing since at least c350188ef7a8704c7336f3c20a1ab2144abbcd4a")
-@RunWith(JUnit4.class)
+/** A test that verifies Beam metrics are reported to Spark's metrics sink in batch mode. */
 public class SparkMetricsSinkTest {
-  @Rule public ExternalResource inMemoryMetricsSink = new InMemoryMetricsSinkRule();
+
+  @ClassRule
+  public static final SparkSessionRule SESSION =
+      new SparkSessionRule(
+          KV.of("spark.metrics.conf.*.sink.memory.class", InMemoryMetrics.class.getName()));
+
+  @Rule public final ExternalResource inMemoryMetricsSink = new InMemoryMetricsSinkRule();
+
+  @Rule
+  public final TestPipeline pipeline = TestPipeline.fromOptions(SESSION.createPipelineOptions());
 
   private static final ImmutableList<String> WORDS =
       ImmutableList.of("hi there", "hi", "hi sue bob", "hi sue", "", "bob hi");
   private static final ImmutableSet<String> EXPECTED_COUNTS =
       ImmutableSet.of("hi: 5", "there: 1", "sue: 2", "bob: 2");
-  private static Pipeline pipeline;
-
-  @BeforeClass
-  public static void beforeClass() {
-    SparkStructuredStreamingPipelineOptions options =
-        PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
-    options.setRunner(SparkStructuredStreamingRunner.class);
-    options.setTestMode(true);
-    pipeline = Pipeline.create(options);
-  }
 
   @Test
   public void testInBatchMode() throws Exception {
@@ -76,9 +64,10 @@ public class SparkMetricsSinkTest {
             .apply(Create.of(WORDS).withCoder(StringUtf8Coder.of()))
             .apply(new WordCount.CountWords())
             .apply(MapElements.via(new WordCount.FormatAsTextFn()));
+
     PAssert.that(output).containsInAnyOrder(EXPECTED_COUNTS);
     pipeline.run();
 
-    assertThat(InMemoryMetrics.<Double>valueOf("emptyLines"), is(1d));
+    assertThat(InMemoryMetrics.valueOf("emptyLines"), is(1d));
   }
 }
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricTest.java
similarity index 71%
rename from runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricTest.java
rename to runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricTest.java
index a6989348e16..fd0aa35e5c8 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricTest.java
@@ -24,12 +24,9 @@ import org.apache.beam.sdk.metrics.MetricKey;
 import org.apache.beam.sdk.metrics.MetricName;
 import org.apache.beam.sdk.metrics.MetricResult;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
 
 /** Test BeamMetric. */
-@RunWith(JUnit4.class)
-public class BeamMetricTest {
+public class SparkBeamMetricTest {
   @Test
   public void testRenderName() {
     MetricResult<Object> metricResult =
@@ -38,10 +35,25 @@ public class BeamMetricTest {
                 "myStep.one.two(three)", MetricName.named("myNameSpace//", "myName()")),
             123,
             456);
-    String renderedName = new SparkBeamMetric().renderName(metricResult);
+    String renderedName = SparkBeamMetric.renderName("", metricResult);
     assertThat(
         "Metric name was not rendered correctly",
         renderedName,
         equalTo("myStep_one_two_three.myNameSpace__.myName__"));
   }
+
+  @Test
+  public void testRenderNameWithPrefix() {
+    MetricResult<Object> metricResult =
+        MetricResult.create(
+            MetricKey.create(
+                "myStep.one.two(three)", MetricName.named("myNameSpace//", "myName()")),
+            123,
+            456);
+    String renderedName = SparkBeamMetric.renderName("prefix", metricResult);
+    assertThat(
+        "Metric name was not rendered correctly",
+        renderedName,
+        equalTo("prefix.myStep_one_two_three.myNameSpace__.myName__"));
+  }
 }
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
index 3151a5fe956..c8a8fba8d28 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
@@ -17,13 +17,23 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
+import static java.util.Arrays.asList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.fromBeamCoder;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
 import static org.junit.Assert.assertEquals;
 
 import java.util.Arrays;
 import java.util.List;
+import java.util.Objects;
 import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.DelegateCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
 import org.junit.ClassRule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -35,13 +45,54 @@ public class EncoderHelpersTest {
 
   @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule();
 
+  private <T> Dataset<T> createDataset(List<T> data, Encoder<T> encoder) {
+    Dataset<T> ds = sessionRule.getSession().createDataset(data, encoder);
+    ds.printSchema();
+    return ds;
+  }
+
   @Test
   public void beamCoderToSparkEncoderTest() {
     List<Integer> data = Arrays.asList(1, 2, 3);
-    Dataset<Integer> dataset =
-        sessionRule
-            .getSession()
-            .createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
+    Dataset<Integer> dataset = createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
     assertEquals(data, dataset.collectAsList());
   }
+
+  @Test
+  public void testBeamEncoderOfPrivateType() {
+    // Verify concrete types are not used in coder generation.
+    // In case of private types this would cause an IllegalAccessError.
+    List<PrivateString> data = asList(new PrivateString("1"), new PrivateString("2"));
+    Dataset<PrivateString> dataset = createDataset(data, fromBeamCoder(PrivateString.CODER));
+    assertThat(dataset.collect(), equalTo(data.toArray()));
+  }
+
+  private static class PrivateString {
+    private static final Coder<PrivateString> CODER =
+        DelegateCoder.of(
+            StringUtf8Coder.of(),
+            str -> str.string,
+            PrivateString::new,
+            new TypeDescriptor<PrivateString>() {});
+
+    private final String string;
+
+    public PrivateString(String string) {
+      this.string = string;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      PrivateString that = (PrivateString) o;
+      return Objects.equals(string, that.string);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(string);
+    }
+  }
 }
diff --git a/runners/spark/src/test/resources/metrics.properties b/runners/spark/src/test/resources/metrics.properties
deleted file mode 100644
index 78705c253df..00000000000
--- a/runners/spark/src/test/resources/metrics.properties
+++ /dev/null
@@ -1,68 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-# The "org.apache.beam.runners.spark.metrics.sink.XSink"
-# (a.k.a Beam.XSink) is only configured for the driver, the executors are set with a Spark native
-# implementation "org.apache.spark.metrics.sink.XSink" (a.k.a Spark.XSink).
-# This is due to sink class loading behavior, which is different on the driver and executors nodes.
-# Since Beam aggregators and metrics are reported via Spark accumulators and thus make their way to
-# the # driver, we only need the "Beam.XSink" on the driver side. Executor nodes can keep
-# reporting Spark native metrics using the traditional Spark.XSink.
-#
-# The current sink configuration pattern is therefore:
-#
-# driver.**.class   = Beam.XSink
-# executor.**.class = Spark.XSink
-
-
-# ************* A metrics sink for tests *************
-*.sink.memory.class=org.apache.beam.runners.spark.aggregators.metrics.sink.InMemoryMetrics
-# ************* End of InMemoryMetrics sink configuration section *************
-
-
-# ************* A sample configuration for outputting metrics to Graphite *************
-
-#driver.sink.graphite.class=org.apache.beam.runners.spark.metrics.sink.GraphiteSink
-#driver.sink.graphite.host=YOUR_HOST
-#driver.sink.graphite.port=2003
-#driver.sink.graphite.prefix=spark
-#driver.sink.graphite.period=1
-#driver.sink.graphite.unit=SECONDS
-
-#executor.sink.graphite.class=org.apache.spark.metrics.sink.GraphiteSink
-#executor.sink.graphite.host=YOUR_HOST
-#executor.sink.graphite.port=2003
-#executor.sink.graphite.prefix=spark
-#executor.sink.graphite.period=1
-#executor.sink.graphite.unit=SECONDS
-
-# ************* End of Graphite sik configuration section *************
-
-
-# ************* A sample configuration for outputting metrics to a CSV file. *************
-
-#driver.sink.csv.class=org.apache.beam.runners.spark.metrics.sink.CsvSink
-#driver.sink.csv.directory=/tmp/spark-metrics
-#driver.sink.csv.period=1
-#driver.sink.graphite.unit=SECONDS
-
-#executor.sink.csv.class=org.apache.spark.metrics.sink.CsvSink
-#executor.sink.csv.directory=/tmp/spark-metrics
-#executor.sink.csv.period=1
-#executor.sink.graphite.unit=SECONDS
-
-# ************* End of CSV sink configuration section *************
diff --git a/website/www/site/content/en/documentation/runners/spark.md b/website/www/site/content/en/documentation/runners/spark.md
index abc1031840b..b5caeace2b2 100644
--- a/website/www/site/content/en/documentation/runners/spark.md
+++ b/website/www/site/content/en/documentation/runners/spark.md
@@ -443,7 +443,11 @@ You can monitor a running Spark job using the Spark [Web Interfaces](https://spa
 Spark also has a history server to [view after the fact](https://spark.apache.org/docs/latest/monitoring.html#viewing-after-the-fact).
 {{< paragraph class="language-java" >}}
 Metrics are also available via [REST API](https://spark.apache.org/docs/latest/monitoring.html#rest-api).
-Spark provides a [metrics system](https://spark.apache.org/docs/latest/monitoring.html#metrics) that allows reporting Spark metrics to a variety of Sinks. The Spark runner reports user-defined Beam Aggregators using this same metrics system and currently supports <code>GraphiteSink</code> and <code>CSVSink</code>, and providing support for additional Sinks supported by Spark is easy and straight-forward.
+Spark provides a [metrics system](https://spark.apache.org/docs/latest/monitoring.html#metrics) that allows reporting Spark metrics to a variety of Sinks.
+The Spark runner reports user-defined Beam Aggregators using this same metrics system and currently supports
+[GraphiteSink](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/org/apache/beam/runners/spark/metrics/sink/GraphiteSink.html)
+and [CSVSink](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/org/apache/beam/runners/spark/metrics/sink/CsvSink.html).
+Providing support for additional Sinks supported by Spark is easy and straight-forward.
 {{< /paragraph >}}
 {{< paragraph class="language-py" >}}Spark metrics are not yet supported on the portable runner.{{< /paragraph >}}