Refactored LDA implementation

Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1f98970b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1f98970b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1f98970b

Branch: refs/heads/master
Commit: 1f98970bb5f010936bdee7a9610a156e20473696
Parents: 9b2ddcc
Author: myui <[email protected]>
Authored: Thu Apr 20 17:14:37 2017 +0900
Committer: myui <[email protected]>
Committed: Thu Apr 20 17:14:37 2017 +0900

----------------------------------------------------------------------
 .../hivemall/topicmodel/OnlineLDAModel.java     | 86 +++++++++++---------
 .../java/hivemall/utils/lang/ArrayUtils.java    |  4 +-
 2 files changed, 51 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1f98970b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java 
b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 3e7ad10..890adac 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -39,6 +39,9 @@ import org.apache.commons.math3.special.Gamma;
 
 public final class OnlineLDAModel {
 
+    // ---------------------------------
+    // HyperParameters
+
     // number of topics
     private final int _K;
 
@@ -52,25 +55,25 @@ public final class OnlineLDAModel {
     // in the truly online setting, this can be an estimate of the maximum 
number of documents that could ever seen
     private long _D = -1L;
 
-    // defined by (tau0 + updateCount)^(-kappa_)
-    // controls how much old lambda is forgotten
-    private double _rhot;
-
     // positive value which downweights early iterations
     @Nonnegative
     private final double _tau0;
 
     // exponential decay rate (i.e., learning rate) which must be in (0.5, 1] 
to guarantee convergence
+    @Nonnegative
     private final double _kappa;
 
+    // check convergence in the expectation (E) step
+    private final double _delta;
+
+    // ---------------------------------
+
     // how many times EM steps are launched; later EM steps do not drastically 
forget old lambda
     private long _updateCount = 0L;
 
-    // random number generator
-    @Nonnull
-    private final GammaDistribution _gd;
-    private static final double SHAPE = 100.d;
-    private static final double SCALE = 1.d / SHAPE;
+    // defined by (tau0 + updateCount)^(-kappa_)
+    // controls how much old lambda is forgotten
+    private double _rhot;
 
     // parameters
     @Nonnull
@@ -79,9 +82,13 @@ public final class OnlineLDAModel {
     @Nonnull
     private final Map<String, float[]> _lambda;
 
-    // check convergence in the expectation (E) step
-    private final double _delta;
+    // random number generator
+    @Nonnull
+    private final GammaDistribution _gd;
+    private static final double SHAPE = 100.d;
+    private static final double SCALE = 1.d / SHAPE;
 
+    // for mini-batch
     @Nonnull
     private final List<Map<String, Float>> _miniBatchMap;
     private int _miniBatchSize;
@@ -134,7 +141,8 @@ public final class OnlineLDAModel {
 
     public void train(@Nonnull final String[][] miniBatch) {
         if (_D <= 0L) {
-            throw new RuntimeException("Total number of documents MUST be set 
via `setNumTotalDocs()`");
+            throw new IllegalStateException(
+                "Total number of documents MUST be set via 
`setNumTotalDocs()`");
         }
 
         preprocessMiniBatch(miniBatch);
@@ -165,7 +173,7 @@ public final class OnlineLDAModel {
             }
         }
 
-        this._docRatio = (float)((double) _D / _miniBatchSize);
+        this._docRatio = (float) ((double) _D / _miniBatchSize);
     }
 
     private static void initMiniBatchMap(@Nonnull final String[][] miniBatch,
@@ -197,26 +205,29 @@ public final class OnlineLDAModel {
         }
     }
 
-    private void initParams(boolean gammaWithRandom) {
-        _phi = new ArrayList<Map<String, float[]>>();
-        _gamma = new float[_miniBatchSize][];
+    private void initParams(final boolean gammaWithRandom) {
+        final List<Map<String, float[]>> phi = new ArrayList<Map<String, 
float[]>>();
+        final float[][] gamma = new float[_miniBatchSize][];
 
         for (int d = 0; d < _miniBatchSize; d++) {
             if (gammaWithRandom) {
-                _gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd);
+                gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd);
             } else {
-                _gamma[d] = ArrayUtils.newInstance(_K, 1.f);
+                gamma[d] = ArrayUtils.newFloatArray(_K, 1.f);
             }
 
             final Map<String, float[]> phi_d = new HashMap<String, float[]>();
-            _phi.add(phi_d);
-            for (String label : _miniBatchMap.get(d).keySet()) {
+            phi.add(phi_d);
+            for (final String label : _miniBatchMap.get(d).keySet()) {
                 phi_d.put(label, new float[_K]);
                 if (!_lambda.containsKey(label)) { // lambda for newly 
observed word
                     _lambda.put(label, ArrayUtils.newRandomFloatArray(_K, 
_gd));
                 }
             }
         }
+
+        this._phi = phi;
+        this._gamma = gamma;
     }
 
     private void eStep() {
@@ -231,22 +242,19 @@ public final class OnlineLDAModel {
             // for digamma(lambdaSum)
             MathUtils.add(lambdaSum, lambda_label, _K);
 
-            float[] digamma_lambda_label = new float[_K];
             digamma_lambda.put(label, MathUtils.digamma(lambda_label));
         }
-        final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
 
+        final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+        // for each of mini-batch documents, update gamma until convergence
         float[] gamma_d, gammaPrev_d;
         Map<String, float[]> eLogBeta_d;
-
-        // for each of mini-batch documents, update gamma until convergence
         for (int d = 0; d < _miniBatchSize; d++) {
             gamma_d = _gamma[d];
             eLogBeta_d = computeElogBetaPerDoc(d, digamma_lambda, 
digamma_lambdaSum);
 
             do {
-                // (deep) copy the last gamma values
-                gammaPrev_d = gamma_d.clone();
+                gammaPrev_d = gamma_d.clone(); // deep copy the last gamma 
values
 
                 updatePhiPerDoc(d, eLogBeta_d);
                 updateGammaPerDoc(d);
@@ -256,12 +264,13 @@ public final class OnlineLDAModel {
 
     @Nonnull
     private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int 
d,
-            @Nonnull Map<String, float[]> digamma_lambda, @Nonnull float[] 
digamma_lambdaSum) {
+            @Nonnull final Map<String, float[]> digamma_lambda,
+            @Nonnull final float[] digamma_lambdaSum) {
         // Dirichlet expectation (2d) for lambda
         final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>();
         final Map<String, Float> doc = _miniBatchMap.get(d);
 
-        for (String label : doc.keySet()) {
+        for (final String label : doc.keySet()) {
             float[] eLogBeta_label = eLogBeta_d.get(label);
             if (eLogBeta_label == null) {
                 eLogBeta_label = new float[_K];
@@ -276,7 +285,8 @@ public final class OnlineLDAModel {
         return eLogBeta_d;
     }
 
-    private void updatePhiPerDoc(@Nonnegative final int d, @Nonnull 
Map<String, float[]> eLogBeta_d) {
+    private void updatePhiPerDoc(@Nonnegative final int d,
+            @Nonnull final Map<String, float[]> eLogBeta_d) {
         // Dirichlet expectation (2d) for gamma
         final float[] eLogTheta_d = new float[_K];
         final float[] gamma_d = _gamma[d];
@@ -288,7 +298,7 @@ public final class OnlineLDAModel {
         // updating phi w/ normalization
         final Map<String, float[]> phi_d = _phi.get(d);
         final Map<String, Float> doc = _miniBatchMap.get(d);
-        for (String label :  doc.keySet()) {
+        for (String label : doc.keySet()) {
             final float[] phi_label = phi_d.get(label);
             final float[] eLogBeta_label = eLogBeta_d.get(label);
 
@@ -340,7 +350,7 @@ public final class OnlineLDAModel {
             for (String label : _miniBatchMap.get(d).keySet()) {
                 float[] lambdaTilde_label = lambdaTilde.get(label);
                 if (lambdaTilde_label == null) {
-                    lambdaTilde_label = ArrayUtils.newInstance(_K, _eta);
+                    lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta);
                     lambdaTilde.put(label, lambdaTilde_label);
                 }
 
@@ -358,7 +368,7 @@ public final class OnlineLDAModel {
 
             float[] lambdaTilde_label = lambdaTilde.get(label);
             if (lambdaTilde_label == null) {
-                lambdaTilde_label = ArrayUtils.newInstance(_K, _eta);
+                lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta);
             }
 
             for (int k = 0; k < _K; k++) {
@@ -381,8 +391,6 @@ public final class OnlineLDAModel {
      * Estimates the variational bound over all documents using only the 
documents passed as mini-batch.
      */
     private float computeApproxBound() {
-        float score = 0.f;
-
         // prepare
         final float[] gammaSum = new float[_miniBatchSize];
         for (int d = 0; d < _miniBatchSize; d++) {
@@ -399,6 +407,7 @@ public final class OnlineLDAModel {
         final float logGamma_alpha = (float) Gamma.logGamma(_alpha);
         final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha);
 
+        float score = 0.f;
         for (int d = 0; d < _miniBatchSize; d++) {
             final float digamma_gammaSum_d = digamma_gammaSum[d];
 
@@ -410,8 +419,10 @@ public final class OnlineLDAModel {
                 final float[] temp = new float[_K];
                 float max = Float.MIN_VALUE;
                 for (int k = 0; k < _K; k++) {
-                    final float eLogTheta_dk = (float) 
Gamma.digamma(_gamma[d][k]) - digamma_gammaSum_d;
-                    final float eLogBeta_kw = (float) 
Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k];
+                    final float eLogTheta_dk = (float) 
Gamma.digamma(_gamma[d][k])
+                            - digamma_gammaSum_d;
+                    final float eLogBeta_kw = (float) 
Gamma.digamma(lambda_label[k])
+                            - digamma_lambdaSum[k];
 
                     temp[k] = eLogTheta_dk + eLogBeta_kw;
                     if (temp[k] > max) {
@@ -484,7 +495,8 @@ public final class OnlineLDAModel {
         return lambda_label[k];
     }
 
-    public void setLambda(@Nonnull final String label, @Nonnegative final int 
k, final float lambda_k) {
+    public void setLambda(@Nonnull final String label, @Nonnegative final int 
k,
+            final float lambda_k) {
         float[] lambda_label = _lambda.get(label);
         if (lambda_label == null) {
             lambda_label = ArrayUtils.newRandomFloatArray(_K, _gd);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1f98970b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java 
b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 711aac7..c20c363 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -719,12 +719,12 @@ public final class ArrayUtils {
     }
 
     @Nonnull
-    public static float[] newInstance(@Nonnegative int size, float 
filledValue) {
+    public static float[] newFloatArray(@Nonnegative int size, float 
filledValue) {
         final float[] a = new float[size];
         Arrays.fill(a, filledValue);
         return a;
     }
-    
+
     @Nonnull
     public static float[] newRandomFloatArray(@Nonnegative final int size,
             @Nonnull final GammaDistribution gd) {

Reply via email to