You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/09/21 11:16:29 UTC
spark git commit: [SPARK-21780][R] Simpler Dataset.sample API in R
Repository: spark
Updated Branches:
refs/heads/master 1da5822e6 -> a8d9ec8a6
[SPARK-21780][R] Simpler Dataset.sample API in R
## What changes were proposed in this pull request?
This PR make `sample(...)` able to omit `withReplacement` defaulting to `FALSE`.
In short, the following examples are allowed:
```r
> df <- createDataFrame(as.list(seq(10)))
> count(sample(df, fraction=0.5, seed=3))
[1] 4
> count(sample(df, fraction=1.0))
[1] 10
```
In addition, this PR also adds some type checking logics as below:
```r
> sample(df, fraction = "a")
Error in sample(df, fraction = "a") :
fraction must be numeric; however, got character
> sample(df, fraction = 1, seed = NULL)
Error in sample(df, fraction = 1, seed = NULL) :
seed must not be NULL or NA; however, got NULL
> sample(df, list(1), 1.0)
Error in sample(df, list(1), 1) :
withReplacement must be logical; however, got list
> sample(df, fraction = -1.0)
...
Error in sample : illegal argument - requirement failed: Sampling fraction (-1.0) must be on interval [0, 1] without replacement
```
## How was this patch tested?
Manually tested, unit tests added in `R/pkg/tests/fulltests/test_sparkSQL.R`.
Author: hyukjinkwon <gu...@gmail.com>
Closes #19243 from HyukjinKwon/SPARK-21780.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a8d9ec8a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a8d9ec8a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a8d9ec8a
Branch: refs/heads/master
Commit: a8d9ec8a60f21abb520b9109b238f914d2449022
Parents: 1da5822
Author: hyukjinkwon <gu...@gmail.com>
Authored: Thu Sep 21 20:16:25 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Thu Sep 21 20:16:25 2017 +0900
----------------------------------------------------------------------
R/pkg/R/DataFrame.R | 40 ++++++++++++++++++++----------
R/pkg/R/generics.R | 4 +--
R/pkg/tests/fulltests/test_sparkSQL.R | 14 +++++++++++
3 files changed, 43 insertions(+), 15 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 1b46c1e..0728141 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -986,10 +986,10 @@ setMethod("unique",
#' @param x A SparkDataFrame
#' @param withReplacement Sampling with replacement or not
#' @param fraction The (rough) sample target fraction
-#' @param seed Randomness seed value
+#' @param seed Randomness seed value. Default is a random seed.
#'
#' @family SparkDataFrame functions
-#' @aliases sample,SparkDataFrame,logical,numeric-method
+#' @aliases sample,SparkDataFrame-method
#' @rdname sample
#' @name sample
#' @export
@@ -998,33 +998,47 @@ setMethod("unique",
#' sparkR.session()
#' path <- "path/to/file.json"
#' df <- read.json(path)
+#' collect(sample(df, fraction = 0.5))
#' collect(sample(df, FALSE, 0.5))
-#' collect(sample(df, TRUE, 0.5))
+#' collect(sample(df, TRUE, 0.5, seed = 3))
#'}
#' @note sample since 1.4.0
setMethod("sample",
- signature(x = "SparkDataFrame", withReplacement = "logical",
- fraction = "numeric"),
- function(x, withReplacement, fraction, seed) {
- if (fraction < 0.0) stop(cat("Negative fraction value:", fraction))
+ signature(x = "SparkDataFrame"),
+ function(x, withReplacement = FALSE, fraction, seed) {
+ if (!is.numeric(fraction)) {
+ stop(paste("fraction must be numeric; however, got", class(fraction)))
+ }
+ if (!is.logical(withReplacement)) {
+ stop(paste("withReplacement must be logical; however, got", class(withReplacement)))
+ }
+
if (!missing(seed)) {
+ if (is.null(seed)) {
+ stop("seed must not be NULL or NA; however, got NULL")
+ }
+ if (is.na(seed)) {
+ stop("seed must not be NULL or NA; however, got NA")
+ }
+
# TODO : Figure out how to send integer as java.lang.Long to JVM so
# we can send seed as an argument through callJMethod
- sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed))
+ sdf <- handledCallJMethod(x@sdf, "sample", as.logical(withReplacement),
+ as.numeric(fraction), as.integer(seed))
} else {
- sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
+ sdf <- handledCallJMethod(x@sdf, "sample",
+ as.logical(withReplacement), as.numeric(fraction))
}
dataFrame(sdf)
})
#' @rdname sample
-#' @aliases sample_frac,SparkDataFrame,logical,numeric-method
+#' @aliases sample_frac,SparkDataFrame-method
#' @name sample_frac
#' @note sample_frac since 1.4.0
setMethod("sample_frac",
- signature(x = "SparkDataFrame", withReplacement = "logical",
- fraction = "numeric"),
- function(x, withReplacement, fraction, seed) {
+ signature(x = "SparkDataFrame"),
+ function(x, withReplacement = FALSE, fraction, seed) {
sample(x, withReplacement, fraction, seed)
})
http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 603ff4e..0fe8f04 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -645,7 +645,7 @@ setGeneric("repartition", function(x, ...) { standardGeneric("repartition") })
#' @rdname sample
#' @export
setGeneric("sample",
- function(x, withReplacement, fraction, seed) {
+ function(x, withReplacement = FALSE, fraction, seed) {
standardGeneric("sample")
})
@@ -656,7 +656,7 @@ setGeneric("rollup", function(x, ...) { standardGeneric("rollup") })
#' @rdname sample
#' @export
setGeneric("sample_frac",
- function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
+ function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") })
#' @rdname sampleBy
#' @export
http://git-wip-us.apache.org/repos/asf/spark/blob/a8d9ec8a/R/pkg/tests/fulltests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index 85a7e08..4d1010e 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1116,6 +1116,20 @@ test_that("sample on a DataFrame", {
sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
expect_true(count(sampled3) < 3)
+ # Different arguments
+ df <- createDataFrame(as.list(seq(10)))
+ expect_equal(count(sample(df, fraction = 0.5, seed = 3)), 4)
+ expect_equal(count(sample(df, withReplacement = TRUE, fraction = 0.5, seed = 3)), 2)
+ expect_equal(count(sample(df, fraction = 1.0)), 10)
+ expect_equal(count(sample(df, fraction = 1L)), 10)
+ expect_equal(count(sample(df, FALSE, fraction = 1.0)), 10)
+
+ expect_error(sample(df, fraction = "a"), "fraction must be numeric")
+ expect_error(sample(df, "a", fraction = 0.1), "however, got character")
+ expect_error(sample(df, fraction = 1, seed = NA), "seed must not be NULL or NA; however, got NA")
+ expect_error(sample(df, fraction = -1.0),
+ "illegal argument - requirement failed: Sampling fraction \\(-1.0\\)")
+
# nolint start
# Test base::sample is working
#expect_equal(length(sample(1:12)), 12)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org