srowen closed pull request #16722: [SPARK-19591][ML][MLlib] Add sample weights
to decision trees
URL: https://github.com/apache/spark/pull/16722
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 2327917e2cad7..94158bf5d6e30 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -31,7 +31,7 @@ object TestingUtils {
* Note that if x or y is extremely close to zero, i.e., smaller than
Double.MinPositiveValue,
* the relative tolerance is meaningless, so the exception will be raised to
warn users.
*/
- private def RelativeErrorComparison(x: Double, y: Double, eps: Double):
Boolean = {
+ private[ml] def RelativeErrorComparison(x: Double, y: Double, eps: Double):
Boolean = {
val absX = math.abs(x)
val absY = math.abs(y)
val diff = math.abs(x - y)
@@ -48,7 +48,7 @@ object TestingUtils {
/**
* Private helper function for comparing two values using absolute tolerance.
*/
- private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double):
Boolean = {
+ private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double):
Boolean = {
math.abs(x - y) < eps
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 9f60f0896ec52..6a9b3564d63fc 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -22,18 +22,21 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy =>
OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel =>
OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dataset
-
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
/**
* Decision tree learning algorithm
(http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -65,6 +68,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
override def setMinInstancesPerNode(value: Int): this.type =
set(minInstancesPerNode, value)
/** @group setParam */
+ @Since("2.2.0")
+ def setMinWeightFractionPerNode(value: Double): this.type =
set(minWeightFractionPerNode, value)
+
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain,
value)
@@ -96,6 +102,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
override def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
override protected def train(dataset: Dataset[_]):
DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -106,14 +122,23 @@ class DecisionTreeClassifier @Since("1.4.0") (
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length
${$(thresholds).length}")
}
-
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset,
numClasses)
+ require(numClasses > 0, s"DecisionTreeClassifier (in extractLabeledPoints)
found numClasses =" +
+ s" $numClasses, but requires numClasses > 0.")
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else
col($(weightCol))
+ val instances =
+ dataset.select(col($(labelCol)).cast(DoubleType), w,
col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ require(label % 1 == 0 && label >= 0 && label < numClasses,
s"Classifier was given" +
+ s" dataset with invalid label $label. Labels must be integers in
range" +
+ s" [0, $numClasses).")
+ Instance(label, weight, features)
+ }
val strategy = getOldStrategy(categoricalFeatures, numClasses)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, instances)
instr.logParams(params: _*)
- val trees = RandomForest.run(oldDataset, strategy, numTrees = 1,
featureSubsetStrategy = "all",
+ val trees = RandomForest.run(instances, strategy, numTrees = 1,
featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
@@ -124,11 +149,12 @@ class DecisionTreeClassifier @Since("1.4.0") (
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
- val instr = Instrumentation.create(this, data)
- instr.logParams(params: _*)
- val trees = RandomForest.run(data, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all",
- seed = 0L, instr = Some(instr), parentUID = Some(uid))
+ val instances = data.map(_.toInstance(1.0))
+ val instr = Instrumentation.create(this, instances)
+ instr.logParams(params: _*)
+ val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+ featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID
= Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
instr.logSuccess(m)
@@ -176,6 +202,7 @@ class DecisionTreeClassificationModel private[ml] (
/**
* Construct a decision tree classification model.
+ *
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index ce834f1d17e0d..5674e48134928 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -21,19 +21,20 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams,
TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel =>
OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.functions.{col, udf}
/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
learning algorithm for
@@ -126,20 +127,20 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length
${$(thresholds).length}")
}
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset,
numClasses)
+ val instances: RDD[Instance] = extractLabeledPoints(dataset,
numClasses).map(_.toInstance(1.0))
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses,
OldAlgo.Classification, getOldImpurity)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, instances)
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol,
rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins,
maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds,
checkpointInterval)
val trees = RandomForest
- .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy,
getSeed, Some(instr))
+ .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy,
getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
- val numFeatures = oldDataset.first().features.size
+ val numFeatures = instances.first().features.size
val m = new RandomForestClassificationModel(trees, numFeatures, numClasses)
instr.logSuccess(m)
m
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
index cce3ca45ccd8f..7e6e4c5a26e4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
@@ -26,4 +26,4 @@ import org.apache.spark.ml.linalg.Vector
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
-private[ml] case class Instance(label: Double, weight: Double, features:
Vector)
+private[spark] case class Instance(label: Double, weight: Double, features:
Vector)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
index c5d0ec1a8d350..a19f6a88968b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
@@ -35,4 +35,9 @@ case class LabeledPoint(@Since("2.0.0") label: Double,
@Since("2.0.0") features:
override def toString: String = {
s"($label,$features)"
}
+
+ private[spark] def toInstance(weight: Double): Instance = {
+ Instance(label, weight, features)
+ }
+
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 01c5cc1c7efa9..8205714dff462 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -23,8 +23,9 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
@@ -33,8 +34,10 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy =>
OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel =>
OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
/**
@@ -64,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0")
override val uid: S
override def setMinInstancesPerNode(value: Int): this.type =
set(minInstancesPerNode, value)
/** @group setParam */
+ @Since("2.2.0")
+ def setMinWeightFractionPerNode(value: Double): this.type =
set(minWeightFractionPerNode, value)
+
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain,
value)
@@ -99,16 +105,31 @@ class DecisionTreeRegressor @Since("1.4.0")
(@Since("1.4.0") override val uid: S
@Since("2.0.0")
def setVarianceCol(value: String): this.type = set(varianceCol, value)
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
override protected def train(dataset: Dataset[_]):
DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else
col($(weightCol))
+ val instances =
+ dataset.select(col($(labelCol)).cast(DoubleType), w,
col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
val strategy = getOldStrategy(categoricalFeatures)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, instances)
instr.logParams(params: _*)
- val trees = RandomForest.run(oldDataset, strategy, numTrees = 1,
featureSubsetStrategy = "all",
+ val trees = RandomForest.run(instances, strategy, numTrees = 1,
featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
@@ -122,8 +143,9 @@ class DecisionTreeRegressor @Since("1.4.0")
(@Since("1.4.0") override val uid: S
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)
- val trees = RandomForest.run(data, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all",
- seed = $(seed), instr = Some(instr), parentUID = Some(uid))
+ val instances = data.map(_.toInstance(1.0))
+ val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+ featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr),
parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
instr.logSuccess(m)
@@ -153,6 +175,7 @@ object DecisionTreeRegressor extends
DefaultParamsReadable[DecisionTreeRegressor
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
* Decision tree (Wikipedia)</a> model for regression.
* It supports both continuous and categorical features.
+ *
* @param rootNode Root of the decision tree
*/
@Since("1.4.0")
@@ -171,6 +194,7 @@ class DecisionTreeRegressionModel private[ml] (
/**
* Construct a decision tree regression model.
+ *
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int) =
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 2f524a8c5784d..0b6bcb9c81235 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -22,7 +22,6 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
@@ -31,10 +30,8 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel =>
OldRandomForestModel}
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.functions.{col, udf}
/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
@@ -117,20 +114,20 @@ class RandomForestRegressor @Since("1.4.0")
(@Since("1.4.0") override val uid: S
override protected def train(dataset: Dataset[_]):
RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+
+ val instances = extractLabeledPoints(dataset).map(_.toInstance(1.0))
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0,
OldAlgo.Regression, getOldImpurity)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, instances)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds,
checkpointInterval)
val trees = RandomForest
- .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy,
getSeed, Some(instr))
+ .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy,
getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])
-
- val numFeatures = oldDataset.first().features.size
+ val numFeatures = instances.first().features.size
val m = new RandomForestRegressionModel(trees, numFeatures)
instr.logSuccess(m)
m
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
index 4e372702f0c65..2bb7020232f36 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -33,13 +33,20 @@ import org.apache.spark.util.random.XORShiftRandom
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples,
respectively.
*
* @param datum Data instance
- * @param subsampleWeights Weight of this instance in each subsampled dataset.
- *
- * TODO: This does not currently support (Double) weighted instances. Once
MLlib has weighted
- * dataset support, update. (We store subsampleWeights as Double for
this future extension.)
+ * @param subsampleCounts Number of samples of this instance in each
subsampled dataset.
+ * @param sampleWeight The weight of this instance.
*/
-private[spark] class BaggedPoint[Datum](val datum: Datum, val
subsampleWeights: Array[Double])
- extends Serializable
+private[spark] class BaggedPoint[Datum](
+ val datum: Datum,
+ val subsampleCounts: Array[Int],
+ val sampleWeight: Double) extends Serializable {
+
+ /**
+ * Subsample counts weighted by the sample weight.
+ */
+ def weightedCounts: Array[Double] = subsampleCounts.map(_ * sampleWeight)
+
+}
private[spark] object BaggedPoint {
@@ -52,6 +59,7 @@ private[spark] object BaggedPoint {
* @param subsamplingRate Fraction of the training data used for learning
decision tree.
* @param numSubsamples Number of subsamples of this RDD to take.
* @param withReplacement Sampling with/without replacement.
+ * @param extractSampleWeight A function to get the sample weight of each
datum.
* @param seed Random seed.
* @return BaggedPoint dataset representation.
*/
@@ -60,12 +68,14 @@ private[spark] object BaggedPoint {
subsamplingRate: Double,
numSubsamples: Int,
withReplacement: Boolean,
+ extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
+ // TODO: implement weighted bootstrapping
if (withReplacement) {
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate,
numSubsamples, seed)
} else {
if (numSubsamples == 1 && subsamplingRate == 1.0) {
- convertToBaggedRDDWithoutSampling(input)
+ convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
} else {
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate,
numSubsamples, seed)
}
@@ -82,16 +92,16 @@ private[spark] object BaggedPoint {
val rng = new XORShiftRandom
rng.setSeed(seed + partitionIndex + 1)
instances.map { instance =>
- val subsampleWeights = new Array[Double](numSubsamples)
+ val subsampleCounts = new Array[Int](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
val x = rng.nextDouble()
- subsampleWeights(subsampleIndex) = {
- if (x < subsamplingRate) 1.0 else 0.0
+ subsampleCounts(subsampleIndex) = {
+ if (x < subsamplingRate) 1 else 0
}
subsampleIndex += 1
}
- new BaggedPoint(instance, subsampleWeights)
+ new BaggedPoint(instance, subsampleCounts, 1.0)
}
}
}
@@ -106,20 +116,20 @@ private[spark] object BaggedPoint {
val poisson = new PoissonDistribution(subsample)
poisson.reseedRandomGenerator(seed + partitionIndex + 1)
instances.map { instance =>
- val subsampleWeights = new Array[Double](numSubsamples)
+ val subsampleCounts = new Array[Int](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
- subsampleWeights(subsampleIndex) = poisson.sample()
+ subsampleCounts(subsampleIndex) = poisson.sample()
subsampleIndex += 1
}
- new BaggedPoint(instance, subsampleWeights)
+ new BaggedPoint(instance, subsampleCounts, 1.0)
}
}
}
private def convertToBaggedRDDWithoutSampling[Datum] (
- input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
- input.map(datum => new BaggedPoint(datum, Array(1.0)))
+ input: RDD[Datum],
+ extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = {
+ input.map(datum => new BaggedPoint(datum, Array(1),
extractSampleWeight(datum)))
}
-
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index 61091bb803e49..3124f4ee3c103 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -104,16 +104,21 @@ private[spark] class DTStatsAggregator(
/**
* Update the stats for a given (feature, bin) for ordered features, using
the given label.
*/
- def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight:
Double): Unit = {
+ def update(
+ featureIndex: Int,
+ binIndex: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
val i = featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
+ impurityAggregator.update(allStats, i, label, numSamples, sampleWeight)
}
/**
* Update the parent node stats using the given label.
*/
- def updateParent(label: Double, instanceWeight: Double): Unit = {
- impurityAggregator.update(parentStats, 0, label, instanceWeight)
+ def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit
= {
+ impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight)
}
/**
@@ -127,9 +132,10 @@ private[spark] class DTStatsAggregator(
featureOffset: Int,
binIndex: Int,
label: Double,
- instanceWeight: Double): Unit = {
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
- label, instanceWeight)
+ label, numSamples, sampleWeight)
}
/**
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 8a9dcb486b7bf..b67cfff752935 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.util.Try
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.RandomForestParams
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -32,16 +32,20 @@ import org.apache.spark.rdd.RDD
/**
* Learning and dataset metadata for DecisionTree.
*
+ * @param weightedNumExamples Weighted count of samples in the tree.
* @param numClasses For classification: labels can take values {0, ...,
numClasses - 1}.
* For regression: fixed at 0 (no meaning).
* @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index to arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
* @param numBins Number of bins for each feature.
+ * @param minWeightFractionPerNode The minimum fraction of the total sample
weight that must be
+ * present in a leaf node in order to be
considered a valid split.
*/
private[spark] class DecisionTreeMetadata(
val numFeatures: Int,
val numExamples: Long,
+ val weightedNumExamples: Double,
val numClasses: Int,
val maxBins: Int,
val featureArity: Map[Int, Int],
@@ -51,6 +55,7 @@ private[spark] class DecisionTreeMetadata(
val quantileStrategy: QuantileStrategy,
val maxDepth: Int,
val minInstancesPerNode: Int,
+ val minWeightFractionPerNode: Double,
val minInfoGain: Double,
val numTrees: Int,
val numFeaturesPerNode: Int) extends Serializable {
@@ -67,6 +72,8 @@ private[spark] class DecisionTreeMetadata(
def isContinuous(featureIndex: Int): Boolean =
!featureArity.contains(featureIndex)
+ def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples
+
/**
* Number of splits for the given feature.
* For unordered features, there is 1 bin per split.
@@ -104,7 +111,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
* as well as the number of splits and bins for each feature.
*/
def buildMetadata(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
@@ -115,7 +122,10 @@ private[spark] object DecisionTreeMetadata extends Logging
{
}
require(numFeatures > 0, s"DecisionTree requires number of features > 0, "
+
s"but was given an empty features vector")
- val numExamples = input.count()
+ val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
+ (acc, x) => (acc._1 + 1L, acc._2 + x.weight),
+ (acc1, acc2) => (acc1._1 + acc2._1, acc1._2 + acc2._2))
+
val numClasses = strategy.algo match {
case Classification => strategy.numClasses
case Regression => 0
@@ -206,17 +216,18 @@ private[spark] object DecisionTreeMetadata extends
Logging {
}
}
- new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
- strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
+ new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
+ numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
numBins,
strategy.impurity, strategy.quantileCalculationStrategy,
strategy.maxDepth,
- strategy.minInstancesPerNode, strategy.minInfoGain, numTrees,
numFeaturesPerNode)
+ strategy.minInstancesPerNode, strategy.minWeightFractionPerNode,
strategy.minInfoGain,
+ numTrees, numFeaturesPerNode)
}
/**
* Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
*/
def buildMetadata(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
strategy: Strategy): DecisionTreeMetadata = {
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 008dd19c2498d..80b4a97a82bd3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -24,7 +24,8 @@ import scala.util.Random
import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.Instrumentation
@@ -82,11 +83,11 @@ private[spark] object RandomForest extends Logging {
/**
* Train a random forest.
*
- * @param input Training data: RDD of `LabeledPoint`
+ * @param input Training data: RDD of
[[org.apache.spark.ml.feature.Instance]]
* @return an unweighted set of trees
*/
def run(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
@@ -100,9 +101,10 @@ private[spark] object RandomForest extends Logging {
timer.start("init")
- val retaggedInput = input.retag(classOf[LabeledPoint])
+ val retaggedInput = input.retag(classOf[Instance])
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees,
featureSubsetStrategy)
+
instr match {
case Some(instrumentation) =>
instrumentation.logNumFeatures(metadata.numFeatures)
@@ -129,7 +131,8 @@ private[spark] object RandomForest extends Logging {
val withReplacement = numTrees > 1
val baggedInput = BaggedPoint
- .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees,
withReplacement, seed)
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees,
withReplacement,
+ (tp: TreePoint) => tp.weight, seed = seed)
.persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
@@ -250,19 +253,21 @@ private[spark] object RandomForest extends Logging {
* For unordered features, bins correspond to subsets of categories; either
the left or right bin
* for each subset is updated.
*
- * @param agg Array storing aggregate calculation, with a set of sufficient
statistics for
- * each (feature, bin).
- * @param treePoint Data point being aggregated.
- * @param splits possible splits indexed (numFeatures)(numSplits)
- * @param unorderedFeatures Set of indices of unordered features.
- * @param instanceWeight Weight (importance) of instance in dataset.
+ * @param agg Array storing aggregate calculation, with a set of sufficient
statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits Possible splits indexed (numFeatures)(numSplits)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
- instanceWeight: Double,
+ numSamples: Int,
+ sampleWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
// Use subsampled features
@@ -289,14 +294,15 @@ private[spark] object RandomForest extends Logging {
var splitIndex = 0
while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue,
featureSplits)) {
- agg.featureUpdate(leftNodeFeatureOffset, splitIndex,
treePoint.label, instanceWeight)
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex,
treePoint.label, numSamples,
+ sampleWeight)
}
splitIndex += 1
}
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples,
sampleWeight)
}
featureIndexIdx += 1
}
@@ -310,12 +316,14 @@ private[spark] object RandomForest extends Logging {
* @param agg Array storing aggregate calculation, with a set of sufficient
statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param instanceWeight Weight (importance) of instance in dataset.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- instanceWeight: Double,
+ numSamples: Int,
+ sampleWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
@@ -325,7 +333,7 @@ private[spark] object RandomForest extends Logging {
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.length) {
val binIndex =
treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
- agg.update(featureIndexIdx, binIndex, label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
featureIndexIdx += 1
}
} else {
@@ -334,7 +342,7 @@ private[spark] object RandomForest extends Logging {
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.update(featureIndex, binIndex, label, instanceWeight)
+ agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
featureIndex += 1
}
}
@@ -423,14 +431,16 @@ private[spark] object RandomForest extends Logging {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ val numSamples = baggedPoint.subsampleCounts(treeIndex)
+ val sampleWeight = baggedPoint.sampleWeight
if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum,
instanceWeight, featuresForNode)
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples,
sampleWeight,
+ featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
- metadata.unorderedFeatures, instanceWeight, featuresForNode)
+ metadata.unorderedFeatures, numSamples, sampleWeight,
featuresForNode)
}
- agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples,
sampleWeight)
}
}
@@ -590,8 +600,8 @@ private[spark] object RandomForest extends Logging {
if (!isLeaf) {
node.split = Some(split)
val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) ==
metadata.maxDepth
- val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
- val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+ val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) <
Utils.EPSILON)
+ val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity)
< Utils.EPSILON)
node.leftChild =
Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
leftChildIsLeaf,
ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild =
Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
@@ -655,15 +665,20 @@ private[spark] object RandomForest extends Logging {
stats.impurity
}
+ val leftRawCount = leftImpurityCalculator.rawCount
+ val rightRawCount = rightImpurityCalculator.rawCount
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
val totalCount = leftCount + rightCount
- // If left child or right child doesn't satisfy minimum instances per node,
- // then this split is invalid, return invalid information gain stats.
- if ((leftCount < metadata.minInstancesPerNode) ||
- (rightCount < metadata.minInstancesPerNode)) {
+ val violatesMinInstancesPerNode = (leftRawCount <
metadata.minInstancesPerNode) ||
+ (rightRawCount < metadata.minInstancesPerNode)
+ val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+ (rightCount < metadata.minWeightPerNode)
+ // If left child or right child doesn't satisfy minimum weight per node or
minimum
+ // instances per node, then this split is invalid, return invalid
information gain stats.
+ if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
@@ -730,7 +745,8 @@ private[spark] object RandomForest extends Logging {
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
- val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset,
splitIdx)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset,
numSplits)
rightChildStats.subtract(leftChildStats)
@@ -872,14 +888,14 @@ private[spark] object RandomForest extends Logging {
* and for multiclass classification with a high-arity feature,
* there is one bin per category.
*
- * @param input Training data: RDD of [[LabeledPoint]]
+ * @param input Training data: RDD of [[Instance]]
* @param metadata Learning and dataset metadata
* @param seed random seed
* @return Splits, an Array of [[Split]]
* of size (numFeatures, numSplits)
*/
protected[tree] def findSplits(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
metadata: DecisionTreeMetadata,
seed: Long): Array[Array[Split]] = {
@@ -900,14 +916,14 @@ private[spark] object RandomForest extends Logging {
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new
XORShiftRandom(seed).nextInt())
} else {
- input.sparkContext.emptyRDD[LabeledPoint]
+ input.sparkContext.emptyRDD[Instance]
}
findSplitsBySorting(sampledInput, metadata, continuousFeatures)
}
private def findSplitsBySorting(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
@@ -918,7 +934,7 @@ private[spark] object RandomForest extends Logging {
val numPartitions = math.min(continuousFeatures.length,
input.partitions.length)
input
- .flatMap(point => continuousFeatures.map(idx => (idx,
point.features(idx))))
+ .flatMap(point => continuousFeatures.map(idx => (idx, (point.weight,
point.features(idx)))))
.groupByKey(numPartitions)
.map { case (idx, samples) =>
val thresholds = findSplitsForContinuousFeature(samples, metadata,
idx)
@@ -982,7 +998,7 @@ private[spark] object RandomForest extends Logging {
* could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will
be set accordingly.
*
- * @param featureSamples feature values of each sample
+ * @param featureSamples feature values and sample weights of each sample
* @param metadata decision tree metadata
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
@@ -990,7 +1006,7 @@ private[spark] object RandomForest extends Logging {
* @return array of split thresholds
*/
private[tree] def findSplitsForContinuousFeature(
- featureSamples: Iterable[Double],
+ featureSamples: Iterable[(Double, Double)],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
@@ -1002,9 +1018,9 @@ private[spark] object RandomForest extends Logging {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value
- val (valueCountMap, numSamples) =
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
- case ((m, cnt), x) =>
- (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
+ val (valueCountMap, numSamples) =
featureSamples.foldLeft((Map.empty[Double, Double], 0.0)) {
+ case ((m, cnt), (w, x)) =>
+ (m + ((x, m.getOrElse(x, 0.0) + w)), cnt + w)
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
index a6ac64a0463cc..16b00299594a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.tree.impl
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.rdd.RDD
@@ -36,10 +36,12 @@ import org.apache.spark.rdd.RDD
* @param label Label from LabeledPoint
* @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are
bin indices.
+ * @param weight Sample weight for this TreePoint.
*/
-private[spark] class TreePoint(val label: Double, val binnedFeatures:
Array[Int])
- extends Serializable {
-}
+private[spark] class TreePoint(
+ val label: Double,
+ val binnedFeatures: Array[Int],
+ val weight: Double) extends Serializable
private[spark] object TreePoint {
@@ -52,7 +54,7 @@ private[spark] object TreePoint {
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
splits: Array[Array[Split]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
// Construct arrays for featureArity for efficiency in the inner loop.
@@ -82,18 +84,18 @@ private[spark] object TreePoint {
* for categorical features.
*/
private def labeledPointToTreePoint(
- labeledPoint: LabeledPoint,
+ instance: Instance,
thresholds: Array[Array[Double]],
featureArity: Array[Int]): TreePoint = {
- val numFeatures = labeledPoint.features.size
+ val numFeatures = instance.features.size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) =
- findBin(featureIndex, labeledPoint, featureArity(featureIndex),
thresholds(featureIndex))
+ findBin(featureIndex, instance, featureArity(featureIndex),
thresholds(featureIndex))
featureIndex += 1
}
- new TreePoint(labeledPoint.label, arr)
+ new TreePoint(instance.label, arr, instance.weight)
}
/**
@@ -106,10 +108,10 @@ private[spark] object TreePoint {
*/
private def findBin(
featureIndex: Int,
- labeledPoint: LabeledPoint,
+ instance: Instance,
featureArity: Int,
thresholds: Array[Double]): Int = {
- val featureValue = labeledPoint.features(featureIndex)
+ val featureValue = instance.features(featureIndex)
if (featureArity == 0) {
val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
@@ -125,7 +127,7 @@ private[spark] object TreePoint {
s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in
{0,...,${featureArity - 1}," +
s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
+ " Bad data point: " + instance.toString)
}
featureValue.toInt
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 0d6e9034e5ce4..820c42fa43b88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -219,7 +219,7 @@ private[ml] object TreeEnsembleModel {
importances.changeValue(feature, scaledGain, _ + scaledGain)
computeFeatureImportance(n.leftChild, importances)
computeFeatureImportance(n.rightChild, importances)
- case n: LeafNode =>
+ case _: LeafNode =>
// do nothing
}
}
@@ -282,6 +282,7 @@ private[ml] object DecisionTreeModelReadWrite {
*
* @param id Index used for tree reconstruction. Indices follow a
pre-order traversal.
* @param impurityStats Stats array. Impurity type is stored in metadata.
+ * @param rawCount The unweighted number of samples falling in this node.
* @param gain Gain, or arbitrary value if leaf node.
* @param leftChild Left child index, or arbitrary value if leaf node.
* @param rightChild Right child index, or arbitrary value if leaf node.
@@ -292,6 +293,7 @@ private[ml] object DecisionTreeModelReadWrite {
prediction: Double,
impurity: Double,
impurityStats: Array[Double],
+ rawCount: Long,
gain: Double,
leftChild: Int,
rightChild: Int,
@@ -311,11 +313,12 @@ private[ml] object DecisionTreeModelReadWrite {
val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
val thisNodeData = NodeData(id, n.prediction, n.impurity,
n.impurityStats.stats,
- n.gain, leftNodeData.head.id, rightNodeData.head.id,
SplitData(n.split))
+ n.impurityStats.rawCount, n.gain, leftNodeData.head.id,
rightNodeData.head.id,
+ SplitData(n.split))
(thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
case _: LeafNode =>
(Seq(NodeData(id, node.prediction, node.impurity,
node.impurityStats.stats,
- -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
+ node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1,
Array.empty[Double], -1))),
id)
}
}
@@ -360,7 +363,8 @@ private[ml] object DecisionTreeModelReadWrite {
// traversal, this guarantees that child nodes will be built before parent
nodes.
val finalNodes = new Array[Node](nodes.length)
nodes.reverseIterator.foreach { case n: NodeData =>
- val impurityStats = ImpurityCalculator.getCalculator(impurityType,
n.impurityStats)
+ val impurityStats =
+ ImpurityCalculator.getCalculator(impurityType, n.impurityStats,
n.rawCount)
val node = if (n.leftChild != -1) {
val leftChild = finalNodes(n.leftChild)
val rightChild = finalNodes(n.rightChild)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 5eb707dfe7bc3..ff5955d409804 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType,
StructType}
* Note: Marked as private and DeveloperApi since this may be made public in
the future.
*/
private[ml] trait DecisionTreeParams extends PredictorParams
- with HasCheckpointInterval with HasSeed {
+ with HasCheckpointInterval with HasSeed with HasWeightCol {
/**
* Maximum depth of the tree (>= 0).
@@ -71,6 +71,21 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" child to have fewer than minInstancesPerNode, the split will be
discarded as invalid." +
" Should be >= 1.", ParamValidators.gtEq(1))
+ /**
+ * Minimum fraction of the weighted sample count that each child must have
after split.
+ * If a split causes the fraction of the total weight in the left or right
child to be less than
+ * minWeightFractionPerNode, the split will be discarded as invalid.
+ * Should be in the interval [0.0, 0.5).
+ * (default = 0.0)
+ * @group param
+ */
+ final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
+ "minWeightFractionPerNode", "Minimum fraction of the weighted sample count
that each child " +
+ "must have after split. If a split causes the fraction of the total weight
in the left or " +
+ "right child to be less than minWeightFractionPerNode, the split will be
discarded as " +
+ "invalid. Should be in interval [0.0, 0.5)",
+ ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive =
false))
+
/**
* Minimum information gain for a split to be considered at a tree node.
* Should be >= 0.0.
@@ -104,8 +119,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" algorithm will cache node IDs for each instance. Caching can speed up
training of deeper" +
" trees.")
- setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
minInfoGain -> 0.0,
- maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+ setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
+ minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
+ cacheNodeIds -> false, checkpointInterval -> 10)
/**
* @deprecated This method is deprecated and will be removed in 2.2.0.
@@ -137,6 +153,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
/** @group getParam */
final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
+ /** @group getParam */
+ final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
+
/**
* @deprecated This method is deprecated and will be removed in 2.2.0.
* @group setParam
@@ -196,6 +215,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index d1331a57de27b..9e1401588142e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -23,6 +23,7 @@ import scala.util.Try
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.{Instance => NewInstance}
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel,
RandomForestParams => NewRFParams}
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -91,8 +92,11 @@ private class RandomForest (
* @return RandomForestModel that can be used for prediction.
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
- val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML),
strategy, numTrees,
- featureSubsetStrategy, seed.toLong, None)
+ val instances = input.map { case LabeledPoint(label, features) =>
+ NewInstance(label, 1.0, features.asML)
+ }
+ val trees: Array[NewDTModel] =
+ NewRandomForest.run(instances, strategy, numTrees,
featureSubsetStrategy, seed.toLong, None)
new RandomForestModel(strategy.algo, trees.map(_.toOld))
}
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 58e8f5be7b9f0..5806741b413be 100644
---
a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -80,7 +80,8 @@ class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
- @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends
Serializable {
+ @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
+ private[spark] var minWeightFractionPerNode: Double = 0.0) extends
Serializable {
/**
*/
@@ -108,7 +109,8 @@ class Strategy @Since("1.3.0") (
maxBins: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer,
java.lang.Integer]) {
this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int,
Int]].asScala.toMap)
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int,
Int]].asScala.toMap,
+ minWeightFractionPerNode = 0.0)
}
/**
@@ -171,8 +173,9 @@ class Strategy @Since("1.3.0") (
@Since("1.2.0")
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
- quantileCalculationStrategy, categoricalFeaturesInfo,
minInstancesPerNode, minInfoGain,
- maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
+ quantileCalculationStrategy, categoricalFeaturesInfo,
minInstancesPerNode,
+ minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
+ checkpointInterval, minWeightFractionPerNode)
}
}
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index d4448da9eef51..f01a98e74886b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -83,23 +83,29 @@ object Entropy extends Impurity {
* @param numClasses Number of classes for label.
*/
private[spark] class EntropyAggregator(numClasses: Int)
- extends ImpurityAggregator(numClasses) with Serializable {
+ extends ImpurityAggregator(numClasses + 1) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature,
bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double,
instanceWeight: Double): Unit = {
- if (label >= statsSize) {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ if (label >= numClasses) {
throw new IllegalArgumentException(s"EntropyAggregator given label
$label" +
- s" but requires label < numClasses (= $statsSize).")
+ s" but requires label < numClasses (= ${numClasses}).")
}
if (label < 0) {
throw new IllegalArgumentException(s"EntropyAggregator given label
$label" +
s"but requires label is non-negative.")
}
- allStats(offset + label.toInt) += instanceWeight
+ allStats(offset + label.toInt) += numSamples * sampleWeight
+ allStats(offset + statsSize - 1) += numSamples
}
/**
@@ -108,7 +114,8 @@ private[spark] class EntropyAggregator(numClasses: Int)
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator =
{
- new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new EntropyCalculator(allStats.view(offset, offset + statsSize -
1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -118,12 +125,13 @@ private[spark] class EntropyAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class EntropyCalculator(stats: Array[Double]) extends
ImpurityCalculator(stats) {
+private[spark] class EntropyCalculator(stats: Array[Double], var rawCount:
Long)
+ extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
+ def copy: EntropyCalculator = new EntropyCalculator(stats.clone(), rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -131,9 +139,9 @@ private[spark] class EntropyCalculator(stats:
Array[Double]) extends ImpurityCal
def calculate(): Double = Entropy.calculate(stats, stats.sum)
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats.sum.toLong
+ def count: Double = stats.sum
/**
* Prediction which should be made based on the sufficient statistics.
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c5e34ffa4f2e5..913ffbbb2457a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -80,23 +80,29 @@ object Gini extends Impurity {
* @param numClasses Number of classes for label.
*/
private[spark] class GiniAggregator(numClasses: Int)
- extends ImpurityAggregator(numClasses) with Serializable {
+ extends ImpurityAggregator(numClasses + 1) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature,
bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double,
instanceWeight: Double): Unit = {
- if (label >= statsSize) {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ if (label >= numClasses) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
- s" but requires label < numClasses (= $statsSize).")
+ s" but requires label < numClasses (= ${numClasses}).")
}
if (label < 0) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
- s"but requires label is non-negative.")
+ s"but requires label to be non-negative.")
}
- allStats(offset + label.toInt) += instanceWeight
+ allStats(offset + label.toInt) += numSamples * sampleWeight
+ allStats(offset + statsSize - 1) += numSamples
}
/**
@@ -105,7 +111,8 @@ private[spark] class GiniAggregator(numClasses: Int)
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
- new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -115,12 +122,13 @@ private[spark] class GiniAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class GiniCalculator(stats: Array[Double]) extends
ImpurityCalculator(stats) {
+private[spark] class GiniCalculator(stats: Array[Double], var rawCount: Long)
+ extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: GiniCalculator = new GiniCalculator(stats.clone())
+ def copy: GiniCalculator = new GiniCalculator(stats.clone(), rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -128,9 +136,9 @@ private[spark] class GiniCalculator(stats: Array[Double])
extends ImpurityCalcul
def calculate(): Double = Gini.calculate(stats, stats.sum)
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats.sum.toLong
+ def count: Double = stats.sum
/**
* Prediction which should be made based on the sufficient statistics.
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index a5bdc2c6d2c94..6a814e658caa5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -79,7 +79,12 @@ private[spark] abstract class ImpurityAggregator(val
statsSize: Int) extends Ser
* @param allStats Flat stats array, with stats for this (node, feature,
bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double,
instanceWeight: Double): Unit
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit
/**
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
@@ -120,6 +125,7 @@ private[spark] abstract class ImpurityCalculator(val stats:
Array[Double]) exten
stats(i) += other.stats(i)
i += 1
}
+ rawCount += other.rawCount
this
}
@@ -137,13 +143,19 @@ private[spark] abstract class ImpurityCalculator(val
stats: Array[Double]) exten
stats(i) -= other.stats(i)
i += 1
}
+ rawCount -= other.rawCount
this
}
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long
+ def count: Double
+
+ /**
+ * Raw number of data points accounted for in the sufficient statistics.
+ */
+ var rawCount: Long
/**
* Prediction which should be made based on the sufficient statistics.
@@ -183,11 +195,14 @@ private[spark] object ImpurityCalculator {
* Create an [[ImpurityCalculator]] instance of the given impurity type and
with
* the given stats.
*/
- def getCalculator(impurity: String, stats: Array[Double]):
ImpurityCalculator = {
+ def getCalculator(
+ impurity: String,
+ stats: Array[Double],
+ rawCount: Long): ImpurityCalculator = {
impurity match {
- case "gini" => new GiniCalculator(stats)
- case "entropy" => new EntropyCalculator(stats)
- case "variance" => new VarianceCalculator(stats)
+ case "gini" => new GiniCalculator(stats, rawCount)
+ case "entropy" => new EntropyCalculator(stats, rawCount)
+ case "variance" => new VarianceCalculator(stats, rawCount)
case _ =>
throw new IllegalArgumentException(
s"ImpurityCalculator builder did not recognize impurity type:
$impurity")
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index c9bf0db4de3c2..a07b919271f71 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -66,21 +66,32 @@ object Variance extends Impurity {
/**
* Class for updating views of a vector of sufficient statistics,
- * in order to compute impurity from a sample.
+ * in order to compute impurity from a sample. For variance, we track:
+ * - sum(w_i)
+ * - sum(w_i * y_i)
+ * - sum(w_i * y_i * y_i)
+ * - count(y_i)
* Note: Instances of this class do not hold the data; they operate on views
of the data.
*/
private[spark] class VarianceAggregator()
- extends ImpurityAggregator(statsSize = 3) with Serializable {
+ extends ImpurityAggregator(statsSize = 4) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature,
bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double,
instanceWeight: Double): Unit = {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ val instanceWeight = numSamples * sampleWeight
allStats(offset) += instanceWeight
allStats(offset + 1) += instanceWeight * label
allStats(offset + 2) += instanceWeight * label * label
+ allStats(offset + 3) += numSamples
}
/**
@@ -89,7 +100,8 @@ private[spark] class VarianceAggregator()
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator
= {
- new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new VarianceCalculator(allStats.view(offset, offset + statsSize -
1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -99,7 +111,8 @@ private[spark] class VarianceAggregator()
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class VarianceCalculator(stats: Array[Double]) extends
ImpurityCalculator(stats) {
+private[spark] class VarianceCalculator(stats: Array[Double], var rawCount:
Long)
+ extends ImpurityCalculator(stats) {
require(stats.length == 3,
s"VarianceCalculator requires sufficient statistics array stats to be of
length 3," +
@@ -108,7 +121,7 @@ private[spark] class VarianceCalculator(stats:
Array[Double]) extends ImpurityCa
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
+ def copy: VarianceCalculator = new VarianceCalculator(stats.clone(),
rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -116,9 +129,9 @@ private[spark] class VarianceCalculator(stats:
Array[Double]) extends ImpurityCa
def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats(0).toLong
+ def count: Double = stats(0)
/**
* Prediction which should be made based on the sufficient statistics.
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index c711e7fa9dc67..a48050c4a25ff 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,8 +21,9 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
+import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode}
import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
@@ -42,6 +43,9 @@ class DecisionTreeClassifierSuite
private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD:
RDD[LabeledPoint] = _
+ private var linearMulticlassDataset: DataFrame = _
+
+ private val seed = 42
override def beforeAll() {
super.beforeAll()
@@ -58,6 +62,20 @@ class DecisionTreeClassifierSuite
categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize(
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
.map(_.asML)
+ linearMulticlassDataset = {
+ val nPoints = 100
+ val coefficients = Array(
+ -0.57997, 0.912083, -0.371077,
+ -0.16624, -0.84355, -0.048509)
+
+ val xMean = Array(5.843, 3.057)
+ val xVariance = Array(0.6856, 0.1899)
+
+ val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
+ coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
+
+ sc.parallelize(testData, 4).toDF()
+ }
}
test("params") {
@@ -246,7 +264,8 @@ class DecisionTreeClassifierSuite
val categoricalFeatures = Map(0 -> 3)
val numClasses = 3
- val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures,
numClasses)
+ val newData: DataFrame =
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), categoricalFeatures,
numClasses)
val newTree = dt.fit(newData)
// copied model must have the same parent.
@@ -273,7 +292,7 @@ class DecisionTreeClassifierSuite
LabeledPoint(1, Vectors.dense(0, 3, 9)),
LabeledPoint(0, Vectors.dense(0, 2, 6))
))
- val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
+ val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)), Map(0 -> 1), 2)
val dt = new DecisionTreeClassifier().setMaxDepth(3)
dt.fit(df)
}
@@ -295,7 +314,7 @@ class DecisionTreeClassifierSuite
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val data = sc.parallelize(arr)
- val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
+ val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)), Map(0 -> 3), 2)
// Must set maxBins s.t. the feature will be treated as an ordered
categorical feature.
val dt = new DecisionTreeClassifier()
@@ -326,7 +345,7 @@ class DecisionTreeClassifierSuite
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val numFeatures = data.first().features.size
val categoricalFeatures = (0 to numFeatures).map(i => (i, 2)).toMap
- val df = TreeTests.setMetadata(data, categoricalFeatures, 2)
+ val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)),
categoricalFeatures, 2)
val model = dt.fit(df)
@@ -351,6 +370,36 @@ class DecisionTreeClassifierSuite
dt.fit(df)
}
+ test("training with sample weights") {
+ val df = linearMulticlassDataset
+ val numClasses = 3
+ val predEquals = (x: Double, y: Double) => x == y
+ // (impurity, maxDepth)
+ val testParams = Seq(
+ ("gini", 10),
+ ("entropy", 10),
+ ("gini", 5)
+ )
+ for ((impurity, maxDepth) <- testParams) {
+ val estimator = new DecisionTreeClassifier()
+ .setMaxDepth(maxDepth)
+ .setSeed(seed)
+ .setMinWeightFractionPerNode(0.049)
+ .setImpurity(impurity)
+
+
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, predEquals, 0.9))
+
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
+ outlierRatio = 2)
+
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, predEquals, 1.0), seed)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
@@ -371,12 +420,12 @@ class DecisionTreeClassifierSuite
// Categorical splits with tree depth 2
val categoricalData: DataFrame =
- TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map(0 -> 2, 1 -> 3),
numClasses = 2)
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
checkModelData)
// Continuous splits with tree depth 2
val continuousData: DataFrame =
- TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int],
numClasses = 2)
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings,
checkModelData)
// Continuous splits with tree depth 0
@@ -399,7 +448,8 @@ private[ml] object DecisionTreeClassifierSuite extends
SparkFunSuite {
val numFeatures = data.first().features.size
val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
val oldTree = OldDecisionTree.train(data.map(OldLabeledPoint.fromML),
oldStrategy)
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures,
numClasses)
val newTree = dt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 0598943c3d4be..b6dbcbff3eca6 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -251,7 +251,8 @@ class GBTClassifierSuite extends SparkFunSuite with
MLlibTestSparkContext
sc.setCheckpointDir(path)
val categoricalFeatures = Map.empty[Int, Int]
- val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses = 2)
+ val df: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures,
numClasses = 2)
val gbt = new GBTClassifier()
.setMaxDepth(2)
.setLossType("logistic")
@@ -346,7 +347,8 @@ class GBTClassifierSuite extends SparkFunSuite with
MLlibTestSparkContext
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
- val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses)
+ val df: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures,
numClasses)
val importances = gbt.fit(df).featureImportances
val mostImportantFeature = importances.argmax
@@ -373,7 +375,7 @@ class GBTClassifierSuite extends SparkFunSuite with
MLlibTestSparkContext
val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" ->
"logistic")
val continuousData: DataFrame =
- TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int],
numClasses = 2)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
checkModelData)
}
}
@@ -394,7 +396,8 @@ private object GBTClassifierSuite extends SparkFunSuite {
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML))
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses = 2)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures,
numClasses = 2)
val newModel = gbt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index ee2aefee7a6db..47c393bcdf7bc 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -134,7 +134,7 @@ class LinearSVCSuite extends SparkFunSuite with
MLlibTestSparkContext with Defau
MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC](
dataset.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC](
- dataset.as[LabeledPoint], estimator, 2, modelEquals)
+ dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC](
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 43547a4aafcb9..d5d8fabaafa42 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -1850,7 +1850,7 @@ class LogisticRegressionSuite
MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel,
LogisticRegression](
dataset.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel,
LogisticRegression](
- dataset.as[LabeledPoint], estimator, numClasses, modelEquals)
+ dataset.as[LabeledPoint], estimator, numClasses, modelEquals,
outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel,
LogisticRegression](
dataset.as[LabeledPoint], estimator, modelEquals, seed)
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 37d7991fe8dd8..cc68562a69750 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -178,7 +178,7 @@ class NaiveBayesSuite extends SparkFunSuite with
MLlibTestSparkContext with Defa
MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes](
dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes](
- dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses,
modelEquals)
+ dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses,
modelEquals, outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes](
dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed)
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 44e1585ee514b..f9d3322770583 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -18,17 +18,17 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest =>
OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -129,7 +129,7 @@ class RandomForestClassifierSuite
}
test("predictRaw and predictProbability") {
- val rdd = orderedLabeledPoints5_20
+ val rdd = orderedLabeledPoints5_20.map(_.toInstance(1.0))
val rf = new RandomForestClassifier()
.setImpurity("Gini")
.setMaxDepth(3)
@@ -167,7 +167,6 @@ class RandomForestClassifierSuite
/////////////////////////////////////////////////////////////////////////////
// Tests of feature importance
/////////////////////////////////////////////////////////////////////////////
-
test("Feature importance with toy data") {
val numClasses = 2
val rf = new RandomForestClassifier()
@@ -179,7 +178,7 @@ class RandomForestClassifierSuite
.setSeed(123)
// In this data, feature 1 is very important.
- val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val data: RDD[Instance] =
TreeTests.featureImportanceData(sc).map(_.toInstance(1.0))
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses)
@@ -212,7 +211,7 @@ class RandomForestClassifierSuite
}
val rf = new RandomForestClassifier().setNumTrees(2)
- val rdd = TreeTests.getTreeReadWriteData(sc)
+ val rdd = TreeTests.getTreeReadWriteData(sc).map(_.toInstance(1.0))
val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" ->
"entropy")
@@ -239,7 +238,8 @@ private object RandomForestClassifierSuite extends
SparkFunSuite {
val oldModel = OldRandomForest.trainClassifier(
data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees,
rf.getFeatureSubsetStrategy,
rf.getSeed.toInt)
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures,
numClasses)
val newModel = rf.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 15fa26e8b5272..1238b2fa8edbc 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -18,15 +18,14 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
- DecisionTreeSuite => OldDecisionTreeSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -34,13 +33,20 @@ class DecisionTreeRegressorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import DecisionTreeRegressorSuite.compareAPIs
+ import testImplicits._
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+ private var linearRegressionData: DataFrame = _
+
+ private val seed = 42
override def beforeAll() {
super.beforeAll()
categoricalDataPointsRDD =
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
+ linearRegressionData =
sc.parallelize(LinearDataGenerator.generateLinearInput(
+ intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
+ xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5),
2).map(_.asML).toDF()
}
/////////////////////////////////////////////////////////////////////////////
@@ -68,7 +74,8 @@ class DecisionTreeRegressorSuite
test("copied model must have the same parent") {
val categoricalFeatures = Map(0 -> 2, 1 -> 2)
- val df = TreeTests.setMetadata(categoricalDataPointsRDD,
categoricalFeatures, numClasses = 0)
+ val df =
TreeTests.setMetadata(categoricalDataPointsRDD.map(_.toInstance(1.0)),
+ categoricalFeatures, numClasses = 0)
val model = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(2)
@@ -85,7 +92,8 @@ class DecisionTreeRegressorSuite
.setVarianceCol("variance")
val categoricalFeatures = Map(0 -> 2, 1 -> 2)
- val df = TreeTests.setMetadata(categoricalDataPointsRDD,
categoricalFeatures, numClasses = 0)
+ val df =
TreeTests.setMetadata(categoricalDataPointsRDD.map(_.toInstance(1.0)),
+ categoricalFeatures, numClasses = 0)
val model = dt.fit(df)
val predictions = model.transform(df)
@@ -98,7 +106,7 @@ class DecisionTreeRegressorSuite
s"Expected variance $expectedVariance but got $variance.")
}
- val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
+ val varianceData: RDD[Instance] =
TreeTests.varianceData(sc).map(_.toInstance(1.0))
val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int],
0)
dt.setMaxDepth(1)
.setMaxBins(6)
@@ -125,7 +133,7 @@ class DecisionTreeRegressorSuite
.setSeed(123)
// In this data, feature 1 is very important.
- val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val data: RDD[Instance] =
TreeTests.featureImportanceData(sc).map(_.toInstance(1.0))
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
@@ -146,6 +154,27 @@ class DecisionTreeRegressorSuite
}
}
+ test("training with sample weights") {
+ val df = linearRegressionData
+ val numClasses = 0
+ val testParams = Seq(5, 10)
+ for (maxDepth <- testParams) {
+ val estimator = new DecisionTreeRegressor()
+ .setMaxDepth(maxDepth)
+ .setMinWeightFractionPerNode(0.05)
+ MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _,
0.05), 0.9))
+ MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses,
+ MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _,
0.1), 0.8),
+ outlierRatio = 2)
+ MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _,
0.01), 1.0), seed)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
@@ -159,7 +188,7 @@ class DecisionTreeRegressorSuite
}
val dt = new DecisionTreeRegressor()
- val rdd = TreeTests.getTreeReadWriteData(sc)
+ val rdd = TreeTests.getTreeReadWriteData(sc).map(_.toInstance(1.0))
// Categorical splits with tree depth 2
val categoricalData: DataFrame =
@@ -192,7 +221,8 @@ private[ml] object DecisionTreeRegressorSuite extends
SparkFunSuite {
val numFeatures = data.first().features.size
val oldStrategy = dt.getOldStrategy(categoricalFeatures)
val oldTree = OldDecisionTree.train(data.map(OldLabeledPoint.fromML),
oldStrategy)
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses = 0)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures,
numClasses = 0)
val newTree = dt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dcf3f9a1ea9b2..e805e42649e61 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -157,7 +157,7 @@ class GBTRegressorSuite extends SparkFunSuite with
MLlibTestSparkContext
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
- val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+ val df: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)),
categoricalFeatures, 0)
val importances = gbt.fit(df).featureImportances
val mostImportantFeature = importances.argmax
@@ -183,7 +183,7 @@ class GBTRegressorSuite extends SparkFunSuite with
MLlibTestSparkContext
val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" ->
"squared")
val continuousData: DataFrame =
- TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int],
numClasses = 0)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
checkModelData)
}
}
@@ -203,7 +203,8 @@ private object GBTRegressorSuite extends SparkFunSuite {
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures,
OldAlgo.Regression)
val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML))
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses = 0)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures,
numClasses = 0)
val newModel = gbt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTRegressionModel.fromOld(
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 584a1b272f6c8..df3dbd0af3f39 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -839,10 +839,12 @@ class LinearRegressionSuite
.setStandardization(standardization)
.setRegParam(regParam)
.setElasticNetParam(elasticNetParam)
+ .setSolver(solver)
MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel,
LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel,
LinearRegression](
- datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses,
modelEquals)
+ datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses,
modelEquals,
+ outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel,
LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed)
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f9f84af..01207931f57ba 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest =>
OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -28,6 +29,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+
/**
* Test suite for [[RandomForestRegressor]].
*/
@@ -86,7 +88,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with
MLlibTestSparkContex
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
- val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+ val df: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)),
categoricalFeatures, 0)
val model = rf.fit(df)
@@ -123,7 +125,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with
MLlibTestSparkContex
val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" ->
"variance")
val continuousData: DataFrame =
- TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int],
numClasses = 0)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
checkModelData)
}
}
@@ -143,7 +145,8 @@ private object RandomForestRegressorSuite extends
SparkFunSuite {
rf.getOldStrategy(categoricalFeatures, numClasses = 0,
OldAlgo.Regression, rf.getOldImpurity)
val oldModel =
OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy,
rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures,
numClasses = 0)
+ val newData: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)),
+ categoricalFeatures, numClasses = 0)
val newModel = rf.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestRegressionModel.fromOld(
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
index 77ab3d8bb75f7..0b09577171c41 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.mllib.tree.EnsembleTestHelper
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -26,12 +27,16 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
*/
class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
- test("BaggedPoint RDD: without subsampling") {
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ test("BaggedPoint RDD: without subsampling with weights") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map {
lp =>
+ Instance(lp.label, 0.5, lp.features.asML)
+ }
val rdd = sc.parallelize(arr)
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false,
+ (instance: Instance) => instance.weight * 4.0, seed = 42)
baggedRDD.collect().foreach { baggedPoint =>
- assert(baggedPoint.subsampleWeights.size == 1 &&
baggedPoint.subsampleWeights(0) == 1)
+ assert(baggedPoint.subsampleCounts.size == 1 &&
baggedPoint.subsampleCounts(0) == 1)
+ assert(baggedPoint.sampleWeight === 2.0)
}
}
@@ -40,13 +45,17 @@ class BaggedPointSuite extends SparkFunSuite with
MLlibTestSparkContext {
val (expectedMean, expectedStddev) = (1.0, 1.0)
val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1,
1000).map(_.asML)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples,
true, seed)
- val subsampleCounts: Array[Array[Double]] =
baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples,
true,
+ (_: LabeledPoint) => 2.0, seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples,
expectedMean,
expectedStddev, epsilon = 0.01)
+ // should ignore weight function for now
+ assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
}
}
@@ -59,8 +68,10 @@ class BaggedPointSuite extends SparkFunSuite with
MLlibTestSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample,
numSubsamples, true, seed)
- val subsampleCounts: Array[Array[Double]] =
baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD =
+ BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true,
seed = seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples,
expectedMean,
expectedStddev, epsilon = 0.01)
}
@@ -71,13 +82,17 @@ class BaggedPointSuite extends SparkFunSuite with
MLlibTestSparkContext {
val (expectedMean, expectedStddev) = (1.0, 0)
val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1,
1000).map(_.asML)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples,
false, seed)
- val subsampleCounts: Array[Array[Double]] =
baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples,
false,
+ (_: LabeledPoint) => 2.0, seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples,
expectedMean,
expectedStddev, epsilon = 0.01)
+ // should ignore weight function for now
+ assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
}
}
@@ -90,8 +105,10 @@ class BaggedPointSuite extends SparkFunSuite with
MLlibTestSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample,
numSubsamples, false, seed)
- val subsampleCounts: Array[Array[Double]] =
baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample,
numSubsamples, false,
+ seed = seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples,
expectedMean,
expectedStddev, epsilon = 0.01)
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index e1ab7c2d6520b..775ca669d527b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -18,10 +18,11 @@
package org.apache.spark.ml.tree.impl
import scala.collection.mutable
+import scala.language.implicitConversions
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.TestingUtils._
@@ -43,7 +44,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
/////////////////////////////////////////////////////////////////////////////
test("Binary classification with continuous features: split calculation") {
- val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML)
+ val arr =
OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
@@ -55,7 +56,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
}
test("Binary classification with binary (ordered) categorical features:
split calculation") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr =
OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2,
numClasses = 2,
@@ -72,7 +73,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
test("Binary classification with 3-ary (ordered) categorical features," +
" with no samples for one category: split calculation") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr =
OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2,
numClasses = 2,
@@ -90,12 +91,12 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
test("find splits for a continuous feature") {
// find splits for normal case
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array.fill(200000)(math.random)
+ val featureSamples = Array.fill(200000)((1.0, math.random))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples,
fakeMetadata, 0)
assert(splits.length === 5)
assert(fakeMetadata.numSplits(0) === 5)
@@ -107,12 +108,12 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
// find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits
in metadata
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
3).map(_.toDouble)
+ val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(x =>
(1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples,
fakeMetadata, 0)
assert(splits === Array(1.0, 2.0))
// check returned splits are distinct
@@ -121,47 +122,67 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
// find splits when most samples close to the minimum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4,
5).map(_.toDouble)
+ val featureSamples =
+ Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0,
x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples,
fakeMetadata, 0)
assert(splits === Array(2.0, 3.0))
}
// find splits when most samples close to the maximum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(2), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2).map(_.toDouble)
+ val featureSamples =
+ Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0,
x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples,
fakeMetadata, 0)
assert(splits === Array(1.0))
}
- // find splits for constant feature
+ // find splits for arbitrarily scaled data
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
+ Map(), Set(),
+ Array(6), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0.0, 0, 0
+ )
+ val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random))
+ val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w,
x) => (w * 0.001, x)}
+ val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w,
x) => (w * 1000, x)}
+ val splitsUnitWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesUnitWeight,
fakeMetadata, 0)
+ val splitsSmallWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesSmallWeight,
fakeMetadata, 0)
+ val splitsLargeWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesLargeWeight,
fakeMetadata, 0)
+ assert(splitsUnitWeight === splitsSmallWeight)
+ assert(splitsUnitWeight === splitsLargeWeight)
+ }
+
+ // find splits when most weight is close to the minimum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(0, 0, 0).map(_.toDouble)
- val featureSamplesEmpty = Array.empty[Double]
+ val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1,
6)).map {
+ case (w, x) => (w.toDouble, x.toDouble)
+ }
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples,
fakeMetadata, 0)
- assert(splits === Array.empty[Double])
- val splitsEmpty =
- RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty,
fakeMetadata, 0)
- assert(splitsEmpty === Array.empty[Double])
+ assert(splits === Array(1.0, 2.0))
}
}
test("train with empty arrays") {
- val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double]))
+ val lp = LabeledPoint(1.0,
Vectors.dense(Array.empty[Double])).toInstance(1.0)
val data = Array.fill(5)(lp)
val rdd = sc.parallelize(data)
@@ -176,8 +197,8 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
}
test("train with constant features") {
- val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
- val data = Array.fill(5)(lp)
+ val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0,
0.0)).toInstance(1.0)
+ val data = Array.fill(5)(instance)
val rdd = sc.parallelize(data)
val strategy = new OldStrategy(
OldAlgo.Classification,
@@ -189,7 +210,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr =
None)
assert(tree.rootNode.impurity === -1.0)
assert(tree.depth === 0)
- assert(tree.rootNode.prediction === lp.label)
+ assert(tree.rootNode.prediction === instance.label)
// Test with no categorical features
val strategy2 = new OldStrategy(
@@ -200,11 +221,11 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr =
None)
assert(tree2.rootNode.impurity === -1.0)
assert(tree2.depth === 0)
- assert(tree2.rootNode.prediction === lp.label)
+ assert(tree2.rootNode.prediction === instance.label)
}
test("Multiclass classification with unordered categorical features: split
calculations") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr =
OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(
@@ -245,7 +266,8 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
}
test("Multiclass classification with ordered categorical features: split
calculations") {
- val arr =
OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML)
+ val arr =
OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ .map(_.asML.toInstance(1.0))
assert(arr.length === 3000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2,
numClasses = 100,
@@ -277,7 +299,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance(1.0)))
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@@ -319,7 +341,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance(1.0)))
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 5,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@@ -371,7 +393,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance(1.0)))
// Must set maxBins s.t. the feature will be treated as an ordered
categorical feature.
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 1,
@@ -390,7 +412,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
}
test("Second level node building with vs. without groups") {
- val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML)
+ val arr =
OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
// For tree with 1 group
@@ -434,7 +456,7 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
def
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy:
OldStrategy) {
val numFeatures = 50
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures,
1000)
- val rdd = sc.parallelize(arr).map(_.asML)
+ val rdd = sc.parallelize(arr).map(_.asML.toInstance(1.0))
// Select feature subset for top nodes. Return true if OK.
def checkFeatureSubsetStrategy(
@@ -547,16 +569,16 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
left2 parent
left right
*/
- val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
+ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0), 6L)
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
- val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
+ val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0), 8L)
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0,
0.5))
val parentImp = parent.impurityStats
- val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
+ val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0), 8L)
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
val grandParent = TreeTests.buildParentNode(left2, parent, new
ContinuousSplit(1, 1.0))
@@ -602,6 +624,57 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
+
+ test("weights at arbitrary scale") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
+ val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance(1.0)))
+ val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
+ Instance(inst.label, 0.001, inst.features)
+ }
+ val rddWithBigWeights = rddWithUnitWeights.map { inst =>
+ Instance(inst.label, 1000, inst.features)
+ }
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
+ val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3,
"all", 42L, None)
+
+ val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3,
"all", 42L, None)
+ unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree,
smallWeightTree) =>
+ TreeTests.checkEqual(unitTree, smallWeightTree)
+ }
+
+ val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3,
"all", 42L, None)
+ unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree,
bigWeightTree) =>
+ TreeTests.checkEqual(unitTree, bigWeightTree)
+ }
+ }
+
+ test("minWeightFraction and minInstancesPerNode") {
+ val data = Array(
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(1.0, 0.1, Vectors.dense(1.0))
+ )
+ val rdd = sc.parallelize(data)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
+ minWeightFractionPerNode = 0.5)
+ val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree1.depth == 0)
+
+ strategy.minWeightFractionPerNode = 0.0
+ val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree2.depth == 1)
+
+ strategy.minInstancesPerNode = 2
+ val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree3.depth == 0)
+
+ strategy.minInstancesPerNode = 1
+ val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree4.depth == 1)
+ }
+
}
private object RandomForestSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index c90cb8ca1034c..999aa80b7750d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -18,13 +18,15 @@
package org.apache.spark.ml.tree.impl
import scala.collection.JavaConverters._
+import scala.util.Random
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute,
NumericAttribute}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
@@ -32,6 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/**
* Convert the given data to a DataFrame, and set the features and label
metadata.
+ *
* @param data Dataset. Categorical features and labels must already have
0-based indices.
* This must be non-empty.
* @param categoricalFeatures Map: categorical feature index -> number of
distinct values
@@ -39,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite {
* @return DataFrame with metadata
*/
def setMetadata(
- data: RDD[LabeledPoint],
+ data: RDD[Instance],
categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = {
val spark = SparkSession.builder()
@@ -66,7 +69,7 @@ private[ml] object TreeTests extends SparkFunSuite {
}
val labelMetadata = labelAttribute.toMetadata()
df.select(df("features").as("features", featuresMetadata),
- df("label").as("label", labelMetadata))
+ df("label").as("label", labelMetadata), df("weight"))
}
/** Java-friendly version of [[setMetadata()]] */
@@ -74,12 +77,14 @@ private[ml] object TreeTests extends SparkFunSuite {
data: JavaRDD[LabeledPoint],
categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer],
numClasses: Int): DataFrame = {
- setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int,
Int]].asScala.toMap,
+ setMetadata(data.rdd.map(_.toInstance(1.0)),
+ categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numClasses)
}
/**
* Set label metadata (particularly the number of classes) on a DataFrame.
+ *
* @param data Dataset. Categorical features and labels must already have
0-based indices.
* This must be non-empty.
* @param numClasses Number of classes label can take. If 0, mark as
continuous.
@@ -124,8 +129,8 @@ private[ml] object TreeTests extends SparkFunSuite {
* make mistakes such as creating loops of Nodes.
*/
private def checkEqual(a: Node, b: Node): Unit = {
- assert(a.prediction === b.prediction)
- assert(a.impurity === b.impurity)
+ assert(a.prediction ~== b.prediction absTol 1e-8)
+ assert(a.impurity ~== b.impurity absTol 1e-8)
(a, b) match {
case (aye: InternalNode, bee: InternalNode) =>
assert(aye.split === bee.split)
@@ -156,6 +161,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/**
* Helper method for constructing a tree for testing.
* Given left, right children, construct a parent node.
+ *
* @param split Split for parent node
* @return Parent node with children attached
*/
@@ -163,8 +169,8 @@ private[ml] object TreeTests extends SparkFunSuite {
val leftImp = left.impurityStats
val rightImp = right.impurityStats
val parentImp = leftImp.copy.add(rightImp)
- val leftWeight = leftImp.count / parentImp.count.toDouble
- val rightWeight = rightImp.count / parentImp.count.toDouble
+ val leftWeight = leftImp.count / parentImp.count
+ val rightWeight = rightImp.count / parentImp.count
val gain = parentImp.calculate() -
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
val pred = parentImp.predict
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index f1ed568d5e60a..4b1bd35313c5a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol,
HasWeightCol}
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
@@ -246,8 +246,8 @@ object MLTestingUtils extends SparkFunSuite {
seed: Long): Unit = {
val (overSampledData, weightedData) =
genEquivalentOversampledAndWeightedInstances(
data, seed)
- val weightedModel = estimator.set(estimator.weightCol,
"weight").fit(weightedData)
val overSampledModel = estimator.set(estimator.weightCol,
"").fit(overSampledData)
+ val weightedModel = estimator.set(estimator.weightCol,
"weight").fit(weightedData)
modelEquals(weightedModel, overSampledModel)
}
@@ -260,15 +260,17 @@ object MLTestingUtils extends SparkFunSuite {
data: Dataset[LabeledPoint],
estimator: E with HasWeightCol,
numClasses: Int,
- modelEquals: (M, M) => Unit): Unit = {
+ modelEquals: (M, M) => Unit,
+ outlierRatio: Int): Unit = {
import data.sqlContext.implicits._
val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap {
case Instance(l, w, f) =>
val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1
- List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w,
f))
+ List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++
List(Instance(l, w, f))
}
val trueModel = estimator.set(estimator.weightCol, "").fit(data)
- val outlierModel = estimator.set(estimator.weightCol,
"weight").fit(outlierDS)
+ val outlierModel = estimator.set(estimator.weightCol, "weight")
+ .fit(outlierDS)
modelEquals(trueModel, outlierModel)
}
@@ -281,10 +283,26 @@ object MLTestingUtils extends SparkFunSuite {
estimator: E with HasWeightCol,
modelEquals: (M, M) => Unit): Unit = {
estimator.set(estimator.weightCol, "weight")
- val models = Seq(0.001, 1.0, 1000.0).map { w =>
+ val models = Seq(0.01, 1.0, 1000.0).map { w =>
val df = data.withColumn("weight", lit(w))
estimator.fit(df)
}
models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)}
}
+
+ def modelPredictionEquals[M <: PredictionModel[_, M]](
+ data: DataFrame,
+ compareFunc: (Double, Double) => Boolean,
+ fractionInTol: Double)(
+ model1: M,
+ model2: M): Unit = {
+ val pred1 =
model1.transform(data).select(model1.getPredictionCol).collect()
+ val pred2 =
model2.transform(data).select(model2.getPredictionCol).collect()
+ val inTol = pred1.zip(pred2).count { case (p1, p2) =>
+ val x = p1.getDouble(0)
+ val y = p2.getDouble(0)
+ compareFunc(x, y)
+ }
+ assert(inTol / pred1.length.toDouble >= fractionInTol)
+ }
}
diff --git
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 441d0f7614bf6..f6efc84a24185 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 4,
numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(strategy.isMulticlassClassification)
assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1))
@@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
numClasses = 3, maxBins = maxBins,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1))
@@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 4,
numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(metadata.isUnordered(featureIndex = 0))
val model = DecisionTree.train(rdd, strategy)
@@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with
MLlibTestSparkContext {
numClasses = 3, maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML),
strategy)
+ val metadata =
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
diff --git
a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
index 14152cdd63bc7..d4171cf441e96 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -18,23 +18,62 @@
package org.apache.spark.mllib.tree
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.tree.impurity._
/**
* Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
*/
class ImpuritySuite extends SparkFunSuite {
+
+ private val seed = 42
+
test("Gini impurity does not support negative labels") {
val gini = new GiniAggregator(2)
intercept[IllegalArgumentException] {
- gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ gini.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
}
}
test("Entropy does not support negative labels") {
val entropy = new EntropyAggregator(2)
intercept[IllegalArgumentException] {
- entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
+ }
+ }
+
+ test("Classification impurities are insensitive to scaling") {
+ val rng = new scala.util.Random(seed)
+ val weightedCounts = Array.fill(5)(rng.nextDouble())
+ val smallWeightedCounts = weightedCounts.map(_ * 0.0001)
+ val largeWeightedCounts = weightedCounts.map(_ * 10000)
+ Seq(Gini, Entropy).foreach { impurity =>
+ val impurity1 = impurity.calculate(weightedCounts, weightedCounts.sum)
+ assert(impurity.calculate(smallWeightedCounts, smallWeightedCounts.sum)
+ ~== impurity1 relTol 0.005)
+ assert(impurity.calculate(largeWeightedCounts, largeWeightedCounts.sum)
+ ~== impurity1 relTol 0.005)
}
}
+ test("Regression impurities are insensitive to scaling") {
+ def computeStats(samples: Seq[Double], weights: Seq[Double]): (Double,
Double, Double) = {
+ samples.zip(weights).foldLeft((0.0, 0.0, 0.0)) { case ((wn, wy, wyy),
(y, w)) =>
+ (wn + w, wy + w * y, wyy + w * y * y)
+ }
+ }
+ val rng = new scala.util.Random(seed)
+ val samples = Array.fill(10)(rng.nextDouble())
+ val _weights = Array.fill(10)(rng.nextDouble())
+ val smallWeights = _weights.map(_ * 0.0001)
+ val largeWeights = _weights.map(_ * 10000)
+ val (count, sum, sumSquared) = computeStats(samples, _weights)
+ Seq(Variance).foreach { impurity =>
+ val impurity1 = impurity.calculate(count, sum, sumSquared)
+ val (smallCount, smallSum, smallSumSquared) = computeStats(samples,
smallWeights)
+ val (largeCount, largeSum, largeSumSquared) = computeStats(samples,
largeWeights)
+ assert(impurity.calculate(smallCount, smallSum, smallSumSquared) ~==
impurity1 relTol 0.005)
+ assert(impurity.calculate(largeCount, largeSum, largeSumSquared) ~==
impurity1 relTol 0.005)
+ }
+ }
+
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7e6e143523387..c41b67b8f6bf0 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,9 @@ object MimaExcludes {
// Exclude rules for 2.2.x
lazy val v22excludes = v21excludes ++ Seq(
+ // [SPARK-9478][ML] Add sample weights to decision trees
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.this"),
+
// [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"),
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]