You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/09/21 11:16:29 UTC

spark git commit: [SPARK-21780][R] Simpler Dataset.sample API in R

Repository: spark
Updated Branches:
  refs/heads/master 1da5822e6 -> a8d9ec8a6


[SPARK-21780][R] Simpler Dataset.sample API in R

## What changes were proposed in this pull request?

This PR make `sample(...)` able to omit `withReplacement` defaulting to `FALSE`.

In short, the following examples are allowed:

```r
> df <- createDataFrame(as.list(seq(10)))
> count(sample(df, fraction=0.5, seed=3))
[1] 4
> count(sample(df, fraction=1.0))
[1] 10
```

In addition, this PR also adds some type checking logics as below:

```r
> sample(df, fraction = "a")
Error in sample(df, fraction = "a") :
  fraction must be numeric; however, got character
> sample(df, fraction = 1, seed = NULL)
Error in sample(df, fraction = 1, seed = NULL) :
  seed must not be NULL or NA; however, got NULL
> sample(df, list(1), 1.0)
Error in sample(df, list(1), 1) :
  withReplacement must be logical; however, got list
> sample(df, fraction = -1.0)
...
Error in sample : illegal argument - requirement failed: Sampling fraction (-1.0) must be on interval [0, 1] without replacement
```

## How was this patch tested?

Manually tested, unit tests added in `R/pkg/tests/fulltests/test_sparkSQL.R`.

Author: hyukjinkwon <gu...@gmail.com>

Closes #19243 from HyukjinKwon/SPARK-21780.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a8d9ec8a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a8d9ec8a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a8d9ec8a

Branch: refs/heads/master
Commit: a8d9ec8a60f21abb520b9109b238f914d2449022
Parents: 1da5822
Author: hyukjinkwon <gu...@gmail.com>
Authored: Thu Sep 21 20:16:25 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Thu Sep 21 20:16:25 2017 +0900

----------------------------------------------------------------------
 R/pkg/R/DataFrame.R                   | 40 ++++++++++++++++++++----------
 R/pkg/R/generics.R                    |  4 +--
 R/pkg/tests/fulltests/test_sparkSQL.R | 14 +++++++++++
 3 files changed, 43 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 1b46c1e..0728141 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -986,10 +986,10 @@ setMethod("unique",
 #' @param x A SparkDataFrame
 #' @param withReplacement Sampling with replacement or not
 #' @param fraction The (rough) sample target fraction
-#' @param seed Randomness seed value
+#' @param seed Randomness seed value. Default is a random seed.
 #'
 #' @family SparkDataFrame functions
-#' @aliases sample,SparkDataFrame,logical,numeric-method
+#' @aliases sample,SparkDataFrame-method
 #' @rdname sample
 #' @name sample
 #' @export
@@ -998,33 +998,47 @@ setMethod("unique",
 #' sparkR.session()
 #' path <- "path/to/file.json"
 #' df <- read.json(path)
+#' collect(sample(df, fraction = 0.5))
 #' collect(sample(df, FALSE, 0.5))
-#' collect(sample(df, TRUE, 0.5))
+#' collect(sample(df, TRUE, 0.5, seed = 3))
 #'}
 #' @note sample since 1.4.0
 setMethod("sample",
-          signature(x = "SparkDataFrame", withReplacement = "logical",
-                    fraction = "numeric"),
-          function(x, withReplacement, fraction, seed) {
-            if (fraction < 0.0) stop(cat("Negative fraction value:", fraction))
+          signature(x = "SparkDataFrame"),
+          function(x, withReplacement = FALSE, fraction, seed) {
+            if (!is.numeric(fraction)) {
+              stop(paste("fraction must be numeric; however, got", class(fraction)))
+            }
+            if (!is.logical(withReplacement)) {
+              stop(paste("withReplacement must be logical; however, got", class(withReplacement)))
+            }
+
             if (!missing(seed)) {
+              if (is.null(seed)) {
+                stop("seed must not be NULL or NA; however, got NULL")
+              }
+              if (is.na(seed)) {
+                stop("seed must not be NULL or NA; however, got NA")
+              }
+
               # TODO : Figure out how to send integer as java.lang.Long to JVM so
               # we can send seed as an argument through callJMethod
-              sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed))
+              sdf <- handledCallJMethod(x@sdf, "sample", as.logical(withReplacement),
+                                        as.numeric(fraction), as.integer(seed))
             } else {
-              sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
+              sdf <- handledCallJMethod(x@sdf, "sample",
+                                        as.logical(withReplacement), as.numeric(fraction))
             }
             dataFrame(sdf)
           })
 
 #' @rdname sample
-#' @aliases sample_frac,SparkDataFrame,logical,numeric-method
+#' @aliases sample_frac,SparkDataFrame-method
 #' @name sample_frac
 #' @note sample_frac since 1.4.0
 setMethod("sample_frac",
-          signature(x = "SparkDataFrame", withReplacement = "logical",
-                    fraction = "numeric"),
-          function(x, withReplacement, fraction, seed) {
+          signature(x = "SparkDataFrame"),
+          function(x, withReplacement = FALSE, fraction, seed) {
             sample(x, withReplacement, fraction, seed)
           })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 603ff4e..0fe8f04 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -645,7 +645,7 @@ setGeneric("repartition", function(x, ...) { standardGeneric("repartition") })
 #' @rdname sample
 #' @export
 setGeneric("sample",
-           function(x, withReplacement, fraction, seed) {
+           function(x, withReplacement = FALSE, fraction, seed) {
              standardGeneric("sample")
            })
 
@@ -656,7 +656,7 @@ setGeneric("rollup", function(x, ...) { standardGeneric("rollup") })
 #' @rdname sample
 #' @export
 setGeneric("sample_frac",
-           function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
+           function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") })
 
 #' @rdname sampleBy
 #' @export

http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/tests/fulltests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index 85a7e08..4d1010e 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1116,6 +1116,20 @@ test_that("sample on a DataFrame", {
   sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
   expect_true(count(sampled3) < 3)
 
+  # Different arguments
+  df <- createDataFrame(as.list(seq(10)))
+  expect_equal(count(sample(df, fraction = 0.5, seed = 3)), 4)
+  expect_equal(count(sample(df, withReplacement = TRUE, fraction = 0.5, seed = 3)), 2)
+  expect_equal(count(sample(df, fraction = 1.0)), 10)
+  expect_equal(count(sample(df, fraction = 1L)), 10)
+  expect_equal(count(sample(df, FALSE, fraction = 1.0)), 10)
+
+  expect_error(sample(df, fraction = "a"), "fraction must be numeric")
+  expect_error(sample(df, "a", fraction = 0.1), "however, got character")
+  expect_error(sample(df, fraction = 1, seed = NA), "seed must not be NULL or NA; however, got NA")
+  expect_error(sample(df, fraction = -1.0),
+               "illegal argument - requirement failed: Sampling fraction \\(-1.0\\)")
+
   # nolint start
   # Test base::sample is working
   #expect_equal(length(sample(1:12)), 12)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org