Repository: spark
Updated Branches:
  refs/heads/master cf0cce903 -> acac7a508


[SPARK-16443][SPARKR] Alternating Least Squares (ALS) wrapper

## What changes were proposed in this pull request?

Add Alternating Least Squares wrapper in SparkR. Unit tests have been updated.

## How was this patch tested?

SparkR unit tests.

(If this patch involves UI changes, please attach a screenshot; otherwise, 
remove this)

![screen shot 2016-07-27 at 3 50 31 
pm](https://cloud.githubusercontent.com/assets/15318264/17195347/f7a6352a-5411-11e6-8e21-61a48070192a.png)
![screen shot 2016-07-27 at 3 50 46 
pm](https://cloud.githubusercontent.com/assets/15318264/17195348/f7a7d452-5411-11e6-845f-6d292283bc28.png)

Author: Junyang Qian <junya...@databricks.com>

Closes #14384 from junyangq/SPARK-16443.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/acac7a50
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/acac7a50
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/acac7a50

Branch: refs/heads/master
Commit: acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8
Parents: cf0cce9
Author: Junyang Qian <junya...@databricks.com>
Authored: Fri Aug 19 14:24:09 2016 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Aug 19 14:24:09 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |   3 +-
 R/pkg/R/generics.R                              |   4 +
 R/pkg/R/mllib.R                                 | 159 ++++++++++++++++++-
 R/pkg/inst/tests/testthat/test_mllib.R          |  40 +++++
 .../org/apache/spark/ml/r/ALSWrapper.scala      | 119 ++++++++++++++
 .../scala/org/apache/spark/ml/r/RWrappers.scala |   2 +
 6 files changed, 322 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/acac7a50/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 4404cff..e1b87b2 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -29,7 +29,8 @@ exportMethods("glm",
               "spark.posterior",
               "spark.perplexity",
               "spark.isoreg",
-              "spark.gaussianMixture")
+              "spark.gaussianMixture",
+              "spark.als")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/acac7a50/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index fe04bcf..693aa31 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1332,3 +1332,7 @@ setGeneric("spark.gaussianMixture",
 #' @rdname write.ml
 #' @export
 setGeneric("write.ml", function(object, path, ...) { 
standardGeneric("write.ml") })
+
+#' @rdname spark.als
+#' @export
+setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })

http://git-wip-us.apache.org/repos/asf/spark/blob/acac7a50/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index b952741..36f38fc 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -74,6 +74,13 @@ setClass("IsotonicRegressionModel", representation(jobj = 
"jobj"))
 #' @note GaussianMixtureModel since 2.1.0
 setClass("GaussianMixtureModel", representation(jobj = "jobj"))
 
+#' S4 class that represents an ALSModel
+#'
+#' @param jobj a Java object reference to the backing Scala ALSWrapper
+#' @export
+#' @note ALSModel since 2.1.0
+setClass("ALSModel", representation(jobj = "jobj"))
+
 #' Saves the MLlib model to the input path
 #'
 #' Saves the MLlib model to the input path. For more information, see the 
specific
@@ -82,8 +89,8 @@ setClass("GaussianMixtureModel", representation(jobj = 
"jobj"))
 #' @name write.ml
 #' @export
 #' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, 
\link{spark.survreg}, \link{spark.lda}
-#' @seealso \link{spark.isoreg}
+#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.lda}, 
\link{spark.naiveBayes}
+#' @seealso \link{spark.survreg}, \link{spark.isoreg}
 #' @seealso \link{read.ml}
 NULL
 
@@ -95,10 +102,11 @@ NULL
 #' @name predict
 #' @export
 #' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.naiveBayes}, 
\link{spark.survreg}
 #' @seealso \link{spark.isoreg}
 NULL
 
+
 #' Generalized Linear Models
 #'
 #' Fits generalized linear model against a Spark DataFrame.
@@ -801,6 +809,8 @@ read.ml <- function(path) {
       return(new("IsotonicRegressionModel", jobj = jobj))
   } else if (isInstanceOf(jobj, 
"org.apache.spark.ml.r.GaussianMixtureWrapper")) {
       return(new("GaussianMixtureModel", jobj = jobj))
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
+      return(new("ALSModel", jobj = jobj))
   } else {
     stop(paste("Unsupported model: ", jobj))
   }
@@ -1053,4 +1063,145 @@ setMethod("summary", signature(object = 
"GaussianMixtureModel"),
 setMethod("predict", signature(object = "GaussianMixtureModel"),
           function(object, newData) {
             return(dataFrame(callJMethod(object@jobj, "transform", 
newData@sdf)))
-          })
\ No newline at end of file
+          })
+
+#' Alternating Least Squares (ALS) for Collaborative Filtering
+#'
+#' \code{spark.als} learns latent factors in collaborative filtering via 
alternating least
+#' squares. Users can call \code{summary} to obtain fitted latent factors, 
\code{predict}
+#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to 
save/load fitted models.
+#'
+#' For more details, see
+#' 
\href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib:
+#' Collaborative Filtering}.
+#'
+#' @param data a SparkDataFrame for training.
+#' @param ratingCol column name for ratings.
+#' @param userCol column name for user ids. Ids must be (or can be coerced 
into) integers.
+#' @param itemCol column name for item ids. Ids must be (or can be coerced 
into) integers.
+#' @param rank rank of the matrix factorization (> 0).
+#' @param reg regularization parameter (>= 0).
+#' @param maxIter maximum number of iterations (>= 0).
+#' @param nonnegative logical value indicating whether to apply nonnegativity 
constraints.
+#' @param implicitPrefs logical value indicating whether to use implicit 
preference.
+#' @param alpha alpha parameter in the implicit preference formulation (>= 0).
+#' @param seed integer seed for random number generation.
+#' @param numUserBlocks number of user blocks used to parallelize computation 
(> 0).
+#' @param numItemBlocks number of item blocks used to parallelize computation 
(> 0).
+#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable 
checkpoint (-1).
+#'
+#' @return \code{spark.als} returns a fitted ALS model
+#' @rdname spark.als
+#' @aliases spark.als,SparkDataFrame-method
+#' @name spark.als
+#' @export
+#' @examples
+#' \dontrun{
+#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 
2, 4.0),
+#'                 list(2, 1, 1.0), list(2, 2, 5.0))
+#' df <- createDataFrame(ratings, c("user", "item", "rating"))
+#' model <- spark.als(df, "rating", "user", "item")
+#'
+#' # extract latent factors
+#' stats <- summary(model)
+#' userFactors <- stats$userFactors
+#' itemFactors <- stats$itemFactors
+#'
+#' # make predictions
+#' predicted <- predict(model, df)
+#' showDF(predicted)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#'
+#' # set other arguments
+#' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
+#'                     reg = 0.1, nonnegative = TRUE)
+#' statsS <- summary(modelS)
+#' }
+#' @note spark.als since 2.1.0
+setMethod("spark.als", signature(data = "SparkDataFrame"),
+          function(data, ratingCol = "rating", userCol = "user", itemCol = 
"item",
+                   rank = 10, reg = 1.0, maxIter = 10, nonnegative = FALSE,
+                   implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, 
numItemBlocks = 10,
+                   checkpointInterval = 10, seed = 0) {
+
+            if (!is.numeric(rank) || rank <= 0) {
+              stop("rank should be a positive number.")
+            }
+            if (!is.numeric(reg) || reg < 0) {
+              stop("reg should be a nonnegative number.")
+            }
+            if (!is.numeric(maxIter) || maxIter <= 0) {
+              stop("maxIter should be a positive number.")
+            }
+
+            jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
+                                "fit", data@sdf, ratingCol, userCol, itemCol, 
as.integer(rank),
+                                reg, as.integer(maxIter), implicitPrefs, 
alpha, nonnegative,
+                                as.integer(numUserBlocks), 
as.integer(numItemBlocks),
+                                as.integer(checkpointInterval), 
as.integer(seed))
+            return(new("ALSModel", jobj = jobj))
+          })
+
+# Returns a summary of the ALS model produced by spark.als.
+
+#' @param object a fitted ALS model.
+#' @return \code{summary} returns a list containing the names of the user 
column,
+#'         the item column and the rating column, the estimated user and item 
factors,
+#'         rank, regularization parameter and maximum number of iterations 
used in training.
+#' @rdname spark.als
+#' @aliases summary,ALSModel-method
+#' @export
+#' @note summary(ALSModel) since 2.1.0
+setMethod("summary", signature(object = "ALSModel"),
+function(object, ...) {
+    jobj <- object@jobj
+    user <- callJMethod(jobj, "userCol")
+    item <- callJMethod(jobj, "itemCol")
+    rating <- callJMethod(jobj, "ratingCol")
+    userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
+    itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
+    rank <- callJMethod(jobj, "rank")
+    return(list(user = user, item = item, rating = rating, userFactors = 
userFactors,
+                itemFactors = itemFactors, rank = rank))
+})
+
+
+# Makes predictions from an ALS model or a model produced by spark.als.
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted values.
+#' @rdname spark.als
+#' @aliases predict,ALSModel-method
+#' @export
+#' @note predict(ALSModel) since 2.1.0
+setMethod("predict", signature(object = "ALSModel"),
+function(object, newData) {
+    return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+})
+
+
+# Saves the ALS model to the input path.
+
+#' @param path the directory where the model is saved.
+#' @param overwrite logical value indicating whether to overwrite if the 
output path
+#'                  already exists. Default is FALSE which means throw 
exception
+#'                  if the output path exists.
+#'
+#' @rdname spark.als
+#' @aliases write.ml,ALSModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(ALSModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "ALSModel", path = "character"),
+function(object, path, overwrite = FALSE) {
+    writer <- callJMethod(object@jobj, "write")
+    if (overwrite) {
+        writer <- callJMethod(writer, "overwrite")
+    }
+    invisible(callJMethod(writer, "save", path))
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/acac7a50/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R 
b/R/pkg/inst/tests/testthat/test_mllib.R
index dfb7a18..67a3099 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -657,4 +657,44 @@ test_that("spark.posterior and spark.perplexity", {
   expect_equal(length(local.posterior), sum(unlist(local.posterior)))
 })
 
+test_that("spark.als", {
+  data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 
4.0),
+  list(2, 1, 1.0), list(2, 2, 5.0))
+  df <- createDataFrame(data, c("user", "item", "score"))
+  model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = 
"item",
+  rank = 10, maxIter = 5, seed = 0, reg = 0.1)
+  stats <- summary(model)
+  expect_equal(stats$rank, 10)
+  test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", 
"item"))
+  predictions <- collect(predict(model, test))
+
+  expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409),
+  tolerance = 1e-4)
+
+  # Test model save/load
+  modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp")
+  write.ml(model, modelPath)
+  expect_error(write.ml(model, modelPath))
+  write.ml(model, modelPath, overwrite = TRUE)
+  model2 <- read.ml(modelPath)
+  stats2 <- summary(model2)
+  expect_equal(stats2$rating, "score")
+  userFactors <- collect(stats$userFactors)
+  itemFactors <- collect(stats$itemFactors)
+  userFactors2 <- collect(stats2$userFactors)
+  itemFactors2 <- collect(stats2$itemFactors)
+
+  orderUser <- order(userFactors$id)
+  orderUser2 <- order(userFactors2$id)
+  expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2])
+  expect_equal(userFactors$features[orderUser], 
userFactors2$features[orderUser2])
+
+  orderItem <- order(itemFactors$id)
+  orderItem2 <- order(itemFactors2$id)
+  expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2])
+  expect_equal(itemFactors$features[orderItem], 
itemFactors2$features[orderItem2])
+
+  unlink(modelPath)
+})
+
 sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/acac7a50/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
new file mode 100644
index 0000000..ad13cce
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.recommendation.{ALS, ALSModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class ALSWrapper private (
+    val alsModel: ALSModel,
+    val ratingCol: String) extends MLWritable {
+
+  lazy val userCol: String = alsModel.getUserCol
+  lazy val itemCol: String = alsModel.getItemCol
+  lazy val userFactors: DataFrame = alsModel.userFactors
+  lazy val itemFactors: DataFrame = alsModel.itemFactors
+  lazy val rank: Int = alsModel.rank
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    alsModel.transform(dataset)
+  }
+
+  override def write: MLWriter = new ALSWrapper.ALSWrapperWriter(this)
+}
+
+private[r] object ALSWrapper extends MLReadable[ALSWrapper] {
+
+  def fit(  // scalastyle:ignore
+      data: DataFrame,
+      ratingCol: String,
+      userCol: String,
+      itemCol: String,
+      rank: Int,
+      regParam: Double,
+      maxIter: Int,
+      implicitPrefs: Boolean,
+      alpha: Double,
+      nonnegative: Boolean,
+      numUserBlocks: Int,
+      numItemBlocks: Int,
+      checkpointInterval: Int,
+      seed: Int): ALSWrapper = {
+
+    val als = new ALS()
+      .setRatingCol(ratingCol)
+      .setUserCol(userCol)
+      .setItemCol(itemCol)
+      .setRank(rank)
+      .setRegParam(regParam)
+      .setMaxIter(maxIter)
+      .setImplicitPrefs(implicitPrefs)
+      .setAlpha(alpha)
+      .setNonnegative(nonnegative)
+      .setNumBlocks(numUserBlocks)
+      .setNumItemBlocks(numItemBlocks)
+      .setCheckpointInterval(checkpointInterval)
+      .setSeed(seed.toLong)
+
+    val alsModel: ALSModel = als.fit(data)
+
+    new ALSWrapper(alsModel, ratingCol)
+  }
+
+  override def read: MLReader[ALSWrapper] = new ALSWrapperReader
+
+  override def load(path: String): ALSWrapper = super.load(path)
+
+  class ALSWrapperWriter(instance: ALSWrapper) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val modelPath = new Path(path, "model").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("ratingCol" -> instance.ratingCol)
+      val rMetadataJson: String = compact(render(rMetadata))
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+      instance.alsModel.save(modelPath)
+    }
+  }
+
+  class ALSWrapperReader extends MLReader[ALSWrapper] {
+
+    override def load(path: String): ALSWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val modelPath = new Path(path, "model").toString
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val ratingCol = (rMetadata \ "ratingCol").extract[String]
+      val alsModel = ALSModel.load(modelPath)
+
+      new ALSWrapper(alsModel, ratingCol)
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/acac7a50/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index e23af51..51a65f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -50,6 +50,8 @@ private[r] object RWrappers extends MLReader[Object] {
         IsotonicRegressionWrapper.load(path)
       case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
         GaussianMixtureWrapper.load(path)
+      case "org.apache.spark.ml.r.ALSWrapper" =>
+        ALSWrapper.load(path)
       case _ =>
         throw new SparkException(s"SparkR read.ml does not support load 
$className")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to