Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16158#discussion_r137545402
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala ---
    @@ -85,6 +86,32 @@ private[ml] trait ValidatorParams extends HasSeed with 
Params {
         instrumentation.logNamedValue("evaluator", 
$(evaluator).getClass.getCanonicalName)
         instrumentation.logNamedValue("estimatorParamMapsLength", 
$(estimatorParamMaps).length)
       }
    +
    +
    +  /**
    +   * Summary of grid search tuning in the format of DataFrame. Each row 
contains one candidate
    +   * paramMap and the corresponding metric of trained model.
    +   */
    +  protected def getTuningSummaryDF(metrics: Array[Double]): DataFrame = {
    +    val params = $(estimatorParamMaps)
    +    require(params.nonEmpty, "estimator param maps should not be empty")
    +    require(params.length == metrics.length, "estimator param maps number 
should match metrics")
    +    val metricName = $(evaluator) match {
    +      case b: BinaryClassificationEvaluator => b.getMetricName
    +      case m: MulticlassClassificationEvaluator => m.getMetricName
    +      case r: RegressionEvaluator => r.getMetricName
    +      case _ => "metrics"
    +    }
    +    val spark = SparkSession.builder().getOrCreate()
    +    val sc = spark.sparkContext
    +    val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++ 
Seq(metricName)
    +    val schema = new StructType(fields.map(name => StructField(name, 
StringType)).toArray)
    +    val rows = sc.parallelize(params.zip(metrics)).map { case (param, 
metric) =>
    +      val values = param.toSeq.sortBy(_.param.name).map(_.value.toString) 
++ Seq(metric.toString)
    --- End diff --
    
    Here seems exists a problem:
    Suppose `params(0)` (which is a `ParamMap`) contains ParamA and ParamB,
    and `params(1)` (which is a `ParamMap`) contains ParamA and ParamC,
    The code here will run into problems. Because you compose the row values 
sorted by param name but do not check whether every row exactly match the first 
row.
    I think better way is, go though the whole `ParamMap` list and collect all 
params used, and sort them by name, as the dataframe schema.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to