You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sh...@apache.org on 2015/10/14 07:31:45 UTC
spark git commit: [SPARK-10996] [SPARKR] Implement sampleBy() in
DataFrameStatFunctions.
Repository: spark
Updated Branches:
refs/heads/master 8b3288570 -> 390b22fad
[SPARK-10996] [SPARKR] Implement sampleBy() in DataFrameStatFunctions.
Author: Sun Rui <ru...@intel.com>
Closes #9023 from sun-rui/SPARK-10996.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/390b22fa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/390b22fa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/390b22fa
Branch: refs/heads/master
Commit: 390b22fad69a33eb6daee25b6b858a2e768670a5
Parents: 8b32885
Author: Sun Rui <ru...@intel.com>
Authored: Tue Oct 13 22:31:23 2015 -0700
Committer: Shivaram Venkataraman <sh...@cs.berkeley.edu>
Committed: Tue Oct 13 22:31:23 2015 -0700
----------------------------------------------------------------------
R/pkg/NAMESPACE | 3 ++-
R/pkg/R/DataFrame.R | 14 ++++++--------
R/pkg/R/generics.R | 6 +++++-
R/pkg/R/sparkR.R | 12 +++---------
R/pkg/R/stats.R | 32 ++++++++++++++++++++++++++++++++
R/pkg/R/utils.R | 18 ++++++++++++++++++
R/pkg/inst/tests/test_sparkSQL.R | 10 ++++++++++
7 files changed, 76 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index ed9cd94..52f7a01 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -65,6 +65,7 @@ exportMethods("arrange",
"repartition",
"sample",
"sample_frac",
+ "sampleBy",
"saveAsParquetFile",
"saveAsTable",
"saveDF",
@@ -254,4 +255,4 @@ export("structField",
"structType.structField",
"print.structType")
-export("as.data.frame")
\ No newline at end of file
+export("as.data.frame")
http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index b7f5f97..993be82 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1831,17 +1831,15 @@ setMethod("fillna",
if (length(colNames) == 0 || !all(colNames != "")) {
stop("value should be an a named list with each name being a column name.")
}
-
- # Convert to the named list to an environment to be passed to JVM
- valueMap <- new.env()
- for (col in colNames) {
- # Check each item in the named list is of valid type
- v <- value[[col]]
+ # Check each item in the named list is of valid type
+ lapply(value, function(v) {
if (!(class(v) %in% c("integer", "numeric", "character"))) {
stop("Each item in value should be an integer, numeric or charactor.")
}
- valueMap[[col]] <- v
- }
+ })
+
+ # Convert to the named list to an environment to be passed to JVM
+ valueMap <- convertNamedListToEnv(value)
# When value is a named list, caller is expected not to pass in cols
if (!is.null(cols)) {
http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index c106a00..4a419f7 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -509,6 +509,10 @@ setGeneric("sample",
setGeneric("sample_frac",
function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
+#' @rdname statfunctions
+#' @export
+setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") })
+
#' @rdname saveAsParquetFile
#' @export
setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") })
@@ -1006,4 +1010,4 @@ setGeneric("as.data.frame")
#' @rdname attach
#' @export
-setGeneric("attach")
\ No newline at end of file
+setGeneric("attach")
http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/sparkR.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index cc47110..9cf2f1a 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -163,19 +163,13 @@ sparkR.init <- function(
sparkHome <- suppressWarnings(normalizePath(sparkHome))
}
- sparkEnvirMap <- new.env()
- for (varname in names(sparkEnvir)) {
- sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
- }
+ sparkEnvirMap <- convertNamedListToEnv(sparkEnvir)
- sparkExecutorEnvMap <- new.env()
- if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) {
+ sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv)
+ if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) {
sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
}
- for (varname in names(sparkExecutorEnv)) {
- sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]]
- }
nonEmptyJars <- Filter(function(x) { x != "" }, jars)
localJarPaths <- lapply(nonEmptyJars,
http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/stats.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index 4928cf4..f79329b 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -127,3 +127,35 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"),
sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support)
collect(dataFrame(sct))
})
+
+#' sampleBy
+#'
+#' Returns a stratified sample without replacement based on the fraction given on each stratum.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param col column that defines strata
+#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is
+#' not specified, we treat its fraction as zero.
+#' @param seed random seed
+#' @return A new DataFrame that represents the stratified sample
+#'
+#' @rdname statfunctions
+#' @name sampleBy
+#' @export
+#' @examples
+#'\dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' sample <- sampleBy(df, "key", fractions, 36)
+#' }
+setMethod("sampleBy",
+ signature(x = "DataFrame", col = "character",
+ fractions = "list", seed = "numeric"),
+ function(x, col, fractions, seed) {
+ fractionsEnv <- convertNamedListToEnv(fractions)
+
+ statFunctions <- callJMethod(x@sdf, "stat")
+ # Seed is expected to be Long on Scala side, here convert it to an integer
+ # due to SerDe limitation now.
+ sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed))
+ dataFrame(sdf)
+ })
http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 94f16c7..0b9e295 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -605,3 +605,21 @@ structToList <- function(struct) {
class(struct) <- "list"
struct
}
+
+# Convert a named list to an environment to be passed to JVM
+convertNamedListToEnv <- function(namedList) {
+ # Make sure each item in the list has a name
+ names <- names(namedList)
+ stopifnot(
+ if (is.null(names)) {
+ length(namedList) == 0
+ } else {
+ !any(is.na(names))
+ })
+
+ env <- new.env()
+ for (name in names) {
+ env[[name]] <- namedList[[name]]
+ }
+ env
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 46cab76..e1b42b0 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1416,6 +1416,16 @@ test_that("freqItems() on a DataFrame", {
expect_identical(result[[2]], list(list(-1, -99)))
})
+test_that("sampleBy() on a DataFrame", {
+ l <- lapply(c(0:99), function(i) { as.character(i %% 3) })
+ df <- createDataFrame(sqlContext, l, "key")
+ fractions <- list("0" = 0.1, "1" = 0.2)
+ sample <- sampleBy(df, "key", fractions, 0)
+ result <- collect(orderBy(count(groupBy(sample, "key")), "key"))
+ expect_identical(as.list(result[1, ]), list(key = "0", count = 2))
+ expect_identical(as.list(result[2, ]), list(key = "1", count = 10))
+})
+
test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
expect_equal(grepl("Table Not Found: blah", retError), TRUE)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org