You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ar...@apache.org on 2019/03/20 09:16:50 UTC

[beam] branch spark-runner_structured-streaming updated: Added SideInput support

This is an automated email from the ASF dual-hosted git repository.

aromanenko pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/spark-runner_structured-streaming by this push:
     new 49ab275  Added SideInput support
49ab275 is described below

commit 49ab27554bc6fc44f5f5f23c5d0a6535fb4a158d
Author: Alexey Romanenko <ar...@gmail.com>
AuthorDate: Tue Mar 19 19:33:11 2019 +0100

    Added SideInput support
---
 .../translation/TranslationContext.java            |   5 +
 .../translation/batch/DoFnFunction.java            |  11 +-
 .../translation/batch/ParDoTranslatorBatch.java    |  48 +++++--
 .../batch/functions/NoOpSideInputReader.java       |  56 --------
 .../batch/functions/SparkSideInputReader.java      | 148 +++++++++++++++++++++
 .../translation/helpers/CoderHelpers.java          |  47 +++++++
 .../translation/helpers/SideInputBroadcast.java    |  28 ++++
 .../translation/batch/ParDoTest.java               |  80 +++++++++--
 8 files changed, 339 insertions(+), 84 deletions(-)

diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
index 013ef75..d2ace25 100644
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
@@ -139,6 +139,11 @@ public class TranslationContext {
     }
   }
 
+  @SuppressWarnings("unchecked")
+  public <T> Dataset<T> getSideInputDataSet(PCollectionView<?> value) {
+    return (Dataset<T>) broadcastDataSets.get(value);
+  }
+
   // --------------------------------------------------------------------------------------------
   //  PCollections methods
   // --------------------------------------------------------------------------------------------
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
index 0409a79..4449082 100644
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
@@ -28,11 +28,11 @@ import java.util.Map;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpSideInputReader;
+import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
 import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
@@ -62,6 +62,7 @@ public class DoFnFunction<InputT, OutputT>
   private final TupleTag<OutputT> mainOutputTag;
   private final Coder<InputT> inputCoder;
   private final Map<TupleTag<?>, Coder<?>> outputCoderMap;
+  private final SideInputBroadcast broadcastStateData;
 
   public DoFnFunction(
       DoFn<InputT, OutputT> doFn,
@@ -71,7 +72,8 @@ public class DoFnFunction<InputT, OutputT>
       List<TupleTag<?>> additionalOutputTags,
       TupleTag<OutputT> mainOutputTag,
       Coder<InputT> inputCoder,
-      Map<TupleTag<?>, Coder<?>> outputCoderMap) {
+      Map<TupleTag<?>, Coder<?>> outputCoderMap,
+      SideInputBroadcast broadcastStateData) {
 
     this.doFn = doFn;
     this.sideInputs = sideInputs;
@@ -81,6 +83,7 @@ public class DoFnFunction<InputT, OutputT>
     this.mainOutputTag = mainOutputTag;
     this.inputCoder = inputCoder;
     this.outputCoderMap = outputCoderMap;
+    this.broadcastStateData = broadcastStateData;
   }
 
   @Override
@@ -93,7 +96,7 @@ public class DoFnFunction<InputT, OutputT>
         DoFnRunners.simpleRunner(
             serializableOptions.get(),
             doFn,
-            new NoOpSideInputReader(sideInputs),
+            new SparkSideInputReader(sideInputs, broadcastStateData),
             outputManager,
             mainOutputTag,
             additionalOutputTags,
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index 443ed67..651901a 100644
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -21,18 +21,20 @@ import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.collect.Lists;
 import java.io.IOException;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
+
 import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
@@ -40,6 +42,7 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.FilterFunction;
 import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.sql.Dataset;
@@ -72,12 +75,6 @@ class ParDoTranslatorBatch<InputT, OutputT>
         signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0;
     checkState(!stateful, "States and timers are not supported for the moment.");
 
-    // TODO: add support of SideInputs
-    List<PCollectionView<?>> sideInputs = getSideInputs(context);
-    System.out.println("sideInputs = " + sideInputs);
-    final boolean hasSideInputs = sideInputs != null && sideInputs.size() > 0;
-    checkState(!hasSideInputs, "SideInputs are not supported for the moment.");
-
     // Init main variables
     Dataset<WindowedValue<InputT>> inputDataSet = context.getDataset(context.getInput());
     Map<TupleTag<?>, PValue> outputs = context.getOutputs();
@@ -88,11 +85,14 @@ class ParDoTranslatorBatch<InputT, OutputT>
 
     // construct a map from side input to WindowingStrategy so that
     // the DoFn runner can map main-input windows to side input windows
+    List<PCollectionView<?>> sideInputs = getSideInputs(context);
     Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputStrategies = new HashMap<>();
     for (PCollectionView<?> sideInput : sideInputs) {
       sideInputStrategies.put(sideInput, sideInput.getPCollection().getWindowingStrategy());
     }
 
+    SideInputBroadcast broadcastStateData = createBroadcastSideInputs(sideInputs, context);
+
     Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
     Coder<InputT> inputCoder = ((PCollection<InputT>) context.getInput()).getCoder();
 
@@ -106,7 +106,9 @@ class ParDoTranslatorBatch<InputT, OutputT>
             outputTags,
             mainOutputTag,
             inputCoder,
-            outputCoderMap);
+            outputCoderMap,
+            broadcastStateData
+        );
 
     Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs =
         inputDataSet.mapPartitions(doFnWrapper, EncoderHelpers.tuple2Encoder());
@@ -116,6 +118,32 @@ class ParDoTranslatorBatch<InputT, OutputT>
     }
   }
 
+  private static SideInputBroadcast createBroadcastSideInputs(
+      List<PCollectionView<?>> sideInputs, TranslationContext context) {
+    JavaSparkContext jsc =
+        JavaSparkContext.fromSparkContext(context.getSparkSession().sparkContext());
+
+    SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
+    for (PCollectionView<?> input : sideInputs) {
+      Coder<? extends BoundedWindow> windowCoder =
+          input.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
+      Coder<WindowedValue<?>> windowedValueCoder =
+          (Coder<WindowedValue<?>>)
+              (Coder<?>) WindowedValue.getFullCoder(input.getPCollection().getCoder(), windowCoder);
+
+      Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataSet(input);
+      List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
+      List<byte[]> codedValues = new ArrayList<>();
+      for (WindowedValue<?> v : valuesList) {
+        codedValues.add(CoderHelpers.toByteArray(v, windowedValueCoder));
+      }
+
+      sideInputBroadcast.add(
+          input.getTagInternal().getId(), jsc.broadcast(codedValues), windowedValueCoder);
+    }
+    return sideInputBroadcast;
+  }
+
   private List<PCollectionView<?>> getSideInputs(TranslationContext context) {
     List<PCollectionView<?>> sideInputs;
     try {
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpSideInputReader.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpSideInputReader.java
deleted file mode 100644
index eca9d95..0000000
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpSideInputReader.java
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions;
-
-import java.util.HashMap;
-import java.util.Map;
-import javax.annotation.Nullable;
-import org.apache.beam.runners.core.SideInputReader;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-
-/**
- * TODO: Need to be implemented
- *
- * <p>A {@link SideInputReader} for the Spark Batch Runner.
- */
-public class NoOpSideInputReader implements SideInputReader {
-  private final Map<TupleTag<?>, WindowingStrategy<?, ?>> sideInputs;
-
-  public NoOpSideInputReader(Map<PCollectionView<?>, WindowingStrategy<?, ?>> indexByView) {
-    sideInputs = new HashMap<>();
-  }
-
-  @Nullable
-  @Override
-  public <T> T get(PCollectionView<T> view, BoundedWindow window) {
-    return null;
-  }
-
-  @Override
-  public <T> boolean contains(PCollectionView<T> view) {
-    return sideInputs.containsKey(view.getTagInternal());
-  }
-
-  @Override
-  public boolean isEmpty() {
-    return sideInputs.isEmpty();
-  }
-}
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java
new file mode 100644
index 0000000..91b4f83
--- /dev/null
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions;
+
+import java.util.*;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+import org.apache.beam.runners.core.InMemoryMultimapSideInputView;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Materializations;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.SparkConf;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+
+/** A {@link SideInputReader} for the Spark Batch Runner. */
+public class SparkSideInputReader implements SideInputReader {
+  /** A {@link Materializations.MultimapView} which always returns an empty iterable. */
+  private static final Materializations.MultimapView EMPTY_MULTMAP_VIEW =
+      o -> Collections.EMPTY_LIST;
+
+  private final Map<TupleTag<?>, WindowingStrategy<?, ?>> sideInputs;
+  private final SideInputBroadcast broadcastStateData;
+
+  public SparkSideInputReader(
+      Map<PCollectionView<?>, WindowingStrategy<?, ?>> indexByView,
+      SideInputBroadcast broadcastStateData) {
+    for (PCollectionView<?> view : indexByView.keySet()) {
+      checkArgument(
+          Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
+              view.getViewFn().getMaterialization().getUrn()),
+          "This handler is only capable of dealing with %s materializations "
+              + "but was asked to handle %s for PCollectionView with tag %s.",
+          Materializations.MULTIMAP_MATERIALIZATION_URN,
+          view.getViewFn().getMaterialization().getUrn(),
+          view.getTagInternal().getId());
+    }
+    sideInputs = new HashMap<>();
+    for (Map.Entry<PCollectionView<?>, WindowingStrategy<?, ?>> entry : indexByView.entrySet()) {
+      sideInputs.put(entry.getKey().getTagInternal(), entry.getValue());
+    }
+    this.broadcastStateData = broadcastStateData;
+  }
+
+  @Nullable
+  @Override
+  public <T> T get(PCollectionView<T> view, BoundedWindow window) {
+    checkNotNull(view, "View passed to sideInput cannot be null");
+    TupleTag<?> tag = view.getTagInternal();
+    checkNotNull(sideInputs.get(tag), "Side input for " + view + " not available.");
+
+    List<byte[]> sideInputsValues =
+        (List<byte[]>) broadcastStateData.getBroadcastValue(tag.getId()).getValue();
+    Coder<?> coder = broadcastStateData.getCoder(tag.getId());
+
+    List<WindowedValue<?>> decodedValues = new ArrayList<>();
+    for (byte[] value : sideInputsValues) {
+      decodedValues.add((WindowedValue<?>) CoderHelpers.fromByteArray(value, coder));
+    }
+
+    Map<BoundedWindow, T> sideInputs = initializeBroadcastVariable(decodedValues, view);
+    T result = sideInputs.get(window);
+    if (result == null) {
+      ViewFn<Materializations.MultimapView, T> viewFn =
+          (ViewFn<Materializations.MultimapView, T>) view.getViewFn();
+      result = viewFn.apply(EMPTY_MULTMAP_VIEW);
+    }
+    return result;
+  }
+
+  @Override
+  public <T> boolean contains(PCollectionView<T> view) {
+    return sideInputs.containsKey(view.getTagInternal());
+  }
+
+  @Override
+  public boolean isEmpty() {
+    return sideInputs.isEmpty();
+  }
+
+  public <T> Map<BoundedWindow, T> initializeBroadcastVariable(
+      Iterable<WindowedValue<?>> inputValues, PCollectionView<T> view) {
+
+    // first partition into windows
+    Map<BoundedWindow, List<WindowedValue<KV<?, ?>>>> partitionedElements = new HashMap<>();
+    for (WindowedValue<KV<?, ?>> value :
+        (Iterable<WindowedValue<KV<?, ?>>>) (Iterable) inputValues) {
+      for (BoundedWindow window : value.getWindows()) {
+        List<WindowedValue<KV<?, ?>>> windowedValues =
+            partitionedElements.computeIfAbsent(window, k -> new ArrayList<>());
+        windowedValues.add(value);
+      }
+    }
+
+    Map<BoundedWindow, T> resultMap = new HashMap<>();
+
+    for (Map.Entry<BoundedWindow, List<WindowedValue<KV<?, ?>>>> elements :
+        partitionedElements.entrySet()) {
+
+      ViewFn<Materializations.MultimapView, T> viewFn =
+          (ViewFn<Materializations.MultimapView, T>) view.getViewFn();
+      Coder keyCoder = ((KvCoder<?, ?>) view.getCoderInternal()).getKeyCoder();
+      resultMap.put(
+          elements.getKey(),
+          (T)
+              viewFn.apply(
+                  InMemoryMultimapSideInputView.fromIterable(
+                      keyCoder,
+                      (Iterable)
+                          elements.getValue().stream()
+                              .map(WindowedValue::getValue)
+                              .collect(Collectors.toList()))));
+    }
+
+    return resultMap;
+  }
+}
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
new file mode 100644
index 0000000..6764dd8
--- /dev/null
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
@@ -0,0 +1,47 @@
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import org.apache.beam.sdk.coders.Coder;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+
+/** Serialization utility class. */
+public final class CoderHelpers {
+  private CoderHelpers() {}
+
+  /**
+   * Utility method for serializing an object using the specified coder.
+   *
+   * @param value Value to serialize.
+   * @param coder Coder to serialize with.
+   * @param <T> type of value that is serialized
+   * @return Byte array representing serialized object.
+   */
+  public static <T> byte[] toByteArray(T value, Coder<T> coder) {
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    try {
+      coder.encode(value, baos, new Coder.Context(true));
+    } catch (IOException e) {
+      throw new IllegalStateException("Error encoding value: " + value, e);
+    }
+    return baos.toByteArray();
+  }
+
+  /**
+   * Utility method for deserializing a byte array using the specified coder.
+   *
+   * @param serialized bytearray to be deserialized.
+   * @param coder Coder to deserialize with.
+   * @param <T> Type of object to be returned.
+   * @return Deserialized object.
+   */
+  public static <T> T fromByteArray(byte[] serialized, Coder<T> coder) {
+    ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
+    try {
+      return coder.decode(bais, new Coder.Context(true));
+    } catch (IOException e) {
+      throw new IllegalStateException("Error decoding bytes for coder: " + coder, e);
+    }
+  }
+}
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SideInputBroadcast.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SideInputBroadcast.java
new file mode 100644
index 0000000..a67a595
--- /dev/null
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SideInputBroadcast.java
@@ -0,0 +1,28 @@
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.spark.broadcast.Broadcast;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+public class SideInputBroadcast implements Serializable {
+
+  private Map<String, Broadcast<?>> bcast = new HashMap<>();
+  private Map<String, Coder<?>> coder = new HashMap<>();
+
+  public SideInputBroadcast(){}
+
+  public void add(String key, Broadcast<?> bcast, Coder<?> coder) {
+    this.bcast.put(key, bcast);
+    this.coder.put(key, coder);
+  }
+
+  public Broadcast<?> getBroadcastValue(String key) {
+    return bcast.get(key);
+  }
+
+  public Coder<?> getCoder(String key) {
+    return coder.get(key);
+  }
+}
diff --git a/runners/spark-structured-streaming/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java b/runners/spark-structured-streaming/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
index c028dc0..b7a682d 100644
--- a/runners/spark-structured-streaming/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
+++ b/runners/spark-structured-streaming/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
 import java.io.Serializable;
 import java.util.List;
+import java.util.Map;
 
 import org.apache.beam.runners.spark.structuredstreaming.SparkPipelineOptions;
 import org.apache.beam.runners.spark.structuredstreaming.SparkRunner;
@@ -29,6 +30,7 @@ import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.junit.BeforeClass;
@@ -89,24 +91,74 @@ public class ParDoTest implements Serializable {
   }
 
   @Test
-  public void testSideInput() {
-    PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
-    final PCollectionView<List<Integer>> sideInput =
-        input.apply(View.asList());
+  public void testSideInputAsList() {
+    PCollection<Integer> sideInput = pipeline.apply("Create sideInput", Create.of(101, 102, 103));
+    final PCollectionView<List<Integer>> sideInputView = sideInput.apply(View.asList());
 
+    PCollection<Integer> input =
+        pipeline.apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
     input.apply(
         ParDo.of(
-            new DoFn<Integer, Integer>() {
-              @ProcessElement
-              public void processElement(ProcessContext context) {
-                List<Integer> list = context.sideInput(sideInput);
+                new DoFn<Integer, Integer>() {
+                  @ProcessElement
+                  public void processElement(ProcessContext context) {
+                    List<Integer> sideInputValue = context.sideInput(sideInputView);
+                    Integer val = context.element();
+                    context.output(val);
+                    System.out.println(
+                        "ParDo1: val = " + val + ", sideInputValue = " + sideInputValue);
+                  }
+                })
+            .withSideInputs(sideInputView));
 
-                Integer val = context.element();
-                context.output(val);
-                System.out.println("ParDo1: val = " + val + ", sideInput = " + list);
-              }
-            })
-            .withSideInputs(sideInput));
+    pipeline.run();
+  }
+
+  @Test
+  public void testSideInputAsSingleton() {
+    PCollection<Integer> sideInput = pipeline.apply("Create sideInput", Create.of(101));
+    final PCollectionView<Integer> sideInputView = sideInput.apply(View.asSingleton());
+
+    PCollection<Integer> input =
+        pipeline.apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+    input.apply(
+        ParDo.of(
+                new DoFn<Integer, Integer>() {
+                  @ProcessElement
+                  public void processElement(ProcessContext context) {
+                    Integer sideInputValue = context.sideInput(sideInputView);
+                    Integer val = context.element();
+                    context.output(val);
+                    System.out.println(
+                        "ParDo1: val = " + val + ", sideInputValue = " + sideInputValue);
+                  }
+                })
+            .withSideInputs(sideInputView));
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testSideInputAsMap() {
+    PCollection<KV<String, Integer>> sideInput =
+        pipeline.apply("Create sideInput", Create.of(KV.of("key1", 1), KV.of("key2", 2)));
+    final PCollectionView<Map<String, Integer>> sideInputView = sideInput.apply(View.asMap());
+
+    PCollection<Integer> input =
+        pipeline.apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+    input.apply(
+        ParDo.of(
+                new DoFn<Integer, Integer>() {
+                  @ProcessElement
+                  public void processElement(ProcessContext context) {
+                    Map<String, Integer> sideInputValue = context.sideInput(sideInputView);
+                    Integer val = context.element();
+                    context.output(val);
+                    System.out.println(
+                        "ParDo1: val = " + val + ", sideInputValue = " + sideInputValue);
+                  }
+                })
+            .withSideInputs(sideInputView));
 
     pipeline.run();
   }