srowen closed pull request #16722: [SPARK-19591][ML][MLlib] Add sample weights 
to decision trees
URL: https://github.com/apache/spark/pull/16722
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala 
b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 2327917e2cad7..94158bf5d6e30 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -31,7 +31,7 @@ object TestingUtils {
    * Note that if x or y is extremely close to zero, i.e., smaller than 
Double.MinPositiveValue,
    * the relative tolerance is meaningless, so the exception will be raised to 
warn users.
    */
-  private def RelativeErrorComparison(x: Double, y: Double, eps: Double): 
Boolean = {
+  private[ml] def RelativeErrorComparison(x: Double, y: Double, eps: Double): 
Boolean = {
     val absX = math.abs(x)
     val absY = math.abs(y)
     val diff = math.abs(x - y)
@@ -48,7 +48,7 @@ object TestingUtils {
   /**
    * Private helper function for comparing two values using absolute tolerance.
    */
-  private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): 
Boolean = {
+  private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): 
Boolean = {
     math.abs(x - y) < eps
   }
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 9f60f0896ec52..6a9b3564d63fc 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -22,18 +22,21 @@ import org.json4s.{DefaultFormats, JObject}
 import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
 import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
 import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
 import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => 
OldStrategy}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => 
OldDecisionTreeModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dataset
-
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
 
 /**
  * Decision tree learning algorithm 
(http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -65,6 +68,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
   override def setMinInstancesPerNode(value: Int): this.type = 
set(minInstancesPerNode, value)
 
   /** @group setParam */
+  @Since("2.2.0")
+  def setMinWeightFractionPerNode(value: Double): this.type = 
set(minWeightFractionPerNode, value)
+
   @Since("1.4.0")
   override def setMinInfoGain(value: Double): this.type = set(minInfoGain, 
value)
 
@@ -96,6 +102,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
   @Since("1.6.0")
   override def setSeed(value: Long): this.type = set(seed, value)
 
+  /**
+   * Sets the value of param [[weightCol]].
+   * If this is not set or empty, we treat all instance weights as 1.0.
+   * Default is not set, so all instances have weight one.
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+
   override protected def train(dataset: Dataset[_]): 
DecisionTreeClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -106,14 +122,23 @@ class DecisionTreeClassifier @Since("1.4.0") (
         ".train() called with non-matching numClasses and thresholds.length." +
         s" numClasses=$numClasses, but thresholds has length 
${$(thresholds).length}")
     }
-
-    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, 
numClasses)
+    require(numClasses > 0, s"DecisionTreeClassifier (in extractLabeledPoints) 
found numClasses =" +
+      s" $numClasses, but requires numClasses > 0.")
+    val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else 
col($(weightCol))
+    val instances =
+      dataset.select(col($(labelCol)).cast(DoubleType), w, 
col($(featuresCol))).rdd.map {
+        case Row(label: Double, weight: Double, features: Vector) =>
+          require(label % 1 == 0 && label >= 0 && label < numClasses, 
s"Classifier was given" +
+            s" dataset with invalid label $label.  Labels must be integers in 
range" +
+            s" [0, $numClasses).")
+          Instance(label, weight, features)
+      }
     val strategy = getOldStrategy(categoricalFeatures, numClasses)
 
-    val instr = Instrumentation.create(this, oldDataset)
+    val instr = Instrumentation.create(this, instances)
     instr.logParams(params: _*)
 
-    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
+    val trees = RandomForest.run(instances, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
       seed = $(seed), instr = Some(instr), parentUID = Some(uid))
 
     val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
@@ -124,11 +149,12 @@ class DecisionTreeClassifier @Since("1.4.0") (
   /** (private[ml]) Train a decision tree on an RDD */
   private[ml] def train(data: RDD[LabeledPoint],
       oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
-    val instr = Instrumentation.create(this, data)
-    instr.logParams(params: _*)
 
-    val trees = RandomForest.run(data, oldStrategy, numTrees = 1, 
featureSubsetStrategy = "all",
-      seed = 0L, instr = Some(instr), parentUID = Some(uid))
+    val instances = data.map(_.toInstance(1.0))
+    val instr = Instrumentation.create(this, instances)
+    instr.logParams(params: _*)
+    val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+      featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID 
= Some(uid))
 
     val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
     instr.logSuccess(m)
@@ -176,6 +202,7 @@ class DecisionTreeClassificationModel private[ml] (
 
   /**
    * Construct a decision tree classification model.
+   *
    * @param rootNode  Root node of tree, with other nodes attached.
    */
   private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index ce834f1d17e0d..5674e48134928 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -21,19 +21,20 @@ import org.json4s.{DefaultFormats, JObject}
 import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, 
TreeEnsembleModel}
 import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
 import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => 
OldRandomForestModel}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.functions.{col, udf}
 
 /**
  * <a href="http://en.wikipedia.org/wiki/Random_forest";>Random Forest</a> 
learning algorithm for
@@ -126,20 +127,20 @@ class RandomForestClassifier @Since("1.4.0") (
         s" numClasses=$numClasses, but thresholds has length 
${$(thresholds).length}")
     }
 
-    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, 
numClasses)
+    val instances: RDD[Instance] = extractLabeledPoints(dataset, 
numClasses).map(_.toInstance(1.0))
     val strategy =
       super.getOldStrategy(categoricalFeatures, numClasses, 
OldAlgo.Classification, getOldImpurity)
 
-    val instr = Instrumentation.create(this, oldDataset)
+    val instr = Instrumentation.create(this, instances)
     instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, 
rawPredictionCol,
       impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, 
maxMemoryInMB, minInfoGain,
       minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, 
checkpointInterval)
 
     val trees = RandomForest
-      .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, 
getSeed, Some(instr))
+      .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, 
getSeed, Some(instr))
       .map(_.asInstanceOf[DecisionTreeClassificationModel])
 
-    val numFeatures = oldDataset.first().features.size
+    val numFeatures = instances.first().features.size
     val m = new RandomForestClassificationModel(trees, numFeatures, numClasses)
     instr.logSuccess(m)
     m
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
index cce3ca45ccd8f..7e6e4c5a26e4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
@@ -26,4 +26,4 @@ import org.apache.spark.ml.linalg.Vector
  * @param weight The weight of this instance.
  * @param features The vector of features for this data point.
  */
-private[ml] case class Instance(label: Double, weight: Double, features: 
Vector)
+private[spark] case class Instance(label: Double, weight: Double, features: 
Vector)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
index c5d0ec1a8d350..a19f6a88968b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
@@ -35,4 +35,9 @@ case class LabeledPoint(@Since("2.0.0") label: Double, 
@Since("2.0.0") features:
   override def toString: String = {
     s"($label,$features)"
   }
+
+  private[spark] def toInstance(weight: Double): Instance = {
+    Instance(label, weight, features)
+  }
+
 }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 01c5cc1c7efa9..8205714dff462 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -23,8 +23,9 @@ import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.shared.HasWeightCol
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
@@ -33,8 +34,10 @@ import org.apache.spark.ml.util._
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => 
OldStrategy}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => 
OldDecisionTreeModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
 
 
 /**
@@ -64,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: S
   override def setMinInstancesPerNode(value: Int): this.type = 
set(minInstancesPerNode, value)
 
   /** @group setParam */
+  @Since("2.2.0")
+  def setMinWeightFractionPerNode(value: Double): this.type = 
set(minWeightFractionPerNode, value)
+
   @Since("1.4.0")
   override def setMinInfoGain(value: Double): this.type = set(minInfoGain, 
value)
 
@@ -99,16 +105,31 @@ class DecisionTreeRegressor @Since("1.4.0") 
(@Since("1.4.0") override val uid: S
   @Since("2.0.0")
   def setVarianceCol(value: String): this.type = set(varianceCol, value)
 
+  /**
+   * Sets the value of param [[weightCol]].
+   * If this is not set or empty, we treat all instance weights as 1.0.
+   * Default is not set, so all instances have weight one.
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+
   override protected def train(dataset: Dataset[_]): 
DecisionTreeRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+    val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else 
col($(weightCol))
+    val instances =
+      dataset.select(col($(labelCol)).cast(DoubleType), w, 
col($(featuresCol))).rdd.map {
+        case Row(label: Double, weight: Double, features: Vector) =>
+          Instance(label, weight, features)
+      }
     val strategy = getOldStrategy(categoricalFeatures)
 
-    val instr = Instrumentation.create(this, oldDataset)
+    val instr = Instrumentation.create(this, instances)
     instr.logParams(params: _*)
 
-    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
+    val trees = RandomForest.run(instances, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
       seed = $(seed), instr = Some(instr), parentUID = Some(uid))
 
     val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
@@ -122,8 +143,9 @@ class DecisionTreeRegressor @Since("1.4.0") 
(@Since("1.4.0") override val uid: S
     val instr = Instrumentation.create(this, data)
     instr.logParams(params: _*)
 
-    val trees = RandomForest.run(data, oldStrategy, numTrees = 1, 
featureSubsetStrategy = "all",
-      seed = $(seed), instr = Some(instr), parentUID = Some(uid))
+    val instances = data.map(_.toInstance(1.0))
+    val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+      featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), 
parentUID = Some(uid))
 
     val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
     instr.logSuccess(m)
@@ -153,6 +175,7 @@ object DecisionTreeRegressor extends 
DefaultParamsReadable[DecisionTreeRegressor
  * <a href="http://en.wikipedia.org/wiki/Decision_tree_learning";>
  * Decision tree (Wikipedia)</a> model for regression.
  * It supports both continuous and categorical features.
+ *
  * @param rootNode  Root of the decision tree
  */
 @Since("1.4.0")
@@ -171,6 +194,7 @@ class DecisionTreeRegressionModel private[ml] (
 
   /**
    * Construct a decision tree regression model.
+   *
    * @param rootNode  Root node of tree, with other nodes attached.
    */
   private[ml] def this(rootNode: Node, numFeatures: Int) =
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 2f524a8c5784d..0b6bcb9c81235 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -22,7 +22,6 @@ import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree._
@@ -31,10 +30,8 @@ import org.apache.spark.ml.util._
 import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => 
OldRandomForestModel}
-import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.functions.{col, udf}
 
 /**
  * <a href="http://en.wikipedia.org/wiki/Random_forest";>Random Forest</a>
@@ -117,20 +114,20 @@ class RandomForestRegressor @Since("1.4.0") 
(@Since("1.4.0") override val uid: S
   override protected def train(dataset: Dataset[_]): 
RandomForestRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+
+    val instances = extractLabeledPoints(dataset).map(_.toInstance(1.0))
     val strategy =
       super.getOldStrategy(categoricalFeatures, numClasses = 0, 
OldAlgo.Regression, getOldImpurity)
 
-    val instr = Instrumentation.create(this, oldDataset)
+    val instr = Instrumentation.create(this, instances)
     instr.logParams(labelCol, featuresCol, predictionCol, impurity, numTrees,
       featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
       minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, 
checkpointInterval)
 
     val trees = RandomForest
-      .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, 
getSeed, Some(instr))
+      .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, 
getSeed, Some(instr))
       .map(_.asInstanceOf[DecisionTreeRegressionModel])
-
-    val numFeatures = oldDataset.first().features.size
+    val numFeatures = instances.first().features.size
     val m = new RandomForestRegressionModel(trees, numFeatures)
     instr.logSuccess(m)
     m
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
index 4e372702f0c65..2bb7020232f36 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -33,13 +33,20 @@ import org.apache.spark.util.random.XORShiftRandom
  * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, 
respectively.
  *
  * @param datum  Data instance
- * @param subsampleWeights  Weight of this instance in each subsampled dataset.
- *
- * TODO: This does not currently support (Double) weighted instances.  Once 
MLlib has weighted
- *       dataset support, update.  (We store subsampleWeights as Double for 
this future extension.)
+ * @param subsampleCounts  Number of samples of this instance in each 
subsampled dataset.
+ * @param sampleWeight The weight of this instance.
  */
-private[spark] class BaggedPoint[Datum](val datum: Datum, val 
subsampleWeights: Array[Double])
-  extends Serializable
+private[spark] class BaggedPoint[Datum](
+    val datum: Datum,
+    val subsampleCounts: Array[Int],
+    val sampleWeight: Double) extends Serializable {
+
+  /**
+   * Subsample counts weighted by the sample weight.
+   */
+  def weightedCounts: Array[Double] = subsampleCounts.map(_ * sampleWeight)
+
+}
 
 private[spark] object BaggedPoint {
 
@@ -52,6 +59,7 @@ private[spark] object BaggedPoint {
    * @param subsamplingRate Fraction of the training data used for learning 
decision tree.
    * @param numSubsamples Number of subsamples of this RDD to take.
    * @param withReplacement Sampling with/without replacement.
+   * @param extractSampleWeight A function to get the sample weight of each 
datum.
    * @param seed Random seed.
    * @return BaggedPoint dataset representation.
    */
@@ -60,12 +68,14 @@ private[spark] object BaggedPoint {
       subsamplingRate: Double,
       numSubsamples: Int,
       withReplacement: Boolean,
+      extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
       seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
+    // TODO: implement weighted bootstrapping
     if (withReplacement) {
       convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, 
numSubsamples, seed)
     } else {
       if (numSubsamples == 1 && subsamplingRate == 1.0) {
-        convertToBaggedRDDWithoutSampling(input)
+        convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
       } else {
         convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, 
numSubsamples, seed)
       }
@@ -82,16 +92,16 @@ private[spark] object BaggedPoint {
       val rng = new XORShiftRandom
       rng.setSeed(seed + partitionIndex + 1)
       instances.map { instance =>
-        val subsampleWeights = new Array[Double](numSubsamples)
+        val subsampleCounts = new Array[Int](numSubsamples)
         var subsampleIndex = 0
         while (subsampleIndex < numSubsamples) {
           val x = rng.nextDouble()
-          subsampleWeights(subsampleIndex) = {
-            if (x < subsamplingRate) 1.0 else 0.0
+          subsampleCounts(subsampleIndex) = {
+            if (x < subsamplingRate) 1 else 0
           }
           subsampleIndex += 1
         }
-        new BaggedPoint(instance, subsampleWeights)
+        new BaggedPoint(instance, subsampleCounts, 1.0)
       }
     }
   }
@@ -106,20 +116,20 @@ private[spark] object BaggedPoint {
       val poisson = new PoissonDistribution(subsample)
       poisson.reseedRandomGenerator(seed + partitionIndex + 1)
       instances.map { instance =>
-        val subsampleWeights = new Array[Double](numSubsamples)
+        val subsampleCounts = new Array[Int](numSubsamples)
         var subsampleIndex = 0
         while (subsampleIndex < numSubsamples) {
-          subsampleWeights(subsampleIndex) = poisson.sample()
+          subsampleCounts(subsampleIndex) = poisson.sample()
           subsampleIndex += 1
         }
-        new BaggedPoint(instance, subsampleWeights)
+        new BaggedPoint(instance, subsampleCounts, 1.0)
       }
     }
   }
 
   private def convertToBaggedRDDWithoutSampling[Datum] (
-      input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
-    input.map(datum => new BaggedPoint(datum, Array(1.0)))
+      input: RDD[Datum],
+      extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = {
+    input.map(datum => new BaggedPoint(datum, Array(1), 
extractSampleWeight(datum)))
   }
-
 }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index 61091bb803e49..3124f4ee3c103 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -104,16 +104,21 @@ private[spark] class DTStatsAggregator(
   /**
    * Update the stats for a given (feature, bin) for ordered features, using 
the given label.
    */
-  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: 
Double): Unit = {
+  def update(
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      numSamples: Int,
+      sampleWeight: Double): Unit = {
     val i = featureOffsets(featureIndex) + binIndex * statsSize
-    impurityAggregator.update(allStats, i, label, instanceWeight)
+    impurityAggregator.update(allStats, i, label, numSamples, sampleWeight)
   }
 
   /**
    * Update the parent node stats using the given label.
    */
-  def updateParent(label: Double, instanceWeight: Double): Unit = {
-    impurityAggregator.update(parentStats, 0, label, instanceWeight)
+  def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit 
= {
+    impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight)
   }
 
   /**
@@ -127,9 +132,10 @@ private[spark] class DTStatsAggregator(
       featureOffset: Int,
       binIndex: Int,
       label: Double,
-      instanceWeight: Double): Unit = {
+      numSamples: Int,
+      sampleWeight: Double): Unit = {
     impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
-      label, instanceWeight)
+      label, numSamples, sampleWeight)
   }
 
   /**
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 8a9dcb486b7bf..b67cfff752935 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
 import scala.util.Try
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.tree.RandomForestParams
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -32,16 +32,20 @@ import org.apache.spark.rdd.RDD
 /**
  * Learning and dataset metadata for DecisionTree.
  *
+ * @param weightedNumExamples  Weighted count of samples in the tree.
  * @param numClasses    For classification: labels can take values {0, ..., 
numClasses - 1}.
  *                      For regression: fixed at 0 (no meaning).
  * @param maxBins  Maximum number of bins, for all features.
  * @param featureArity  Map: categorical feature index to arity.
  *                      I.e., the feature takes values in {0, ..., arity - 1}.
  * @param numBins  Number of bins for each feature.
+ * @param minWeightFractionPerNode  The minimum fraction of the total sample 
weight that must be
+ *                                  present in a leaf node in order to be 
considered a valid split.
  */
 private[spark] class DecisionTreeMetadata(
     val numFeatures: Int,
     val numExamples: Long,
+    val weightedNumExamples: Double,
     val numClasses: Int,
     val maxBins: Int,
     val featureArity: Map[Int, Int],
@@ -51,6 +55,7 @@ private[spark] class DecisionTreeMetadata(
     val quantileStrategy: QuantileStrategy,
     val maxDepth: Int,
     val minInstancesPerNode: Int,
+    val minWeightFractionPerNode: Double,
     val minInfoGain: Double,
     val numTrees: Int,
     val numFeaturesPerNode: Int) extends Serializable {
@@ -67,6 +72,8 @@ private[spark] class DecisionTreeMetadata(
 
   def isContinuous(featureIndex: Int): Boolean = 
!featureArity.contains(featureIndex)
 
+  def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples
+
   /**
    * Number of splits for the given feature.
    * For unordered features, there is 1 bin per split.
@@ -104,7 +111,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
    * as well as the number of splits and bins for each feature.
    */
   def buildMetadata(
-      input: RDD[LabeledPoint],
+      input: RDD[Instance],
       strategy: Strategy,
       numTrees: Int,
       featureSubsetStrategy: String): DecisionTreeMetadata = {
@@ -115,7 +122,10 @@ private[spark] object DecisionTreeMetadata extends Logging 
{
     }
     require(numFeatures > 0, s"DecisionTree requires number of features > 0, " 
+
       s"but was given an empty features vector")
-    val numExamples = input.count()
+    val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
+      (acc, x) => (acc._1 + 1L, acc._2 + x.weight),
+      (acc1, acc2) => (acc1._1 + acc2._1, acc1._2 + acc2._2))
+
     val numClasses = strategy.algo match {
       case Classification => strategy.numClasses
       case Regression => 0
@@ -206,17 +216,18 @@ private[spark] object DecisionTreeMetadata extends 
Logging {
         }
     }
 
-    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
-      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
+    new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
+      numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, 
numBins,
       strategy.impurity, strategy.quantileCalculationStrategy, 
strategy.maxDepth,
-      strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, 
numFeaturesPerNode)
+      strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, 
strategy.minInfoGain,
+      numTrees, numFeaturesPerNode)
   }
 
   /**
    * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
    */
   def buildMetadata(
-      input: RDD[LabeledPoint],
+      input: RDD[Instance],
       strategy: Strategy): DecisionTreeMetadata = {
     buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
   }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 008dd19c2498d..80b4a97a82bd3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -24,7 +24,8 @@ import scala.util.Random
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.util.Instrumentation
@@ -82,11 +83,11 @@ private[spark] object RandomForest extends Logging {
   /**
    * Train a random forest.
    *
-   * @param input Training data: RDD of `LabeledPoint`
+   * @param input Training data: RDD of 
[[org.apache.spark.ml.feature.Instance]]
    * @return an unweighted set of trees
    */
   def run(
-      input: RDD[LabeledPoint],
+      input: RDD[Instance],
       strategy: OldStrategy,
       numTrees: Int,
       featureSubsetStrategy: String,
@@ -100,9 +101,10 @@ private[spark] object RandomForest extends Logging {
 
     timer.start("init")
 
-    val retaggedInput = input.retag(classOf[LabeledPoint])
+    val retaggedInput = input.retag(classOf[Instance])
     val metadata =
       DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, 
featureSubsetStrategy)
+
     instr match {
       case Some(instrumentation) =>
         instrumentation.logNumFeatures(metadata.numFeatures)
@@ -129,7 +131,8 @@ private[spark] object RandomForest extends Logging {
     val withReplacement = numTrees > 1
 
     val baggedInput = BaggedPoint
-      .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, 
withReplacement, seed)
+      .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, 
withReplacement,
+        (tp: TreePoint) => tp.weight, seed = seed)
       .persist(StorageLevel.MEMORY_AND_DISK)
 
     // depth of the decision tree
@@ -250,19 +253,21 @@ private[spark] object RandomForest extends Logging {
    * For unordered features, bins correspond to subsets of categories; either 
the left or right bin
    * for each subset is updated.
    *
-   * @param agg  Array storing aggregate calculation, with a set of sufficient 
statistics for
-   *             each (feature, bin).
-   * @param treePoint  Data point being aggregated.
-   * @param splits possible splits indexed (numFeatures)(numSplits)
-   * @param unorderedFeatures  Set of indices of unordered features.
-   * @param instanceWeight  Weight (importance) of instance in dataset.
+   * @param agg Array storing aggregate calculation, with a set of sufficient 
statistics for
+   *            each (feature, bin).
+   * @param treePoint Data point being aggregated.
+   * @param splits Possible splits indexed (numFeatures)(numSplits)
+   * @param unorderedFeatures Set of indices of unordered features.
+   * @param numSamples Number of times this instance occurs in the sample.
+   * @param sampleWeight Weight (importance) of instance in dataset.
    */
   private def mixedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
       splits: Array[Array[Split]],
       unorderedFeatures: Set[Int],
-      instanceWeight: Double,
+      numSamples: Int,
+      sampleWeight: Double,
       featuresForNode: Option[Array[Int]]): Unit = {
     val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
       // Use subsampled features
@@ -289,14 +294,15 @@ private[spark] object RandomForest extends Logging {
         var splitIndex = 0
         while (splitIndex < numSplits) {
           if (featureSplits(splitIndex).shouldGoLeft(featureValue, 
featureSplits)) {
-            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label, instanceWeight)
+            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label, numSamples,
+              sampleWeight)
           }
           splitIndex += 1
         }
       } else {
         // Ordered feature
         val binIndex = treePoint.binnedFeatures(featureIndex)
-        agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
+        agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, 
sampleWeight)
       }
       featureIndexIdx += 1
     }
@@ -310,12 +316,14 @@ private[spark] object RandomForest extends Logging {
    * @param agg  Array storing aggregate calculation, with a set of sufficient 
statistics for
    *             each (feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param instanceWeight  Weight (importance) of instance in dataset.
+   * @param numSamples Number of times this instance occurs in the sample.
+   * @param sampleWeight  Weight (importance) of instance in dataset.
    */
   private def orderedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
-      instanceWeight: Double,
+      numSamples: Int,
+      sampleWeight: Double,
       featuresForNode: Option[Array[Int]]): Unit = {
     val label = treePoint.label
 
@@ -325,7 +333,7 @@ private[spark] object RandomForest extends Logging {
       var featureIndexIdx = 0
       while (featureIndexIdx < featuresForNode.get.length) {
         val binIndex = 
treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
-        agg.update(featureIndexIdx, binIndex, label, instanceWeight)
+        agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
         featureIndexIdx += 1
       }
     } else {
@@ -334,7 +342,7 @@ private[spark] object RandomForest extends Logging {
       var featureIndex = 0
       while (featureIndex < numFeatures) {
         val binIndex = treePoint.binnedFeatures(featureIndex)
-        agg.update(featureIndex, binIndex, label, instanceWeight)
+        agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
         featureIndex += 1
       }
     }
@@ -423,14 +431,16 @@ private[spark] object RandomForest extends Logging {
       if (nodeInfo != null) {
         val aggNodeIndex = nodeInfo.nodeIndexInGroup
         val featuresForNode = nodeInfo.featureSubset
-        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+        val numSamples = baggedPoint.subsampleCounts(treeIndex)
+        val sampleWeight = baggedPoint.sampleWeight
         if (metadata.unorderedFeatures.isEmpty) {
-          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, 
instanceWeight, featuresForNode)
+          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, 
sampleWeight,
+            featuresForNode)
         } else {
           mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
-            metadata.unorderedFeatures, instanceWeight, featuresForNode)
+            metadata.unorderedFeatures, numSamples, sampleWeight, 
featuresForNode)
         }
-        agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
+        agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, 
sampleWeight)
       }
     }
 
@@ -590,8 +600,8 @@ private[spark] object RandomForest extends Logging {
         if (!isLeaf) {
           node.split = Some(split)
           val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == 
metadata.maxDepth
-          val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
-          val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+          val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < 
Utils.EPSILON)
+          val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) 
< Utils.EPSILON)
           node.leftChild = 
Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
             leftChildIsLeaf, 
ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
           node.rightChild = 
Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
@@ -655,15 +665,20 @@ private[spark] object RandomForest extends Logging {
       stats.impurity
     }
 
+    val leftRawCount = leftImpurityCalculator.rawCount
+    val rightRawCount = rightImpurityCalculator.rawCount
     val leftCount = leftImpurityCalculator.count
     val rightCount = rightImpurityCalculator.count
 
     val totalCount = leftCount + rightCount
 
-    // If left child or right child doesn't satisfy minimum instances per node,
-    // then this split is invalid, return invalid information gain stats.
-    if ((leftCount < metadata.minInstancesPerNode) ||
-      (rightCount < metadata.minInstancesPerNode)) {
+    val violatesMinInstancesPerNode = (leftRawCount < 
metadata.minInstancesPerNode) ||
+      (rightRawCount < metadata.minInstancesPerNode)
+    val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+      (rightCount < metadata.minWeightPerNode)
+    // If left child or right child doesn't satisfy minimum weight per node or 
minimum
+    // instances per node, then this split is invalid, return invalid 
information gain stats.
+    if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
       return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
     }
 
@@ -730,7 +745,8 @@ private[spark] object RandomForest extends Logging {
           // Find best split.
           val (bestFeatureSplitIndex, bestFeatureGainStats) =
             Range(0, numSplits).map { case splitIdx =>
-              val leftChildStats = 
binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+              val leftChildStats =
+                binAggregates.getImpurityCalculator(nodeFeatureOffset, 
splitIdx)
               val rightChildStats =
                 binAggregates.getImpurityCalculator(nodeFeatureOffset, 
numSplits)
               rightChildStats.subtract(leftChildStats)
@@ -872,14 +888,14 @@ private[spark] object RandomForest extends Logging {
    *       and for multiclass classification with a high-arity feature,
    *       there is one bin per category.
    *
-   * @param input Training data: RDD of [[LabeledPoint]]
+   * @param input Training data: RDD of [[Instance]]
    * @param metadata Learning and dataset metadata
    * @param seed random seed
    * @return Splits, an Array of [[Split]]
    *          of size (numFeatures, numSplits)
    */
   protected[tree] def findSplits(
-      input: RDD[LabeledPoint],
+      input: RDD[Instance],
       metadata: DecisionTreeMetadata,
       seed: Long): Array[Array[Split]] = {
 
@@ -900,14 +916,14 @@ private[spark] object RandomForest extends Logging {
       logDebug("fraction of data used for calculating quantiles = " + fraction)
       input.sample(withReplacement = false, fraction, new 
XORShiftRandom(seed).nextInt())
     } else {
-      input.sparkContext.emptyRDD[LabeledPoint]
+      input.sparkContext.emptyRDD[Instance]
     }
 
     findSplitsBySorting(sampledInput, metadata, continuousFeatures)
   }
 
   private def findSplitsBySorting(
-      input: RDD[LabeledPoint],
+      input: RDD[Instance],
       metadata: DecisionTreeMetadata,
       continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
 
@@ -918,7 +934,7 @@ private[spark] object RandomForest extends Logging {
       val numPartitions = math.min(continuousFeatures.length, 
input.partitions.length)
 
       input
-        .flatMap(point => continuousFeatures.map(idx => (idx, 
point.features(idx))))
+        .flatMap(point => continuousFeatures.map(idx => (idx, (point.weight, 
point.features(idx)))))
         .groupByKey(numPartitions)
         .map { case (idx, samples) =>
           val thresholds = findSplitsForContinuousFeature(samples, metadata, 
idx)
@@ -982,7 +998,7 @@ private[spark] object RandomForest extends Logging {
    *       could be different from the specified `numSplits`.
    *       The `numSplits` attribute in the `DecisionTreeMetadata` class will 
be set accordingly.
    *
-   * @param featureSamples feature values of each sample
+   * @param featureSamples feature values and sample weights of each sample
    * @param metadata decision tree metadata
    *                 NOTE: `metadata.numbins` will be changed accordingly
    *                       if there are not enough splits to be found
@@ -990,7 +1006,7 @@ private[spark] object RandomForest extends Logging {
    * @return array of split thresholds
    */
   private[tree] def findSplitsForContinuousFeature(
-      featureSamples: Iterable[Double],
+      featureSamples: Iterable[(Double, Double)],
       metadata: DecisionTreeMetadata,
       featureIndex: Int): Array[Double] = {
     require(metadata.isContinuous(featureIndex),
@@ -1002,9 +1018,9 @@ private[spark] object RandomForest extends Logging {
       val numSplits = metadata.numSplits(featureIndex)
 
       // get count for each distinct value
-      val (valueCountMap, numSamples) = 
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
-        case ((m, cnt), x) =>
-          (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
+      val (valueCountMap, numSamples) = 
featureSamples.foldLeft((Map.empty[Double, Double], 0.0)) {
+        case ((m, cnt), (w, x)) =>
+          (m + ((x, m.getOrElse(x, 0.0) + w)), cnt + w)
       }
       // sort distinct values
       val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
index a6ac64a0463cc..16b00299594a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.ml.tree.impl
 
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.tree.{ContinuousSplit, Split}
 import org.apache.spark.rdd.RDD
 
@@ -36,10 +36,12 @@ import org.apache.spark.rdd.RDD
  * @param label  Label from LabeledPoint
  * @param binnedFeatures  Binned feature values.
  *                        Same length as LabeledPoint.features, but values are 
bin indices.
+ * @param weight Sample weight for this TreePoint.
  */
-private[spark] class TreePoint(val label: Double, val binnedFeatures: 
Array[Int])
-  extends Serializable {
-}
+private[spark] class TreePoint(
+    val label: Double,
+    val binnedFeatures: Array[Int],
+    val weight: Double) extends Serializable
 
 private[spark] object TreePoint {
 
@@ -52,7 +54,7 @@ private[spark] object TreePoint {
    * @return  TreePoint dataset representation
    */
   def convertToTreeRDD(
-      input: RDD[LabeledPoint],
+      input: RDD[Instance],
       splits: Array[Array[Split]],
       metadata: DecisionTreeMetadata): RDD[TreePoint] = {
     // Construct arrays for featureArity for efficiency in the inner loop.
@@ -82,18 +84,18 @@ private[spark] object TreePoint {
    *                      for categorical features.
    */
   private def labeledPointToTreePoint(
-      labeledPoint: LabeledPoint,
+      instance: Instance,
       thresholds: Array[Array[Double]],
       featureArity: Array[Int]): TreePoint = {
-    val numFeatures = labeledPoint.features.size
+    val numFeatures = instance.features.size
     val arr = new Array[Int](numFeatures)
     var featureIndex = 0
     while (featureIndex < numFeatures) {
       arr(featureIndex) =
-        findBin(featureIndex, labeledPoint, featureArity(featureIndex), 
thresholds(featureIndex))
+        findBin(featureIndex, instance, featureArity(featureIndex), 
thresholds(featureIndex))
       featureIndex += 1
     }
-    new TreePoint(labeledPoint.label, arr)
+    new TreePoint(instance.label, arr, instance.weight)
   }
 
   /**
@@ -106,10 +108,10 @@ private[spark] object TreePoint {
    */
   private def findBin(
       featureIndex: Int,
-      labeledPoint: LabeledPoint,
+      instance: Instance,
       featureArity: Int,
       thresholds: Array[Double]): Int = {
-    val featureValue = labeledPoint.features(featureIndex)
+    val featureValue = instance.features(featureIndex)
 
     if (featureArity == 0) {
       val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
@@ -125,7 +127,7 @@ private[spark] object TreePoint {
           s"DecisionTree given invalid data:" +
             s" Feature $featureIndex is categorical with values in 
{0,...,${featureArity - 1}," +
             s" but a data point gives it value $featureValue.\n" +
-            "  Bad data point: " + labeledPoint.toString)
+            "  Bad data point: " + instance.toString)
       }
       featureValue.toInt
     }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 0d6e9034e5ce4..820c42fa43b88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -219,7 +219,7 @@ private[ml] object TreeEnsembleModel {
         importances.changeValue(feature, scaledGain, _ + scaledGain)
         computeFeatureImportance(n.leftChild, importances)
         computeFeatureImportance(n.rightChild, importances)
-      case n: LeafNode =>
+      case _: LeafNode =>
       // do nothing
     }
   }
@@ -282,6 +282,7 @@ private[ml] object DecisionTreeModelReadWrite {
    *
    * @param id  Index used for tree reconstruction.  Indices follow a 
pre-order traversal.
    * @param impurityStats  Stats array.  Impurity type is stored in metadata.
+   * @param rawCount  The unweighted number of samples falling in this node.
    * @param gain  Gain, or arbitrary value if leaf node.
    * @param leftChild  Left child index, or arbitrary value if leaf node.
    * @param rightChild  Right child index, or arbitrary value if leaf node.
@@ -292,6 +293,7 @@ private[ml] object DecisionTreeModelReadWrite {
     prediction: Double,
     impurity: Double,
     impurityStats: Array[Double],
+    rawCount: Long,
     gain: Double,
     leftChild: Int,
     rightChild: Int,
@@ -311,11 +313,12 @@ private[ml] object DecisionTreeModelReadWrite {
         val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
         val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
         val thisNodeData = NodeData(id, n.prediction, n.impurity, 
n.impurityStats.stats,
-          n.gain, leftNodeData.head.id, rightNodeData.head.id, 
SplitData(n.split))
+          n.impurityStats.rawCount, n.gain, leftNodeData.head.id, 
rightNodeData.head.id,
+          SplitData(n.split))
         (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
       case _: LeafNode =>
         (Seq(NodeData(id, node.prediction, node.impurity, 
node.impurityStats.stats,
-          -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
+          node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, 
Array.empty[Double], -1))),
           id)
     }
   }
@@ -360,7 +363,8 @@ private[ml] object DecisionTreeModelReadWrite {
     // traversal, this guarantees that child nodes will be built before parent 
nodes.
     val finalNodes = new Array[Node](nodes.length)
     nodes.reverseIterator.foreach { case n: NodeData =>
-      val impurityStats = ImpurityCalculator.getCalculator(impurityType, 
n.impurityStats)
+      val impurityStats =
+        ImpurityCalculator.getCalculator(impurityType, n.impurityStats, 
n.rawCount)
       val node = if (n.leftChild != -1) {
         val leftChild = finalNodes(n.leftChild)
         val rightChild = finalNodes(n.rightChild)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 5eb707dfe7bc3..ff5955d409804 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, 
StructType}
  * Note: Marked as private and DeveloperApi since this may be made public in 
the future.
  */
 private[ml] trait DecisionTreeParams extends PredictorParams
-  with HasCheckpointInterval with HasSeed {
+  with HasCheckpointInterval with HasSeed with HasWeightCol {
 
   /**
    * Maximum depth of the tree (>= 0).
@@ -71,6 +71,21 @@ private[ml] trait DecisionTreeParams extends PredictorParams
     " child to have fewer than minInstancesPerNode, the split will be 
discarded as invalid." +
     " Should be >= 1.", ParamValidators.gtEq(1))
 
+  /**
+   * Minimum fraction of the weighted sample count that each child must have 
after split.
+   * If a split causes the fraction of the total weight in the left or right 
child to be less than
+   * minWeightFractionPerNode, the split will be discarded as invalid.
+   * Should be in the interval [0.0, 0.5).
+   * (default = 0.0)
+   * @group param
+   */
+  final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
+    "minWeightFractionPerNode", "Minimum fraction of the weighted sample count 
that each child " +
+    "must have after split. If a split causes the fraction of the total weight 
in the left or " +
+    "right child to be less than minWeightFractionPerNode, the split will be 
discarded as " +
+    "invalid. Should be in interval [0.0, 0.5)",
+    ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = 
false))
+
   /**
    * Minimum information gain for a split to be considered at a tree node.
    * Should be >= 0.0.
@@ -104,8 +119,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
     " algorithm will cache node IDs for each instance. Caching can speed up 
training of deeper" +
     " trees.")
 
-  setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, 
minInfoGain -> 0.0,
-    maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+  setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
+    minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
+    cacheNodeIds -> false, checkpointInterval -> 10)
 
   /**
    * @deprecated This method is deprecated and will be removed in 2.2.0.
@@ -137,6 +153,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
   /** @group getParam */
   final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
 
+  /** @group getParam */
+  final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
+
   /**
    * @deprecated This method is deprecated and will be removed in 2.2.0.
    * @group setParam
@@ -196,6 +215,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
     strategy.maxMemoryInMB = getMaxMemoryInMB
     strategy.minInfoGain = getMinInfoGain
     strategy.minInstancesPerNode = getMinInstancesPerNode
+    strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
     strategy.useNodeIdCache = getCacheNodeIds
     strategy.numClasses = numClasses
     strategy.categoricalFeaturesInfo = categoricalFeatures
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index d1331a57de27b..9e1401588142e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -23,6 +23,7 @@ import scala.util.Try
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.{Instance => NewInstance}
 import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, 
RandomForestParams => NewRFParams}
 import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -91,8 +92,11 @@ private class RandomForest (
    * @return RandomForestModel that can be used for prediction.
    */
   def run(input: RDD[LabeledPoint]): RandomForestModel = {
-    val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), 
strategy, numTrees,
-      featureSubsetStrategy, seed.toLong, None)
+    val instances = input.map { case LabeledPoint(label, features) =>
+      NewInstance(label, 1.0, features.asML)
+    }
+    val trees: Array[NewDTModel] =
+      NewRandomForest.run(instances, strategy, numTrees, 
featureSubsetStrategy, seed.toLong, None)
     new RandomForestModel(strategy.algo, trees.map(_.toOld))
   }
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 58e8f5be7b9f0..5806741b413be 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -80,7 +80,8 @@ class Strategy @Since("1.3.0") (
     @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
     @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
     @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
-    @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends 
Serializable {
+    @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
+    private[spark] var minWeightFractionPerNode: Double = 0.0) extends 
Serializable {
 
   /**
    */
@@ -108,7 +109,8 @@ class Strategy @Since("1.3.0") (
       maxBins: Int,
       categoricalFeaturesInfo: java.util.Map[java.lang.Integer, 
java.lang.Integer]) {
     this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
-      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, 
Int]].asScala.toMap)
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, 
Int]].asScala.toMap,
+      minWeightFractionPerNode = 0.0)
   }
 
   /**
@@ -171,8 +173,9 @@ class Strategy @Since("1.3.0") (
   @Since("1.2.0")
   def copy: Strategy = {
     new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
-      quantileCalculationStrategy, categoricalFeaturesInfo, 
minInstancesPerNode, minInfoGain,
-      maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
+      quantileCalculationStrategy, categoricalFeaturesInfo, 
minInstancesPerNode,
+      minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
+      checkpointInterval, minWeightFractionPerNode)
   }
 }
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index d4448da9eef51..f01a98e74886b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -83,23 +83,29 @@ object Entropy extends Impurity {
  * @param numClasses  Number of classes for label.
  */
 private[spark] class EntropyAggregator(numClasses: Int)
-  extends ImpurityAggregator(numClasses) with Serializable {
+  extends ImpurityAggregator(numClasses + 1) with Serializable {
 
   /**
    * Update stats for one (node, feature, bin) with the given label.
    * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double, 
instanceWeight: Double): Unit = {
-    if (label >= statsSize) {
+  def update(
+      allStats: Array[Double],
+      offset: Int,
+      label: Double,
+      numSamples: Int,
+      sampleWeight: Double): Unit = {
+    if (label >= numClasses) {
       throw new IllegalArgumentException(s"EntropyAggregator given label 
$label" +
-        s" but requires label < numClasses (= $statsSize).")
+        s" but requires label < numClasses (= ${numClasses}).")
     }
     if (label < 0) {
       throw new IllegalArgumentException(s"EntropyAggregator given label 
$label" +
         s"but requires label is non-negative.")
     }
-    allStats(offset + label.toInt) += instanceWeight
+    allStats(offset + label.toInt) += numSamples * sampleWeight
+    allStats(offset + statsSize - 1) += numSamples
   }
 
   /**
@@ -108,7 +114,8 @@ private[spark] class EntropyAggregator(numClasses: Int)
    * @param offset    Start index of stats for this (node, feature, bin).
    */
   def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = 
{
-    new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
+    new EntropyCalculator(allStats.view(offset, offset + statsSize - 
1).toArray,
+      allStats(offset + statsSize - 1).toLong)
   }
 }
 
@@ -118,12 +125,13 @@ private[spark] class EntropyAggregator(numClasses: Int)
  * (node, feature, bin).
  * @param stats  Array of sufficient statistics for a (node, feature, bin).
  */
-private[spark] class EntropyCalculator(stats: Array[Double]) extends 
ImpurityCalculator(stats) {
+private[spark] class EntropyCalculator(stats: Array[Double], var rawCount: 
Long)
+  extends ImpurityCalculator(stats) {
 
   /**
    * Make a deep copy of this [[ImpurityCalculator]].
    */
-  def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
+  def copy: EntropyCalculator = new EntropyCalculator(stats.clone(), rawCount)
 
   /**
    * Calculate the impurity from the stored sufficient statistics.
@@ -131,9 +139,9 @@ private[spark] class EntropyCalculator(stats: 
Array[Double]) extends ImpurityCal
   def calculate(): Double = Entropy.calculate(stats, stats.sum)
 
   /**
-   * Number of data points accounted for in the sufficient statistics.
+   * Weighted number of data points accounted for in the sufficient statistics.
    */
-  def count: Long = stats.sum.toLong
+  def count: Double = stats.sum
 
   /**
    * Prediction which should be made based on the sufficient statistics.
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c5e34ffa4f2e5..913ffbbb2457a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -80,23 +80,29 @@ object Gini extends Impurity {
  * @param numClasses  Number of classes for label.
  */
 private[spark] class GiniAggregator(numClasses: Int)
-  extends ImpurityAggregator(numClasses) with Serializable {
+  extends ImpurityAggregator(numClasses + 1) with Serializable {
 
   /**
    * Update stats for one (node, feature, bin) with the given label.
    * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double, 
instanceWeight: Double): Unit = {
-    if (label >= statsSize) {
+  def update(
+      allStats: Array[Double],
+      offset: Int,
+      label: Double,
+      numSamples: Int,
+      sampleWeight: Double): Unit = {
+    if (label >= numClasses) {
       throw new IllegalArgumentException(s"GiniAggregator given label $label" +
-        s" but requires label < numClasses (= $statsSize).")
+        s" but requires label < numClasses (= ${numClasses}).")
     }
     if (label < 0) {
       throw new IllegalArgumentException(s"GiniAggregator given label $label" +
-        s"but requires label is non-negative.")
+        s"but requires label to be non-negative.")
     }
-    allStats(offset + label.toInt) += instanceWeight
+    allStats(offset + label.toInt) += numSamples * sampleWeight
+    allStats(offset + statsSize - 1) += numSamples
   }
 
   /**
@@ -105,7 +111,8 @@ private[spark] class GiniAggregator(numClasses: Int)
    * @param offset    Start index of stats for this (node, feature, bin).
    */
   def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
-    new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
+    new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
+      allStats(offset + statsSize - 1).toLong)
   }
 }
 
@@ -115,12 +122,13 @@ private[spark] class GiniAggregator(numClasses: Int)
  * (node, feature, bin).
  * @param stats  Array of sufficient statistics for a (node, feature, bin).
  */
-private[spark] class GiniCalculator(stats: Array[Double]) extends 
ImpurityCalculator(stats) {
+private[spark] class GiniCalculator(stats: Array[Double], var rawCount: Long)
+  extends ImpurityCalculator(stats) {
 
   /**
    * Make a deep copy of this [[ImpurityCalculator]].
    */
-  def copy: GiniCalculator = new GiniCalculator(stats.clone())
+  def copy: GiniCalculator = new GiniCalculator(stats.clone(), rawCount)
 
   /**
    * Calculate the impurity from the stored sufficient statistics.
@@ -128,9 +136,9 @@ private[spark] class GiniCalculator(stats: Array[Double]) 
extends ImpurityCalcul
   def calculate(): Double = Gini.calculate(stats, stats.sum)
 
   /**
-   * Number of data points accounted for in the sufficient statistics.
+   * Weighted number of data points accounted for in the sufficient statistics.
    */
-  def count: Long = stats.sum.toLong
+  def count: Double = stats.sum
 
   /**
    * Prediction which should be made based on the sufficient statistics.
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index a5bdc2c6d2c94..6a814e658caa5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -79,7 +79,12 @@ private[spark] abstract class ImpurityAggregator(val 
statsSize: Int) extends Ser
    * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double, 
instanceWeight: Double): Unit
+  def update(
+      allStats: Array[Double],
+      offset: Int,
+      label: Double,
+      numSamples: Int,
+      sampleWeight: Double): Unit
 
   /**
    * Get an [[ImpurityCalculator]] for a (node, feature, bin).
@@ -120,6 +125,7 @@ private[spark] abstract class ImpurityCalculator(val stats: 
Array[Double]) exten
       stats(i) += other.stats(i)
       i += 1
     }
+    rawCount += other.rawCount
     this
   }
 
@@ -137,13 +143,19 @@ private[spark] abstract class ImpurityCalculator(val 
stats: Array[Double]) exten
       stats(i) -= other.stats(i)
       i += 1
     }
+    rawCount -= other.rawCount
     this
   }
 
   /**
-   * Number of data points accounted for in the sufficient statistics.
+   * Weighted number of data points accounted for in the sufficient statistics.
    */
-  def count: Long
+  def count: Double
+
+  /**
+   * Raw number of data points accounted for in the sufficient statistics.
+   */
+  var rawCount: Long
 
   /**
    * Prediction which should be made based on the sufficient statistics.
@@ -183,11 +195,14 @@ private[spark] object ImpurityCalculator {
    * Create an [[ImpurityCalculator]] instance of the given impurity type and 
with
    * the given stats.
    */
-  def getCalculator(impurity: String, stats: Array[Double]): 
ImpurityCalculator = {
+  def getCalculator(
+      impurity: String,
+      stats: Array[Double],
+      rawCount: Long): ImpurityCalculator = {
     impurity match {
-      case "gini" => new GiniCalculator(stats)
-      case "entropy" => new EntropyCalculator(stats)
-      case "variance" => new VarianceCalculator(stats)
+      case "gini" => new GiniCalculator(stats, rawCount)
+      case "entropy" => new EntropyCalculator(stats, rawCount)
+      case "variance" => new VarianceCalculator(stats, rawCount)
       case _ =>
         throw new IllegalArgumentException(
           s"ImpurityCalculator builder did not recognize impurity type: 
$impurity")
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index c9bf0db4de3c2..a07b919271f71 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -66,21 +66,32 @@ object Variance extends Impurity {
 
 /**
  * Class for updating views of a vector of sufficient statistics,
- * in order to compute impurity from a sample.
+ * in order to compute impurity from a sample. For variance, we track:
+ *   - sum(w_i)
+ *   - sum(w_i * y_i)
+ *   - sum(w_i * y_i * y_i)
+ *   - count(y_i)
  * Note: Instances of this class do not hold the data; they operate on views 
of the data.
  */
 private[spark] class VarianceAggregator()
-  extends ImpurityAggregator(statsSize = 3) with Serializable {
+  extends ImpurityAggregator(statsSize = 4) with Serializable {
 
   /**
    * Update stats for one (node, feature, bin) with the given label.
    * @param allStats  Flat stats array, with stats for this (node, feature, 
bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double, 
instanceWeight: Double): Unit = {
+  def update(
+      allStats: Array[Double],
+      offset: Int,
+      label: Double,
+      numSamples: Int,
+      sampleWeight: Double): Unit = {
+    val instanceWeight = numSamples * sampleWeight
     allStats(offset) += instanceWeight
     allStats(offset + 1) += instanceWeight * label
     allStats(offset + 2) += instanceWeight * label * label
+    allStats(offset + 3) += numSamples
   }
 
   /**
@@ -89,7 +100,8 @@ private[spark] class VarianceAggregator()
    * @param offset    Start index of stats for this (node, feature, bin).
    */
   def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator 
= {
-    new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
+    new VarianceCalculator(allStats.view(offset, offset + statsSize - 
1).toArray,
+      allStats(offset + statsSize - 1).toLong)
   }
 }
 
@@ -99,7 +111,8 @@ private[spark] class VarianceAggregator()
  * (node, feature, bin).
  * @param stats  Array of sufficient statistics for a (node, feature, bin).
  */
-private[spark] class VarianceCalculator(stats: Array[Double]) extends 
ImpurityCalculator(stats) {
+private[spark] class VarianceCalculator(stats: Array[Double], var rawCount: 
Long)
+  extends ImpurityCalculator(stats) {
 
   require(stats.length == 3,
     s"VarianceCalculator requires sufficient statistics array stats to be of 
length 3," +
@@ -108,7 +121,7 @@ private[spark] class VarianceCalculator(stats: 
Array[Double]) extends ImpurityCa
   /**
    * Make a deep copy of this [[ImpurityCalculator]].
    */
-  def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
+  def copy: VarianceCalculator = new VarianceCalculator(stats.clone(), 
rawCount)
 
   /**
    * Calculate the impurity from the stored sufficient statistics.
@@ -116,9 +129,9 @@ private[spark] class VarianceCalculator(stats: 
Array[Double]) extends ImpurityCa
   def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
 
   /**
-   * Number of data points accounted for in the sufficient statistics.
+   * Weighted number of data points accounted for in the sufficient statistics.
    */
-  def count: Long = stats(0).toLong
+  def count: Double = stats(0)
 
   /**
    * Prediction which should be made based on the sufficient statistics.
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index c711e7fa9dc67..a48050c4a25ff 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,8 +21,9 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
+import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode}
 import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, 
DecisionTreeSuite => OldDecisionTreeSuite}
@@ -42,6 +43,9 @@ class DecisionTreeClassifierSuite
   private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
   private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
   private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: 
RDD[LabeledPoint] = _
+  private var linearMulticlassDataset: DataFrame = _
+
+  private val seed = 42
 
   override def beforeAll() {
     super.beforeAll()
@@ -58,6 +62,20 @@ class DecisionTreeClassifierSuite
     categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize(
       
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
       .map(_.asML)
+    linearMulticlassDataset = {
+      val nPoints = 100
+      val coefficients = Array(
+        -0.57997, 0.912083, -0.371077,
+        -0.16624, -0.84355, -0.048509)
+
+      val xMean = Array(5.843, 3.057)
+      val xVariance = Array(0.6856, 0.1899)
+
+      val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
+        coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
+
+      sc.parallelize(testData, 4).toDF()
+    }
   }
 
   test("params") {
@@ -246,7 +264,8 @@ class DecisionTreeClassifierSuite
     val categoricalFeatures = Map(0 -> 3)
     val numClasses = 3
 
-    val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, 
numClasses)
+    val newData: DataFrame =
+      TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), categoricalFeatures, 
numClasses)
     val newTree = dt.fit(newData)
 
     // copied model must have the same parent.
@@ -273,7 +292,7 @@ class DecisionTreeClassifierSuite
       LabeledPoint(1, Vectors.dense(0, 3, 9)),
       LabeledPoint(0, Vectors.dense(0, 2, 6))
     ))
-    val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
+    val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)), Map(0 -> 1), 2)
     val dt = new DecisionTreeClassifier().setMaxDepth(3)
     dt.fit(df)
   }
@@ -295,7 +314,7 @@ class DecisionTreeClassifierSuite
       LabeledPoint(0.0, Vectors.dense(2.0)),
       LabeledPoint(1.0, Vectors.dense(2.0)))
     val data = sc.parallelize(arr)
-    val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
+    val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)), Map(0 -> 3), 2)
 
     // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
     val dt = new DecisionTreeClassifier()
@@ -326,7 +345,7 @@ class DecisionTreeClassifierSuite
     val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
     val numFeatures = data.first().features.size
     val categoricalFeatures = (0 to numFeatures).map(i => (i, 2)).toMap
-    val df = TreeTests.setMetadata(data, categoricalFeatures, 2)
+    val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)), 
categoricalFeatures, 2)
 
     val model = dt.fit(df)
 
@@ -351,6 +370,36 @@ class DecisionTreeClassifierSuite
     dt.fit(df)
   }
 
+  test("training with sample weights") {
+    val df = linearMulticlassDataset
+    val numClasses = 3
+    val predEquals = (x: Double, y: Double) => x == y
+    // (impurity, maxDepth)
+    val testParams = Seq(
+      ("gini", 10),
+      ("entropy", 10),
+      ("gini", 5)
+    )
+    for ((impurity, maxDepth) <- testParams) {
+      val estimator = new DecisionTreeClassifier()
+        .setMaxDepth(maxDepth)
+        .setSeed(seed)
+        .setMinWeightFractionPerNode(0.049)
+        .setImpurity(impurity)
+
+      
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeClassificationModel,
+        DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+        MLTestingUtils.modelPredictionEquals(df, predEquals, 0.9))
+      
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeClassificationModel,
+        DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+        numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
+        outlierRatio = 2)
+      
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeClassificationModel,
+        DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+        MLTestingUtils.modelPredictionEquals(df, predEquals, 1.0), seed)
+    }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
@@ -371,12 +420,12 @@ class DecisionTreeClassifierSuite
 
     // Categorical splits with tree depth 2
     val categoricalData: DataFrame =
-      TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
+      TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map(0 -> 2, 1 -> 3), 
numClasses = 2)
     testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, 
checkModelData)
 
     // Continuous splits with tree depth 2
     val continuousData: DataFrame =
-      TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+      TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int], 
numClasses = 2)
     testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, 
checkModelData)
 
     // Continuous splits with tree depth 0
@@ -399,7 +448,8 @@ private[ml] object DecisionTreeClassifierSuite extends 
SparkFunSuite {
     val numFeatures = data.first().features.size
     val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
     val oldTree = OldDecisionTree.train(data.map(OldLabeledPoint.fromML), 
oldStrategy)
-    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
+    val newData: DataFrame =
+      TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 
numClasses)
     val newTree = dt.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 0598943c3d4be..b6dbcbff3eca6 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -251,7 +251,8 @@ class GBTClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext
     sc.setCheckpointDir(path)
 
     val categoricalFeatures = Map.empty[Int, Int]
-    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 2)
+    val df: DataFrame =
+      TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 
numClasses = 2)
     val gbt = new GBTClassifier()
       .setMaxDepth(2)
       .setLossType("logistic")
@@ -346,7 +347,8 @@ class GBTClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext
     // In this data, feature 1 is very important.
     val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
     val categoricalFeatures = Map.empty[Int, Int]
-    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
+    val df: DataFrame =
+      TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 
numClasses)
 
     val importances = gbt.fit(df).featureImportances
     val mostImportantFeature = importances.argmax
@@ -373,7 +375,7 @@ class GBTClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext
     val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> 
"logistic")
 
     val continuousData: DataFrame =
-      TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+      TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int], 
numClasses = 2)
     testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, 
checkModelData)
   }
 }
@@ -394,7 +396,8 @@ private object GBTClassifierSuite extends SparkFunSuite {
       gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
     val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
     val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML))
-    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 2)
+    val newData: DataFrame =
+      TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 
numClasses = 2)
     val newModel = gbt.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = GBTClassificationModel.fromOld(
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index ee2aefee7a6db..47c393bcdf7bc 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -134,7 +134,7 @@ class LinearSVCSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defau
     MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC](
       dataset.as[LabeledPoint], estimator, modelEquals)
     MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC](
-      dataset.as[LabeledPoint], estimator, 2, modelEquals)
+      dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3)
     MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC](
       dataset.as[LabeledPoint], estimator, modelEquals, 42L)
   }
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 43547a4aafcb9..d5d8fabaafa42 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -1850,7 +1850,7 @@ class LogisticRegressionSuite
       MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel, 
LogisticRegression](
         dataset.as[LabeledPoint], estimator, modelEquals)
       MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel, 
LogisticRegression](
-        dataset.as[LabeledPoint], estimator, numClasses, modelEquals)
+        dataset.as[LabeledPoint], estimator, numClasses, modelEquals, 
outlierRatio = 3)
       MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel, 
LogisticRegression](
         dataset.as[LabeledPoint], estimator, modelEquals, seed)
     }
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 37d7991fe8dd8..cc68562a69750 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -178,7 +178,7 @@ class NaiveBayesSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defa
       MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes](
         dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals)
       MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes](
-        dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, 
modelEquals)
+        dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, 
modelEquals, outlierRatio = 3)
       MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes](
         dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed)
     }
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 44e1585ee514b..f9d3322770583 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -18,17 +18,17 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => 
OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 
@@ -129,7 +129,7 @@ class RandomForestClassifierSuite
   }
 
   test("predictRaw and predictProbability") {
-    val rdd = orderedLabeledPoints5_20
+    val rdd = orderedLabeledPoints5_20.map(_.toInstance(1.0))
     val rf = new RandomForestClassifier()
       .setImpurity("Gini")
       .setMaxDepth(3)
@@ -167,7 +167,6 @@ class RandomForestClassifierSuite
   /////////////////////////////////////////////////////////////////////////////
   // Tests of feature importance
   /////////////////////////////////////////////////////////////////////////////
-
   test("Feature importance with toy data") {
     val numClasses = 2
     val rf = new RandomForestClassifier()
@@ -179,7 +178,7 @@ class RandomForestClassifierSuite
       .setSeed(123)
 
     // In this data, feature 1 is very important.
-    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val data: RDD[Instance] = 
TreeTests.featureImportanceData(sc).map(_.toInstance(1.0))
     val categoricalFeatures = Map.empty[Int, Int]
     val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
 
@@ -212,7 +211,7 @@ class RandomForestClassifierSuite
     }
 
     val rf = new RandomForestClassifier().setNumTrees(2)
-    val rdd = TreeTests.getTreeReadWriteData(sc)
+    val rdd = TreeTests.getTreeReadWriteData(sc).map(_.toInstance(1.0))
 
     val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> 
"entropy")
 
@@ -239,7 +238,8 @@ private object RandomForestClassifierSuite extends 
SparkFunSuite {
     val oldModel = OldRandomForest.trainClassifier(
       data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, 
rf.getFeatureSubsetStrategy,
       rf.getSeed.toInt)
-    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
+    val newData: DataFrame =
+      TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 
numClasses)
     val newModel = rf.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = RandomForestClassificationModel.fromOld(
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 15fa26e8b5272..1238b2fa8edbc 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -18,15 +18,14 @@
 package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
-  DecisionTreeSuite => OldDecisionTreeSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, 
DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 
@@ -34,13 +33,20 @@ class DecisionTreeRegressorSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   import DecisionTreeRegressorSuite.compareAPIs
+  import testImplicits._
 
   private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+  private var linearRegressionData: DataFrame = _
+
+  private val seed = 42
 
   override def beforeAll() {
     super.beforeAll()
     categoricalDataPointsRDD =
       
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
+    linearRegressionData = 
sc.parallelize(LinearDataGenerator.generateLinearInput(
+      intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
+      xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 
2).map(_.asML).toDF()
   }
 
   /////////////////////////////////////////////////////////////////////////////
@@ -68,7 +74,8 @@ class DecisionTreeRegressorSuite
 
   test("copied model must have the same parent") {
     val categoricalFeatures = Map(0 -> 2, 1 -> 2)
-    val df = TreeTests.setMetadata(categoricalDataPointsRDD, 
categoricalFeatures, numClasses = 0)
+    val df = 
TreeTests.setMetadata(categoricalDataPointsRDD.map(_.toInstance(1.0)),
+      categoricalFeatures, numClasses = 0)
     val model = new DecisionTreeRegressor()
       .setImpurity("variance")
       .setMaxDepth(2)
@@ -85,7 +92,8 @@ class DecisionTreeRegressorSuite
       .setVarianceCol("variance")
     val categoricalFeatures = Map(0 -> 2, 1 -> 2)
 
-    val df = TreeTests.setMetadata(categoricalDataPointsRDD, 
categoricalFeatures, numClasses = 0)
+    val df = 
TreeTests.setMetadata(categoricalDataPointsRDD.map(_.toInstance(1.0)),
+      categoricalFeatures, numClasses = 0)
     val model = dt.fit(df)
 
     val predictions = model.transform(df)
@@ -98,7 +106,7 @@ class DecisionTreeRegressorSuite
         s"Expected variance $expectedVariance but got $variance.")
     }
 
-    val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
+    val varianceData: RDD[Instance] = 
TreeTests.varianceData(sc).map(_.toInstance(1.0))
     val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 
0)
     dt.setMaxDepth(1)
       .setMaxBins(6)
@@ -125,7 +133,7 @@ class DecisionTreeRegressorSuite
       .setSeed(123)
 
     // In this data, feature 1 is very important.
-    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val data: RDD[Instance] = 
TreeTests.featureImportanceData(sc).map(_.toInstance(1.0))
     val categoricalFeatures = Map.empty[Int, Int]
     val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
 
@@ -146,6 +154,27 @@ class DecisionTreeRegressorSuite
       }
   }
 
+  test("training with sample weights") {
+    val df = linearRegressionData
+    val numClasses = 0
+    val testParams = Seq(5, 10)
+    for (maxDepth <- testParams) {
+      val estimator = new DecisionTreeRegressor()
+        .setMaxDepth(maxDepth)
+        .setMinWeightFractionPerNode(0.05)
+      MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel,
+        DecisionTreeRegressor](df.as[LabeledPoint], estimator,
+        MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 
0.05), 0.9))
+      MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel,
+        DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses,
+        MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 
0.1), 0.8),
+        outlierRatio = 2)
+      MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel,
+        DecisionTreeRegressor](df.as[LabeledPoint], estimator,
+        MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 
0.01), 1.0), seed)
+    }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
@@ -159,7 +188,7 @@ class DecisionTreeRegressorSuite
     }
 
     val dt = new DecisionTreeRegressor()
-    val rdd = TreeTests.getTreeReadWriteData(sc)
+    val rdd = TreeTests.getTreeReadWriteData(sc).map(_.toInstance(1.0))
 
     // Categorical splits with tree depth 2
     val categoricalData: DataFrame =
@@ -192,7 +221,8 @@ private[ml] object DecisionTreeRegressorSuite extends 
SparkFunSuite {
     val numFeatures = data.first().features.size
     val oldStrategy = dt.getOldStrategy(categoricalFeatures)
     val oldTree = OldDecisionTree.train(data.map(OldLabeledPoint.fromML), 
oldStrategy)
-    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 0)
+    val newData: DataFrame =
+      TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 
numClasses = 0)
     val newTree = dt.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dcf3f9a1ea9b2..e805e42649e61 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -157,7 +157,7 @@ class GBTRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContext
     // In this data, feature 1 is very important.
     val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
     val categoricalFeatures = Map.empty[Int, Int]
-    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+    val df: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)), 
categoricalFeatures, 0)
 
     val importances = gbt.fit(df).featureImportances
     val mostImportantFeature = importances.argmax
@@ -183,7 +183,7 @@ class GBTRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContext
 
     val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> 
"squared")
     val continuousData: DataFrame =
-      TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+      TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int], 
numClasses = 0)
     testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, 
checkModelData)
   }
 }
@@ -203,7 +203,8 @@ private object GBTRegressorSuite extends SparkFunSuite {
     val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, 
OldAlgo.Regression)
     val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
     val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML))
-    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 0)
+    val newData: DataFrame =
+      TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 
numClasses = 0)
     val newModel = gbt.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = GBTRegressionModel.fromOld(
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 584a1b272f6c8..df3dbd0af3f39 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -839,10 +839,12 @@ class LinearRegressionSuite
         .setStandardization(standardization)
         .setRegParam(regParam)
         .setElasticNetParam(elasticNetParam)
+        .setSolver(solver)
       MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, 
LinearRegression](
         datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
       MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, 
LinearRegression](
-        datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, 
modelEquals)
+        datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, 
modelEquals,
+        outlierRatio = 3)
       MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, 
LinearRegression](
         datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed)
     }
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f9f84af..01207931f57ba 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.ml.regression
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.MLTestingUtils
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => 
OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -28,6 +29,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 
+
 /**
  * Test suite for [[RandomForestRegressor]].
  */
@@ -86,7 +88,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContex
     // In this data, feature 1 is very important.
     val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
     val categoricalFeatures = Map.empty[Int, Int]
-    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+    val df: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)), 
categoricalFeatures, 0)
 
     val model = rf.fit(df)
 
@@ -123,7 +125,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContex
     val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> 
"variance")
 
     val continuousData: DataFrame =
-      TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+      TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int], 
numClasses = 0)
     testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, 
checkModelData)
   }
 }
@@ -143,7 +145,8 @@ private object RandomForestRegressorSuite extends 
SparkFunSuite {
       rf.getOldStrategy(categoricalFeatures, numClasses = 0, 
OldAlgo.Regression, rf.getOldImpurity)
     val oldModel = 
OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy,
       rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
-    val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 0)
+    val newData: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)),
+      categoricalFeatures, numClasses = 0)
     val newModel = rf.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = RandomForestRegressionModel.fromOld(
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
index 77ab3d8bb75f7..0b09577171c41 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.tree.impl
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.mllib.tree.EnsembleTestHelper
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 
@@ -26,12 +27,16 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
  */
 class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext  {
 
-  test("BaggedPoint RDD: without subsampling") {
-    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+  test("BaggedPoint RDD: without subsampling with weights") {
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map { 
lp =>
+      Instance(lp.label, 0.5, lp.features.asML)
+    }
     val rdd = sc.parallelize(arr)
-    val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
+    val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false,
+      (instance: Instance) => instance.weight * 4.0, seed = 42)
     baggedRDD.collect().foreach { baggedPoint =>
-      assert(baggedPoint.subsampleWeights.size == 1 && 
baggedPoint.subsampleWeights(0) == 1)
+      assert(baggedPoint.subsampleCounts.size == 1 && 
baggedPoint.subsampleCounts(0) == 1)
+      assert(baggedPoint.sampleWeight === 2.0)
     }
   }
 
@@ -40,13 +45,17 @@ class BaggedPointSuite extends SparkFunSuite with 
MLlibTestSparkContext  {
     val (expectedMean, expectedStddev) = (1.0, 1.0)
 
     val seeds = Array(123, 5354, 230, 349867, 23987)
-    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 
1000).map(_.asML)
     val rdd = sc.parallelize(arr)
     seeds.foreach { seed =>
-      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, 
true, seed)
-      val subsampleCounts: Array[Array[Double]] = 
baggedRDD.map(_.subsampleWeights).collect()
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, 
true,
+        (_: LabeledPoint) => 2.0, seed)
+      val subsampleCounts: Array[Array[Double]] =
+        baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
       EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, 
expectedMean,
         expectedStddev, epsilon = 0.01)
+      // should ignore weight function for now
+      assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
     }
   }
 
@@ -59,8 +68,10 @@ class BaggedPointSuite extends SparkFunSuite with 
MLlibTestSparkContext  {
     val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
     val rdd = sc.parallelize(arr)
     seeds.foreach { seed =>
-      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, 
numSubsamples, true, seed)
-      val subsampleCounts: Array[Array[Double]] = 
baggedRDD.map(_.subsampleWeights).collect()
+      val baggedRDD =
+        BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, 
seed = seed)
+      val subsampleCounts: Array[Array[Double]] =
+        baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
       EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, 
expectedMean,
         expectedStddev, epsilon = 0.01)
     }
@@ -71,13 +82,17 @@ class BaggedPointSuite extends SparkFunSuite with 
MLlibTestSparkContext  {
     val (expectedMean, expectedStddev) = (1.0, 0)
 
     val seeds = Array(123, 5354, 230, 349867, 23987)
-    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 
1000).map(_.asML)
     val rdd = sc.parallelize(arr)
     seeds.foreach { seed =>
-      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, 
false, seed)
-      val subsampleCounts: Array[Array[Double]] = 
baggedRDD.map(_.subsampleWeights).collect()
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, 
false,
+        (_: LabeledPoint) => 2.0, seed)
+      val subsampleCounts: Array[Array[Double]] =
+        baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
       EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, 
expectedMean,
         expectedStddev, epsilon = 0.01)
+      // should ignore weight function for now
+      assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
     }
   }
 
@@ -90,8 +105,10 @@ class BaggedPointSuite extends SparkFunSuite with 
MLlibTestSparkContext  {
     val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
     val rdd = sc.parallelize(arr)
     seeds.foreach { seed =>
-      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, 
numSubsamples, false, seed)
-      val subsampleCounts: Array[Array[Double]] = 
baggedRDD.map(_.subsampleWeights).collect()
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, 
numSubsamples, false,
+        seed = seed)
+      val subsampleCounts: Array[Array[Double]] =
+        baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
       EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, 
expectedMean,
         expectedStddev, epsilon = 0.01)
     }
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index e1ab7c2d6520b..775ca669d527b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -18,10 +18,11 @@
 package org.apache.spark.ml.tree.impl
 
 import scala.collection.mutable
+import scala.language.implicitConversions
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.util.TestingUtils._
@@ -43,7 +44,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   /////////////////////////////////////////////////////////////////////////////
 
   test("Binary classification with continuous features: split calculation") {
-    val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML)
+    val arr = 
OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance(1.0))
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
@@ -55,7 +56,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   }
 
   test("Binary classification with binary (ordered) categorical features: 
split calculation") {
-    val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+    val arr = 
OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, 
numClasses = 2,
@@ -72,7 +73,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
   test("Binary classification with 3-ary (ordered) categorical features," +
     " with no samples for one category: split calculation") {
-    val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+    val arr = 
OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, 
numClasses = 2,
@@ -90,12 +91,12 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   test("find splits for a continuous feature") {
     // find splits for normal case
     {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
         Map(), Set(),
         Array(6), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0.0, 0, 0
       )
-      val featureSamples = Array.fill(200000)(math.random)
+      val featureSamples = Array.fill(200000)((1.0, math.random))
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
       assert(splits.length === 5)
       assert(fakeMetadata.numSplits(0) === 5)
@@ -107,12 +108,12 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     // find splits should not return identical splits
     // when there are not enough split candidates, reduce the number of splits 
in metadata
     {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
         Map(), Set(),
         Array(5), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0.0, 0, 0
       )
-      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 
3).map(_.toDouble)
+      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(x => 
(1.0, x.toDouble))
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
       assert(splits === Array(1.0, 2.0))
       // check returned splits are distinct
@@ -121,47 +122,67 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
     // find splits when most samples close to the minimum
     {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
         Map(), Set(),
         Array(3), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0.0, 0, 0
       )
-      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 
5).map(_.toDouble)
+      val featureSamples =
+        Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, 
x.toDouble))
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
       assert(splits === Array(2.0, 3.0))
     }
 
     // find splits when most samples close to the maximum
     {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
         Map(), Set(),
         Array(2), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0.0, 0, 0
       )
-      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2).map(_.toDouble)
+      val featureSamples =
+        Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, 
x.toDouble))
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
       assert(splits === Array(1.0))
     }
 
-    // find splits for constant feature
+    // find splits for arbitrarily scaled data
     {
-      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
+        Map(), Set(),
+        Array(6), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0.0, 0, 0
+      )
+      val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random))
+      val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w, 
x) => (w * 0.001, x)}
+      val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w, 
x) => (w * 1000, x)}
+      val splitsUnitWeight = RandomForest
+        .findSplitsForContinuousFeature(featureSamplesUnitWeight, 
fakeMetadata, 0)
+      val splitsSmallWeight = RandomForest
+        .findSplitsForContinuousFeature(featureSamplesSmallWeight, 
fakeMetadata, 0)
+      val splitsLargeWeight = RandomForest
+        .findSplitsForContinuousFeature(featureSamplesLargeWeight, 
fakeMetadata, 0)
+      assert(splitsUnitWeight === splitsSmallWeight)
+      assert(splitsUnitWeight === splitsLargeWeight)
+    }
+
+    // find splits when most weight is close to the minimum
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
         Map(), Set(),
         Array(3), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0.0, 0, 0
       )
-      val featureSamples = Array(0, 0, 0).map(_.toDouble)
-      val featureSamplesEmpty = Array.empty[Double]
+      val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 
6)).map {
+        case (w, x) => (w.toDouble, x.toDouble)
+      }
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
-      assert(splits === Array.empty[Double])
-      val splitsEmpty =
-        RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, 
fakeMetadata, 0)
-      assert(splitsEmpty === Array.empty[Double])
+      assert(splits === Array(1.0, 2.0))
     }
   }
 
   test("train with empty arrays") {
-    val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double]))
+    val lp = LabeledPoint(1.0, 
Vectors.dense(Array.empty[Double])).toInstance(1.0)
     val data = Array.fill(5)(lp)
     val rdd = sc.parallelize(data)
 
@@ -176,8 +197,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   }
 
   test("train with constant features") {
-    val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
-    val data = Array.fill(5)(lp)
+    val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 
0.0)).toInstance(1.0)
+    val data = Array.fill(5)(instance)
     val rdd = sc.parallelize(data)
     val strategy = new OldStrategy(
           OldAlgo.Classification,
@@ -189,7 +210,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = 
None)
     assert(tree.rootNode.impurity === -1.0)
     assert(tree.depth === 0)
-    assert(tree.rootNode.prediction === lp.label)
+    assert(tree.rootNode.prediction === instance.label)
 
     // Test with no categorical features
     val strategy2 = new OldStrategy(
@@ -200,11 +221,11 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = 
None)
     assert(tree2.rootNode.impurity === -1.0)
     assert(tree2.depth === 0)
-    assert(tree2.rootNode.prediction === lp.label)
+    assert(tree2.rootNode.prediction === instance.label)
   }
 
   test("Multiclass classification with unordered categorical features: split 
calculations") {
-    val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+    val arr = 
OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new OldStrategy(
@@ -245,7 +266,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   }
 
   test("Multiclass classification with ordered categorical features: split 
calculations") {
-    val arr = 
OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML)
+    val arr = 
OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+      .map(_.asML.toInstance(1.0))
     assert(arr.length === 3000)
     val rdd = sc.parallelize(arr)
     val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, 
numClasses = 100,
@@ -277,7 +299,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
       LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
       LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
-    val input = sc.parallelize(arr)
+    val input = sc.parallelize(arr.map(_.toInstance(1.0)))
 
     val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 1,
       numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@@ -319,7 +341,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
       LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
       LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
-    val input = sc.parallelize(arr)
+    val input = sc.parallelize(arr.map(_.toInstance(1.0)))
 
     val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 5,
       numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@@ -371,7 +393,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       LabeledPoint(0.0, Vectors.dense(2.0)),
       LabeledPoint(0.0, Vectors.dense(2.0)),
       LabeledPoint(1.0, Vectors.dense(2.0)))
-    val input = sc.parallelize(arr)
+    val input = sc.parallelize(arr.map(_.toInstance(1.0)))
 
     // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
     val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 1,
@@ -390,7 +412,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   }
 
   test("Second level node building with vs. without groups") {
-    val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML)
+    val arr = 
OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance(1.0))
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     // For tree with 1 group
@@ -434,7 +456,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   def 
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: 
OldStrategy) {
     val numFeatures = 50
     val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 
1000)
-    val rdd = sc.parallelize(arr).map(_.asML)
+    val rdd = sc.parallelize(arr).map(_.asML.toInstance(1.0))
 
     // Select feature subset for top nodes.  Return true if OK.
     def checkFeatureSubsetStrategy(
@@ -547,16 +569,16 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       left2       parent
                 left  right
      */
-    val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
+    val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0), 6L)
     val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
 
-    val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
+    val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0), 8L)
     val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
 
     val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 
0.5))
     val parentImp = parent.impurityStats
 
-    val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
+    val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0), 8L)
     val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
 
     val grandParent = TreeTests.buildParentNode(left2, parent, new 
ContinuousSplit(1, 1.0))
@@ -602,6 +624,57 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
     assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
+
+  test("weights at arbitrary scale") {
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
+    val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance(1.0)))
+    val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
+      Instance(inst.label, 0.001, inst.features)
+    }
+    val rddWithBigWeights = rddWithUnitWeights.map { inst =>
+      Instance(inst.label, 1000, inst.features)
+    }
+    val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
+    val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, 
"all", 42L, None)
+
+    val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, 
"all", 42L, None)
+    unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, 
smallWeightTree) =>
+      TreeTests.checkEqual(unitTree, smallWeightTree)
+    }
+
+    val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, 
"all", 42L, None)
+    unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, 
bigWeightTree) =>
+      TreeTests.checkEqual(unitTree, bigWeightTree)
+    }
+  }
+
+  test("minWeightFraction and minInstancesPerNode") {
+    val data = Array(
+      Instance(0.0, 1.0, Vectors.dense(0.0)),
+      Instance(0.0, 1.0, Vectors.dense(0.0)),
+      Instance(0.0, 1.0, Vectors.dense(0.0)),
+      Instance(0.0, 1.0, Vectors.dense(0.0)),
+      Instance(1.0, 0.1, Vectors.dense(1.0))
+    )
+    val rdd = sc.parallelize(data)
+    val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
+      minWeightFractionPerNode = 0.5)
+    val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+    assert(tree1.depth == 0)
+
+    strategy.minWeightFractionPerNode = 0.0
+    val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+    assert(tree2.depth == 1)
+
+    strategy.minInstancesPerNode = 2
+    val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+    assert(tree3.depth == 0)
+
+    strategy.minInstancesPerNode = 1
+    val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+    assert(tree4.depth == 1)
+  }
+
 }
 
 private object RandomForestSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index c90cb8ca1034c..999aa80b7750d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -18,13 +18,15 @@
 package org.apache.spark.ml.tree.impl
 
 import scala.collection.JavaConverters._
+import scala.util.Random
 
 import org.apache.spark.{SparkContext, SparkFunSuite}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, 
NumericAttribute}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, SparkSession}
 
@@ -32,6 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite {
 
   /**
    * Convert the given data to a DataFrame, and set the features and label 
metadata.
+   *
    * @param data  Dataset.  Categorical features and labels must already have 
0-based indices.
    *              This must be non-empty.
    * @param categoricalFeatures  Map: categorical feature index -> number of 
distinct values
@@ -39,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite {
    * @return DataFrame with metadata
    */
   def setMetadata(
-      data: RDD[LabeledPoint],
+      data: RDD[Instance],
       categoricalFeatures: Map[Int, Int],
       numClasses: Int): DataFrame = {
     val spark = SparkSession.builder()
@@ -66,7 +69,7 @@ private[ml] object TreeTests extends SparkFunSuite {
     }
     val labelMetadata = labelAttribute.toMetadata()
     df.select(df("features").as("features", featuresMetadata),
-      df("label").as("label", labelMetadata))
+      df("label").as("label", labelMetadata), df("weight"))
   }
 
   /** Java-friendly version of [[setMetadata()]] */
@@ -74,12 +77,14 @@ private[ml] object TreeTests extends SparkFunSuite {
       data: JavaRDD[LabeledPoint],
       categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer],
       numClasses: Int): DataFrame = {
-    setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, 
Int]].asScala.toMap,
+    setMetadata(data.rdd.map(_.toInstance(1.0)),
+      categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
       numClasses)
   }
 
   /**
    * Set label metadata (particularly the number of classes) on a DataFrame.
+   *
    * @param data  Dataset.  Categorical features and labels must already have 
0-based indices.
    *              This must be non-empty.
    * @param numClasses  Number of classes label can take. If 0, mark as 
continuous.
@@ -124,8 +129,8 @@ private[ml] object TreeTests extends SparkFunSuite {
    *       make mistakes such as creating loops of Nodes.
    */
   private def checkEqual(a: Node, b: Node): Unit = {
-    assert(a.prediction === b.prediction)
-    assert(a.impurity === b.impurity)
+    assert(a.prediction ~== b.prediction absTol 1e-8)
+    assert(a.impurity ~== b.impurity absTol 1e-8)
     (a, b) match {
       case (aye: InternalNode, bee: InternalNode) =>
         assert(aye.split === bee.split)
@@ -156,6 +161,7 @@ private[ml] object TreeTests extends SparkFunSuite {
   /**
    * Helper method for constructing a tree for testing.
    * Given left, right children, construct a parent node.
+   *
    * @param split  Split for parent node
    * @return  Parent node with children attached
    */
@@ -163,8 +169,8 @@ private[ml] object TreeTests extends SparkFunSuite {
     val leftImp = left.impurityStats
     val rightImp = right.impurityStats
     val parentImp = leftImp.copy.add(rightImp)
-    val leftWeight = leftImp.count / parentImp.count.toDouble
-    val rightWeight = rightImp.count / parentImp.count.toDouble
+    val leftWeight = leftImp.count / parentImp.count
+    val rightWeight = rightImp.count / parentImp.count
     val gain = parentImp.calculate() -
       (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
     val pred = parentImp.predict
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index f1ed568d5e60a..4b1bd35313c5a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, 
HasWeightCol}
+import org.apache.spark.ml.param.shared.HasWeightCol
 import org.apache.spark.ml.recommendation.{ALS, ALSModel}
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
@@ -246,8 +246,8 @@ object MLTestingUtils extends SparkFunSuite {
       seed: Long): Unit = {
     val (overSampledData, weightedData) = 
genEquivalentOversampledAndWeightedInstances(
       data, seed)
-    val weightedModel = estimator.set(estimator.weightCol, 
"weight").fit(weightedData)
     val overSampledModel = estimator.set(estimator.weightCol, 
"").fit(overSampledData)
+    val weightedModel = estimator.set(estimator.weightCol, 
"weight").fit(weightedData)
     modelEquals(weightedModel, overSampledModel)
   }
 
@@ -260,15 +260,17 @@ object MLTestingUtils extends SparkFunSuite {
       data: Dataset[LabeledPoint],
       estimator: E with HasWeightCol,
       numClasses: Int,
-      modelEquals: (M, M) => Unit): Unit = {
+      modelEquals: (M, M) => Unit,
+      outlierRatio: Int): Unit = {
     import data.sqlContext.implicits._
     val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap {
       case Instance(l, w, f) =>
         val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1
-        List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, 
f))
+        List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ 
List(Instance(l, w, f))
     }
     val trueModel = estimator.set(estimator.weightCol, "").fit(data)
-    val outlierModel = estimator.set(estimator.weightCol, 
"weight").fit(outlierDS)
+    val outlierModel = estimator.set(estimator.weightCol, "weight")
+      .fit(outlierDS)
     modelEquals(trueModel, outlierModel)
   }
 
@@ -281,10 +283,26 @@ object MLTestingUtils extends SparkFunSuite {
       estimator: E with HasWeightCol,
       modelEquals: (M, M) => Unit): Unit = {
     estimator.set(estimator.weightCol, "weight")
-    val models = Seq(0.001, 1.0, 1000.0).map { w =>
+    val models = Seq(0.01, 1.0, 1000.0).map { w =>
       val df = data.withColumn("weight", lit(w))
       estimator.fit(df)
     }
     models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)}
   }
+
+  def modelPredictionEquals[M <: PredictionModel[_, M]](
+      data: DataFrame,
+      compareFunc: (Double, Double) => Boolean,
+      fractionInTol: Double)(
+      model1: M,
+      model2: M): Unit = {
+    val pred1 = 
model1.transform(data).select(model1.getPredictionCol).collect()
+    val pred2 = 
model2.transform(data).select(model2.getPredictionCol).collect()
+    val inTol = pred1.zip(pred2).count { case (p1, p2) =>
+      val x = p1.getDouble(0)
+      val y = p2.getDouble(0)
+      compareFunc(x, y)
+    }
+    assert(inTol / pred1.length.toDouble >= fractionInTol)
+  }
 }
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 441d0f7614bf6..f6efc84a24185 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
 
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
@@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
@@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, maxDepth = 3,
       numClasses = 2, maxBins = 100)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
@@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, maxDepth = 3,
       numClasses = 2, maxBins = 100)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
@@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
       numClasses = 2, maxBins = 100)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
@@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
       numClasses = 2, maxBins = 100)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
@@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(strategy.isMulticlassClassification)
     assert(metadata.isUnordered(featureIndex = 0))
     assert(metadata.isUnordered(featureIndex = 1))
@@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       numClasses = 3, maxBins = maxBins,
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     assert(strategy.isMulticlassClassification)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(metadata.isUnordered(featureIndex = 0))
     assert(metadata.isUnordered(featureIndex = 1))
 
@@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
     assert(strategy.isMulticlassClassification)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(metadata.isUnordered(featureIndex = 0))
 
     val model = DecisionTree.train(rdd, strategy)
@@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       numClasses = 3, maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
     assert(strategy.isMulticlassClassification)
-    val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), 
strategy)
+    val metadata = 
DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
index 14152cdd63bc7..d4171cf441e96 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -18,23 +18,62 @@
 package org.apache.spark.mllib.tree
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.tree.impurity._
 
 /**
  * Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
  */
 class ImpuritySuite extends SparkFunSuite {
+
+  private val seed = 42
+
   test("Gini impurity does not support negative labels") {
     val gini = new GiniAggregator(2)
     intercept[IllegalArgumentException] {
-      gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+      gini.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
     }
   }
 
   test("Entropy does not support negative labels") {
     val entropy = new EntropyAggregator(2)
     intercept[IllegalArgumentException] {
-      entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+      entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
+    }
+  }
+
+  test("Classification impurities are insensitive to scaling") {
+    val rng = new scala.util.Random(seed)
+    val weightedCounts = Array.fill(5)(rng.nextDouble())
+    val smallWeightedCounts = weightedCounts.map(_ * 0.0001)
+    val largeWeightedCounts = weightedCounts.map(_ * 10000)
+    Seq(Gini, Entropy).foreach { impurity =>
+      val impurity1 = impurity.calculate(weightedCounts, weightedCounts.sum)
+      assert(impurity.calculate(smallWeightedCounts, smallWeightedCounts.sum)
+        ~== impurity1 relTol 0.005)
+      assert(impurity.calculate(largeWeightedCounts, largeWeightedCounts.sum)
+        ~== impurity1 relTol 0.005)
     }
   }
+  test("Regression impurities are insensitive to scaling") {
+    def computeStats(samples: Seq[Double], weights: Seq[Double]): (Double, 
Double, Double) = {
+      samples.zip(weights).foldLeft((0.0, 0.0, 0.0)) { case ((wn, wy, wyy), 
(y, w)) =>
+        (wn + w, wy + w * y, wyy + w * y * y)
+      }
+    }
+    val rng = new scala.util.Random(seed)
+    val samples = Array.fill(10)(rng.nextDouble())
+    val _weights = Array.fill(10)(rng.nextDouble())
+    val smallWeights = _weights.map(_ * 0.0001)
+    val largeWeights = _weights.map(_ * 10000)
+    val (count, sum, sumSquared) = computeStats(samples, _weights)
+    Seq(Variance).foreach { impurity =>
+      val impurity1 = impurity.calculate(count, sum, sumSquared)
+      val (smallCount, smallSum, smallSumSquared) = computeStats(samples, 
smallWeights)
+      val (largeCount, largeSum, largeSumSquared) = computeStats(samples, 
largeWeights)
+      assert(impurity.calculate(smallCount, smallSum, smallSumSquared) ~== 
impurity1 relTol 0.005)
+      assert(impurity.calculate(largeCount, largeSum, largeSumSquared) ~== 
impurity1 relTol 0.005)
+    }
+  }
+
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7e6e143523387..c41b67b8f6bf0 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,9 @@ object MimaExcludes {
 
   // Exclude rules for 2.2.x
   lazy val v22excludes = v21excludes ++ Seq(
+    // [SPARK-9478][ML] Add sample weights to decision trees
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.this"),
+
     // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
     
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"),
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to