Repository: spark
Updated Branches:
  refs/heads/branch-2.0 d4bb9a3ff -> ca0802fd5


[SPARK-16005][R] Add `randomSplit` to SparkR

## What changes were proposed in this pull request?

This PR adds `randomSplit` to SparkR for API parity.

## How was this patch tested?

Pass the Jenkins tests (with new testcase.)

Author: Dongjoon Hyun <dongj...@apache.org>

Closes #13721 from dongjoon-hyun/SPARK-16005.

(cherry picked from commit 7d65a0db4a231882200513836f2720f59b35f364)
Signed-off-by: Shivaram Venkataraman <shiva...@cs.berkeley.edu>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca0802fd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca0802fd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca0802fd

Branch: refs/heads/branch-2.0
Commit: ca0802fd55f42fdcdd98533ee515d40d9f04a4b3
Parents: d4bb9a3
Author: Dongjoon Hyun <dongj...@apache.org>
Authored: Fri Jun 17 16:07:33 2016 -0700
Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu>
Committed: Fri Jun 17 16:07:41 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                           |  1 +
 R/pkg/R/DataFrame.R                       | 37 ++++++++++++++++++++++++++
 R/pkg/R/generics.R                        |  4 +++
 R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 +++++++++++++
 4 files changed, 60 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca0802fd/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 5db43ae..9412ec3 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -81,6 +81,7 @@ exportMethods("arrange",
               "orderBy",
               "persist",
               "printSchema",
+              "randomSplit",
               "rbind",
               "registerTempTable",
               "rename",

http://git-wip-us.apache.org/repos/asf/spark/blob/ca0802fd/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 231e4f0..4e04456 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2934,3 +2934,40 @@ setMethod("write.jdbc",
             write <- callJMethod(write, "mode", jmode)
             invisible(callJMethod(write, "jdbc", url, tableName, jprops))
           })
+
+#' randomSplit
+#'
+#' Return a list of randomly split dataframes with the provided weights.
+#'
+#' @param x A SparkDataFrame
+#' @param weights A vector of weights for splits, will be normalized if they 
don't sum to 1
+#' @param seed A seed to use for random split
+#'
+#' @family SparkDataFrame functions
+#' @rdname randomSplit
+#' @name randomSplit
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' df <- createDataFrame(data.frame(id = 1:1000))
+#' df_list <- randomSplit(df, c(2, 3, 5), 0)
+#' # df_list contains 3 SparkDataFrames with each having about 200, 300 and 
500 rows respectively
+#' sapply(df_list, count)
+#' }
+#' @note since 2.0.0
+setMethod("randomSplit",
+          signature(x = "SparkDataFrame", weights = "numeric"),
+          function(x, weights, seed) {
+            if (!all(sapply(weights, function(c) { c >= 0 }))) {
+              stop("all weight values should not be negative")
+            }
+            normalized_list <- as.list(weights / sum(weights))
+            if (!missing(seed)) {
+              sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list, 
as.integer(seed))
+            } else {
+              sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list)
+            }
+            sapply(sdfs, dataFrame)
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/ca0802fd/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 594bf2e..6e754af 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -679,6 +679,10 @@ setGeneric("withColumnRenamed",
 #' @export
 setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
 
+#' @rdname randomSplit
+#' @export
+setGeneric("randomSplit", function(x, weights, seed) { 
standardGeneric("randomSplit") })
+
 ###################### Column Methods ##########################
 
 #' @rdname column

http://git-wip-us.apache.org/repos/asf/spark/blob/ca0802fd/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R 
b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 7aa03a9..607bd9c 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2280,6 +2280,24 @@ test_that("createDataFrame sqlContext parameter backward 
compatibility", {
   expect_equal(collect(before), collect(after))
 })
 
+test_that("randomSplit", {
+  num <- 4000
+  df <- createDataFrame(data.frame(id = 1:num))
+
+  weights <- c(2, 3, 5)
+  df_list <- randomSplit(df, weights)
+  expect_equal(length(weights), length(df_list))
+  counts <- sapply(df_list, count)
+  expect_equal(num, sum(counts))
+  expect_true(all(sapply(abs(counts / num - weights / sum(weights)), 
function(e) { e < 0.05 })))
+
+  df_list <- randomSplit(df, weights, 0)
+  expect_equal(length(weights), length(df_list))
+  counts <- sapply(df_list, count)
+  expect_equal(num, sum(counts))
+  expect_true(all(sapply(abs(counts / num - weights / sum(weights)), 
function(e) { e < 0.05 })))
+})
+
 unlink(parquetPath)
 unlink(jsonPath)
 unlink(jsonPathNa)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to