huaxingao commented on pull request #28710:
URL: https://github.com/apache/spark/pull/28710#issuecomment-642907472


   @zhengruifeng Thanks for your comments.
   I think the traits should be like this (not including weightCol etc. for 
simplicity):
   ```
   trait ClassificationSummary {
     def predictionCol: String
     def labelCol: String
     val multiclassMetrics = 
         new MulticlassMetrics(
            predictions.select(col(predictionCol), 
col(labelCol).cast(DoubleType))
           . . . . . .
        )
   
   trait BinaryClassificationSummary extends ClassificationSummary
     def rawPredictionCol: String
     val binaryMetrics = 
       new BinaryClassificationMetrics(
         predictions.select(col(rawPredictionCol), 
col(labelCol).cast(DoubleType))
         . . . . . . 
       )
   ```
   However, currently BinaryLogisticRegressionSummary uses probabilityCol for 
BinaryClassificationMetrics and this probabilityCol is in 
LogisticRegressionSummary instead of BinaryLogisticRegressionSummary. In order 
not to break the existing code, I need to make several changes for the above 
traits
   1. change rawPredictionCol to scoreCol
   can't use rawPredictionCol since currently LogisticRegression uses  
probabilityCol
   can't use probabilityCol since LinearSVC doesn't have probabilityCol 
   2. put scoreCol in ClassificationSummary (since currently probabilityCol is 
in LogisticRegressionSummary instead of BinaryLogisticRegressionSummary)
   that's how I get the current traits as following:
   ```
   trait ClassificationSummary {
     def scoreCol: String
     def predictionCol: String
     def labelCol: String
     val multiclassMetrics = 
         new MulticlassMetrics(
            predictions.select(col(predictionCol), 
col(labelCol).cast(DoubleType))
           . . . . . .
        )
   
   trait BinaryClassificationSummary extends ClassificationSummary
     val binaryMetrics = 
       new BinaryClassificationMetrics(
         predictions.select(col(scoreCol), col(labelCol).cast(DoubleType))
         . . . . . . 
       )
   ```
   To implement summary for other classifiers:
   ```
   LinearSVCSummary extends BinaryClassificationSummary  // use 
rawPredicationCol
   FMClassifierSummary extends BinaryClassificationSummary  // use 
ProbabilityCol
   ```
   For RandomForestClassifer (also for DecisionTreeClassifier and GBTClassiifer)
   ```
   RandomForestSummary extends ClassificationSummary
   BinaryRandomForestSummary extends BinaryClassificationSummary  // use 
ProbabilityCol
   if (numOfClass == 2)
      summary = BinaryRandomForestSummary
   else
      sumary = RandomForestSummary
   ```
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to