You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by am...@apache.org on 2016/03/15 19:48:12 UTC
[15/23] incubator-beam git commit: [BEAM-11] second iteration of
package reorganisation
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/KafkaStreamingTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/KafkaStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/KafkaStreamingTest.java
deleted file mode 100644
index 05340d6..0000000
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/KafkaStreamingTest.java
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.streaming;
-
-import com.google.cloud.dataflow.sdk.Pipeline;
-import com.google.cloud.dataflow.sdk.coders.KvCoder;
-import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
-import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
-import com.google.cloud.dataflow.sdk.transforms.DoFn;
-import com.google.cloud.dataflow.sdk.transforms.ParDo;
-import com.google.cloud.dataflow.sdk.transforms.View;
-import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows;
-import com.google.cloud.dataflow.sdk.transforms.windowing.Window;
-import com.google.cloud.dataflow.sdk.values.KV;
-import com.google.cloud.dataflow.sdk.values.PCollection;
-
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import org.apache.beam.runners.spark.io.KafkaIO;
-import org.apache.beam.runners.spark.EvaluationResult;
-import org.apache.beam.runners.spark.SparkPipelineRunner;
-import org.apache.beam.runners.spark.streaming.utils.DataflowAssertStreaming;
-import org.apache.beam.runners.spark.streaming.utils.EmbeddedKafkaCluster;
-
-import org.apache.kafka.clients.producer.KafkaProducer;
-import org.apache.kafka.clients.producer.ProducerRecord;
-import org.apache.kafka.common.serialization.Serializer;
-import org.apache.kafka.common.serialization.StringSerializer;
-import org.joda.time.Duration;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-
-import java.io.IOException;
-import java.util.Collections;
-import java.util.Map;
-import java.util.Properties;
-import java.util.Set;
-
-import kafka.serializer.StringDecoder;
-
-/**
- * Test Kafka as input.
- */
-public class KafkaStreamingTest {
- private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER =
- new EmbeddedKafkaCluster.EmbeddedZookeeper(17001);
- private static final EmbeddedKafkaCluster EMBEDDED_KAFKA_CLUSTER =
- new EmbeddedKafkaCluster(EMBEDDED_ZOOKEEPER.getConnection(),
- new Properties(), Collections.singletonList(6667));
- private static final String TOPIC = "kafka_dataflow_test_topic";
- private static final Map<String, String> KAFKA_MESSAGES = ImmutableMap.of(
- "k1", "v1", "k2", "v2", "k3", "v3", "k4", "v4"
- );
- private static final Set<String> EXPECTED = ImmutableSet.of(
- "k1,v1", "k2,v2", "k3,v3", "k4,v4"
- );
- private static final long TEST_TIMEOUT_MSEC = 1000L;
-
- @BeforeClass
- public static void init() throws IOException {
- EMBEDDED_ZOOKEEPER.startup();
- EMBEDDED_KAFKA_CLUSTER.startup();
-
- // write to Kafka
- Properties producerProps = new Properties();
- producerProps.putAll(EMBEDDED_KAFKA_CLUSTER.getProps());
- producerProps.put("request.required.acks", 1);
- producerProps.put("bootstrap.servers", EMBEDDED_KAFKA_CLUSTER.getBrokerList());
- Serializer<String> stringSerializer = new StringSerializer();
- try (@SuppressWarnings("unchecked") KafkaProducer<String, String> kafkaProducer =
- new KafkaProducer(producerProps, stringSerializer, stringSerializer)) {
- for (Map.Entry<String, String> en : KAFKA_MESSAGES.entrySet()) {
- kafkaProducer.send(new ProducerRecord<>(TOPIC, en.getKey(), en.getValue()));
- }
- }
- }
-
- @Test
- public void testRun() throws Exception {
- // test read from Kafka
- SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create();
- options.setAppName(this.getClass().getSimpleName());
- options.setRunner(SparkPipelineRunner.class);
- options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval
- Pipeline p = Pipeline.create(options);
-
- Map<String, String> kafkaParams = ImmutableMap.of(
- "metadata.broker.list", EMBEDDED_KAFKA_CLUSTER.getBrokerList(),
- "auto.offset.reset", "smallest"
- );
-
- PCollection<KV<String, String>> kafkaInput = p.apply(KafkaIO.Read.from(StringDecoder.class,
- StringDecoder.class, String.class, String.class, Collections.singleton(TOPIC),
- kafkaParams))
- .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));
- PCollection<KV<String, String>> windowedWords = kafkaInput
- .apply(Window.<KV<String, String>>into(FixedWindows.of(Duration.standardSeconds(1))));
-
- PCollection<String> formattedKV = windowedWords.apply(ParDo.of(new FormatKVFn()));
-
- DataflowAssert.thatIterable(formattedKV.apply(View.<String>asIterable()))
- .containsInAnyOrder(EXPECTED);
-
- EvaluationResult res = SparkPipelineRunner.create(options).run(p);
- res.close();
-
- DataflowAssertStreaming.assertNoFailures(res);
- }
-
- @AfterClass
- public static void tearDown() {
- EMBEDDED_KAFKA_CLUSTER.shutdown();
- EMBEDDED_ZOOKEEPER.shutdown();
- }
-
- private static class FormatKVFn extends DoFn<KV<String, String>, String> {
- @Override
- public void processElement(ProcessContext c) {
- c.output(c.element().getKey() + "," + c.element().getValue());
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/SimpleStreamingWordCountTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/SimpleStreamingWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/SimpleStreamingWordCountTest.java
deleted file mode 100644
index 16b145a..0000000
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/SimpleStreamingWordCountTest.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.streaming;
-
-import com.google.cloud.dataflow.sdk.Pipeline;
-import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
-import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
-import com.google.cloud.dataflow.sdk.transforms.View;
-import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows;
-import com.google.cloud.dataflow.sdk.transforms.windowing.Window;
-import com.google.cloud.dataflow.sdk.values.PCollection;
-import com.google.common.collect.ImmutableSet;
-
-import org.apache.beam.runners.spark.io.CreateStream;
-import org.apache.beam.runners.spark.EvaluationResult;
-import org.apache.beam.runners.spark.SimpleWordCountTest;
-import org.apache.beam.runners.spark.SparkPipelineRunner;
-import org.apache.beam.runners.spark.streaming.utils.DataflowAssertStreaming;
-
-import org.joda.time.Duration;
-import org.junit.Test;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.Set;
-
-public class SimpleStreamingWordCountTest {
-
- private static final String[] WORDS_ARRAY = {
- "hi there", "hi", "hi sue bob", "hi sue", "", "bob hi"};
- private static final List<Iterable<String>> WORDS_QUEUE =
- Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY));
- private static final Set<String> EXPECTED_COUNT_SET =
- ImmutableSet.of("hi: 5", "there: 1", "sue: 2", "bob: 2");
- private static final long TEST_TIMEOUT_MSEC = 1000L;
-
- @Test
- public void testRun() throws Exception {
- SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create();
- options.setAppName(this.getClass().getSimpleName());
- options.setRunner(SparkPipelineRunner.class);
- options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval
- Pipeline p = Pipeline.create(options);
-
- PCollection<String> inputWords =
- p.apply(CreateStream.fromQueue(WORDS_QUEUE)).setCoder(StringUtf8Coder.of());
- PCollection<String> windowedWords = inputWords
- .apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
-
- PCollection<String> output = windowedWords.apply(new SimpleWordCountTest.CountWords());
-
- DataflowAssert.thatIterable(output.apply(View.<String>asIterable()))
- .containsInAnyOrder(EXPECTED_COUNT_SET);
-
- EvaluationResult res = SparkPipelineRunner.create(options).run(p);
- res.close();
-
- DataflowAssertStreaming.assertNoFailures(res);
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/DataflowAssertStreaming.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/DataflowAssertStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/DataflowAssertStreaming.java
deleted file mode 100644
index 367a062..0000000
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/DataflowAssertStreaming.java
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.streaming.utils;
-
-import org.apache.beam.runners.spark.EvaluationResult;
-
-import org.junit.Assert;
-
-/**
- * Since DataflowAssert doesn't propagate assert exceptions, use Aggregators to assert streaming
- * success/failure counters.
- */
-public final class DataflowAssertStreaming {
- /**
- * Copied aggregator names from {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert}
- */
- static final String SUCCESS_COUNTER = "DataflowAssertSuccess";
- static final String FAILURE_COUNTER = "DataflowAssertFailure";
-
- private DataflowAssertStreaming() {
- }
-
- public static void assertNoFailures(EvaluationResult res) {
- int failures = res.getAggregatorValue(FAILURE_COUNTER, Integer.class);
- Assert.assertEquals("Found " + failures + " failures, see the log for details", 0, failures);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/EmbeddedKafkaCluster.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/EmbeddedKafkaCluster.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/EmbeddedKafkaCluster.java
deleted file mode 100644
index 8273684..0000000
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/EmbeddedKafkaCluster.java
+++ /dev/null
@@ -1,317 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.streaming.utils;
-
-import org.apache.zookeeper.server.NIOServerCnxnFactory;
-import org.apache.zookeeper.server.ServerCnxnFactory;
-import org.apache.zookeeper.server.ZooKeeperServer;
-
-import java.io.File;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.net.InetSocketAddress;
-import java.net.ServerSocket;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Properties;
-import java.util.Random;
-
-import kafka.server.KafkaConfig;
-import kafka.server.KafkaServer;
-import kafka.utils.Time;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * https://gist.github.com/fjavieralba/7930018
- */
-public class EmbeddedKafkaCluster {
-
- private static final Logger LOG = LoggerFactory.getLogger(EmbeddedKafkaCluster.class);
-
- private final List<Integer> ports;
- private final String zkConnection;
- private final Properties baseProperties;
-
- private final String brokerList;
-
- private final List<KafkaServer> brokers;
- private final List<File> logDirs;
-
- public EmbeddedKafkaCluster(String zkConnection) {
- this(zkConnection, new Properties());
- }
-
- public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties) {
- this(zkConnection, baseProperties, Collections.singletonList(-1));
- }
-
- public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties, List<Integer> ports) {
- this.zkConnection = zkConnection;
- this.ports = resolvePorts(ports);
- this.baseProperties = baseProperties;
-
- this.brokers = new ArrayList<>();
- this.logDirs = new ArrayList<>();
-
- this.brokerList = constructBrokerList(this.ports);
- }
-
- private static List<Integer> resolvePorts(List<Integer> ports) {
- List<Integer> resolvedPorts = new ArrayList<>();
- for (Integer port : ports) {
- resolvedPorts.add(resolvePort(port));
- }
- return resolvedPorts;
- }
-
- private static int resolvePort(int port) {
- if (port == -1) {
- return TestUtils.getAvailablePort();
- }
- return port;
- }
-
- private static String constructBrokerList(List<Integer> ports) {
- StringBuilder sb = new StringBuilder();
- for (Integer port : ports) {
- if (sb.length() > 0) {
- sb.append(",");
- }
- sb.append("localhost:").append(port);
- }
- return sb.toString();
- }
-
- public void startup() {
- for (int i = 0; i < ports.size(); i++) {
- Integer port = ports.get(i);
- File logDir = TestUtils.constructTempDir("kafka-local");
-
- Properties properties = new Properties();
- properties.putAll(baseProperties);
- properties.setProperty("zookeeper.connect", zkConnection);
- properties.setProperty("broker.id", String.valueOf(i + 1));
- properties.setProperty("host.name", "localhost");
- properties.setProperty("port", Integer.toString(port));
- properties.setProperty("log.dir", logDir.getAbsolutePath());
- properties.setProperty("log.flush.interval.messages", String.valueOf(1));
-
- KafkaServer broker = startBroker(properties);
-
- brokers.add(broker);
- logDirs.add(logDir);
- }
- }
-
-
- private static KafkaServer startBroker(Properties props) {
- KafkaServer server = new KafkaServer(new KafkaConfig(props), new SystemTime());
- server.startup();
- return server;
- }
-
- public Properties getProps() {
- Properties props = new Properties();
- props.putAll(baseProperties);
- props.put("metadata.broker.list", brokerList);
- props.put("zookeeper.connect", zkConnection);
- return props;
- }
-
- public String getBrokerList() {
- return brokerList;
- }
-
- public List<Integer> getPorts() {
- return ports;
- }
-
- public String getZkConnection() {
- return zkConnection;
- }
-
- public void shutdown() {
- for (KafkaServer broker : brokers) {
- try {
- broker.shutdown();
- } catch (Exception e) {
- LOG.warn("{}", e.getMessage(), e);
- }
- }
- for (File logDir : logDirs) {
- try {
- TestUtils.deleteFile(logDir);
- } catch (FileNotFoundException e) {
- LOG.warn("{}", e.getMessage(), e);
- }
- }
- }
-
- @Override
- public String toString() {
- return "EmbeddedKafkaCluster{" + "brokerList='" + brokerList + "'}";
- }
-
- public static class EmbeddedZookeeper {
- private int port = -1;
- private int tickTime = 500;
-
- private ServerCnxnFactory factory;
- private File snapshotDir;
- private File logDir;
-
- public EmbeddedZookeeper() {
- this(-1);
- }
-
- public EmbeddedZookeeper(int port) {
- this(port, 500);
- }
-
- public EmbeddedZookeeper(int port, int tickTime) {
- this.port = resolvePort(port);
- this.tickTime = tickTime;
- }
-
- private static int resolvePort(int port) {
- if (port == -1) {
- return TestUtils.getAvailablePort();
- }
- return port;
- }
-
- public void startup() throws IOException {
- if (this.port == -1) {
- this.port = TestUtils.getAvailablePort();
- }
- this.factory = NIOServerCnxnFactory.createFactory(new InetSocketAddress("localhost", port),
- 1024);
- this.snapshotDir = TestUtils.constructTempDir("embedded-zk/snapshot");
- this.logDir = TestUtils.constructTempDir("embedded-zk/log");
-
- try {
- factory.startup(new ZooKeeperServer(snapshotDir, logDir, tickTime));
- } catch (InterruptedException e) {
- throw new IOException(e);
- }
- }
-
-
- public void shutdown() {
- factory.shutdown();
- try {
- TestUtils.deleteFile(snapshotDir);
- } catch (FileNotFoundException e) {
- // ignore
- }
- try {
- TestUtils.deleteFile(logDir);
- } catch (FileNotFoundException e) {
- // ignore
- }
- }
-
- public String getConnection() {
- return "localhost:" + port;
- }
-
- public void setPort(int port) {
- this.port = port;
- }
-
- public void setTickTime(int tickTime) {
- this.tickTime = tickTime;
- }
-
- public int getPort() {
- return port;
- }
-
- public int getTickTime() {
- return tickTime;
- }
-
- @Override
- public String toString() {
- return "EmbeddedZookeeper{" + "connection=" + getConnection() + "}";
- }
- }
-
- static class SystemTime implements Time {
- @Override
- public long milliseconds() {
- return System.currentTimeMillis();
- }
-
- @Override
- public long nanoseconds() {
- return System.nanoTime();
- }
-
- @Override
- public void sleep(long ms) {
- try {
- Thread.sleep(ms);
- } catch (InterruptedException e) {
- // Ignore
- }
- }
- }
-
- static final class TestUtils {
- private static final Random RANDOM = new Random();
-
- private TestUtils() {
- }
-
- static File constructTempDir(String dirPrefix) {
- File file = new File(System.getProperty("java.io.tmpdir"), dirPrefix + RANDOM.nextInt
- (10000000));
- if (!file.mkdirs()) {
- throw new RuntimeException("could not create temp directory: " + file.getAbsolutePath());
- }
- file.deleteOnExit();
- return file;
- }
-
- static int getAvailablePort() {
- try {
- try (ServerSocket socket = new ServerSocket(0)) {
- return socket.getLocalPort();
- }
- } catch (IOException e) {
- throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
- }
- }
-
- static boolean deleteFile(File path) throws FileNotFoundException {
- if (!path.exists()) {
- throw new FileNotFoundException(path.getAbsolutePath());
- }
- boolean ret = true;
- if (path.isDirectory()) {
- for (File f : path.listFiles()) {
- ret = ret && deleteFile(f);
- }
- }
- return ret && path.delete();
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombineGloballyTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombineGloballyTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombineGloballyTest.java
new file mode 100644
index 0000000..6945d68
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombineGloballyTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.common.collect.Iterables;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+public class CombineGloballyTest {
+
+ private static final String[] WORDS_ARRAY = {
+ "hi there", "hi", "hi sue bob",
+ "hi sue", "", "bob hi"};
+ private static final List<String> WORDS = Arrays.asList(WORDS_ARRAY);
+
+ @Test
+ public void test() throws Exception {
+ SparkPipelineOptions options = SparkPipelineOptionsFactory.create();
+ Pipeline p = Pipeline.create(options);
+ PCollection<String> inputWords = p.apply(Create.of(WORDS)).setCoder(StringUtf8Coder.of());
+ PCollection<String> output = inputWords.apply(Combine.globally(new WordMerger()));
+
+ EvaluationResult res = SparkPipelineRunner.create().run(p);
+ assertEquals("hi there,hi,hi sue bob,hi sue,,bob hi", Iterables.getOnlyElement(res.get(output)));
+ res.close();
+ }
+
+ public static class WordMerger extends Combine.CombineFn<String, StringBuilder, String> {
+
+ @Override
+ public StringBuilder createAccumulator() {
+ // return null to differentiate from an empty string
+ return null;
+ }
+
+ @Override
+ public StringBuilder addInput(StringBuilder accumulator, String input) {
+ return combine(accumulator, input);
+ }
+
+ @Override
+ public StringBuilder mergeAccumulators(Iterable<StringBuilder> accumulators) {
+ StringBuilder sb = new StringBuilder();
+ for (StringBuilder accum : accumulators) {
+ if (accum != null) {
+ sb.append(accum);
+ }
+ }
+ return sb;
+ }
+
+ @Override
+ public String extractOutput(StringBuilder accumulator) {
+ return accumulator != null ? accumulator.toString(): "";
+ }
+
+ private static StringBuilder combine(StringBuilder accum, String datum) {
+ if (accum == null) {
+ return new StringBuilder(datum);
+ } else {
+ accum.append(",").append(datum);
+ return accum;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombinePerKeyTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombinePerKeyTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombinePerKeyTest.java
new file mode 100644
index 0000000..0373968
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombinePerKeyTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.KvCoder;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.coders.VarLongCoder;
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
+import com.google.cloud.dataflow.sdk.transforms.*;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.common.collect.ImmutableList;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class CombinePerKeyTest {
+
+ private static final List<String> WORDS =
+ ImmutableList.of("the", "quick", "brown", "fox", "jumped", "over", "the", "lazy", "dog");
+ @Test
+ public void testRun() {
+ Pipeline p = Pipeline.create(PipelineOptionsFactory.create());
+ PCollection<String> inputWords = p.apply(Create.of(WORDS)).setCoder(StringUtf8Coder.of());
+ PCollection<KV<String, Long>> cnts = inputWords.apply(new SumPerKey<String>());
+ EvaluationResult res = SparkPipelineRunner.create().run(p);
+ Map<String, Long> actualCnts = new HashMap<>();
+ for (KV<String, Long> kv : res.get(cnts)) {
+ actualCnts.put(kv.getKey(), kv.getValue());
+ }
+ res.close();
+ Assert.assertEquals(8, actualCnts.size());
+ Assert.assertEquals(Long.valueOf(2L), actualCnts.get("the"));
+ }
+
+ private static class SumPerKey<T> extends PTransform<PCollection<T>, PCollection<KV<T, Long>>> {
+ @Override
+ public PCollection<KV<T, Long>> apply(PCollection<T> pcol) {
+ PCollection<KV<T, Long>> withLongs = pcol.apply(ParDo.of(new DoFn<T, KV<T, Long>>() {
+ @Override
+ public void processElement(ProcessContext processContext) throws Exception {
+ processContext.output(KV.of(processContext.element(), 1L));
+ }
+ })).setCoder(KvCoder.of(pcol.getCoder(), VarLongCoder.of()));
+ return withLongs.apply(Sum.<T>longsPerKey());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/DoFnOutputTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/DoFnOutputTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/DoFnOutputTest.java
new file mode 100644
index 0000000..a9779e6
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/DoFnOutputTest.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.junit.Test;
+
+import java.io.Serializable;
+
+public class DoFnOutputTest implements Serializable {
+ @Test
+ public void test() throws Exception {
+ SparkPipelineOptions options = SparkPipelineOptionsFactory.create();
+ options.setRunner(SparkPipelineRunner.class);
+ Pipeline pipeline = Pipeline.create(options);
+
+ PCollection<String> strings = pipeline.apply(Create.of("a"));
+ // Test that values written from startBundle() and finishBundle() are written to
+ // the output
+ PCollection<String> output = strings.apply(ParDo.of(new DoFn<String, String>() {
+ @Override
+ public void startBundle(Context c) throws Exception {
+ c.output("start");
+ }
+ @Override
+ public void processElement(ProcessContext c) throws Exception {
+ c.output(c.element());
+ }
+ @Override
+ public void finishBundle(Context c) throws Exception {
+ c.output("finish");
+ }
+ }));
+
+ DataflowAssert.that(output).containsInAnyOrder("start", "a", "finish");
+
+ EvaluationResult res = SparkPipelineRunner.create().run(pipeline);
+ res.close();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java
new file mode 100644
index 0000000..8ab3798
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
+import com.google.cloud.dataflow.sdk.runners.AggregatorValues;
+import com.google.cloud.dataflow.sdk.transforms.*;
+import com.google.cloud.dataflow.sdk.values.*;
+import com.google.common.collect.Iterables;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class MultiOutputWordCountTest {
+
+ private static final TupleTag<String> upper = new TupleTag<>();
+ private static final TupleTag<String> lower = new TupleTag<>();
+ private static final TupleTag<KV<String, Long>> lowerCnts = new TupleTag<>();
+ private static final TupleTag<KV<String, Long>> upperCnts = new TupleTag<>();
+
+ @Test
+ public void testRun() throws Exception {
+ Pipeline p = Pipeline.create(PipelineOptionsFactory.create());
+ PCollection<String> regex = p.apply(Create.of("[^a-zA-Z']+"));
+ PCollection<String> w1 = p.apply(Create.of("Here are some words to count", "and some others"));
+ PCollection<String> w2 = p.apply(Create.of("Here are some more words", "and even more words"));
+ PCollectionList<String> list = PCollectionList.of(w1).and(w2);
+
+ PCollection<String> union = list.apply(Flatten.<String>pCollections());
+ PCollectionView<String> regexView = regex.apply(View.<String>asSingleton());
+ CountWords countWords = new CountWords(regexView);
+ PCollectionTuple luc = union.apply(countWords);
+ PCollection<Long> unique = luc.get(lowerCnts).apply(
+ ApproximateUnique.<KV<String, Long>>globally(16));
+
+ EvaluationResult res = SparkPipelineRunner.create().run(p);
+ Iterable<KV<String, Long>> actualLower = res.get(luc.get(lowerCnts));
+ Assert.assertEquals("are", actualLower.iterator().next().getKey());
+ Iterable<KV<String, Long>> actualUpper = res.get(luc.get(upperCnts));
+ Assert.assertEquals("Here", actualUpper.iterator().next().getKey());
+ Iterable<Long> actualUniqCount = res.get(unique);
+ Assert.assertEquals(9, (long) actualUniqCount.iterator().next());
+ int actualTotalWords = res.getAggregatorValue("totalWords", Integer.class);
+ Assert.assertEquals(18, actualTotalWords);
+ int actualMaxWordLength = res.getAggregatorValue("maxWordLength", Integer.class);
+ Assert.assertEquals(6, actualMaxWordLength);
+ AggregatorValues<Integer> aggregatorValues = res.getAggregatorValues(countWords
+ .getTotalWordsAggregator());
+ Assert.assertEquals(18, Iterables.getOnlyElement(aggregatorValues.getValues()).intValue());
+
+ res.close();
+ }
+
+ /**
+ * A DoFn that tokenizes lines of text into individual words.
+ */
+ static class ExtractWordsFn extends DoFn<String, String> {
+
+ private final Aggregator<Integer, Integer> totalWords = createAggregator("totalWords",
+ new Sum.SumIntegerFn());
+ private final Aggregator<Integer, Integer> maxWordLength = createAggregator("maxWordLength",
+ new Max.MaxIntegerFn());
+ private final PCollectionView<String> regex;
+
+ ExtractWordsFn(PCollectionView<String> regex) {
+ this.regex = regex;
+ }
+
+ @Override
+ public void processElement(ProcessContext c) {
+ String[] words = c.element().split(c.sideInput(regex));
+ for (String word : words) {
+ totalWords.addValue(1);
+ if (!word.isEmpty()) {
+ maxWordLength.addValue(word.length());
+ if (Character.isLowerCase(word.charAt(0))) {
+ c.output(word);
+ } else {
+ c.sideOutput(upper, word);
+ }
+ }
+ }
+ }
+ }
+
+ public static class CountWords extends PTransform<PCollection<String>, PCollectionTuple> {
+
+ private final PCollectionView<String> regex;
+ private final ExtractWordsFn extractWordsFn;
+
+ public CountWords(PCollectionView<String> regex) {
+ this.regex = regex;
+ this.extractWordsFn = new ExtractWordsFn(regex);
+ }
+
+ @Override
+ public PCollectionTuple apply(PCollection<String> lines) {
+ // Convert lines of text into individual words.
+ PCollectionTuple lowerUpper = lines
+ .apply(ParDo.of(extractWordsFn)
+ .withSideInputs(regex)
+ .withOutputTags(lower, TupleTagList.of(upper)));
+ lowerUpper.get(lower).setCoder(StringUtf8Coder.of());
+ lowerUpper.get(upper).setCoder(StringUtf8Coder.of());
+ PCollection<KV<String, Long>> lowerCounts = lowerUpper.get(lower).apply(Count
+ .<String>perElement());
+ PCollection<KV<String, Long>> upperCounts = lowerUpper.get(upper).apply(Count
+ .<String>perElement());
+ return PCollectionTuple
+ .of(lowerCnts, lowerCounts)
+ .and(upperCnts, upperCounts);
+ }
+
+ Aggregator<Integer, Integer> getTotalWordsAggregator() {
+ return extractWordsFn.totalWords;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SerializationTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SerializationTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SerializationTest.java
new file mode 100644
index 0000000..b378795
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SerializationTest.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.AtomicCoder;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.transforms.*;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.common.base.Function;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Set;
+import java.util.regex.Pattern;
+
+public class SerializationTest {
+
+ public static class StringHolder { // not serializable
+ private final String string;
+
+ public StringHolder(String string) {
+ this.string = string;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ StringHolder that = (StringHolder) o;
+ return string.equals(that.string);
+ }
+
+ @Override
+ public int hashCode() {
+ return string.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return string;
+ }
+ }
+
+ public static class StringHolderUtf8Coder extends AtomicCoder<StringHolder> {
+
+ private final StringUtf8Coder stringUtf8Coder = StringUtf8Coder.of();
+
+ @Override
+ public void encode(StringHolder value, OutputStream outStream, Context context) throws IOException {
+ stringUtf8Coder.encode(value.toString(), outStream, context);
+ }
+
+ @Override
+ public StringHolder decode(InputStream inStream, Context context) throws IOException {
+ return new StringHolder(stringUtf8Coder.decode(inStream, context));
+ }
+
+ public static Coder<StringHolder> of() {
+ return new StringHolderUtf8Coder();
+ }
+ }
+
+ private static final String[] WORDS_ARRAY = {
+ "hi there", "hi", "hi sue bob",
+ "hi sue", "", "bob hi"};
+ private static final List<StringHolder> WORDS = Lists.transform(
+ Arrays.asList(WORDS_ARRAY), new Function<String, StringHolder>() {
+ @Override public StringHolder apply(String s) {
+ return new StringHolder(s);
+ }
+ });
+ private static final Set<StringHolder> EXPECTED_COUNT_SET =
+ ImmutableSet.copyOf(Lists.transform(
+ Arrays.asList("hi: 5", "there: 1", "sue: 2", "bob: 2"),
+ new Function<String, StringHolder>() {
+ @Override
+ public StringHolder apply(String s) {
+ return new StringHolder(s);
+ }
+ }));
+
+ @Test
+ public void testRun() throws Exception {
+ SparkPipelineOptions options = SparkPipelineOptionsFactory.create();
+ options.setRunner(SparkPipelineRunner.class);
+ Pipeline p = Pipeline.create(options);
+ PCollection<StringHolder> inputWords =
+ p.apply(Create.of(WORDS).withCoder(StringHolderUtf8Coder.of()));
+ PCollection<StringHolder> output = inputWords.apply(new CountWords());
+
+ DataflowAssert.that(output).containsInAnyOrder(EXPECTED_COUNT_SET);
+
+ EvaluationResult res = SparkPipelineRunner.create().run(p);
+ res.close();
+ }
+
+ /**
+ * A DoFn that tokenizes lines of text into individual words.
+ */
+ static class ExtractWordsFn extends DoFn<StringHolder, StringHolder> {
+ private static final Pattern WORD_BOUNDARY = Pattern.compile("[^a-zA-Z']+");
+ private final Aggregator<Long, Long> emptyLines =
+ createAggregator("emptyLines", new Sum.SumLongFn());
+
+ @Override
+ public void processElement(ProcessContext c) {
+ // Split the line into words.
+ String[] words = WORD_BOUNDARY.split(c.element().toString());
+
+ // Keep track of the number of lines without any words encountered while tokenizing.
+ // This aggregator is visible in the monitoring UI when run using DataflowPipelineRunner.
+ if (words.length == 0) {
+ emptyLines.addValue(1L);
+ }
+
+ // Output each word encountered into the output PCollection.
+ for (String word : words) {
+ if (!word.isEmpty()) {
+ c.output(new StringHolder(word));
+ }
+ }
+ }
+ }
+
+ /**
+ * A DoFn that converts a Word and Count into a printable string.
+ */
+ private static class FormatCountsFn extends DoFn<KV<StringHolder, Long>, StringHolder> {
+ @Override
+ public void processElement(ProcessContext c) {
+ c.output(new StringHolder(c.element().getKey() + ": " + c.element().getValue()));
+ }
+ }
+
+ private static class CountWords extends PTransform<PCollection<StringHolder>, PCollection<StringHolder>> {
+ @Override
+ public PCollection<StringHolder> apply(PCollection<StringHolder> lines) {
+
+ // Convert lines of text into individual words.
+ PCollection<StringHolder> words = lines.apply(
+ ParDo.of(new ExtractWordsFn()));
+
+ // Count the number of times each word occurs.
+ PCollection<KV<StringHolder, Long>> wordCounts =
+ words.apply(Count.<StringHolder>perElement());
+
+ // Format each word and count into a printable string.
+
+ return wordCounts.apply(ParDo.of(new FormatCountsFn()));
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java
new file mode 100644
index 0000000..fc14fc7
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.StringDelegateCoder;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.net.URI;
+
+import static org.junit.Assert.*;
+
+public class SideEffectsTest implements Serializable {
+
+ static class UserException extends RuntimeException {
+ }
+
+ @Test
+ public void test() throws Exception {
+ SparkPipelineOptions options = SparkPipelineOptionsFactory.create();
+ options.setRunner(SparkPipelineRunner.class);
+ Pipeline pipeline = Pipeline.create(options);
+
+ pipeline.getCoderRegistry().registerCoder(URI.class, StringDelegateCoder.of(URI.class));
+
+ pipeline.apply(Create.of("a")).apply(ParDo.of(new DoFn<String, String>() {
+ @Override
+ public void processElement(ProcessContext c) throws Exception {
+ throw new UserException();
+ }
+ }));
+
+ try {
+ pipeline.run();
+ fail("Run should thrown an exception");
+ } catch (RuntimeException e) {
+ assertNotNull(e.getCause());
+
+ // TODO: remove the version check (and the setup and teardown methods) when we no
+ // longer support Spark 1.3 or 1.4
+ String version = SparkContextFactory.getSparkContext(options.getSparkMaster(), options.getAppName()).version();
+ if (!version.startsWith("1.3.") && !version.startsWith("1.4.")) {
+ assertTrue(e.getCause() instanceof UserException);
+ }
+ }
+ }
+
+ @Before
+ public void setup() {
+ System.setProperty(SparkContextFactory.TEST_REUSE_SPARK_CONTEXT, "true");
+ }
+
+ @After
+ public void teardown() {
+ System.setProperty(SparkContextFactory.TEST_REUSE_SPARK_CONTEXT, "false");
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TestSparkPipelineOptionsFactory.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TestSparkPipelineOptionsFactory.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TestSparkPipelineOptionsFactory.java
new file mode 100644
index 0000000..9cace83
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TestSparkPipelineOptionsFactory.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestSparkPipelineOptionsFactory {
+ @Test
+ public void testDefaultCreateMethod() {
+ SparkPipelineOptions actualOptions = SparkPipelineOptionsFactory.create();
+ Assert.assertEquals("local[1]", actualOptions.getSparkMaster());
+ }
+
+ @Test
+ public void testSettingCustomOptions() {
+ SparkPipelineOptions actualOptions = SparkPipelineOptionsFactory.create();
+ actualOptions.setSparkMaster("spark://207.184.161.138:7077");
+ Assert.assertEquals("spark://207.184.161.138:7077", actualOptions.getSparkMaster());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java
new file mode 100644
index 0000000..da30321
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.api.client.repackaged.com.google.common.base.Joiner;
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.io.TextIO;
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
+import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner;
+import com.google.cloud.dataflow.sdk.runners.PipelineRunner;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.common.base.Charsets;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.apache.commons.io.FileUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TestName;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * A test for the transforms registered in TransformTranslator.
+ * Builds a regular Dataflow pipeline with each of the mapped
+ * transforms, and makes sure that they work when the pipeline is
+ * executed in Spark.
+ */
+public class TransformTranslatorTest {
+
+ @Rule
+ public TestName name = new TestName();
+
+ private DirectPipelineRunner directRunner;
+ private SparkPipelineRunner sparkRunner;
+ private String testDataDirName;
+
+ @Before public void init() throws IOException {
+ sparkRunner = SparkPipelineRunner.create();
+ directRunner = DirectPipelineRunner.createForTest();
+ testDataDirName = Joiner.on(File.separator).join("target", "test-data", name.getMethodName())
+ + File.separator;
+ FileUtils.deleteDirectory(new File(testDataDirName));
+ new File(testDataDirName).mkdirs();
+ }
+
+ /**
+ * Builds a simple pipeline with TextIO.Read and TextIO.Write, runs the pipeline
+ * in DirectPipelineRunner and on SparkPipelineRunner, with the mapped dataflow-to-spark
+ * transforms. Finally it makes sure that the results are the same for both runs.
+ */
+ @Test
+ public void testTextIOReadAndWriteTransforms() throws IOException {
+ String directOut = runPipeline("direct", directRunner);
+ String sparkOut = runPipeline("spark", sparkRunner);
+
+ List<String> directOutput =
+ Files.readAllLines(Paths.get(directOut + "-00000-of-00001"), Charsets.UTF_8);
+
+ List<String> sparkOutput =
+ Files.readAllLines(Paths.get(sparkOut + "-00000-of-00001"), Charsets.UTF_8);
+
+ // sort output to get a stable result (PCollections are not ordered)
+ Collections.sort(directOutput);
+ Collections.sort(sparkOutput);
+
+ Assert.assertArrayEquals(directOutput.toArray(), sparkOutput.toArray());
+ }
+
+ private String runPipeline(String name, PipelineRunner<?> runner) {
+ Pipeline p = Pipeline.create(PipelineOptionsFactory.create());
+ String outFile = Joiner.on(File.separator).join(testDataDirName, "test_text_out_" + name);
+ PCollection<String> lines = p.apply(TextIO.Read.from("src/test/resources/test_text.txt"));
+ lines.apply(TextIO.Write.to(outFile));
+ runner.run(p);
+ return outFile;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java
new file mode 100644
index 0000000..9f29a37
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows;
+import com.google.cloud.dataflow.sdk.transforms.windowing.Window;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.common.collect.ImmutableList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SimpleWordCountTest;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.joda.time.Duration;
+import org.junit.Test;
+
+public class WindowedWordCountTest {
+ private static final String[] WORDS_ARRAY = {
+ "hi there", "hi", "hi sue bob",
+ "hi sue", "", "bob hi"};
+ private static final Long[] TIMESTAMPS_ARRAY = {
+ 60000L, 60000L, 60000L,
+ 120000L, 120000L, 120000L};
+ private static final List<String> WORDS = Arrays.asList(WORDS_ARRAY);
+ private static final List<Long> TIMESTAMPS = Arrays.asList(TIMESTAMPS_ARRAY);
+ private static final List<String> EXPECTED_COUNT_SET =
+ ImmutableList.of("hi: 3", "there: 1", "sue: 1", "bob: 1",
+ "hi: 2", "sue: 1", "bob: 1");
+
+ @Test
+ public void testRun() throws Exception {
+ SparkPipelineOptions options = SparkPipelineOptionsFactory.create();
+ options.setRunner(SparkPipelineRunner.class);
+ Pipeline p = Pipeline.create(PipelineOptionsFactory.create());
+ PCollection<String> inputWords = p.apply(Create.timestamped(WORDS, TIMESTAMPS))
+ .setCoder(StringUtf8Coder.of());
+ PCollection<String> windowedWords = inputWords
+ .apply(Window.<String>into(FixedWindows.of(Duration.standardMinutes(1))));
+
+ PCollection<String> output = windowedWords.apply(new SimpleWordCountTest.CountWords());
+
+ DataflowAssert.that(output).containsInAnyOrder(EXPECTED_COUNT_SET);
+
+ EvaluationResult res = SparkPipelineRunner.create().run(p);
+ res.close();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java
new file mode 100644
index 0000000..a3eb301
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.transforms.Flatten;
+import com.google.cloud.dataflow.sdk.transforms.View;
+import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows;
+import com.google.cloud.dataflow.sdk.transforms.windowing.Window;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollectionList;
+
+import org.apache.beam.runners.spark.SparkStreamingPipelineOptions;
+import org.apache.beam.runners.spark.io.CreateStream;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.apache.beam.runners.spark.translation.streaming.utils.DataflowAssertStreaming;
+
+import org.joda.time.Duration;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Test Flatten (union) implementation for streaming.
+ */
+public class FlattenStreamingTest {
+
+ private static final String[] WORDS_ARRAY_1 = {
+ "one", "two", "three", "four"};
+ private static final List<Iterable<String>> WORDS_QUEUE_1 =
+ Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY_1));
+ private static final String[] WORDS_ARRAY_2 = {
+ "five", "six", "seven", "eight"};
+ private static final List<Iterable<String>> WORDS_QUEUE_2 =
+ Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY_2));
+ private static final String[] EXPECTED_UNION = {
+ "one", "two", "three", "four", "five", "six", "seven", "eight"};
+ private static final long TEST_TIMEOUT_MSEC = 1000L;
+
+ @Test
+ public void testRun() throws Exception {
+ SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create();
+ options.setAppName(this.getClass().getSimpleName());
+ options.setRunner(SparkPipelineRunner.class);
+ options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval
+ Pipeline p = Pipeline.create(options);
+
+ PCollection<String> w1 =
+ p.apply(CreateStream.fromQueue(WORDS_QUEUE_1)).setCoder(StringUtf8Coder.of());
+ PCollection<String> windowedW1 =
+ w1.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
+ PCollection<String> w2 =
+ p.apply(CreateStream.fromQueue(WORDS_QUEUE_2)).setCoder(StringUtf8Coder.of());
+ PCollection<String> windowedW2 =
+ w2.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
+ PCollectionList<String> list = PCollectionList.of(windowedW1).and(windowedW2);
+ PCollection<String> union = list.apply(Flatten.<String>pCollections());
+
+ DataflowAssert.thatIterable(union.apply(View.<String>asIterable()))
+ .containsInAnyOrder(EXPECTED_UNION);
+
+ EvaluationResult res = SparkPipelineRunner.create(options).run(p);
+ res.close();
+
+ DataflowAssertStreaming.assertNoFailures(res);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java
new file mode 100644
index 0000000..628fe86
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.KvCoder;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.transforms.View;
+import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows;
+import com.google.cloud.dataflow.sdk.transforms.windowing.Window;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import org.apache.beam.runners.spark.SparkStreamingPipelineOptions;
+import org.apache.beam.runners.spark.io.KafkaIO;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.apache.beam.runners.spark.translation.streaming.utils.DataflowAssertStreaming;
+import org.apache.beam.runners.spark.translation.streaming.utils.EmbeddedKafkaCluster;
+
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.common.serialization.Serializer;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.joda.time.Duration;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+
+import kafka.serializer.StringDecoder;
+
+/**
+ * Test Kafka as input.
+ */
+public class KafkaStreamingTest {
+ private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER =
+ new EmbeddedKafkaCluster.EmbeddedZookeeper(17001);
+ private static final EmbeddedKafkaCluster EMBEDDED_KAFKA_CLUSTER =
+ new EmbeddedKafkaCluster(EMBEDDED_ZOOKEEPER.getConnection(),
+ new Properties(), Collections.singletonList(6667));
+ private static final String TOPIC = "kafka_dataflow_test_topic";
+ private static final Map<String, String> KAFKA_MESSAGES = ImmutableMap.of(
+ "k1", "v1", "k2", "v2", "k3", "v3", "k4", "v4"
+ );
+ private static final Set<String> EXPECTED = ImmutableSet.of(
+ "k1,v1", "k2,v2", "k3,v3", "k4,v4"
+ );
+ private static final long TEST_TIMEOUT_MSEC = 1000L;
+
+ @BeforeClass
+ public static void init() throws IOException {
+ EMBEDDED_ZOOKEEPER.startup();
+ EMBEDDED_KAFKA_CLUSTER.startup();
+
+ // write to Kafka
+ Properties producerProps = new Properties();
+ producerProps.putAll(EMBEDDED_KAFKA_CLUSTER.getProps());
+ producerProps.put("request.required.acks", 1);
+ producerProps.put("bootstrap.servers", EMBEDDED_KAFKA_CLUSTER.getBrokerList());
+ Serializer<String> stringSerializer = new StringSerializer();
+ try (@SuppressWarnings("unchecked") KafkaProducer<String, String> kafkaProducer =
+ new KafkaProducer(producerProps, stringSerializer, stringSerializer)) {
+ for (Map.Entry<String, String> en : KAFKA_MESSAGES.entrySet()) {
+ kafkaProducer.send(new ProducerRecord<>(TOPIC, en.getKey(), en.getValue()));
+ }
+ }
+ }
+
+ @Test
+ public void testRun() throws Exception {
+ // test read from Kafka
+ SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create();
+ options.setAppName(this.getClass().getSimpleName());
+ options.setRunner(SparkPipelineRunner.class);
+ options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval
+ Pipeline p = Pipeline.create(options);
+
+ Map<String, String> kafkaParams = ImmutableMap.of(
+ "metadata.broker.list", EMBEDDED_KAFKA_CLUSTER.getBrokerList(),
+ "auto.offset.reset", "smallest"
+ );
+
+ PCollection<KV<String, String>> kafkaInput = p.apply(KafkaIO.Read.from(StringDecoder.class,
+ StringDecoder.class, String.class, String.class, Collections.singleton(TOPIC),
+ kafkaParams))
+ .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));
+ PCollection<KV<String, String>> windowedWords = kafkaInput
+ .apply(Window.<KV<String, String>>into(FixedWindows.of(Duration.standardSeconds(1))));
+
+ PCollection<String> formattedKV = windowedWords.apply(ParDo.of(new FormatKVFn()));
+
+ DataflowAssert.thatIterable(formattedKV.apply(View.<String>asIterable()))
+ .containsInAnyOrder(EXPECTED);
+
+ EvaluationResult res = SparkPipelineRunner.create(options).run(p);
+ res.close();
+
+ DataflowAssertStreaming.assertNoFailures(res);
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ EMBEDDED_KAFKA_CLUSTER.shutdown();
+ EMBEDDED_ZOOKEEPER.shutdown();
+ }
+
+ private static class FormatKVFn extends DoFn<KV<String, String>, String> {
+ @Override
+ public void processElement(ProcessContext c) {
+ c.output(c.element().getKey() + "," + c.element().getValue());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
new file mode 100644
index 0000000..b591510
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.transforms.View;
+import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows;
+import com.google.cloud.dataflow.sdk.transforms.windowing.Window;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.common.collect.ImmutableSet;
+
+import org.apache.beam.runners.spark.SparkStreamingPipelineOptions;
+import org.apache.beam.runners.spark.io.CreateStream;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SimpleWordCountTest;
+import org.apache.beam.runners.spark.SparkPipelineRunner;
+import org.apache.beam.runners.spark.translation.streaming.utils.DataflowAssertStreaming;
+
+import org.joda.time.Duration;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+public class SimpleStreamingWordCountTest {
+
+ private static final String[] WORDS_ARRAY = {
+ "hi there", "hi", "hi sue bob", "hi sue", "", "bob hi"};
+ private static final List<Iterable<String>> WORDS_QUEUE =
+ Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY));
+ private static final Set<String> EXPECTED_COUNT_SET =
+ ImmutableSet.of("hi: 5", "there: 1", "sue: 2", "bob: 2");
+ private static final long TEST_TIMEOUT_MSEC = 1000L;
+
+ @Test
+ public void testRun() throws Exception {
+ SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create();
+ options.setAppName(this.getClass().getSimpleName());
+ options.setRunner(SparkPipelineRunner.class);
+ options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval
+ Pipeline p = Pipeline.create(options);
+
+ PCollection<String> inputWords =
+ p.apply(CreateStream.fromQueue(WORDS_QUEUE)).setCoder(StringUtf8Coder.of());
+ PCollection<String> windowedWords = inputWords
+ .apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
+
+ PCollection<String> output = windowedWords.apply(new SimpleWordCountTest.CountWords());
+
+ DataflowAssert.thatIterable(output.apply(View.<String>asIterable()))
+ .containsInAnyOrder(EXPECTED_COUNT_SET);
+
+ EvaluationResult res = SparkPipelineRunner.create(options).run(p);
+ res.close();
+
+ DataflowAssertStreaming.assertNoFailures(res);
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/DataflowAssertStreaming.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/DataflowAssertStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/DataflowAssertStreaming.java
new file mode 100644
index 0000000..30673dd
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/DataflowAssertStreaming.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming.utils;
+
+import org.apache.beam.runners.spark.EvaluationResult;
+
+import org.junit.Assert;
+
+/**
+ * Since DataflowAssert doesn't propagate assert exceptions, use Aggregators to assert streaming
+ * success/failure counters.
+ */
+public final class DataflowAssertStreaming {
+ /**
+ * Copied aggregator names from {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert}
+ */
+ static final String SUCCESS_COUNTER = "DataflowAssertSuccess";
+ static final String FAILURE_COUNTER = "DataflowAssertFailure";
+
+ private DataflowAssertStreaming() {
+ }
+
+ public static void assertNoFailures(EvaluationResult res) {
+ int failures = res.getAggregatorValue(FAILURE_COUNTER, Integer.class);
+ Assert.assertEquals("Found " + failures + " failures, see the log for details", 0, failures);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/EmbeddedKafkaCluster.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/EmbeddedKafkaCluster.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/EmbeddedKafkaCluster.java
new file mode 100644
index 0000000..e967cdb
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/EmbeddedKafkaCluster.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming.utils;
+
+import org.apache.zookeeper.server.NIOServerCnxnFactory;
+import org.apache.zookeeper.server.ServerCnxnFactory;
+import org.apache.zookeeper.server.ZooKeeperServer;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Properties;
+import java.util.Random;
+
+import kafka.server.KafkaConfig;
+import kafka.server.KafkaServer;
+import kafka.utils.Time;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * https://gist.github.com/fjavieralba/7930018
+ */
+public class EmbeddedKafkaCluster {
+
+ private static final Logger LOG = LoggerFactory.getLogger(EmbeddedKafkaCluster.class);
+
+ private final List<Integer> ports;
+ private final String zkConnection;
+ private final Properties baseProperties;
+
+ private final String brokerList;
+
+ private final List<KafkaServer> brokers;
+ private final List<File> logDirs;
+
+ public EmbeddedKafkaCluster(String zkConnection) {
+ this(zkConnection, new Properties());
+ }
+
+ public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties) {
+ this(zkConnection, baseProperties, Collections.singletonList(-1));
+ }
+
+ public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties, List<Integer> ports) {
+ this.zkConnection = zkConnection;
+ this.ports = resolvePorts(ports);
+ this.baseProperties = baseProperties;
+
+ this.brokers = new ArrayList<>();
+ this.logDirs = new ArrayList<>();
+
+ this.brokerList = constructBrokerList(this.ports);
+ }
+
+ private static List<Integer> resolvePorts(List<Integer> ports) {
+ List<Integer> resolvedPorts = new ArrayList<>();
+ for (Integer port : ports) {
+ resolvedPorts.add(resolvePort(port));
+ }
+ return resolvedPorts;
+ }
+
+ private static int resolvePort(int port) {
+ if (port == -1) {
+ return TestUtils.getAvailablePort();
+ }
+ return port;
+ }
+
+ private static String constructBrokerList(List<Integer> ports) {
+ StringBuilder sb = new StringBuilder();
+ for (Integer port : ports) {
+ if (sb.length() > 0) {
+ sb.append(",");
+ }
+ sb.append("localhost:").append(port);
+ }
+ return sb.toString();
+ }
+
+ public void startup() {
+ for (int i = 0; i < ports.size(); i++) {
+ Integer port = ports.get(i);
+ File logDir = TestUtils.constructTempDir("kafka-local");
+
+ Properties properties = new Properties();
+ properties.putAll(baseProperties);
+ properties.setProperty("zookeeper.connect", zkConnection);
+ properties.setProperty("broker.id", String.valueOf(i + 1));
+ properties.setProperty("host.name", "localhost");
+ properties.setProperty("port", Integer.toString(port));
+ properties.setProperty("log.dir", logDir.getAbsolutePath());
+ properties.setProperty("log.flush.interval.messages", String.valueOf(1));
+
+ KafkaServer broker = startBroker(properties);
+
+ brokers.add(broker);
+ logDirs.add(logDir);
+ }
+ }
+
+
+ private static KafkaServer startBroker(Properties props) {
+ KafkaServer server = new KafkaServer(new KafkaConfig(props), new SystemTime());
+ server.startup();
+ return server;
+ }
+
+ public Properties getProps() {
+ Properties props = new Properties();
+ props.putAll(baseProperties);
+ props.put("metadata.broker.list", brokerList);
+ props.put("zookeeper.connect", zkConnection);
+ return props;
+ }
+
+ public String getBrokerList() {
+ return brokerList;
+ }
+
+ public List<Integer> getPorts() {
+ return ports;
+ }
+
+ public String getZkConnection() {
+ return zkConnection;
+ }
+
+ public void shutdown() {
+ for (KafkaServer broker : brokers) {
+ try {
+ broker.shutdown();
+ } catch (Exception e) {
+ LOG.warn("{}", e.getMessage(), e);
+ }
+ }
+ for (File logDir : logDirs) {
+ try {
+ TestUtils.deleteFile(logDir);
+ } catch (FileNotFoundException e) {
+ LOG.warn("{}", e.getMessage(), e);
+ }
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "EmbeddedKafkaCluster{" + "brokerList='" + brokerList + "'}";
+ }
+
+ public static class EmbeddedZookeeper {
+ private int port = -1;
+ private int tickTime = 500;
+
+ private ServerCnxnFactory factory;
+ private File snapshotDir;
+ private File logDir;
+
+ public EmbeddedZookeeper() {
+ this(-1);
+ }
+
+ public EmbeddedZookeeper(int port) {
+ this(port, 500);
+ }
+
+ public EmbeddedZookeeper(int port, int tickTime) {
+ this.port = resolvePort(port);
+ this.tickTime = tickTime;
+ }
+
+ private static int resolvePort(int port) {
+ if (port == -1) {
+ return TestUtils.getAvailablePort();
+ }
+ return port;
+ }
+
+ public void startup() throws IOException {
+ if (this.port == -1) {
+ this.port = TestUtils.getAvailablePort();
+ }
+ this.factory = NIOServerCnxnFactory.createFactory(new InetSocketAddress("localhost", port),
+ 1024);
+ this.snapshotDir = TestUtils.constructTempDir("embedded-zk/snapshot");
+ this.logDir = TestUtils.constructTempDir("embedded-zk/log");
+
+ try {
+ factory.startup(new ZooKeeperServer(snapshotDir, logDir, tickTime));
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ }
+ }
+
+
+ public void shutdown() {
+ factory.shutdown();
+ try {
+ TestUtils.deleteFile(snapshotDir);
+ } catch (FileNotFoundException e) {
+ // ignore
+ }
+ try {
+ TestUtils.deleteFile(logDir);
+ } catch (FileNotFoundException e) {
+ // ignore
+ }
+ }
+
+ public String getConnection() {
+ return "localhost:" + port;
+ }
+
+ public void setPort(int port) {
+ this.port = port;
+ }
+
+ public void setTickTime(int tickTime) {
+ this.tickTime = tickTime;
+ }
+
+ public int getPort() {
+ return port;
+ }
+
+ public int getTickTime() {
+ return tickTime;
+ }
+
+ @Override
+ public String toString() {
+ return "EmbeddedZookeeper{" + "connection=" + getConnection() + "}";
+ }
+ }
+
+ static class SystemTime implements Time {
+ @Override
+ public long milliseconds() {
+ return System.currentTimeMillis();
+ }
+
+ @Override
+ public long nanoseconds() {
+ return System.nanoTime();
+ }
+
+ @Override
+ public void sleep(long ms) {
+ try {
+ Thread.sleep(ms);
+ } catch (InterruptedException e) {
+ // Ignore
+ }
+ }
+ }
+
+ static final class TestUtils {
+ private static final Random RANDOM = new Random();
+
+ private TestUtils() {
+ }
+
+ static File constructTempDir(String dirPrefix) {
+ File file = new File(System.getProperty("java.io.tmpdir"), dirPrefix + RANDOM.nextInt
+ (10000000));
+ if (!file.mkdirs()) {
+ throw new RuntimeException("could not create temp directory: " + file.getAbsolutePath());
+ }
+ file.deleteOnExit();
+ return file;
+ }
+
+ static int getAvailablePort() {
+ try {
+ try (ServerSocket socket = new ServerSocket(0)) {
+ return socket.getLocalPort();
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
+ }
+ }
+
+ static boolean deleteFile(File path) throws FileNotFoundException {
+ if (!path.exists()) {
+ throw new FileNotFoundException(path.getAbsolutePath());
+ }
+ boolean ret = true;
+ if (path.isDirectory()) {
+ for (File f : path.listFiles()) {
+ ret = ret && deleteFile(f);
+ }
+ }
+ return ret && path.delete();
+ }
+ }
+}