Repository: incubator-hivemall Updated Branches: refs/heads/master e4e1531e1 -> 6e24d3a95 (forced update)
Refactored LDA implementation Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/6e24d3a9 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/6e24d3a9 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/6e24d3a9 Branch: refs/heads/master Commit: 6e24d3a958c8e1eeb724ce980bb0c2d23d99a398 Parents: 9669c9d Author: myui <yuin...@gmail.com> Authored: Thu Apr 20 21:01:36 2017 +0900 Committer: myui <yuin...@gmail.com> Committed: Thu Apr 20 21:06:10 2017 +0900 ---------------------------------------------------------------------- .../main/java/hivemall/topicmodel/LDAUDTF.java | 7 +- .../hivemall/topicmodel/OnlineLDAModel.java | 161 +++++++++---------- .../java/hivemall/utils/math/MathUtils.java | 83 +++++++--- 3 files changed, 147 insertions(+), 104 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6e24d3a9/core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java index 91ee7a2..9aa15e2 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java @@ -196,8 +196,8 @@ public class LDAUDTF extends UDTFWithOptions { initModel(); } - int length = wordCountsOI.getListLength(args[0]); - String[] wordCounts = new String[length]; + final int length = wordCountsOI.getListLength(args[0]); + final String[] wordCounts = new String[length]; int j = 0; for (int i = 0; i < length; i++) { Object o = wordCountsOI.getListElement(args[0], i); @@ -208,6 +208,9 @@ public class LDAUDTF extends UDTFWithOptions { wordCounts[j] = s; j++; } + if (j == 0) {// avoid empty documents + return; + } count++; if (isAutoD) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6e24d3a9/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java index 890adac..8fef10c 100644 --- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java +++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java @@ -90,12 +90,12 @@ public final class OnlineLDAModel { // for mini-batch @Nonnull - private final List<Map<String, Float>> _miniBatchMap; + private final List<Map<String, Float>> _miniBatchDocs; private int _miniBatchSize; // for computing perplexity private float _docRatio = 1.f; - private long _wordCount = 0L; + private double _valueSum = 0.d; public OnlineLDAModel(int K, float alpha, double delta) { // for E step only instantiation this(K, alpha, 1 / 20.f, -1L, 1020, 0.7, delta); @@ -125,15 +125,13 @@ public final class OnlineLDAModel { // initialize the parameters this._lambda = new HashMap<String, float[]>(100); - this._miniBatchMap = new ArrayList<Map<String, Float>>(); + this._miniBatchDocs = new ArrayList<Map<String, Float>>(); } /** - * In a truly online setting, total number of documents corresponds to the number of documents - * that have ever seen. In that case, users need to manually set the current max number of documents - * via this method. - * Note that, since the same set of documents could be repeatedly passed to `train()`, - * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient. + * In a truly online setting, total number of documents corresponds to the number of documents that have ever seen. In that case, users need to + * manually set the current max number of documents via this method. Note that, since the same set of documents could be repeatedly passed to + * `train()`, simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient. */ public void setNumTotalDocs(@Nonnegative long D) { this._D = D; @@ -161,34 +159,35 @@ public final class OnlineLDAModel { } private void preprocessMiniBatch(@Nonnull final String[][] miniBatch) { - initMiniBatchMap(miniBatch, _miniBatchMap); + initMiniBatch(miniBatch, _miniBatchDocs); - this._miniBatchSize = _miniBatchMap.size(); + this._miniBatchSize = _miniBatchDocs.size(); // accumulate the number of words for each documents - this._wordCount = 0L; + double valueSum = 0.d; for (int d = 0; d < _miniBatchSize; d++) { - for (float n : _miniBatchMap.get(d).values()) { - this._wordCount += n; + for (Float n : _miniBatchDocs.get(d).values()) { + valueSum += n.floatValue(); } } + this._valueSum = valueSum; this._docRatio = (float) ((double) _D / _miniBatchSize); } - private static void initMiniBatchMap(@Nonnull final String[][] miniBatch, - @Nonnull final List<Map<String, Float>> map) { - map.clear(); + private static void initMiniBatch(@Nonnull final String[][] miniBatch, + @Nonnull final List<Map<String, Float>> docs) { + docs.clear(); final FeatureValue probe = new FeatureValue(); // parse document for (final String[] e : miniBatch) { - if (e == null) { + if (e == null || e.length == 0) { continue; } - final Map<String, Float> docMap = new HashMap<String, Float>(); + final Map<String, Float> doc = new HashMap<String, Float>(); // parse features for (String fv : e) { @@ -198,10 +197,10 @@ public final class OnlineLDAModel { FeatureValue.parseFeatureAsString(fv, probe); String label = probe.getFeatureAsString(); float value = probe.getValueAsFloat(); - docMap.put(label, value); + doc.put(label, Float.valueOf(value)); } - map.add(docMap); + docs.add(doc); } } @@ -218,7 +217,7 @@ public final class OnlineLDAModel { final Map<String, float[]> phi_d = new HashMap<String, float[]>(); phi.add(phi_d); - for (final String label : _miniBatchMap.get(d).keySet()) { + for (final String label : _miniBatchDocs.get(d).keySet()) { phi_d.put(label, new float[_K]); if (!_lambda.containsKey(label)) { // lambda for newly observed word _lambda.put(label, ArrayUtils.newRandomFloatArray(_K, _gd)); @@ -233,19 +232,19 @@ public final class OnlineLDAModel { private void eStep() { // since lambda is invariant in the expectation step, // `digamma`s of lambda values for Elogbeta are pre-computed - final float[] lambdaSum = new float[_K]; + final double[] lambdaSum = new double[_K]; final Map<String, float[]> digamma_lambda = new HashMap<String, float[]>(); for (Map.Entry<String, float[]> e : _lambda.entrySet()) { String label = e.getKey(); float[] lambda_label = e.getValue(); // for digamma(lambdaSum) - MathUtils.add(lambdaSum, lambda_label, _K); + MathUtils.add(lambda_label, lambdaSum, _K); digamma_lambda.put(label, MathUtils.digamma(lambda_label)); } - final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum); + final double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum); // for each of mini-batch documents, update gamma until convergence float[] gamma_d, gammaPrev_d; Map<String, float[]> eLogBeta_d; @@ -265,11 +264,11 @@ public final class OnlineLDAModel { @Nonnull private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int d, @Nonnull final Map<String, float[]> digamma_lambda, - @Nonnull final float[] digamma_lambdaSum) { - // Dirichlet expectation (2d) for lambda - final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>(); - final Map<String, Float> doc = _miniBatchMap.get(d); + @Nonnull final double[] digamma_lambdaSum) { + final Map<String, Float> doc = _miniBatchDocs.get(d); + // Dirichlet expectation (2d) for lambda + final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>(doc.size()); for (final String label : doc.keySet()) { float[] eLogBeta_label = eLogBeta_d.get(label); if (eLogBeta_label == null) { @@ -278,7 +277,7 @@ public final class OnlineLDAModel { } final float[] digamma_lambda_label = digamma_lambda.get(label); for (int k = 0; k < _K; k++) { - eLogBeta_label[k] = digamma_lambda_label[k] - digamma_lambdaSum[k]; + eLogBeta_label[k] = (float) (digamma_lambda_label[k] - digamma_lambdaSum[k]); } } @@ -288,28 +287,27 @@ public final class OnlineLDAModel { private void updatePhiPerDoc(@Nonnegative final int d, @Nonnull final Map<String, float[]> eLogBeta_d) { // Dirichlet expectation (2d) for gamma - final float[] eLogTheta_d = new float[_K]; final float[] gamma_d = _gamma[d]; - final float digamma_gammaSum_d = (float) Gamma.digamma(MathUtils.sum(gamma_d)); + final double digamma_gammaSum_d = Gamma.digamma(MathUtils.sum(gamma_d)); + final double[] eLogTheta_d = new double[_K]; for (int k = 0; k < _K; k++) { - eLogTheta_d[k] = (float) Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d; + eLogTheta_d[k] = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d; } // updating phi w/ normalization final Map<String, float[]> phi_d = _phi.get(d); - final Map<String, Float> doc = _miniBatchMap.get(d); + final Map<String, Float> doc = _miniBatchDocs.get(d); for (String label : doc.keySet()) { final float[] phi_label = phi_d.get(label); final float[] eLogBeta_label = eLogBeta_d.get(label); - float normalizer = 0.f; + double normalizer = 0.d; for (int k = 0; k < _K; k++) { float phiVal = (float) Math.exp(eLogBeta_label[k] + eLogTheta_d[k]) + 1E-20f; phi_label[k] = phiVal; normalizer += phiVal; } - // normalize for (int k = 0; k < _K; k++) { phi_label[k] /= normalizer; } @@ -317,7 +315,7 @@ public final class OnlineLDAModel { } private void updateGammaPerDoc(@Nonnegative final int d) { - final Map<String, Float> doc = _miniBatchMap.get(d); + final Map<String, Float> doc = _miniBatchDocs.get(d); final Map<String, float[]> phi_d = _phi.get(d); final float[] gamma_d = _gamma[d]; @@ -326,7 +324,7 @@ public final class OnlineLDAModel { } for (Map.Entry<String, Float> e : doc.entrySet()) { final float[] phi_label = phi_d.get(e.getKey()); - final float val = e.getValue(); + final float val = e.getValue().floatValue(); for (int k = 0; k < _K; k++) { gamma_d[k] += phi_label[k] * val; } @@ -347,7 +345,7 @@ public final class OnlineLDAModel { final Map<String, float[]> lambdaTilde = new HashMap<String, float[]>(); for (int d = 0; d < _miniBatchSize; d++) { final Map<String, float[]> phi_d = _phi.get(d); - for (String label : _miniBatchMap.get(d).keySet()) { + for (String label : _miniBatchDocs.get(d).keySet()) { float[] lambdaTilde_label = lambdaTilde.get(label); if (lambdaTilde_label == null) { lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta); @@ -382,73 +380,67 @@ public final class OnlineLDAModel { * Calculate approximate perplexity for the current mini-batch. */ public float computePerplexity() { - float bound = computeApproxBound(); - float perWordBound = bound / (_docRatio * _wordCount); - return (float) Math.exp(-1.f * perWordBound); + double bound = computeApproxBound(); + double perWordBound = bound / (_docRatio * _valueSum); + return (float) Math.exp(-1.d * perWordBound); } /** * Estimates the variational bound over all documents using only the documents passed as mini-batch. */ - private float computeApproxBound() { + private double computeApproxBound() { // prepare - final float[] gammaSum = new float[_miniBatchSize]; + final double[] gammaSum = new double[_miniBatchSize]; for (int d = 0; d < _miniBatchSize; d++) { gammaSum[d] = MathUtils.sum(_gamma[d]); } - final float[] digamma_gammaSum = MathUtils.digamma(gammaSum); + final double[] digamma_gammaSum = MathUtils.digamma(gammaSum); - final float[] lambdaSum = new float[_K]; + final double[] lambdaSum = new double[_K]; for (float[] lambda_label : _lambda.values()) { - MathUtils.add(lambdaSum, lambda_label, _K); + MathUtils.add(lambda_label, lambdaSum, _K); } - final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum); + final double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum); - final float logGamma_alpha = (float) Gamma.logGamma(_alpha); - final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha); + final double logGamma_alpha = Gamma.logGamma(_alpha); + final double logGamma_alphaSum = Gamma.logGamma(_K * _alpha); - float score = 0.f; + double score = 0.d; for (int d = 0; d < _miniBatchSize; d++) { - final float digamma_gammaSum_d = digamma_gammaSum[d]; + final double digamma_gammaSum_d = digamma_gammaSum[d]; + final float[] gamma_d = _gamma[d]; // E[log p(doc | theta, beta)] - for (Map.Entry<String, Float> e : _miniBatchMap.get(d).entrySet()) { + for (Map.Entry<String, Float> e : _miniBatchDocs.get(d).entrySet()) { final float[] lambda_label = _lambda.get(e.getKey()); // logsumexp( Elogthetad + Elogbetad ) - final float[] temp = new float[_K]; - float max = Float.MIN_VALUE; + final double[] temp = new double[_K]; + double max = Double.MIN_VALUE; for (int k = 0; k < _K; k++) { - final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k]) - - digamma_gammaSum_d; - final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k]) - - digamma_lambdaSum[k]; - - temp[k] = eLogTheta_dk + eLogBeta_kw; - if (temp[k] > max) { - max = temp[k]; + double eLogTheta_dk = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d; + double eLogBeta_kw = Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k]; + final double tempK = eLogTheta_dk + eLogBeta_kw; + if (tempK > max) { + max = tempK; } + temp[k] = tempK; } - float logsumexp = 0.f; - for (int k = 0; k < _K; k++) { - logsumexp += (float) Math.exp(temp[k] - max); - } - logsumexp = max + (float) Math.log(logsumexp); + double logsumexp = MathUtils.logsumexp(temp, max); // sum( word count * logsumexp(...) ) - score += e.getValue() * logsumexp; + score += e.getValue().floatValue() * logsumexp; } // E[log p(theta | alpha) - log q(theta | gamma)] for (int k = 0; k < _K; k++) { - final float gamma_dk = _gamma[d][k]; + float gamma_dk = gamma_d[k]; // sum( (alpha - gammad) * Elogthetad ) - score += (_alpha - gamma_dk) - * ((float) Gamma.digamma(gamma_dk) - digamma_gammaSum_d); + score += (_alpha - gamma_dk) * (Gamma.digamma(gamma_dk) - digamma_gammaSum_d); // sum( gammaln(gammad) - gammaln(alpha) ) - score += (float) Gamma.logGamma(gamma_dk) - logGamma_alpha; + score += Gamma.logGamma(gamma_dk) - logGamma_alpha; } score += logGamma_alphaSum; // gammaln(sum(alpha)) score -= Gamma.logGamma(gammaSum[d]); // gammaln(sum(gammad)) @@ -458,25 +450,25 @@ public final class OnlineLDAModel { // (i.e., online setting); likelihood should be always roughly on the same scale score *= _docRatio; - final float logGamma_eta = (float) Gamma.logGamma(_eta); - final float logGamma_etaSum = (float) Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta + final double logGamma_eta = Gamma.logGamma(_eta); + final double logGamma_etaSum = Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta // E[log p(beta | eta) - log q (beta | lambda)] - for (float[] lambda_label : _lambda.values()) { + for (final float[] lambda_label : _lambda.values()) { for (int k = 0; k < _K; k++) { - final float lambda_k = lambda_label[k]; + float lambda_label_k = lambda_label[k]; // sum( (eta - lambda) * Elogbeta ) - score += (_eta - lambda_k) - * (float) (Gamma.digamma(lambda_k) - digamma_lambdaSum[k]); + score += (_eta - lambda_label_k) + * (Gamma.digamma(lambda_label_k) - digamma_lambdaSum[k]); // sum( gammaln(lambda) - gammaln(eta) ) - score += (float) Gamma.logGamma(lambda_k) - logGamma_eta; + score += Gamma.logGamma(lambda_label_k) - logGamma_eta; } } for (int k = 0; k < _K; k++) { // sum( gammaln(etaSum) - gammaln( lambdaSum_k ) - score += logGamma_etaSum - (float) Gamma.logGamma(lambdaSum[k]); + score += logGamma_etaSum - Gamma.logGamma(lambdaSum[k]); } return score; @@ -513,7 +505,7 @@ public final class OnlineLDAModel { @Nonnull public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k, @Nonnegative int topN) { - float lambdaSum = 0.f; + double lambdaSum = 0.d; final SortedMap<Float, List<String>> sortedLambda = new TreeMap<Float, List<String>>( Collections.reverseOrder()); @@ -535,7 +527,8 @@ public final class OnlineLDAModel { topN = Math.min(topN, _lambda.keySet().size()); int tt = 0; for (Map.Entry<Float, List<String>> e : sortedLambda.entrySet()) { - ret.put(e.getKey() / lambdaSum, e.getValue()); + float key = (float) (e.getKey().floatValue() / lambdaSum); + ret.put(Float.valueOf(key), e.getValue()); if (++tt == topN) { break; @@ -556,9 +549,9 @@ public final class OnlineLDAModel { // normalize topic distribution final float[] topicDistr = new float[_K]; final float[] gamma0 = _gamma[0]; - final float gammaSum = MathUtils.sum(gamma0); + final double gammaSum = MathUtils.sum(gamma0); for (int k = 0; k < _K; k++) { - topicDistr[k] = gamma0[k] / gammaSum; + topicDistr[k] = (float) (gamma0[k] / gammaSum); } return topicDistr; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6e24d3a9/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 7fdea55..9b46527 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -314,44 +314,91 @@ public final class MathUtils { return perm; } - public static float sum(@Nullable final float[] a) { - if (a == null) { - return 0.f; + public static double sum(@Nullable final float[] arr) { + if (arr == null) { + return 0.d; } - float sum = 0.f; - for (float v : a) { + double sum = 0.d; + for (float v : arr) { sum += v; } return sum; } - public static float sum(@Nullable final float[] a, @Nonnegative final int size) { - if (a == null) { - return 0.f; - } - - float sum = 0.f; + public static void add(@Nonnull final float[] src, @Nonnull final float[] dst, final int size) { for (int i = 0; i < size; i++) { - sum += a[i]; + dst[i] += src[i]; } - return sum; } - public static void add(@Nonnull final float[] dst, @Nonnull final float[] toAdd, final int size) { + public static void add(@Nonnull final float[] src, @Nonnull final double[] dst, final int size) { for (int i = 0; i < size; i++) { - dst[i] += toAdd[i]; + dst[i] += src[i]; } } @Nonnull - public static float[] digamma(@Nonnull final float[] a) { - final int k = a.length; + public static float[] digamma(@Nonnull final float[] arr) { + final int k = arr.length; final float[] ret = new float[k]; for (int i = 0; i < k; i++) { - ret[i] = (float) Gamma.digamma(a[i]); + ret[i] = (float) Gamma.digamma(arr[i]); } return ret; } + @Nonnull + public static double[] digamma(@Nonnull final double[] arr) { + final int k = arr.length; + final double[] ret = new double[k]; + for (int i = 0; i < k; i++) { + ret[i] = Gamma.digamma(arr[i]); + } + return ret; + } + + public static float logsumexp(@Nonnull final float[] arr) { + if (arr.length == 0) { + return 0.f; + } + float max = 0.f; + for (final float v : arr) { + if (v > max) { + max = v; + } + } + return logsumexp(arr, max); + } + + public static float logsumexp(@Nonnull final float[] arr, final float max) { + double logsumexp = 0.d; + for (final float v : arr) { + logsumexp += Math.exp(v - max); + } + logsumexp = Math.log(logsumexp) + max; + return (float) logsumexp; + } + + public static double logsumexp(@Nonnull final double[] arr) { + if (arr.length == 0) { + return 0.d; + } + double max = 0.d; + for (final double v : arr) { + if (v > max) { + max = v; + } + } + return logsumexp(arr, max); + } + + public static double logsumexp(@Nonnull final double[] arr, final double max) { + double logsumexp = 0.d; + for (final double v : arr) { + logsumexp += Math.exp(v - max); + } + return Math.log(logsumexp) + max; + } + }