Repository: spark
Updated Branches:
  refs/heads/master 15e301556 -> 8d29001de


[SPARK-13011] K-means wrapper in SparkR

https://issues.apache.org/jira/browse/SPARK-13011

Author: Xusen Yin <[email protected]>

Closes #11124 from yinxusen/SPARK-13011.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8d29001d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8d29001d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8d29001d

Branch: refs/heads/master
Commit: 8d29001dec5c3695721a76df3f70da50512ef28f
Parents: 15e3015
Author: Xusen Yin <[email protected]>
Authored: Tue Feb 23 15:42:58 2016 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Tue Feb 23 15:42:58 2016 -0800

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |  4 +-
 R/pkg/R/generics.R                              |  8 +++
 R/pkg/R/mllib.R                                 | 74 ++++++++++++++++++--
 R/pkg/inst/tests/testthat/test_mllib.R          | 28 ++++++++
 .../org/apache/spark/ml/clustering/KMeans.scala | 45 +++++++++++-
 .../org/apache/spark/ml/r/SparkRWrappers.scala  | 52 +++++++++++++-
 6 files changed, 203 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8d29001d/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index f194a46..6a3d63f 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -13,7 +13,9 @@ export("print.jobj")
 # MLlib integration
 exportMethods("glm",
               "predict",
-              "summary")
+              "summary",
+              "kmeans",
+              "fitted")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/8d29001d/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 2dba71a..ab61bce 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1160,3 +1160,11 @@ setGeneric("predict", function(object, ...) { 
standardGeneric("predict") })
 #' @rdname rbind
 #' @export
 setGeneric("rbind", signature = "...")
+
+#' @rdname kmeans
+#' @export
+setGeneric("kmeans")
+
+#' @rdname fitted
+#' @export
+setGeneric("fitted")

http://git-wip-us.apache.org/repos/asf/spark/blob/8d29001d/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 8d3b438..346f33d 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"),
 setMethod("summary", signature(object = "PipelineModel"),
           function(object, ...) {
             modelName <- 
callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
-                                   "getModelName", object@model)
+                                     "getModelName", object@model)
             features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
-                                   "getModelFeatures", object@model)
+                                    "getModelFeatures", object@model)
             coefficients <- 
callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
-                                   "getModelCoefficients", object@model)
+                                        "getModelCoefficients", object@model)
             if (modelName == "LinearRegressionModel") {
               devianceResiduals <- 
callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
                                                "getModelDevianceResiduals", 
object@model)
@@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"),
               colnames(coefficients) <- c("Estimate", "Std. Error", "t value", 
"Pr(>|t|)")
               rownames(coefficients) <- unlist(features)
               return(list(devianceResiduals = devianceResiduals, coefficients 
= coefficients))
-            } else {
+            } else if (modelName == "LogisticRegressionModel") {
               coefficients <- as.matrix(unlist(coefficients))
               colnames(coefficients) <- c("Estimate")
               rownames(coefficients) <- unlist(features)
               return(list(coefficients = coefficients))
+            } else if (modelName == "KMeansModel") {
+              modelSize <- 
callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+                                       "getKMeansModelSize", object@model)
+              cluster <- 
callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+                                     "getKMeansCluster", object@model, 
"classes")
+              k <- unlist(modelSize)[1]
+              size <- unlist(modelSize)[-1]
+              coefficients <- t(matrix(coefficients, ncol = k))
+              colnames(coefficients) <- unlist(features)
+              rownames(coefficients) <- 1:k
+              return(list(coefficients = coefficients, size = size, cluster = 
dataFrame(cluster)))
+            } else {
+              stop(paste("Unsupported model", modelName, sep = " "))
+            }
+          })
+
+#' Fit a k-means model
+#'
+#' Fit a k-means model, similarly to R's kmeans().
+#'
+#' @param x DataFrame for training
+#' @param centers Number of centers
+#' @param iter.max Maximum iteration number
+#' @param algorithm Algorithm choosen to fit the model
+#' @return A fitted k-means model
+#' @rdname kmeans
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- kmeans(x, centers = 2, algorithm="random")
+#'}
+setMethod("kmeans", signature(x = "DataFrame"),
+          function(x, centers, iter.max = 10, algorithm = c("random", 
"k-means||")) {
+            columnNames <- as.array(colnames(x))
+            algorithm <- match.arg(algorithm)
+            model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", 
"fitKMeans", x@sdf,
+                                 algorithm, iter.max, centers, columnNames)
+            return(new("PipelineModel", model = model))
+         })
+
+#' Get fitted result from a model
+#'
+#' Get fitted result from a model, similarly to R's fitted().
+#'
+#' @param object A fitted MLlib model
+#' @return DataFrame containing fitted values
+#' @rdname fitted
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- kmeans(trainingData, 2)
+#' fitted.model <- fitted(model)
+#' showDF(fitted.model)
+#'}
+setMethod("fitted", signature(object = "PipelineModel"),
+          function(object, method = c("centers", "classes"), ...) {
+            modelName <- 
callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+                                     "getModelName", object@model)
+
+            if (modelName == "KMeansModel") {
+              method <- match.arg(method)
+              fittedResult <- 
callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+                                          "getKMeansCluster", object@model, 
method)
+              return(dataFrame(fittedResult))
+            } else {
+              stop(paste("Unsupported model", modelName, sep = " "))
             }
           })

http://git-wip-us.apache.org/repos/asf/spark/blob/8d29001d/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R 
b/R/pkg/inst/tests/testthat/test_mllib.R
index 08099dd..595512e 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -113,3 +113,31 @@ test_that("summary works on base GLM models", {
   baseSummary <- summary(baseModel)
   expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
 })
+
+test_that("kmeans", {
+  newIris <- iris
+  newIris$Species <- NULL
+  training <- suppressWarnings(createDataFrame(sqlContext, newIris))
+
+  # Cache the DataFrame here to work around the bug SPARK-13178.
+  cache(training)
+  take(training, 1)
+
+  model <- kmeans(x = training, centers = 2)
+  sample <- take(select(predict(model, training), "prediction"), 1)
+  expect_equal(typeof(sample$prediction), "integer")
+  expect_equal(sample$prediction, 1)
+
+  # Test stats::kmeans is working
+  statsModel <- kmeans(x = newIris, centers = 2)
+  expect_equal(unique(statsModel$cluster), c(1, 2))
+
+  # Test fitted works on KMeans
+  fitted.model <- fitted(model)
+  expect_equal(sort(collect(distinct(select(fitted.model, 
"prediction")))$prediction), c(0, 1))
+
+  # Test summary works on KMeans
+  summary.model <- summary(model)
+  cluster <- summary.model$cluster
+  expect_equal(sort(collect(distinct(select(cluster, 
"prediction")))$prediction), c(0, 1))
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/8d29001d/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index b2292e2..c6a3eac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
 
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
@@ -135,6 +136,26 @@ class KMeansModel private[ml] (
 
   @Since("1.6.0")
   override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+
+  private var trainingSummary: Option[KMeansSummary] = None
+
+  private[clustering] def setSummary(summary: KMeansSummary): this.type = {
+    this.trainingSummary = Some(summary)
+    this
+  }
+
+  /**
+   * Gets summary of model on training set. An exception is
+   * thrown if `trainingSummary == None`.
+   */
+  @Since("2.0.0")
+  def summary: KMeansSummary = trainingSummary match {
+    case Some(summ) => summ
+    case None =>
+      throw new SparkException(
+        s"No training summary available for the 
${this.getClass.getSimpleName}",
+        new NullPointerException())
+  }
 }
 
 @Since("1.6.0")
@@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") (
       .setSeed($(seed))
       .setEpsilon($(tol))
     val parentModel = algo.run(rdd)
-    val model = new KMeansModel(uid, parentModel)
-    copyValues(model.setParent(this))
+    val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
+    val summary = new KMeansSummary(model.transform(dataset), 
$(predictionCol), $(featuresCol))
+    model.setSummary(summary)
   }
 
   @Since("1.5.0")
@@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
   override def load(path: String): KMeans = super.load(path)
 }
 
+class KMeansSummary private[clustering] (
+    @Since("2.0.0") @transient val predictions: DataFrame,
+    @Since("2.0.0") val predictionCol: String,
+    @Since("2.0.0") val featuresCol: String) extends Serializable {
+
+  /**
+   * Cluster centers of the transformed data.
+   */
+  @Since("2.0.0")
+  @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+  /**
+   * Size of each cluster.
+   */
+  @Since("2.0.0")
+  lazy val size: Array[Int] = cluster.map {
+    case Row(clusterIdx: Int) => (clusterIdx, 1)
+  }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8d29001d/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 551e75d..d23e4fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r
 import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.classification.{LogisticRegression, 
LogisticRegressionModel}
-import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
+import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
 import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
 import org.apache.spark.sql.DataFrame
 
@@ -51,6 +52,22 @@ private[r] object SparkRWrappers {
     pipeline.fit(df)
   }
 
+  def fitKMeans(
+      df: DataFrame,
+      initMode: String,
+      maxIter: Double,
+      k: Double,
+      columns: Array[String]): PipelineModel = {
+    val assembler = new VectorAssembler().setInputCols(columns)
+    val kMeans = new KMeans()
+      .setInitMode(initMode)
+      .setMaxIter(maxIter.toInt)
+      .setK(k.toInt)
+      .setFeaturesCol(assembler.getOutputCol)
+    val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
+    pipeline.fit(df)
+  }
+
   def getModelCoefficients(model: PipelineModel): Array[Double] = {
     model.stages.last match {
       case m: LinearRegressionModel => {
@@ -72,6 +89,8 @@ private[r] object SparkRWrappers {
           m.coefficients.toArray
         }
       }
+      case m: KMeansModel =>
+        m.clusterCenters.flatMap(_.toArray)
     }
   }
 
@@ -85,6 +104,31 @@ private[r] object SparkRWrappers {
     }
   }
 
+  def getKMeansModelSize(model: PipelineModel): Array[Int] = {
+    model.stages.last match {
+      case m: KMeansModel => Array(m.getK) ++ m.summary.size
+      case other => throw new UnsupportedOperationException(
+        s"KMeansModel required but ${other.getClass.getSimpleName} found.")
+    }
+  }
+
+  def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
+    model.stages.last match {
+      case m: KMeansModel =>
+        if (method == "centers") {
+          // Drop the assembled vector for easy-print to R side.
+          m.summary.predictions.drop(m.summary.featuresCol)
+        } else if (method == "classes") {
+          m.summary.cluster
+        } else {
+          throw new UnsupportedOperationException(
+            s"Method (centers or classes) required but $method found.")
+        }
+      case other => throw new UnsupportedOperationException(
+        s"KMeansModel required but ${other.getClass.getSimpleName} found.")
+    }
+  }
+
   def getModelFeatures(model: PipelineModel): Array[String] = {
     model.stages.last match {
       case m: LinearRegressionModel =>
@@ -103,6 +147,10 @@ private[r] object SparkRWrappers {
         } else {
           attrs.attributes.get.map(_.name.get)
         }
+      case m: KMeansModel =>
+        val attrs = AttributeGroup.fromStructField(
+          m.summary.predictions.schema(m.summary.featuresCol))
+        attrs.attributes.get.map(_.name.get)
     }
   }
 
@@ -112,6 +160,8 @@ private[r] object SparkRWrappers {
         "LinearRegressionModel"
       case m: LogisticRegressionModel =>
         "LogisticRegressionModel"
+      case m: KMeansModel =>
+        "KMeansModel"
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to