You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sh...@apache.org on 2016/06/16 04:42:21 UTC

spark git commit: [SPARK-12922][SPARKR][WIP] Implement gapply() on DataFrame in SparkR

Repository: spark
Updated Branches:
  refs/heads/master b75f454f9 -> 7c6c69263


[SPARK-12922][SPARKR][WIP] Implement gapply() on DataFrame in SparkR

## What changes were proposed in this pull request?

gapply() applies an R function on groups grouped by one or more columns of a DataFrame, and returns a DataFrame. It is like GroupedDataSet.flatMapGroups() in the Dataset API.

Please, let me know what do you think and if you have any ideas to improve it.

Thank you!

## How was this patch tested?
Unit tests.
1. Primitive test with different column types
2. Add a boolean column
3. Compute average by a group

Author: Narine Kokhlikyan <na...@gmail.com>
Author: NarineK <na...@us.ibm.com>

Closes #12836 from NarineK/gapply2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7c6c6926
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7c6c6926
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7c6c6926

Branch: refs/heads/master
Commit: 7c6c6926376c93acc42dd56a399d816f4838f28c
Parents: b75f454
Author: Narine Kokhlikyan <na...@gmail.com>
Authored: Wed Jun 15 21:42:05 2016 -0700
Committer: Shivaram Venkataraman <sh...@cs.berkeley.edu>
Committed: Wed Jun 15 21:42:05 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |   1 +
 R/pkg/R/DataFrame.R                             |  82 ++++++++++-
 R/pkg/R/deserialize.R                           |  30 ++++
 R/pkg/R/generics.R                              |   4 +
 R/pkg/R/group.R                                 |  62 +++++++++
 R/pkg/inst/tests/testthat/test_sparkSQL.R       |  65 +++++++++
 R/pkg/inst/worker/worker.R                      | 138 ++++++++++++-------
 .../scala/org/apache/spark/api/r/RRunner.scala  |  20 ++-
 .../sql/catalyst/plans/logical/object.scala     |  49 +++++++
 .../spark/sql/RelationalGroupedDataset.scala    |  48 ++++++-
 .../org/apache/spark/sql/api/r/SQLUtils.scala   |  26 ++--
 .../spark/sql/execution/SparkStrategies.scala   |   3 +
 .../apache/spark/sql/execution/objects.scala    |  72 +++++++++-
 .../sql/execution/r/MapPartitionsRWrapper.scala |   5 +-
 14 files changed, 540 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index a8cf53f..8db4d5c 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -62,6 +62,7 @@ exportMethods("arrange",
               "filter",
               "first",
               "freqItems",
+              "gapply",
               "group_by",
               "groupBy",
               "head",

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 0ff350d..9a9b3f7 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1181,7 +1181,7 @@ dapplyInternal <- function(x, func, schema) {
 #'             func should have only one parameter, to which a data.frame corresponds
 #'             to each partition will be passed.
 #'             The output of func should be a data.frame.
-#' @param schema The schema of the resulting DataFrame after the function is applied.
+#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
 #'               It must match the output of func.
 #' @family SparkDataFrame functions
 #' @rdname dapply
@@ -1267,6 +1267,86 @@ setMethod("dapplyCollect",
             ldf
           })
 
+#' gapply
+#'
+#' Group the SparkDataFrame using the specified columns and apply the R function to each
+#' group.
+#'
+#' @param x A SparkDataFrame
+#' @param cols Grouping columns
+#' @param func A function to be applied to each group partition specified by grouping
+#'             column of the SparkDataFrame. The function `func` takes as argument
+#'             a key - grouping columns and a data frame - a local R data.frame.
+#'             The output of `func` is a local R data.frame.
+#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
+#'               The schema must match to output of `func`. It has to be defined for each
+#'               output column with preferred output column name and corresponding data type.
+#' @family SparkDataFrame functions
+#' @rdname gapply
+#' @name gapply
+#' @export
+#' @examples
+#' 
+#' \dontrun{
+#' Computes the arithmetic mean of the second column by grouping
+#' on the first and third columns. Output the grouping values and the average.
+#'
+#' df <- createDataFrame (
+#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
+#'   c("a", "b", "c", "d"))
+#'
+#' Here our output contains three columns, the key which is a combination of two
+#' columns with data types integer and string and the mean which is a double.
+#' schema <-  structType(structField("a", "integer"), structField("c", "string"),
+#'   structField("avg", "double"))
+#' df1 <- gapply(
+#'   df,
+#'   list("a", "c"),
+#'   function(key, x) {
+#'     y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+#'   },
+#' schema)
+#' collect(df1)
+#'
+#' Result
+#' ------
+#' a c avg
+#' 3 3 3.0
+#' 1 1 1.5
+#'
+#' Fits linear models on iris dataset by grouping on the 'Species' column and
+#' using 'Sepal_Length' as a target variable, 'Sepal_Width', 'Petal_Length'
+#' and 'Petal_Width' as training features.
+#' 
+#' df <- createDataFrame (iris)
+#' schema <- structType(structField("(Intercept)", "double"),
+#'   structField("Sepal_Width", "double"),structField("Petal_Length", "double"),
+#'   structField("Petal_Width", "double"))
+#' df1 <- gapply(
+#'   df,
+#'   list(df$"Species"),
+#'   function(key, x) {
+#'     m <- suppressWarnings(lm(Sepal_Length ~
+#'     Sepal_Width + Petal_Length + Petal_Width, x))
+#'     data.frame(t(coef(m)))
+#'   }, schema)
+#' collect(df1)
+#'
+#'Result
+#'---------
+#' Model  (Intercept)  Sepal_Width  Petal_Length  Petal_Width
+#' 1        0.699883    0.3303370    0.9455356    -0.1697527
+#' 2        1.895540    0.3868576    0.9083370    -0.6792238
+#' 3        2.351890    0.6548350    0.2375602     0.2521257
+#'
+#'}
+setMethod("gapply",
+          signature(x = "SparkDataFrame"),
+          function(x, cols, func, schema) {
+            grouped <- do.call("groupBy", c(x, cols))
+            gapply(grouped, func, schema)
+          })
+
 ############################## RDD Map Functions ##################################
 # All of the following functions mirror the existing RDD map functions,           #
 # but allow for use with DataFrames by first converting to an RRDD before calling #

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/R/deserialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index ce071b1..0e99b17 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -197,6 +197,36 @@ readMultipleObjects <- function(inputCon) {
   data # this is a list of named lists now
 }
 
+readMultipleObjectsWithKeys <- function(inputCon) {
+  # readMultipleObjectsWithKeys will read multiple continuous objects from
+  # a DataOutputStream. There is no preceding field telling the count
+  # of the objects, so the number of objects varies, we try to read
+  # all objects in a loop until the end of the stream. This function
+  # is for use by gapply. Each group of rows is followed by the grouping
+  # key for this group which is then followed by next group.
+  keys <- list()
+  data <- list()
+  subData <- list()
+  while (TRUE) {
+    # If reaching the end of the stream, type returned should be "".
+    type <- readType(inputCon)
+    if (type == "") {
+      break
+    } else if (type == "r") {
+      type <- readType(inputCon)
+      # A grouping boundary detected
+      key <- readTypedObject(inputCon, type)
+      index <- length(data) + 1L
+      data[[index]] <- subData
+      keys[[index]] <- key
+      subData <- list()
+    } else {
+      subData[[length(subData) + 1L]] <- readTypedObject(inputCon, type)
+    }
+  }
+  list(keys = keys, data = data) # this is a list of keys and corresponding data
+}
+
 readRowList <- function(obj) {
   # readRowList is meant for use inside an lapply. As a result, it is
   # necessary to open a standalone connection for the row and consume

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 50fc204..40a96d8 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -454,6 +454,10 @@ setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
 #' @export
 setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") })
 
+#' @rdname gapply
+#' @export
+setGeneric("gapply", function(x, ...) { standardGeneric("gapply") })
+
 #' @rdname summary
 #' @export
 setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/R/group.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 08f4a49..b704776 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -142,3 +142,65 @@ createMethods <- function() {
 }
 
 createMethods()
+
+#' gapply
+#'
+#' Applies a R function to each group in the input GroupedData
+#'
+#' @param x a GroupedData
+#' @param func A function to be applied to each group partition specified by GroupedData.
+#'             The function `func` takes as argument a key - grouping columns and
+#'             a data frame - a local R data.frame.
+#'             The output of `func` is a local R data.frame.
+#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
+#'               The schema must match to output of `func`. It has to be defined for each
+#'               output column with preferred output column name and corresponding data type.
+#' @return a SparkDataFrame
+#' @rdname gapply
+#' @name gapply
+#' @examples
+#' \dontrun{
+#' Computes the arithmetic mean of the second column by grouping
+#' on the first and third columns. Output the grouping values and the average.
+#'
+#' df <- createDataFrame (
+#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
+#'   c("a", "b", "c", "d"))
+#'
+#' Here our output contains three columns, the key which is a combination of two
+#' columns with data types integer and string and the mean which is a double.
+#' schema <-  structType(structField("a", "integer"), structField("c", "string"),
+#'   structField("avg", "double"))
+#' df1 <- gapply(
+#'   df,
+#'   list("a", "c"),
+#'   function(key, x) {
+#'     y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+#'   },
+#' schema)
+#' collect(df1)
+#'
+#' Result
+#' ------
+#' a c avg
+#' 3 3 3.0
+#' 1 1 1.5
+#' }
+setMethod("gapply",
+          signature(x = "GroupedData"),
+          function(x, func, schema) {
+            try(if (is.null(schema)) stop("schema cannot be NULL"))
+            packageNamesArr <- serialize(.sparkREnv[[".packages"]],
+                                 connection = NULL)
+            broadcastArr <- lapply(ls(.broadcastNames),
+                              function(name) { get(name, .broadcastNames) })
+            sdf <- callJStatic(
+                     "org.apache.spark.sql.api.r.SQLUtils",
+                     "gapply",
+                     x@sgd,
+                     serialize(cleanClosure(func), connection = NULL),
+                     packageNamesArr,
+                     broadcastArr,
+                     schema$jobj)
+            dataFrame(sdf)
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index d1ca3b7..c11930a 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2146,6 +2146,71 @@ test_that("repartition by columns on DataFrame", {
   expect_equal(nrow(df1), 2)
 })
 
+test_that("gapply() on a DataFrame", {
+  df <- createDataFrame (
+    list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
+    c("a", "b", "c", "d"))
+  expected <- collect(df)
+  df1 <- gapply(df, list("a"), function(key, x) { x }, schema(df))
+  actual <- collect(df1)
+  expect_identical(actual, expected)
+
+  # Computes the sum of second column by grouping on the first and third columns
+  # and checks if the sum is larger than 2
+  schema <- structType(structField("a", "integer"), structField("e", "boolean"))
+  df2 <- gapply(
+    df,
+    list(df$"a", df$"c"),
+    function(key, x) {
+      y <- data.frame(key[1], sum(x$b) > 2)
+    },
+    schema)
+  actual <- collect(df2)$e
+  expected <- c(TRUE, TRUE)
+  expect_identical(actual, expected)
+
+  # Computes the arithmetic mean of the second column by grouping
+  # on the first and third columns. Output the groupping value and the average.
+  schema <-  structType(structField("a", "integer"), structField("c", "string"),
+               structField("avg", "double"))
+  df3 <- gapply(
+    df,
+    list("a", "c"),
+    function(key, x) {
+      y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+    },
+    schema)
+  actual <- collect(df3)
+  actual <-  actual[order(actual$a), ]
+  rownames(actual) <- NULL
+  expected <- collect(select(df, "a", "b", "c"))
+  expected <- data.frame(aggregate(expected$b, by = list(expected$a, expected$c), FUN = mean))
+  colnames(expected) <- c("a", "c", "avg")
+  expected <-  expected[order(expected$a), ]
+  rownames(expected) <- NULL
+  expect_identical(actual, expected)
+
+  irisDF <- suppressWarnings(createDataFrame (iris))
+  schema <-  structType(structField("Sepal_Length", "double"), structField("Avg", "double"))
+  # Groups by `Sepal_Length` and computes the average for `Sepal_Width`
+  df4 <- gapply(
+    cols = list("Sepal_Length"),
+    irisDF,
+    function(key, x) {
+      y <- data.frame(key, mean(x$Sepal_Width), stringsAsFactors = FALSE)
+    },
+    schema)
+  actual <- collect(df4)
+  actual <- actual[order(actual$Sepal_Length), ]
+  rownames(actual) <- NULL
+  agg_local_df <- data.frame(aggregate(iris$Sepal.Width, by = list(iris$Sepal.Length), FUN = mean),
+                    stringsAsFactors = FALSE)
+  colnames(agg_local_df) <- c("Sepal_Length", "Avg")
+  expected <-  agg_local_df[order(agg_local_df$Sepal_Length), ]
+  rownames(expected) <- NULL
+  expect_identical(actual, expected)
+})
+
 test_that("Window functions on a DataFrame", {
   setHiveContext(sc)
   df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")),

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/inst/worker/worker.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 40cda0c..debf018 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -27,6 +27,54 @@ elapsedSecs <- function() {
   proc.time()[3]
 }
 
+compute <- function(mode, partition, serializer, deserializer, key,
+             colNames, computeFunc, inputData) {
+  if (mode > 0) {
+    if (deserializer == "row") {
+      # Transform the list of rows into a data.frame
+      # Note that the optional argument stringsAsFactors for rbind is
+      # available since R 3.2.4. So we set the global option here.
+      oldOpt <- getOption("stringsAsFactors")
+      options(stringsAsFactors = FALSE)
+      inputData <- do.call(rbind.data.frame, inputData)
+      options(stringsAsFactors = oldOpt)
+
+      names(inputData) <- colNames
+    } else {
+      # Check to see if inputData is a valid data.frame
+      stopifnot(deserializer == "byte")
+      stopifnot(class(inputData) == "data.frame")
+    }
+
+    if (mode == 2) {
+      output <- computeFunc(key, inputData)
+    } else {
+      output <- computeFunc(inputData)
+    }
+    if (serializer == "row") {
+      # Transform the result data.frame back to a list of rows
+      output <- split(output, seq(nrow(output)))
+    } else {
+      # Serialize the ouput to a byte array
+      stopifnot(serializer == "byte")
+    }
+  } else {
+    output <- computeFunc(partition, inputData)
+  }
+  return (output)
+}
+
+outputResult <- function(serializer, output, outputCon) {
+  if (serializer == "byte") {
+    SparkR:::writeRawSerialize(outputCon, output)
+  } else if (serializer == "row") {
+    SparkR:::writeRowSerialize(outputCon, output)
+  } else {
+    # write lines one-by-one with flag
+    lapply(output, function(line) SparkR:::writeString(outputCon, line))
+  }
+}
+
 # Constants
 specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L)
 
@@ -79,75 +127,71 @@ if (numBroadcastVars > 0) {
 
 # Timing broadcast
 broadcastElap <- elapsedSecs()
+# Initial input timing
+inputElap <- broadcastElap
 
 # If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int
 # as number of partitions to create.
 numPartitions <- SparkR:::readInt(inputCon)
 
-isDataFrame <- as.logical(SparkR:::readInt(inputCon))
+# 0 - RDD mode, 1 - dapply mode, 2 - gapply mode
+mode <- SparkR:::readInt(inputCon)
 
-# If isDataFrame, then read column names
-if (isDataFrame) {
+if (mode > 0) {
   colNames <- SparkR:::readObject(inputCon)
 }
 
 isEmpty <- SparkR:::readInt(inputCon)
+computeInputElapsDiff <- 0
+outputComputeElapsDiff <- 0
 
 if (isEmpty != 0) {
-
   if (numPartitions == -1) {
     if (deserializer == "byte") {
       # Now read as many characters as described in funcLen
       data <- SparkR:::readDeserialize(inputCon)
     } else if (deserializer == "string") {
       data <- as.list(readLines(inputCon))
+    } else if (deserializer == "row" && mode == 2) {
+      dataWithKeys <- SparkR:::readMultipleObjectsWithKeys(inputCon)
+      keys <- dataWithKeys$keys
+      data <- dataWithKeys$data
     } else if (deserializer == "row") {
       data <- SparkR:::readMultipleObjects(inputCon)
     }
+
     # Timing reading input data for execution
     inputElap <- elapsedSecs()
-
-    if (isDataFrame) {
-      if (deserializer == "row") {
-        # Transform the list of rows into a data.frame
-        # Note that the optional argument stringsAsFactors for rbind is
-        # available since R 3.2.4. So we set the global option here.
-        oldOpt <- getOption("stringsAsFactors")
-        options(stringsAsFactors = FALSE)
-        data <- do.call(rbind.data.frame, data)
-        options(stringsAsFactors = oldOpt)
-
-        names(data) <- colNames
-      } else {
-        # Check to see if data is a valid data.frame
-        stopifnot(deserializer == "byte")
-        stopifnot(class(data) == "data.frame")
-      }
-      output <- computeFunc(data)
-      if (serializer == "row") {
-        # Transform the result data.frame back to a list of rows
-        output <- split(output, seq(nrow(output)))
-      } else {
-        # Serialize the ouput to a byte array
-        stopifnot(serializer == "byte")
+    if (mode > 0) {
+      if (mode == 1) {
+        output <- compute(mode, partition, serializer, deserializer, NULL,
+                    colNames, computeFunc, data)
+       } else {
+        # gapply mode
+        for (i in 1:length(data)) {
+          # Timing reading input data for execution
+          inputElap <- elapsedSecs()
+          output <- compute(mode, partition, serializer, deserializer, keys[[i]],
+                      colNames, computeFunc, data[[i]])
+          computeElap <- elapsedSecs()
+          outputResult(serializer, output, outputCon)
+          outputElap <- elapsedSecs()
+          computeInputElapsDiff <-  computeInputElapsDiff + (computeElap - inputElap)
+          outputComputeElapsDiff <- outputComputeElapsDiff + (outputElap - computeElap)
+        }
       }
     } else {
-      output <- computeFunc(partition, data)
+      output <- compute(mode, partition, serializer, deserializer, NULL,
+                  colNames, computeFunc, data)
     }
-
-    # Timing computing
-    computeElap <- elapsedSecs()
-
-    if (serializer == "byte") {
-      SparkR:::writeRawSerialize(outputCon, output)
-    } else if (serializer == "row") {
-      SparkR:::writeRowSerialize(outputCon, output)
-    } else {
-      # write lines one-by-one with flag
-      lapply(output, function(line) SparkR:::writeString(outputCon, line))
+    if (mode != 2) {
+      # Not a gapply mode
+      computeElap <- elapsedSecs()
+      outputResult(serializer, output, outputCon)
+      outputElap <- elapsedSecs()
+      computeInputElapsDiff <- computeElap - inputElap
+      outputComputeElapsDiff <- outputElap - computeElap
     }
-    # Timing output
-    outputElap <- elapsedSecs()
   } else {
     if (deserializer == "byte") {
       # Now read as many characters as described in funcLen
@@ -189,11 +233,9 @@ if (isEmpty != 0) {
     }
     # Timing output
     outputElap <- elapsedSecs()
+    computeInputElapsDiff <- computeElap - inputElap
+    outputComputeElapsDiff <- outputElap - computeElap
   }
-} else {
-  inputElap <- broadcastElap
-  computeElap <- broadcastElap
-  outputElap <- broadcastElap
 }
 
 # Report timing
@@ -202,8 +244,8 @@ SparkR:::writeDouble(outputCon, bootTime)
 SparkR:::writeDouble(outputCon, initElap - bootElap)        # init
 SparkR:::writeDouble(outputCon, broadcastElap - initElap)   # broadcast
 SparkR:::writeDouble(outputCon, inputElap - broadcastElap)  # input
-SparkR:::writeDouble(outputCon, computeElap - inputElap)    # compute
-SparkR:::writeDouble(outputCon, outputElap - computeElap)   # output
+SparkR:::writeDouble(outputCon, computeInputElapsDiff)    # compute
+SparkR:::writeDouble(outputCon, outputComputeElapsDiff)   # output
 
 # End of output
 SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM)

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 24ad689..496fdf8 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -40,7 +40,8 @@ private[spark] class RRunner[U](
     broadcastVars: Array[Broadcast[Object]],
     numPartitions: Int = -1,
     isDataFrame: Boolean = false,
-    colNames: Array[String] = null)
+    colNames: Array[String] = null,
+    mode: Int = RRunnerModes.RDD)
   extends Logging {
   private var bootTime: Double = _
   private var dataStream: DataInputStream = _
@@ -148,8 +149,7 @@ private[spark] class RRunner[U](
           }
 
           dataOut.writeInt(numPartitions)
-
-          dataOut.writeInt(if (isDataFrame) 1 else 0)
+          dataOut.writeInt(mode)
 
           if (isDataFrame) {
             SerDe.writeObject(dataOut, colNames)
@@ -180,6 +180,13 @@ private[spark] class RRunner[U](
 
           for (elem <- iter) {
             elem match {
+              case (key, innerIter: Iterator[_]) =>
+                for (innerElem <- innerIter) {
+                  writeElem(innerElem)
+                }
+                // Writes key which can be used as a boundary in group-aggregate
+                dataOut.writeByte('r')
+                writeElem(key)
               case (key, value) =>
                 writeElem(key)
                 writeElem(value)
@@ -187,6 +194,7 @@ private[spark] class RRunner[U](
                 writeElem(elem)
             }
           }
+
           stream.flush()
         } catch {
           // TODO: We should propagate this error to the task thread
@@ -268,6 +276,12 @@ private object SpecialLengths {
   val TIMING_DATA = -1
 }
 
+private[spark] object RRunnerModes {
+  val RDD = 0
+  val DATAFRAME_DAPPLY = 1
+  val DATAFRAME_GAPPLY = 2
+}
+
 private[r] class BufferedStreamThread(
     in: InputStream,
     name: String,

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 78e8822..7beeeb4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -246,6 +246,55 @@ case class MapGroups(
     outputObjAttr: Attribute,
     child: LogicalPlan) extends UnaryNode with ObjectProducer
 
+/** Factory for constructing new `FlatMapGroupsInR` nodes. */
+object FlatMapGroupsInR {
+  def apply(
+      func: Array[Byte],
+      packageNames: Array[Byte],
+      broadcastVars: Array[Broadcast[Object]],
+      schema: StructType,
+      keyDeserializer: Expression,
+      valueDeserializer: Expression,
+      inputSchema: StructType,
+      groupingAttributes: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
+      child: LogicalPlan): LogicalPlan = {
+    val mapped = FlatMapGroupsInR(
+      func,
+      packageNames,
+      broadcastVars,
+      inputSchema,
+      schema,
+      UnresolvedDeserializer(keyDeserializer, groupingAttributes),
+      UnresolvedDeserializer(valueDeserializer, dataAttributes),
+      groupingAttributes,
+      dataAttributes,
+      CatalystSerde.generateObjAttr(RowEncoder(schema)),
+      child)
+    CatalystSerde.serialize(mapped)(RowEncoder(schema))
+  }
+}
+
+case class FlatMapGroupsInR(
+    func: Array[Byte],
+    packageNames: Array[Byte],
+    broadcastVars: Array[Broadcast[Object]],
+    inputSchema: StructType,
+    outputSchema: StructType,
+    keyDeserializer: Expression,
+    valueDeserializer: Expression,
+    groupingAttributes: Seq[Attribute],
+    dataAttributes: Seq[Attribute],
+    outputObjAttr: Attribute,
+    child: LogicalPlan) extends UnaryNode with ObjectProducer{
+
+  override lazy val schema = outputSchema
+
+  override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema,
+    keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr,
+    child)
+}
+
 /** Factory for constructing new `CoGroup` nodes. */
 object CoGroup {
   def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder](

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 49b6eab..1aa5767 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -20,14 +20,18 @@ package org.apache.spark.sql
 import scala.collection.JavaConverters._
 import scala.language.implicitConversions
 
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.api.r.SQLUtils._
 import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot}
 import org.apache.spark.sql.catalyst.util.usePrettyExpression
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.NumericType
+import org.apache.spark.sql.types.StructType
 
 /**
  * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]].
@@ -381,6 +385,48 @@ class RelationalGroupedDataset protected[sql](
   def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
     pivot(pivotColumn, values.asScala)
   }
+
+  /**
+   * Applies the given serialized R function `func` to each group of data. For each unique group,
+   * the function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[DataFrame]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * [[org.apache.spark.sql.expressions#Aggregator Aggregator]].
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+   * constraints of their cluster.
+   *
+   * @since 2.0.0
+   */
+  private[sql] def flatMapGroupsInR(
+      f: Array[Byte],
+      packageNames: Array[Byte],
+      broadcastVars: Array[Broadcast[Object]],
+      outputSchema: StructType): DataFrame = {
+      val groupingNamedExpressions = groupingExprs.map(alias)
+      val groupingCols = groupingNamedExpressions.map(Column(_))
+      val groupingDataFrame = df.select(groupingCols : _*)
+      val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
+      Dataset.ofRows(
+        df.sparkSession,
+        FlatMapGroupsInR(
+          f,
+          packageNames,
+          broadcastVars,
+          outputSchema,
+          groupingDataFrame.exprEnc.deserializer,
+          df.exprEnc.deserializer,
+          df.exprEnc.schema,
+          groupingAttributes,
+          df.logicalPlan.output,
+          df.logicalPlan))
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 486a440..fe426fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -26,7 +26,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.r.SerDe
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row, SaveMode, SQLContext}
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.types._
 
@@ -146,16 +146,26 @@ private[sql] object SQLUtils {
       packageNames: Array[Byte],
       broadcastVars: Array[Object],
       schema: StructType): DataFrame = {
-    val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])
-    val realSchema =
-      if (schema == null) {
-        SERIALIZED_R_DATA_SCHEMA
-      } else {
-        schema
-      }
+    val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]])
+    val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema
     df.mapPartitionsInR(func, packageNames, bv, realSchema)
   }
 
+  /**
+   * The helper function for gapply() on R side.
+   */
+  def gapply(
+      gd: RelationalGroupedDataset,
+      func: Array[Byte],
+      packageNames: Array[Byte],
+      broadcastVars: Array[Object],
+      schema: StructType): DataFrame = {
+    val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]])
+    val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema
+    gd.flatMapGroupsInR(func, packageNames, bv, realSchema)
+  }
+
+
   def dfToCols(df: DataFrame): Array[Array[Any]] = {
     val localDF: Array[Row] = df.collect()
     val numCols = df.columns.length

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 60466e2..8e2f2ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -337,6 +337,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
         execution.MapPartitionsExec(
           execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
+      case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
+        execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
+          data, objAttr, planLater(child)) :: Nil
       case logical.MapElements(f, objAttr, child) =>
         execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
       case logical.AppendColumns(f, in, out, child) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 5fced94..c7e2671 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -20,13 +20,17 @@ package org.apache.spark.sql.execution
 import scala.language.existentials
 
 import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.api.r._
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.api.r.SQLUtils._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.types.{DataType, ObjectType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
 
 
 /**
@@ -325,6 +329,72 @@ case class MapGroupsExec(
 }
 
 /**
+ * Groups the input rows together and calls the R function with each group and an iterator
+ * containing all elements in the group.
+ * The result of this function is flattened before being output.
+ */
+case class FlatMapGroupsInRExec(
+    func: Array[Byte],
+    packageNames: Array[Byte],
+    broadcastVars: Array[Broadcast[Object]],
+    inputSchema: StructType,
+    outputSchema: StructType,
+    keyDeserializer: Expression,
+    valueDeserializer: Expression,
+    groupingAttributes: Seq[Attribute],
+    dataAttributes: Seq[Attribute],
+    outputObjAttr: Attribute,
+    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
+
+  override def output: Seq[Attribute] = outputObjAttr :: Nil
+  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(groupingAttributes) :: Nil
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    val isSerializedRData =
+      if (outputSchema == SERIALIZED_R_DATA_SCHEMA) true else false
+    val serializerForR = if (!isSerializedRData) {
+      SerializationFormats.ROW
+    } else {
+      SerializationFormats.BYTE
+    }
+
+    child.execute().mapPartitionsInternal { iter =>
+      val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+      val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
+      val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
+      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
+      val runner = new RRunner[Array[Byte]](
+        func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars,
+        isDataFrame = true, colNames = inputSchema.fieldNames,
+        mode = RRunnerModes.DATAFRAME_GAPPLY)
+
+      val groupedRBytes = grouped.map { case (key, rowIter) =>
+        val deserializedIter = rowIter.map(getValue)
+        val newIter =
+          deserializedIter.asInstanceOf[Iterator[Row]].map { row => rowToRBytes(row) }
+        val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
+        (newKey, newIter)
+      }
+
+      val outputIter = runner.compute(groupedRBytes, -1)
+      if (!isSerializedRData) {
+        val result = outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
+        result.map(outputObject)
+      } else {
+        val result = outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
+        result.map(outputObject)
+      }
+    }
+  }
+}
+
+/**
  * Co-groups the data from left and right children, and calls the function with each group and 2
  * iterators containing all elements in the group from left and right side.
  * The result of this function is flattened before being output.

http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
index 6c76328..70539da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
@@ -17,8 +17,7 @@
 
 package org.apache.spark.sql.execution.r
 
-import org.apache.spark.api.r.RRunner
-import org.apache.spark.api.r.SerializationFormats
+import org.apache.spark.api.r._
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.api.r.SQLUtils._
 import org.apache.spark.sql.Row
@@ -55,7 +54,7 @@ private[sql] case class MapPartitionsRWrapper(
 
     val runner = new RRunner[Array[Byte]](
       func, deserializer, serializer, packageNames, broadcastVars,
-      isDataFrame = true, colNames = colNames)
+      isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY)
     // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
     val outputIter = runner.compute(newIter, -1)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org