Repository: spark
Updated Branches:
  refs/heads/master 166f34618 -> 57d994994


[SPARK-24557][ML] ClusteringEvaluator support array input

## What changes were proposed in this pull request?
ClusteringEvaluator support array input

## How was this patch tested?
added tests

Author: zhengruifeng <ruife...@foxmail.com>

Closes #21563 from zhengruifeng/clu_eval_support_array.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/57d99499
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/57d99499
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/57d99499

Branch: refs/heads/master
Commit: 57d994994d27154f57f2724924c42beb2ab2e0e7
Parents: 166f346
Author: zhengruifeng <ruife...@foxmail.com>
Authored: Wed Aug 1 23:46:01 2018 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Aug 1 23:46:01 2018 -0700

----------------------------------------------------------------------
 .../spark/ml/evaluation/ClusteringEvaluator.scala    | 15 +++++++++------
 .../ml/evaluation/ClusteringEvaluatorSuite.scala     | 15 ++++++++++++++-
 2 files changed, 23 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/57d99499/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
index 4353c46..a6d6b4e 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -21,11 +21,10 @@ import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, 
Vectors, VectorUDT}
+import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, 
Vectors}
 import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable,
-  SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.sql.{Column, DataFrame, Dataset}
 import org.apache.spark.sql.functions.{avg, col, udf}
 import org.apache.spark.sql.types.DoubleType
@@ -107,15 +106,19 @@ class ClusteringEvaluator @Since("2.3.0") 
(@Since("2.3.0") override val uid: Str
 
   @Since("2.3.0")
   override def evaluate(dataset: Dataset[_]): Double = {
-    SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
+    SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol))
     SchemaUtils.checkNumericType(dataset.schema, $(predictionCol))
 
+    val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
+    val df = dataset.select(col($(predictionCol)),
+      vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata))
+
     ($(metricName), $(distanceMeasure)) match {
       case ("silhouette", "squaredEuclidean") =>
         SquaredEuclideanSilhouette.computeSilhouetteScore(
-          dataset, $(predictionCol), $(featuresCol))
+          df, $(predictionCol), $(featuresCol))
       case ("silhouette", "cosine") =>
-        CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), 
$(featuresCol))
+        CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), 
$(featuresCol))
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/57d99499/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
index 2c175ff..e2d7756 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Dataset
@@ -33,10 +33,17 @@ class ClusteringEvaluatorSuite
   import testImplicits._
 
   @transient var irisDataset: Dataset[_] = _
+  @transient var newIrisDataset: Dataset[_] = _
+  @transient var newIrisDatasetD: Dataset[_] = _
+  @transient var newIrisDatasetF: Dataset[_] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
     irisDataset = 
spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt")
+    val datasets = MLTestingUtils.generateArrayFeatureDataset(irisDataset)
+    newIrisDataset = datasets._1
+    newIrisDatasetD = datasets._2
+    newIrisDatasetF = datasets._3
   }
 
   test("params") {
@@ -66,6 +73,9 @@ class ClusteringEvaluatorSuite
         .setPredictionCol("label")
 
     assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5)
+    assert(evaluator.evaluate(newIrisDataset) ~== 0.6564679231 relTol 1e-5)
+    assert(evaluator.evaluate(newIrisDatasetD) ~== 0.6564679231 relTol 1e-5)
+    assert(evaluator.evaluate(newIrisDatasetF) ~== 0.6564679231 relTol 1e-5)
   }
 
   /*
@@ -85,6 +95,9 @@ class ClusteringEvaluatorSuite
       .setDistanceMeasure("cosine")
 
     assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5)
+    assert(evaluator.evaluate(newIrisDataset) ~== 0.7222369298 relTol 1e-5)
+    assert(evaluator.evaluate(newIrisDatasetD) ~== 0.7222369298 relTol 1e-5)
+    assert(evaluator.evaluate(newIrisDatasetF) ~== 0.7222369298 relTol 1e-5)
   }
 
   test("number of clusters must be greater than one") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to