You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/08/27 07:42:44 UTC

spark git commit: [SPARK-15382][SQL] Fix a bug in sampling with replacement

Repository: spark
Updated Branches:
  refs/heads/master 718b6bad2 -> cd0ed31ea


[SPARK-15382][SQL] Fix a bug in sampling with replacement

## What changes were proposed in this pull request?
This pr to fix a bug below in sampling with replacement
```
val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
df.sample(true, 2.0).withColumn("c", monotonically_increasing_id).select($"c").show
+---+
|  c|
+---+
|  0|
|  1|
|  1|
|  1|
|  2|
+---+
```

## How was this patch tested?
Added a test in `DataFrameSuite`.

Author: Takeshi YAMAMURO <li...@gmail.com>

Closes #14800 from maropu/FixSampleBug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cd0ed31e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cd0ed31e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cd0ed31e

Branch: refs/heads/master
Commit: cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9
Parents: 718b6ba
Author: Takeshi YAMAMURO <li...@gmail.com>
Authored: Sat Aug 27 08:42:41 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sat Aug 27 08:42:41 2016 +0100

----------------------------------------------------------------------
 .../apache/spark/sql/execution/basicPhysicalOperators.scala   | 1 +
 .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala  | 7 +++++++
 2 files changed, 8 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cd0ed31e/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 3562083..dd78a78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -266,6 +266,7 @@ case class SampleExec(
     if (withReplacement) {
       val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
       val initSampler = ctx.freshName("initSampler")
+      ctx.copyResult = true
       ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
         s"$initSampler();")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cd0ed31e/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index cd48577..ce0b92a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1579,4 +1579,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val df = spark.createDataFrame(rdd, StructType(schemas), false)
     assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
   }
+
+  test("copy results for sampling with replacement") {
+    val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
+    val sampleDf = df.sample(true, 2.00)
+    val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect
+    assert(d.size == d.distinct.size)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org