Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/16158#discussion_r137542479
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala ---
@@ -85,6 +86,32 @@ private[ml] trait ValidatorParams extends HasSeed with
Params {
instrumentation.logNamedValue("evaluator",
$(evaluator).getClass.getCanonicalName)
instrumentation.logNamedValue("estimatorParamMapsLength",
$(estimatorParamMaps).length)
}
+
+
+ /**
+ * Summary of grid search tuning in the format of DataFrame. Each row
contains one candidate
+ * paramMap and the corresponding metric of trained model.
+ */
+ protected def getTuningSummaryDF(metrics: Array[Double]): DataFrame = {
+ val params = $(estimatorParamMaps)
+ require(params.nonEmpty, "estimator param maps should not be empty")
+ require(params.length == metrics.length, "estimator param maps number
should match metrics")
+ val metricName = $(evaluator) match {
+ case b: BinaryClassificationEvaluator => b.getMetricName
+ case m: MulticlassClassificationEvaluator => m.getMetricName
+ case r: RegressionEvaluator => r.getMetricName
+ case _ => "metrics"
+ }
+ val spark = SparkSession.builder().getOrCreate()
+ val sc = spark.sparkContext
+ val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++
Seq(metricName)
+ val schema = new StructType(fields.map(name => StructField(name,
StringType)).toArray)
+ val rows = sc.parallelize(params.zip(metrics)).map { case (param,
metric) =>
+ val values = param.toSeq.sortBy(_.param.name).map(_.value.toString)
++ Seq(metric.toString)
+ Row.fromSeq(values)
+ }
--- End diff --
Here the var names is a little confusing,
`params` ==> `paramMaps`
`case (param, metric)` ==> `case (paramMap, metric)`
will be more clear.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]