Repository: spark Updated Branches: refs/heads/master fedf6961b -> 5ac96854c
[SPARK-21981][PYTHON][ML] Added Python interface for ClusteringEvaluator ## What changes were proposed in this pull request? Added Python interface for ClusteringEvaluator ## How was this patch tested? Manual test, eg. the example Python code in the comments. cc yanboliang Author: Marco Gaido <[email protected]> Author: Marco Gaido <[email protected]> Closes #19204 from mgaido91/SPARK-21981. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ac96854 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ac96854 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ac96854 Branch: refs/heads/master Commit: 5ac96854cc6186fa2dad602d0906ff2705e3f610 Parents: fedf696 Author: Marco Gaido <[email protected]> Authored: Fri Sep 22 13:12:33 2017 +0800 Committer: Yanbo Liang <[email protected]> Committed: Fri Sep 22 13:12:33 2017 +0800 ---------------------------------------------------------------------- python/pyspark/ml/evaluation.py | 76 +++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5ac96854/python/pyspark/ml/evaluation.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 09cdf9b..aa8dbe7 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -20,12 +20,13 @@ from abc import abstractmethod, ABCMeta from pyspark import since, keyword_only from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol, \ + HasFeaturesCol from pyspark.ml.common import inherit_doc from pyspark.ml.util import JavaMLReadable, JavaMLWritable __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', - 'MulticlassClassificationEvaluator'] + 'MulticlassClassificationEvaluator', 'ClusteringEvaluator'] @inherit_doc @@ -325,6 +326,77 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio kwargs = self._input_kwargs return self._set(**kwargs) + +@inherit_doc +class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, + JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Evaluator for Clustering results, which expects two input + columns: prediction and features. + + >>> from pyspark.ml.linalg import Vectors + >>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + ... [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0), + ... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)]) + >>> dataset = spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + ... + >>> evaluator = ClusteringEvaluator(predictionCol="prediction") + >>> evaluator.evaluate(dataset) + 0.9079... + >>> ce_path = temp_path + "/ce" + >>> evaluator.save(ce_path) + >>> evaluator2 = ClusteringEvaluator.load(ce_path) + >>> str(evaluator2.getPredictionCol()) + 'prediction' + + .. versionadded:: 2.3.0 + """ + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (silhouette)", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, predictionCol="prediction", featuresCol="features", + metricName="silhouette"): + """ + __init__(self, predictionCol="prediction", featuresCol="features", \ + metricName="silhouette") + """ + super(ClusteringEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid) + self._setDefault(metricName="silhouette") + kwargs = self._input_kwargs + self._set(**kwargs) + + @since("2.3.0") + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + return self._set(metricName=value) + + @since("2.3.0") + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + @since("2.3.0") + def setParams(self, predictionCol="prediction", featuresCol="features", + metricName="silhouette"): + """ + setParams(self, predictionCol="prediction", featuresCol="features", \ + metricName="silhouette") + Sets params for clustering evaluator. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + if __name__ == "__main__": import doctest import tempfile --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
