You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/12/14 03:04:19 UTC

[1/3] incubator-beam git commit: [BEAM-757] Use DoFnRunner in the implementation of DoFn via FlatMapFunction.

Repository: incubator-beam
Updated Branches:
  refs/heads/master ce3aa657a -> 44b4eba51


[BEAM-757] Use DoFnRunner in the implementation of DoFn via FlatMapFunction.

Implement AggregatorFactory for Spark runner, to be used by DoFnRunner.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/2be9a154
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/2be9a154
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/2be9a154

Branch: refs/heads/master
Commit: 2be9a15412faca1ae74873cf46e39abe9c4f921d
Parents: e776d1d
Author: Sela <an...@paypal.com>
Authored: Sun Dec 11 14:30:24 2016 +0200
Committer: Sela <an...@paypal.com>
Committed: Tue Dec 13 10:04:44 2016 +0200

----------------------------------------------------------------------
 .../spark/aggregators/SparkAggregators.java     |  30 +-
 .../runners/spark/translation/DoFnFunction.java | 110 +++---
 .../spark/translation/MultiDoFnFunction.java    | 135 +++----
 .../spark/translation/SparkProcessContext.java  | 375 ++++---------------
 4 files changed, 231 insertions(+), 419 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2be9a154/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
index 1b06691..657264f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
@@ -21,13 +21,16 @@ package org.apache.beam.runners.spark.aggregators;
 import com.google.common.collect.ImmutableList;
 import java.util.Collection;
 import java.util.Map;
+import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
 import org.apache.beam.sdk.AggregatorValues;
 import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.util.ExecutionContext;
 import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.JavaSparkContext;
 
 /**
- * A utility class for retrieving aggregator values.
+ * A utility class for handling Beam {@link Aggregator}s.
  */
 public class SparkAggregators {
 
@@ -94,4 +97,29 @@ public class SparkAggregators {
                               final JavaSparkContext javaSparkContext) {
     return valueOf(getNamedAggregators(javaSparkContext), name, typeClass);
   }
+
+  /**
+   * An implementation of {@link Aggregator.AggregatorFactory} for the SparkRunner.
+   */
+  public static class Factory implements Aggregator.AggregatorFactory {
+
+    private final SparkRuntimeContext runtimeContext;
+    private final Accumulator<NamedAggregators> accumulator;
+
+    public Factory(SparkRuntimeContext runtimeContext, Accumulator<NamedAggregators> accumulator) {
+      this.runtimeContext = runtimeContext;
+      this.accumulator = accumulator;
+    }
+
+    @Override
+    public <InputT, AccumT, OutputT> Aggregator<InputT, OutputT> createAggregatorForDoFn(
+        Class<?> fnClass,
+        ExecutionContext.StepContext stepContext,
+        String aggregatorName,
+        Combine.CombineFn<InputT, AccumT, OutputT> combine) {
+
+      return runtimeContext.createAggregator(accumulator, aggregatorName, combine);
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2be9a154/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
index f4be121..4c49a7f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
@@ -18,14 +18,18 @@
 
 package org.apache.beam.runners.spark.translation;
 
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.util.BroadcastHelper;
-import org.apache.beam.sdk.transforms.OldDoFn;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.runners.spark.util.SparkSideInputReader;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
@@ -37,80 +41,80 @@ import org.apache.spark.api.java.function.FlatMapFunction;
 /**
  * Beam's Do functions correspond to Spark's FlatMap functions.
  *
- * @param <InputT> Input element type.
+ * @param <InputT>  Input element type.
  * @param <OutputT> Output element type.
  */
 public class DoFnFunction<InputT, OutputT>
     implements FlatMapFunction<Iterator<WindowedValue<InputT>>, WindowedValue<OutputT>> {
-  private final Accumulator<NamedAggregators> accum;
-  private final OldDoFn<InputT, OutputT> mFunction;
-  private final SparkRuntimeContext mRuntimeContext;
-  private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> mSideInputs;
-  private final WindowFn<Object, ?> windowFn;
+
+  private final Accumulator<NamedAggregators> accumulator;
+  private final DoFn<InputT, OutputT> doFn;
+  private final SparkRuntimeContext runtimeContext;
+  private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs;
+  private final WindowingStrategy<?, ?> windowingStrategy;
 
   /**
-   * @param accum             The Spark Accumulator that handles the Beam Aggregators.
-   * @param fn                DoFunction to be wrapped.
-   * @param runtime           Runtime to apply function in.
-   * @param sideInputs        Side inputs used in DoFunction.
-   * @param windowFn          Input {@link WindowFn}.
+   * @param accumulator       The Spark {@link Accumulator} that backs the Beam Aggregators.
+   * @param doFn              The {@link DoFn} to be wrapped.
+   * @param runtimeContext    The {@link SparkRuntimeContext}.
+   * @param sideInputs        Side inputs used in this {@link DoFn}.
+   * @param windowingStrategy Input {@link WindowingStrategy}.
    */
-  public DoFnFunction(Accumulator<NamedAggregators> accum,
-                      OldDoFn<InputT, OutputT> fn,
-                      SparkRuntimeContext runtime,
-                      Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs,
-                      WindowFn<Object, ?> windowFn) {
-    this.accum = accum;
-    this.mFunction = fn;
-    this.mRuntimeContext = runtime;
-    this.mSideInputs = sideInputs;
-    this.windowFn = windowFn;
+  public DoFnFunction(
+      Accumulator<NamedAggregators> accumulator,
+      DoFn<InputT, OutputT> doFn,
+      SparkRuntimeContext runtimeContext,
+      Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs,
+      WindowingStrategy<?, ?> windowingStrategy) {
+
+    this.accumulator = accumulator;
+    this.doFn = doFn;
+    this.runtimeContext = runtimeContext;
+    this.sideInputs = sideInputs;
+    this.windowingStrategy = windowingStrategy;
   }
 
 
   @Override
-  public Iterable<WindowedValue<OutputT>> call(Iterator<WindowedValue<InputT>> iter) throws
-      Exception {
-    return new ProcCtxt(mFunction, mRuntimeContext, mSideInputs, windowFn)
-        .callWithCtxt(iter);
+  public Iterable<WindowedValue<OutputT>> call(
+      Iterator<WindowedValue<InputT>> iter) throws Exception {
+
+    DoFnOutputManager outputManager = new DoFnOutputManager();
+    DoFnRunner<InputT, OutputT> doFnRunner =
+        DoFnRunners.createDefault(
+            runtimeContext.getPipelineOptions(),
+            doFn,
+            new SparkSideInputReader(sideInputs),
+            outputManager,
+            new TupleTag<OutputT>() {},
+            Collections.<TupleTag<?>>emptyList(),
+            new SparkProcessContext.NoOpStepContext(),
+            new SparkAggregators.Factory(runtimeContext, accumulator),
+            windowingStrategy
+        );
+
+    return new SparkProcessContext<>(doFnRunner, outputManager).processPartition(iter);
   }
 
-  private class ProcCtxt extends SparkProcessContext<InputT, OutputT, WindowedValue<OutputT>> {
+  private class DoFnOutputManager
+      implements SparkProcessContext.SparkOutputManager<WindowedValue<OutputT>> {
 
     private final List<WindowedValue<OutputT>> outputs = new LinkedList<>();
 
-    ProcCtxt(OldDoFn<InputT, OutputT> fn,
-             SparkRuntimeContext runtimeContext,
-             Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs,
-             WindowFn<Object, ?> windowFn) {
-      super(fn, runtimeContext, sideInputs, windowFn);
-    }
-
     @Override
-    protected synchronized void outputWindowedValue(WindowedValue<OutputT> o) {
-      outputs.add(o);
-    }
-
-    @Override
-    protected <T> void sideOutputWindowedValue(TupleTag<T> tag, WindowedValue<T> output) {
-      throw new UnsupportedOperationException(
-          "sideOutput is an unsupported operation for doFunctions, use a "
-              + "MultiDoFunction instead.");
-    }
-
-    @Override
-    public Accumulator<NamedAggregators> getAccumulator() {
-      return accum;
+    public void clear() {
+      outputs.clear();
     }
 
     @Override
-    protected void clearOutput() {
-      outputs.clear();
+    public Iterator<WindowedValue<OutputT>> iterator() {
+      return outputs.iterator();
     }
 
     @Override
-    protected Iterator<WindowedValue<OutputT>> getOutputIterator() {
-      return outputs.iterator();
+    @SuppressWarnings("unchecked")
+    public synchronized <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
+      outputs.add((WindowedValue<OutputT>) output);
     }
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2be9a154/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index 8175beb..710c5cd 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -22,20 +22,26 @@ import com.google.common.base.Function;
 import com.google.common.collect.Iterators;
 import com.google.common.collect.LinkedListMultimap;
 import com.google.common.collect.Multimap;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.Map;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.util.BroadcastHelper;
-import org.apache.beam.sdk.transforms.OldDoFn;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.runners.spark.util.SparkSideInputReader;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
+
 import scala.Tuple2;
 
+
 /**
  * DoFunctions ignore side outputs. MultiDoFunctions deal with side outputs by enriching the
  * underlying data with multiple TupleTags.
@@ -44,89 +50,90 @@ import scala.Tuple2;
  * @param <OutputT> Output type for DoFunction.
  */
 public class MultiDoFnFunction<InputT, OutputT>
-    implements PairFlatMapFunction<Iterator<WindowedValue<InputT>>, TupleTag<?>,
-        WindowedValue<?>> {
-  private final Accumulator<NamedAggregators> accum;
-  private final OldDoFn<InputT, OutputT> mFunction;
-  private final SparkRuntimeContext mRuntimeContext;
-  private final TupleTag<OutputT> mMainOutputTag;
-  private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> mSideInputs;
-  private final WindowFn<Object, ?> windowFn;
+    implements PairFlatMapFunction<Iterator<WindowedValue<InputT>>, TupleTag<?>, WindowedValue<?>> {
+
+  private final Accumulator<NamedAggregators> accumulator;
+  private final DoFn<InputT, OutputT> doFn;
+  private final SparkRuntimeContext runtimeContext;
+  private final TupleTag<OutputT> mainOutputTag;
+  private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs;
+  private final WindowingStrategy<?, ?> windowingStrategy;
 
   /**
-   * @param accum             The Spark Accumulator that handles the Beam Aggregators.
-   * @param fn                DoFunction to be wrapped.
-   * @param runtimeContext    Runtime to apply function in.
+   * @param accumulator       The Spark {@link Accumulator} that backs the Beam Aggregators.
+   * @param doFn              The {@link DoFn} to be wrapped.
+   * @param runtimeContext    The {@link SparkRuntimeContext}.
    * @param mainOutputTag     The main output {@link TupleTag}.
-   * @param sideInputs        Side inputs used in DoFunction.
-   * @param windowFn          Input {@link WindowFn}.
+   * @param sideInputs        Side inputs used in this {@link DoFn}.
+   * @param windowingStrategy Input {@link WindowingStrategy}.
    */
-  public MultiDoFnFunction(Accumulator<NamedAggregators> accum,
-                           OldDoFn<InputT, OutputT> fn,
-                           SparkRuntimeContext runtimeContext,
-                           TupleTag<OutputT> mainOutputTag,
-                           Map<TupleTag<?>, KV<WindowingStrategy<?, ?>,
-                               BroadcastHelper<?>>> sideInputs,
-                           WindowFn<Object, ?> windowFn) {
-    this.accum = accum;
-    this.mFunction = fn;
-    this.mRuntimeContext = runtimeContext;
-    this.mMainOutputTag = mainOutputTag;
-    this.mSideInputs = sideInputs;
-    this.windowFn = windowFn;
+  public MultiDoFnFunction(
+      Accumulator<NamedAggregators> accumulator,
+      DoFn<InputT, OutputT> doFn,
+      SparkRuntimeContext runtimeContext,
+      TupleTag<OutputT> mainOutputTag,
+      Map<TupleTag<?>, KV<WindowingStrategy<?, ?>,
+      BroadcastHelper<?>>> sideInputs,
+      WindowingStrategy<?, ?> windowingStrategy) {
+
+    this.accumulator = accumulator;
+    this.doFn = doFn;
+    this.runtimeContext = runtimeContext;
+    this.mainOutputTag = mainOutputTag;
+    this.sideInputs = sideInputs;
+    this.windowingStrategy = windowingStrategy;
   }
 
   @Override
-  public Iterable<Tuple2<TupleTag<?>, WindowedValue<?>>>
-      call(Iterator<WindowedValue<InputT>> iter) throws Exception {
-    return new ProcCtxt(mFunction, mRuntimeContext, mSideInputs, windowFn)
-        .callWithCtxt(iter);
-  }
+  public Iterable<Tuple2<TupleTag<?>, WindowedValue<?>>> call(
+      Iterator<WindowedValue<InputT>> iter) throws Exception {
 
-  private class ProcCtxt
-      extends SparkProcessContext<InputT, OutputT, Tuple2<TupleTag<?>, WindowedValue<?>>> {
+    DoFnOutputManager outputManager = new DoFnOutputManager();
+    DoFnRunner<InputT, OutputT> doFnRunner =
+        DoFnRunners.createDefault(
+            runtimeContext.getPipelineOptions(),
+            doFn,
+            new SparkSideInputReader(sideInputs),
+            outputManager,
+            mainOutputTag,
+            Collections.<TupleTag<?>>emptyList(),
+            new SparkProcessContext.NoOpStepContext(),
+            new SparkAggregators.Factory(runtimeContext, accumulator),
+            windowingStrategy
+        );
 
-    private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create();
+    return new SparkProcessContext<>(doFnRunner, outputManager).processPartition(iter);
+  }
 
-    ProcCtxt(OldDoFn<InputT, OutputT> fn,
-             SparkRuntimeContext runtimeContext,
-             Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs,
-             WindowFn<Object, ?> windowFn) {
-      super(fn, runtimeContext, sideInputs, windowFn);
-    }
+  private class DoFnOutputManager
+      implements SparkProcessContext.SparkOutputManager<Tuple2<TupleTag<?>, WindowedValue<?>>> {
 
-    @Override
-    protected synchronized void outputWindowedValue(WindowedValue<OutputT> o) {
-      outputs.put(mMainOutputTag, o);
-    }
+    private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create();;
 
     @Override
-    protected <T> void sideOutputWindowedValue(TupleTag<T> tag, WindowedValue<T> output) {
-      outputs.put(tag, output);
-    }
-
-    @Override
-    public Accumulator<NamedAggregators> getAccumulator() {
-      return accum;
+    public void clear() {
+      outputs.clear();
     }
 
     @Override
-    protected void clearOutput() {
-      outputs.clear();
+    public Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> iterator() {
+      Iterator<Map.Entry<TupleTag<?>, WindowedValue<?>>> entryIter = outputs.entries().iterator();
+      return Iterators.transform(entryIter, this.<TupleTag<?>, WindowedValue<?>>entryToTupleFn());
     }
 
-    @Override
-    protected Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> getOutputIterator() {
-      return Iterators.transform(outputs.entries().iterator(),
-          new Function<Map.Entry<TupleTag<?>, WindowedValue<?>>,
-              Tuple2<TupleTag<?>, WindowedValue<?>>>() {
+    private <K, V> Function<Map.Entry<K, V>, Tuple2<K, V>> entryToTupleFn() {
+      return new Function<Map.Entry<K, V>, Tuple2<K, V>>() {
         @Override
-        public Tuple2<TupleTag<?>, WindowedValue<?>> apply(Map.Entry<TupleTag<?>,
-            WindowedValue<?>> input) {
-          return new Tuple2<TupleTag<?>, WindowedValue<?>>(input.getKey(), input.getValue());
+        public Tuple2<K, V> apply(Map.Entry<K, V> en) {
+          return new Tuple2<>(en.getKey(), en.getValue());
         }
-      });
+      };
     }
 
+    @Override
+    @SuppressWarnings("unchecked")
+    public synchronized <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
+      outputs.put(tag, output);
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2be9a154/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index bb0ec2f..efd8202 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -18,304 +18,129 @@
 
 package org.apache.beam.runners.spark.translation;
 
-import static com.google.common.base.Preconditions.checkState;
-
 import com.google.common.collect.AbstractIterator;
-import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
-import java.util.Collection;
+import java.io.IOException;
 import java.util.Iterator;
-import java.util.Map;
-import org.apache.beam.runners.spark.aggregators.NamedAggregators;
-import org.apache.beam.runners.spark.util.BroadcastHelper;
-import org.apache.beam.runners.spark.util.SparkSideInputReader;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.transforms.Aggregator;
-import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.OldDoFn;
-import org.apache.beam.sdk.transforms.OldDoFn.RequiresWindowAccess;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners.OutputManager;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.PaneInfo;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
-import org.apache.beam.sdk.util.SideInputReader;
-import org.apache.beam.sdk.util.SystemDoFnInternal;
+import org.apache.beam.sdk.util.ExecutionContext.StepContext;
 import org.apache.beam.sdk.util.TimerInternals;
-import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.WindowingInternals;
-import org.apache.beam.sdk.util.WindowingStrategy;
-import org.apache.beam.sdk.util.state.InMemoryStateInternals;
 import org.apache.beam.sdk.util.state.StateInternals;
-import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
-import org.apache.spark.Accumulator;
-import org.joda.time.Instant;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
 
 
 /**
- * Spark runner process context.
+ * Spark runner process context processes Spark partitions using Beam's {@link DoFnRunner}.
  */
-public abstract class SparkProcessContext<InputT, OutputT, ValueT>
-    extends OldDoFn<InputT, OutputT>.ProcessContext {
-  private static final Logger LOG = LoggerFactory.getLogger(SparkProcessContext.class);
+class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
 
-  private final OldDoFn<InputT, OutputT> fn;
-  private final SparkRuntimeContext mRuntimeContext;
-  private final SideInputReader sideInputReader;
-  private final WindowFn<Object, ?> windowFn;
+  private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
+  private final SparkOutputManager<OutputT> outputManager;
 
-  WindowedValue<InputT> windowedValue;
+  SparkProcessContext(
+      DoFnRunner<FnInputT, FnOutputT> doFnRunner,
+      SparkOutputManager<OutputT> outputManager) {
 
-  SparkProcessContext(OldDoFn<InputT, OutputT> fn,
-                      SparkRuntimeContext runtime,
-                      Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs,
-                      WindowFn<Object, ?> windowFn) {
-    fn.super();
-    this.fn = fn;
-    this.mRuntimeContext = runtime;
-    this.sideInputReader = new SparkSideInputReader(sideInputs);
-    this.windowFn = windowFn;
+    this.doFnRunner = doFnRunner;
+    this.outputManager = outputManager;
   }
 
-  void setup() {
-    setupDelegateAggregators();
-  }
+  Iterable<OutputT> processPartition(
+      Iterator<WindowedValue<FnInputT>> partition) throws Exception {
 
-  Iterable<ValueT> callWithCtxt(Iterator<WindowedValue<InputT>> iter) throws Exception{
-    this.setup();
-    // skip if bundle is empty.
-    if (!iter.hasNext()) {
+    // skip if partition is empty.
+    if (!partition.hasNext()) {
       return Lists.newArrayList();
     }
-    try {
-      fn.setup();
-      fn.startBundle(this);
-      return this.getOutputIterable(iter, fn);
-    } catch (Exception e) {
-      try {
-        // this teardown handles exceptions encountered in setup() and startBundle(). teardown
-        // after execution or due to exceptions in process element is called in the iterator
-        // produced by ctxt.getOutputIterable returned from this method.
-        fn.teardown();
-      } catch (Exception teardownException) {
-        LOG.error(
-            "Suppressing exception while tearing down Function {}", fn, teardownException);
-        e.addSuppressed(teardownException);
-      }
-      throw wrapUserCodeException(e);
-    }
+    // call startBundle() before beginning to process the partition.
+    doFnRunner.startBundle();
+    // process the partition; finishBundle() is called from within the output iterator.
+    return this.getOutputIterable(partition, doFnRunner);
   }
 
-  @Override
-  public PipelineOptions getPipelineOptions() {
-    return mRuntimeContext.getPipelineOptions();
+  private void clearOutput() {
+    outputManager.clear();
   }
 
-  @Override
-  public <T> T sideInput(PCollectionView<T> view) {
-    //validate element window.
-    final Collection<? extends BoundedWindow> elementWindows = windowedValue.getWindows();
-    checkState(elementWindows.size() == 1, "sideInput can only be called when the main "
-        + "input element is in exactly one window");
-    return sideInputReader.get(view, elementWindows.iterator().next());
+  private Iterator<OutputT> getOutputIterator() {
+    return outputManager.iterator();
   }
 
-  @Override
-  public <AggregatorInputT, AggregatorOutputT>
-  Aggregator<AggregatorInputT, AggregatorOutputT> createAggregatorInternal(
-      String named,
-      Combine.CombineFn<AggregatorInputT, ?, AggregatorOutputT> combineFn) {
-    return mRuntimeContext.createAggregator(getAccumulator(), named, combineFn);
-  }
-
-  public abstract Accumulator<NamedAggregators> getAccumulator();
+  private Iterable<OutputT> getOutputIterable(
+      final Iterator<WindowedValue<FnInputT>> iter,
+      final DoFnRunner<FnInputT, FnOutputT> doFnRunner) {
 
-  @Override
-  public InputT element() {
-    return windowedValue.getValue();
+    return new Iterable<OutputT>() {
+      @Override
+      public Iterator<OutputT> iterator() {
+        return new ProcCtxtIterator(iter, doFnRunner);
+      }
+    };
   }
 
-  @Override
-  public void output(OutputT output) {
-    outputWithTimestamp(output, windowedValue != null ? windowedValue.getTimestamp() : null);
-  }
+  interface SparkOutputManager<T> extends OutputManager, Iterable<T> {
 
-  @Override
-  public void outputWithTimestamp(OutputT output, Instant timestamp) {
-    if (windowedValue == null) {
-      // this is start/finishBundle.
-      outputWindowedValue(noElementWindowedValue(output, timestamp, windowFn));
-    } else {
-      outputWindowedValue(WindowedValue.of(output, timestamp, windowedValue.getWindows(),
-          windowedValue.getPane()));
-    }
-  }
+    void clear();
 
-  @Override
-  public <T> void sideOutput(TupleTag<T> tag, T output) {
-    sideOutputWithTimestamp(
-        tag, output, windowedValue != null ? windowedValue.getTimestamp() : null);
   }
 
-  @Override
-  public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
-    if (windowedValue == null) {
-      // this is start/finishBundle.
-      sideOutputWindowedValue(tag, noElementWindowedValue(output, timestamp, windowFn));
-    } else {
-      sideOutputWindowedValue(tag, WindowedValue.of(output, timestamp, windowedValue.getWindows(),
-          windowedValue.getPane()));
+  static class NoOpStepContext implements StepContext {
+    @Override
+    public String getStepName() {
+      return null;
     }
-  }
 
-  protected abstract void outputWindowedValue(WindowedValue<OutputT> output);
-
-  protected abstract <T> void sideOutputWindowedValue(TupleTag<T> tag, WindowedValue<T> output);
+    @Override
+    public String getTransformName() {
+      return null;
+    }
 
-  static <T, W extends BoundedWindow> WindowedValue<T> noElementWindowedValue(
-      final T output, final Instant timestamp, WindowFn<Object, W> windowFn) {
-    WindowFn<Object, W>.AssignContext assignContext =
-        windowFn.new AssignContext() {
+    @Override
+    public void noteOutput(WindowedValue<?> output) { }
 
-          @Override
-          public Object element() {
-            return output;
-          }
+    @Override
+    public void noteSideOutput(TupleTag<?> tag, WindowedValue<?> output) { }
 
-          @Override
-          public Instant timestamp() {
-            if (timestamp != null) {
-              return timestamp;
-            }
-            throw new UnsupportedOperationException(
-                "outputWithTimestamp was called with " + "null timestamp.");
-          }
+    @Override
+    public <T, W extends BoundedWindow> void writePCollectionViewData(
+        TupleTag<?> tag,
+        Iterable<WindowedValue<T>> data,
+        Coder<Iterable<WindowedValue<T>>> dataCoder,
+        W window,
+        Coder<W> windowCoder) throws IOException { }
 
-          @Override
-          public BoundedWindow window() {
-            throw new UnsupportedOperationException(
-                "Window not available for " + "start/finishBundle output.");
-          }
-        };
-    try {
-      @SuppressWarnings("unchecked")
-      Collection<? extends BoundedWindow> windows = windowFn.assignWindows(assignContext);
-      Instant outputTimestamp = timestamp != null ? timestamp : BoundedWindow.TIMESTAMP_MIN_VALUE;
-      return WindowedValue.of(output, outputTimestamp, windows, PaneInfo.NO_FIRING);
-    } catch (Exception e) {
-      throw new RuntimeException("Failed to assign windows at start/finishBundle.", e);
+    @Override
+    public StateInternals<?> stateInternals() {
+      return null;
     }
-  }
 
-  @Override
-  public Instant timestamp() {
-    return windowedValue.getTimestamp();
-  }
-
-  @Override
-  public BoundedWindow window() {
-    if (!(fn instanceof OldDoFn.RequiresWindowAccess)) {
-      throw new UnsupportedOperationException(
-          "window() is only available in the context of a OldDoFn marked as RequiresWindowAccess.");
+    @Override
+    public TimerInternals timerInternals() {
+      return null;
     }
-    return Iterables.getOnlyElement(windowedValue.getWindows());
-  }
-
-  @Override
-  public PaneInfo pane() {
-    return windowedValue.getPane();
-  }
-
-  @Override
-  public WindowingInternals<InputT, OutputT> windowingInternals() {
-    return new WindowingInternals<InputT, OutputT>() {
-
-      @Override
-      public Collection<? extends BoundedWindow> windows() {
-        return windowedValue.getWindows();
-      }
-
-      @Override
-      public void outputWindowedValue(
-          OutputT output,
-          Instant timestamp,
-          Collection<? extends BoundedWindow> windows,
-          PaneInfo paneInfo) {
-        SparkProcessContext.this.outputWindowedValue(
-            WindowedValue.of(output, timestamp, windows, paneInfo));
-      }
-
-      @Override
-      public <SideOutputT> void sideOutputWindowedValue(
-          TupleTag<SideOutputT> tag,
-          SideOutputT output,
-          Instant timestamp,
-          Collection<? extends BoundedWindow> windows,
-          PaneInfo paneInfo) {
-        SparkProcessContext.this.sideOutputWindowedValue(
-            tag, WindowedValue.of(output, timestamp, windows, paneInfo));
-      }
-
-      @Override
-      public StateInternals stateInternals() {
-        //TODO: implement state internals.
-        // This is a temporary placeholder to get the TfIdfTest
-        // working for the initial Beam code drop.
-        return InMemoryStateInternals.forKey("DUMMY");
-      }
-
-      @Override
-      public TimerInternals timerInternals() {
-        throw new UnsupportedOperationException(
-            "WindowingInternals#timerInternals() is not yet supported.");
-      }
-
-      @Override
-      public PaneInfo pane() {
-        return windowedValue.getPane();
-      }
-
-      @Override
-      public <T> T sideInput(PCollectionView<T> view, BoundedWindow sideInputWindow) {
-        throw new UnsupportedOperationException(
-            "WindowingInternals#sideInput() is not yet supported.");
-      }
-    };
-  }
-
-  protected abstract void clearOutput();
-
-  protected abstract Iterator<ValueT> getOutputIterator();
-
-  protected Iterable<ValueT> getOutputIterable(final Iterator<WindowedValue<InputT>> iter,
-                                               final OldDoFn<InputT, OutputT> doFn) {
-    return new Iterable<ValueT>() {
-      @Override
-      public Iterator<ValueT> iterator() {
-        return new ProcCtxtIterator(iter, doFn);
-      }
-    };
   }
 
-  private class ProcCtxtIterator extends AbstractIterator<ValueT> {
+  private class ProcCtxtIterator extends AbstractIterator<OutputT> {
 
-    private final Iterator<WindowedValue<InputT>> inputIterator;
-    private final OldDoFn<InputT, OutputT> doFn;
-    private Iterator<ValueT> outputIterator;
+    private final Iterator<WindowedValue<FnInputT>> inputIterator;
+    private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
+    private Iterator<OutputT> outputIterator;
     private boolean calledFinish;
 
-    ProcCtxtIterator(Iterator<WindowedValue<InputT>> iterator, OldDoFn<InputT, OutputT> doFn) {
+    ProcCtxtIterator(
+        Iterator<WindowedValue<FnInputT>> iterator,
+        DoFnRunner<FnInputT, FnOutputT> doFnRunner) {
       this.inputIterator = iterator;
-      this.doFn = doFn;
+      this.doFnRunner = doFnRunner;
       this.outputIterator = getOutputIterator();
     }
 
     @Override
-    protected ValueT computeNext() {
+    protected OutputT computeNext() {
       // Process each element from the (input) iterator, which produces, zero, one or more
       // output elements (of type V) in the output iterator. Note that the output
       // collection (and iterator) is reset between each call to processElement, so the
@@ -327,72 +152,20 @@ public abstract class SparkProcessContext<InputT, OutputT, ValueT>
         } else if (inputIterator.hasNext()) {
           clearOutput();
           // grab the next element and process it.
-          windowedValue = inputIterator.next();
-          if (windowedValue.getWindows().size() <= 1
-              || (!RequiresWindowAccess.class.isAssignableFrom(doFn.getClass())
-                  && sideInputReader.isEmpty())) {
-            // if there's no reason to explode, process compacted.
-            invokeProcessElement();
-          } else {
-            // explode and process the element in each of it's assigned windows.
-            for (WindowedValue<InputT> wv: windowedValue.explodeWindows()) {
-              windowedValue = wv;
-              invokeProcessElement();
-            }
-          }
+          doFnRunner.processElement(inputIterator.next());
           outputIterator = getOutputIterator();
         } else {
           // no more input to consume, but finishBundle can produce more output
           if (!calledFinish) {
-            windowedValue = null; // clear the last element processed
             clearOutput();
-            try {
-              calledFinish = true;
-              doFn.finishBundle(SparkProcessContext.this);
-            } catch (Exception e) {
-              handleProcessingException(e);
-              throw wrapUserCodeException(e);
-            }
+            calledFinish = true;
+            doFnRunner.finishBundle();
             outputIterator = getOutputIterator();
             continue; // try to consume outputIterator from start of loop
           }
-          try {
-            doFn.teardown();
-          } catch (Exception e) {
-            LOG.error(
-                "Suppressing teardown exception that occurred after processing entire input", e);
-          }
           return endOfData();
         }
       }
     }
-
-    private void invokeProcessElement() {
-      try {
-        doFn.processElement(SparkProcessContext.this);
-      } catch (Exception e) {
-        handleProcessingException(e);
-        throw wrapUserCodeException(e);
-      }
-    }
-
-    private void handleProcessingException(Exception e) {
-      try {
-        doFn.teardown();
-      } catch (Exception e1) {
-        LOG.error("Exception while cleaning up DoFn", e1);
-        e.addSuppressed(e1);
-      }
-    }
   }
-
-
-  private RuntimeException wrapUserCodeException(Throwable t) {
-    throw UserCodeException.wrapIf(!isSystemDoFn(), t);
-  }
-
-  private boolean isSystemDoFn() {
-    return fn.getClass().isAnnotationPresent(SystemDoFnInternal.class);
-  }
-
 }


[2/3] incubator-beam git commit: [BEAM-807] Replace OldDoFn with DoFn.

Posted by ke...@apache.org.
[BEAM-807] Replace OldDoFn with DoFn.

Add a custom AssignWindows implementation.

Setup and teardown DoFn.

Add implementation for GroupAlsoByWindow via flatMap.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/4ffed3e0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/4ffed3e0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/4ffed3e0

Branch: refs/heads/master
Commit: 4ffed3e09a2f0ec3583098f6cfd53a2ddcc6f8c2
Parents: 2be9a15
Author: Sela <an...@paypal.com>
Authored: Sun Dec 11 14:32:49 2016 +0200
Committer: Sela <an...@paypal.com>
Committed: Tue Dec 13 10:05:18 2016 +0200

----------------------------------------------------------------------
 .../beam/runners/spark/examples/WordCount.java  |   6 +-
 .../runners/spark/translation/DoFnFunction.java |   2 +-
 .../translation/GroupCombineFunctions.java      |  23 +-
 .../spark/translation/MultiDoFnFunction.java    |   2 +-
 .../spark/translation/SparkAssignWindowFn.java  |  69 ++++++
 .../translation/SparkGroupAlsoByWindowFn.java   | 214 +++++++++++++++++++
 .../spark/translation/SparkProcessContext.java  |  10 +
 .../spark/translation/TransformTranslator.java  |  31 +--
 .../streaming/StreamingTransformTranslator.java |  35 ++-
 .../streaming/utils/PAssertStreaming.java       |  26 +--
 10 files changed, 345 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
index b2672b5..1252d12 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
@@ -25,8 +25,8 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.MapElements;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SimpleFunction;
@@ -44,11 +44,11 @@ public class WordCount {
    * of-line. This DoFn tokenizes lines of text into individual words; we pass it to a ParDo in the
    * pipeline.
    */
-  static class ExtractWordsFn extends OldDoFn<String, String> {
+  static class ExtractWordsFn extends DoFn<String, String> {
     private final Aggregator<Long, Long> emptyLines =
         createAggregator("emptyLines", new Sum.SumLongFn());
 
-    @Override
+    @ProcessElement
     public void processElement(ProcessContext c) {
       if (c.element().trim().isEmpty()) {
         emptyLines.addValue(1L);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
index 4c49a7f..6a641b5 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
@@ -93,7 +93,7 @@ public class DoFnFunction<InputT, OutputT>
             windowingStrategy
         );
 
-    return new SparkProcessContext<>(doFnRunner, outputManager).processPartition(iter);
+    return new SparkProcessContext<>(doFn, doFnRunner, outputManager).processPartition(iter);
   }
 
   private class DoFnOutputManager

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
index 421b1b0..4875b0c 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
@@ -18,11 +18,9 @@
 
 package org.apache.beam.runners.spark.translation;
 
-
 import com.google.common.collect.Lists;
 import java.util.Collections;
 import java.util.Map;
-import org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn;
 import org.apache.beam.runners.core.SystemReduceFn;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -33,9 +31,7 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.transforms.CombineWithContext;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
@@ -59,7 +55,7 @@ public class GroupCombineFunctions {
   /**
    * Apply {@link org.apache.beam.sdk.transforms.GroupByKey} to a Spark RDD.
    */
-  public static <K, V,  W extends BoundedWindow> JavaRDD<WindowedValue<KV<K,
+  public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K,
       Iterable<V>>>> groupByKey(JavaRDD<WindowedValue<KV<K, V>>> rdd,
                                 Accumulator<NamedAggregators> accum,
                                 KvCoder<K, V> coder,
@@ -86,15 +82,14 @@ public class GroupCombineFunctions {
             .map(WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction());
 
     //--- now group also by window.
-    @SuppressWarnings("unchecked")
-    WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn();
-    // GroupAlsoByWindow current uses a dummy in-memory StateInternals
-    OldDoFn<KV<K, Iterable<WindowedValue<V>>>, KV<K, Iterable<V>>> gabwDoFn =
-        new GroupAlsoByWindowsViaOutputBufferDoFn<K, V, Iterable<V>, W>(
-            windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory<K>(),
-                SystemReduceFn.<K, V, W>buffering(valueCoder));
-    return groupedByKey.mapPartitions(new DoFnFunction<>(accum, gabwDoFn, runtimeContext, null,
-        windowFn));
+    // GroupAlsoByWindow currently uses a dummy in-memory StateInternals
+    return groupedByKey.flatMap(
+        new SparkGroupAlsoByWindowFn<>(
+            windowingStrategy,
+            new TranslationUtils.InMemoryStateInternalsFactory<K>(),
+            SystemReduceFn.<K, V, W>buffering(valueCoder),
+            runtimeContext,
+            accum));
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index 710c5cd..8a55369 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -102,7 +102,7 @@ public class MultiDoFnFunction<InputT, OutputT>
             windowingStrategy
         );
 
-    return new SparkProcessContext<>(doFnRunner, outputManager).processPartition(iter);
+    return new SparkProcessContext<>(doFn, doFnRunner, outputManager).processPartition(iter);
   }
 
   private class DoFnOutputManager

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
new file mode 100644
index 0000000..9d7ed7d
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.common.collect.Iterables;
+import java.util.Collection;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.spark.api.java.function.Function;
+import org.joda.time.Instant;
+
+
+/**
+ * An implementation of {@link org.apache.beam.runners.core.AssignWindows} for the Spark runner.
+ */
+public class SparkAssignWindowFn<T, W extends BoundedWindow>
+    implements Function<WindowedValue<T>, WindowedValue<T>> {
+
+  private WindowFn<? super T, W> fn;
+
+  public SparkAssignWindowFn(WindowFn<? super T, W> fn) {
+    this.fn = fn;
+  }
+
+  @Override
+  @SuppressWarnings("unchecked")
+  public WindowedValue<T> call(WindowedValue<T> windowedValue) throws Exception {
+    final BoundedWindow boundedWindow = Iterables.getOnlyElement(windowedValue.getWindows());
+    final T element = windowedValue.getValue();
+    final Instant timestamp = windowedValue.getTimestamp();
+    Collection<W> windows =
+        ((WindowFn<T, W>) fn).assignWindows(
+            ((WindowFn<T, W>) fn).new AssignContext() {
+                @Override
+                public T element() {
+                  return element;
+                }
+
+                @Override
+                public Instant timestamp() {
+                  return timestamp;
+                }
+
+                @Override
+                public BoundedWindow window() {
+                  return boundedWindow;
+                }
+              });
+    return WindowedValue.of(element, timestamp, windows, PaneInfo.NO_FIRING);
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
new file mode 100644
index 0000000..87d3f50
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.common.collect.Iterables;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import org.apache.beam.runners.core.GroupAlsoByWindowsDoFn;
+import org.apache.beam.runners.core.OutputWindowedValue;
+import org.apache.beam.runners.core.ReduceFnRunner;
+import org.apache.beam.runners.core.SystemReduceFn;
+import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine;
+import org.apache.beam.runners.core.triggers.TriggerStateMachines;
+import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.SideInputReader;
+import org.apache.beam.sdk.util.TimerInternals;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.InMemoryTimerInternals;
+import org.apache.beam.sdk.util.state.StateInternals;
+import org.apache.beam.sdk.util.state.StateInternalsFactory;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.Accumulator;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.joda.time.Instant;
+
+
+
+/**
+ * An implementation of {@link org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn}
+ * for the Spark runner.
+ */
+public class SparkGroupAlsoByWindowFn<K, InputT, W extends BoundedWindow>
+    implements FlatMapFunction<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>,
+        WindowedValue<KV<K, Iterable<InputT>>>> {
+
+  private final WindowingStrategy<?, W> windowingStrategy;
+  private final StateInternalsFactory<K> stateInternalsFactory;
+  private final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn;
+  private final SparkRuntimeContext runtimeContext;
+  private final Aggregator<Long, Long> droppedDueToClosedWindow;
+
+
+  public SparkGroupAlsoByWindowFn(
+      WindowingStrategy<?, W> windowingStrategy,
+      StateInternalsFactory<K> stateInternalsFactory,
+      SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn,
+      SparkRuntimeContext runtimeContext,
+      Accumulator<NamedAggregators> accumulator) {
+    this.windowingStrategy = windowingStrategy;
+    this.stateInternalsFactory = stateInternalsFactory;
+    this.reduceFn = reduceFn;
+    this.runtimeContext = runtimeContext;
+
+    droppedDueToClosedWindow = runtimeContext.createAggregator(
+        accumulator,
+        GroupAlsoByWindowsDoFn.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER,
+        new Sum.SumLongFn());
+  }
+
+  @Override
+  public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call(
+      WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>> windowedValue) throws Exception {
+    K key = windowedValue.getValue().getKey();
+    Iterable<WindowedValue<InputT>> inputs = windowedValue.getValue().getValue();
+
+    //------ based on GroupAlsoByWindowsViaOutputBufferDoFn ------//
+
+    // Used with Batch, we know that all the data is available for this key. We can't use the
+    // timer manager from the context because it doesn't exist. So we create one and emulate the
+    // watermark, knowing that we have all data and it is in timestamp order.
+    InMemoryTimerInternals timerInternals = new InMemoryTimerInternals();
+    timerInternals.advanceProcessingTime(Instant.now());
+    timerInternals.advanceSynchronizedProcessingTime(Instant.now());
+    StateInternals<K> stateInternals = stateInternalsFactory.stateInternalsForKey(key);
+    GABWOutputWindowedValue<K, InputT> outputter = new GABWOutputWindowedValue<>();
+
+    ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner =
+        new ReduceFnRunner<>(
+            key,
+            windowingStrategy,
+            ExecutableTriggerStateMachine.create(
+                TriggerStateMachines.stateMachineForTrigger(windowingStrategy.getTrigger())),
+            stateInternals,
+            timerInternals,
+            outputter,
+            new SideInputReader() {
+                @Override
+                public <T> T get(PCollectionView<T> view, BoundedWindow sideInputWindow) {
+                  throw new UnsupportedOperationException(
+                      "GroupAlsoByWindow must not have side inputs");
+                }
+
+                @Override
+                public <T> boolean contains(PCollectionView<T> view) {
+                  throw new UnsupportedOperationException(
+                      "GroupAlsoByWindow must not have side inputs");
+                }
+
+                @Override
+                public boolean isEmpty() {
+                  throw new UnsupportedOperationException(
+                      "GroupAlsoByWindow must not have side inputs");
+                }
+              },
+            droppedDueToClosedWindow,
+            reduceFn,
+            runtimeContext.getPipelineOptions());
+
+    Iterable<List<WindowedValue<InputT>>> chunks = Iterables.partition(inputs, 1000);
+    for (Iterable<WindowedValue<InputT>> chunk : chunks) {
+      // Process the chunk of elements.
+      reduceFnRunner.processElements(chunk);
+
+      // Then, since elements are sorted by their timestamp, advance the input watermark
+      // to the first element.
+      timerInternals.advanceInputWatermark(chunk.iterator().next().getTimestamp());
+      // Advance the processing times.
+      timerInternals.advanceProcessingTime(Instant.now());
+      timerInternals.advanceSynchronizedProcessingTime(Instant.now());
+
+      // Fire all the eligible timers.
+      fireEligibleTimers(timerInternals, reduceFnRunner);
+
+      // Leave the output watermark undefined. Since there's no late data in batch mode
+      // there's really no need to track it as we do for streaming.
+    }
+
+    // Finish any pending windows by advancing the input watermark to infinity.
+    timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+    // Finally, advance the processing time to infinity to fire any timers.
+    timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+    timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+    fireEligibleTimers(timerInternals, reduceFnRunner);
+
+    reduceFnRunner.persist();
+
+    return outputter.getOutputs();
+  }
+
+  private void fireEligibleTimers(InMemoryTimerInternals timerInternals,
+      ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner) throws Exception {
+    List<TimerInternals.TimerData> timers = new ArrayList<>();
+    while (true) {
+        TimerInternals.TimerData timer;
+        while ((timer = timerInternals.removeNextEventTimer()) != null) {
+          timers.add(timer);
+        }
+        while ((timer = timerInternals.removeNextProcessingTimer()) != null) {
+          timers.add(timer);
+        }
+        while ((timer = timerInternals.removeNextSynchronizedProcessingTimer()) != null) {
+          timers.add(timer);
+        }
+        if (timers.isEmpty()) {
+          break;
+        }
+        reduceFnRunner.onTimers(timers);
+        timers.clear();
+    }
+  }
+
+  private static class GABWOutputWindowedValue<K, V>
+      implements OutputWindowedValue<KV<K, Iterable<V>>> {
+    private final List<WindowedValue<KV<K, Iterable<V>>>> outputs = new ArrayList<>();
+
+    @Override
+    public void outputWindowedValue(
+        KV<K, Iterable<V>> output,
+        Instant timestamp,
+        Collection<? extends BoundedWindow> windows,
+        PaneInfo pane) {
+      outputs.add(WindowedValue.of(output, timestamp, windows, pane));
+    }
+
+    @Override
+    public <SideOutputT> void sideOutputWindowedValue(
+        TupleTag<SideOutputT> tag,
+        SideOutputT output,
+        Instant timestamp,
+        Collection<? extends BoundedWindow> windows, PaneInfo pane) {
+      throw new UnsupportedOperationException("GroupAlsoByWindow should not use side outputs.");
+    }
+
+    Iterable<WindowedValue<KV<K, Iterable<V>>>> getOutputs() {
+      return outputs;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index efd8202..3a31cae 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -25,6 +25,8 @@ import java.util.Iterator;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.DoFnRunners.OutputManager;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.ExecutionContext.StepContext;
 import org.apache.beam.sdk.util.TimerInternals;
@@ -38,13 +40,16 @@ import org.apache.beam.sdk.values.TupleTag;
  */
 class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
 
+  private final DoFn<FnInputT, FnOutputT> doFn;
   private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
   private final SparkOutputManager<OutputT> outputManager;
 
   SparkProcessContext(
+      DoFn<FnInputT, FnOutputT> doFn,
       DoFnRunner<FnInputT, FnOutputT> doFnRunner,
       SparkOutputManager<OutputT> outputManager) {
 
+    this.doFn = doFn;
     this.doFnRunner = doFnRunner;
     this.outputManager = outputManager;
   }
@@ -52,6 +57,9 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
   Iterable<OutputT> processPartition(
       Iterator<WindowedValue<FnInputT>> partition) throws Exception {
 
+    // setup DoFn.
+    DoFnInvokers.invokerFor(doFn).invokeSetup();
+
     // skip if partition is empty.
     if (!partition.hasNext()) {
       return Lists.newArrayList();
@@ -160,6 +168,8 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
             clearOutput();
             calledFinish = true;
             doFnRunner.finishBundle();
+            // teardown DoFn.
+            DoFnInvokers.invokerFor(doFn).invokeTeardown();
             outputIterator = getOutputIterator();
             continue; // try to consume outputIterator from start of loop
           }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 964eb37..ac91892 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -32,7 +32,6 @@ import java.util.Map;
 import org.apache.avro.mapred.AvroKey;
 import org.apache.avro.mapreduce.AvroJob;
 import org.apache.avro.mapreduce.AvroKeyInputFormat;
-import org.apache.beam.runners.core.AssignWindowsDoFn;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -54,13 +53,11 @@ import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.CombineFnUtil;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
@@ -235,16 +232,15 @@ public final class TransformTranslator {
         @SuppressWarnings("unchecked")
         JavaRDD<WindowedValue<InputT>> inRDD =
             ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
-        @SuppressWarnings("unchecked")
-        final WindowFn<Object, ?> windowFn =
-            (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn();
+        WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
         Accumulator<NamedAggregators> accum =
             SparkAggregators.getNamedAggregators(context.getSparkContext());
         Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
         context.putDataset(transform,
-            new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, transform.getFn(),
-                context.getRuntimeContext(), sideInputs, windowFn))));
+            new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, doFn,
+                context.getRuntimeContext(), sideInputs, windowingStrategy))));
       }
     };
   }
@@ -259,16 +255,15 @@ public final class TransformTranslator {
         @SuppressWarnings("unchecked")
         JavaRDD<WindowedValue<InputT>> inRDD =
             ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
-        @SuppressWarnings("unchecked")
-        final WindowFn<Object, ?> windowFn =
-            (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn();
+        WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
         Accumulator<NamedAggregators> accum =
             SparkAggregators.getNamedAggregators(context.getSparkContext());
         JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD
             .mapPartitionsToPair(
-                new MultiDoFnFunction<>(accum, transform.getFn(), context.getRuntimeContext(),
+                new MultiDoFnFunction<>(accum, doFn, context.getRuntimeContext(),
                 transform.getMainOutputTag(), TranslationUtils.getSideInputs(
-                    transform.getSideInputs(), context), windowFn)).cache();
+                    transform.getSideInputs(), context), windowingStrategy)).cache();
         PCollectionTuple pct = context.getOutput(transform);
         for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) {
           @SuppressWarnings("unchecked")
@@ -508,14 +503,8 @@ public final class TransformTranslator {
         if (TranslationUtils.skipAssignWindows(transform, context)) {
           context.putDataset(transform, new BoundedDataset<>(inRDD));
         } else {
-          @SuppressWarnings("unchecked")
-          WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn();
-          OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn);
-          Accumulator<NamedAggregators> accum =
-              SparkAggregators.getNamedAggregators(context.getSparkContext());
-          context.putDataset(transform,
-              new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, addWindowsDoFn,
-                  context.getRuntimeContext(), null, null))));
+          context.putDataset(transform, new BoundedDataset<>(
+              inRDD.map(new SparkAssignWindowFn<>(transform.getWindowFn()))));
         }
       }
     };

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 00df7d4..27204ed 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -24,7 +24,6 @@ import com.google.common.collect.Maps;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
-import org.apache.beam.runners.core.AssignWindowsDoFn;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.io.ConsoleIO;
@@ -36,6 +35,7 @@ import org.apache.beam.runners.spark.translation.DoFnFunction;
 import org.apache.beam.runners.spark.translation.EvaluationContext;
 import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
 import org.apache.beam.runners.spark.translation.MultiDoFnFunction;
+import org.apache.beam.runners.spark.translation.SparkAssignWindowFn;
 import org.apache.beam.runners.spark.translation.SparkKeyedCombineFn;
 import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
@@ -51,7 +51,6 @@ import org.apache.beam.sdk.transforms.CombineWithContext;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -163,7 +162,7 @@ final class StreamingTransformTranslator {
   private static <T, W extends BoundedWindow> TransformEvaluator<Window.Bound<T>> window() {
     return new TransformEvaluator<Window.Bound<T>>() {
       @Override
-      public void evaluate(Window.Bound<T> transform, EvaluationContext context) {
+      public void evaluate(final Window.Bound<T> transform, EvaluationContext context) {
         @SuppressWarnings("unchecked")
         WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn();
         @SuppressWarnings("unchecked")
@@ -189,16 +188,11 @@ final class StreamingTransformTranslator {
         if (TranslationUtils.skipAssignWindows(transform, context)) {
           context.putDataset(transform, new UnboundedDataset<>(windowedDStream));
         } else {
-          final OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn);
-          final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
           JavaDStream<WindowedValue<T>> outStream = windowedDStream.transform(
               new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() {
             @Override
             public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> rdd) throws Exception {
-              final Accumulator<NamedAggregators> accum =
-                  SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context()));
-              return rdd.mapPartitions(
-                new DoFnFunction<>(accum, addWindowsDoFn, runtimeContext, null, null));
+              return rdd.map(new SparkAssignWindowFn<>(transform.getWindowFn()));
             }
           });
           context.putDataset(transform, new UnboundedDataset<>(outStream));
@@ -350,13 +344,13 @@ final class StreamingTransformTranslator {
       @Override
       public void evaluate(final ParDo.Bound<InputT, OutputT> transform,
                            final EvaluationContext context) {
-        DoFn<InputT, OutputT> doFn = transform.getNewFn();
+        final DoFn<InputT, OutputT> doFn = transform.getNewFn();
         rejectStateAndTimers(doFn);
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
         final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
-        final WindowFn<Object, ?> windowFn =
-            (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn();
+        final WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
         JavaDStream<WindowedValue<InputT>> dStream =
             ((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream();
 
@@ -369,7 +363,7 @@ final class StreamingTransformTranslator {
             final Accumulator<NamedAggregators> accum =
                 SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context()));
             return rdd.mapPartitions(
-                new DoFnFunction<>(accum, transform.getFn(), runtimeContext, sideInputs, windowFn));
+                new DoFnFunction<>(accum, doFn, runtimeContext, sideInputs, windowingStrategy));
           }
         });
 
@@ -384,14 +378,13 @@ final class StreamingTransformTranslator {
       @Override
       public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform,
                            final EvaluationContext context) {
-        DoFn<InputT, OutputT> doFn = transform.getNewFn();
+        final DoFn<InputT, OutputT> doFn = transform.getNewFn();
         rejectStateAndTimers(doFn);
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
         final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
-        @SuppressWarnings("unchecked")
-        final WindowFn<Object, ?> windowFn =
-            (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn();
+        final WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
         @SuppressWarnings("unchecked")
         JavaDStream<WindowedValue<InputT>> dStream =
             ((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream();
@@ -403,8 +396,8 @@ final class StreamingTransformTranslator {
               JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
             final Accumulator<NamedAggregators> accum =
                 SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context()));
-            return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, transform.getFn(),
-                runtimeContext, transform.getMainOutputTag(), sideInputs, windowFn));
+            return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, doFn,
+                runtimeContext, transform.getMainOutputTag(), sideInputs, windowingStrategy));
           }
         }).cache();
         PCollectionTuple pct = context.getOutput(transform);
@@ -423,8 +416,8 @@ final class StreamingTransformTranslator {
     };
   }
 
-  private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = Maps
-      .newHashMap();
+  private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS =
+      Maps.newHashMap();
 
   static {
     EVALUATORS.put(Read.Unbounded.class, readUnbounded());

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
index 471ec92..0284b3d 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
@@ -27,8 +27,8 @@ import org.apache.beam.runners.spark.SparkPipelineResult;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Values;
@@ -55,11 +55,12 @@ public final class PAssertStreaming implements Serializable {
    * Note that it is oblivious to windowing, so the assertion will apply indiscriminately to all
    * windows.
    */
-  public static <T> SparkPipelineResult runAndAssertContents(Pipeline p,
-                                                          PCollection<T> actual,
-                                                          T[] expected,
-                                                          Duration timeout,
-                                                          boolean stopGracefully) {
+  public static <T> SparkPipelineResult runAndAssertContents(
+      Pipeline p,
+      PCollection<T> actual,
+      T[] expected,
+      Duration timeout,
+      boolean stopGracefully) {
     // Because PAssert does not support non-global windowing, but all our data is in one window,
     // we set up the assertion directly.
     actual
@@ -86,14 +87,15 @@ public final class PAssertStreaming implements Serializable {
    * Default to stop gracefully so that tests will finish processing even if slower for reasons
    * such as a slow runtime environment.
    */
-  public static <T> SparkPipelineResult runAndAssertContents(Pipeline p,
-                                                          PCollection<T> actual,
-                                                          T[] expected,
-                                                          Duration timeout) {
+  public static <T> SparkPipelineResult runAndAssertContents(
+      Pipeline p,
+      PCollection<T> actual,
+      T[] expected,
+      Duration timeout) {
     return runAndAssertContents(p, actual, expected, timeout, true);
   }
 
-  private static class AssertDoFn<T> extends OldDoFn<Iterable<T>, Void> {
+  private static class AssertDoFn<T> extends DoFn<Iterable<T>, Void> {
     private final Aggregator<Integer, Integer> success =
         createAggregator(PAssert.SUCCESS_COUNTER, new Sum.SumIntegerFn());
     private final Aggregator<Integer, Integer> failure =
@@ -104,7 +106,7 @@ public final class PAssertStreaming implements Serializable {
       this.expected = expected;
     }
 
-    @Override
+    @ProcessElement
     public void processElement(ProcessContext c) throws Exception {
       try {
         assertThat(c.element(), containsInAnyOrder(expected));



[3/3] incubator-beam git commit: This closes #1578

Posted by ke...@apache.org.
This closes #1578


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/44b4eba5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/44b4eba5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/44b4eba5

Branch: refs/heads/master
Commit: 44b4eba51fad2fd4c33ad9467ac766d4b433f852
Parents: ce3aa65 4ffed3e
Author: Kenneth Knowles <kl...@google.com>
Authored: Tue Dec 13 18:54:51 2016 -0800
Committer: Kenneth Knowles <kl...@google.com>
Committed: Tue Dec 13 18:54:51 2016 -0800

----------------------------------------------------------------------
 .../spark/aggregators/SparkAggregators.java     |  30 +-
 .../beam/runners/spark/examples/WordCount.java  |   6 +-
 .../runners/spark/translation/DoFnFunction.java | 110 +++---
 .../translation/GroupCombineFunctions.java      |  23 +-
 .../spark/translation/MultiDoFnFunction.java    | 135 ++++---
 .../spark/translation/SparkAssignWindowFn.java  |  69 ++++
 .../translation/SparkGroupAlsoByWindowFn.java   | 214 +++++++++++
 .../spark/translation/SparkProcessContext.java  | 385 ++++---------------
 .../spark/translation/TransformTranslator.java  |  31 +-
 .../streaming/StreamingTransformTranslator.java |  35 +-
 .../streaming/utils/PAssertStreaming.java       |  26 +-
 11 files changed, 574 insertions(+), 490 deletions(-)
----------------------------------------------------------------------