You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sh...@apache.org on 2015/05/09 03:30:05 UTC

spark git commit: [SPARK-7231] [SPARKR] Changes to make SparkR DataFrame dplyr friendly.

Repository: spark
Updated Branches:
  refs/heads/master b6c797b08 -> 0a901dd3a


[SPARK-7231] [SPARKR] Changes to make SparkR DataFrame dplyr friendly.

Changes include
1. Rename sortDF to arrange
2. Add new aliases `group_by` and `sample_frac`, `summarize`
3. Add more user friendly column addition (mutate), rename
4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr

Using these changes we can pretty much run the examples as described in http://cran.rstudio.com/web/packages/dplyr/vignettes/introduction.html with the same syntax

The only thing missing in SparkR is auto resolving column names when used in an expression i.e. making something like `select(flights, delay)` works in dply but we right now need `select(flights, flights$delay)` or `select(flights, "delay")`. But this is a complicated change and I'll file a new issue for it

cc sun-rui rxin

Author: Shivaram Venkataraman <sh...@cs.berkeley.edu>

Closes #6005 from shivaram/sparkr-df-api and squashes the following commits:

5e0716a [Shivaram Venkataraman] Fix some roxygen bugs
1254953 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into sparkr-df-api
0521149 [Shivaram Venkataraman] Changes to make SparkR DataFrame dplyr friendly. Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a901dd3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a901dd3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a901dd3

Branch: refs/heads/master
Commit: 0a901dd3a1eb3fd459d45b771ce4ad2cfef2a944
Parents: b6c797b
Author: Shivaram Venkataraman <sh...@cs.berkeley.edu>
Authored: Fri May 8 18:29:57 2015 -0700
Committer: Shivaram Venkataraman <sh...@cs.berkeley.edu>
Committed: Fri May 8 18:29:57 2015 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |  11 +-
 R/pkg/R/DataFrame.R                             | 127 +++++++++++++++++--
 R/pkg/R/column.R                                |  32 ++++-
 R/pkg/R/generics.R                              |  41 +++++-
 R/pkg/R/group.R                                 |  10 +-
 R/pkg/inst/tests/test_sparkSQL.R                |  36 +++++-
 .../scala/org/apache/spark/sql/functions.scala  |  16 +++
 .../org/apache/spark/sql/DataFrameSuite.scala   |   5 +
 8 files changed, 249 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0a901dd3/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7611f47..819e9a2 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -9,7 +9,8 @@ export("print.jobj")
 
 exportClasses("DataFrame")
 
-exportMethods("cache",
+exportMethods("arrange",
+              "cache",
               "collect",
               "columns",
               "count",
@@ -20,6 +21,7 @@ exportMethods("cache",
               "explain",
               "filter",
               "first",
+              "group_by",
               "groupBy",
               "head",
               "insertInto",
@@ -28,12 +30,15 @@ exportMethods("cache",
               "join",
               "limit",
               "orderBy",
+              "mutate",
               "names",
               "persist",
               "printSchema",
               "registerTempTable",
+              "rename",
               "repartition",
               "sampleDF",
+              "sample_frac",
               "saveAsParquetFile",
               "saveAsTable",
               "saveDF",
@@ -42,7 +47,7 @@ exportMethods("cache",
               "selectExpr",
               "show",
               "showDF",
-              "sortDF",
+              "summarize",
               "take",
               "unionAll",
               "unpersist",
@@ -72,6 +77,8 @@ exportMethods("abs",
               "max",
               "mean",
               "min",
+              "n",
+              "n_distinct",
               "rlike",
               "sqrt",
               "startsWith",

http://git-wip-us.apache.org/repos/asf/spark/blob/0a901dd3/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 354642e..8a9d2dd 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -480,6 +480,7 @@ setMethod("distinct",
 #' @param withReplacement Sampling with replacement or not
 #' @param fraction The (rough) sample target fraction
 #' @rdname sampleDF
+#' @aliases sample_frac
 #' @export
 #' @examples
 #'\dontrun{
@@ -501,6 +502,15 @@ setMethod("sampleDF",
             dataFrame(sdf)
           })
 
+#' @rdname sampleDF
+#' @aliases sampleDF
+setMethod("sample_frac",
+          signature(x = "DataFrame", withReplacement = "logical",
+                    fraction = "numeric"),
+          function(x, withReplacement, fraction) {
+            sampleDF(x, withReplacement, fraction)
+          })
+
 #' Count
 #' 
 #' Returns the number of rows in a DataFrame
@@ -682,7 +692,8 @@ setMethod("toRDD",
 #' @param x a DataFrame
 #' @return a GroupedData
 #' @seealso GroupedData
-#' @rdname DataFrame
+#' @aliases group_by
+#' @rdname groupBy
 #' @export
 #' @examples
 #' \dontrun{
@@ -705,12 +716,21 @@ setMethod("groupBy",
              groupedData(sgd)
            })
 
-#' Agg
+#' @rdname groupBy
+#' @aliases group_by
+setMethod("group_by",
+          signature(x = "DataFrame"),
+          function(x, ...) {
+            groupBy(x, ...)
+          })
+
+#' Summarize data across columns
 #'
 #' Compute aggregates by specifying a list of columns
 #'
 #' @param x a DataFrame
 #' @rdname DataFrame
+#' @aliases summarize
 #' @export
 setMethod("agg",
           signature(x = "DataFrame"),
@@ -718,6 +738,14 @@ setMethod("agg",
             agg(groupBy(x), ...)
           })
 
+#' @rdname DataFrame
+#' @aliases agg
+setMethod("summarize",
+          signature(x = "DataFrame"),
+          function(x, ...) {
+            agg(x, ...)
+          })
+
 
 ############################## RDD Map Functions ##################################
 # All of the following functions mirror the existing RDD map functions,           #
@@ -886,7 +914,7 @@ setMethod("select",
           signature(x = "DataFrame", col = "list"),
           function(x, col) {
             cols <- lapply(col, function(c) {
-              if (class(c)== "Column") {
+              if (class(c) == "Column") {
                 c@jc
               } else {
                 col(c)@jc
@@ -946,6 +974,42 @@ setMethod("withColumn",
             select(x, x$"*", alias(col, colName))
           })
 
+#' Mutate
+#'
+#' Return a new DataFrame with the specified columns added.
+#'
+#' @param x A DataFrame
+#' @param col a named argument of the form name = col
+#' @return A new DataFrame with the new columns added.
+#' @rdname withColumn
+#' @aliases withColumn
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
+#' names(newDF) # Will contain newCol, newCol2
+#' }
+setMethod("mutate",
+          signature(x = "DataFrame"),
+          function(x, ...) {
+            cols <- list(...)
+            stopifnot(length(cols) > 0)
+            stopifnot(class(cols[[1]]) == "Column")
+            ns <- names(cols)
+            if (!is.null(ns)) {
+              for (n in ns) {
+                if (n != "") {
+                  cols[[n]] <- alias(cols[[n]], n)
+                }
+              }
+            }
+            do.call(select, c(x, x$"*", cols))
+          })
+
 #' WithColumnRenamed
 #'
 #' Rename an existing column in a DataFrame.
@@ -977,9 +1041,47 @@ setMethod("withColumnRenamed",
             select(x, cols)
           })
 
+#' Rename
+#'
+#' Rename an existing column in a DataFrame.
+#'
+#' @param x A DataFrame
+#' @param newCol A named pair of the form new_column_name = existing_column
+#' @return A DataFrame with the column name changed.
+#' @rdname withColumnRenamed
+#' @aliases withColumnRenamed
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' newDF <- rename(df, col1 = df$newCol1)
+#' }
+setMethod("rename",
+          signature(x = "DataFrame"),
+          function(x, ...) {
+            renameCols <- list(...)
+            stopifnot(length(renameCols) > 0)
+            stopifnot(class(renameCols[[1]]) == "Column")
+            newNames <- names(renameCols)
+            oldNames <- lapply(renameCols, function(col) {
+              callJMethod(col@jc, "toString")
+            })
+            cols <- lapply(columns(x), function(c) {
+              if (c %in% oldNames) {
+                alias(col(c), newNames[[match(c, oldNames)]])
+              } else {
+                col(c)
+              }
+            })
+            select(x, cols)
+          })
+
 setClassUnion("characterOrColumn", c("character", "Column"))
 
-#' SortDF 
+#' Arrange
 #'
 #' Sort a DataFrame by the specified column(s).
 #'
@@ -987,7 +1089,7 @@ setClassUnion("characterOrColumn", c("character", "Column"))
 #' @param col Either a Column object or character vector indicating the field to sort on
 #' @param ... Additional sorting fields
 #' @return A DataFrame where all elements are sorted.
-#' @rdname sortDF
+#' @rdname arrange
 #' @export
 #' @examples
 #'\dontrun{
@@ -995,11 +1097,11 @@ setClassUnion("characterOrColumn", c("character", "Column"))
 #' sqlCtx <- sparkRSQL.init(sc)
 #' path <- "path/to/file.json"
 #' df <- jsonFile(sqlCtx, path)
-#' sortDF(df, df$col1)
-#' sortDF(df, "col1")
-#' sortDF(df, asc(df$col1), desc(abs(df$col2)))
+#' arrange(df, df$col1)
+#' arrange(df, "col1")
+#' arrange(df, asc(df$col1), desc(abs(df$col2)))
 #' }
-setMethod("sortDF",
+setMethod("arrange",
           signature(x = "DataFrame", col = "characterOrColumn"),
           function(x, col, ...) {
             if (class(col) == "character") {
@@ -1013,12 +1115,12 @@ setMethod("sortDF",
             dataFrame(sdf)
           })
 
-#' @rdname sortDF
+#' @rdname arrange
 #' @aliases orderBy,DataFrame,function-method
 setMethod("orderBy",
           signature(x = "DataFrame", col = "characterOrColumn"),
           function(x, col) {
-            sortDF(x, col)
+            arrange(x, col)
           })
 
 #' Filter
@@ -1026,7 +1128,7 @@ setMethod("orderBy",
 #' Filter the rows of a DataFrame according to a given condition.
 #'
 #' @param x A DataFrame to be sorted.
-#' @param condition The condition to sort on. This may either be a Column expression
+#' @param condition The condition to filter on. This may either be a Column expression
 #' or a string containing a SQL statement
 #' @return A DataFrame containing only the rows that meet the condition.
 #' @rdname filter
@@ -1106,6 +1208,7 @@ setMethod("join",
 #'
 #' Return a new DataFrame containing the union of rows in this DataFrame
 #' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
+#' Note that this does not remove duplicate rows across the two DataFrames.
 #'
 #' @param x A Spark DataFrame
 #' @param y A Spark DataFrame

http://git-wip-us.apache.org/repos/asf/spark/blob/0a901dd3/R/pkg/R/column.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 95fb9ff..9a68445 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -131,6 +131,8 @@ createMethods()
 #' alias
 #'
 #' Set a new name for a column
+
+#' @rdname column
 setMethod("alias",
           signature(object = "Column"),
           function(object, data) {
@@ -141,8 +143,12 @@ setMethod("alias",
             }
           })
 
+#' substr
+#'
 #' An expression that returns a substring.
 #'
+#' @rdname column
+#'
 #' @param start starting position
 #' @param stop ending position
 setMethod("substr", signature(x = "Column"),
@@ -152,6 +158,9 @@ setMethod("substr", signature(x = "Column"),
           })
 
 #' Casts the column to a different data type.
+#'
+#' @rdname column
+#'
 #' @examples
 #' \dontrun{
 #'   cast(df$age, "string")
@@ -173,8 +182,8 @@ setMethod("cast",
 
 #' Approx Count Distinct
 #'
-#' Returns the approximate number of distinct items in a group.
-#'
+#' @rdname column
+#' @return the approximate number of distinct items in a group.
 setMethod("approxCountDistinct",
           signature(x = "Column"),
           function(x, rsd = 0.95) {
@@ -184,8 +193,8 @@ setMethod("approxCountDistinct",
 
 #' Count Distinct
 #'
-#' returns the number of distinct items in a group.
-#'
+#' @rdname column
+#' @return the number of distinct items in a group.
 setMethod("countDistinct",
           signature(x = "Column"),
           function(x, ...) {
@@ -197,3 +206,18 @@ setMethod("countDistinct",
             column(jc)
           })
 
+#' @rdname column
+#' @aliases countDistinct
+setMethod("n_distinct",
+          signature(x = "Column"),
+          function(x, ...) {
+            countDistinct(x, ...)
+          })
+
+#' @rdname column
+#' @aliases count
+setMethod("n",
+          signature(x = "Column"),
+          function(x) {
+            count(x)
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/0a901dd3/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 380e8eb..557128a 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -380,6 +380,14 @@ setGeneric("value", function(bcast) { standardGeneric("value") })
 
 ####################  DataFrame Methods ########################
 
+#' @rdname agg
+#' @export
+setGeneric("agg", function (x, ...) { standardGeneric("agg") })
+
+#' @rdname arrange
+#' @export
+setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") })
+
 #' @rdname schema
 #' @export
 setGeneric("columns", function(x) {standardGeneric("columns") })
@@ -404,6 +412,10 @@ setGeneric("except", function(x, y) { standardGeneric("except") })
 #' @export
 setGeneric("filter", function(x, condition) { standardGeneric("filter") })
 
+#' @rdname groupBy
+#' @export
+setGeneric("group_by", function(x, ...) { standardGeneric("group_by") })
+
 #' @rdname DataFrame
 #' @export
 setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
@@ -424,7 +436,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") })
 #' @export
 setGeneric("limit", function(x, num) {standardGeneric("limit") })
 
-#' @rdname sortDF
+#' @rdname withColumn
+#' @export
+setGeneric("mutate", function(x, ...) {standardGeneric("mutate") })
+
+#' @rdname arrange
 #' @export
 setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
 
@@ -432,12 +448,23 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
 #' @export
 setGeneric("printSchema", function(x) { standardGeneric("printSchema") })
 
+#' @rdname withColumnRenamed
+#' @export
+setGeneric("rename", function(x, ...) { standardGeneric("rename") })
+
 #' @rdname registerTempTable
 #' @export
 setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
 
 #' @rdname sampleDF
 #' @export
+setGeneric("sample_frac",
+           function(x, withReplacement, fraction, seed) {
+             standardGeneric("sample_frac")
+          })
+
+#' @rdname sampleDF
+#' @export
 setGeneric("sampleDF",
            function(x, withReplacement, fraction, seed) {
              standardGeneric("sampleDF")
@@ -473,9 +500,9 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr")
 #' @export
 setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
 
-#' @rdname sortDF
+#' @rdname agg
 #' @export
-setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") })
+setGeneric("summarize", function(x,...) { standardGeneric("summarize") })
 
 # @rdname tojson
 # @export
@@ -566,6 +593,14 @@ setGeneric("lower", function(x) { standardGeneric("lower") })
 
 #' @rdname column
 #' @export
+setGeneric("n", function(x) { standardGeneric("n") })
+
+#' @rdname column
+#' @export
+setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
+
+#' @rdname column
+#' @export
 setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
 
 #' @rdname column

http://git-wip-us.apache.org/repos/asf/spark/blob/0a901dd3/R/pkg/R/group.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 02237b3..5a7a8a2 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -56,6 +56,7 @@ setMethod("show", "GroupedData",
 #'
 #' @param x a GroupedData
 #' @return a DataFrame
+#' @rdname agg
 #' @export
 #' @examples
 #' \dontrun{
@@ -83,8 +84,6 @@ setMethod("count",
 #'  df2 <- agg(df, age = "sum")  # new column name will be created as 'SUM(age#0)'
 #'  df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
 #' }
-setGeneric("agg", function (x, ...) { standardGeneric("agg") })
-
 setMethod("agg",
           signature(x = "GroupedData"),
           function(x, ...) {
@@ -112,6 +111,13 @@ setMethod("agg",
             dataFrame(sdf)
           })
 
+#' @rdname agg
+#' @aliases agg
+setMethod("summarize",
+          signature(x = "GroupedData"),
+          function(x, ...) {
+            agg(x, ...)
+          })
 
 # sum/mean/avg/min/max
 methods <- c("sum", "mean", "avg", "min", "max")

http://git-wip-us.apache.org/repos/asf/spark/blob/0a901dd3/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 7a42e28..dbb535e 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -428,6 +428,10 @@ test_that("sampleDF on a DataFrame", {
   expect_true(inherits(sampled, "DataFrame"))
   sampled2 <- sampleDF(df, FALSE, 0.1)
   expect_true(count(sampled2) < 3)
+
+  # Also test sample_frac
+  sampled3 <- sample_frac(df, FALSE, 0.1)
+  expect_true(count(sampled3) < 3)
 })
 
 test_that("select operators", {
@@ -533,6 +537,7 @@ test_that("column functions", {
   c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c)
   c3 <- lower(c) + upper(c) + first(c) + last(c)
   c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
+  c5 <- n(c) + n_distinct(c)
 })
 
 test_that("string operators", {
@@ -557,6 +562,13 @@ test_that("group by", {
   expect_true(inherits(df2, "DataFrame"))
   expect_true(3 == count(df2))
 
+  # Also test group_by, summarize, mean
+  gd1 <- group_by(df, "name")
+  expect_true(inherits(gd1, "GroupedData"))
+  df_summarized <- summarize(gd, mean_age = mean(df$age))
+  expect_true(inherits(df_summarized, "DataFrame"))
+  expect_true(3 == count(df_summarized))
+
   df3 <- agg(gd, age = "sum")
   expect_true(inherits(df3, "DataFrame"))
   expect_true(3 == count(df3))
@@ -573,12 +585,12 @@ test_that("group by", {
   expect_true(3 == count(max(gd, "age")))
 })
 
-test_that("sortDF() and orderBy() on a DataFrame", {
+test_that("arrange() and orderBy() on a DataFrame", {
   df <- jsonFile(sqlCtx, jsonPath)
-  sorted <- sortDF(df, df$age)
+  sorted <- arrange(df, df$age)
   expect_true(collect(sorted)[1,2] == "Michael")
 
-  sorted2 <- sortDF(df, "name")
+  sorted2 <- arrange(df, "name")
   expect_true(collect(sorted2)[2,"age"] == 19)
 
   sorted3 <- orderBy(df, asc(df$age))
@@ -659,17 +671,17 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", {
   writeLines(lines, jsonPath2)
   df2 <- loadDF(sqlCtx, jsonPath2, "json")
 
-  unioned <- sortDF(unionAll(df, df2), df$age)
+  unioned <- arrange(unionAll(df, df2), df$age)
   expect_true(inherits(unioned, "DataFrame"))
   expect_true(count(unioned) == 6)
   expect_true(first(unioned)$name == "Michael")
 
-  excepted <- sortDF(except(df, df2), desc(df$age))
+  excepted <- arrange(except(df, df2), desc(df$age))
   expect_true(inherits(unioned, "DataFrame"))
   expect_true(count(excepted) == 2)
   expect_true(first(excepted)$name == "Justin")
 
-  intersected <- sortDF(intersect(df, df2), df$age)
+  intersected <- arrange(intersect(df, df2), df$age)
   expect_true(inherits(unioned, "DataFrame"))
   expect_true(count(intersected) == 1)
   expect_true(first(intersected)$name == "Andy")
@@ -687,6 +699,18 @@ test_that("withColumn() and withColumnRenamed()", {
   expect_true(columns(newDF2)[1] == "newerAge")
 })
 
+test_that("mutate() and rename()", {
+  df <- jsonFile(sqlCtx, jsonPath)
+  newDF <- mutate(df, newAge = df$age + 2)
+  expect_true(length(columns(newDF)) == 3)
+  expect_true(columns(newDF)[3] == "newAge")
+  expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32)
+
+  newDF2 <- rename(df, newerAge = df$age)
+  expect_true(length(columns(newDF2)) == 2)
+  expect_true(columns(newDF2)[1] == "newerAge")
+})
+
 test_that("saveDF() on DataFrame and works with parquetFile", {
   df <- jsonFile(sqlCtx, jsonPath)
   saveDF(df, parquetPath, "parquet", mode="overwrite")

http://git-wip-us.apache.org/repos/asf/spark/blob/0a901dd3/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 1728b0b..fae4bd0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -247,6 +247,22 @@ object functions {
   def last(columnName: String): Column = last(Column(columnName))
 
   /**
+   * Aggregate function: returns the average of the values in a group.
+   * Alias for avg.
+   *
+   * @group agg_funcs
+   */
+  def mean(e: Column): Column = avg(e)
+
+  /**
+   * Aggregate function: returns the average of the values in a group.
+   * Alias for avg.
+   *
+   * @group agg_funcs
+   */
+  def mean(columnName: String): Column = avg(columnName)
+
+  /**
    * Aggregate function: returns the minimum value of the expression in a group.
    *
    * @group agg_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/0a901dd3/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d2ca8dc..cf590cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -308,6 +308,11 @@ class DataFrameSuite extends QueryTest {
       testData2.agg(avg('a)),
       Row(2.0))
 
+    // Also check mean
+    checkAnswer(
+      testData2.agg(mean('a)),
+      Row(2.0))
+
     checkAnswer(
       testData2.agg(avg('a), sumDistinct('a)), // non-partial
       Row(2.0, 6.0) :: Nil)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org