Repository: spark
Updated Branches:
  refs/heads/master 9498e528d -> 7e7350285


[SPARK-24132][ML] Instrumentation improvement for classification

## What changes were proposed in this pull request?

- Add OptionalInstrumentation as argument for getNumClasses in 
ml.classification.Classifier

- Change the function call for getNumClasses in train() in 
ml.classification.DecisionTreeClassifier, 
ml.classification.RandomForestClassifier, and ml.classification.NaiveBayes

- Modify the instrumentation creation in ml.classification.LinearSVC

- Change the log call in ml.classification.OneVsRest and 
ml.classification.LinearSVC

## How was this patch tested?

Manual.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Lu WANG <lu.w...@databricks.com>

Closes #21204 from ludatabricks/SPARK-23686.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e735028
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e735028
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e735028

Branch: refs/heads/master
Commit: 7e7350285dc22764f599671d874617c0eea093e5
Parents: 9498e52
Author: Lu WANG <lu.w...@databricks.com>
Authored: Tue May 8 21:20:58 2018 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue May 8 21:20:58 2018 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/DecisionTreeClassifier.scala    | 9 ++++++---
 .../org/apache/spark/ml/classification/LinearSVC.scala      | 9 ++++++---
 .../org/apache/spark/ml/classification/NaiveBayes.scala     | 3 ++-
 .../org/apache/spark/ml/classification/OneVsRest.scala      | 4 ++--
 .../spark/ml/classification/RandomForestClassifier.scala    | 4 +++-
 5 files changed, 19 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 57797d1..c9786f1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") (
   override def setSeed(value: Long): this.type = set(seed, value)
 
   override protected def train(dataset: Dataset[_]): 
DecisionTreeClassificationModel = {
+    val instr = Instrumentation.create(this, dataset)
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val numClasses: Int = getNumClasses(dataset)
+    instr.logNumClasses(numClasses)
 
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, 
numClasses)
     val strategy = getOldStrategy(categoricalFeatures, numClasses)
 
-    val instr = Instrumentation.create(this, oldDataset)
-    instr.logParams(params: _*)
+    instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, 
maxMemoryInMB,
+      cacheNodeIds, checkpointInterval, impurity, seed)
 
     val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
       seed = $(seed), instr = Some(instr), parentUID = Some(uid))
@@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
   private[ml] def train(data: RDD[LabeledPoint],
       oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
     val instr = Instrumentation.create(this, data)
-    instr.logParams(params: _*)
+    instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, 
maxMemoryInMB,
+      cacheNodeIds, checkpointInterval, impurity, seed)
 
     val trees = RandomForest.run(data, oldStrategy, numTrees = 1, 
featureSubsetStrategy = "all",
       seed = 0L, instr = Some(instr), parentUID = Some(uid))

http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 80c537e..38eb045 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") (
           Instance(label, weight, features)
       }
 
-    val instr = Instrumentation.create(this, instances)
+    val instr = Instrumentation.create(this, dataset)
     instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, 
threshold,
       aggregationDepth)
 
@@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") (
         (new MultivariateOnlineSummarizer, new MultiClassSummarizer)
       )(seqOp, combOp, $(aggregationDepth))
     }
+    instr.logNamedValue(Instrumentation.loggerTags.numExamples, 
summarizer.count)
+    instr.logNamedValue("lowestLabelWeight", 
labelSummarizer.histogram.min.toString)
+    instr.logNamedValue("highestLabelWeight", 
labelSummarizer.histogram.max.toString)
 
     val histogram = labelSummarizer.histogram
     val numInvalid = labelSummarizer.countInvalid
@@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") (
       if (numInvalid != 0) {
         val msg = s"Classification labels should be in [0 to ${numClasses - 
1}]. " +
           s"Found $numInvalid invalid labels."
-        logError(msg)
+        instr.logError(msg)
         throw new SparkException(msg)
       }
 
@@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") (
       bcFeaturesStd.destroy(blocking = false)
       if (state == null) {
         val msg = s"${optimizer.getClass.getName} failed."
-        logError(msg)
+        instr.logError(msg)
         throw new SparkException(msg)
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 45fb585..1dde18d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") (
   private[spark] def trainWithLabelCheck(
       dataset: Dataset[_],
       positiveLabel: Boolean): NaiveBayesModel = {
+    val instr = Instrumentation.create(this, dataset)
     if (positiveLabel && isDefined(thresholds)) {
       val numClasses = getNumClasses(dataset)
+      instr.logNumClasses(numClasses)
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
         ".train() called with non-matching numClasses and thresholds.length." +
         s" numClasses=$numClasses, but thresholds has length 
${$(thresholds).length}")
@@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") (
       }
     }
 
-    val instr = Instrumentation.create(this, dataset)
     instr.logParams(labelCol, featuresCol, weightCol, predictionCol, 
rawPredictionCol,
       probabilityCol, modelType, smoothing, thresholds)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7df53a6..3474b61 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") (
     transformSchema(dataset.schema)
 
     val instr = Instrumentation.create(this, dataset)
-    instr.logParams(labelCol, featuresCol, predictionCol, parallelism)
+    instr.logParams(labelCol, featuresCol, predictionCol, parallelism, 
rawPredictionCol)
     instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
 
     // determine number of classes either from metadata if provided, or via 
computation.
@@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") (
       getClassifier match {
         case _: HasWeightCol => true
         case c =>
-          logWarning(s"weightCol is ignored, as it is not supported by $c 
now.")
+          instr.logWarning(s"weightCol is ignored, as it is not supported by 
$c now.")
           false
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/7e735028/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index f1ef26a..040db3b 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") (
     set(featureSubsetStrategy, value)
 
   override protected def train(dataset: Dataset[_]): 
RandomForestClassificationModel = {
+    val instr = Instrumentation.create(this, dataset)
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val numClasses: Int = getNumClasses(dataset)
@@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") (
     val strategy =
       super.getOldStrategy(categoricalFeatures, numClasses, 
OldAlgo.Classification, getOldImpurity)
 
-    val instr = Instrumentation.create(this, oldDataset)
     instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, 
rawPredictionCol,
       impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, 
maxMemoryInMB, minInfoGain,
       minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, 
checkpointInterval)
@@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") (
 
     val numFeatures = oldDataset.first().features.size
     val m = new RandomForestClassificationModel(uid, trees, numFeatures, 
numClasses)
+    instr.logNumClasses(numClasses)
+    instr.logNumFeatures(numFeatures)
     instr.logSuccess(m)
     m
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to