You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/12/14 03:04:20 UTC
[2/3] incubator-beam git commit: [BEAM-807] Replace OldDoFn with DoFn.
[BEAM-807] Replace OldDoFn with DoFn.
Add a custom AssignWindows implementation.
Setup and teardown DoFn.
Add implementation for GroupAlsoByWindow via flatMap.
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/4ffed3e0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/4ffed3e0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/4ffed3e0
Branch: refs/heads/master
Commit: 4ffed3e09a2f0ec3583098f6cfd53a2ddcc6f8c2
Parents: 2be9a15
Author: Sela <an...@paypal.com>
Authored: Sun Dec 11 14:32:49 2016 +0200
Committer: Sela <an...@paypal.com>
Committed: Tue Dec 13 10:05:18 2016 +0200
----------------------------------------------------------------------
.../beam/runners/spark/examples/WordCount.java | 6 +-
.../runners/spark/translation/DoFnFunction.java | 2 +-
.../translation/GroupCombineFunctions.java | 23 +-
.../spark/translation/MultiDoFnFunction.java | 2 +-
.../spark/translation/SparkAssignWindowFn.java | 69 ++++++
.../translation/SparkGroupAlsoByWindowFn.java | 214 +++++++++++++++++++
.../spark/translation/SparkProcessContext.java | 10 +
.../spark/translation/TransformTranslator.java | 31 +--
.../streaming/StreamingTransformTranslator.java | 35 ++-
.../streaming/utils/PAssertStreaming.java | 26 +--
10 files changed, 345 insertions(+), 73 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
index b2672b5..1252d12 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
@@ -25,8 +25,8 @@ import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
-import org.apache.beam.sdk.transforms.OldDoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SimpleFunction;
@@ -44,11 +44,11 @@ public class WordCount {
* of-line. This DoFn tokenizes lines of text into individual words; we pass it to a ParDo in the
* pipeline.
*/
- static class ExtractWordsFn extends OldDoFn<String, String> {
+ static class ExtractWordsFn extends DoFn<String, String> {
private final Aggregator<Long, Long> emptyLines =
createAggregator("emptyLines", new Sum.SumLongFn());
- @Override
+ @ProcessElement
public void processElement(ProcessContext c) {
if (c.element().trim().isEmpty()) {
emptyLines.addValue(1L);
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
index 4c49a7f..6a641b5 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
@@ -93,7 +93,7 @@ public class DoFnFunction<InputT, OutputT>
windowingStrategy
);
- return new SparkProcessContext<>(doFnRunner, outputManager).processPartition(iter);
+ return new SparkProcessContext<>(doFn, doFnRunner, outputManager).processPartition(iter);
}
private class DoFnOutputManager
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
index 421b1b0..4875b0c 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
@@ -18,11 +18,9 @@
package org.apache.beam.runners.spark.translation;
-
import com.google.common.collect.Lists;
import java.util.Collections;
import java.util.Map;
-import org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.spark.aggregators.NamedAggregators;
import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -33,9 +31,7 @@ import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.CombineWithContext;
-import org.apache.beam.sdk.transforms.OldDoFn;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowingStrategy;
import org.apache.beam.sdk.values.KV;
@@ -59,7 +55,7 @@ public class GroupCombineFunctions {
/**
* Apply {@link org.apache.beam.sdk.transforms.GroupByKey} to a Spark RDD.
*/
- public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K,
+ public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K,
Iterable<V>>>> groupByKey(JavaRDD<WindowedValue<KV<K, V>>> rdd,
Accumulator<NamedAggregators> accum,
KvCoder<K, V> coder,
@@ -86,15 +82,14 @@ public class GroupCombineFunctions {
.map(WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction());
//--- now group also by window.
- @SuppressWarnings("unchecked")
- WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn();
- // GroupAlsoByWindow current uses a dummy in-memory StateInternals
- OldDoFn<KV<K, Iterable<WindowedValue<V>>>, KV<K, Iterable<V>>> gabwDoFn =
- new GroupAlsoByWindowsViaOutputBufferDoFn<K, V, Iterable<V>, W>(
- windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory<K>(),
- SystemReduceFn.<K, V, W>buffering(valueCoder));
- return groupedByKey.mapPartitions(new DoFnFunction<>(accum, gabwDoFn, runtimeContext, null,
- windowFn));
+ // GroupAlsoByWindow currently uses a dummy in-memory StateInternals
+ return groupedByKey.flatMap(
+ new SparkGroupAlsoByWindowFn<>(
+ windowingStrategy,
+ new TranslationUtils.InMemoryStateInternalsFactory<K>(),
+ SystemReduceFn.<K, V, W>buffering(valueCoder),
+ runtimeContext,
+ accum));
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index 710c5cd..8a55369 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -102,7 +102,7 @@ public class MultiDoFnFunction<InputT, OutputT>
windowingStrategy
);
- return new SparkProcessContext<>(doFnRunner, outputManager).processPartition(iter);
+ return new SparkProcessContext<>(doFn, doFnRunner, outputManager).processPartition(iter);
}
private class DoFnOutputManager
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
new file mode 100644
index 0000000..9d7ed7d
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.common.collect.Iterables;
+import java.util.Collection;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.spark.api.java.function.Function;
+import org.joda.time.Instant;
+
+
+/**
+ * An implementation of {@link org.apache.beam.runners.core.AssignWindows} for the Spark runner.
+ */
+public class SparkAssignWindowFn<T, W extends BoundedWindow>
+ implements Function<WindowedValue<T>, WindowedValue<T>> {
+
+ private WindowFn<? super T, W> fn;
+
+ public SparkAssignWindowFn(WindowFn<? super T, W> fn) {
+ this.fn = fn;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public WindowedValue<T> call(WindowedValue<T> windowedValue) throws Exception {
+ final BoundedWindow boundedWindow = Iterables.getOnlyElement(windowedValue.getWindows());
+ final T element = windowedValue.getValue();
+ final Instant timestamp = windowedValue.getTimestamp();
+ Collection<W> windows =
+ ((WindowFn<T, W>) fn).assignWindows(
+ ((WindowFn<T, W>) fn).new AssignContext() {
+ @Override
+ public T element() {
+ return element;
+ }
+
+ @Override
+ public Instant timestamp() {
+ return timestamp;
+ }
+
+ @Override
+ public BoundedWindow window() {
+ return boundedWindow;
+ }
+ });
+ return WindowedValue.of(element, timestamp, windows, PaneInfo.NO_FIRING);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
new file mode 100644
index 0000000..87d3f50
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.common.collect.Iterables;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import org.apache.beam.runners.core.GroupAlsoByWindowsDoFn;
+import org.apache.beam.runners.core.OutputWindowedValue;
+import org.apache.beam.runners.core.ReduceFnRunner;
+import org.apache.beam.runners.core.SystemReduceFn;
+import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine;
+import org.apache.beam.runners.core.triggers.TriggerStateMachines;
+import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.SideInputReader;
+import org.apache.beam.sdk.util.TimerInternals;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.InMemoryTimerInternals;
+import org.apache.beam.sdk.util.state.StateInternals;
+import org.apache.beam.sdk.util.state.StateInternalsFactory;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.Accumulator;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.joda.time.Instant;
+
+
+
+/**
+ * An implementation of {@link org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn}
+ * for the Spark runner.
+ */
+public class SparkGroupAlsoByWindowFn<K, InputT, W extends BoundedWindow>
+ implements FlatMapFunction<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>,
+ WindowedValue<KV<K, Iterable<InputT>>>> {
+
+ private final WindowingStrategy<?, W> windowingStrategy;
+ private final StateInternalsFactory<K> stateInternalsFactory;
+ private final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn;
+ private final SparkRuntimeContext runtimeContext;
+ private final Aggregator<Long, Long> droppedDueToClosedWindow;
+
+
+ public SparkGroupAlsoByWindowFn(
+ WindowingStrategy<?, W> windowingStrategy,
+ StateInternalsFactory<K> stateInternalsFactory,
+ SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn,
+ SparkRuntimeContext runtimeContext,
+ Accumulator<NamedAggregators> accumulator) {
+ this.windowingStrategy = windowingStrategy;
+ this.stateInternalsFactory = stateInternalsFactory;
+ this.reduceFn = reduceFn;
+ this.runtimeContext = runtimeContext;
+
+ droppedDueToClosedWindow = runtimeContext.createAggregator(
+ accumulator,
+ GroupAlsoByWindowsDoFn.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER,
+ new Sum.SumLongFn());
+ }
+
+ @Override
+ public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call(
+ WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>> windowedValue) throws Exception {
+ K key = windowedValue.getValue().getKey();
+ Iterable<WindowedValue<InputT>> inputs = windowedValue.getValue().getValue();
+
+ //------ based on GroupAlsoByWindowsViaOutputBufferDoFn ------//
+
+ // Used with Batch, we know that all the data is available for this key. We can't use the
+ // timer manager from the context because it doesn't exist. So we create one and emulate the
+ // watermark, knowing that we have all data and it is in timestamp order.
+ InMemoryTimerInternals timerInternals = new InMemoryTimerInternals();
+ timerInternals.advanceProcessingTime(Instant.now());
+ timerInternals.advanceSynchronizedProcessingTime(Instant.now());
+ StateInternals<K> stateInternals = stateInternalsFactory.stateInternalsForKey(key);
+ GABWOutputWindowedValue<K, InputT> outputter = new GABWOutputWindowedValue<>();
+
+ ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner =
+ new ReduceFnRunner<>(
+ key,
+ windowingStrategy,
+ ExecutableTriggerStateMachine.create(
+ TriggerStateMachines.stateMachineForTrigger(windowingStrategy.getTrigger())),
+ stateInternals,
+ timerInternals,
+ outputter,
+ new SideInputReader() {
+ @Override
+ public <T> T get(PCollectionView<T> view, BoundedWindow sideInputWindow) {
+ throw new UnsupportedOperationException(
+ "GroupAlsoByWindow must not have side inputs");
+ }
+
+ @Override
+ public <T> boolean contains(PCollectionView<T> view) {
+ throw new UnsupportedOperationException(
+ "GroupAlsoByWindow must not have side inputs");
+ }
+
+ @Override
+ public boolean isEmpty() {
+ throw new UnsupportedOperationException(
+ "GroupAlsoByWindow must not have side inputs");
+ }
+ },
+ droppedDueToClosedWindow,
+ reduceFn,
+ runtimeContext.getPipelineOptions());
+
+ Iterable<List<WindowedValue<InputT>>> chunks = Iterables.partition(inputs, 1000);
+ for (Iterable<WindowedValue<InputT>> chunk : chunks) {
+ // Process the chunk of elements.
+ reduceFnRunner.processElements(chunk);
+
+ // Then, since elements are sorted by their timestamp, advance the input watermark
+ // to the first element.
+ timerInternals.advanceInputWatermark(chunk.iterator().next().getTimestamp());
+ // Advance the processing times.
+ timerInternals.advanceProcessingTime(Instant.now());
+ timerInternals.advanceSynchronizedProcessingTime(Instant.now());
+
+ // Fire all the eligible timers.
+ fireEligibleTimers(timerInternals, reduceFnRunner);
+
+ // Leave the output watermark undefined. Since there's no late data in batch mode
+ // there's really no need to track it as we do for streaming.
+ }
+
+ // Finish any pending windows by advancing the input watermark to infinity.
+ timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+ // Finally, advance the processing time to infinity to fire any timers.
+ timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+ timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+ fireEligibleTimers(timerInternals, reduceFnRunner);
+
+ reduceFnRunner.persist();
+
+ return outputter.getOutputs();
+ }
+
+ private void fireEligibleTimers(InMemoryTimerInternals timerInternals,
+ ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner) throws Exception {
+ List<TimerInternals.TimerData> timers = new ArrayList<>();
+ while (true) {
+ TimerInternals.TimerData timer;
+ while ((timer = timerInternals.removeNextEventTimer()) != null) {
+ timers.add(timer);
+ }
+ while ((timer = timerInternals.removeNextProcessingTimer()) != null) {
+ timers.add(timer);
+ }
+ while ((timer = timerInternals.removeNextSynchronizedProcessingTimer()) != null) {
+ timers.add(timer);
+ }
+ if (timers.isEmpty()) {
+ break;
+ }
+ reduceFnRunner.onTimers(timers);
+ timers.clear();
+ }
+ }
+
+ private static class GABWOutputWindowedValue<K, V>
+ implements OutputWindowedValue<KV<K, Iterable<V>>> {
+ private final List<WindowedValue<KV<K, Iterable<V>>>> outputs = new ArrayList<>();
+
+ @Override
+ public void outputWindowedValue(
+ KV<K, Iterable<V>> output,
+ Instant timestamp,
+ Collection<? extends BoundedWindow> windows,
+ PaneInfo pane) {
+ outputs.add(WindowedValue.of(output, timestamp, windows, pane));
+ }
+
+ @Override
+ public <SideOutputT> void sideOutputWindowedValue(
+ TupleTag<SideOutputT> tag,
+ SideOutputT output,
+ Instant timestamp,
+ Collection<? extends BoundedWindow> windows, PaneInfo pane) {
+ throw new UnsupportedOperationException("GroupAlsoByWindow should not use side outputs.");
+ }
+
+ Iterable<WindowedValue<KV<K, Iterable<V>>>> getOutputs() {
+ return outputs;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index efd8202..3a31cae 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -25,6 +25,8 @@ import java.util.Iterator;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners.OutputManager;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.ExecutionContext.StepContext;
import org.apache.beam.sdk.util.TimerInternals;
@@ -38,13 +40,16 @@ import org.apache.beam.sdk.values.TupleTag;
*/
class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
+ private final DoFn<FnInputT, FnOutputT> doFn;
private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
private final SparkOutputManager<OutputT> outputManager;
SparkProcessContext(
+ DoFn<FnInputT, FnOutputT> doFn,
DoFnRunner<FnInputT, FnOutputT> doFnRunner,
SparkOutputManager<OutputT> outputManager) {
+ this.doFn = doFn;
this.doFnRunner = doFnRunner;
this.outputManager = outputManager;
}
@@ -52,6 +57,9 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
Iterable<OutputT> processPartition(
Iterator<WindowedValue<FnInputT>> partition) throws Exception {
+ // setup DoFn.
+ DoFnInvokers.invokerFor(doFn).invokeSetup();
+
// skip if partition is empty.
if (!partition.hasNext()) {
return Lists.newArrayList();
@@ -160,6 +168,8 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
clearOutput();
calledFinish = true;
doFnRunner.finishBundle();
+ // teardown DoFn.
+ DoFnInvokers.invokerFor(doFn).invokeTeardown();
outputIterator = getOutputIterator();
continue; // try to consume outputIterator from start of loop
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 964eb37..ac91892 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -32,7 +32,6 @@ import java.util.Map;
import org.apache.avro.mapred.AvroKey;
import org.apache.avro.mapreduce.AvroJob;
import org.apache.avro.mapreduce.AvroKeyInputFormat;
-import org.apache.beam.runners.core.AssignWindowsDoFn;
import org.apache.beam.runners.spark.aggregators.NamedAggregators;
import org.apache.beam.runners.spark.aggregators.SparkAggregators;
import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -54,13 +53,11 @@ import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowingStrategy;
@@ -235,16 +232,15 @@ public final class TransformTranslator {
@SuppressWarnings("unchecked")
JavaRDD<WindowedValue<InputT>> inRDD =
((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
- @SuppressWarnings("unchecked")
- final WindowFn<Object, ?> windowFn =
- (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn();
+ WindowingStrategy<?, ?> windowingStrategy =
+ context.getInput(transform).getWindowingStrategy();
Accumulator<NamedAggregators> accum =
SparkAggregators.getNamedAggregators(context.getSparkContext());
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs =
TranslationUtils.getSideInputs(transform.getSideInputs(), context);
context.putDataset(transform,
- new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, transform.getFn(),
- context.getRuntimeContext(), sideInputs, windowFn))));
+ new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, doFn,
+ context.getRuntimeContext(), sideInputs, windowingStrategy))));
}
};
}
@@ -259,16 +255,15 @@ public final class TransformTranslator {
@SuppressWarnings("unchecked")
JavaRDD<WindowedValue<InputT>> inRDD =
((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
- @SuppressWarnings("unchecked")
- final WindowFn<Object, ?> windowFn =
- (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn();
+ WindowingStrategy<?, ?> windowingStrategy =
+ context.getInput(transform).getWindowingStrategy();
Accumulator<NamedAggregators> accum =
SparkAggregators.getNamedAggregators(context.getSparkContext());
JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD
.mapPartitionsToPair(
- new MultiDoFnFunction<>(accum, transform.getFn(), context.getRuntimeContext(),
+ new MultiDoFnFunction<>(accum, doFn, context.getRuntimeContext(),
transform.getMainOutputTag(), TranslationUtils.getSideInputs(
- transform.getSideInputs(), context), windowFn)).cache();
+ transform.getSideInputs(), context), windowingStrategy)).cache();
PCollectionTuple pct = context.getOutput(transform);
for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) {
@SuppressWarnings("unchecked")
@@ -508,14 +503,8 @@ public final class TransformTranslator {
if (TranslationUtils.skipAssignWindows(transform, context)) {
context.putDataset(transform, new BoundedDataset<>(inRDD));
} else {
- @SuppressWarnings("unchecked")
- WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn();
- OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn);
- Accumulator<NamedAggregators> accum =
- SparkAggregators.getNamedAggregators(context.getSparkContext());
- context.putDataset(transform,
- new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, addWindowsDoFn,
- context.getRuntimeContext(), null, null))));
+ context.putDataset(transform, new BoundedDataset<>(
+ inRDD.map(new SparkAssignWindowFn<>(transform.getWindowFn()))));
}
}
};
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 00df7d4..27204ed 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -24,7 +24,6 @@ import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import org.apache.beam.runners.core.AssignWindowsDoFn;
import org.apache.beam.runners.spark.aggregators.NamedAggregators;
import org.apache.beam.runners.spark.aggregators.SparkAggregators;
import org.apache.beam.runners.spark.io.ConsoleIO;
@@ -36,6 +35,7 @@ import org.apache.beam.runners.spark.translation.DoFnFunction;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
import org.apache.beam.runners.spark.translation.MultiDoFnFunction;
+import org.apache.beam.runners.spark.translation.SparkAssignWindowFn;
import org.apache.beam.runners.spark.translation.SparkKeyedCombineFn;
import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
@@ -51,7 +51,6 @@ import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -163,7 +162,7 @@ final class StreamingTransformTranslator {
private static <T, W extends BoundedWindow> TransformEvaluator<Window.Bound<T>> window() {
return new TransformEvaluator<Window.Bound<T>>() {
@Override
- public void evaluate(Window.Bound<T> transform, EvaluationContext context) {
+ public void evaluate(final Window.Bound<T> transform, EvaluationContext context) {
@SuppressWarnings("unchecked")
WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn();
@SuppressWarnings("unchecked")
@@ -189,16 +188,11 @@ final class StreamingTransformTranslator {
if (TranslationUtils.skipAssignWindows(transform, context)) {
context.putDataset(transform, new UnboundedDataset<>(windowedDStream));
} else {
- final OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn);
- final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
JavaDStream<WindowedValue<T>> outStream = windowedDStream.transform(
new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() {
@Override
public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> rdd) throws Exception {
- final Accumulator<NamedAggregators> accum =
- SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context()));
- return rdd.mapPartitions(
- new DoFnFunction<>(accum, addWindowsDoFn, runtimeContext, null, null));
+ return rdd.map(new SparkAssignWindowFn<>(transform.getWindowFn()));
}
});
context.putDataset(transform, new UnboundedDataset<>(outStream));
@@ -350,13 +344,13 @@ final class StreamingTransformTranslator {
@Override
public void evaluate(final ParDo.Bound<InputT, OutputT> transform,
final EvaluationContext context) {
- DoFn<InputT, OutputT> doFn = transform.getNewFn();
+ final DoFn<InputT, OutputT> doFn = transform.getNewFn();
rejectStateAndTimers(doFn);
final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs =
TranslationUtils.getSideInputs(transform.getSideInputs(), context);
- final WindowFn<Object, ?> windowFn =
- (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn();
+ final WindowingStrategy<?, ?> windowingStrategy =
+ context.getInput(transform).getWindowingStrategy();
JavaDStream<WindowedValue<InputT>> dStream =
((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream();
@@ -369,7 +363,7 @@ final class StreamingTransformTranslator {
final Accumulator<NamedAggregators> accum =
SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context()));
return rdd.mapPartitions(
- new DoFnFunction<>(accum, transform.getFn(), runtimeContext, sideInputs, windowFn));
+ new DoFnFunction<>(accum, doFn, runtimeContext, sideInputs, windowingStrategy));
}
});
@@ -384,14 +378,13 @@ final class StreamingTransformTranslator {
@Override
public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform,
final EvaluationContext context) {
- DoFn<InputT, OutputT> doFn = transform.getNewFn();
+ final DoFn<InputT, OutputT> doFn = transform.getNewFn();
rejectStateAndTimers(doFn);
final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs =
TranslationUtils.getSideInputs(transform.getSideInputs(), context);
- @SuppressWarnings("unchecked")
- final WindowFn<Object, ?> windowFn =
- (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn();
+ final WindowingStrategy<?, ?> windowingStrategy =
+ context.getInput(transform).getWindowingStrategy();
@SuppressWarnings("unchecked")
JavaDStream<WindowedValue<InputT>> dStream =
((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream();
@@ -403,8 +396,8 @@ final class StreamingTransformTranslator {
JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
final Accumulator<NamedAggregators> accum =
SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context()));
- return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, transform.getFn(),
- runtimeContext, transform.getMainOutputTag(), sideInputs, windowFn));
+ return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, doFn,
+ runtimeContext, transform.getMainOutputTag(), sideInputs, windowingStrategy));
}
}).cache();
PCollectionTuple pct = context.getOutput(transform);
@@ -423,8 +416,8 @@ final class StreamingTransformTranslator {
};
}
- private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = Maps
- .newHashMap();
+ private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS =
+ Maps.newHashMap();
static {
EVALUATORS.put(Read.Unbounded.class, readUnbounded());
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
index 471ec92..0284b3d 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
@@ -27,8 +27,8 @@ import org.apache.beam.runners.spark.SparkPipelineResult;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Values;
@@ -55,11 +55,12 @@ public final class PAssertStreaming implements Serializable {
* Note that it is oblivious to windowing, so the assertion will apply indiscriminately to all
* windows.
*/
- public static <T> SparkPipelineResult runAndAssertContents(Pipeline p,
- PCollection<T> actual,
- T[] expected,
- Duration timeout,
- boolean stopGracefully) {
+ public static <T> SparkPipelineResult runAndAssertContents(
+ Pipeline p,
+ PCollection<T> actual,
+ T[] expected,
+ Duration timeout,
+ boolean stopGracefully) {
// Because PAssert does not support non-global windowing, but all our data is in one window,
// we set up the assertion directly.
actual
@@ -86,14 +87,15 @@ public final class PAssertStreaming implements Serializable {
* Default to stop gracefully so that tests will finish processing even if slower for reasons
* such as a slow runtime environment.
*/
- public static <T> SparkPipelineResult runAndAssertContents(Pipeline p,
- PCollection<T> actual,
- T[] expected,
- Duration timeout) {
+ public static <T> SparkPipelineResult runAndAssertContents(
+ Pipeline p,
+ PCollection<T> actual,
+ T[] expected,
+ Duration timeout) {
return runAndAssertContents(p, actual, expected, timeout, true);
}
- private static class AssertDoFn<T> extends OldDoFn<Iterable<T>, Void> {
+ private static class AssertDoFn<T> extends DoFn<Iterable<T>, Void> {
private final Aggregator<Integer, Integer> success =
createAggregator(PAssert.SUCCESS_COUNTER, new Sum.SumIntegerFn());
private final Aggregator<Integer, Integer> failure =
@@ -104,7 +106,7 @@ public final class PAssertStreaming implements Serializable {
this.expected = expected;
}
- @Override
+ @ProcessElement
public void processElement(ProcessContext c) throws Exception {
try {
assertThat(c.element(), containsInAnyOrder(expected));