Refactored LDA implementation
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1f98970b Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1f98970b Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1f98970b Branch: refs/heads/master Commit: 1f98970bb5f010936bdee7a9610a156e20473696 Parents: 9b2ddcc Author: myui <[email protected]> Authored: Thu Apr 20 17:14:37 2017 +0900 Committer: myui <[email protected]> Committed: Thu Apr 20 17:14:37 2017 +0900 ---------------------------------------------------------------------- .../hivemall/topicmodel/OnlineLDAModel.java | 86 +++++++++++--------- .../java/hivemall/utils/lang/ArrayUtils.java | 4 +- 2 files changed, 51 insertions(+), 39 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1f98970b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java index 3e7ad10..890adac 100644 --- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java +++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java @@ -39,6 +39,9 @@ import org.apache.commons.math3.special.Gamma; public final class OnlineLDAModel { + // --------------------------------- + // HyperParameters + // number of topics private final int _K; @@ -52,25 +55,25 @@ public final class OnlineLDAModel { // in the truly online setting, this can be an estimate of the maximum number of documents that could ever seen private long _D = -1L; - // defined by (tau0 + updateCount)^(-kappa_) - // controls how much old lambda is forgotten - private double _rhot; - // positive value which downweights early iterations @Nonnegative private final double _tau0; // exponential decay rate (i.e., learning rate) which must be in (0.5, 1] to guarantee convergence + @Nonnegative private final double _kappa; + // check convergence in the expectation (E) step + private final double _delta; + + // --------------------------------- + // how many times EM steps are launched; later EM steps do not drastically forget old lambda private long _updateCount = 0L; - // random number generator - @Nonnull - private final GammaDistribution _gd; - private static final double SHAPE = 100.d; - private static final double SCALE = 1.d / SHAPE; + // defined by (tau0 + updateCount)^(-kappa_) + // controls how much old lambda is forgotten + private double _rhot; // parameters @Nonnull @@ -79,9 +82,13 @@ public final class OnlineLDAModel { @Nonnull private final Map<String, float[]> _lambda; - // check convergence in the expectation (E) step - private final double _delta; + // random number generator + @Nonnull + private final GammaDistribution _gd; + private static final double SHAPE = 100.d; + private static final double SCALE = 1.d / SHAPE; + // for mini-batch @Nonnull private final List<Map<String, Float>> _miniBatchMap; private int _miniBatchSize; @@ -134,7 +141,8 @@ public final class OnlineLDAModel { public void train(@Nonnull final String[][] miniBatch) { if (_D <= 0L) { - throw new RuntimeException("Total number of documents MUST be set via `setNumTotalDocs()`"); + throw new IllegalStateException( + "Total number of documents MUST be set via `setNumTotalDocs()`"); } preprocessMiniBatch(miniBatch); @@ -165,7 +173,7 @@ public final class OnlineLDAModel { } } - this._docRatio = (float)((double) _D / _miniBatchSize); + this._docRatio = (float) ((double) _D / _miniBatchSize); } private static void initMiniBatchMap(@Nonnull final String[][] miniBatch, @@ -197,26 +205,29 @@ public final class OnlineLDAModel { } } - private void initParams(boolean gammaWithRandom) { - _phi = new ArrayList<Map<String, float[]>>(); - _gamma = new float[_miniBatchSize][]; + private void initParams(final boolean gammaWithRandom) { + final List<Map<String, float[]>> phi = new ArrayList<Map<String, float[]>>(); + final float[][] gamma = new float[_miniBatchSize][]; for (int d = 0; d < _miniBatchSize; d++) { if (gammaWithRandom) { - _gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd); + gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd); } else { - _gamma[d] = ArrayUtils.newInstance(_K, 1.f); + gamma[d] = ArrayUtils.newFloatArray(_K, 1.f); } final Map<String, float[]> phi_d = new HashMap<String, float[]>(); - _phi.add(phi_d); - for (String label : _miniBatchMap.get(d).keySet()) { + phi.add(phi_d); + for (final String label : _miniBatchMap.get(d).keySet()) { phi_d.put(label, new float[_K]); if (!_lambda.containsKey(label)) { // lambda for newly observed word _lambda.put(label, ArrayUtils.newRandomFloatArray(_K, _gd)); } } } + + this._phi = phi; + this._gamma = gamma; } private void eStep() { @@ -231,22 +242,19 @@ public final class OnlineLDAModel { // for digamma(lambdaSum) MathUtils.add(lambdaSum, lambda_label, _K); - float[] digamma_lambda_label = new float[_K]; digamma_lambda.put(label, MathUtils.digamma(lambda_label)); } - final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum); + final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum); + // for each of mini-batch documents, update gamma until convergence float[] gamma_d, gammaPrev_d; Map<String, float[]> eLogBeta_d; - - // for each of mini-batch documents, update gamma until convergence for (int d = 0; d < _miniBatchSize; d++) { gamma_d = _gamma[d]; eLogBeta_d = computeElogBetaPerDoc(d, digamma_lambda, digamma_lambdaSum); do { - // (deep) copy the last gamma values - gammaPrev_d = gamma_d.clone(); + gammaPrev_d = gamma_d.clone(); // deep copy the last gamma values updatePhiPerDoc(d, eLogBeta_d); updateGammaPerDoc(d); @@ -256,12 +264,13 @@ public final class OnlineLDAModel { @Nonnull private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int d, - @Nonnull Map<String, float[]> digamma_lambda, @Nonnull float[] digamma_lambdaSum) { + @Nonnull final Map<String, float[]> digamma_lambda, + @Nonnull final float[] digamma_lambdaSum) { // Dirichlet expectation (2d) for lambda final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>(); final Map<String, Float> doc = _miniBatchMap.get(d); - for (String label : doc.keySet()) { + for (final String label : doc.keySet()) { float[] eLogBeta_label = eLogBeta_d.get(label); if (eLogBeta_label == null) { eLogBeta_label = new float[_K]; @@ -276,7 +285,8 @@ public final class OnlineLDAModel { return eLogBeta_d; } - private void updatePhiPerDoc(@Nonnegative final int d, @Nonnull Map<String, float[]> eLogBeta_d) { + private void updatePhiPerDoc(@Nonnegative final int d, + @Nonnull final Map<String, float[]> eLogBeta_d) { // Dirichlet expectation (2d) for gamma final float[] eLogTheta_d = new float[_K]; final float[] gamma_d = _gamma[d]; @@ -288,7 +298,7 @@ public final class OnlineLDAModel { // updating phi w/ normalization final Map<String, float[]> phi_d = _phi.get(d); final Map<String, Float> doc = _miniBatchMap.get(d); - for (String label : doc.keySet()) { + for (String label : doc.keySet()) { final float[] phi_label = phi_d.get(label); final float[] eLogBeta_label = eLogBeta_d.get(label); @@ -340,7 +350,7 @@ public final class OnlineLDAModel { for (String label : _miniBatchMap.get(d).keySet()) { float[] lambdaTilde_label = lambdaTilde.get(label); if (lambdaTilde_label == null) { - lambdaTilde_label = ArrayUtils.newInstance(_K, _eta); + lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta); lambdaTilde.put(label, lambdaTilde_label); } @@ -358,7 +368,7 @@ public final class OnlineLDAModel { float[] lambdaTilde_label = lambdaTilde.get(label); if (lambdaTilde_label == null) { - lambdaTilde_label = ArrayUtils.newInstance(_K, _eta); + lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta); } for (int k = 0; k < _K; k++) { @@ -381,8 +391,6 @@ public final class OnlineLDAModel { * Estimates the variational bound over all documents using only the documents passed as mini-batch. */ private float computeApproxBound() { - float score = 0.f; - // prepare final float[] gammaSum = new float[_miniBatchSize]; for (int d = 0; d < _miniBatchSize; d++) { @@ -399,6 +407,7 @@ public final class OnlineLDAModel { final float logGamma_alpha = (float) Gamma.logGamma(_alpha); final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha); + float score = 0.f; for (int d = 0; d < _miniBatchSize; d++) { final float digamma_gammaSum_d = digamma_gammaSum[d]; @@ -410,8 +419,10 @@ public final class OnlineLDAModel { final float[] temp = new float[_K]; float max = Float.MIN_VALUE; for (int k = 0; k < _K; k++) { - final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k]) - digamma_gammaSum_d; - final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k]; + final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k]) + - digamma_gammaSum_d; + final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k]) + - digamma_lambdaSum[k]; temp[k] = eLogTheta_dk + eLogBeta_kw; if (temp[k] > max) { @@ -484,7 +495,8 @@ public final class OnlineLDAModel { return lambda_label[k]; } - public void setLambda(@Nonnull final String label, @Nonnegative final int k, final float lambda_k) { + public void setLambda(@Nonnull final String label, @Nonnegative final int k, + final float lambda_k) { float[] lambda_label = _lambda.get(label); if (lambda_label == null) { lambda_label = ArrayUtils.newRandomFloatArray(_K, _gd); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1f98970b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java index 711aac7..c20c363 100644 --- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java +++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java @@ -719,12 +719,12 @@ public final class ArrayUtils { } @Nonnull - public static float[] newInstance(@Nonnegative int size, float filledValue) { + public static float[] newFloatArray(@Nonnegative int size, float filledValue) { final float[] a = new float[size]; Arrays.fill(a, filledValue); return a; } - + @Nonnull public static float[] newRandomFloatArray(@Nonnegative final int size, @Nonnull final GammaDistribution gd) {
