Repository: spark Updated Branches: refs/heads/master 3e770a64a -> e963070c1
[SPARK-9722] [ML] Pass random seed to spark.ml DecisionTree* Author: Yu ISHIKAWA <yuu.ishik...@gmail.com> Closes #9402 from yu-iskw/SPARK-9722. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e963070c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e963070c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e963070c Branch: refs/heads/master Commit: e963070c13f56fbc2dfaf9f5d4e69d34afd0957c Parents: 3e770a6 Author: Yu ISHIKAWA <yuu.ishik...@gmail.com> Authored: Sun Nov 1 23:52:50 2015 -0800 Committer: DB Tsai <d...@netflix.com> Committed: Sun Nov 1 23:52:50 2015 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e963070c/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 96d5652..4a3b12d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -74,7 +74,7 @@ private[ml] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") - val splits = findSplits(retaggedInput, metadata) + val splits = findSplits(retaggedInput, metadata, seed) timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => @@ -815,6 +815,7 @@ private[ml] object RandomForest extends Logging { * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param metadata Learning and dataset metadata + * @param seed random seed * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). @@ -823,7 +824,8 @@ private[ml] object RandomForest extends Logging { */ protected[tree] def findSplits( input: RDD[LabeledPoint], - metadata: DecisionTreeMetadata): Array[Array[Split]] = { + metadata: DecisionTreeMetadata, + seed : Long): Array[Array[Split]] = { logDebug("isMulticlass = " + metadata.isMulticlass) @@ -840,7 +842,7 @@ private[ml] object RandomForest extends Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect() + input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } else { new Array[LabeledPoint](0) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org