You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/04/25 02:27:20 UTC

git commit: SPARK-1438 RDD.sample() make seed param optional

Repository: spark
Updated Branches:
  refs/heads/master f99af8529 -> 35e3d199f


SPARK-1438 RDD.sample() make seed param optional

copying form previous pull request https://github.com/apache/spark/pull/462

Its probably better to let the underlying language implementation take care of the default . This was easier to do with python as the default value for seed in random and numpy random is None.

In Scala/Java side it might mean propagating an Option or null(oh no!) down the chain until where the Random is constructed. But, looks like the convention in some other methods was to use System.nanoTime. So, followed that convention.

Conflict with overloaded method in sql.SchemaRDD.sample which also defines default params.
sample(fraction, withReplacement=false, seed=math.random)
Scala does not allow more than one overloaded to have default params. I believe the author intended to override the RDD.sample method and not overload it. So, changed it.

If backward compatible is important, 3 new method can be introduced (without default params) like this
sample(fraction)
sample(fraction, withReplacement)
sample(fraction, withReplacement, seed)

Added some tests for the scala RDD takeSample method.

Author: Arun Ramakrishnan <sm...@gmail.com>

This patch had conflicts when merged, resolved by
Committer: Matei Zaharia <ma...@databricks.com>

Closes #477 from smartnut007/master and squashes the following commits:

07bb06e [Arun Ramakrishnan] SPARK-1438 fixing more space formatting issues
b9ebfe2 [Arun Ramakrishnan] SPARK-1438 removing redundant import of random in python rddsampler
8d05b1a [Arun Ramakrishnan] SPARK-1438 RDD . Replace System.nanoTime with a Random generated number. python: use a separate instance of Random instead of seeding language api global Random instance.
69619c6 [Arun Ramakrishnan] SPARK-1438 fix spacing issue
0c247db [Arun Ramakrishnan] SPARK-1438 RDD language apis to support optional seed in RDD methods sample/takeSample


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35e3d199
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35e3d199
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35e3d199

Branch: refs/heads/master
Commit: 35e3d199f04fba3230625002a458d43b9578b2e8
Parents: f99af85
Author: Arun Ramakrishnan <sm...@gmail.com>
Authored: Thu Apr 24 17:27:16 2014 -0700
Committer: Matei Zaharia <ma...@databricks.com>
Committed: Thu Apr 24 17:27:16 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/api/java/JavaDoubleRDD.scala   |  9 +++++-
 .../org/apache/spark/api/java/JavaPairRDD.scala |  9 +++++-
 .../org/apache/spark/api/java/JavaRDD.scala     |  9 +++++-
 .../org/apache/spark/api/java/JavaRDDLike.scala |  6 +++-
 .../spark/rdd/PartitionwiseSampledRDD.scala     |  5 ++--
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 11 ++++---
 .../scala/org/apache/spark/util/Utils.scala     |  2 ++
 .../scala/org/apache/spark/rdd/RDDSuite.scala   | 21 ++++++++++++-
 python/pyspark/rdd.py                           | 13 ++++----
 python/pyspark/rddsampler.py                    | 31 +++++++++-----------
 .../catalyst/plans/logical/basicOperators.scala |  2 +-
 .../scala/org/apache/spark/sql/SchemaRDD.scala  |  5 ++--
 .../spark/sql/execution/basicOperators.scala    |  6 ++--
 13 files changed, 88 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 4330cef..a6123bd 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -30,6 +30,7 @@ import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.StatCounter
+import org.apache.spark.util.Utils
 
 class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] {
 
@@ -133,7 +134,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
   /**
    * Return a sampled subset of this RDD.
    */
-  def sample(withReplacement: Boolean, fraction: JDouble, seed: Int): JavaDoubleRDD =
+  def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
+    sample(withReplacement, fraction, Utils.random.nextLong)
+    
+  /**
+   * Return a sampled subset of this RDD.
+   */
+  def sample(withReplacement: Boolean, fraction: JDouble, seed: Long): JavaDoubleRDD =
     fromRDD(srdd.sample(withReplacement, fraction, seed))
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index b3ec270..554c065 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -39,6 +39,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
 
 class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
                        (implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V])
@@ -119,7 +120,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   /**
    * Return a sampled subset of this RDD.
    */
-  def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
+  def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
+    sample(withReplacement, fraction, Utils.random.nextLong)
+    
+  /**
+   * Return a sampled subset of this RDD.
+   */
+  def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
     new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 327c155..dc698de 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -24,6 +24,7 @@ import org.apache.spark._
 import org.apache.spark.api.java.function.{Function => JFunction}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
 
 class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
   extends JavaRDDLike[T, JavaRDD[T]] {
@@ -98,7 +99,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
   /**
    * Return a sampled subset of this RDD.
    */
-  def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
+  def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
+    sample(withReplacement, fraction, Utils.random.nextLong)
+    
+  /**
+   * Return a sampled subset of this RDD.
+   */
+  def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] =
     wrapRDD(rdd.sample(withReplacement, fraction, seed))
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 725c423..574a986 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -34,6 +34,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
 
 trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   def wrapRDD(rdd: RDD[T]): This
@@ -394,7 +395,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     new java.util.ArrayList(arr)
   }
 
-  def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
+  def takeSample(withReplacement: Boolean, num: Int): JList[T] = 
+    takeSample(withReplacement, num, Utils.random.nextLong)
+    
+  def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
     import scala.collection.JavaConversions._
     val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
     new java.util.ArrayList(arr)

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index b4e3bb5..b5b8a57 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.{Partition, TaskContext}
 import org.apache.spark.util.random.RandomSampler
+import org.apache.spark.util.Utils
 
 private[spark]
 class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
@@ -38,14 +39,14 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  *
  * @param prev RDD to be sampled
  * @param sampler a random sampler
- * @param seed random seed, default to System.nanoTime
+ * @param seed random seed
  * @tparam T input RDD item type
  * @tparam U sampled RDD item type
  */
 private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
     prev: RDD[T],
     sampler: RandomSampler[T, U],
-    @transient seed: Long = System.nanoTime)
+    @transient seed: Long = Utils.random.nextLong)
   extends RDD[U](prev) {
 
   override def getPartitions: Array[Partition] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6c897cc..e8bbfbf 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -341,7 +341,9 @@ abstract class RDD[T: ClassTag](
   /**
    * Return a sampled subset of this RDD.
    */
-  def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
+  def sample(withReplacement: Boolean, 
+      fraction: Double, 
+      seed: Long = Utils.random.nextLong): RDD[T] = {
     require(fraction >= 0.0, "Invalid fraction value: " + fraction)
     if (withReplacement) {
       new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
@@ -354,11 +356,11 @@ abstract class RDD[T: ClassTag](
    * Randomly splits this RDD with the provided weights.
    *
    * @param weights weights for splits, will be normalized if they don't sum to 1
-   * @param seed random seed, default to System.nanoTime
+   * @param seed random seed
    *
    * @return split RDDs in an array
    */
-  def randomSplit(weights: Array[Double], seed: Long = System.nanoTime): Array[RDD[T]] = {
+  def randomSplit(weights: Array[Double], seed: Long = Utils.random.nextLong): Array[RDD[T]] = {
     val sum = weights.sum
     val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
     normalizedCumWeights.sliding(2).map { x =>
@@ -366,7 +368,8 @@ abstract class RDD[T: ClassTag](
     }.toArray
   }
 
-  def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
+  def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
+  {
     var fraction = 0.0
     var total = 0
     val multiplier = 3.0

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index d333e2a..084a71c 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -46,6 +46,8 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
 private[spark] object Utils extends Logging {
 
   val osName = System.getProperty("os.name")
+  
+  val random = new Random()
 
   /** Serialize an object using Java serialization */
   def serialize[T](o: T): Array[Byte] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2676558..8da9a0d 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -463,7 +463,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {
 
   test("takeSample") {
     val data = sc.parallelize(1 to 100, 2)
-
+    
+    for (num <- List(5, 20, 100)) {
+      val sample = data.takeSample(withReplacement=false, num=num)
+      assert(sample.size === num)        // Got exactly num elements
+      assert(sample.toSet.size === num)  // Elements are distinct
+      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+    }
     for (seed <- 1 to 5) {
       val sample = data.takeSample(withReplacement=false, 20, seed)
       assert(sample.size === 20)        // Got exactly 20 elements
@@ -481,6 +487,19 @@ class RDDSuite extends FunSuite with SharedSparkContext {
       assert(sample.size === 20)        // Got exactly 20 elements
       assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
     }
+    {
+      val sample = data.takeSample(withReplacement=true, num=20)
+      assert(sample.size === 20)        // Got exactly 100 elements
+      assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
+      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+    }
+    {
+      val sample = data.takeSample(withReplacement=true, num=100)
+      assert(sample.size === 100)        // Got exactly 100 elements
+      // Chance of getting all distinct elements is astronomically low, so test we got < 100
+      assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+    }
     for (seed <- 1 to 5) {
       val sample = data.takeSample(withReplacement=true, 100, seed)
       assert(sample.size === 100)        // Got exactly 100 elements

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 91fc7e6..d73ab70 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -30,6 +30,7 @@ from tempfile import NamedTemporaryFile
 from threading import Thread
 import warnings
 import heapq
+from random import Random
 
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -332,7 +333,7 @@ class RDD(object):
                    .reduceByKey(lambda x, _: x) \
                    .map(lambda (x, _): x)
 
-    def sample(self, withReplacement, fraction, seed):
+    def sample(self, withReplacement, fraction, seed=None):
         """
         Return a sampled subset of this RDD (relies on numpy and falls back
         on default random generator if numpy is unavailable).
@@ -344,7 +345,7 @@ class RDD(object):
         return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
 
     # this is ported from scala/spark/RDD.scala
-    def takeSample(self, withReplacement, num, seed):
+    def takeSample(self, withReplacement, num, seed=None):
         """
         Return a fixed-size sampled subset of this RDD (currently requires numpy).
 
@@ -381,13 +382,11 @@ class RDD(object):
         # If the first sample didn't turn out large enough, keep trying to take samples;
         # this shouldn't happen often because we use a big multiplier for their initial size.
         # See: scala/spark/RDD.scala
+        rand = Random(seed)
         while len(samples) < total:
-            if seed > sys.maxint - 2:
-                seed = -1
-            seed += 1
-            samples = self.sample(withReplacement, fraction, seed).collect()
+            samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()
 
-        sampler = RDDSampler(withReplacement, fraction, seed+1)
+        sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
         sampler.shuffle(samples)
         return samples[0:total]
 

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/python/pyspark/rddsampler.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index aca2ef3..845a267 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -19,7 +19,7 @@ import sys
 import random
 
 class RDDSampler(object):
-    def __init__(self, withReplacement, fraction, seed):
+    def __init__(self, withReplacement, fraction, seed=None):
         try:
             import numpy
             self._use_numpy = True
@@ -27,7 +27,7 @@ class RDDSampler(object):
             print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
             self._use_numpy = False
 
-        self._seed = seed
+        self._seed = seed if seed is not None else random.randint(0, sys.maxint)
         self._withReplacement = withReplacement
         self._fraction = fraction
         self._random = None
@@ -38,17 +38,14 @@ class RDDSampler(object):
         if self._use_numpy:
             import numpy
             self._random = numpy.random.RandomState(self._seed)
-            for _ in range(0, split):
-                # discard the next few values in the sequence to have a
-                # different seed for the different splits
-                self._random.randint(sys.maxint)
         else:
-            import random
-            random.seed(self._seed)
-            for _ in range(0, split):
-                # discard the next few values in the sequence to have a
-                # different seed for the different splits
-                random.randint(0, sys.maxint)
+            self._random = random.Random(self._seed)
+
+        for _ in range(0, split):
+            # discard the next few values in the sequence to have a
+            # different seed for the different splits
+            self._random.randint(0, sys.maxint)
+
         self._split = split
         self._rand_initialized = True
 
@@ -59,7 +56,7 @@ class RDDSampler(object):
         if self._use_numpy:
             return self._random.random_sample()
         else:
-            return random.uniform(0.0, 1.0)
+            return self._random.uniform(0.0, 1.0)
 
     def getPoissonSample(self, split, mean):
         if not self._rand_initialized or split != self._split:
@@ -73,26 +70,26 @@ class RDDSampler(object):
             num_arrivals = 1
             cur_time = 0.0
 
-            cur_time += random.expovariate(mean)
+            cur_time += self._random.expovariate(mean)
 
             if cur_time > 1.0:
                 return 0
 
             while(cur_time <= 1.0):
-                cur_time += random.expovariate(mean)
+                cur_time += self._random.expovariate(mean)
                 num_arrivals += 1
 
             return (num_arrivals - 1)
     
     def shuffle(self, vals):
-        if self._random == None or split != self._split:
+        if self._random == None:
             self.initRandomGenerator(0)  # this should only ever called on the master so
             # the split does not matter
         
         if self._use_numpy:
             self._random.shuffle(vals)
         else:
-            random.shuffle(vals, self._random)
+            self._random.shuffle(vals, self._random.random)
 
     def func(self, split, iterator):
         if self._withReplacement:            

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 397473e..732708e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -168,7 +168,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
   def references = Set.empty
 }
 
-case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan)
+case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
     extends UnaryNode {
 
   def output = child.output

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 6cb0e0f..ca6e0a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -256,10 +256,11 @@ class SchemaRDD(
    * @group Query
    */
   @Experimental
+  override
   def sample(
-      fraction: Double,
       withReplacement: Boolean = true,
-      seed: Int = (math.random * 1000).toInt) =
+      fraction: Double,
+      seed: Long) =
     new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/35e3d199/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e4cf202..d807187 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -57,9 +57,9 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
  * :: DeveloperApi ::
  */
 @DeveloperApi
-case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
-    extends UnaryNode {
-
+case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
+  extends UnaryNode
+{
   override def output = child.output
 
   // TODO: How to pick seed?