Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/16158#discussion_r137542479 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala --- @@ -85,6 +86,32 @@ private[ml] trait ValidatorParams extends HasSeed with Params { instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) } + + + /** + * Summary of grid search tuning in the format of DataFrame. Each row contains one candidate + * paramMap and the corresponding metric of trained model. + */ + protected def getTuningSummaryDF(metrics: Array[Double]): DataFrame = { + val params = $(estimatorParamMaps) + require(params.nonEmpty, "estimator param maps should not be empty") + require(params.length == metrics.length, "estimator param maps number should match metrics") + val metricName = $(evaluator) match { + case b: BinaryClassificationEvaluator => b.getMetricName + case m: MulticlassClassificationEvaluator => m.getMetricName + case r: RegressionEvaluator => r.getMetricName + case _ => "metrics" + } + val spark = SparkSession.builder().getOrCreate() + val sc = spark.sparkContext + val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++ Seq(metricName) + val schema = new StructType(fields.map(name => StructField(name, StringType)).toArray) + val rows = sc.parallelize(params.zip(metrics)).map { case (param, metric) => + val values = param.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString) + Row.fromSeq(values) + } --- End diff -- Here the var names is a little confusing, `params` ==> `paramMaps` `case (param, metric)` ==> `case (paramMap, metric)` will be more clear.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org