Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/8329#discussion_r37571490
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala ---
    @@ -575,6 +575,48 @@ class DistributedLDAModel private[clustering] (
         }
       }
     
    +  /**
    +   * Return the top topic for each (doc, term) pair.  I.e., for each 
document, what is the most
    +   * likely topic generating each term?
    +   *
    +   * @return RDD of (doc ID, assignment of top topic index for each term),
    +   *         where the assignment is specified via a pair of zippable 
arrays
    +   *         (term indices, topic indices).  Note that terms will be 
omitted if not present in
    +   *         the document.
    +   */
    +  lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = {
    +    // For reference, compare the below code with the core part of 
EMLDAOptimizer.next().
    +    val eta = topicConcentration
    +    val W = vocabSize
    +    val alpha = docConcentration(0)
    +    val N_k = globalTopicTotals
    +    val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], 
Array[Int])] => Unit =
    +      (edgeContext) => {
    +        // E-STEP: Compute gamma_{wjk} (smoothed topic distributions).
    +        val scaledTopicDistribution: TopicCounts =
    +          computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, 
eta, alpha)
    +        // For this (doc j, term w), send top topic k to doc vertex.
    +        val topTopic: Int = argmax(scaledTopicDistribution)
    +        val term: Int = index2term(edgeContext.dstId)
    +        edgeContext.sendToSrc((Array(term), Array(topTopic)))
    +      }
    +    val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => 
(Array[Int], Array[Int]) =
    +      (terms_topics0, terms_topics1) => {
    +        (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ 
terms_topics1._2)
    +      }
    +    // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
    +    graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, 
mergeMsg).filter(isDocumentVertex)
    +        .map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) 
=>
    +      val (sortedTerms, sortedTopics) = 
terms.zip(topics).sortBy(_._1).unzip
    --- End diff --
    
    Leave a TODO because `zip` is not efficient


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to