Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19204#discussion_r138763970
--- Diff: python/pyspark/ml/evaluation.py ---
@@ -328,6 +329,87 @@ def setParams(self, predictionCol="prediction",
labelCol="label",
kwargs = self._input_kwargs
return self._set(**kwargs)
+
+@inherit_doc
+class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
+ JavaMLReadable, JavaMLWritable):
+ """
+ .. note:: Experimental
+
+ Evaluator for Clustering results, which expects two input
+ columns: prediction and features.
+
+ >>> from sklearn import datasets
+ >>> from pyspark.sql.types import *
+ >>> from pyspark.ml.linalg import Vectors, VectorUDT
+ >>> from pyspark.ml.evaluation import ClusteringEvaluator
+ ...
+ >>> iris = datasets.load_iris()
+ >>> iris_rows = [(Vectors.dense(x), int(iris.target[i]))
+ ... for i, x in enumerate(iris.data)]
+ >>> schema = StructType([
+ ... StructField("features", VectorUDT(), True),
+ ... StructField("cluster_id", IntegerType(), True)])
+ >>> rdd = spark.sparkContext.parallelize(iris_rows)
+ >>> dataset = spark.createDataFrame(rdd, schema)
+ ...
+ >>> evaluator = ClusteringEvaluator(predictionCol="cluster_id")
+ >>> evaluator.evaluate(dataset)
+ 0.656...
+ >>> ce_path = temp_path + "/ce"
+ >>> evaluator.save(ce_path)
+ >>> evaluator2 = ClusteringEvaluator.load(ce_path)
+ >>> str(evaluator2.getPredictionCol())
+ 'cluster_id'
+
+ .. versionadded:: 2.3.0
+ """
+ metricName = Param(Params._dummy(), "metricName",
+ "metric name in evaluation "
+ "(silhouette)",
--- End diff --
The string in multiple lines, we should use """ instead of "". Otherwise
move them to the same line.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]