Github user MLnick commented on a diff in the pull request:
https://github.com/apache/spark/pull/18748#discussion_r139141803
--- Diff: mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
---
@@ -356,6 +371,40 @@ class ALSModel private[ml] (
}
/**
+ * Returns top `numUsers` users recommended for each item id in the
input data set. Note that if
+ * there are duplicate ids in the input dataset, only one set of
recommendations per unique id
+ * will be returned.
+ * @param dataset a Dataset containing a column of item ids. The column
name must match `itemCol`.
+ * @param numUsers max number of recommendations for each item.
+ * @return a DataFrame of (itemCol: Int, recommendations), where
recommendations are
+ * stored as an array of (userCol: Int, rating: Float) Rows.
+ */
+ @Since("2.3.0")
+ def recommendForItemSubset(dataset: Dataset[_], numUsers: Int):
DataFrame = {
+ val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors,
$(itemCol))
+ recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol),
numUsers)
+ }
+
+ /**
+ * Returns a subset of a factor DataFrame limited to only those unique
ids contained
+ * in the input dataset.
+ * @param dataset input Dataset containing id column to user to filter
factors.
+ * @param factors factor DataFrame to filter.
+ * @param column column name containing the ids in the input dataset.
+ * @return DataFrame containing factors only for those ids present in
both the input dataset and
+ * the factor DataFrame.
+ */
+ private def getSourceFactorSubset(
+ dataset: Dataset[_],
+ factors: DataFrame,
+ column: String): DataFrame = {
+ dataset.select(column)
+ .distinct()
+ .join(factors, dataset(column) === factors("id"))
+ .select(factors("id"), factors("features"))
+ }
--- End diff --
How does that eliminate the need for `distinct`?
e.g. take a look at the below:
```
scala> factors.show
+---+--------+
| id|features|
+---+--------+
| 0| [0, 1]|
| 1| [1, 2]|
| 2| [2, 3]|
| 3| [3, 4]|
+---+--------+
scala> dataset.show
+----+
|user|
+----+
| 0|
| 0|
| 3|
| 0|
+----+
scala> dataset.select("user").join(factors, dataset("user") ===
factors("id"), "left_semi").show
+----+
|user|
+----+
| 0|
| 0|
| 3|
| 0|
+----+
scala> dataset.select("user").join(factors, dataset("user") ===
factors("id"), "left_semi").select(factors("id"), factors("features"))
org.apache.spark.sql.AnalysisException: resolved attribute(s)
id#75,features#76 missing from user#17 in operator !Project [id#75,
features#76];;
!Project [id#75, features#76]
+- Join LeftSemi, (user#17 = id#75)
:- Project [user#17]
: +- Project [value#15 AS user#17]
: +- LocalRelation [value#15]
+- Project [_1#72 AS id#75, _2#73 AS features#76]
+- LocalRelation [_1#72, _2#73]
at
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41)
at
org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:89)
at
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:276)
at
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:80)
at
org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:127)
at
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:80)
at
org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:89)
at
org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:53)
at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:69)
at
org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withPlan(Dataset.scala:3103)
at org.apache.spark.sql.Dataset.select(Dataset.scala:1255)
... 48 elided
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]