zhengruifeng commented on a change in pull request #28710:
URL: https://github.com/apache/spark/pull/28710#discussion_r439969774
##########
File path: project/MimaExcludes.scala
##########
@@ -39,18 +39,44 @@ object MimaExcludes {
// [SPARK-31077] Remove ChiSqSelector dependency on
mllib.ChiSqSelectorModel
// private constructor
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.this"),
+
// [SPARK-31127] Implement abstract Selector
// org.apache.spark.ml.feature.ChiSqSelectorModel type hierarchy change
// before: class ChiSqSelector extends Estimator with ChiSqSelectorParams
// after: class ChiSqSelector extends PSelector
// false positive, no binary incompatibility
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelector"),
+
//[SPARK-31840] Add instance weight support in LogisticRegressionSummary
// weightCol in
org.apache.spark.ml.classification.LogisticRegressionSummary is present only in
current version
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightCol"),
+
// [SPARK-24634] Add a new metric regarding number of inputs later than
watermark plus allowed delay
-
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.<init>$default$4")
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.<init>$default$4"),
+
+ //[SPARK-31893] Add a generic ClassificationSummary trait
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.weightCol"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession_="),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics_="),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.weightCol"),
+
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.asBinary"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightCol"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession_="),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics_="),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"),
+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.weightCol")
)
Review comment:
Are there some MiMa execlusions can be removed after the lastest change?
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics,
MulticlassMetrics}
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
+
+
+/**
+ * Abstraction for multiclass classification results for a given model.
+ */
+private[classification] trait ClassificationSummary extends Serializable {
+
+ /**
+ * Dataframe output by the model's `transform` method.
+ */
+ @Since("3.1.0")
+ def predictions: DataFrame
+
+ /** Field in "predictions" which gives the prediction of each class. */
+ @Since("3.1.0")
+ def predictionCol: String
+
+ /** Field in "predictions" which gives the true label of each instance (if
available). */
+ @Since("3.1.0")
+ def labelCol: String
+
+ /** Field in "predictions" which gives the weight of each instance as a
vector. */
+ @Since("3.1.0")
+ def weightCol: String
+
+ @transient private val multiclassMetrics = {
+ val weightColumn = if (predictions.schema.fieldNames.contains(weightCol)) {
+ col(weightCol).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
+ new MulticlassMetrics(
+ predictions.select(col(predictionCol), col(labelCol).cast(DoubleType),
weightColumn)
+ .rdd.map {
+ case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
+ })
+ }
+
+ /**
+ * Returns the sequence of labels in ascending order. This order matches the
order used
+ * in metrics which are specified as arrays over labels, e.g.,
truePositiveRateByLabel.
+ *
+ * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1},
However, if the
+ * training set is missing a label, then all of the arrays over labels
+ * (e.g., from truePositiveRateByLabel) will be of length numClasses-1
instead of the
+ * expected numClasses.
+ */
+ @Since("3.1.0")
+ def labels: Array[Double] = multiclassMetrics.labels
+
+ /** Returns true positive rate for each label (category). */
+ @Since("3.1.0")
+ def truePositiveRateByLabel: Array[Double] = recallByLabel
+
+ /** Returns false positive rate for each label (category). */
+ @Since("3.1.0")
+ def falsePositiveRateByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label =>
multiclassMetrics.falsePositiveRate(label))
+ }
+
+ /** Returns precision for each label (category). */
+ @Since("3.1.0")
+ def precisionByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.precision(label))
+ }
+
+ /** Returns recall for each label (category). */
+ @Since("3.1.0")
+ def recallByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.recall(label))
+ }
+
+ /** Returns f-measure for each label (category). */
+ @Since("3.1.0")
+ def fMeasureByLabel(beta: Double): Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label,
beta))
+ }
+
+ /** Returns f1-measure for each label (category). */
+ @Since("3.1.0")
+ def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0)
+
+ /**
+ * Returns accuracy.
+ * (equals to the total number of correctly classified instances
+ * out of the total number of instances.)
+ */
+ @Since("3.1.0")
+ def accuracy: Double = multiclassMetrics.accuracy
+
+ /**
+ * Returns weighted true positive rate.
+ * (equals to precision, recall and f-measure)
+ */
+ @Since("3.1.0")
+ def weightedTruePositiveRate: Double = weightedRecall
+
+ /** Returns weighted false positive rate. */
+ @Since("3.1.0")
+ def weightedFalsePositiveRate: Double =
multiclassMetrics.weightedFalsePositiveRate
+
+ /**
+ * Returns weighted averaged recall.
+ * (equals to precision, recall and f-measure)
+ */
+ @Since("3.1.0")
+ def weightedRecall: Double = multiclassMetrics.weightedRecall
+
+ /** Returns weighted averaged precision. */
+ @Since("3.1.0")
+ def weightedPrecision: Double = multiclassMetrics.weightedPrecision
+
+ /** Returns weighted averaged f-measure. */
+ @Since("3.1.0")
+ def weightedFMeasure(beta: Double): Double =
multiclassMetrics.weightedFMeasure(beta)
+
+ /** Returns weighted averaged f1-measure. */
+ @Since("3.1.0")
+ def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
+
+ /**
+ * Convenient method for casting to binary classification summary.
+ * This method will throw an Exception if the summary is not a binary
summary.
+ */
+ @Since("3.1.0")
+ def asBinary: BinaryClassificationSummary = this match {
+ case b: BinaryClassificationSummary => b
+ case _ =>
+ throw new RuntimeException("Cannot cast to a binary summary.")
+ }
+}
+
+/**
+ * Abstraction for training results.
+ */
+private[classification] trait TrainingSummary {
+
+ /** objective function (scaled loss + regularization) at each iteration. */
+ @Since("3.1.0")
+ def objectiveHistory: Array[Double]
+
+ /** Number of training iterations. */
+ @Since("3.1.0")
+ def totalIterations: Int = objectiveHistory.length
+}
+
+/**
+ * Abstraction for binary classification results for a given model.
+ */
+trait BinaryClassificationSummary extends ClassificationSummary {
Review comment:
private[classification]?
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
##########
@@ -1396,239 +1393,34 @@ object LogisticRegressionModel extends
MLReadable[LogisticRegressionModel] {
/**
* Abstraction for logistic regression results for a given model.
*/
-sealed trait LogisticRegressionSummary extends Serializable {
-
- /**
- * Dataframe output by the model's `transform` method.
- */
- @Since("1.5.0")
- def predictions: DataFrame
+sealed trait LogisticRegressionSummary extends ClassificationSummary {
/** Field in "predictions" which gives the probability of each class as a
vector. */
@Since("1.5.0")
def probabilityCol: String
- /** Field in "predictions" which gives the prediction of each class. */
- @Since("2.3.0")
- def predictionCol: String
-
- /** Field in "predictions" which gives the true label of each instance (if
available). */
- @Since("1.5.0")
- def labelCol: String
-
/** Field in "predictions" which gives the features of each instance as a
vector. */
@Since("1.6.0")
def featuresCol: String
-
- /** Field in "predictions" which gives the weight of each instance as a
vector. */
- @Since("3.1.0")
- def weightCol: String
-
- @transient private val multiclassMetrics = {
- if (predictions.schema.fieldNames.contains(weightCol)) {
- new MulticlassMetrics(
- predictions.select(
- col(predictionCol),
- col(labelCol).cast(DoubleType),
- checkNonNegativeWeight(col(weightCol).cast(DoubleType))).rdd.map {
- case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
- })
- } else {
- new MulticlassMetrics(
- predictions.select(
- col(predictionCol),
- col(labelCol).cast(DoubleType),
- lit(1.0)).rdd.map {
- case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
- })
- }
- }
-
- /**
- * Returns the sequence of labels in ascending order. This order matches the
order used
- * in metrics which are specified as arrays over labels, e.g.,
truePositiveRateByLabel.
- *
- * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1},
However, if the
- * training set is missing a label, then all of the arrays over labels
- * (e.g., from truePositiveRateByLabel) will be of length numClasses-1
instead of the
- * expected numClasses.
- */
- @Since("2.3.0")
- def labels: Array[Double] = multiclassMetrics.labels
-
- /** Returns true positive rate for each label (category). */
- @Since("2.3.0")
- def truePositiveRateByLabel: Array[Double] = recallByLabel
-
- /** Returns false positive rate for each label (category). */
- @Since("2.3.0")
- def falsePositiveRateByLabel: Array[Double] = {
- multiclassMetrics.labels.map(label =>
multiclassMetrics.falsePositiveRate(label))
- }
-
- /** Returns precision for each label (category). */
- @Since("2.3.0")
- def precisionByLabel: Array[Double] = {
- multiclassMetrics.labels.map(label => multiclassMetrics.precision(label))
- }
-
- /** Returns recall for each label (category). */
- @Since("2.3.0")
- def recallByLabel: Array[Double] = {
- multiclassMetrics.labels.map(label => multiclassMetrics.recall(label))
- }
-
- /** Returns f-measure for each label (category). */
- @Since("2.3.0")
- def fMeasureByLabel(beta: Double): Array[Double] = {
- multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label,
beta))
- }
-
- /** Returns f1-measure for each label (category). */
- @Since("2.3.0")
- def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0)
-
- /**
- * Returns accuracy.
- * (equals to the total number of correctly classified instances
- * out of the total number of instances.)
- */
- @Since("2.3.0")
- def accuracy: Double = multiclassMetrics.accuracy
-
- /**
- * Returns weighted true positive rate.
- * (equals to precision, recall and f-measure)
- */
- @Since("2.3.0")
- def weightedTruePositiveRate: Double = weightedRecall
-
- /** Returns weighted false positive rate. */
- @Since("2.3.0")
- def weightedFalsePositiveRate: Double =
multiclassMetrics.weightedFalsePositiveRate
-
- /**
- * Returns weighted averaged recall.
- * (equals to precision, recall and f-measure)
- */
- @Since("2.3.0")
- def weightedRecall: Double = multiclassMetrics.weightedRecall
-
- /** Returns weighted averaged precision. */
- @Since("2.3.0")
- def weightedPrecision: Double = multiclassMetrics.weightedPrecision
-
- /** Returns weighted averaged f-measure. */
- @Since("2.3.0")
- def weightedFMeasure(beta: Double): Double =
multiclassMetrics.weightedFMeasure(beta)
-
- /** Returns weighted averaged f1-measure. */
- @Since("2.3.0")
- def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
-
- /**
- * Convenient method for casting to binary logistic regression summary.
- * This method will throw an Exception if the summary is not a binary
summary.
- */
- @Since("2.3.0")
- def asBinary: BinaryLogisticRegressionSummary = this match {
- case b: BinaryLogisticRegressionSummary => b
- case _ =>
- throw new RuntimeException("Cannot cast to a binary summary.")
- }
}
/**
* Abstraction for multiclass logistic regression training results.
- * Currently, the training summary ignores the training weights except
- * for the objective trace.
*/
-sealed trait LogisticRegressionTrainingSummary extends
LogisticRegressionSummary {
-
- /** objective function (scaled loss + regularization) at each iteration. */
- @Since("1.5.0")
- def objectiveHistory: Array[Double]
-
- /** Number of training iterations. */
- @Since("1.5.0")
- def totalIterations: Int = objectiveHistory.length
-
+sealed trait LogisticRegressionTrainingSummary extends
LogisticRegressionSummary
+ with TrainingSummary {
}
/**
* Abstraction for binary logistic regression results for a given model.
*/
-sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary
{
-
- private val sparkSession = predictions.sparkSession
- import sparkSession.implicits._
-
- // TODO: Allow the user to vary the number of bins using a setBins method in
- // BinaryClassificationMetrics. For now the default is set to 100.
- @transient private val binaryMetrics = if
(predictions.schema.fieldNames.contains(weightCol)) {
- new BinaryClassificationMetrics(
- predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType),
- checkNonNegativeWeight(col(weightCol).cast(DoubleType))).rdd.map {
- case Row(score: Vector, label: Double, weight: Double) => (score(1),
label, weight)
- }, 100
- )
- } else {
- new BinaryClassificationMetrics(
- predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType),
- lit(1.0)).rdd.map {
- case Row(score: Vector, label: Double, weight: Double) => (score(1),
label, weight)
- }, 100
- )
- }
-
- /**
- * Returns the receiver operating characteristic (ROC) curve,
- * which is a Dataframe having two fields (FPR, TPR)
- * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
- * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic
- */
- @Since("1.5.0")
- @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
+sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary
+ with BinaryClassificationSummary {
- /**
- * Computes the area under the receiver operating characteristic (ROC) curve.
- */
- @Since("1.5.0")
- lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
-
- /**
- * Returns the precision-recall curve, which is a Dataframe containing
- * two fields recall, precision with (0.0, 1.0) prepended to it.
- */
- @Since("1.5.0")
- @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall",
"precision")
-
- /**
- * Returns a dataframe with two fields (threshold, F-Measure) curve with
beta = 1.0.
- */
- @Since("1.5.0")
- @transient lazy val fMeasureByThreshold: DataFrame = {
- binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
- }
-
- /**
- * Returns a dataframe with two fields (threshold, precision) curve.
- * Every possible probability obtained in transforming the dataset are used
- * as thresholds used in calculating the precision.
- */
- @Since("1.5.0")
- @transient lazy val precisionByThreshold: DataFrame = {
- binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
- }
-
- /**
- * Returns a dataframe with two fields (threshold, recall) curve.
- * Every possible probability obtained in transforming the dataset are used
- * as thresholds used in calculating the recall.
- */
- @Since("1.5.0")
- @transient lazy val recallByThreshold: DataFrame = {
- binaryMetrics.recallByThreshold().toDF("threshold", "recall")
+ override def scoreCol: String = if (probabilityCol.nonEmpty) {
Review comment:
private?
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
##########
@@ -1396,239 +1393,34 @@ object LogisticRegressionModel extends
MLReadable[LogisticRegressionModel] {
/**
* Abstraction for logistic regression results for a given model.
*/
-sealed trait LogisticRegressionSummary extends Serializable {
-
- /**
- * Dataframe output by the model's `transform` method.
- */
- @Since("1.5.0")
- def predictions: DataFrame
+sealed trait LogisticRegressionSummary extends ClassificationSummary {
/** Field in "predictions" which gives the probability of each class as a
vector. */
@Since("1.5.0")
def probabilityCol: String
- /** Field in "predictions" which gives the prediction of each class. */
- @Since("2.3.0")
- def predictionCol: String
-
- /** Field in "predictions" which gives the true label of each instance (if
available). */
- @Since("1.5.0")
- def labelCol: String
-
/** Field in "predictions" which gives the features of each instance as a
vector. */
@Since("1.6.0")
def featuresCol: String
-
- /** Field in "predictions" which gives the weight of each instance as a
vector. */
- @Since("3.1.0")
- def weightCol: String
-
- @transient private val multiclassMetrics = {
- if (predictions.schema.fieldNames.contains(weightCol)) {
- new MulticlassMetrics(
- predictions.select(
- col(predictionCol),
- col(labelCol).cast(DoubleType),
- checkNonNegativeWeight(col(weightCol).cast(DoubleType))).rdd.map {
- case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
- })
- } else {
- new MulticlassMetrics(
- predictions.select(
- col(predictionCol),
- col(labelCol).cast(DoubleType),
- lit(1.0)).rdd.map {
- case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
- })
- }
- }
-
- /**
- * Returns the sequence of labels in ascending order. This order matches the
order used
- * in metrics which are specified as arrays over labels, e.g.,
truePositiveRateByLabel.
- *
- * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1},
However, if the
- * training set is missing a label, then all of the arrays over labels
- * (e.g., from truePositiveRateByLabel) will be of length numClasses-1
instead of the
- * expected numClasses.
- */
- @Since("2.3.0")
- def labels: Array[Double] = multiclassMetrics.labels
-
- /** Returns true positive rate for each label (category). */
- @Since("2.3.0")
- def truePositiveRateByLabel: Array[Double] = recallByLabel
-
- /** Returns false positive rate for each label (category). */
- @Since("2.3.0")
- def falsePositiveRateByLabel: Array[Double] = {
- multiclassMetrics.labels.map(label =>
multiclassMetrics.falsePositiveRate(label))
- }
-
- /** Returns precision for each label (category). */
- @Since("2.3.0")
- def precisionByLabel: Array[Double] = {
- multiclassMetrics.labels.map(label => multiclassMetrics.precision(label))
- }
-
- /** Returns recall for each label (category). */
- @Since("2.3.0")
- def recallByLabel: Array[Double] = {
- multiclassMetrics.labels.map(label => multiclassMetrics.recall(label))
- }
-
- /** Returns f-measure for each label (category). */
- @Since("2.3.0")
- def fMeasureByLabel(beta: Double): Array[Double] = {
- multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label,
beta))
- }
-
- /** Returns f1-measure for each label (category). */
- @Since("2.3.0")
- def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0)
-
- /**
- * Returns accuracy.
- * (equals to the total number of correctly classified instances
- * out of the total number of instances.)
- */
- @Since("2.3.0")
- def accuracy: Double = multiclassMetrics.accuracy
-
- /**
- * Returns weighted true positive rate.
- * (equals to precision, recall and f-measure)
- */
- @Since("2.3.0")
- def weightedTruePositiveRate: Double = weightedRecall
-
- /** Returns weighted false positive rate. */
- @Since("2.3.0")
- def weightedFalsePositiveRate: Double =
multiclassMetrics.weightedFalsePositiveRate
-
- /**
- * Returns weighted averaged recall.
- * (equals to precision, recall and f-measure)
- */
- @Since("2.3.0")
- def weightedRecall: Double = multiclassMetrics.weightedRecall
-
- /** Returns weighted averaged precision. */
- @Since("2.3.0")
- def weightedPrecision: Double = multiclassMetrics.weightedPrecision
-
- /** Returns weighted averaged f-measure. */
- @Since("2.3.0")
- def weightedFMeasure(beta: Double): Double =
multiclassMetrics.weightedFMeasure(beta)
-
- /** Returns weighted averaged f1-measure. */
- @Since("2.3.0")
- def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
-
- /**
- * Convenient method for casting to binary logistic regression summary.
- * This method will throw an Exception if the summary is not a binary
summary.
- */
- @Since("2.3.0")
- def asBinary: BinaryLogisticRegressionSummary = this match {
- case b: BinaryLogisticRegressionSummary => b
- case _ =>
- throw new RuntimeException("Cannot cast to a binary summary.")
- }
}
/**
* Abstraction for multiclass logistic regression training results.
- * Currently, the training summary ignores the training weights except
- * for the objective trace.
*/
-sealed trait LogisticRegressionTrainingSummary extends
LogisticRegressionSummary {
-
- /** objective function (scaled loss + regularization) at each iteration. */
- @Since("1.5.0")
- def objectiveHistory: Array[Double]
-
- /** Number of training iterations. */
- @Since("1.5.0")
- def totalIterations: Int = objectiveHistory.length
-
+sealed trait LogisticRegressionTrainingSummary extends
LogisticRegressionSummary
+ with TrainingSummary {
}
/**
* Abstraction for binary logistic regression results for a given model.
*/
-sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary
{
-
- private val sparkSession = predictions.sparkSession
- import sparkSession.implicits._
-
- // TODO: Allow the user to vary the number of bins using a setBins method in
- // BinaryClassificationMetrics. For now the default is set to 100.
- @transient private val binaryMetrics = if
(predictions.schema.fieldNames.contains(weightCol)) {
- new BinaryClassificationMetrics(
- predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType),
- checkNonNegativeWeight(col(weightCol).cast(DoubleType))).rdd.map {
- case Row(score: Vector, label: Double, weight: Double) => (score(1),
label, weight)
- }, 100
- )
- } else {
- new BinaryClassificationMetrics(
- predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType),
- lit(1.0)).rdd.map {
- case Row(score: Vector, label: Double, weight: Double) => (score(1),
label, weight)
- }, 100
- )
- }
-
- /**
- * Returns the receiver operating characteristic (ROC) curve,
- * which is a Dataframe having two fields (FPR, TPR)
- * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
- * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic
- */
- @Since("1.5.0")
- @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
+sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary
+ with BinaryClassificationSummary {
- /**
- * Computes the area under the receiver operating characteristic (ROC) curve.
- */
- @Since("1.5.0")
- lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
-
- /**
- * Returns the precision-recall curve, which is a Dataframe containing
- * two fields recall, precision with (0.0, 1.0) prepended to it.
- */
- @Since("1.5.0")
- @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall",
"precision")
-
- /**
- * Returns a dataframe with two fields (threshold, F-Measure) curve with
beta = 1.0.
- */
- @Since("1.5.0")
- @transient lazy val fMeasureByThreshold: DataFrame = {
- binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
- }
-
- /**
- * Returns a dataframe with two fields (threshold, precision) curve.
- * Every possible probability obtained in transforming the dataset are used
- * as thresholds used in calculating the precision.
- */
- @Since("1.5.0")
- @transient lazy val precisionByThreshold: DataFrame = {
- binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
- }
-
- /**
- * Returns a dataframe with two fields (threshold, recall) curve.
- * Every possible probability obtained in transforming the dataset are used
- * as thresholds used in calculating the recall.
- */
- @Since("1.5.0")
- @transient lazy val recallByThreshold: DataFrame = {
- binaryMetrics.recallByThreshold().toDF("threshold", "recall")
+ override def scoreCol: String = if (probabilityCol.nonEmpty) {
Review comment:
if `probabilityCol.isEmpty` use `rawPredictionCol` instead?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]