Github user akopich commented on a diff in the pull request:
https://github.com/apache/spark/pull/18924#discussion_r142625490
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala ---
@@ -462,36 +462,55 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta)
val alpha = this.alpha.asBreeze
val gammaShape = this.gammaShape
+ val optimizeDocConcentration = this.optimizeDocConcentration
+ // We calculate logphat in the same pass as other statistics, but we
only need
+ // it if we are optimizing docConcentration
+ val logphatPartOptionBase = () => if (optimizeDocConcentration)
Some(BDV.zeros[Double](k))
+ else None
- val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions
{ docs =>
+ val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] =
batch.mapPartitions { docs =>
val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
val stat = BDM.zeros[Double](k, vocabSize)
- var gammaPart = List[BDV[Double]]()
+ val logphatPartOption = logphatPartOptionBase()
+ var nonEmptyDocCount : Long = 0L
nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
+ nonEmptyDocCount += 1
val (gammad, sstats, ids) =
OnlineLDAOptimizer.variationalTopicInference(
termCounts, expElogbetaBc.value, alpha, gammaShape, k)
- stat(::, ids) := stat(::, ids).toDenseMatrix + sstats
- gammaPart = gammad :: gammaPart
+ stat(::, ids) := stat(::, ids) + sstats
+ logphatPartOption.foreach(_ +=
LDAUtils.dirichletExpectation(gammad))
}
- Iterator((stat, gammaPart))
- }.persist(StorageLevel.MEMORY_AND_DISK)
- val statsSum: BDM[Double] =
stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))(
- _ += _, _ += _)
- val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
- stats.map(_._2).flatMap(list =>
list).collect().map(_.toDenseMatrix): _*)
- stats.unpersist()
- expElogbetaBc.destroy(false)
- val batchResult = statsSum *:* expElogbeta.t
+ Iterator((stat, logphatPartOption, nonEmptyDocCount))
+ }
+
+ val elementWiseSum = (u : (BDM[Double], Option[BDV[Double]], Long),
+ v : (BDM[Double], Option[BDV[Double]],
Long)) => {
+ u._1 += v._1
+ u._2.foreach(_ += v._2.get)
+ (u._1, u._2, u._3 + v._3)
+ }
+
+ val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]],
nonEmptyDocsN : Long) = stats
+ .treeAggregate((BDM.zeros[Double](k, vocabSize),
logphatPartOptionBase(), 0L))(
+ elementWiseSum, elementWiseSum
+ )
+ val batchResult = statsSum *:* expElogbeta.t
// Note that this is an optimization to avoid batch.count
- updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
- if (optimizeDocConcentration) updateAlpha(gammat)
+ val batchSize = (miniBatchFraction * corpusSize).ceil.toInt
+ updateLambda(batchResult, batchSize)
+
+ logphatOption.foreach(_ /= batchSize.toDouble)
+ logphatOption.foreach(updateAlpha(_, nonEmptyDocsN))
+
+ expElogbetaBc.destroy(false)
+
this
}
/**
- * Update lambda based on the batch submitted. batchSize can be
different for each iteration.
+ * Update lambda based on the batch submitted. nonEmptyDocsN can be
different for each iteration.
--- End diff --
Thanks. Comment reverted.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]