You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/04/30 00:34:10 UTC

spark git commit: [SPARK-7156][SQL] support RandomSplit in DataFrames

Repository: spark
Updated Branches:
  refs/heads/master c9d530e2e -> d7dbce8f7


[SPARK-7156][SQL] support RandomSplit in DataFrames

This is built on top of kaka1992 's PR #5711 using Logical plans.

Author: Burak Yavuz <br...@gmail.com>

Closes #5761 from brkyvz/random-sample and squashes the following commits:

a1fb0aa [Burak Yavuz] remove unrelated file
69669c3 [Burak Yavuz] fix broken test
1ddb3da [Burak Yavuz] copy base
6000328 [Burak Yavuz] added python api and fixed test
3c11d1b [Burak Yavuz] fixed broken test
f400ade [Burak Yavuz] fix build errors
2384266 [Burak Yavuz] addressed comments v0.1
e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7dbce8f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7dbce8f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7dbce8f

Branch: refs/heads/master
Commit: d7dbce8f7da8a7fd01df6633a6043f51161b7d18
Parents: c9d530e
Author: Burak Yavuz <br...@gmail.com>
Authored: Wed Apr 29 15:34:05 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Apr 29 15:34:05 2015 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 19 ++++++++--
 .../java/org/apache/spark/JavaAPISuite.java     |  8 ++---
 python/pyspark/sql/dataframe.py                 | 18 +++++++++-
 .../apache/spark/sql/catalyst/dsl/package.scala |  6 ----
 .../catalyst/plans/logical/basicOperators.scala | 18 ++++++++--
 .../scala/org/apache/spark/sql/DataFrame.scala  | 38 +++++++++++++++++++-
 .../spark/sql/execution/SparkStrategies.scala   |  4 +--
 .../spark/sql/execution/basicOperators.scala    | 20 +++++++++--
 .../org/apache/spark/sql/DataFrameSuite.scala   | 17 +++++++++
 .../org/apache/spark/sql/hive/HiveQl.scala      |  4 +--
 10 files changed, 130 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index d80d94a..330255f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -407,12 +407,27 @@ abstract class RDD[T: ClassTag](
     val sum = weights.sum
     val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
     normalizedCumWeights.sliding(2).map { x =>
-      new PartitionwiseSampledRDD[T, T](
-        this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
+      randomSampleWithRange(x(0), x(1), seed)
     }.toArray
   }
 
   /**
+   * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
+   * range.
+   * @param lb lower bound to use for the Bernoulli sampler
+   * @param ub upper bound to use for the Bernoulli sampler
+   * @param seed the seed for the Random number generator
+   * @return A random sub-sample of the RDD without replacement.
+   */
+  private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
+    this.mapPartitionsWithIndex { case (index, partition) =>
+      val sampler = new BernoulliCellSampler[T](lb, ub)
+      sampler.setSeed(seed + index)
+      sampler.sample(partition)
+    }
+  }
+
+  /**
    * Return a fixed-size sampled subset of this RDD in an array
    *
    * @param withReplacement whether sampling is done with replacement

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 34ac936..c2089b0 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -157,11 +157,11 @@ public class JavaAPISuite implements Serializable {
   public void randomSplit() {
     List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
     JavaRDD<Integer> rdd = sc.parallelize(ints);
-    JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11);
+    JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
     Assert.assertEquals(3, splits.length);
-    Assert.assertEquals(2, splits[0].count());
-    Assert.assertEquals(3, splits[1].count());
-    Assert.assertEquals(5, splits[2].count());
+    Assert.assertEquals(1, splits[0].count());
+    Assert.assertEquals(2, splits[1].count());
+    Assert.assertEquals(7, splits[2].count());
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d9cbbc6..3074af3 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -426,7 +426,7 @@ class DataFrame(object):
     def sample(self, withReplacement, fraction, seed=None):
         """Returns a sampled subset of this :class:`DataFrame`.
 
-        >>> df.sample(False, 0.5, 97).count()
+        >>> df.sample(False, 0.5, 42).count()
         1
         """
         assert fraction >= 0.0, "Negative fraction value: %s" % fraction
@@ -434,6 +434,22 @@ class DataFrame(object):
         rdd = self._jdf.sample(withReplacement, fraction, long(seed))
         return DataFrame(rdd, self.sql_ctx)
 
+    def randomSplit(self, weights, seed=None):
+        """Randomly splits this :class:`DataFrame` with the provided weights.
+
+        >>> splits = df4.randomSplit([1.0, 2.0], 24)
+        >>> splits[0].count()
+        1
+
+        >>> splits[1].count()
+        3
+        """
+        for w in weights:
+            assert w >= 0.0, "Negative weight value: %s" % w
+        seed = seed if seed is not None else random.randint(0, sys.maxsize)
+        rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
+        return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
+
     @property
     def dtypes(self):
         """Returns all column names and their data types as a list.

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5d5aba9..fa6cc7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -278,12 +278,6 @@ package object dsl {
     def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan =
       Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
 
-    def sample(
-        fraction: Double,
-        withReplacement: Boolean = true,
-        seed: Int = (math.random * 1000).toInt): LogicalPlan =
-      Sample(fraction, withReplacement, seed, logicalPlan)
-
     // TODO specify the output column names
     def generate(
         generator: Generator,

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 608e272..21208c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
 }
 
-case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
-    extends UnaryNode {
+/**
+ * Sample the dataset.
+ *
+ * @param lowerBound Lower-bound of the sampling probability (usually 0.0)
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
+ *                   will be ub - lb.
+ * @param withReplacement Whether to sample with replacement.
+ * @param seed the random seed
+ * @param child the LogicalPlan
+ */
+case class Sample(
+    lowerBound: Double,
+    upperBound: Double,
+    withReplacement: Boolean,
+    seed: Long,
+    child: LogicalPlan) extends UnaryNode {
 
   override def output: Seq[Attribute] = child.output
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 0e896e5..0d02e14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -706,7 +706,7 @@ class DataFrame private[sql](
    * @group dfops
    */
   def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
-    Sample(fraction, withReplacement, seed, logicalPlan)
+    Sample(0.0, fraction, withReplacement, seed, logicalPlan)
   }
 
   /**
@@ -721,6 +721,42 @@ class DataFrame private[sql](
   }
 
   /**
+   * Randomly splits this [[DataFrame]] with the provided weights.
+   *
+   * @param weights weights for splits, will be normalized if they don't sum to 1.
+   * @param seed Seed for sampling.
+   * @group dfops
+   */
+  def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+    val sum = weights.sum
+    val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+    normalizedCumWeights.sliding(2).map { x =>
+      new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
+    }.toArray
+  }
+
+  /**
+   * Randomly splits this [[DataFrame]] with the provided weights.
+   *
+   * @param weights weights for splits, will be normalized if they don't sum to 1.
+   * @group dfops
+   */
+  def randomSplit(weights: Array[Double]): Array[DataFrame] = {
+    randomSplit(weights, Utils.random.nextLong)
+  }
+
+  /**
+   * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api.
+   *
+   * @param weights weights for splits, will be normalized if they don't sum to 1.
+   * @param seed Seed for sampling.
+   * @group dfops
+   */
+  def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
+    randomSplit(weights.toArray, seed)
+  }
+
+  /**
    * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
    * rows by the provided function.  This is similar to a `LATERAL VIEW` in HiveQL. The columns of
    * the input row are implicitly joined with each row that is output by the function.

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index af58911..326e8ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -303,8 +303,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.Expand(projections, output, planLater(child)) :: Nil
       case logical.Aggregate(group, agg, child) =>
         execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
-      case logical.Sample(fraction, withReplacement, seed, child) =>
-        execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
+      case logical.Sample(lb, ub, withReplacement, seed, child) =>
+        execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
       case logical.LocalRelation(output, data) =>
         LocalTableScan(output, data) :: Nil
       case logical.Limit(IntegerLiteral(limit), child) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 1afdb40..5ca11e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
 
 /**
  * :: DeveloperApi ::
+ * Sample the dataset.
+ * @param lowerBound Lower-bound of the sampling probability (usually 0.0)
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled 
+ *                   will be ub - lb.
+ * @param withReplacement Whether to sample with replacement.
+ * @param seed the random seed
+ * @param child the QueryPlan
  */
 @DeveloperApi
-case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
+case class Sample(
+    lowerBound: Double,
+    upperBound: Double,
+    withReplacement: Boolean,
+    seed: Long,
+    child: SparkPlan)
   extends UnaryNode
 {
   override def output: Seq[Attribute] = child.output
 
   // TODO: How to pick seed?
   override def execute(): RDD[Row] = {
-    child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
+    if (withReplacement) {
+      child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
+    } else {
+      child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5ec06d4..b70e127b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
     assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
   }
 
+  test("randomSplit") {
+    val n = 600
+    val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
+    for (seed <- 1 to 5) {
+      val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
+      assert(splits.length == 3, "wrong number of splits")
+
+      assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
+        data.collect().toList, "incomplete or wrong split")
+
+      val s = splits.map(_.count())
+      assert(math.abs(s(0) - 100) < 50) // std =  9.13
+      assert(math.abs(s(1) - 200) < 50) // std = 11.55
+      assert(math.abs(s(2) - 300) < 50) // std = 12.25
+    }
+  }
+
   test("describe") {
     val describeTestData = Seq(
       ("Bob",   16, 176),

http://git-wip-us.apache.org/repos/asf/spark/blob/d7dbce8f/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2dc6463..0a86519 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
             fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
               && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
             s"Sampling fraction ($fraction) must be on interval [0, 100]")
-          Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
+          Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
             relation)
         case Token("TOK_TABLEBUCKETSAMPLE",
                Token(numerator, Nil) ::
                Token(denominator, Nil) :: Nil) =>
           val fraction = numerator.toDouble / denominator.toDouble
-          Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
+          Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
         case a: ASTNode =>
           throw new NotImplementedError(
             s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org