Repository: spark Updated Branches: refs/heads/branch-2.0 d4bb9a3ff -> ca0802fd5
[SPARK-16005][R] Add `randomSplit` to SparkR ## What changes were proposed in this pull request? This PR adds `randomSplit` to SparkR for API parity. ## How was this patch tested? Pass the Jenkins tests (with new testcase.) Author: Dongjoon Hyun <dongj...@apache.org> Closes #13721 from dongjoon-hyun/SPARK-16005. (cherry picked from commit 7d65a0db4a231882200513836f2720f59b35f364) Signed-off-by: Shivaram Venkataraman <shiva...@cs.berkeley.edu> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca0802fd Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca0802fd Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca0802fd Branch: refs/heads/branch-2.0 Commit: ca0802fd55f42fdcdd98533ee515d40d9f04a4b3 Parents: d4bb9a3 Author: Dongjoon Hyun <dongj...@apache.org> Authored: Fri Jun 17 16:07:33 2016 -0700 Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu> Committed: Fri Jun 17 16:07:41 2016 -0700 ---------------------------------------------------------------------- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 37 ++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 +++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 +++++++++++++ 4 files changed, 60 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ca0802fd/R/pkg/NAMESPACE ---------------------------------------------------------------------- diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5db43ae..9412ec3 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -81,6 +81,7 @@ exportMethods("arrange", "orderBy", "persist", "printSchema", + "randomSplit", "rbind", "registerTempTable", "rename", http://git-wip-us.apache.org/repos/asf/spark/blob/ca0802fd/R/pkg/R/DataFrame.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 231e4f0..4e04456 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2934,3 +2934,40 @@ setMethod("write.jdbc", write <- callJMethod(write, "mode", jmode) invisible(callJMethod(write, "jdbc", url, tableName, jprops)) }) + +#' randomSplit +#' +#' Return a list of randomly split dataframes with the provided weights. +#' +#' @param x A SparkDataFrame +#' @param weights A vector of weights for splits, will be normalized if they don't sum to 1 +#' @param seed A seed to use for random split +#' +#' @family SparkDataFrame functions +#' @rdname randomSplit +#' @name randomSplit +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df <- createDataFrame(data.frame(id = 1:1000)) +#' df_list <- randomSplit(df, c(2, 3, 5), 0) +#' # df_list contains 3 SparkDataFrames with each having about 200, 300 and 500 rows respectively +#' sapply(df_list, count) +#' } +#' @note since 2.0.0 +setMethod("randomSplit", + signature(x = "SparkDataFrame", weights = "numeric"), + function(x, weights, seed) { + if (!all(sapply(weights, function(c) { c >= 0 }))) { + stop("all weight values should not be negative") + } + normalized_list <- as.list(weights / sum(weights)) + if (!missing(seed)) { + sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list, as.integer(seed)) + } else { + sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list) + } + sapply(sdfs, dataFrame) + }) http://git-wip-us.apache.org/repos/asf/spark/blob/ca0802fd/R/pkg/R/generics.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 594bf2e..6e754af 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -679,6 +679,10 @@ setGeneric("withColumnRenamed", #' @export setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +#' @rdname randomSplit +#' @export +setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) + ###################### Column Methods ########################## #' @rdname column http://git-wip-us.apache.org/repos/asf/spark/blob/ca0802fd/R/pkg/inst/tests/testthat/test_sparkSQL.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7aa03a9..607bd9c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2280,6 +2280,24 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { expect_equal(collect(before), collect(after)) }) +test_that("randomSplit", { + num <- 4000 + df <- createDataFrame(data.frame(id = 1:num)) + + weights <- c(2, 3, 5) + df_list <- randomSplit(df, weights) + expect_equal(length(weights), length(df_list)) + counts <- sapply(df_list, count) + expect_equal(num, sum(counts)) + expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) + + df_list <- randomSplit(df, weights, 0) + expect_equal(length(weights), length(df_list)) + counts <- sapply(df_list, count) + expect_equal(num, sum(counts)) + expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org