Repository: spark
Updated Branches:
  refs/heads/master 55c4ca88a -> 2a24c481d


[SPARK-23975][ML] Allow Clustering to take Arrays of Double as input features

## What changes were proposed in this pull request?

- Multiple possible input types is added in validateAndTransformSchema() and 
computeCost() while checking column type

- Add if statement in transform() to support array type as featuresCol

- Add the case statement in fit() while selecting columns from dataset

These changes will be applied to KMeans first, then to other clustering method

## How was this patch tested?

unit test is added

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Lu WANG <lu.w...@databricks.com>

Closes #21081 from ludatabricks/SPARK-23975.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a24c481
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a24c481
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a24c481

Branch: refs/heads/master
Commit: 2a24c481da3f30b510deb62e5cf21c9463cf250c
Parents: 55c4ca8
Author: Lu WANG <lu.w...@databricks.com>
Authored: Tue Apr 24 09:25:41 2018 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Apr 24 09:25:41 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/KMeans.scala | 32 +++++++---
 .../org/apache/spark/ml/util/DatasetUtils.scala | 63 ++++++++++++++++++++
 .../spark/ml/clustering/KMeansSuite.scala       | 38 ++++++++++++
 3 files changed, 126 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2a24c481/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 1ad157a..d475c72 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, 
Vectors => OldVectors
 import org.apache.spark.mllib.linalg.VectorImplicits._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
-import org.apache.spark.sql.functions.{col, udf}
-import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.functions.udf
+import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, 
IntegerType, StructType}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.VersionUtils.majorVersion
 
@@ -87,12 +87,23 @@ private[clustering] trait KMeansParams extends Params with 
HasMaxIter with HasFe
   def getInitSteps: Int = $(initSteps)
 
   /**
+   * Validates the input schema.
+   * @param schema input schema
+   */
+  private[clustering] def validateSchema(schema: StructType): Unit = {
+    val typeCandidates = List( new VectorUDT,
+      new ArrayType(DoubleType, false),
+      new ArrayType(FloatType, false))
+
+    SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
+  }
+  /**
    * Validates and transforms the input schema.
    * @param schema input schema
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    validateSchema(schema)
     SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
   }
 }
@@ -125,8 +136,11 @@ class KMeansModel private[ml] (
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
+
     val predictUDF = udf((vector: Vector) => predict(vector))
-    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+
+    dataset.withColumn($(predictionCol),
+      predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
   }
 
   @Since("1.5.0")
@@ -146,8 +160,10 @@ class KMeansModel private[ml] (
   // TODO: Replace the temp fix when we have proper evaluators defined for 
clustering.
   @Since("2.0.0")
   def computeCost(dataset: Dataset[_]): Double = {
-    SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
-    val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
+    validateSchema(dataset.schema)
+
+    val data: RDD[OldVector] = 
dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol))
+      .rdd.map {
       case Row(point: Vector) => OldVectors.fromML(point)
     }
     parentModel.computeCost(data)
@@ -335,7 +351,9 @@ class KMeans @Since("1.5.0") (
     transformSchema(dataset.schema, logging = true)
 
     val handlePersistence = dataset.storageLevel == StorageLevel.NONE
-    val instances: RDD[OldVector] = 
dataset.select(col($(featuresCol))).rdd.map {
+    val instances: RDD[OldVector] = dataset.select(
+      DatasetUtils.columnToVector(dataset, getFeaturesCol))
+      .rdd.map {
       case Row(point: Vector) => OldVectors.fromML(point)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2a24c481/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
new file mode 100644
index 0000000..52619cb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
+import org.apache.spark.sql.{Column, Dataset}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType}
+
+
+private[spark] object DatasetUtils {
+
+  /**
+   * Cast a column in a Dataset to Vector type.
+   *
+   * The supported data types of the input column are
+   * - Vector
+   * - float/double type Array.
+   *
+   * Note: The returned column does not have Metadata.
+   *
+   * @param dataset input DataFrame
+   * @param colName column name.
+   * @return Vector column
+   */
+  def columnToVector(dataset: Dataset[_], colName: String): Column = {
+    val columnDataType = dataset.schema(colName).dataType
+    columnDataType match {
+      case _: VectorUDT => col(colName)
+      case fdt: ArrayType =>
+        val transferUDF = fdt.elementType match {
+          case _: FloatType => udf(f = (vector: Seq[Float]) => {
+            val inputArray = Array.fill[Double](vector.size)(0.0)
+            vector.indices.foreach(idx => inputArray(idx) = 
vector(idx).toDouble)
+            Vectors.dense(inputArray)
+          })
+          case _: DoubleType => udf((vector: Seq[Double]) => {
+            Vectors.dense(vector.toArray)
+          })
+          case other =>
+            throw new IllegalArgumentException(s"Array[$other] column cannot 
be cast to Vector")
+        }
+        transferUDF(col(colName))
+      case other =>
+        throw new IllegalArgumentException(s"$other column cannot be cast to 
Vector")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2a24c481/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 77c9d48..5445ebe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -30,6 +30,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, 
KMeans => MLlibKMeans
 import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, 
IntegerType, StructType}
 
 private[clustering] case class TestRow(features: Vector)
 
@@ -199,6 +201,42 @@ class KMeansSuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultR
     assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
   }
 
+  test("KMean with Array input") {
+    val featuresColNameD = "array_double_features"
+    val featuresColNameF = "array_float_features"
+
+    val doubleUDF = udf { (features: Vector) =>
+      val featureArray = Array.fill[Double](features.size)(0.0)
+      features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
+      featureArray
+    }
+    val floatUDF = udf { (features: Vector) =>
+      val featureArray = Array.fill[Float](features.size)(0.0f)
+      features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
+      featureArray
+    }
+
+    val newdatasetD = dataset.withColumn(featuresColNameD, 
doubleUDF(col("features")))
+      .drop("features")
+    val newdatasetF = dataset.withColumn(featuresColNameF, 
floatUDF(col("features")))
+      .drop("features")
+    assert(newdatasetD.schema(featuresColNameD).dataType.equals(new 
ArrayType(DoubleType, false)))
+    assert(newdatasetF.schema(featuresColNameF).dataType.equals(new 
ArrayType(FloatType, false)))
+
+    val kmeansD = new 
KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1)
+    val kmeansF = new 
KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1)
+    val modelD = kmeansD.fit(newdatasetD)
+    val modelF = kmeansF.fit(newdatasetF)
+    val transformedD = modelD.transform(newdatasetD)
+    val transformedF = modelF.transform(newdatasetF)
+
+    val predictDifference = transformedD.select("prediction")
+      .except(transformedF.select("prediction"))
+    assert(predictDifference.count() == 0)
+    assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) )
+  }
+
+
   test("read/write") {
     def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
       assert(model.clusterCenters === model2.clusterCenters)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to