Github user felixcheung commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15746#discussion_r86929250
  
    --- Diff: R/pkg/R/mllib.R ---
    @@ -1863,5 +1889,209 @@ print.summary.RandomForestRegressionModel <- 
function(x, ...) {
     #' @export
     #' @note print.summary.RandomForestClassificationModel since 2.1.0
     print.summary.RandomForestClassificationModel <- function(x, ...) {
    -  print.summary.randomForest(x)
    +  print.summary.treeEnsemble(x)
    +}
    +
    +#' Gradient Boosted Tree Model for Regression and Classification
    +#'
    +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or 
Classification model on a
    +#' SparkDataFrame. Users can call \code{summary} to get a summary of the 
fitted
    +#' Gradient Boosted Tree model, \code{predict} to make predictions on new 
data, and
    +#' \code{write.ml}/\code{read.ml} to save/load fitted models.
    +#' For more details, see
    +#' 
\href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{
    +#' GBT Regression} and
    +#' 
\href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{
    +#' GBT Classification}
    +#'
    +#' @param data a SparkDataFrame for training.
    +#' @param formula a symbolic description of the model to be fitted. 
Currently only a few formula
    +#'                operators are supported, including '~', ':', '+', and 
'-'.
    +#' @param type type of model, one of "regression" or "classification", to 
fit
    +#' @param maxDepth Maximum depth of the tree (>= 0).
    +#' @param maxBins Maximum number of bins used for discretizing continuous 
features and for choosing
    +#'                how to split on features at each node. More bins give 
higher granularity. Must be
    +#'                >= 2 and >= number of categories in any categorical 
feature.
    +#' @param maxIter Param for maximum number of iterations (>= 0).
    +#' @param stepSize Param for Step size to be used for each iteration of 
optimization.
    +#' @param lossType Loss function which GBT tries to minimize.
    +#'                 For classification, must be "logistic". For regression, 
must be one of
    +#'                 "squared" (L2) and "absolute" (L1), default is 
"squared".
    +#' @param seed integer seed for random number generation.
    +#' @param subsamplingRate Fraction of the training data used for learning 
each decision tree, in
    +#'                        range (0, 1].
    +#' @param minInstancesPerNode Minimum number of instances each child must 
have after split. If a
    +#'                            split causes the left or right child to have 
fewer than
    +#'                            minInstancesPerNode, the split will be 
discarded as invalid. Should be
    +#'                            >= 1.
    +#' @param minInfoGain Minimum information gain for a split to be 
considered at a tree node.
    +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or 
disable checkpoint (-1).
    +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram 
aggregation.
    +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to 
executors to match instances with
    +#'                     nodes. If TRUE, the algorithm will cache node IDs 
for each instance. Caching
    +#'                     can speed up training of deeper trees. Users can 
set how often should the
    +#'                     cache be checkpointed or disable it by setting 
checkpointInterval.
    +#' @param ... additional arguments passed to the method.
    +#' @aliases spark.gbt,SparkDataFrame,formula-method
    +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model.
    +#' @rdname spark.gbt
    +#' @name spark.gbt
    +#' @export
    +#' @examples
    +#' \dontrun{
    +#' # fit a Gradient Boosted Tree Regression Model
    +#' df <- createDataFrame(longley)
    +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, 
maxBins = 16)
    +#'
    +#' # get the summary of the model
    +#' summary(model)
    +#'
    +#' # make predictions
    +#' predictions <- predict(model, df)
    +#'
    +#' # save and load the model
    +#' path <- "path/to/model"
    +#' write.ml(model, path)
    +#' savedModel <- read.ml(path)
    +#' summary(savedModel)
    +#'
    +#' # fit a Gradient Boosted Tree Classification Model
    +#' # label must be binary - Only binary classification is supported for 
GBT.
    +#' df <- createDataFrame(iris[iris$Species != "virginica", ])
    +#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, 
"classification")
    +#'
    +#' # numeric label is also supported
    +#' iris2 <- iris[iris$Species != "virginica", ]
    +#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1)
    +#' df <- createDataFrame(iris2)
    +#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification")
    +#' }
    +#' @note spark.gbt since 2.1.0
    +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = 
"formula"),
    +          function(data, formula, type = c("regression", "classification"),
    +                   maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 
0.1, lossType = NULL,
    +                   seed = NULL, subsamplingRate = 1.0, minInstancesPerNode 
= 1, minInfoGain = 0.0,
    +                   checkpointInterval = 10, maxMemoryInMB = 256, 
cacheNodeIds = FALSE) {
    +            type <- match.arg(type)
    +            formula <- paste(deparse(formula), collapse = "")
    +            if (!is.null(seed)) {
    +              seed <- as.character(as.integer(seed))
    +            }
    +            switch(type,
    +                   regression = {
    +                     if (is.null(lossType)) lossType <- "squared"
    +                     lossType <- match.arg(lossType, c("squared", 
"absolute"))
    +                     jobj <- 
callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper",
    +                                         "fit", data@sdf, formula, 
as.integer(maxDepth),
    +                                         as.integer(maxBins), 
as.integer(maxIter),
    +                                         as.numeric(stepSize), 
as.integer(minInstancesPerNode),
    +                                         as.numeric(minInfoGain), 
as.integer(checkpointInterval),
    +                                         lossType, seed, 
as.numeric(subsamplingRate),
    +                                         as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds))
    +                     new("GBTRegressionModel", jobj = jobj)
    +                   },
    +                   classification = {
    +                     if (is.null(lossType)) lossType <- "logistic"
    +                     lossType <- match.arg(lossType, "logistic")
    +                     jobj <- 
callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper",
    +                                         "fit", data@sdf, formula, 
as.integer(maxDepth),
    +                                         as.integer(maxBins), 
as.integer(maxIter),
    +                                         as.numeric(stepSize), 
as.integer(minInstancesPerNode),
    +                                         as.numeric(minInfoGain), 
as.integer(checkpointInterval),
    +                                         lossType, seed, 
as.numeric(subsamplingRate),
    +                                         as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds))
    +                     new("GBTClassificationModel", jobj = jobj)
    +                   }
    +            )
    +          })
    +
    +# Makes predictions from a Gradient Boosted Tree Regression model or 
Classification model
    +
    +#' @param newData a SparkDataFrame for testing.
    +#' @return \code{predict} returns a SparkDataFrame containing predicted 
labeled in a column named
    +#' "prediction"
    +#' @rdname spark.gbt
    +#' @aliases predict,GBTRegressionModel-method
    +#' @export
    +#' @note predict(GBTRegressionModel) since 2.1.0
    +setMethod("predict", signature(object = "GBTRegressionModel"),
    +          function(object, newData) {
    +            predict_internal(object, newData)
    +          })
    +
    +#' @rdname spark.gbt
    +#' @aliases predict,GBTClassificationModel-method
    +#' @export
    +#' @note predict(GBTClassificationModel) since 2.1.0
    +setMethod("predict", signature(object = "GBTClassificationModel"),
    +          function(object, newData) {
    +            predict_internal(object, newData)
    +          })
    +
    +# Save the Gradient Boosted Tree Regression or Classification model to the 
input path.
    +
    +#' @param object A fitted Gradient Boosted Tree regression model or 
classification model
    +#' @param path The directory where the model is saved
    +#' @param overwrite Overwrites or not if the output path already exists. 
Default is FALSE
    +#'                  which means throw exception if the output path exists.
    +#' @aliases write.ml,GBTRegressionModel,character-method
    +#' @rdname spark.gbt
    +#' @export
    +#' @note write.ml(GBTRegressionModel, character) since 2.1.0
    +setMethod("write.ml", signature(object = "GBTRegressionModel", path = 
"character"),
    +          function(object, path, overwrite = FALSE) {
    +            write_internal(object, path, overwrite)
    +          })
    +
    +#' @aliases write.ml,GBTClassificationModel,character-method
    +#' @rdname spark.gbt
    +#' @export
    +#' @note write.ml(GBTClassificationModel, character) since 2.1.0
    +setMethod("write.ml", signature(object = "GBTClassificationModel", path = 
"character"),
    +          function(object, path, overwrite = FALSE) {
    +            write_internal(object, path, overwrite)
    +          })
    +
    +#' @return \code{summary} returns the model's features as lists, depth and 
number of nodes
    +#'                        or number of classes.
    --- End diff --
    
    opened SPARK-18349


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to