Repository: spark
Updated Branches:
  refs/heads/master 9b8eca65d -> b40546651


[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy

This PR adds a param to `ALS`/`ALSModel` to set the strategy used when 
encountering unknown users or items at prediction time in `transform`. This can 
occur in 2 scenarios: (a) production scoring, and (b) cross-validation & 
evaluation.

The current behavior returns `NaN` if a user/item is unknown. In scenario (b), 
this can easily occur when using `CrossValidator` or `TrainValidationSplit` 
since some users/items may only occur in the test set and not in the training 
set. In this case, the evaluator returns `NaN` for all metrics, making model 
selection impossible.

The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). 
The other option supported initially is `drop`, which drops all rows with `NaN` 
predictions. This flag allows users to use `ALS` in cross-validation settings. 
It is made an `expertParam`. The param is made a string so that the set of 
strategies can be extended in future (some options are discussed in 
[SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)).
## How was this patch tested?

New unit tests, and manual "before and after" tests for Scala & Python using 
MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or 
`TrainValidationSplit` with the default param setting results in metrics that 
are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid 
metrics.

Author: Nick Pentreath <ni...@za.ibm.com>

Closes #12896 from MLnick/SPARK-14489-als-nan.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b4054665
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b4054665
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b4054665

Branch: refs/heads/master
Commit: b405466513bcc02cadf1477b6b682ace95d81658
Parents: 9b8eca6
Author: Nick Pentreath <ni...@za.ibm.com>
Authored: Tue Feb 28 16:17:35 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Tue Feb 28 16:17:35 2017 +0200

----------------------------------------------------------------------
 .../apache/spark/ml/recommendation/ALS.scala    | 44 ++++++++++++++++-
 .../spark/ml/recommendation/ALSSuite.scala      | 51 +++++++++++++++++++-
 python/pyspark/ml/recommendation.py             | 30 ++++++++++--
 3 files changed, 116 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 97c8655..af00762 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -90,6 +90,27 @@ private[recommendation] trait ALSModelParams extends Params 
with HasPredictionCo
       n.toInt
     }
   }
+
+  /**
+   * Param for strategy for dealing with unknown or new users/items at 
prediction time.
+   * This may be useful in cross-validation or production scenarios, for 
handling user/item ids
+   * the model has not seen in the training data.
+   * Supported values:
+   * - "nan":  predicted value for unknown ids will be NaN.
+   * - "drop": rows in the input DataFrame containing unknown ids will be 
dropped from
+   *           the output DataFrame containing predictions.
+   * Default: "nan".
+   * @group expertParam
+   */
+  val coldStartStrategy = new Param[String](this, "coldStartStrategy",
+    "strategy for dealing with unknown or new users/items at prediction time. 
This may be " +
+    "useful in cross-validation or production scenarios, for handling 
user/item ids the model " +
+    "has not seen in the training data. Supported values: " +
+    s"${ALSModel.supportedColdStartStrategies.mkString(",")}.",
+    (s: String) => 
ALSModel.supportedColdStartStrategies.contains(s.toLowerCase))
+
+  /** @group expertGetParam */
+  def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase
 }
 
 /**
@@ -203,7 +224,8 @@ private[recommendation] trait ALSParams extends 
ALSModelParams with HasMaxIter w
   setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, 
numItemBlocks -> 10,
     implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
     ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
-    intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> 
"MEMORY_AND_DISK")
+    intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> 
"MEMORY_AND_DISK",
+    coldStartStrategy -> "nan")
 
   /**
    * Validates and transforms the input schema.
@@ -248,6 +270,10 @@ class ALSModel private[ml] (
   @Since("1.3.0")
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
+  /** @group expertSetParam */
+  @Since("2.2.0")
+  def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, 
value)
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema)
@@ -260,13 +286,19 @@ class ALSModel private[ml] (
         Float.NaN
       }
     }
-    dataset
+    val predictions = dataset
       .join(userFactors,
         checkedCast(dataset($(userCol)).cast(DoubleType)) === 
userFactors("id"), "left")
       .join(itemFactors,
         checkedCast(dataset($(itemCol)).cast(DoubleType)) === 
itemFactors("id"), "left")
       .select(dataset("*"),
         predict(userFactors("features"), 
itemFactors("features")).as($(predictionCol)))
+    getColdStartStrategy match {
+      case ALSModel.Drop =>
+        predictions.na.drop("all", Seq($(predictionCol)))
+      case ALSModel.NaN =>
+        predictions
+    }
   }
 
   @Since("1.3.0")
@@ -290,6 +322,10 @@ class ALSModel private[ml] (
 @Since("1.6.0")
 object ALSModel extends MLReadable[ALSModel] {
 
+  private val NaN = "nan"
+  private val Drop = "drop"
+  private[recommendation] final val supportedColdStartStrategies = Array(NaN, 
Drop)
+
   @Since("1.6.0")
   override def read: MLReader[ALSModel] = new ALSModelReader
 
@@ -432,6 +468,10 @@ class ALS(@Since("1.4.0") override val uid: String) 
extends Estimator[ALSModel]
   @Since("2.0.0")
   def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, 
value)
 
+  /** @group expertSetParam */
+  @Since("2.2.0")
+  def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, 
value)
+
   /**
    * Sets both numUserBlocks and numItemBlocks to the specific value.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index b923bac..c9e7b50 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -498,8 +498,8 @@ class ALSSuite
           (ex, act) =>
             ex.userFactors.first().getSeq[Float](1) === 
act.userFactors.first.getSeq[Float](1)
         } { (ex, act, _) =>
-          ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
-            act.transform(_: DataFrame).select("prediction").first.getFloat(0) 
absTol 1e-6
+          ex.transform(_: DataFrame).select("prediction").first.getDouble(0) 
~==
+            act.transform(_: 
DataFrame).select("prediction").first.getDouble(0) absTol 1e-6
         }
     }
     // check user/item ids falling outside of Int range
@@ -547,6 +547,53 @@ class ALSSuite
       ALS.train(ratings)
     }
   }
+
+  test("ALS cold start user/item prediction strategy") {
+    val spark = this.spark
+    import spark.implicits._
+    import org.apache.spark.sql.functions._
+
+    val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 
1)
+    val data = ratings.toDF
+    val knownUser = data.select(max("user")).as[Int].first()
+    val unknownUser = knownUser + 10
+    val knownItem = data.select(max("item")).as[Int].first()
+    val unknownItem = knownItem + 20
+    val test = Seq(
+      (unknownUser, unknownItem),
+      (knownUser, unknownItem),
+      (unknownUser, knownItem),
+      (knownUser, knownItem)
+    ).toDF("user", "item")
+
+    val als = new ALS().setMaxIter(1).setRank(1)
+    // default is 'nan'
+    val defaultModel = als.fit(data)
+    val defaultPredictions = 
defaultModel.transform(test).select("prediction").as[Float].collect()
+    assert(defaultPredictions.length == 4)
+    assert(defaultPredictions.slice(0, 3).forall(_.isNaN))
+    assert(!defaultPredictions.last.isNaN)
+
+    // check 'drop' strategy should filter out rows with unknown users/items
+    val dropPredictions = defaultModel
+      .setColdStartStrategy("drop")
+      .transform(test)
+      .select("prediction").as[Float].collect()
+    assert(dropPredictions.length == 1)
+    assert(!dropPredictions.head.isNaN)
+    assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14)
+  }
+
+  test("case insensitive cold start param value") {
+    val spark = this.spark
+    import spark.implicits._
+    val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 
1)
+    val data = ratings.toDF
+    val model = new ALS().fit(data)
+    Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s =>
+      model.setColdStartStrategy(s).transform(data)
+    }
+  }
 }
 
 class ALSCleanerSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/python/pyspark/ml/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/recommendation.py 
b/python/pyspark/ml/recommendation.py
index e28d38b..43f82da 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, 
HasMaxIter, HasPredictionCol, Ha
     finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
                               "StorageLevel for ALS model factors.",
                               typeConverter=TypeConverters.toString)
+    coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy 
for dealing with " +
+                              "unknown or new users/items at prediction time. 
This may be useful " +
+                              "in cross-validation or production scenarios, 
for handling " +
+                              "user/item ids the model has not seen in the 
training data. " +
+                              "Supported values: 'nan', 'drop'.",
+                              typeConverter=TypeConverters.toString)
 
     @keyword_only
     def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, 
numItemBlocks=10,
                  implicitPrefs=False, alpha=1.0, userCol="user", 
itemCol="item", seed=None,
                  ratingCol="rating", nonnegative=False, checkpointInterval=10,
                  intermediateStorageLevel="MEMORY_AND_DISK",
-                 finalStorageLevel="MEMORY_AND_DISK"):
+                 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
         """
         __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, 
numItemBlocks=10, \
                  implicitPrefs=false, alpha=1.0, userCol="user", 
itemCol="item", seed=None, \
                  ratingCol="rating", nonnegative=false, checkpointInterval=10, 
\
                  intermediateStorageLevel="MEMORY_AND_DISK", \
-                 finalStorageLevel="MEMORY_AND_DISK")
+                 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
         """
         super(ALS, self).__init__()
         self._java_obj = 
self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
@@ -145,7 +151,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, 
HasPredictionCol, Ha
                          implicitPrefs=False, alpha=1.0, userCol="user", 
itemCol="item",
                          ratingCol="rating", nonnegative=False, 
checkpointInterval=10,
                          intermediateStorageLevel="MEMORY_AND_DISK",
-                         finalStorageLevel="MEMORY_AND_DISK")
+                         finalStorageLevel="MEMORY_AND_DISK", 
coldStartStrategy="nan")
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
@@ -155,13 +161,13 @@ class ALS(JavaEstimator, HasCheckpointInterval, 
HasMaxIter, HasPredictionCol, Ha
                   implicitPrefs=False, alpha=1.0, userCol="user", 
itemCol="item", seed=None,
                   ratingCol="rating", nonnegative=False, checkpointInterval=10,
                   intermediateStorageLevel="MEMORY_AND_DISK",
-                  finalStorageLevel="MEMORY_AND_DISK"):
+                  finalStorageLevel="MEMORY_AND_DISK", 
coldStartStrategy="nan"):
         """
         setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, 
numItemBlocks=10, \
                  implicitPrefs=False, alpha=1.0, userCol="user", 
itemCol="item", seed=None, \
                  ratingCol="rating", nonnegative=False, checkpointInterval=10, 
\
                  intermediateStorageLevel="MEMORY_AND_DISK", \
-                 finalStorageLevel="MEMORY_AND_DISK")
+                 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
         Sets params for ALS.
         """
         kwargs = self.setParams._input_kwargs
@@ -332,6 +338,20 @@ class ALS(JavaEstimator, HasCheckpointInterval, 
HasMaxIter, HasPredictionCol, Ha
         """
         return self.getOrDefault(self.finalStorageLevel)
 
+    @since("2.2.0")
+    def setColdStartStrategy(self, value):
+        """
+        Sets the value of :py:attr:`coldStartStrategy`.
+        """
+        return self._set(coldStartStrategy=value)
+
+    @since("2.2.0")
+    def getColdStartStrategy(self):
+        """
+        Gets the value of coldStartStrategy or its default value.
+        """
+        return self.getOrDefault(self.coldStartStrategy)
+
 
 class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
     """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to