You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sh...@apache.org on 2015/10/14 07:31:45 UTC

spark git commit: [SPARK-10996] [SPARKR] Implement sampleBy() in DataFrameStatFunctions.

Repository: spark
Updated Branches:
  refs/heads/master 8b3288570 -> 390b22fad


[SPARK-10996] [SPARKR] Implement sampleBy() in DataFrameStatFunctions.

Author: Sun Rui <ru...@intel.com>

Closes #9023 from sun-rui/SPARK-10996.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/390b22fa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/390b22fa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/390b22fa

Branch: refs/heads/master
Commit: 390b22fad69a33eb6daee25b6b858a2e768670a5
Parents: 8b32885
Author: Sun Rui <ru...@intel.com>
Authored: Tue Oct 13 22:31:23 2015 -0700
Committer: Shivaram Venkataraman <sh...@cs.berkeley.edu>
Committed: Tue Oct 13 22:31:23 2015 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                  |  3 ++-
 R/pkg/R/DataFrame.R              | 14 ++++++--------
 R/pkg/R/generics.R               |  6 +++++-
 R/pkg/R/sparkR.R                 | 12 +++---------
 R/pkg/R/stats.R                  | 32 ++++++++++++++++++++++++++++++++
 R/pkg/R/utils.R                  | 18 ++++++++++++++++++
 R/pkg/inst/tests/test_sparkSQL.R | 10 ++++++++++
 7 files changed, 76 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index ed9cd94..52f7a01 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -65,6 +65,7 @@ exportMethods("arrange",
               "repartition",
               "sample",
               "sample_frac",
+              "sampleBy",
               "saveAsParquetFile",
               "saveAsTable",
               "saveDF",
@@ -254,4 +255,4 @@ export("structField",
        "structType.structField",
        "print.structType")
 
-export("as.data.frame")
\ No newline at end of file
+export("as.data.frame")

http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index b7f5f97..993be82 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1831,17 +1831,15 @@ setMethod("fillna",
               if (length(colNames) == 0 || !all(colNames != "")) {
                 stop("value should be an a named list with each name being a column name.")
               }
-
-              # Convert to the named list to an environment to be passed to JVM
-              valueMap <- new.env()
-              for (col in colNames) {
-                # Check each item in the named list is of valid type
-                v <- value[[col]]
+              # Check each item in the named list is of valid type
+              lapply(value, function(v) {
                 if (!(class(v) %in% c("integer", "numeric", "character"))) {
                   stop("Each item in value should be an integer, numeric or charactor.")
                 }
-                valueMap[[col]] <- v
-              }
+              })
+
+              # Convert to the named list to an environment to be passed to JVM
+              valueMap <- convertNamedListToEnv(value)
 
               # When value is a named list, caller is expected not to pass in cols
               if (!is.null(cols)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index c106a00..4a419f7 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -509,6 +509,10 @@ setGeneric("sample",
 setGeneric("sample_frac",
            function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
 
+#' @rdname statfunctions
+#' @export
+setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") })
+
 #' @rdname saveAsParquetFile
 #' @export
 setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") })
@@ -1006,4 +1010,4 @@ setGeneric("as.data.frame")
 
 #' @rdname attach
 #' @export
-setGeneric("attach")
\ No newline at end of file
+setGeneric("attach")

http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/sparkR.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index cc47110..9cf2f1a 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -163,19 +163,13 @@ sparkR.init <- function(
     sparkHome <- suppressWarnings(normalizePath(sparkHome))
   }
 
-  sparkEnvirMap <- new.env()
-  for (varname in names(sparkEnvir)) {
-    sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
-  }
+  sparkEnvirMap <- convertNamedListToEnv(sparkEnvir)
 
-  sparkExecutorEnvMap <- new.env()
-  if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) {
+  sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv)
+  if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) {
     sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
       paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
   }
-  for (varname in names(sparkExecutorEnv)) {
-    sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]]
-  }
 
   nonEmptyJars <- Filter(function(x) { x != "" }, jars)
   localJarPaths <- lapply(nonEmptyJars,

http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/stats.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index 4928cf4..f79329b 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -127,3 +127,35 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"),
             sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support)
             collect(dataFrame(sct))
           })
+
+#' sampleBy
+#'
+#' Returns a stratified sample without replacement based on the fraction given on each stratum.
+#' 
+#' @param x A SparkSQL DataFrame
+#' @param col column that defines strata
+#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is
+#'                  not specified, we treat its fraction as zero.
+#' @param seed random seed
+#' @return A new DataFrame that represents the stratified sample
+#'
+#' @rdname statfunctions
+#' @name sampleBy
+#' @export
+#' @examples
+#'\dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' sample <- sampleBy(df, "key", fractions, 36)
+#' }
+setMethod("sampleBy",
+          signature(x = "DataFrame", col = "character",
+                    fractions = "list", seed = "numeric"),
+          function(x, col, fractions, seed) {
+            fractionsEnv <- convertNamedListToEnv(fractions)
+
+            statFunctions <- callJMethod(x@sdf, "stat")
+            # Seed is expected to be Long on Scala side, here convert it to an integer
+            # due to SerDe limitation now.
+            sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed))
+            dataFrame(sdf)
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/R/utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 94f16c7..0b9e295 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -605,3 +605,21 @@ structToList <- function(struct) {
   class(struct) <- "list"
   struct
 }
+
+# Convert a named list to an environment to be passed to JVM
+convertNamedListToEnv <- function(namedList) {
+  # Make sure each item in the list has a name
+  names <- names(namedList)
+  stopifnot(
+    if (is.null(names)) {
+      length(namedList) == 0
+    } else {
+      !any(is.na(names))
+    })
+
+  env <- new.env()
+  for (name in names) {
+    env[[name]] <- namedList[[name]]
+  }
+  env
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/390b22fa/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 46cab76..e1b42b0 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1416,6 +1416,16 @@ test_that("freqItems() on a DataFrame", {
   expect_identical(result[[2]], list(list(-1, -99)))
 })
 
+test_that("sampleBy() on a DataFrame", {
+  l <- lapply(c(0:99), function(i) { as.character(i %% 3) })
+  df <- createDataFrame(sqlContext, l, "key")
+  fractions <- list("0" = 0.1, "1" = 0.2)
+  sample <- sampleBy(df, "key", fractions, 0)
+  result <- collect(orderBy(count(groupBy(sample, "key")), "key"))
+  expect_identical(as.list(result[1, ]), list(key = "0", count = 2))
+  expect_identical(as.list(result[2, ]), list(key = "1", count = 10))
+})
+
 test_that("SQL error message is returned from JVM", {
   retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
   expect_equal(grepl("Table Not Found: blah", retError), TRUE)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org