Github user yanboliang commented on a diff in the pull request:

    https://github.com/apache/spark/pull/12023#discussion_r57904716
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala 
---
    @@ -199,21 +210,71 @@ final class RandomForestRegressionModel private[ml] (
       private[ml] def toOld: OldRandomForestModel = {
         new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
       }
    +
    +  @Since("2.0.0")
    +  override def write: MLWriter =
    +    new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)
    +
    +  @Since("2.0.0")
    +  override def read: MLReader[RandomForestRegressionModel] =
    +    new RandomForestRegressionModel.RandomForestRegressionModelReader(this)
     }
     
    -private[ml] object RandomForestRegressionModel {
    -
    -  /** (private[ml]) Convert a model from the old API */
    -  def fromOld(
    -      oldModel: OldRandomForestModel,
    -      parent: RandomForestRegressor,
    -      categoricalFeatures: Map[Int, Int],
    -      numFeatures: Int = -1): RandomForestRegressionModel = {
    -    require(oldModel.algo == OldAlgo.Regression, "Cannot convert 
RandomForestModel" +
    -      s" with algo=${oldModel.algo} (old API) to 
RandomForestRegressionModel (new API).")
    -    val newTrees = oldModel.trees.map { tree =>
    -      // parent for each tree is null since there is no good way to set 
this.
    -      DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
    +@Since("2.0.0")
    +object RandomForestRegressionModel extends 
MLReadable[RandomForestRegressionModel] {
    +
    +    @Since("2.0.0")
    +    override def load(path: String): RandomForestRegressionModel = 
super.load(path)
    +
    +    private[RandomForestRegressionModel]
    +    class RandomForestRegressionModelWriter(instance: 
RandomForestRegressionModel)
    +      extends MLWriter {
    +
    +          override protected def saveImpl(path: String): Unit = {
    +            val extraMetadata: JObject = Map(
    +                "numFeatures" -> instance.numFeatures)
    +            DefaultParamsWriter.saveMetadata(instance, path, sc, 
Some(extraMetadata))
    +            for ( treeIndex <- 1 to instance.getNumTrees) {
    --- End diff --
    
    +1


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to