Repository: spark Updated Branches: refs/heads/branch-1.4 8164fbc25 -> f41be8fb3
[SPARK-7473] [MLLIB] Add reservoir sample in RandomForest reservoir feature sample by using existing api Author: AiHe <[email protected]> Closes #5988 from AiHe/reservoir and squashes the following commits: e7a41ac [AiHe] remove non-robust testing case 28ffb9a [AiHe] set seed as rng.nextLong 37459e1 [AiHe] set fixed seed 1e98a4c [AiHe] [MLLIB][tree] Add reservoir sample in RandomForest (cherry picked from commit deb411335a09b91eb1f75421d77e1c3686719621) Signed-off-by: Joseph K. Bradley <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f41be8fb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f41be8fb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f41be8fb Branch: refs/heads/branch-1.4 Commit: f41be8fb38608c79ff69a85f0715de5ebd3ae2a5 Parents: 8164fbc Author: AiHe <[email protected]> Authored: Fri May 15 20:42:35 2015 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Fri May 15 20:42:59 2015 -0700 ---------------------------------------------------------------------- .../main/scala/org/apache/spark/mllib/tree/RandomForest.scala | 6 +++--- .../scala/org/apache/spark/mllib/tree/RandomForestSuite.scala | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f41be8fb/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 055e60c..b347c45 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import org.apache.spark.util.random.SamplingUtils /** * :: Experimental :: @@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) - Some(rng.shuffle(Range(0, metadata.numFeatures).toList) - .take(metadata.numFeaturesPerNode).toArray) + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } http://git-wip-us.apache.org/repos/asf/spark/blob/f41be8fb/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index ee3bc98..4ed6695 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,7 +196,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - EnsembleTestHelper.validateClassifier(model, arr, 1.0) } test("subsampling rate in RandomForest"){ --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
