You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by am...@apache.org on 2016/03/15 19:48:05 UTC
[08/23] incubator-beam git commit: [BEAM-11] Spark runner directory
structure and pom setup.
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationContext.java
new file mode 100644
index 0000000..836987f
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationContext.java
@@ -0,0 +1,284 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException;
+import com.google.cloud.dataflow.sdk.runners.AggregatorValues;
+import com.google.cloud.dataflow.sdk.transforms.Aggregator;
+import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.util.WindowedValue;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollectionView;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.POutput;
+import com.google.cloud.dataflow.sdk.values.PValue;
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+import org.apache.beam.runners.spark.coders.CoderHelpers;
+import org.apache.spark.api.java.JavaRDDLike;
+import org.apache.spark.api.java.JavaSparkContext;
+
+
+/**
+ * Evaluation context allows us to define how pipeline instructions.
+ */
+public class EvaluationContext implements EvaluationResult {
+ private final JavaSparkContext jsc;
+ private final Pipeline pipeline;
+ private final SparkRuntimeContext runtime;
+ private final Map<PValue, RDDHolder<?>> pcollections = new LinkedHashMap<>();
+ private final Set<RDDHolder<?>> leafRdds = new LinkedHashSet<>();
+ private final Set<PValue> multireads = new LinkedHashSet<>();
+ private final Map<PValue, Object> pobjects = new LinkedHashMap<>();
+ private final Map<PValue, Iterable<? extends WindowedValue<?>>> pview = new LinkedHashMap<>();
+ protected AppliedPTransform<?, ?, ?> currentTransform;
+
+ public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) {
+ this.jsc = jsc;
+ this.pipeline = pipeline;
+ this.runtime = new SparkRuntimeContext(jsc, pipeline);
+ }
+
+ /**
+ * Holds an RDD or values for deferred conversion to an RDD if needed. PCollections are
+ * sometimes created from a collection of objects (using RDD parallelize) and then
+ * only used to create View objects; in which case they do not need to be
+ * converted to bytes since they are not transferred across the network until they are
+ * broadcast.
+ */
+ private class RDDHolder<T> {
+
+ private Iterable<T> values;
+ private Coder<T> coder;
+ private JavaRDDLike<WindowedValue<T>, ?> rdd;
+
+ RDDHolder(Iterable<T> values, Coder<T> coder) {
+ this.values = values;
+ this.coder = coder;
+ }
+
+ RDDHolder(JavaRDDLike<WindowedValue<T>, ?> rdd) {
+ this.rdd = rdd;
+ }
+
+ JavaRDDLike<WindowedValue<T>, ?> getRDD() {
+ if (rdd == null) {
+ Iterable<WindowedValue<T>> windowedValues = Iterables.transform(values,
+ new Function<T, WindowedValue<T>>() {
+ @Override
+ public WindowedValue<T> apply(T t) {
+ // TODO: this is wrong if T is a TimestampedValue
+ return WindowedValue.valueInEmptyWindows(t);
+ }
+ });
+ WindowedValue.ValueOnlyWindowedValueCoder<T> windowCoder =
+ WindowedValue.getValueOnlyCoder(coder);
+ rdd = jsc.parallelize(CoderHelpers.toByteArrays(windowedValues, windowCoder))
+ .map(CoderHelpers.fromByteFunction(windowCoder));
+ }
+ return rdd;
+ }
+
+ Iterable<T> getValues(PCollection<T> pcollection) {
+ if (values == null) {
+ coder = pcollection.getCoder();
+ JavaRDDLike<byte[], ?> bytesRDD = rdd.map(WindowingHelpers.<T>unwindowFunction())
+ .map(CoderHelpers.toByteFunction(coder));
+ List<byte[]> clientBytes = bytesRDD.collect();
+ values = Iterables.transform(clientBytes, new Function<byte[], T>() {
+ @Override
+ public T apply(byte[] bytes) {
+ return CoderHelpers.fromByteArray(bytes, coder);
+ }
+ });
+ }
+ return values;
+ }
+
+ Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) {
+ return Iterables.transform(get(pcollection), new Function<T, WindowedValue<T>>() {
+ @Override
+ public WindowedValue<T> apply(T t) {
+ return WindowedValue.valueInEmptyWindows(t); // TODO: not the right place?
+ }
+ });
+ }
+ }
+
+ protected JavaSparkContext getSparkContext() {
+ return jsc;
+ }
+
+ protected Pipeline getPipeline() {
+ return pipeline;
+ }
+
+ protected SparkRuntimeContext getRuntimeContext() {
+ return runtime;
+ }
+
+ protected void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) {
+ this.currentTransform = transform;
+ }
+
+ protected AppliedPTransform<?, ?, ?> getCurrentTransform() {
+ return currentTransform;
+ }
+
+ protected <I extends PInput> I getInput(PTransform<I, ?> transform) {
+ checkArgument(currentTransform != null && currentTransform.getTransform() == transform,
+ "can only be called with current transform");
+ @SuppressWarnings("unchecked")
+ I input = (I) currentTransform.getInput();
+ return input;
+ }
+
+ protected <O extends POutput> O getOutput(PTransform<?, O> transform) {
+ checkArgument(currentTransform != null && currentTransform.getTransform() == transform,
+ "can only be called with current transform");
+ @SuppressWarnings("unchecked")
+ O output = (O) currentTransform.getOutput();
+ return output;
+ }
+
+ protected <T> void setOutputRDD(PTransform<?, ?> transform,
+ JavaRDDLike<WindowedValue<T>, ?> rdd) {
+ setRDD((PValue) getOutput(transform), rdd);
+ }
+
+ protected <T> void setOutputRDDFromValues(PTransform<?, ?> transform, Iterable<T> values,
+ Coder<T> coder) {
+ pcollections.put((PValue) getOutput(transform), new RDDHolder<>(values, coder));
+ }
+
+ void setPView(PValue view, Iterable<? extends WindowedValue<?>> value) {
+ pview.put(view, value);
+ }
+
+ protected boolean hasOutputRDD(PTransform<? extends PInput, ?> transform) {
+ PValue pvalue = (PValue) getOutput(transform);
+ return pcollections.containsKey(pvalue);
+ }
+
+ protected JavaRDDLike<?, ?> getRDD(PValue pvalue) {
+ RDDHolder<?> rddHolder = pcollections.get(pvalue);
+ JavaRDDLike<?, ?> rdd = rddHolder.getRDD();
+ leafRdds.remove(rddHolder);
+ if (multireads.contains(pvalue)) {
+ // Ensure the RDD is marked as cached
+ rdd.rdd().cache();
+ } else {
+ multireads.add(pvalue);
+ }
+ return rdd;
+ }
+
+ protected <T> void setRDD(PValue pvalue, JavaRDDLike<WindowedValue<T>, ?> rdd) {
+ try {
+ rdd.rdd().setName(pvalue.getName());
+ } catch (IllegalStateException e) {
+ // name not set, ignore
+ }
+ RDDHolder<T> rddHolder = new RDDHolder<>(rdd);
+ pcollections.put(pvalue, rddHolder);
+ leafRdds.add(rddHolder);
+ }
+
+ JavaRDDLike<?, ?> getInputRDD(PTransform<? extends PInput, ?> transform) {
+ return getRDD((PValue) getInput(transform));
+ }
+
+
+ <T> Iterable<? extends WindowedValue<?>> getPCollectionView(PCollectionView<T> view) {
+ return pview.get(view);
+ }
+
+ /**
+ * Computes the outputs for all RDDs that are leaves in the DAG and do not have any
+ * actions (like saving to a file) registered on them (i.e. they are performed for side
+ * effects).
+ */
+ protected void computeOutputs() {
+ for (RDDHolder<?> rddHolder : leafRdds) {
+ JavaRDDLike<?, ?> rdd = rddHolder.getRDD();
+ rdd.rdd().cache(); // cache so that any subsequent get() is cheap
+ rdd.count(); // force the RDD to be computed
+ }
+ }
+
+ @Override
+ public <T> T get(PValue value) {
+ if (pobjects.containsKey(value)) {
+ @SuppressWarnings("unchecked")
+ T result = (T) pobjects.get(value);
+ return result;
+ }
+ if (pcollections.containsKey(value)) {
+ JavaRDDLike<?, ?> rdd = pcollections.get(value).getRDD();
+ @SuppressWarnings("unchecked")
+ T res = (T) Iterables.getOnlyElement(rdd.collect());
+ pobjects.put(value, res);
+ return res;
+ }
+ throw new IllegalStateException("Cannot resolve un-known PObject: " + value);
+ }
+
+ @Override
+ public <T> T getAggregatorValue(String named, Class<T> resultType) {
+ return runtime.getAggregatorValue(named, resultType);
+ }
+
+ @Override
+ public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator)
+ throws AggregatorRetrievalException {
+ return runtime.getAggregatorValues(aggregator);
+ }
+
+ @Override
+ public <T> Iterable<T> get(PCollection<T> pcollection) {
+ @SuppressWarnings("unchecked")
+ RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection);
+ return rddHolder.getValues(pcollection);
+ }
+
+ <T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) {
+ @SuppressWarnings("unchecked")
+ RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection);
+ return rddHolder.getWindowedValues(pcollection);
+ }
+
+ @Override
+ public void close() {
+ SparkContextFactory.stopSparkContext(jsc);
+ }
+
+ /** The runner is blocking. */
+ @Override
+ public State getState() {
+ return State.DONE;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationResult.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationResult.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationResult.java
new file mode 100644
index 0000000..4de97f6
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationResult.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import com.google.cloud.dataflow.sdk.PipelineResult;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PValue;
+
+/**
+ * Interface for retrieving the result(s) of running a pipeline. Allows us to translate between
+ * {@code PObject<T>}s or {@code PCollection<T>}s and Ts or collections of Ts.
+ */
+public interface EvaluationResult extends PipelineResult {
+ /**
+ * Retrieves an iterable of results associated with the PCollection passed in.
+ *
+ * @param pcollection Collection we wish to translate.
+ * @param <T> Type of elements contained in collection.
+ * @return Natively types result associated with collection.
+ */
+ <T> Iterable<T> get(PCollection<T> pcollection);
+
+ /**
+ * Retrieve an object of Type T associated with the PValue passed in.
+ *
+ * @param pval PValue to retrieve associated data for.
+ * @param <T> Type of object to return.
+ * @return Native object.
+ */
+ <T> T get(PValue pval);
+
+ /**
+ * Retrieves the final value of the aggregator.
+ *
+ * @param aggName name of aggregator.
+ * @param resultType Class of final result of aggregation.
+ * @param <T> Type of final result of aggregation.
+ * @return Result of aggregation associated with specified name.
+ */
+ <T> T getAggregatorValue(String aggName, Class<T> resultType);
+
+ /**
+ * Releases any runtime resources, including distributed-execution contexts currently held by
+ * this EvaluationResult; once close() has been called,
+ * {@link EvaluationResult#get(PCollection)} might
+ * not work for subsequent calls.
+ */
+ void close();
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/MultiDoFnFunction.java
new file mode 100644
index 0000000..968825b
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/MultiDoFnFunction.java
@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import java.util.Iterator;
+import java.util.Map;
+
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.util.WindowedValue;
+import com.google.cloud.dataflow.sdk.values.TupleTag;
+import com.google.common.base.Function;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.LinkedListMultimap;
+import com.google.common.collect.Multimap;
+import org.apache.beam.runners.spark.util.BroadcastHelper;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.joda.time.Instant;
+import scala.Tuple2;
+
+/**
+ * DoFunctions ignore side outputs. MultiDoFunctions deal with side outputs by enriching the
+ * underlying data with multiple TupleTags.
+ *
+ * @param <I> Input type for DoFunction.
+ * @param <O> Output type for DoFunction.
+ */
+class MultiDoFnFunction<I, O>
+ implements PairFlatMapFunction<Iterator<WindowedValue<I>>, TupleTag<?>, WindowedValue<?>> {
+ private final DoFn<I, O> mFunction;
+ private final SparkRuntimeContext mRuntimeContext;
+ private final TupleTag<O> mMainOutputTag;
+ private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs;
+
+ MultiDoFnFunction(
+ DoFn<I, O> fn,
+ SparkRuntimeContext runtimeContext,
+ TupleTag<O> mainOutputTag,
+ Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) {
+ this.mFunction = fn;
+ this.mRuntimeContext = runtimeContext;
+ this.mMainOutputTag = mainOutputTag;
+ this.mSideInputs = sideInputs;
+ }
+
+ @Override
+ public Iterable<Tuple2<TupleTag<?>, WindowedValue<?>>>
+ call(Iterator<WindowedValue<I>> iter) throws Exception {
+ ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs);
+ mFunction.startBundle(ctxt);
+ ctxt.setup();
+ return ctxt.getOutputIterable(iter, mFunction);
+ }
+
+ private class ProcCtxt extends SparkProcessContext<I, O, Tuple2<TupleTag<?>, WindowedValue<?>>> {
+
+ private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create();
+
+ ProcCtxt(DoFn<I, O> fn, SparkRuntimeContext runtimeContext, Map<TupleTag<?>,
+ BroadcastHelper<?>> sideInputs) {
+ super(fn, runtimeContext, sideInputs);
+ }
+
+ @Override
+ public synchronized void output(O o) {
+ outputs.put(mMainOutputTag, windowedValue.withValue(o));
+ }
+
+ @Override
+ public synchronized void output(WindowedValue<O> o) {
+ outputs.put(mMainOutputTag, o);
+ }
+
+ @Override
+ public synchronized <T> void sideOutput(TupleTag<T> tag, T t) {
+ outputs.put(tag, windowedValue.withValue(t));
+ }
+
+ @Override
+ public <T> void sideOutputWithTimestamp(TupleTag<T> tupleTag, T t, Instant instant) {
+ outputs.put(tupleTag, WindowedValue.of(t, instant,
+ windowedValue.getWindows(), windowedValue.getPane()));
+ }
+
+ @Override
+ protected void clearOutput() {
+ outputs.clear();
+ }
+
+ @Override
+ protected Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> getOutputIterator() {
+ return Iterators.transform(outputs.entries().iterator(),
+ new Function<Map.Entry<TupleTag<?>, WindowedValue<?>>,
+ Tuple2<TupleTag<?>, WindowedValue<?>>>() {
+ @Override
+ public Tuple2<TupleTag<?>, WindowedValue<?>> apply(Map.Entry<TupleTag<?>,
+ WindowedValue<?>> input) {
+ return new Tuple2<TupleTag<?>, WindowedValue<?>>(input.getKey(), input.getValue());
+ }
+ });
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextFactory.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextFactory.java
new file mode 100644
index 0000000..10b7369
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextFactory.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.serializer.KryoSerializer;
+
+final class SparkContextFactory {
+
+ /**
+ * If the property {@code dataflow.spark.test.reuseSparkContext} is set to
+ * {@code true} then the Spark context will be reused for dataflow pipelines.
+ * This property should only be enabled for tests.
+ */
+ static final String TEST_REUSE_SPARK_CONTEXT =
+ "dataflow.spark.test.reuseSparkContext";
+ private static JavaSparkContext sparkContext;
+ private static String sparkMaster;
+
+ private SparkContextFactory() {
+ }
+
+ static synchronized JavaSparkContext getSparkContext(String master, String appName) {
+ if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) {
+ if (sparkContext == null) {
+ sparkContext = createSparkContext(master, appName);
+ sparkMaster = master;
+ } else if (!master.equals(sparkMaster)) {
+ throw new IllegalArgumentException(String.format("Cannot reuse spark context " +
+ "with different spark master URL. Existing: %s, requested: %s.",
+ sparkMaster, master));
+ }
+ return sparkContext;
+ } else {
+ return createSparkContext(master, appName);
+ }
+ }
+
+ static synchronized void stopSparkContext(JavaSparkContext context) {
+ if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) {
+ context.stop();
+ }
+ }
+
+ private static JavaSparkContext createSparkContext(String master, String appName) {
+ SparkConf conf = new SparkConf();
+ conf.setMaster(master);
+ conf.setAppName(appName);
+ conf.set("spark.serializer", KryoSerializer.class.getCanonicalName());
+ return new JavaSparkContext(conf);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineEvaluator.java
new file mode 100644
index 0000000..913e5a1
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineEvaluator.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import com.google.cloud.dataflow.sdk.runners.TransformTreeNode;
+import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.POutput;
+
+/**
+ * Pipeline {@link SparkPipelineRunner.Evaluator} for Spark.
+ */
+public final class SparkPipelineEvaluator extends SparkPipelineRunner.Evaluator {
+
+ private final EvaluationContext ctxt;
+
+ public SparkPipelineEvaluator(EvaluationContext ctxt, SparkPipelineTranslator translator) {
+ super(translator);
+ this.ctxt = ctxt;
+ }
+
+ @Override
+ protected <PT extends PTransform<? super PInput, POutput>> void doVisitTransform(TransformTreeNode
+ node) {
+ @SuppressWarnings("unchecked")
+ PT transform = (PT) node.getTransform();
+ @SuppressWarnings("unchecked")
+ Class<PT> transformClass = (Class<PT>) (Class<?>) transform.getClass();
+ @SuppressWarnings("unchecked") TransformEvaluator<PT> evaluator =
+ (TransformEvaluator<PT>) translator.translate(transformClass);
+ LOG.info("Evaluating {}", transform);
+ AppliedPTransform<PInput, POutput, PT> appliedTransform =
+ AppliedPTransform.of(node.getFullName(), node.getInput(), node.getOutput(), transform);
+ ctxt.setCurrentTransform(appliedTransform);
+ evaluator.evaluate(transform, ctxt);
+ ctxt.setCurrentTransform(null);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
new file mode 100644
index 0000000..1a5093b
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions;
+import com.google.cloud.dataflow.sdk.options.Default;
+import com.google.cloud.dataflow.sdk.options.Description;
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.options.StreamingOptions;
+
+public interface SparkPipelineOptions extends PipelineOptions, StreamingOptions,
+ ApplicationNameOptions {
+ @Description("The url of the spark master to connect to, (e.g. spark://host:port, local[4]).")
+ @Default.String("local[1]")
+ String getSparkMaster();
+
+ void setSparkMaster(String master);
+
+ @Override
+ @Default.Boolean(false)
+ boolean isStreaming();
+
+ @Override
+ @Default.String("spark dataflow pipeline job")
+ String getAppName();
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsFactory.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsFactory.java
new file mode 100644
index 0000000..7b44ee4
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsFactory.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
+
+public final class SparkPipelineOptionsFactory {
+ private SparkPipelineOptionsFactory() {
+ }
+
+ public static SparkPipelineOptions create() {
+ return PipelineOptionsFactory.as(SparkPipelineOptions.class);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsRegistrar.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsRegistrar.java
new file mode 100644
index 0000000..9f7f8c1
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsRegistrar.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar;
+import com.google.common.collect.ImmutableList;
+
+public class SparkPipelineOptionsRegistrar implements PipelineOptionsRegistrar {
+ @Override
+ public Iterable<Class<? extends PipelineOptions>> getPipelineOptions() {
+ return ImmutableList.<Class<? extends PipelineOptions>>of(SparkPipelineOptions.class);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
new file mode 100644
index 0000000..429750d
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
@@ -0,0 +1,252 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator;
+import com.google.cloud.dataflow.sdk.runners.PipelineRunner;
+import com.google.cloud.dataflow.sdk.runners.TransformTreeNode;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.POutput;
+import com.google.cloud.dataflow.sdk.values.PValue;
+
+import org.apache.beam.runners.spark.streaming.SparkStreamingPipelineOptions;
+import org.apache.beam.runners.spark.streaming.StreamingEvaluationContext;
+import org.apache.beam.runners.spark.streaming.StreamingTransformTranslator;
+import org.apache.beam.runners.spark.streaming.StreamingWindowPipelineDetector;
+
+import org.apache.spark.SparkException;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The SparkPipelineRunner translate operations defined on a pipeline to a representation
+ * executable by Spark, and then submitting the job to Spark to be executed. If we wanted to run
+ * a dataflow pipeline with the default options of a single threaded spark instance in local mode,
+ * we would do the following:
+ *
+ * {@code
+ * Pipeline p = [logic for pipeline creation]
+ * EvaluationResult result = SparkPipelineRunner.create().run(p);
+ * }
+ *
+ * To create a pipeline runner to run against a different spark cluster, with a custom master url
+ * we would do the following:
+ *
+ * {@code
+ * Pipeline p = [logic for pipeline creation]
+ * SparkPipelineOptions options = SparkPipelineOptionsFactory.create();
+ * options.setSparkMaster("spark://host:port");
+ * EvaluationResult result = SparkPipelineRunner.create(options).run(p);
+ * }
+ *
+ * To create a Spark streaming pipeline runner use {@link SparkStreamingPipelineOptions}
+ */
+public final class SparkPipelineRunner extends PipelineRunner<EvaluationResult> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(SparkPipelineRunner.class);
+ /**
+ * Options used in this pipeline runner.
+ */
+ private final SparkPipelineOptions mOptions;
+
+ /**
+ * Creates and returns a new SparkPipelineRunner with default options. In particular, against a
+ * spark instance running in local mode.
+ *
+ * @return A pipeline runner with default options.
+ */
+ public static SparkPipelineRunner create() {
+ SparkPipelineOptions options = SparkPipelineOptionsFactory.create();
+ return new SparkPipelineRunner(options);
+ }
+
+ /**
+ * Creates and returns a new SparkPipelineRunner with specified options.
+ *
+ * @param options The SparkPipelineOptions to use when executing the job.
+ * @return A pipeline runner that will execute with specified options.
+ */
+ public static SparkPipelineRunner create(SparkPipelineOptions options) {
+ return new SparkPipelineRunner(options);
+ }
+
+ /**
+ * Creates and returns a new SparkPipelineRunner with specified options.
+ *
+ * @param options The PipelineOptions to use when executing the job.
+ * @return A pipeline runner that will execute with specified options.
+ */
+ public static SparkPipelineRunner fromOptions(PipelineOptions options) {
+ SparkPipelineOptions sparkOptions =
+ PipelineOptionsValidator.validate(SparkPipelineOptions.class, options);
+ return new SparkPipelineRunner(sparkOptions);
+ }
+
+ /**
+ * No parameter constructor defaults to running this pipeline in Spark's local mode, in a single
+ * thread.
+ */
+ private SparkPipelineRunner(SparkPipelineOptions options) {
+ mOptions = options;
+ }
+
+
+ @Override
+ public EvaluationResult run(Pipeline pipeline) {
+ try {
+ // validate streaming configuration
+ if (mOptions.isStreaming() && !(mOptions instanceof SparkStreamingPipelineOptions)) {
+ throw new RuntimeException("A streaming job must be configured with " +
+ SparkStreamingPipelineOptions.class.getSimpleName() + ", found " +
+ mOptions.getClass().getSimpleName());
+ }
+ LOG.info("Executing pipeline using the SparkPipelineRunner.");
+ JavaSparkContext jsc = SparkContextFactory.getSparkContext(mOptions
+ .getSparkMaster(), mOptions.getAppName());
+
+ if (mOptions.isStreaming()) {
+ SparkPipelineTranslator translator =
+ new StreamingTransformTranslator.Translator(new TransformTranslator.Translator());
+ // if streaming - fixed window should be defined on all UNBOUNDED inputs
+ StreamingWindowPipelineDetector streamingWindowPipelineDetector =
+ new StreamingWindowPipelineDetector(translator);
+ pipeline.traverseTopologically(streamingWindowPipelineDetector);
+ if (!streamingWindowPipelineDetector.isWindowing()) {
+ throw new IllegalStateException("Spark streaming pipeline must be windowed!");
+ }
+
+ Duration batchInterval = streamingWindowPipelineDetector.getBatchDuration();
+ LOG.info("Setting Spark streaming batchInterval to {} msec", batchInterval.milliseconds());
+ EvaluationContext ctxt = createStreamingEvaluationContext(jsc, pipeline, batchInterval);
+
+ pipeline.traverseTopologically(new SparkPipelineEvaluator(ctxt, translator));
+ ctxt.computeOutputs();
+
+ LOG.info("Streaming pipeline construction complete. Starting execution..");
+ ((StreamingEvaluationContext) ctxt).getStreamingContext().start();
+
+ return ctxt;
+ } else {
+ EvaluationContext ctxt = new EvaluationContext(jsc, pipeline);
+ SparkPipelineTranslator translator = new TransformTranslator.Translator();
+ pipeline.traverseTopologically(new SparkPipelineEvaluator(ctxt, translator));
+ ctxt.computeOutputs();
+
+ LOG.info("Pipeline execution complete.");
+
+ return ctxt;
+ }
+ } catch (Exception e) {
+ // Scala doesn't declare checked exceptions in the bytecode, and the Java compiler
+ // won't let you catch something that is not declared, so we can't catch
+ // SparkException here. Instead we do an instanceof check.
+ // Then we find the cause by seeing if it's a user exception (wrapped by our
+ // SparkProcessException), or just use the SparkException cause.
+ if (e instanceof SparkException && e.getCause() != null) {
+ if (e.getCause() instanceof SparkProcessContext.SparkProcessException &&
+ e.getCause().getCause() != null) {
+ throw new RuntimeException(e.getCause().getCause());
+ } else {
+ throw new RuntimeException(e.getCause());
+ }
+ }
+ // otherwise just wrap in a RuntimeException
+ throw new RuntimeException(e);
+ }
+ }
+
+ private EvaluationContext
+ createStreamingEvaluationContext(JavaSparkContext jsc, Pipeline pipeline,
+ Duration batchDuration) {
+ SparkStreamingPipelineOptions streamingOptions = (SparkStreamingPipelineOptions) mOptions;
+ JavaStreamingContext jssc = new JavaStreamingContext(jsc, batchDuration);
+ return new StreamingEvaluationContext(jsc, pipeline, jssc, streamingOptions.getTimeout());
+ }
+
+ public abstract static class Evaluator implements Pipeline.PipelineVisitor {
+ protected static final Logger LOG = LoggerFactory.getLogger(Evaluator.class);
+
+ protected final SparkPipelineTranslator translator;
+
+ protected Evaluator(SparkPipelineTranslator translator) {
+ this.translator = translator;
+ }
+
+ // Set upon entering a composite node which can be directly mapped to a single
+ // TransformEvaluator.
+ private TransformTreeNode currentTranslatedCompositeNode;
+
+ /**
+ * If true, we're currently inside a subtree of a composite node which directly maps to a
+ * single
+ * TransformEvaluator; children nodes are ignored, and upon post-visiting the translated
+ * composite node, the associated TransformEvaluator will be visited.
+ */
+ private boolean inTranslatedCompositeNode() {
+ return currentTranslatedCompositeNode != null;
+ }
+
+ @Override
+ public void enterCompositeTransform(TransformTreeNode node) {
+ if (!inTranslatedCompositeNode() && node.getTransform() != null) {
+ @SuppressWarnings("unchecked")
+ Class<PTransform<?, ?>> transformClass =
+ (Class<PTransform<?, ?>>) node.getTransform().getClass();
+ if (translator.hasTranslation(transformClass)) {
+ LOG.info("Entering directly-translatable composite transform: '{}'", node.getFullName());
+ LOG.debug("Composite transform class: '{}'", transformClass);
+ currentTranslatedCompositeNode = node;
+ }
+ }
+ }
+
+ @Override
+ public void leaveCompositeTransform(TransformTreeNode node) {
+ // NB: We depend on enterCompositeTransform and leaveCompositeTransform providing 'node'
+ // objects for which Object.equals() returns true iff they are the same logical node
+ // within the tree.
+ if (inTranslatedCompositeNode() && node.equals(currentTranslatedCompositeNode)) {
+ LOG.info("Post-visiting directly-translatable composite transform: '{}'",
+ node.getFullName());
+ doVisitTransform(node);
+ currentTranslatedCompositeNode = null;
+ }
+ }
+
+ @Override
+ public void visitTransform(TransformTreeNode node) {
+ if (inTranslatedCompositeNode()) {
+ LOG.info("Skipping '{}'; already in composite transform.", node.getFullName());
+ return;
+ }
+ doVisitTransform(node);
+ }
+
+ protected abstract <PT extends PTransform<? super PInput, POutput>> void
+ doVisitTransform(TransformTreeNode node);
+
+ @Override
+ public void visitValue(PValue value, TransformTreeNode producer) {
+ }
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunnerRegistrar.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunnerRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunnerRegistrar.java
new file mode 100644
index 0000000..9a84370
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunnerRegistrar.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import com.google.cloud.dataflow.sdk.runners.PipelineRunner;
+import com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar;
+import com.google.common.collect.ImmutableList;
+
+public class SparkPipelineRunnerRegistrar implements PipelineRunnerRegistrar {
+ @Override
+ public Iterable<Class<? extends PipelineRunner<?>>> getPipelineRunners() {
+ return ImmutableList.<Class<? extends PipelineRunner<?>>>of(SparkPipelineRunner.class);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineTranslator.java
new file mode 100644
index 0000000..e45491a
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineTranslator.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+package org.apache.beam.runners.spark;
+
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+
+/**
+ * Translator to support translation between Dataflow transformations and Spark transformations.
+ */
+public interface SparkPipelineTranslator {
+
+ boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz);
+
+ <PT extends PTransform<?, ?>> TransformEvaluator<PT> translate(Class<PT> clazz);
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkProcessContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkProcessContext.java
new file mode 100644
index 0000000..c634152
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkProcessContext.java
@@ -0,0 +1,257 @@
+/*
+ * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.transforms.Aggregator;
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow;
+import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo;
+import com.google.cloud.dataflow.sdk.util.TimerInternals;
+import com.google.cloud.dataflow.sdk.util.WindowedValue;
+import com.google.cloud.dataflow.sdk.util.WindowingInternals;
+import com.google.cloud.dataflow.sdk.util.state.StateInternals;
+import com.google.cloud.dataflow.sdk.values.PCollectionView;
+import com.google.cloud.dataflow.sdk.values.TupleTag;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Iterables;
+
+import org.apache.beam.runners.spark.util.BroadcastHelper;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+abstract class SparkProcessContext<I, O, V> extends DoFn<I, O>.ProcessContext {
+
+ private static final Logger LOG = LoggerFactory.getLogger(SparkProcessContext.class);
+
+ private final DoFn<I, O> fn;
+ private final SparkRuntimeContext mRuntimeContext;
+ private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs;
+
+ protected WindowedValue<I> windowedValue;
+
+ SparkProcessContext(DoFn<I, O> fn,
+ SparkRuntimeContext runtime,
+ Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) {
+ fn.super();
+ this.fn = fn;
+ this.mRuntimeContext = runtime;
+ this.mSideInputs = sideInputs;
+ }
+
+ void setup() {
+ setupDelegateAggregators();
+ }
+
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return mRuntimeContext.getPipelineOptions();
+ }
+
+ @Override
+ public <T> T sideInput(PCollectionView<T> view) {
+ @SuppressWarnings("unchecked")
+ BroadcastHelper<Iterable<WindowedValue<?>>> broadcastHelper =
+ (BroadcastHelper<Iterable<WindowedValue<?>>>) mSideInputs.get(view.getTagInternal());
+ Iterable<WindowedValue<?>> contents = broadcastHelper.getValue();
+ return view.fromIterableInternal(contents);
+ }
+
+ @Override
+ public abstract void output(O output);
+
+ public abstract void output(WindowedValue<O> output);
+
+ @Override
+ public <T> void sideOutput(TupleTag<T> tupleTag, T t) {
+ String message = "sideOutput is an unsupported operation for doFunctions, use a " +
+ "MultiDoFunction instead.";
+ LOG.warn(message);
+ throw new UnsupportedOperationException(message);
+ }
+
+ @Override
+ public <T> void sideOutputWithTimestamp(TupleTag<T> tupleTag, T t, Instant instant) {
+ String message =
+ "sideOutputWithTimestamp is an unsupported operation for doFunctions, use a " +
+ "MultiDoFunction instead.";
+ LOG.warn(message);
+ throw new UnsupportedOperationException(message);
+ }
+
+ @Override
+ public <AI, AO> Aggregator<AI, AO> createAggregatorInternal(
+ String named,
+ Combine.CombineFn<AI, ?, AO> combineFn) {
+ return mRuntimeContext.createAggregator(named, combineFn);
+ }
+
+ @Override
+ public I element() {
+ return windowedValue.getValue();
+ }
+
+ @Override
+ public void outputWithTimestamp(O output, Instant timestamp) {
+ output(WindowedValue.of(output, timestamp,
+ windowedValue.getWindows(), windowedValue.getPane()));
+ }
+
+ @Override
+ public Instant timestamp() {
+ return windowedValue.getTimestamp();
+ }
+
+ @Override
+ public BoundedWindow window() {
+ if (!(fn instanceof DoFn.RequiresWindowAccess)) {
+ throw new UnsupportedOperationException(
+ "window() is only available in the context of a DoFn marked as RequiresWindow.");
+ }
+ return Iterables.getOnlyElement(windowedValue.getWindows());
+ }
+
+ @Override
+ public PaneInfo pane() {
+ return windowedValue.getPane();
+ }
+
+ @Override
+ public WindowingInternals<I, O> windowingInternals() {
+ return new WindowingInternals<I, O>() {
+
+ @Override
+ public Collection<? extends BoundedWindow> windows() {
+ return windowedValue.getWindows();
+ }
+
+ @Override
+ public void outputWindowedValue(O output, Instant timestamp, Collection<?
+ extends BoundedWindow> windows, PaneInfo paneInfo) {
+ output(WindowedValue.of(output, timestamp, windows, paneInfo));
+ }
+
+ @Override
+ public StateInternals stateInternals() {
+ throw new UnsupportedOperationException(
+ "WindowingInternals#stateInternals() is not yet supported.");
+ }
+
+ @Override
+ public TimerInternals timerInternals() {
+ throw new UnsupportedOperationException(
+ "WindowingInternals#timerInternals() is not yet supported.");
+ }
+
+ @Override
+ public PaneInfo pane() {
+ return windowedValue.getPane();
+ }
+
+ @Override
+ public <T> void writePCollectionViewData(TupleTag<?> tag,
+ Iterable<WindowedValue<T>> data, Coder<T> elemCoder) throws IOException {
+ throw new UnsupportedOperationException(
+ "WindowingInternals#writePCollectionViewData() is not yet supported.");
+ }
+
+ @Override
+ public <T> T sideInput(PCollectionView<T> view, BoundedWindow mainInputWindow) {
+ throw new UnsupportedOperationException(
+ "WindowingInternals#sideInput() is not yet supported.");
+ }
+ };
+ }
+
+ protected abstract void clearOutput();
+ protected abstract Iterator<V> getOutputIterator();
+
+ protected Iterable<V> getOutputIterable(final Iterator<WindowedValue<I>> iter,
+ final DoFn<I, O> doFn) {
+ return new Iterable<V>() {
+ @Override
+ public Iterator<V> iterator() {
+ return new ProcCtxtIterator(iter, doFn);
+ }
+ };
+ }
+
+ private class ProcCtxtIterator extends AbstractIterator<V> {
+
+ private final Iterator<WindowedValue<I>> inputIterator;
+ private final DoFn<I, O> doFn;
+ private Iterator<V> outputIterator;
+ private boolean calledFinish;
+
+ ProcCtxtIterator(Iterator<WindowedValue<I>> iterator, DoFn<I, O> doFn) {
+ this.inputIterator = iterator;
+ this.doFn = doFn;
+ this.outputIterator = getOutputIterator();
+ }
+
+ @Override
+ protected V computeNext() {
+ // Process each element from the (input) iterator, which produces, zero, one or more
+ // output elements (of type V) in the output iterator. Note that the output
+ // collection (and iterator) is reset between each call to processElement, so the
+ // collection only holds the output values for each call to processElement, rather
+ // than for the whole partition (which would use too much memory).
+ while (true) {
+ if (outputIterator.hasNext()) {
+ return outputIterator.next();
+ } else if (inputIterator.hasNext()) {
+ clearOutput();
+ windowedValue = inputIterator.next();
+ try {
+ doFn.processElement(SparkProcessContext.this);
+ } catch (Exception e) {
+ throw new SparkProcessException(e);
+ }
+ outputIterator = getOutputIterator();
+ } else {
+ // no more input to consume, but finishBundle can produce more output
+ if (!calledFinish) {
+ clearOutput();
+ try {
+ calledFinish = true;
+ doFn.finishBundle(SparkProcessContext.this);
+ } catch (Exception e) {
+ throw new SparkProcessException(e);
+ }
+ outputIterator = getOutputIterator();
+ continue; // try to consume outputIterator from start of loop
+ }
+ return endOfData();
+ }
+ }
+ }
+ }
+
+ static class SparkProcessException extends RuntimeException {
+ SparkProcessException(Throwable t) {
+ super(t);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRuntimeContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRuntimeContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRuntimeContext.java
new file mode 100644
index 0000000..da48ad7
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRuntimeContext.java
@@ -0,0 +1,214 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.coders.CoderRegistry;
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.runners.AggregatorValues;
+import com.google.cloud.dataflow.sdk.transforms.Aggregator;
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import com.google.cloud.dataflow.sdk.transforms.Max;
+import com.google.cloud.dataflow.sdk.transforms.Min;
+import com.google.cloud.dataflow.sdk.transforms.Sum;
+import com.google.cloud.dataflow.sdk.values.TypeDescriptor;
+import com.google.common.collect.ImmutableList;
+
+import org.apache.beam.runners.spark.aggregators.AggAccumParam;
+import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+
+import org.apache.spark.Accumulator;
+import org.apache.spark.api.java.JavaSparkContext;
+
+
+/**
+ * The SparkRuntimeContext allows us to define useful features on the client side before our
+ * data flow program is launched.
+ */
+public class SparkRuntimeContext implements Serializable {
+ /**
+ * An accumulator that is a map from names to aggregators.
+ */
+ private final Accumulator<NamedAggregators> accum;
+
+ private final String serializedPipelineOptions;
+
+ /**
+ * Map fo names to dataflow aggregators.
+ */
+ private final Map<String, Aggregator<?, ?>> aggregators = new HashMap<>();
+ private transient CoderRegistry coderRegistry;
+
+ SparkRuntimeContext(JavaSparkContext jsc, Pipeline pipeline) {
+ this.accum = jsc.accumulator(new NamedAggregators(), new AggAccumParam());
+ this.serializedPipelineOptions = serializePipelineOptions(pipeline.getOptions());
+ }
+
+ private static String serializePipelineOptions(PipelineOptions pipelineOptions) {
+ try {
+ return new ObjectMapper().writeValueAsString(pipelineOptions);
+ } catch (JsonProcessingException e) {
+ throw new IllegalStateException("Failed to serialize the pipeline options.", e);
+ }
+ }
+
+ private static PipelineOptions deserializePipelineOptions(String serializedPipelineOptions) {
+ try {
+ return new ObjectMapper().readValue(serializedPipelineOptions, PipelineOptions.class);
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to deserialize the pipeline options.", e);
+ }
+ }
+
+ /**
+ * Retrieves corresponding value of an aggregator.
+ *
+ * @param aggregatorName Name of the aggregator to retrieve the value of.
+ * @param typeClass Type class of value to be retrieved.
+ * @param <T> Type of object to be returned.
+ * @return The value of the aggregator.
+ */
+ public <T> T getAggregatorValue(String aggregatorName, Class<T> typeClass) {
+ return accum.value().getValue(aggregatorName, typeClass);
+ }
+
+ public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) {
+ @SuppressWarnings("unchecked")
+ Class<T> aggValueClass = (Class<T>) aggregator.getCombineFn().getOutputType().getRawType();
+ final T aggregatorValue = getAggregatorValue(aggregator.getName(), aggValueClass);
+ return new AggregatorValues<T>() {
+ @Override
+ public Collection<T> getValues() {
+ return ImmutableList.of(aggregatorValue);
+ }
+
+ @Override
+ public Map<String, T> getValuesAtSteps() {
+ throw new UnsupportedOperationException("getValuesAtSteps is not supported.");
+ }
+ };
+ }
+
+ public synchronized PipelineOptions getPipelineOptions() {
+ return deserializePipelineOptions(serializedPipelineOptions);
+ }
+
+ /**
+ * Creates and aggregator and associates it with the specified name.
+ *
+ * @param named Name of aggregator.
+ * @param combineFn Combine function used in aggregation.
+ * @param <IN> Type of inputs to aggregator.
+ * @param <INTER> Intermediate data type
+ * @param <OUT> Type of aggregator outputs.
+ * @return Specified aggregator
+ */
+ public synchronized <IN, INTER, OUT> Aggregator<IN, OUT> createAggregator(
+ String named,
+ Combine.CombineFn<? super IN, INTER, OUT> combineFn) {
+ @SuppressWarnings("unchecked")
+ Aggregator<IN, OUT> aggregator = (Aggregator<IN, OUT>) aggregators.get(named);
+ if (aggregator == null) {
+ @SuppressWarnings("unchecked")
+ NamedAggregators.CombineFunctionState<IN, INTER, OUT> state =
+ new NamedAggregators.CombineFunctionState<>(
+ (Combine.CombineFn<IN, INTER, OUT>) combineFn,
+ (Coder<IN>) getCoder(combineFn),
+ this);
+ accum.add(new NamedAggregators(named, state));
+ aggregator = new SparkAggregator<>(named, state);
+ aggregators.put(named, aggregator);
+ }
+ return aggregator;
+ }
+
+ public CoderRegistry getCoderRegistry() {
+ if (coderRegistry == null) {
+ coderRegistry = new CoderRegistry();
+ coderRegistry.registerStandardCoders();
+ }
+ return coderRegistry;
+ }
+
+ private Coder<?> getCoder(Combine.CombineFn<?, ?, ?> combiner) {
+ try {
+ if (combiner.getClass() == Sum.SumIntegerFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class));
+ } else if (combiner.getClass() == Sum.SumLongFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class));
+ } else if (combiner.getClass() == Sum.SumDoubleFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class));
+ } else if (combiner.getClass() == Min.MinIntegerFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class));
+ } else if (combiner.getClass() == Min.MinLongFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class));
+ } else if (combiner.getClass() == Min.MinDoubleFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class));
+ } else if (combiner.getClass() == Max.MaxIntegerFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class));
+ } else if (combiner.getClass() == Max.MaxLongFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class));
+ } else if (combiner.getClass() == Max.MaxDoubleFn.class) {
+ return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class));
+ } else {
+ throw new IllegalArgumentException("unsupported combiner in Aggregator: "
+ + combiner.getClass().getName());
+ }
+ } catch (CannotProvideCoderException e) {
+ throw new IllegalStateException("Could not determine default coder for combiner", e);
+ }
+ }
+
+ /**
+ * Initialize spark aggregators exactly once.
+ *
+ * @param <IN> Type of element fed in to aggregator.
+ */
+ private static class SparkAggregator<IN, OUT> implements Aggregator<IN, OUT>, Serializable {
+ private final String name;
+ private final NamedAggregators.State<IN, ?, OUT> state;
+
+ SparkAggregator(String name, NamedAggregators.State<IN, ?, OUT> state) {
+ this.name = name;
+ this.state = state;
+ }
+
+ @Override
+ public String getName() {
+ return name;
+ }
+
+ @Override
+ public void addValue(IN elem) {
+ state.update(elem);
+ }
+
+ @Override
+ public Combine.CombineFn<IN, ?, OUT> getCombineFn() {
+ return state.getCombineFn();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformEvaluator.java
new file mode 100644
index 0000000..8aaceeb
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformEvaluator.java
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import java.io.Serializable;
+
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+
+public interface TransformEvaluator<PT extends PTransform<?, ?>> extends Serializable {
+ void evaluate(PT transform, EvaluationContext context);
+}