Repository: spark Updated Branches: refs/heads/master 166f34618 -> 57d994994
[SPARK-24557][ML] ClusteringEvaluator support array input ## What changes were proposed in this pull request? ClusteringEvaluator support array input ## How was this patch tested? added tests Author: zhengruifeng <ruife...@foxmail.com> Closes #21563 from zhengruifeng/clu_eval_support_array. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/57d99499 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/57d99499 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/57d99499 Branch: refs/heads/master Commit: 57d994994d27154f57f2724924c42beb2ab2e0e7 Parents: 166f346 Author: zhengruifeng <ruife...@foxmail.com> Authored: Wed Aug 1 23:46:01 2018 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Wed Aug 1 23:46:01 2018 -0700 ---------------------------------------------------------------------- .../spark/ml/evaluation/ClusteringEvaluator.scala | 15 +++++++++------ .../ml/evaluation/ClusteringEvaluatorSuite.scala | 15 ++++++++++++++- 2 files changed, 23 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/57d99499/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 4353c46..a6d6b4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -21,11 +21,10 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, - SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -107,15 +106,19 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol)) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) + val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) + val df = dataset.select(col($(predictionCol)), + vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata)) + ($(metricName), $(distanceMeasure)) match { case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol)) + df, $(predictionCol), $(featuresCol)) case ("silhouette", "cosine") => - CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) + CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), $(featuresCol)) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/57d99499/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 2c175ff..e2d7756 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -33,10 +33,17 @@ class ClusteringEvaluatorSuite import testImplicits._ @transient var irisDataset: Dataset[_] = _ + @transient var newIrisDataset: Dataset[_] = _ + @transient var newIrisDatasetD: Dataset[_] = _ + @transient var newIrisDatasetF: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt") + val datasets = MLTestingUtils.generateArrayFeatureDataset(irisDataset) + newIrisDataset = datasets._1 + newIrisDatasetD = datasets._2 + newIrisDatasetF = datasets._3 } test("params") { @@ -66,6 +73,9 @@ class ClusteringEvaluatorSuite .setPredictionCol("label") assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.6564679231 relTol 1e-5) } /* @@ -85,6 +95,9 @@ class ClusteringEvaluatorSuite .setDistanceMeasure("cosine") assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.7222369298 relTol 1e-5) } test("number of clusters must be greater than one") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org