Repository: spark
Updated Branches:
  refs/heads/branch-2.0 087bd2799 -> 10c476fc8


[SPARK-15294][R] Add `pivot` to SparkR

## What changes were proposed in this pull request?

This PR adds `pivot` function to SparkR for API parity. Since this PR is based 
on https://github.com/apache/spark/pull/13295 , mhnatiuk should be credited for 
the work he did.

## How was this patch tested?

Pass the Jenkins tests (including new testcase.)

Author: Dongjoon Hyun <[email protected]>

Closes #13786 from dongjoon-hyun/SPARK-15294.

(cherry picked from commit 217db56ba11fcdf9e3a81946667d1d99ad7344ee)
Signed-off-by: Shivaram Venkataraman <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/10c476fc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/10c476fc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/10c476fc

Branch: refs/heads/branch-2.0
Commit: 10c476fc8f4780e487d8ada626f6924866f5711f
Parents: 087bd27
Author: Dongjoon Hyun <[email protected]>
Authored: Mon Jun 20 21:09:39 2016 -0700
Committer: Shivaram Venkataraman <[email protected]>
Committed: Mon Jun 20 21:09:51 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                           |  1 +
 R/pkg/R/generics.R                        |  4 +++
 R/pkg/R/group.R                           | 43 ++++++++++++++++++++++++++
 R/pkg/inst/tests/testthat/test_sparkSQL.R | 25 +++++++++++++++
 4 files changed, 73 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/10c476fc/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 45663f4..ea42888 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -294,6 +294,7 @@ exportMethods("%in%",
 
 exportClasses("GroupedData")
 exportMethods("agg")
+exportMethods("pivot")
 
 export("as.DataFrame",
        "cacheTable",

http://git-wip-us.apache.org/repos/asf/spark/blob/10c476fc/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 3fb6370..c307de7 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -160,6 +160,10 @@ setGeneric("persist", function(x, newLevel) { 
standardGeneric("persist") })
 # @export
 setGeneric("pipeRDD", function(x, command, env = list()) { 
standardGeneric("pipeRDD")})
 
+# @rdname pivot
+# @export
+setGeneric("pivot", function(x, colname, values = list()) { 
standardGeneric("pivot") })
+
 # @rdname reduce
 # @export
 setGeneric("reduce", function(x, func) { standardGeneric("reduce") })

http://git-wip-us.apache.org/repos/asf/spark/blob/10c476fc/R/pkg/R/group.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 51e1516..0687f14 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -134,6 +134,49 @@ methods <- c("avg", "max", "mean", "min", "sum")
 # These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", 
"stddev_samp", "stddev_pop",
 # "variance", "var_samp", "var_pop"
 
+#' Pivot a column of the GroupedData and perform the specified aggregation.
+#'
+#' Pivot a column of the GroupedData and perform the specified aggregation.
+#' There are two versions of pivot function: one that requires the caller to 
specify the list
+#' of distinct values to pivot on, and one that does not. The latter is more 
concise but less
+#' efficient, because Spark needs to first compute the list of distinct values 
internally.
+#'
+#' @param x a GroupedData object
+#' @param colname A column name
+#' @param values A value or a list/vector of distinct values for the output 
columns.
+#' @return GroupedData object
+#' @rdname pivot
+#' @name pivot
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(data.frame(
+#'     earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
+#'     course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"),
+#'     period = c("1H", "1H", "2H", "2H", "1H", "1H", "2H", "2H"),
+#'     year = c(2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016)
+#' ))
+#' group_sum <- sum(pivot(groupBy(df, "year"), "course"), "earnings")
+#' group_min <- min(pivot(groupBy(df, "year"), "course", "R"), "earnings")
+#' group_max <- max(pivot(groupBy(df, "year"), "course", c("Python", "R")), 
"earnings")
+#' group_mean <- mean(pivot(groupBy(df, "year"), "course", list("Python", 
"R")), "earnings")
+#' }
+#' @note pivot since 2.0.0
+setMethod("pivot",
+          signature(x = "GroupedData", colname = "character"),
+          function(x, colname, values = list()){
+            stopifnot(length(colname) == 1)
+            if (length(values) == 0) {
+              result <- callJMethod(x@sgd, "pivot", colname)
+            } else {
+              if (length(values) > length(unique(values))) {
+                stop("Values are not unique")
+              }
+              result <- callJMethod(x@sgd, "pivot", colname, as.list(values))
+            }
+            groupedData(result)
+          })
+
 createMethod <- function(name) {
   setMethod(name,
             signature(x = "GroupedData"),

http://git-wip-us.apache.org/repos/asf/spark/blob/10c476fc/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R 
b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index d53c40d..7c192fb 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1398,6 +1398,31 @@ test_that("group by, agg functions", {
   unlink(jsonPath3)
 })
 
+test_that("pivot GroupedData column", {
+  df <- createDataFrame(data.frame(
+    earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
+    course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"),
+    year = c(2013, 2013, 2014, 2014, 2015, 2015, 2016, 2016)
+  ))
+  sum1 <- collect(sum(pivot(groupBy(df, "year"), "course"), "earnings"))
+  sum2 <- collect(sum(pivot(groupBy(df, "year"), "course", c("Python", "R")), 
"earnings"))
+  sum3 <- collect(sum(pivot(groupBy(df, "year"), "course", list("Python", 
"R")), "earnings"))
+  sum4 <- collect(sum(pivot(groupBy(df, "year"), "course", "R"), "earnings"))
+
+  correct_answer <- data.frame(
+    year = c(2013, 2014, 2015, 2016),
+    Python = c(10000, 15000, 20000, 22000),
+    R = c(10000, 11000, 12000, 21000)
+  )
+  expect_equal(sum1, correct_answer)
+  expect_equal(sum2, correct_answer)
+  expect_equal(sum3, correct_answer)
+  expect_equal(sum4, correct_answer[, c("year", "R")])
+
+  expect_error(collect(sum(pivot(groupBy(df, "year"), "course", c("R", "R")), 
"earnings")))
+  expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", 
"R")), "earnings")))
+})
+
 test_that("arrange() and orderBy() on a DataFrame", {
   df <- read.json(jsonPath)
   sorted <- arrange(df, df$age)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to