zhengruifeng commented on a change in pull request #27103: [SPARK-30381][ML] 
Refactor GBT to reuse treePoints for all trees
URL: https://github.com/apache/spark/pull/27103#discussion_r363574943
 
 

 ##########
 File path: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
 ##########
 @@ -313,29 +335,70 @@ private[spark] object GradientBoostedTrees extends 
Logging {
 
     // Initialize tree
     timer.start("building tree 0")
-    val metadata = DecisionTreeMetadata.buildMetadata(
-      input.retag(classOf[Instance]), treeStrategy, numTrees = 1,
-      featureSubsetStrategy)
-    val firstTreeModel = RandomForest.runWithMetadata(input, metadata, 
treeStrategy,
-      numTrees = 1, featureSubsetStrategy, seed = seed, instr = instr,
+    val retaggedInput = input.retag(classOf[Instance])
+    timer.start("buildMetadata")
+    val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, 
treeStrategy,
+      numTrees = 1, featureSubsetStrategy)
+    timer.stop("buildMetadata")
+
+    timer.start("findSplits")
+    val splits = RandomForest.findSplits(retaggedInput, metadata, seed)
+    val bcSplits = sc.broadcast(splits)
+    timer.stop("findSplits")
+
+    // Bin feature values (TreePoint representation).
+    // Cache input RDD for speedup during multiple passes.
+    val treePoints = TreePoint.convertToTreeRDD(
+      retaggedInput, splits, metadata)
+      .setName("binned tree points")
+      .persist(StorageLevel.MEMORY_AND_DISK)
+    val validationTreePoints = if (validate) {
+      TreePoint.convertToTreeRDD(
+        validationInput.retag(classOf[Instance]), splits, metadata)
+        .persist(StorageLevel.MEMORY_AND_DISK)
+    } else sc.emptyRDD[TreePoint]
+
+    val firstCounts = BaggedPoint
+      .convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, 
numSubsamples = 1,
+        withReplacement = false, (tp: TreePoint) => tp.weight, seed = seed)
+      .map { bagged =>
+        require(bagged.subsampleCounts.length == 1)
+        require(bagged.sampleWeight == bagged.datum.weight)
+        bagged.subsampleCounts.head
+      }.setName("firstCounts for iter=0")
+      .persist(StorageLevel.MEMORY_AND_DISK)
+
+    val firstBagged = treePoints.zip(firstCounts)
+      .map { case (treePoint, count) =>
+        // according to current design, treePoint.weight == 
baggedPoint.sampleWeight
+        new BaggedPoint[TreePoint](treePoint, Array(count), treePoint.weight)
+    }
+
+    val firstTreeModel = RandomForest.runBagged(baggedInput = firstBagged,
+      metadata = metadata, splits = splits, strategy = treeStrategy, numTrees 
= 1,
+      featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = None,
       parentUID = None)
       .head.asInstanceOf[DecisionTreeRegressionModel]
+
+    firstCounts.unpersist()
+
     val firstTreeWeight = 1.0
     baseLearners(0) = firstTreeModel
     baseLearnerWeights(0) = firstTreeWeight
 
-    var predError = computeInitialPredictionAndError(input, firstTreeWeight, 
firstTreeModel, loss)
+    var predError = computeInitialPredictionAndError(
+      treePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
     predErrorCheckpointer.update(predError)
-    logDebug("error of gbt = " + computeWeightedError(input, predError))
+    logDebug("error of gbt = " + computeWeightedError(treePoints, predError))
 
 Review comment:
   I had checked the convergence by `gbtm.evaluateEachIteration(df, "xxx")` and 
the curves are similar (may differ due to different `splits`)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to