You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sh...@apache.org on 2016/06/16 04:42:21 UTC
spark git commit: [SPARK-12922][SPARKR][WIP] Implement gapply() on
DataFrame in SparkR
Repository: spark
Updated Branches:
refs/heads/master b75f454f9 -> 7c6c69263
[SPARK-12922][SPARKR][WIP] Implement gapply() on DataFrame in SparkR
## What changes were proposed in this pull request?
gapply() applies an R function on groups grouped by one or more columns of a DataFrame, and returns a DataFrame. It is like GroupedDataSet.flatMapGroups() in the Dataset API.
Please, let me know what do you think and if you have any ideas to improve it.
Thank you!
## How was this patch tested?
Unit tests.
1. Primitive test with different column types
2. Add a boolean column
3. Compute average by a group
Author: Narine Kokhlikyan <na...@gmail.com>
Author: NarineK <na...@us.ibm.com>
Closes #12836 from NarineK/gapply2.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7c6c6926
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7c6c6926
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7c6c6926
Branch: refs/heads/master
Commit: 7c6c6926376c93acc42dd56a399d816f4838f28c
Parents: b75f454
Author: Narine Kokhlikyan <na...@gmail.com>
Authored: Wed Jun 15 21:42:05 2016 -0700
Committer: Shivaram Venkataraman <sh...@cs.berkeley.edu>
Committed: Wed Jun 15 21:42:05 2016 -0700
----------------------------------------------------------------------
R/pkg/NAMESPACE | 1 +
R/pkg/R/DataFrame.R | 82 ++++++++++-
R/pkg/R/deserialize.R | 30 ++++
R/pkg/R/generics.R | 4 +
R/pkg/R/group.R | 62 +++++++++
R/pkg/inst/tests/testthat/test_sparkSQL.R | 65 +++++++++
R/pkg/inst/worker/worker.R | 138 ++++++++++++-------
.../scala/org/apache/spark/api/r/RRunner.scala | 20 ++-
.../sql/catalyst/plans/logical/object.scala | 49 +++++++
.../spark/sql/RelationalGroupedDataset.scala | 48 ++++++-
.../org/apache/spark/sql/api/r/SQLUtils.scala | 26 ++--
.../spark/sql/execution/SparkStrategies.scala | 3 +
.../apache/spark/sql/execution/objects.scala | 72 +++++++++-
.../sql/execution/r/MapPartitionsRWrapper.scala | 5 +-
14 files changed, 540 insertions(+), 65 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index a8cf53f..8db4d5c 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -62,6 +62,7 @@ exportMethods("arrange",
"filter",
"first",
"freqItems",
+ "gapply",
"group_by",
"groupBy",
"head",
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 0ff350d..9a9b3f7 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1181,7 +1181,7 @@ dapplyInternal <- function(x, func, schema) {
#' func should have only one parameter, to which a data.frame corresponds
#' to each partition will be passed.
#' The output of func should be a data.frame.
-#' @param schema The schema of the resulting DataFrame after the function is applied.
+#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
#' It must match the output of func.
#' @family SparkDataFrame functions
#' @rdname dapply
@@ -1267,6 +1267,86 @@ setMethod("dapplyCollect",
ldf
})
+#' gapply
+#'
+#' Group the SparkDataFrame using the specified columns and apply the R function to each
+#' group.
+#'
+#' @param x A SparkDataFrame
+#' @param cols Grouping columns
+#' @param func A function to be applied to each group partition specified by grouping
+#' column of the SparkDataFrame. The function `func` takes as argument
+#' a key - grouping columns and a data frame - a local R data.frame.
+#' The output of `func` is a local R data.frame.
+#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
+#' The schema must match to output of `func`. It has to be defined for each
+#' output column with preferred output column name and corresponding data type.
+#' @family SparkDataFrame functions
+#' @rdname gapply
+#' @name gapply
+#' @export
+#' @examples
+#'
+#' \dontrun{
+#' Computes the arithmetic mean of the second column by grouping
+#' on the first and third columns. Output the grouping values and the average.
+#'
+#' df <- createDataFrame (
+#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
+#' c("a", "b", "c", "d"))
+#'
+#' Here our output contains three columns, the key which is a combination of two
+#' columns with data types integer and string and the mean which is a double.
+#' schema <- structType(structField("a", "integer"), structField("c", "string"),
+#' structField("avg", "double"))
+#' df1 <- gapply(
+#' df,
+#' list("a", "c"),
+#' function(key, x) {
+#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+#' },
+#' schema)
+#' collect(df1)
+#'
+#' Result
+#' ------
+#' a c avg
+#' 3 3 3.0
+#' 1 1 1.5
+#'
+#' Fits linear models on iris dataset by grouping on the 'Species' column and
+#' using 'Sepal_Length' as a target variable, 'Sepal_Width', 'Petal_Length'
+#' and 'Petal_Width' as training features.
+#'
+#' df <- createDataFrame (iris)
+#' schema <- structType(structField("(Intercept)", "double"),
+#' structField("Sepal_Width", "double"),structField("Petal_Length", "double"),
+#' structField("Petal_Width", "double"))
+#' df1 <- gapply(
+#' df,
+#' list(df$"Species"),
+#' function(key, x) {
+#' m <- suppressWarnings(lm(Sepal_Length ~
+#' Sepal_Width + Petal_Length + Petal_Width, x))
+#' data.frame(t(coef(m)))
+#' }, schema)
+#' collect(df1)
+#'
+#'Result
+#'---------
+#' Model (Intercept) Sepal_Width Petal_Length Petal_Width
+#' 1 0.699883 0.3303370 0.9455356 -0.1697527
+#' 2 1.895540 0.3868576 0.9083370 -0.6792238
+#' 3 2.351890 0.6548350 0.2375602 0.2521257
+#'
+#'}
+setMethod("gapply",
+ signature(x = "SparkDataFrame"),
+ function(x, cols, func, schema) {
+ grouped <- do.call("groupBy", c(x, cols))
+ gapply(grouped, func, schema)
+ })
+
############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
# but allow for use with DataFrames by first converting to an RRDD before calling #
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/R/deserialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index ce071b1..0e99b17 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -197,6 +197,36 @@ readMultipleObjects <- function(inputCon) {
data # this is a list of named lists now
}
+readMultipleObjectsWithKeys <- function(inputCon) {
+ # readMultipleObjectsWithKeys will read multiple continuous objects from
+ # a DataOutputStream. There is no preceding field telling the count
+ # of the objects, so the number of objects varies, we try to read
+ # all objects in a loop until the end of the stream. This function
+ # is for use by gapply. Each group of rows is followed by the grouping
+ # key for this group which is then followed by next group.
+ keys <- list()
+ data <- list()
+ subData <- list()
+ while (TRUE) {
+ # If reaching the end of the stream, type returned should be "".
+ type <- readType(inputCon)
+ if (type == "") {
+ break
+ } else if (type == "r") {
+ type <- readType(inputCon)
+ # A grouping boundary detected
+ key <- readTypedObject(inputCon, type)
+ index <- length(data) + 1L
+ data[[index]] <- subData
+ keys[[index]] <- key
+ subData <- list()
+ } else {
+ subData[[length(subData) + 1L]] <- readTypedObject(inputCon, type)
+ }
+ }
+ list(keys = keys, data = data) # this is a list of keys and corresponding data
+}
+
readRowList <- function(obj) {
# readRowList is meant for use inside an lapply. As a result, it is
# necessary to open a standalone connection for the row and consume
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 50fc204..40a96d8 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -454,6 +454,10 @@ setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
#' @export
setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") })
+#' @rdname gapply
+#' @export
+setGeneric("gapply", function(x, ...) { standardGeneric("gapply") })
+
#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/R/group.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 08f4a49..b704776 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -142,3 +142,65 @@ createMethods <- function() {
}
createMethods()
+
+#' gapply
+#'
+#' Applies a R function to each group in the input GroupedData
+#'
+#' @param x a GroupedData
+#' @param func A function to be applied to each group partition specified by GroupedData.
+#' The function `func` takes as argument a key - grouping columns and
+#' a data frame - a local R data.frame.
+#' The output of `func` is a local R data.frame.
+#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
+#' The schema must match to output of `func`. It has to be defined for each
+#' output column with preferred output column name and corresponding data type.
+#' @return a SparkDataFrame
+#' @rdname gapply
+#' @name gapply
+#' @examples
+#' \dontrun{
+#' Computes the arithmetic mean of the second column by grouping
+#' on the first and third columns. Output the grouping values and the average.
+#'
+#' df <- createDataFrame (
+#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
+#' c("a", "b", "c", "d"))
+#'
+#' Here our output contains three columns, the key which is a combination of two
+#' columns with data types integer and string and the mean which is a double.
+#' schema <- structType(structField("a", "integer"), structField("c", "string"),
+#' structField("avg", "double"))
+#' df1 <- gapply(
+#' df,
+#' list("a", "c"),
+#' function(key, x) {
+#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+#' },
+#' schema)
+#' collect(df1)
+#'
+#' Result
+#' ------
+#' a c avg
+#' 3 3 3.0
+#' 1 1 1.5
+#' }
+setMethod("gapply",
+ signature(x = "GroupedData"),
+ function(x, func, schema) {
+ try(if (is.null(schema)) stop("schema cannot be NULL"))
+ packageNamesArr <- serialize(.sparkREnv[[".packages"]],
+ connection = NULL)
+ broadcastArr <- lapply(ls(.broadcastNames),
+ function(name) { get(name, .broadcastNames) })
+ sdf <- callJStatic(
+ "org.apache.spark.sql.api.r.SQLUtils",
+ "gapply",
+ x@sgd,
+ serialize(cleanClosure(func), connection = NULL),
+ packageNamesArr,
+ broadcastArr,
+ schema$jobj)
+ dataFrame(sdf)
+ })
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index d1ca3b7..c11930a 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2146,6 +2146,71 @@ test_that("repartition by columns on DataFrame", {
expect_equal(nrow(df1), 2)
})
+test_that("gapply() on a DataFrame", {
+ df <- createDataFrame (
+ list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
+ c("a", "b", "c", "d"))
+ expected <- collect(df)
+ df1 <- gapply(df, list("a"), function(key, x) { x }, schema(df))
+ actual <- collect(df1)
+ expect_identical(actual, expected)
+
+ # Computes the sum of second column by grouping on the first and third columns
+ # and checks if the sum is larger than 2
+ schema <- structType(structField("a", "integer"), structField("e", "boolean"))
+ df2 <- gapply(
+ df,
+ list(df$"a", df$"c"),
+ function(key, x) {
+ y <- data.frame(key[1], sum(x$b) > 2)
+ },
+ schema)
+ actual <- collect(df2)$e
+ expected <- c(TRUE, TRUE)
+ expect_identical(actual, expected)
+
+ # Computes the arithmetic mean of the second column by grouping
+ # on the first and third columns. Output the groupping value and the average.
+ schema <- structType(structField("a", "integer"), structField("c", "string"),
+ structField("avg", "double"))
+ df3 <- gapply(
+ df,
+ list("a", "c"),
+ function(key, x) {
+ y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+ },
+ schema)
+ actual <- collect(df3)
+ actual <- actual[order(actual$a), ]
+ rownames(actual) <- NULL
+ expected <- collect(select(df, "a", "b", "c"))
+ expected <- data.frame(aggregate(expected$b, by = list(expected$a, expected$c), FUN = mean))
+ colnames(expected) <- c("a", "c", "avg")
+ expected <- expected[order(expected$a), ]
+ rownames(expected) <- NULL
+ expect_identical(actual, expected)
+
+ irisDF <- suppressWarnings(createDataFrame (iris))
+ schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double"))
+ # Groups by `Sepal_Length` and computes the average for `Sepal_Width`
+ df4 <- gapply(
+ cols = list("Sepal_Length"),
+ irisDF,
+ function(key, x) {
+ y <- data.frame(key, mean(x$Sepal_Width), stringsAsFactors = FALSE)
+ },
+ schema)
+ actual <- collect(df4)
+ actual <- actual[order(actual$Sepal_Length), ]
+ rownames(actual) <- NULL
+ agg_local_df <- data.frame(aggregate(iris$Sepal.Width, by = list(iris$Sepal.Length), FUN = mean),
+ stringsAsFactors = FALSE)
+ colnames(agg_local_df) <- c("Sepal_Length", "Avg")
+ expected <- agg_local_df[order(agg_local_df$Sepal_Length), ]
+ rownames(expected) <- NULL
+ expect_identical(actual, expected)
+})
+
test_that("Window functions on a DataFrame", {
setHiveContext(sc)
df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")),
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/R/pkg/inst/worker/worker.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 40cda0c..debf018 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -27,6 +27,54 @@ elapsedSecs <- function() {
proc.time()[3]
}
+compute <- function(mode, partition, serializer, deserializer, key,
+ colNames, computeFunc, inputData) {
+ if (mode > 0) {
+ if (deserializer == "row") {
+ # Transform the list of rows into a data.frame
+ # Note that the optional argument stringsAsFactors for rbind is
+ # available since R 3.2.4. So we set the global option here.
+ oldOpt <- getOption("stringsAsFactors")
+ options(stringsAsFactors = FALSE)
+ inputData <- do.call(rbind.data.frame, inputData)
+ options(stringsAsFactors = oldOpt)
+
+ names(inputData) <- colNames
+ } else {
+ # Check to see if inputData is a valid data.frame
+ stopifnot(deserializer == "byte")
+ stopifnot(class(inputData) == "data.frame")
+ }
+
+ if (mode == 2) {
+ output <- computeFunc(key, inputData)
+ } else {
+ output <- computeFunc(inputData)
+ }
+ if (serializer == "row") {
+ # Transform the result data.frame back to a list of rows
+ output <- split(output, seq(nrow(output)))
+ } else {
+ # Serialize the ouput to a byte array
+ stopifnot(serializer == "byte")
+ }
+ } else {
+ output <- computeFunc(partition, inputData)
+ }
+ return (output)
+}
+
+outputResult <- function(serializer, output, outputCon) {
+ if (serializer == "byte") {
+ SparkR:::writeRawSerialize(outputCon, output)
+ } else if (serializer == "row") {
+ SparkR:::writeRowSerialize(outputCon, output)
+ } else {
+ # write lines one-by-one with flag
+ lapply(output, function(line) SparkR:::writeString(outputCon, line))
+ }
+}
+
# Constants
specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L)
@@ -79,75 +127,71 @@ if (numBroadcastVars > 0) {
# Timing broadcast
broadcastElap <- elapsedSecs()
+# Initial input timing
+inputElap <- broadcastElap
# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)
-isDataFrame <- as.logical(SparkR:::readInt(inputCon))
+# 0 - RDD mode, 1 - dapply mode, 2 - gapply mode
+mode <- SparkR:::readInt(inputCon)
-# If isDataFrame, then read column names
-if (isDataFrame) {
+if (mode > 0) {
colNames <- SparkR:::readObject(inputCon)
}
isEmpty <- SparkR:::readInt(inputCon)
+computeInputElapsDiff <- 0
+outputComputeElapsDiff <- 0
if (isEmpty != 0) {
-
if (numPartitions == -1) {
if (deserializer == "byte") {
# Now read as many characters as described in funcLen
data <- SparkR:::readDeserialize(inputCon)
} else if (deserializer == "string") {
data <- as.list(readLines(inputCon))
+ } else if (deserializer == "row" && mode == 2) {
+ dataWithKeys <- SparkR:::readMultipleObjectsWithKeys(inputCon)
+ keys <- dataWithKeys$keys
+ data <- dataWithKeys$data
} else if (deserializer == "row") {
data <- SparkR:::readMultipleObjects(inputCon)
}
+
# Timing reading input data for execution
inputElap <- elapsedSecs()
-
- if (isDataFrame) {
- if (deserializer == "row") {
- # Transform the list of rows into a data.frame
- # Note that the optional argument stringsAsFactors for rbind is
- # available since R 3.2.4. So we set the global option here.
- oldOpt <- getOption("stringsAsFactors")
- options(stringsAsFactors = FALSE)
- data <- do.call(rbind.data.frame, data)
- options(stringsAsFactors = oldOpt)
-
- names(data) <- colNames
- } else {
- # Check to see if data is a valid data.frame
- stopifnot(deserializer == "byte")
- stopifnot(class(data) == "data.frame")
- }
- output <- computeFunc(data)
- if (serializer == "row") {
- # Transform the result data.frame back to a list of rows
- output <- split(output, seq(nrow(output)))
- } else {
- # Serialize the ouput to a byte array
- stopifnot(serializer == "byte")
+ if (mode > 0) {
+ if (mode == 1) {
+ output <- compute(mode, partition, serializer, deserializer, NULL,
+ colNames, computeFunc, data)
+ } else {
+ # gapply mode
+ for (i in 1:length(data)) {
+ # Timing reading input data for execution
+ inputElap <- elapsedSecs()
+ output <- compute(mode, partition, serializer, deserializer, keys[[i]],
+ colNames, computeFunc, data[[i]])
+ computeElap <- elapsedSecs()
+ outputResult(serializer, output, outputCon)
+ outputElap <- elapsedSecs()
+ computeInputElapsDiff <- computeInputElapsDiff + (computeElap - inputElap)
+ outputComputeElapsDiff <- outputComputeElapsDiff + (outputElap - computeElap)
+ }
}
} else {
- output <- computeFunc(partition, data)
+ output <- compute(mode, partition, serializer, deserializer, NULL,
+ colNames, computeFunc, data)
}
-
- # Timing computing
- computeElap <- elapsedSecs()
-
- if (serializer == "byte") {
- SparkR:::writeRawSerialize(outputCon, output)
- } else if (serializer == "row") {
- SparkR:::writeRowSerialize(outputCon, output)
- } else {
- # write lines one-by-one with flag
- lapply(output, function(line) SparkR:::writeString(outputCon, line))
+ if (mode != 2) {
+ # Not a gapply mode
+ computeElap <- elapsedSecs()
+ outputResult(serializer, output, outputCon)
+ outputElap <- elapsedSecs()
+ computeInputElapsDiff <- computeElap - inputElap
+ outputComputeElapsDiff <- outputElap - computeElap
}
- # Timing output
- outputElap <- elapsedSecs()
} else {
if (deserializer == "byte") {
# Now read as many characters as described in funcLen
@@ -189,11 +233,9 @@ if (isEmpty != 0) {
}
# Timing output
outputElap <- elapsedSecs()
+ computeInputElapsDiff <- computeElap - inputElap
+ outputComputeElapsDiff <- outputElap - computeElap
}
-} else {
- inputElap <- broadcastElap
- computeElap <- broadcastElap
- outputElap <- broadcastElap
}
# Report timing
@@ -202,8 +244,8 @@ SparkR:::writeDouble(outputCon, bootTime)
SparkR:::writeDouble(outputCon, initElap - bootElap) # init
SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast
SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input
-SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute
-SparkR:::writeDouble(outputCon, outputElap - computeElap) # output
+SparkR:::writeDouble(outputCon, computeInputElapsDiff) # compute
+SparkR:::writeDouble(outputCon, outputComputeElapsDiff) # output
# End of output
SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM)
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 24ad689..496fdf8 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -40,7 +40,8 @@ private[spark] class RRunner[U](
broadcastVars: Array[Broadcast[Object]],
numPartitions: Int = -1,
isDataFrame: Boolean = false,
- colNames: Array[String] = null)
+ colNames: Array[String] = null,
+ mode: Int = RRunnerModes.RDD)
extends Logging {
private var bootTime: Double = _
private var dataStream: DataInputStream = _
@@ -148,8 +149,7 @@ private[spark] class RRunner[U](
}
dataOut.writeInt(numPartitions)
-
- dataOut.writeInt(if (isDataFrame) 1 else 0)
+ dataOut.writeInt(mode)
if (isDataFrame) {
SerDe.writeObject(dataOut, colNames)
@@ -180,6 +180,13 @@ private[spark] class RRunner[U](
for (elem <- iter) {
elem match {
+ case (key, innerIter: Iterator[_]) =>
+ for (innerElem <- innerIter) {
+ writeElem(innerElem)
+ }
+ // Writes key which can be used as a boundary in group-aggregate
+ dataOut.writeByte('r')
+ writeElem(key)
case (key, value) =>
writeElem(key)
writeElem(value)
@@ -187,6 +194,7 @@ private[spark] class RRunner[U](
writeElem(elem)
}
}
+
stream.flush()
} catch {
// TODO: We should propagate this error to the task thread
@@ -268,6 +276,12 @@ private object SpecialLengths {
val TIMING_DATA = -1
}
+private[spark] object RRunnerModes {
+ val RDD = 0
+ val DATAFRAME_DAPPLY = 1
+ val DATAFRAME_GAPPLY = 2
+}
+
private[r] class BufferedStreamThread(
in: InputStream,
name: String,
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 78e8822..7beeeb4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -246,6 +246,55 @@ case class MapGroups(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectProducer
+/** Factory for constructing new `FlatMapGroupsInR` nodes. */
+object FlatMapGroupsInR {
+ def apply(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ schema: StructType,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ inputSchema: StructType,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ child: LogicalPlan): LogicalPlan = {
+ val mapped = FlatMapGroupsInR(
+ func,
+ packageNames,
+ broadcastVars,
+ inputSchema,
+ schema,
+ UnresolvedDeserializer(keyDeserializer, groupingAttributes),
+ UnresolvedDeserializer(valueDeserializer, dataAttributes),
+ groupingAttributes,
+ dataAttributes,
+ CatalystSerde.generateObjAttr(RowEncoder(schema)),
+ child)
+ CatalystSerde.serialize(mapped)(RowEncoder(schema))
+ }
+}
+
+case class FlatMapGroupsInR(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectProducer{
+
+ override lazy val schema = outputSchema
+
+ override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema,
+ keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr,
+ child)
+}
+
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder](
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 49b6eab..1aa5767 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -20,14 +20,18 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
import scala.language.implicitConversions
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot}
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.NumericType
+import org.apache.spark.sql.types.StructType
/**
* A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]].
@@ -381,6 +385,48 @@ class RelationalGroupedDataset protected[sql](
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
pivot(pivotColumn, values.asScala)
}
+
+ /**
+ * Applies the given serialized R function `func` to each group of data. For each unique group,
+ * the function will be passed the group key and an iterator that contains all of the elements in
+ * the group. The function can return an iterator containing elements of an arbitrary type which
+ * will be returned as a new [[DataFrame]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an
+ * [[org.apache.spark.sql.expressions#Aggregator Aggregator]].
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+ * constraints of their cluster.
+ *
+ * @since 2.0.0
+ */
+ private[sql] def flatMapGroupsInR(
+ f: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ outputSchema: StructType): DataFrame = {
+ val groupingNamedExpressions = groupingExprs.map(alias)
+ val groupingCols = groupingNamedExpressions.map(Column(_))
+ val groupingDataFrame = df.select(groupingCols : _*)
+ val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
+ Dataset.ofRows(
+ df.sparkSession,
+ FlatMapGroupsInR(
+ f,
+ packageNames,
+ broadcastVars,
+ outputSchema,
+ groupingDataFrame.exprEnc.deserializer,
+ df.exprEnc.deserializer,
+ df.exprEnc.schema,
+ groupingAttributes,
+ df.logicalPlan.output,
+ df.logicalPlan))
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 486a440..fe426fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -26,7 +26,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
@@ -146,16 +146,26 @@ private[sql] object SQLUtils {
packageNames: Array[Byte],
broadcastVars: Array[Object],
schema: StructType): DataFrame = {
- val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])
- val realSchema =
- if (schema == null) {
- SERIALIZED_R_DATA_SCHEMA
- } else {
- schema
- }
+ val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]])
+ val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema
df.mapPartitionsInR(func, packageNames, bv, realSchema)
}
+ /**
+ * The helper function for gapply() on R side.
+ */
+ def gapply(
+ gd: RelationalGroupedDataset,
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Object],
+ schema: StructType): DataFrame = {
+ val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]])
+ val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema
+ gd.flatMapGroupsInR(func, packageNames, bv, realSchema)
+ }
+
+
def dfToCols(df: DataFrame): Array[Array[Any]] = {
val localDF: Array[Row] = df.collect()
val numCols = df.columns.length
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 60466e2..8e2f2ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -337,6 +337,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
execution.MapPartitionsExec(
execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
+ case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
+ execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
+ data, objAttr, planLater(child)) :: Nil
case logical.MapElements(f, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 5fced94..c7e2671 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -20,13 +20,17 @@ package org.apache.spark.sql.execution
import scala.language.existentials
import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.api.r._
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.types.{DataType, ObjectType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
/**
@@ -325,6 +329,72 @@ case class MapGroupsExec(
}
/**
+ * Groups the input rows together and calls the R function with each group and an iterator
+ * containing all elements in the group.
+ * The result of this function is flattened before being output.
+ */
+case class FlatMapGroupsInRExec(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
+
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(groupingAttributes) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val isSerializedRData =
+ if (outputSchema == SERIALIZED_R_DATA_SCHEMA) true else false
+ val serializerForR = if (!isSerializedRData) {
+ SerializationFormats.ROW
+ } else {
+ SerializationFormats.BYTE
+ }
+
+ child.execute().mapPartitionsInternal { iter =>
+ val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+ val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
+ val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
+ val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
+ val runner = new RRunner[Array[Byte]](
+ func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars,
+ isDataFrame = true, colNames = inputSchema.fieldNames,
+ mode = RRunnerModes.DATAFRAME_GAPPLY)
+
+ val groupedRBytes = grouped.map { case (key, rowIter) =>
+ val deserializedIter = rowIter.map(getValue)
+ val newIter =
+ deserializedIter.asInstanceOf[Iterator[Row]].map { row => rowToRBytes(row) }
+ val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
+ (newKey, newIter)
+ }
+
+ val outputIter = runner.compute(groupedRBytes, -1)
+ if (!isSerializedRData) {
+ val result = outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
+ result.map(outputObject)
+ } else {
+ val result = outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
+ result.map(outputObject)
+ }
+ }
+ }
+}
+
+/**
* Co-groups the data from left and right children, and calls the function with each group and 2
* iterators containing all elements in the group from left and right side.
* The result of this function is flattened before being output.
http://git-wip-us.apache.org/repos/asf/spark/blob/7c6c6926/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
index 6c76328..70539da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.execution.r
-import org.apache.spark.api.r.RRunner
-import org.apache.spark.api.r.SerializationFormats
+import org.apache.spark.api.r._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.Row
@@ -55,7 +54,7 @@ private[sql] case class MapPartitionsRWrapper(
val runner = new RRunner[Array[Byte]](
func, deserializer, serializer, packageNames, broadcastVars,
- isDataFrame = true, colNames = colNames)
+ isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY)
// Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
val outputIter = runner.compute(newIter, -1)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org