You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/10/07 18:46:43 UTC

spark git commit: [SPARK-10752] [SPARKR] Implement corr() and cov in DataFrameStatFunctions.

Repository: spark
Updated Branches:
  refs/heads/master 27cdde2ff -> f57c63d4c


[SPARK-10752] [SPARKR] Implement corr() and cov in DataFrameStatFunctions.

Author: Sun Rui <ru...@intel.com>

Closes #8869 from sun-rui/SPARK-10752.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f57c63d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f57c63d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f57c63d4

Branch: refs/heads/master
Commit: f57c63d4c30d092a320c72f8c7181f2fa711ec30
Parents: 27cdde2
Author: Sun Rui <ru...@intel.com>
Authored: Wed Oct 7 09:46:37 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Oct 7 09:46:37 2015 -0700

----------------------------------------------------------------------
 R/pkg/DESCRIPTION                |   1 +
 R/pkg/NAMESPACE                  |   2 +
 R/pkg/R/DataFrame.R              |  33 +----------
 R/pkg/R/generics.R               |  10 +++-
 R/pkg/R/stats.R                  | 102 ++++++++++++++++++++++++++++++++++
 R/pkg/inst/tests/test_sparkSQL.R |  12 ++++
 6 files changed, 127 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f57c63d4/R/pkg/DESCRIPTION
----------------------------------------------------------------------
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index a3a16c4..3d6edb7 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -33,4 +33,5 @@ Collate:
     'mllib.R'
     'serialize.R'
     'sparkR.R'
+    'stats.R'
     'utils.R'

http://git-wip-us.apache.org/repos/asf/spark/blob/f57c63d4/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index c28c47d..9aad354 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -27,6 +27,8 @@ exportMethods("arrange",
               "collect",
               "columns",
               "count",
+              "cov",
+              "corr",
               "crosstab",
               "describe",
               "dim",

http://git-wip-us.apache.org/repos/asf/spark/blob/f57c63d4/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 14aea92..85db3a5 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1828,36 +1828,6 @@ setMethod("fillna",
             dataFrame(sdf)
           })
 
-#' crosstab
-#'
-#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
-#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
-#' non-zero pair frequencies will be returned.
-#'
-#' @param col1 name of the first column. Distinct items will make the first item of each row.
-#' @param col2 name of the second column. Distinct items will make the column names of the output.
-#' @return a local R data.frame representing the contingency table. The first column of each row
-#'         will be the distinct values of `col1` and the column names will be the distinct values
-#'         of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
-#'         occurrences will have zero as their counts.
-#'
-#' @rdname statfunctions
-#' @name crosstab
-#' @export
-#' @examples
-#' \dontrun{
-#' df <- jsonFile(sqlCtx, "/path/to/file.json")
-#' ct = crosstab(df, "title", "gender")
-#' }
-setMethod("crosstab",
-          signature(x = "DataFrame", col1 = "character", col2 = "character"),
-          function(x, col1, col2) {
-            statFunctions <- callJMethod(x@sdf, "stat")
-            sct <- callJMethod(statFunctions, "crosstab", col1, col2)
-            collect(dataFrame(sct))
-          })
-
-
 #' This function downloads the contents of a DataFrame into an R's data.frame.
 #' Since data.frames are held in memory, ensure that you have enough memory
 #' in your system to accommodate the contents.
@@ -1879,5 +1849,4 @@ setMethod("as.data.frame",
               stop(paste("Unused argument(s): ", paste(list(...), collapse=", ")))
             }
             collect(x)
-          }
-)
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/f57c63d4/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 3db41e0..e9086fd 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -399,6 +399,14 @@ setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") })
 #' @export
 setGeneric("columns", function(x) {standardGeneric("columns") })
 
+#' @rdname statfunctions
+#' @export
+setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") })
+
+#' @rdname statfunctions
+#' @export
+setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") })
+
 #' @rdname describe
 #' @export
 setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
@@ -986,4 +994,4 @@ setGeneric("rbind", signature = "...")
 
 #' @rdname as.data.frame
 #' @export
-setGeneric("as.data.frame")
\ No newline at end of file
+setGeneric("as.data.frame")

http://git-wip-us.apache.org/repos/asf/spark/blob/f57c63d4/R/pkg/R/stats.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
new file mode 100644
index 0000000..06382d5
--- /dev/null
+++ b/R/pkg/R/stats.R
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# stats.R - Statistic functions for DataFrames.
+
+setOldClass("jobj")
+
+#' crosstab
+#'
+#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
+#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
+#' non-zero pair frequencies will be returned.
+#'
+#' @param col1 name of the first column. Distinct items will make the first item of each row.
+#' @param col2 name of the second column. Distinct items will make the column names of the output.
+#' @return a local R data.frame representing the contingency table. The first column of each row
+#'         will be the distinct values of `col1` and the column names will be the distinct values
+#'         of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
+#'         occurrences will have zero as their counts.
+#'
+#' @rdname statfunctions
+#' @name crosstab
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' ct <- crosstab(df, "title", "gender")
+#' }
+setMethod("crosstab",
+          signature(x = "DataFrame", col1 = "character", col2 = "character"),
+          function(x, col1, col2) {
+            statFunctions <- callJMethod(x@sdf, "stat")
+            sct <- callJMethod(statFunctions, "crosstab", col1, col2)
+            collect(dataFrame(sct))
+          })
+
+#' cov
+#'
+#' Calculate the sample covariance of two numerical columns of a DataFrame.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param col1 the name of the first column
+#' @param col2 the name of the second column
+#' @return the covariance of the two columns.
+#'
+#' @rdname statfunctions
+#' @name cov
+#' @export
+#' @examples
+#'\dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' cov <- cov(df, "title", "gender")
+#' }
+setMethod("cov",
+          signature(x = "DataFrame", col1 = "character", col2 = "character"),
+          function(x, col1, col2) {
+            statFunctions <- callJMethod(x@sdf, "stat")
+            callJMethod(statFunctions, "cov", col1, col2)
+          })
+
+#' corr
+#'
+#' Calculates the correlation of two columns of a DataFrame.
+#' Currently only supports the Pearson Correlation Coefficient.
+#' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics.
+#' 
+#' @param x A SparkSQL DataFrame
+#' @param col1 the name of the first column
+#' @param col2 the name of the second column
+#' @param method Optional. A character specifying the method for calculating the correlation.
+#'               only "pearson" is allowed now.
+#' @return The Pearson Correlation Coefficient as a Double.
+#'
+#' @rdname statfunctions
+#' @name corr
+#' @export
+#' @examples
+#'\dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' corr <- corr(df, "title", "gender")
+#' corr <- corr(df, "title", "gender", method = "pearson")
+#' }
+setMethod("corr",
+          signature(x = "DataFrame", col1 = "character", col2 = "character"),
+          function(x, col1, col2, method = "pearson") {
+            statFunctions <- callJMethod(x@sdf, "stat")
+            callJMethod(statFunctions, "corr", col1, col2, method)
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/f57c63d4/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index faf42b7..bcf52b8 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1329,6 +1329,18 @@ test_that("crosstab() on a DataFrame", {
   expect_identical(expected, ordered)
 })
 
+test_that("cov() and corr() on a DataFrame", {
+  l <- lapply(c(0:9), function(x) { list(x, x * 2.0) })
+  df <- createDataFrame(sqlContext, l, c("singles", "doubles"))
+  result <- cov(df, "singles", "doubles")
+  expect_true(abs(result - 55.0 / 3) < 1e-12)
+
+  result <- corr(df, "singles", "doubles")
+  expect_true(abs(result - 1.0) < 1e-12)
+  result <- corr(df, "singles", "doubles", "pearson")
+  expect_true(abs(result - 1.0) < 1e-12)
+})
+
 test_that("SQL error message is returned from JVM", {
   retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
   expect_equal(grepl("Table Not Found: blah", retError), TRUE)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org