You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ar...@apache.org on 2019/03/20 09:16:50 UTC
[beam] branch spark-runner_structured-streaming updated: Added
SideInput support
This is an automated email from the ASF dual-hosted git repository.
aromanenko pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/spark-runner_structured-streaming by this push:
new 49ab275 Added SideInput support
49ab275 is described below
commit 49ab27554bc6fc44f5f5f23c5d0a6535fb4a158d
Author: Alexey Romanenko <ar...@gmail.com>
AuthorDate: Tue Mar 19 19:33:11 2019 +0100
Added SideInput support
---
.../translation/TranslationContext.java | 5 +
.../translation/batch/DoFnFunction.java | 11 +-
.../translation/batch/ParDoTranslatorBatch.java | 48 +++++--
.../batch/functions/NoOpSideInputReader.java | 56 --------
.../batch/functions/SparkSideInputReader.java | 148 +++++++++++++++++++++
.../translation/helpers/CoderHelpers.java | 47 +++++++
.../translation/helpers/SideInputBroadcast.java | 28 ++++
.../translation/batch/ParDoTest.java | 80 +++++++++--
8 files changed, 339 insertions(+), 84 deletions(-)
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
index 013ef75..d2ace25 100644
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
@@ -139,6 +139,11 @@ public class TranslationContext {
}
}
+ @SuppressWarnings("unchecked")
+ public <T> Dataset<T> getSideInputDataSet(PCollectionView<?> value) {
+ return (Dataset<T>) broadcastDataSets.get(value);
+ }
+
// --------------------------------------------------------------------------------------------
// PCollections methods
// --------------------------------------------------------------------------------------------
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
index 0409a79..4449082 100644
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
@@ -28,11 +28,11 @@ import java.util.Map;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpSideInputReader;
+import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
@@ -62,6 +62,7 @@ public class DoFnFunction<InputT, OutputT>
private final TupleTag<OutputT> mainOutputTag;
private final Coder<InputT> inputCoder;
private final Map<TupleTag<?>, Coder<?>> outputCoderMap;
+ private final SideInputBroadcast broadcastStateData;
public DoFnFunction(
DoFn<InputT, OutputT> doFn,
@@ -71,7 +72,8 @@ public class DoFnFunction<InputT, OutputT>
List<TupleTag<?>> additionalOutputTags,
TupleTag<OutputT> mainOutputTag,
Coder<InputT> inputCoder,
- Map<TupleTag<?>, Coder<?>> outputCoderMap) {
+ Map<TupleTag<?>, Coder<?>> outputCoderMap,
+ SideInputBroadcast broadcastStateData) {
this.doFn = doFn;
this.sideInputs = sideInputs;
@@ -81,6 +83,7 @@ public class DoFnFunction<InputT, OutputT>
this.mainOutputTag = mainOutputTag;
this.inputCoder = inputCoder;
this.outputCoderMap = outputCoderMap;
+ this.broadcastStateData = broadcastStateData;
}
@Override
@@ -93,7 +96,7 @@ public class DoFnFunction<InputT, OutputT>
DoFnRunners.simpleRunner(
serializableOptions.get(),
doFn,
- new NoOpSideInputReader(sideInputs),
+ new SparkSideInputReader(sideInputs, broadcastStateData),
outputManager,
mainOutputTag,
additionalOutputTags,
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index 443ed67..651901a 100644
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -21,18 +21,20 @@ import static com.google.common.base.Preconditions.checkState;
import com.google.common.collect.Lists;
import java.io.IOException;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
+
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
@@ -40,6 +42,7 @@ import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
@@ -72,12 +75,6 @@ class ParDoTranslatorBatch<InputT, OutputT>
signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0;
checkState(!stateful, "States and timers are not supported for the moment.");
- // TODO: add support of SideInputs
- List<PCollectionView<?>> sideInputs = getSideInputs(context);
- System.out.println("sideInputs = " + sideInputs);
- final boolean hasSideInputs = sideInputs != null && sideInputs.size() > 0;
- checkState(!hasSideInputs, "SideInputs are not supported for the moment.");
-
// Init main variables
Dataset<WindowedValue<InputT>> inputDataSet = context.getDataset(context.getInput());
Map<TupleTag<?>, PValue> outputs = context.getOutputs();
@@ -88,11 +85,14 @@ class ParDoTranslatorBatch<InputT, OutputT>
// construct a map from side input to WindowingStrategy so that
// the DoFn runner can map main-input windows to side input windows
+ List<PCollectionView<?>> sideInputs = getSideInputs(context);
Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputStrategies = new HashMap<>();
for (PCollectionView<?> sideInput : sideInputs) {
sideInputStrategies.put(sideInput, sideInput.getPCollection().getWindowingStrategy());
}
+ SideInputBroadcast broadcastStateData = createBroadcastSideInputs(sideInputs, context);
+
Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
Coder<InputT> inputCoder = ((PCollection<InputT>) context.getInput()).getCoder();
@@ -106,7 +106,9 @@ class ParDoTranslatorBatch<InputT, OutputT>
outputTags,
mainOutputTag,
inputCoder,
- outputCoderMap);
+ outputCoderMap,
+ broadcastStateData
+ );
Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs =
inputDataSet.mapPartitions(doFnWrapper, EncoderHelpers.tuple2Encoder());
@@ -116,6 +118,32 @@ class ParDoTranslatorBatch<InputT, OutputT>
}
}
+ private static SideInputBroadcast createBroadcastSideInputs(
+ List<PCollectionView<?>> sideInputs, TranslationContext context) {
+ JavaSparkContext jsc =
+ JavaSparkContext.fromSparkContext(context.getSparkSession().sparkContext());
+
+ SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
+ for (PCollectionView<?> input : sideInputs) {
+ Coder<? extends BoundedWindow> windowCoder =
+ input.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
+ Coder<WindowedValue<?>> windowedValueCoder =
+ (Coder<WindowedValue<?>>)
+ (Coder<?>) WindowedValue.getFullCoder(input.getPCollection().getCoder(), windowCoder);
+
+ Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataSet(input);
+ List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
+ List<byte[]> codedValues = new ArrayList<>();
+ for (WindowedValue<?> v : valuesList) {
+ codedValues.add(CoderHelpers.toByteArray(v, windowedValueCoder));
+ }
+
+ sideInputBroadcast.add(
+ input.getTagInternal().getId(), jsc.broadcast(codedValues), windowedValueCoder);
+ }
+ return sideInputBroadcast;
+ }
+
private List<PCollectionView<?>> getSideInputs(TranslationContext context) {
List<PCollectionView<?>> sideInputs;
try {
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpSideInputReader.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpSideInputReader.java
deleted file mode 100644
index eca9d95..0000000
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpSideInputReader.java
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions;
-
-import java.util.HashMap;
-import java.util.Map;
-import javax.annotation.Nullable;
-import org.apache.beam.runners.core.SideInputReader;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-
-/**
- * TODO: Need to be implemented
- *
- * <p>A {@link SideInputReader} for the Spark Batch Runner.
- */
-public class NoOpSideInputReader implements SideInputReader {
- private final Map<TupleTag<?>, WindowingStrategy<?, ?>> sideInputs;
-
- public NoOpSideInputReader(Map<PCollectionView<?>, WindowingStrategy<?, ?>> indexByView) {
- sideInputs = new HashMap<>();
- }
-
- @Nullable
- @Override
- public <T> T get(PCollectionView<T> view, BoundedWindow window) {
- return null;
- }
-
- @Override
- public <T> boolean contains(PCollectionView<T> view) {
- return sideInputs.containsKey(view.getTagInternal());
- }
-
- @Override
- public boolean isEmpty() {
- return sideInputs.isEmpty();
- }
-}
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java
new file mode 100644
index 0000000..91b4f83
--- /dev/null
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions;
+
+import java.util.*;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+import org.apache.beam.runners.core.InMemoryMultimapSideInputView;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Materializations;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.SparkConf;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+
+/** A {@link SideInputReader} for the Spark Batch Runner. */
+public class SparkSideInputReader implements SideInputReader {
+ /** A {@link Materializations.MultimapView} which always returns an empty iterable. */
+ private static final Materializations.MultimapView EMPTY_MULTMAP_VIEW =
+ o -> Collections.EMPTY_LIST;
+
+ private final Map<TupleTag<?>, WindowingStrategy<?, ?>> sideInputs;
+ private final SideInputBroadcast broadcastStateData;
+
+ public SparkSideInputReader(
+ Map<PCollectionView<?>, WindowingStrategy<?, ?>> indexByView,
+ SideInputBroadcast broadcastStateData) {
+ for (PCollectionView<?> view : indexByView.keySet()) {
+ checkArgument(
+ Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
+ view.getViewFn().getMaterialization().getUrn()),
+ "This handler is only capable of dealing with %s materializations "
+ + "but was asked to handle %s for PCollectionView with tag %s.",
+ Materializations.MULTIMAP_MATERIALIZATION_URN,
+ view.getViewFn().getMaterialization().getUrn(),
+ view.getTagInternal().getId());
+ }
+ sideInputs = new HashMap<>();
+ for (Map.Entry<PCollectionView<?>, WindowingStrategy<?, ?>> entry : indexByView.entrySet()) {
+ sideInputs.put(entry.getKey().getTagInternal(), entry.getValue());
+ }
+ this.broadcastStateData = broadcastStateData;
+ }
+
+ @Nullable
+ @Override
+ public <T> T get(PCollectionView<T> view, BoundedWindow window) {
+ checkNotNull(view, "View passed to sideInput cannot be null");
+ TupleTag<?> tag = view.getTagInternal();
+ checkNotNull(sideInputs.get(tag), "Side input for " + view + " not available.");
+
+ List<byte[]> sideInputsValues =
+ (List<byte[]>) broadcastStateData.getBroadcastValue(tag.getId()).getValue();
+ Coder<?> coder = broadcastStateData.getCoder(tag.getId());
+
+ List<WindowedValue<?>> decodedValues = new ArrayList<>();
+ for (byte[] value : sideInputsValues) {
+ decodedValues.add((WindowedValue<?>) CoderHelpers.fromByteArray(value, coder));
+ }
+
+ Map<BoundedWindow, T> sideInputs = initializeBroadcastVariable(decodedValues, view);
+ T result = sideInputs.get(window);
+ if (result == null) {
+ ViewFn<Materializations.MultimapView, T> viewFn =
+ (ViewFn<Materializations.MultimapView, T>) view.getViewFn();
+ result = viewFn.apply(EMPTY_MULTMAP_VIEW);
+ }
+ return result;
+ }
+
+ @Override
+ public <T> boolean contains(PCollectionView<T> view) {
+ return sideInputs.containsKey(view.getTagInternal());
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return sideInputs.isEmpty();
+ }
+
+ public <T> Map<BoundedWindow, T> initializeBroadcastVariable(
+ Iterable<WindowedValue<?>> inputValues, PCollectionView<T> view) {
+
+ // first partition into windows
+ Map<BoundedWindow, List<WindowedValue<KV<?, ?>>>> partitionedElements = new HashMap<>();
+ for (WindowedValue<KV<?, ?>> value :
+ (Iterable<WindowedValue<KV<?, ?>>>) (Iterable) inputValues) {
+ for (BoundedWindow window : value.getWindows()) {
+ List<WindowedValue<KV<?, ?>>> windowedValues =
+ partitionedElements.computeIfAbsent(window, k -> new ArrayList<>());
+ windowedValues.add(value);
+ }
+ }
+
+ Map<BoundedWindow, T> resultMap = new HashMap<>();
+
+ for (Map.Entry<BoundedWindow, List<WindowedValue<KV<?, ?>>>> elements :
+ partitionedElements.entrySet()) {
+
+ ViewFn<Materializations.MultimapView, T> viewFn =
+ (ViewFn<Materializations.MultimapView, T>) view.getViewFn();
+ Coder keyCoder = ((KvCoder<?, ?>) view.getCoderInternal()).getKeyCoder();
+ resultMap.put(
+ elements.getKey(),
+ (T)
+ viewFn.apply(
+ InMemoryMultimapSideInputView.fromIterable(
+ keyCoder,
+ (Iterable)
+ elements.getValue().stream()
+ .map(WindowedValue::getValue)
+ .collect(Collectors.toList()))));
+ }
+
+ return resultMap;
+ }
+}
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
new file mode 100644
index 0000000..6764dd8
--- /dev/null
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
@@ -0,0 +1,47 @@
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import org.apache.beam.sdk.coders.Coder;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+
+/** Serialization utility class. */
+public final class CoderHelpers {
+ private CoderHelpers() {}
+
+ /**
+ * Utility method for serializing an object using the specified coder.
+ *
+ * @param value Value to serialize.
+ * @param coder Coder to serialize with.
+ * @param <T> type of value that is serialized
+ * @return Byte array representing serialized object.
+ */
+ public static <T> byte[] toByteArray(T value, Coder<T> coder) {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ try {
+ coder.encode(value, baos, new Coder.Context(true));
+ } catch (IOException e) {
+ throw new IllegalStateException("Error encoding value: " + value, e);
+ }
+ return baos.toByteArray();
+ }
+
+ /**
+ * Utility method for deserializing a byte array using the specified coder.
+ *
+ * @param serialized bytearray to be deserialized.
+ * @param coder Coder to deserialize with.
+ * @param <T> Type of object to be returned.
+ * @return Deserialized object.
+ */
+ public static <T> T fromByteArray(byte[] serialized, Coder<T> coder) {
+ ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
+ try {
+ return coder.decode(bais, new Coder.Context(true));
+ } catch (IOException e) {
+ throw new IllegalStateException("Error decoding bytes for coder: " + coder, e);
+ }
+ }
+}
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SideInputBroadcast.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SideInputBroadcast.java
new file mode 100644
index 0000000..a67a595
--- /dev/null
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SideInputBroadcast.java
@@ -0,0 +1,28 @@
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.spark.broadcast.Broadcast;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+public class SideInputBroadcast implements Serializable {
+
+ private Map<String, Broadcast<?>> bcast = new HashMap<>();
+ private Map<String, Coder<?>> coder = new HashMap<>();
+
+ public SideInputBroadcast(){}
+
+ public void add(String key, Broadcast<?> bcast, Coder<?> coder) {
+ this.bcast.put(key, bcast);
+ this.coder.put(key, coder);
+ }
+
+ public Broadcast<?> getBroadcastValue(String key) {
+ return bcast.get(key);
+ }
+
+ public Coder<?> getCoder(String key) {
+ return coder.get(key);
+ }
+}
diff --git a/runners/spark-structured-streaming/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java b/runners/spark-structured-streaming/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
index c028dc0..b7a682d 100644
--- a/runners/spark-structured-streaming/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
+++ b/runners/spark-structured-streaming/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
import java.io.Serializable;
import java.util.List;
+import java.util.Map;
import org.apache.beam.runners.spark.structuredstreaming.SparkPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkRunner;
@@ -29,6 +30,7 @@ import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.junit.BeforeClass;
@@ -89,24 +91,74 @@ public class ParDoTest implements Serializable {
}
@Test
- public void testSideInput() {
- PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
- final PCollectionView<List<Integer>> sideInput =
- input.apply(View.asList());
+ public void testSideInputAsList() {
+ PCollection<Integer> sideInput = pipeline.apply("Create sideInput", Create.of(101, 102, 103));
+ final PCollectionView<List<Integer>> sideInputView = sideInput.apply(View.asList());
+ PCollection<Integer> input =
+ pipeline.apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
input.apply(
ParDo.of(
- new DoFn<Integer, Integer>() {
- @ProcessElement
- public void processElement(ProcessContext context) {
- List<Integer> list = context.sideInput(sideInput);
+ new DoFn<Integer, Integer>() {
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ List<Integer> sideInputValue = context.sideInput(sideInputView);
+ Integer val = context.element();
+ context.output(val);
+ System.out.println(
+ "ParDo1: val = " + val + ", sideInputValue = " + sideInputValue);
+ }
+ })
+ .withSideInputs(sideInputView));
- Integer val = context.element();
- context.output(val);
- System.out.println("ParDo1: val = " + val + ", sideInput = " + list);
- }
- })
- .withSideInputs(sideInput));
+ pipeline.run();
+ }
+
+ @Test
+ public void testSideInputAsSingleton() {
+ PCollection<Integer> sideInput = pipeline.apply("Create sideInput", Create.of(101));
+ final PCollectionView<Integer> sideInputView = sideInput.apply(View.asSingleton());
+
+ PCollection<Integer> input =
+ pipeline.apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+ input.apply(
+ ParDo.of(
+ new DoFn<Integer, Integer>() {
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ Integer sideInputValue = context.sideInput(sideInputView);
+ Integer val = context.element();
+ context.output(val);
+ System.out.println(
+ "ParDo1: val = " + val + ", sideInputValue = " + sideInputValue);
+ }
+ })
+ .withSideInputs(sideInputView));
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testSideInputAsMap() {
+ PCollection<KV<String, Integer>> sideInput =
+ pipeline.apply("Create sideInput", Create.of(KV.of("key1", 1), KV.of("key2", 2)));
+ final PCollectionView<Map<String, Integer>> sideInputView = sideInput.apply(View.asMap());
+
+ PCollection<Integer> input =
+ pipeline.apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+ input.apply(
+ ParDo.of(
+ new DoFn<Integer, Integer>() {
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ Map<String, Integer> sideInputValue = context.sideInput(sideInputView);
+ Integer val = context.element();
+ context.output(val);
+ System.out.println(
+ "ParDo1: val = " + val + ", sideInputValue = " + sideInputValue);
+ }
+ })
+ .withSideInputs(sideInputView));
pipeline.run();
}