You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2013/09/01 23:58:55 UTC
[11/69] [abbrv] [partial] Initial work to rename package to
org.apache.spark
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
new file mode 100644
index 0000000..8a869c9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -0,0 +1,865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.*;
+
+import com.google.common.base.Optional;
+import scala.Tuple2;
+
+import com.google.common.base.Charsets;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import com.google.common.io.Files;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.Job;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.*;
+import org.apache.spark.partial.BoundedDouble;
+import org.apache.spark.partial.PartialResult;
+import org.apache.spark.storage.StorageLevel;
+import org.apache.spark.util.StatCounter;
+
+
+// The test suite itself is Serializable so that anonymous Function implementations can be
+// serialized, as an alternative to converting these anonymous classes to static inner classes;
+// see http://stackoverflow.com/questions/758570/.
+public class JavaAPISuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaAPISuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port");
+ }
+
+ static class ReverseIntComparator implements Comparator<Integer>, Serializable {
+
+ @Override
+ public int compare(Integer a, Integer b) {
+ if (a > b) return -1;
+ else if (a < b) return 1;
+ else return 0;
+ }
+ };
+
+ @Test
+ public void sparkContextUnion() {
+ // Union of non-specialized JavaRDDs
+ List<String> strings = Arrays.asList("Hello", "World");
+ JavaRDD<String> s1 = sc.parallelize(strings);
+ JavaRDD<String> s2 = sc.parallelize(strings);
+ // Varargs
+ JavaRDD<String> sUnion = sc.union(s1, s2);
+ Assert.assertEquals(4, sUnion.count());
+ // List
+ List<JavaRDD<String>> list = new ArrayList<JavaRDD<String>>();
+ list.add(s2);
+ sUnion = sc.union(s1, list);
+ Assert.assertEquals(4, sUnion.count());
+
+ // Union of JavaDoubleRDDs
+ List<Double> doubles = Arrays.asList(1.0, 2.0);
+ JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD dUnion = sc.union(d1, d2);
+ Assert.assertEquals(4, dUnion.count());
+
+ // Union of JavaPairRDDs
+ List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+ pairs.add(new Tuple2<Integer, Integer>(1, 2));
+ pairs.add(new Tuple2<Integer, Integer>(3, 4));
+ JavaPairRDD<Integer, Integer> p1 = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> p2 = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> pUnion = sc.union(p1, p2);
+ Assert.assertEquals(4, pUnion.count());
+ }
+
+ @Test
+ public void sortByKey() {
+ List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+ pairs.add(new Tuple2<Integer, Integer>(0, 4));
+ pairs.add(new Tuple2<Integer, Integer>(3, 2));
+ pairs.add(new Tuple2<Integer, Integer>(-1, 1));
+
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+
+ // Default comparator
+ JavaPairRDD<Integer, Integer> sortedRDD = rdd.sortByKey();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
+ List<Tuple2<Integer, Integer>> sortedPairs = sortedRDD.collect();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
+ Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
+
+ // Custom comparator
+ sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false);
+ Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
+ sortedPairs = sortedRDD.collect();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
+ Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
+ }
+
+ static int foreachCalls = 0;
+
+ @Test
+ public void foreach() {
+ foreachCalls = 0;
+ JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World"));
+ rdd.foreach(new VoidFunction<String>() {
+ @Override
+ public void call(String s) {
+ foreachCalls++;
+ }
+ });
+ Assert.assertEquals(2, foreachCalls);
+ }
+
+ @Test
+ public void lookup() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ Assert.assertEquals(2, categories.lookup("Oranges").size());
+ Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size());
+ }
+
+ @Test
+ public void groupBy() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Function<Integer, Boolean> isOdd = new Function<Integer, Boolean>() {
+ @Override
+ public Boolean call(Integer x) {
+ return x % 2 == 0;
+ }
+ };
+ JavaPairRDD<Boolean, List<Integer>> oddsAndEvens = rdd.groupBy(isOdd);
+ Assert.assertEquals(2, oddsAndEvens.count());
+ Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens
+ Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds
+
+ oddsAndEvens = rdd.groupBy(isOdd, 1);
+ Assert.assertEquals(2, oddsAndEvens.count());
+ Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens
+ Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds
+ }
+
+ @Test
+ public void cogroup() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, Integer>("Oranges", 2),
+ new Tuple2<String, Integer>("Apples", 3)
+ ));
+ JavaPairRDD<String, Tuple2<List<String>, List<Integer>>> cogrouped = categories.cogroup(prices);
+ Assert.assertEquals("[Fruit, Citrus]", cogrouped.lookup("Oranges").get(0)._1().toString());
+ Assert.assertEquals("[2]", cogrouped.lookup("Oranges").get(0)._2().toString());
+
+ cogrouped.collect();
+ }
+
+ @Test
+ public void leftOuterJoin() {
+ JavaPairRDD<Integer, Integer> rdd1 = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(1, 2),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(3, 1)
+ ));
+ JavaPairRDD<Integer, Character> rdd2 = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<Integer, Character>(1, 'x'),
+ new Tuple2<Integer, Character>(2, 'y'),
+ new Tuple2<Integer, Character>(2, 'z'),
+ new Tuple2<Integer, Character>(4, 'w')
+ ));
+ List<Tuple2<Integer,Tuple2<Integer,Optional<Character>>>> joined =
+ rdd1.leftOuterJoin(rdd2).collect();
+ Assert.assertEquals(5, joined.size());
+ Tuple2<Integer,Tuple2<Integer,Optional<Character>>> firstUnmatched =
+ rdd1.leftOuterJoin(rdd2).filter(
+ new Function<Tuple2<Integer, Tuple2<Integer, Optional<Character>>>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<Integer, Tuple2<Integer, Optional<Character>>> tup)
+ throws Exception {
+ return !tup._2()._2().isPresent();
+ }
+ }).first();
+ Assert.assertEquals(3, firstUnmatched._1().intValue());
+ }
+
+ @Test
+ public void foldReduce() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ };
+
+ int sum = rdd.fold(0, add);
+ Assert.assertEquals(33, sum);
+
+ sum = rdd.reduce(add);
+ Assert.assertEquals(33, sum);
+ }
+
+ @Test
+ public void foldByKey() {
+ List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(3, 1)
+ );
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> sums = rdd.foldByKey(0,
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, sums.lookup(1).get(0).intValue());
+ Assert.assertEquals(2, sums.lookup(2).get(0).intValue());
+ Assert.assertEquals(3, sums.lookup(3).get(0).intValue());
+ }
+
+ @Test
+ public void reduceByKey() {
+ List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(3, 1)
+ );
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> counts = rdd.reduceByKey(
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, counts.lookup(1).get(0).intValue());
+ Assert.assertEquals(2, counts.lookup(2).get(0).intValue());
+ Assert.assertEquals(3, counts.lookup(3).get(0).intValue());
+
+ Map<Integer, Integer> localCounts = counts.collectAsMap();
+ Assert.assertEquals(1, localCounts.get(1).intValue());
+ Assert.assertEquals(2, localCounts.get(2).intValue());
+ Assert.assertEquals(3, localCounts.get(3).intValue());
+
+ localCounts = rdd.reduceByKeyLocally(new Function2<Integer, Integer,
+ Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, localCounts.get(1).intValue());
+ Assert.assertEquals(2, localCounts.get(2).intValue());
+ Assert.assertEquals(3, localCounts.get(3).intValue());
+ }
+
+ @Test
+ public void approximateResults() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Map<Integer, Long> countsByValue = rdd.countByValue();
+ Assert.assertEquals(2, countsByValue.get(1).longValue());
+ Assert.assertEquals(1, countsByValue.get(13).longValue());
+
+ PartialResult<Map<Integer, BoundedDouble>> approx = rdd.countByValueApprox(1);
+ Map<Integer, BoundedDouble> finalValue = approx.getFinalValue();
+ Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01);
+ Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01);
+ }
+
+ @Test
+ public void take() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
+ Assert.assertEquals(1, rdd.first().intValue());
+ List<Integer> firstTwo = rdd.take(2);
+ List<Integer> sample = rdd.takeSample(false, 2, 42);
+ }
+
+ @Test
+ public void cartesian() {
+ JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ JavaRDD<String> stringRDD = sc.parallelize(Arrays.asList("Hello", "World"));
+ JavaPairRDD<String, Double> cartesian = stringRDD.cartesian(doubleRDD);
+ Assert.assertEquals(new Tuple2<String, Double>("Hello", 1.0), cartesian.first());
+ }
+
+ @Test
+ public void javaDoubleRDD() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ JavaDoubleRDD distinct = rdd.distinct();
+ Assert.assertEquals(5, distinct.count());
+ JavaDoubleRDD filter = rdd.filter(new Function<Double, Boolean>() {
+ @Override
+ public Boolean call(Double x) {
+ return x > 2.0;
+ }
+ });
+ Assert.assertEquals(3, filter.count());
+ JavaDoubleRDD union = rdd.union(rdd);
+ Assert.assertEquals(12, union.count());
+ union = union.cache();
+ Assert.assertEquals(12, union.count());
+
+ Assert.assertEquals(20, rdd.sum(), 0.01);
+ StatCounter stats = rdd.stats();
+ Assert.assertEquals(20, stats.sum(), 0.01);
+ Assert.assertEquals(20/6.0, rdd.mean(), 0.01);
+ Assert.assertEquals(20/6.0, rdd.mean(), 0.01);
+ Assert.assertEquals(6.22222, rdd.variance(), 0.01);
+ Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01);
+ Assert.assertEquals(2.49444, rdd.stdev(), 0.01);
+ Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01);
+
+ Double first = rdd.first();
+ List<Double> take = rdd.take(5);
+ }
+
+ @Test
+ public void map() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
+ @Override
+ public Double call(Integer x) {
+ return 1.0 * x;
+ }
+ }).cache();
+ JavaPairRDD<Integer, Integer> pairs = rdd.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer x) {
+ return new Tuple2<Integer, Integer>(x, x);
+ }
+ }).cache();
+ JavaRDD<String> strings = rdd.map(new Function<Integer, String>() {
+ @Override
+ public String call(Integer x) {
+ return x.toString();
+ }
+ }).cache();
+ }
+
+ @Test
+ public void flatMap() {
+ JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello World!",
+ "The quick brown fox jumps over the lazy dog."));
+ JavaRDD<String> words = rdd.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String x) {
+ return Arrays.asList(x.split(" "));
+ }
+ });
+ Assert.assertEquals("Hello", words.first());
+ Assert.assertEquals(11, words.count());
+
+ JavaPairRDD<String, String> pairs = rdd.flatMap(
+ new PairFlatMapFunction<String, String, String>() {
+
+ @Override
+ public Iterable<Tuple2<String, String>> call(String s) {
+ List<Tuple2<String, String>> pairs = new LinkedList<Tuple2<String, String>>();
+ for (String word : s.split(" ")) pairs.add(new Tuple2<String, String>(word, word));
+ return pairs;
+ }
+ }
+ );
+ Assert.assertEquals(new Tuple2<String, String>("Hello", "Hello"), pairs.first());
+ Assert.assertEquals(11, pairs.count());
+
+ JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction<String>() {
+ @Override
+ public Iterable<Double> call(String s) {
+ List<Double> lengths = new LinkedList<Double>();
+ for (String word : s.split(" ")) lengths.add(word.length() * 1.0);
+ return lengths;
+ }
+ });
+ Double x = doubles.first();
+ Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01);
+ Assert.assertEquals(11, pairs.count());
+ }
+
+ @Test
+ public void mapsFromPairsToPairs() {
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+
+ // Regression test for SPARK-668:
+ JavaPairRDD<String, Integer> swapped = pairRDD.flatMap(
+ new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+ return Collections.singletonList(item.swap());
+ }
+ });
+ swapped.collect();
+
+ // There was never a bug here, but it's worth testing:
+ pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+ return item.swap();
+ }
+ }).collect();
+ }
+
+ @Test
+ public void mapPartitions() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ JavaRDD<Integer> partitionSums = rdd.mapPartitions(
+ new FlatMapFunction<Iterator<Integer>, Integer>() {
+ @Override
+ public Iterable<Integer> call(Iterator<Integer> iter) {
+ int sum = 0;
+ while (iter.hasNext()) {
+ sum += iter.next();
+ }
+ return Collections.singletonList(sum);
+ }
+ });
+ Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
+ }
+
+ @Test
+ public void persist() {
+ JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals(20, doubleRDD.sum(), 0.1);
+
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+ pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals("a", pairRDD.first()._2());
+
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ rdd = rdd.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals(1, rdd.first().intValue());
+ }
+
+ @Test
+ public void iterator() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
+ TaskContext context = new TaskContext(0, 0, 0, null);
+ Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
+ }
+
+ @Test
+ public void glom() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ Assert.assertEquals("[1, 2]", rdd.glom().first().toString());
+ }
+
+ // File input / output tests are largely adapted from FileSuite:
+
+ @Test
+ public void textFiles() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsTextFile(outputDir);
+ // Read the plain text file and check it's OK
+ File outputFile = new File(outputDir, "part-00000");
+ String content = Files.toString(outputFile, Charsets.UTF_8);
+ Assert.assertEquals("1\n2\n3\n4\n", content);
+ // Also try reading it in as a text file RDD
+ List<String> expected = Arrays.asList("1", "2", "3", "4");
+ JavaRDD<String> readRDD = sc.textFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
+ public void textFilesCompressed() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsTextFile(outputDir, DefaultCodec.class);
+
+ // Try reading it in as a text file RDD
+ List<String> expected = Arrays.asList("1", "2", "3", "4");
+ JavaRDD<String> readRDD = sc.textFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
+ public void sequenceFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ // Try reading the output back as an object file
+ JavaPairRDD<Integer, String> readRDD = sc.sequenceFile(outputDir, IntWritable.class,
+ Text.class).map(new PairFunction<Tuple2<IntWritable, Text>, Integer, String>() {
+ @Override
+ public Tuple2<Integer, String> call(Tuple2<IntWritable, Text> pair) {
+ return new Tuple2<Integer, String>(pair._1().get(), pair._2().toString());
+ }
+ });
+ Assert.assertEquals(pairs, readRDD.collect());
+ }
+
+ @Test
+ public void writeWithNewAPIHadoopFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class,
+ org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.sequenceFile(outputDir, IntWritable.class,
+ Text.class);
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void readWithNewAPIHadoopFile() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.newAPIHadoopFile(outputDir,
+ org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class,
+ Text.class, new Job().getConfiguration());
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void objectFilesOfInts() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsObjectFile(outputDir);
+ // Try reading the output back as an object file
+ List<Integer> expected = Arrays.asList(1, 2, 3, 4);
+ JavaRDD<Integer> readRDD = sc.objectFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
+ public void objectFilesOfComplexTypes() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+ rdd.saveAsObjectFile(outputDir);
+ // Try reading the output back as an object file
+ JavaRDD<Tuple2<Integer, String>> readRDD = sc.objectFile(outputDir);
+ Assert.assertEquals(pairs, readRDD.collect());
+ }
+
+ @Test
+ public void hadoopFile() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
+ SequenceFileInputFormat.class, IntWritable.class, Text.class);
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void hadoopFileCompressed() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output_compressed").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class,
+ DefaultCodec.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
+ SequenceFileInputFormat.class, IntWritable.class, Text.class);
+
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
+ public void zip() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
+ @Override
+ public Double call(Integer x) {
+ return 1.0 * x;
+ }
+ });
+ JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
+ zipped.count();
+ }
+
+ @Test
+ public void zipPartitions() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2);
+ JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2);
+ FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn =
+ new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() {
+ @Override
+ public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) {
+ int sizeI = 0;
+ int sizeS = 0;
+ while (i.hasNext()) {
+ sizeI += 1;
+ i.next();
+ }
+ while (s.hasNext()) {
+ sizeS += 1;
+ s.next();
+ }
+ return Arrays.asList(sizeI, sizeS);
+ }
+ };
+
+ JavaRDD<Integer> sizes = rdd1.zipPartitions(rdd2, sizesFn);
+ Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
+ }
+
+ @Test
+ public void accumulators() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+
+ final Accumulator<Integer> intAccum = sc.intAccumulator(10);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ intAccum.add(x);
+ }
+ });
+ Assert.assertEquals((Integer) 25, intAccum.value());
+
+ final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ doubleAccum.add((double) x);
+ }
+ });
+ Assert.assertEquals((Double) 25.0, doubleAccum.value());
+
+ // Try a custom accumulator type
+ AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
+ public Float addInPlace(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float addAccumulator(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float zero(Float initialValue) {
+ return 0.0f;
+ }
+ };
+
+ final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ floatAccum.add((float) x);
+ }
+ });
+ Assert.assertEquals((Float) 25.0f, floatAccum.value());
+
+ // Test the setValue method
+ floatAccum.setValue(5.0f);
+ Assert.assertEquals((Float) 5.0f, floatAccum.value());
+ }
+
+ @Test
+ public void keyBy() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2));
+ List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() {
+ public String call(Integer t) throws Exception {
+ return t.toString();
+ }
+ }).collect();
+ Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
+ Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
+ }
+
+ @Test
+ public void checkpointAndComputation() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
+ }
+
+ @Test
+ public void checkpointAndRestore() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+
+ Assert.assertTrue(rdd.getCheckpointFile().isPresent());
+ JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
+ }
+
+ @Test
+ public void mapOnPairRDD() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1,2,3,4));
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+ JavaPairRDD<Integer, Integer> rdd3 = rdd2.map(
+ new PairFunction<Tuple2<Integer, Integer>, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) throws Exception {
+ return new Tuple2<Integer, Integer>(in._2(), in._1());
+ }
+ });
+ Assert.assertEquals(Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(0, 2),
+ new Tuple2<Integer, Integer>(1, 3),
+ new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala
new file mode 100644
index 0000000..d7b23c9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+import com.esotericsoftware.kryo._
+
+import KryoTest._
+
+class KryoSerializerSuite extends FunSuite with SharedSparkContext {
+ test("basic types") {
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ }
+ check(1)
+ check(1L)
+ check(1.0f)
+ check(1.0)
+ check(1.toByte)
+ check(1.toShort)
+ check("")
+ check("hello")
+ check(Integer.MAX_VALUE)
+ check(Integer.MIN_VALUE)
+ check(java.lang.Long.MAX_VALUE)
+ check(java.lang.Long.MIN_VALUE)
+ check[String](null)
+ check(Array(1, 2, 3))
+ check(Array(1L, 2L, 3L))
+ check(Array(1.0, 2.0, 3.0))
+ check(Array(1.0f, 2.9f, 3.9f))
+ check(Array("aaa", "bbb", "ccc"))
+ check(Array("aaa", "bbb", null))
+ check(Array(true, false, true))
+ check(Array('a', 'b', 'c'))
+ check(Array[Int]())
+ check(Array(Array("1", "2"), Array("1", "2", "3", "4")))
+ }
+
+ test("pairs") {
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ }
+ check((1, 1))
+ check((1, 1L))
+ check((1L, 1))
+ check((1L, 1L))
+ check((1.0, 1))
+ check((1, 1.0))
+ check((1.0, 1.0))
+ check((1.0, 1L))
+ check((1L, 1.0))
+ check((1.0, 1L))
+ check(("x", 1))
+ check(("x", 1.0))
+ check(("x", 1L))
+ check((1, "x"))
+ check((1.0, "x"))
+ check((1L, "x"))
+ check(("x", "x"))
+ }
+
+ test("Scala data structures") {
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ }
+ check(List[Int]())
+ check(List[Int](1, 2, 3))
+ check(List[String]())
+ check(List[String]("x", "y", "z"))
+ check(None)
+ check(Some(1))
+ check(Some("hi"))
+ check(mutable.ArrayBuffer(1, 2, 3))
+ check(mutable.ArrayBuffer("1", "2", "3"))
+ check(mutable.Map())
+ check(mutable.Map(1 -> "one", 2 -> "two"))
+ check(mutable.Map("one" -> 1, "two" -> 2))
+ check(mutable.HashMap(1 -> "one", 2 -> "two"))
+ check(mutable.HashMap("one" -> 1, "two" -> 2))
+ check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4))))
+ check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three")))
+ }
+
+ test("custom registrator") {
+ import KryoTest._
+ System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName)
+
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ }
+
+ check(CaseClass(17, "hello"))
+
+ val c1 = new ClassWithNoArgConstructor
+ c1.x = 32
+ check(c1)
+
+ val c2 = new ClassWithoutNoArgConstructor(47)
+ check(c2)
+
+ val hashMap = new java.util.HashMap[String, String]
+ hashMap.put("foo", "bar")
+ check(hashMap)
+
+ System.clearProperty("spark.kryo.registrator")
+ }
+
+ test("kryo with collect") {
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x)
+ assert(control === result.toSeq)
+ }
+
+ test("kryo with parallelize") {
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect()
+ assert (control === result.toSeq)
+ }
+
+ test("kryo with parallelize for specialized tuples") {
+ assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).count === 3)
+ }
+
+ test("kryo with parallelize for primitive arrays") {
+ assert (sc.parallelize( Array(1, 2, 3) ).count === 3)
+ }
+
+ test("kryo with collect for specialized tuples") {
+ assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11))
+ }
+
+ test("kryo with reduce") {
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_))
+ .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x
+ assert(control.sum === result)
+ }
+
+ // TODO: this still doesn't work
+ ignore("kryo with fold") {
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_))
+ .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x
+ assert(10 + control.sum === result)
+ }
+
+ override def beforeAll() {
+ System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer")
+ System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName)
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ super.afterAll()
+ System.clearProperty("spark.kryo.registrator")
+ System.clearProperty("spark.serializer")
+ }
+}
+
+object KryoTest {
+ case class CaseClass(i: Int, s: String) {}
+
+ class ClassWithNoArgConstructor {
+ var x: Int = 0
+ override def equals(other: Any) = other match {
+ case c: ClassWithNoArgConstructor => x == c.x
+ case _ => false
+ }
+ }
+
+ class ClassWithoutNoArgConstructor(val x: Int) {
+ override def equals(other: Any) = other match {
+ case c: ClassWithoutNoArgConstructor => x == c.x
+ case _ => false
+ }
+ }
+
+ class MyRegistrator extends KryoRegistrator {
+ override def registerClasses(k: Kryo) {
+ k.register(classOf[CaseClass])
+ k.register(classOf[ClassWithNoArgConstructor])
+ k.register(classOf[ClassWithoutNoArgConstructor])
+ k.register(classOf[java.util.HashMap[_, _]])
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
new file mode 100644
index 0000000..6ec124d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.BeforeAndAfterAll
+
+import org.jboss.netty.logging.InternalLoggerFactory
+import org.jboss.netty.logging.Slf4JLoggerFactory
+
+/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */
+trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite =>
+
+ @transient var sc: SparkContext = _
+
+ override def beforeAll() {
+ InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
+ super.beforeAll()
+ }
+
+ override def afterEach() {
+ resetSparkContext()
+ super.afterEach()
+ }
+
+ def resetSparkContext() = {
+ if (sc != null) {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+ }
+
+}
+
+object LocalSparkContext {
+ def stop(sc: SparkContext) {
+ sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+ }
+
+ /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
+ def withSpark[T](sc: SparkContext)(f: SparkContext => T) = {
+ try {
+ f(sc)
+ } finally {
+ stop(sc)
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
new file mode 100644
index 0000000..6013320
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import akka.actor._
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.AkkaUtils
+
+class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
+
+ test("compressSize") {
+ assert(MapOutputTracker.compressSize(0L) === 0)
+ assert(MapOutputTracker.compressSize(1L) === 1)
+ assert(MapOutputTracker.compressSize(2L) === 8)
+ assert(MapOutputTracker.compressSize(10L) === 25)
+ assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145)
+ assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218)
+ // This last size is bigger than we can encode in a byte, so check that we just return 255
+ assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255)
+ }
+
+ test("decompressSize") {
+ assert(MapOutputTracker.decompressSize(0) === 0)
+ for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) {
+ val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size))
+ assert(size2 >= 0.99 * size && size2 <= 1.11 * size,
+ "size " + size + " decompressed to " + size2 + ", which is out of range")
+ }
+ }
+
+ test("master start and stop") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ tracker.stop()
+ }
+
+ test("master register and fetch") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ tracker.registerShuffle(10, 2)
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ Array(compressedSize1000, compressedSize10000)))
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ Array(compressedSize10000, compressedSize1000)))
+ val statuses = tracker.getServerStatuses(10, 0)
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000),
+ (BlockManagerId("b", "hostB", 1000, 0), size10000)))
+ tracker.stop()
+ }
+
+ test("master register and unregister and fetch") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTracker()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ tracker.registerShuffle(10, 2)
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ Array(compressedSize1000, compressedSize1000, compressedSize1000)))
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ Array(compressedSize10000, compressedSize1000, compressedSize1000)))
+
+ // As if we had two simulatenous fetch failures
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+
+ // The remaining reduce task might try to grab the output despite the shuffle failure;
+ // this should cause it to fail, and the scheduler will ignore the failure due to the
+ // stage already being aborted.
+ intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+ }
+
+ test("remote fetch") {
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ val masterTracker = new MapOutputTracker()
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
+ val slaveTracker = new MapOutputTracker()
+ slaveTracker.trackerActor = slaveSystem.actorFor(
+ "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ // failure should be cached
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala
new file mode 100644
index 0000000..f79752b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+
+import org.scalatest.FunSuite
+
+import com.google.common.io.Files
+import org.apache.spark.SparkContext._
+
+
+class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("groupByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with duplicates") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with negative key hash codes") {
+ val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesForMinus1 = groups.find(_._1 == -1).get._2
+ assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with many output partitions") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey(10).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("reduceByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with collectAsMap") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collectAsMap()
+ assert(sums.size === 2)
+ assert(sums(1) === 7)
+ assert(sums(2) === 1)
+ }
+
+ test("reduceByKey with many output partitons") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_, 10).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with partitioner") {
+ val p = new Partitioner() {
+ def numPartitions = 2
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
+ assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+ assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ }
+
+ test("join") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("join all-to-all") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 6)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (1, 'y')),
+ (1, (2, 'x')),
+ (1, (2, 'y')),
+ (1, (3, 'x')),
+ (1, (3, 'y'))
+ ))
+ }
+
+ test("leftOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.leftOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (1, Some('x'))),
+ (1, (2, Some('x'))),
+ (2, (1, Some('y'))),
+ (2, (1, Some('z'))),
+ (3, (1, None))
+ ))
+ }
+
+ test("rightOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.rightOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (Some(1), 'x')),
+ (1, (Some(2), 'x')),
+ (2, (Some(1), 'y')),
+ (2, (Some(1), 'z')),
+ (4, (None, 'w'))
+ ))
+ }
+
+ test("join with no matches") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 0)
+ }
+
+ test("join with many output partitions") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2, 10).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("groupWith") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.groupWith(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
+ (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
+ (3, (ArrayBuffer(1), ArrayBuffer())),
+ (4, (ArrayBuffer(), ArrayBuffer('w')))
+ ))
+ }
+
+ test("zero-partition RDD") {
+ val emptyDir = Files.createTempDir()
+ val file = sc.textFile(emptyDir.getAbsolutePath)
+ assert(file.partitions.size == 0)
+ assert(file.collect().toList === Nil)
+ // Test that a shuffle on the file works, because this used to be a bug
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ }
+
+ test("keys and values") {
+ val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
+ assert(rdd.keys.collect().toList === List(1, 2))
+ assert(rdd.values.collect().toList === List("a", "b"))
+ }
+
+ test("default partitioner uses partition size") {
+ // specify 2000 partitions
+ val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
+ // do a map, which loses the partitioner
+ val b = a.map(a => (a, (a * 2).toString))
+ // then a group by, and see we didn't revert to 2 partitions
+ val c = b.groupByKey()
+ assert(c.partitions.size === 2000)
+ }
+
+ test("default partitioner uses largest partitioner") {
+ val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
+ val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
+ val c = a.join(b)
+ assert(c.partitions.size === 2000)
+ }
+
+ test("subtract") {
+ val a = sc.parallelize(Array(1, 2, 3), 2)
+ val b = sc.parallelize(Array(2, 3, 4), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set(1))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtract with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+ // Ideally we could keep the original partitioner...
+ assert(c.partitioner === None)
+ }
+
+ test("subtractByKey") {
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+ val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtractByKey with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitioner.get === p)
+ }
+
+ test("foldByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.foldByKey(0)(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("foldByKey with mutable result type") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
+ // Fold the values using in-place mutation
+ val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
+ assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
+ // Check that the mutable objects in the original RDD were not changed
+ assert(bufs.collect().toSet === Set(
+ (1, ArrayBuffer(1)),
+ (1, ArrayBuffer(2)),
+ (1, ArrayBuffer(3)),
+ (1, ArrayBuffer(1)),
+ (2, ArrayBuffer(1))))
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
new file mode 100644
index 0000000..adbe805
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
@@ -0,0 +1,28 @@
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.PartitionPruningRDD
+
+
+class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
+
+ test("Pruned Partitions inherit locality prefs correctly") {
+ class TestPartition(i: Int) extends Partition {
+ def index = i
+ }
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(1),
+ new TestPartition(2),
+ new TestPartition(3))
+ }
+ def compute(split: Partition, context: TaskContext) = {Iterator()}
+ }
+ val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
+ val p = prunedRDD.partitions(0)
+ assert(p.index == 2)
+ assert(prunedRDD.partitions.length == 1)
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
new file mode 100644
index 0000000..7669cf6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import scala.collection.mutable.ArrayBuffer
+import SparkContext._
+import org.apache.spark.util.StatCounter
+import scala.math.abs
+
+class PartitioningSuite extends FunSuite with SharedSparkContext {
+
+ test("HashPartitioner equality") {
+ val p2 = new HashPartitioner(2)
+ val p4 = new HashPartitioner(4)
+ val anotherP4 = new HashPartitioner(4)
+ assert(p2 === p2)
+ assert(p4 === p4)
+ assert(p2 != p4)
+ assert(p4 != p2)
+ assert(p4 === anotherP4)
+ assert(anotherP4 === p4)
+ }
+
+ test("RangePartitioner equality") {
+ // Make an RDD where all the elements are the same so that the partition range bounds
+ // are deterministically all the same.
+ val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x))
+
+ val p2 = new RangePartitioner(2, rdd)
+ val p4 = new RangePartitioner(4, rdd)
+ val anotherP4 = new RangePartitioner(4, rdd)
+ val descendingP2 = new RangePartitioner(2, rdd, false)
+ val descendingP4 = new RangePartitioner(4, rdd, false)
+
+ assert(p2 === p2)
+ assert(p4 === p4)
+ assert(p2 != p4)
+ assert(p4 != p2)
+ assert(p4 === anotherP4)
+ assert(anotherP4 === p4)
+ assert(descendingP2 === descendingP2)
+ assert(descendingP4 === descendingP4)
+ assert(descendingP2 != descendingP4)
+ assert(descendingP4 != descendingP2)
+ assert(p2 != descendingP2)
+ assert(p4 != descendingP4)
+ assert(descendingP2 != p2)
+ assert(descendingP4 != p4)
+ }
+
+ test("HashPartitioner not equal to RangePartitioner") {
+ val rdd = sc.parallelize(1 to 10).map(x => (x, x))
+ val rangeP2 = new RangePartitioner(2, rdd)
+ val hashP2 = new HashPartitioner(2)
+ assert(rangeP2 === rangeP2)
+ assert(hashP2 === hashP2)
+ assert(hashP2 != rangeP2)
+ assert(rangeP2 != hashP2)
+ }
+
+ test("partitioner preservation") {
+ val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
+
+ val grouped2 = rdd.groupByKey(2)
+ val grouped4 = rdd.groupByKey(4)
+ val reduced2 = rdd.reduceByKey(_ + _, 2)
+ val reduced4 = rdd.reduceByKey(_ + _, 4)
+
+ assert(rdd.partitioner === None)
+
+ assert(grouped2.partitioner === Some(new HashPartitioner(2)))
+ assert(grouped4.partitioner === Some(new HashPartitioner(4)))
+ assert(reduced2.partitioner === Some(new HashPartitioner(2)))
+ assert(reduced4.partitioner === Some(new HashPartitioner(4)))
+
+ assert(grouped2.groupByKey().partitioner === grouped2.partitioner)
+ assert(grouped2.groupByKey(3).partitioner != grouped2.partitioner)
+ assert(grouped2.groupByKey(2).partitioner === grouped2.partitioner)
+ assert(grouped4.groupByKey().partitioner === grouped4.partitioner)
+ assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner)
+ assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner)
+
+ assert(grouped2.join(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner)
+
+ assert(grouped2.join(reduced2).partitioner === grouped2.partitioner)
+ assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
+ assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
+ assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
+
+ assert(grouped2.map(_ => 1).partitioner === None)
+ assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner)
+ assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner)
+ assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner)
+ }
+
+ test("partitioning Java arrays should fail") {
+ val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
+ val arrPairs: RDD[(Array[Int], Int)] =
+ sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
+
+ assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array"))
+ // We can't catch all usages of arrays, since they might occur inside other collections:
+ //assert(fails { arrPairs.distinct() })
+ assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
+ }
+
+ test("zero-length partitions should be correctly handled") {
+ // Create RDD with some consecutive empty partitions (including the "first" one)
+ val rdd: RDD[Double] = sc
+ .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
+ .filter(_ >= 0.0)
+
+ // Run the partitions, including the consecutive empty ones, through StatCounter
+ val stats: StatCounter = rdd.stats();
+ assert(abs(6.0 - stats.sum) < 0.01);
+ assert(abs(6.0/2 - rdd.mean) < 0.01);
+ assert(abs(1.0 - rdd.variance) < 0.01);
+ assert(abs(1.0 - rdd.stdev) < 0.01);
+
+ // Add other tests here for classes that should be able to handle empty partitions correctly
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
new file mode 100644
index 0000000..2e851d8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+import SparkContext._
+
+class PipedRDDSuite extends FunSuite with SharedSparkContext {
+
+ test("basic pipe") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+
+ val piped = nums.pipe(Seq("cat"))
+
+ val c = piped.collect()
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ }
+
+ test("advanced pipe") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val bl = sc.broadcast(List("0"))
+
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Int, f: String=> Unit) => f(i + "_"))
+
+ val c = piped.collect()
+
+ assert(c.size === 8)
+ assert(c(0) === "0")
+ assert(c(1) === "\u0001")
+ assert(c(2) === "1_")
+ assert(c(3) === "2_")
+ assert(c(4) === "0")
+ assert(c(5) === "\u0001")
+ assert(c(6) === "3_")
+ assert(c(7) === "4_")
+
+ val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
+ val d = nums1.groupBy(str=>str.split("\t")(0)).
+ pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
+ assert(d.size === 8)
+ assert(d(0) === "0")
+ assert(d(1) === "\u0001")
+ assert(d(2) === "b\t2_")
+ assert(d(3) === "b\t4_")
+ assert(d(4) === "0")
+ assert(d(5) === "\u0001")
+ assert(d(6) === "a\t1_")
+ assert(d(7) === "a\t3_")
+ }
+
+ test("pipe with env variable") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
+ val c = piped.collect()
+ assert(c.size === 2)
+ assert(c(0) === "LALALA")
+ assert(c(1) === "LALALA")
+ }
+
+ test("pipe with non-zero exit status") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
+ intercept[SparkException] {
+ piped.collect()
+ }
+ }
+
+}