Repository: spark
Updated Branches:
  refs/heads/branch-2.0 0d24fe09a -> 54c04aa5d


[SPARK-15202][SPARKR] add dapplyCollect() method for DataFrame in SparkR.

## What changes were proposed in this pull request?

dapplyCollect() applies an R function on each partition of a SparkDataFrame and 
collects the result back to R as a data.frame.
```
dapplyCollect(df, function(ldf) {...})
```

## How was this patch tested?
SparkR unit tests.

Author: Sun Rui <sunrui2...@gmail.com>

Closes #12989 from sun-rui/SPARK-15202.

(cherry picked from commit b3930f74a0929b2cdcbbe5cbe34f0b1d35eb01cc)
Signed-off-by: Shivaram Venkataraman <shiva...@cs.berkeley.edu>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/54c04aa5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/54c04aa5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/54c04aa5

Branch: refs/heads/branch-2.0
Commit: 54c04aa5d0a6012eb58efd0e7cf6d1d287818fa8
Parents: 0d24fe0
Author: Sun Rui <sunrui2...@gmail.com>
Authored: Thu May 12 17:50:55 2016 -0700
Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu>
Committed: Thu May 12 17:51:02 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                           |  1 +
 R/pkg/R/DataFrame.R                       | 86 +++++++++++++++++++++-----
 R/pkg/R/generics.R                        |  4 ++
 R/pkg/inst/tests/testthat/test_sparkSQL.R | 21 ++++++-
 4 files changed, 95 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/54c04aa5/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 1432ab8..239ad06 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -47,6 +47,7 @@ exportMethods("arrange",
               "covar_pop",
               "crosstab",
               "dapply",
+              "dapplyCollect",
               "describe",
               "dim",
               "distinct",

http://git-wip-us.apache.org/repos/asf/spark/blob/54c04aa5/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 43c46b8..0c2a194 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1153,9 +1153,27 @@ setMethod("summarize",
             agg(x, ...)
           })
 
+dapplyInternal <- function(x, func, schema) {
+  packageNamesArr <- serialize(.sparkREnv[[".packages"]],
+                               connection = NULL)
+
+  broadcastArr <- lapply(ls(.broadcastNames),
+                         function(name) { get(name, .broadcastNames) })
+
+  sdf <- callJStatic(
+           "org.apache.spark.sql.api.r.SQLUtils",
+           "dapply",
+           x@sdf,
+           serialize(cleanClosure(func), connection = NULL),
+           packageNamesArr,
+           broadcastArr,
+           if (is.null(schema)) { schema } else { schema$jobj })
+  dataFrame(sdf)
+}
+
 #' dapply
 #'
-#' Apply a function to each partition of a DataFrame.
+#' Apply a function to each partition of a SparkDataFrame.
 #'
 #' @param x A SparkDataFrame
 #' @param func A function to be applied to each partition of the 
SparkDataFrame.
@@ -1197,21 +1215,57 @@ setMethod("summarize",
 setMethod("dapply",
           signature(x = "SparkDataFrame", func = "function", schema = 
"structType"),
           function(x, func, schema) {
-            packageNamesArr <- serialize(.sparkREnv[[".packages"]],
-                                         connection = NULL)
-
-            broadcastArr <- lapply(ls(.broadcastNames),
-                                   function(name) { get(name, .broadcastNames) 
})
-
-            sdf <- callJStatic(
-                     "org.apache.spark.sql.api.r.SQLUtils",
-                     "dapply",
-                     x@sdf,
-                     serialize(cleanClosure(func), connection = NULL),
-                     packageNamesArr,
-                     broadcastArr,
-                     schema$jobj)
-            dataFrame(sdf)
+            dapplyInternal(x, func, schema)
+          })
+
+#' dapplyCollect
+#'
+#' Apply a function to each partition of a SparkDataFrame and collect the 
result back
+#’ to R as a data.frame.
+#'
+#' @param x A SparkDataFrame
+#' @param func A function to be applied to each partition of the 
SparkDataFrame.
+#'             func should have only one parameter, to which a data.frame 
corresponds
+#'             to each partition will be passed.
+#'             The output of func should be a data.frame.
+#' @family SparkDataFrame functions
+#' @rdname dapply
+#' @name dapplyCollect
+#' @export
+#' @examples
+#' \dontrun{
+#'   df <- createDataFrame (sqlContext, iris)
+#'   ldf <- dapplyCollect(df, function(x) { x })
+#'
+#'   # filter and add a column
+#'   df <- createDataFrame (
+#'           sqlContext, 
+#'           list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
+#'           c("a", "b", "c"))
+#'   ldf <- dapplyCollect(
+#'            df,
+#'            function(x) {
+#'              y <- x[x[1] > 1, ]
+#'              y <- cbind(y, y[1] + 1L)
+#'            })
+#'   # the result
+#'   #       a b c d
+#'   #       2 2 2 3
+#'   #       3 3 3 4
+#' }
+setMethod("dapplyCollect",
+          signature(x = "SparkDataFrame", func = "function"),
+          function(x, func) {
+            df <- dapplyInternal(x, func, NULL)
+
+            content <- callJMethod(df@sdf, "collect")
+            # content is a list of items of struct type. Each item has a 
single field
+            # which is a serialized data.frame corresponds to one partition of 
the
+            # SparkDataFrame.
+            ldfs <- lapply(content, function(x) { unserialize(x[[1]]) })
+            ldf <- do.call(rbind, ldfs)
+            row.names(ldf) <- NULL
+            ldf
           })
 
 ############################## RDD Map Functions 
##################################

http://git-wip-us.apache.org/repos/asf/spark/blob/54c04aa5/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 8563be1..ed76ad6 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -450,6 +450,10 @@ setGeneric("covar_pop", function(col1, col2) 
{standardGeneric("covar_pop") })
 #' @export
 setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
 
+#' @rdname dapply
+#' @export
+setGeneric("dapplyCollect", function(x, func) { 
standardGeneric("dapplyCollect") })
+
 #' @rdname summary
 #' @export
 setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })

http://git-wip-us.apache.org/repos/asf/spark/blob/54c04aa5/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R 
b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 0f67bc2..6a99b43 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2043,7 +2043,7 @@ test_that("Histogram", {
   expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1))
 })
 
-test_that("dapply() on a DataFrame", {
+test_that("dapply() and dapplyCollect() on a DataFrame", {
   df <- createDataFrame (
           sqlContext,
           list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
@@ -2053,6 +2053,8 @@ test_that("dapply() on a DataFrame", {
   result <- collect(df1)
   expect_identical(ldf, result)
 
+  result <- dapplyCollect(df, function(x) { x })
+  expect_identical(ldf, result)
 
   # Filter and add a column
   schema <- structType(structField("a", "integer"), structField("b", "double"),
@@ -2070,6 +2072,16 @@ test_that("dapply() on a DataFrame", {
   rownames(expected) <- NULL
   expect_identical(expected, result)
 
+  result <- dapplyCollect(
+              df,
+              function(x) {
+                y <- x[x$a > 1, ]
+                y <- cbind(y, y$a + 1L)
+              })
+  expected1 <- expected
+  names(expected1) <- names(result)
+  expect_identical(expected1, result)
+
   # Remove the added column
   df2 <- dapply(
            df1,
@@ -2080,6 +2092,13 @@ test_that("dapply() on a DataFrame", {
   result <- collect(df2)
   expected <- expected[, c("a", "b", "c")]
   expect_identical(expected, result)
+
+  result <- dapplyCollect(
+              df1,
+              function(x) {
+               x[, c("a", "b", "c")]
+              })
+  expect_identical(expected, result)
 })
 
 test_that("repartition by columns on DataFrame", {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to