huaxingao commented on pull request #28710:
URL: https://github.com/apache/spark/pull/28710#issuecomment-642907472
@zhengruifeng Thanks for your comments.
I think the traits should be like this (not including weightCol etc. for
simplicity):
```
trait ClassificationSummary {
def predictionCol: String
def labelCol: String
val multiclassMetrics =
new MulticlassMetrics(
predictions.select(col(predictionCol),
col(labelCol).cast(DoubleType))
. . . . . .
)
trait BinaryClassificationSummary extends ClassificationSummary
def rawPredictionCol: String
val binaryMetrics =
new BinaryClassificationMetrics(
predictions.select(col(rawPredictionCol),
col(labelCol).cast(DoubleType))
. . . . . .
)
```
However, currently BinaryLogisticRegressionSummary uses probabilityCol for
BinaryClassificationMetrics and this probabilityCol is in
LogisticRegressionSummary instead of BinaryLogisticRegressionSummary. In order
not to break the existing code, I need to make several changes for the above
traits
1. change rawPredictionCol to scoreCol
can't use rawPredictionCol since currently LogisticRegression uses
probabilityCol
can't use probabilityCol since LinearSVC doesn't have probabilityCol
2. put scoreCol in ClassificationSummary (since currently probabilityCol is
in LogisticRegressionSummary instead of BinaryLogisticRegressionSummary)
that's how I get the current traits as following:
```
trait ClassificationSummary {
def scoreCol: String
def predictionCol: String
def labelCol: String
val multiclassMetrics =
new MulticlassMetrics(
predictions.select(col(predictionCol),
col(labelCol).cast(DoubleType))
. . . . . .
)
trait BinaryClassificationSummary extends ClassificationSummary
val binaryMetrics =
new BinaryClassificationMetrics(
predictions.select(col(scoreCol), col(labelCol).cast(DoubleType))
. . . . . .
)
```
To implement summary for other classifiers:
```
LinearSVCSummary extends BinaryClassificationSummary // use
rawPredicationCol
FMClassifierSummary extends BinaryClassificationSummary // use
ProbabilityCol
```
For RandomForestClassifer (also for DecisionTreeClassifier and GBTClassiifer)
```
RandomForestSummary extends ClassificationSummary
BinaryRandomForestSummary extends BinaryClassificationSummary // use
ProbabilityCol
if (numOfClass == 2)
summary = BinaryRandomForestSummary
else
sumary = RandomForestSummary
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]