Repository: spark Updated Branches: refs/heads/master 4d5a005b0 -> 2931e89f0
[SPARK-10736] [ML] Use 1 for all ratings if $(ratingCol) = "" For some implicit dataset, ratings may not exist in the training data. In this case, we can assume all observed pairs to be positive and treat their ratings as 1. This should happen when users set ```ratingCol``` to an empty string. Author: Yanbo Liang <[email protected]> Closes #8937 from yanboliang/spark-10736. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2931e89f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2931e89f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2931e89f Branch: refs/heads/master Commit: 2931e89f0c54248d87f1f84c81137a5a91e142e9 Parents: 4d5a005 Author: Yanbo Liang <[email protected]> Authored: Tue Sep 29 23:58:32 2015 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Tue Sep 29 23:58:32 2015 -0700 ---------------------------------------------------------------------- .../src/main/scala/org/apache/spark/ml/recommendation/ALS.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2931e89f/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 9a56a75..f6f5281 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -315,9 +315,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { override def fit(dataset: DataFrame): ALSModel = { import dataset.sqlContext.implicits._ + val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), - col($(ratingCol)).cast(FloatType)) + .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
