Github user MLnick commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18748#discussion_r139141803
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
---
    @@ -356,6 +371,40 @@ class ALSModel private[ml] (
       }
     
       /**
    +   * Returns top `numUsers` users recommended for each item id in the 
input data set. Note that if
    +   * there are duplicate ids in the input dataset, only one set of 
recommendations per unique id
    +   * will be returned.
    +   * @param dataset a Dataset containing a column of item ids. The column 
name must match `itemCol`.
    +   * @param numUsers max number of recommendations for each item.
    +   * @return a DataFrame of (itemCol: Int, recommendations), where 
recommendations are
    +   *         stored as an array of (userCol: Int, rating: Float) Rows.
    +   */
    +  @Since("2.3.0")
    +  def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): 
DataFrame = {
    +    val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors, 
$(itemCol))
    +    recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol), 
numUsers)
    +  }
    +
    +  /**
    +   * Returns a subset of a factor DataFrame limited to only those unique 
ids contained
    +   * in the input dataset.
    +   * @param dataset input Dataset containing id column to user to filter 
factors.
    +   * @param factors factor DataFrame to filter.
    +   * @param column column name containing the ids in the input dataset.
    +   * @return DataFrame containing factors only for those ids present in 
both the input dataset and
    +   *         the factor DataFrame.
    +   */
    +  private def getSourceFactorSubset(
    +      dataset: Dataset[_],
    +      factors: DataFrame,
    +      column: String): DataFrame = {
    +    dataset.select(column)
    +      .distinct()
    +      .join(factors, dataset(column) === factors("id"))
    +      .select(factors("id"), factors("features"))
    +  }
    --- End diff --
    
    How does that eliminate the need for `distinct`?
    
    e.g. take a look at the below:
    
    ```
    scala> factors.show
    +---+--------+
    | id|features|
    +---+--------+
    |  0|  [0, 1]|
    |  1|  [1, 2]|
    |  2|  [2, 3]|
    |  3|  [3, 4]|
    +---+--------+
    
    
    scala> dataset.show
    +----+
    |user|
    +----+
    |   0|
    |   0|
    |   3|
    |   0|
    +----+
    
    
    scala> dataset.select("user").join(factors, dataset("user") === 
factors("id"), "left_semi").show
    +----+
    |user|
    +----+
    |   0|
    |   0|
    |   3|
    |   0|
    +----+
    
    scala> dataset.select("user").join(factors, dataset("user") === 
factors("id"), "left_semi").select(factors("id"), factors("features"))
    org.apache.spark.sql.AnalysisException: resolved attribute(s) 
id#75,features#76 missing from user#17 in operator !Project [id#75, 
features#76];;
    !Project [id#75, features#76]
    +- Join LeftSemi, (user#17 = id#75)
       :- Project [user#17]
       :  +- Project [value#15 AS user#17]
       :     +- LocalRelation [value#15]
       +- Project [_1#72 AS id#75, _2#73 AS features#76]
          +- LocalRelation [_1#72, _2#73]
    
      at 
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41)
      at 
org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:89)
      at 
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:276)
      at 
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:80)
      at 
org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:127)
      at 
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:80)
      at 
org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:89)
      at 
org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:53)
      at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:69)
      at 
org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withPlan(Dataset.scala:3103)
      at org.apache.spark.sql.Dataset.select(Dataset.scala:1255)
      ... 48 elided
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to