Close #89: [HIVEMALL-120] Refactor on LDA/pLSA's mini-batch & buffered iteration logic
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/0495ffad Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/0495ffad Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/0495ffad Branch: refs/heads/master Commit: 0495ffadbc42bffa36cb583622708ae1fa65a44e Parents: bfc5b75 Author: Takuya Kitazawa <[email protected]> Authored: Tue Jun 27 13:53:45 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Tue Jun 27 13:53:45 2017 +0900 ---------------------------------------------------------------------- .../AbstractProbabilisticTopicModel.java | 98 ++++ .../topicmodel/IncrementalPLSAModel.java | 51 +- .../hivemall/topicmodel/LDAPredictUDAF.java | 2 +- .../main/java/hivemall/topicmodel/LDAUDTF.java | 503 +------------------ .../hivemall/topicmodel/OnlineLDAModel.java | 82 +-- .../hivemall/topicmodel/PLSAPredictUDAF.java | 2 +- .../main/java/hivemall/topicmodel/PLSAUDTF.java | 490 +----------------- .../ProbabilisticTopicModelBaseUDTF.java | 487 ++++++++++++++++++ .../topicmodel/IncrementalPLSAModelTest.java | 8 +- .../java/hivemall/topicmodel/LDAUDTFTest.java | 14 +- .../hivemall/topicmodel/OnlineLDAModelTest.java | 4 +- .../java/hivemall/topicmodel/PLSAUDTFTest.java | 14 +- 12 files changed, 653 insertions(+), 1102 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java new file mode 100644 index 0000000..3c097e2 --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.annotations.VisibleForTesting; +import hivemall.model.FeatureValue; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.*; + +public abstract class AbstractProbabilisticTopicModel { + + // number of topics + protected final int _K; + + // total number of documents + protected long _D; + + // for mini-batch + @Nonnull + protected final List<Map<String, Float>> _miniBatchDocs; + protected int _miniBatchSize; + + public AbstractProbabilisticTopicModel(int K) { + this._K = K; + this._D = 0L; + this._miniBatchDocs = new ArrayList<Map<String, Float>>(); + } + + protected static void initMiniBatch(@Nonnull final String[][] miniBatch, + @Nonnull final List<Map<String, Float>> docs) { + docs.clear(); + + final FeatureValue probe = new FeatureValue(); + + // parse document + for (final String[] e : miniBatch) { + if (e == null || e.length == 0) { + continue; + } + + final Map<String, Float> doc = new HashMap<String, Float>(); + + // parse features + for (String fv : e) { + if (fv == null) { + continue; + } + FeatureValue.parseFeatureAsString(fv, probe); + String label = probe.getFeatureAsString(); + float value = probe.getValueAsFloat(); + doc.put(label, Float.valueOf(value)); + } + + docs.add(doc); + } + } + + public void accumulateDocCount() { + this._D += 1; + } + + public long getDocCount() { + return _D; + } + + public abstract void train(@Nonnull final String[][] miniBatch); + + public abstract float computePerplexity(); + + @Nonnull + public abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k); + + @Nonnull + public abstract float[] getTopicDistribution(@Nonnull final String[] doc); + + @VisibleForTesting + abstract float getWordScore(@Nonnull final String word, @Nonnegative final int topic); + + public abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, final float score); +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java index 6eef23e..b99e670 100644 --- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java +++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java @@ -20,9 +20,10 @@ package hivemall.topicmodel; import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray; import static hivemall.utils.math.MathUtils.l1normalize; + +import hivemall.annotations.VisibleForTesting; import hivemall.math.random.PRNG; import hivemall.math.random.RandomNumberGeneratorFactory; -import hivemall.model.FeatureValue; import hivemall.utils.math.MathUtils; import java.util.ArrayList; @@ -37,14 +38,11 @@ import java.util.TreeMap; import javax.annotation.Nonnegative; import javax.annotation.Nonnull; -public final class IncrementalPLSAModel { +public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel { // --------------------------------- // HyperParameters - // number of topics - private final int _K; - // control how much P(w|z) update is affected by the last value private final float _alpha; @@ -65,20 +63,15 @@ public final class IncrementalPLSAModel { private List<float[]> _p_dz; // P(z|d) probability of topics for documents private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic - @Nonnull - private final List<Map<String, Float>> _miniBatchDocs; - private int _miniBatchSize; - public IncrementalPLSAModel(int K, float alpha, double delta) { - this._K = K; + super(K); + this._alpha = alpha; this._delta = delta; this._rnd = RandomNumberGeneratorFactory.createPRNG(1001); this._p_zw = new HashMap<String, float[]>(); - - this._miniBatchDocs = new ArrayList<Map<String, Float>>(); } public void train(@Nonnull final String[][] miniBatch) { @@ -106,35 +99,6 @@ public final class IncrementalPLSAModel { } } - private static void initMiniBatch(@Nonnull final String[][] miniBatch, - @Nonnull final List<Map<String, Float>> docs) { - docs.clear(); - - final FeatureValue probe = new FeatureValue(); - - // parse document - for (final String[] e : miniBatch) { - if (e == null || e.length == 0) { - continue; - } - - final Map<String, Float> doc = new HashMap<String, Float>(); - - // parse features - for (String fv : e) { - if (fv == null) { - continue; - } - FeatureValue.parseFeatureAsString(fv, probe); - String word = probe.getFeatureAsString(); - float value = probe.getValueAsFloat(); - doc.put(word, Float.valueOf(value)); - } - - docs.add(doc); - } - } - private void initParams() { final List<float[]> p_dz = new ArrayList<float[]>(); final List<Map<String, float[]>> p_dwz = new ArrayList<Map<String, float[]>>(); @@ -302,11 +266,12 @@ public final class IncrementalPLSAModel { return _p_dz.get(0); } - public float getProbability(@Nonnull final String w, @Nonnegative final int z) { + @VisibleForTesting + float getWordScore(@Nonnull final String w, @Nonnegative final int z) { return _p_zw.get(w)[z]; } - public void setProbability(@Nonnull final String w, @Nonnegative final int z, final float prob) { + public void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) { float[] prob_label = _p_zw.get(w); if (prob_label == null) { prob_label = newRandomFloatArray(_K, _rnd); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java index 03779b0..94d510a 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java @@ -471,7 +471,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { for (int k = 0; k < topics; k++) { final float lambda_k = lambda_word.get(k).floatValue(); if (lambda_k != -1.f) { - model.setLambda(word, k, lambda_k); + model.setWordScore(word, k, lambda_k); } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java index de57518..41386a4 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java @@ -18,27 +18,7 @@ */ package hivemall.topicmodel; -import hivemall.UDTFWithOptions; -import hivemall.annotations.VisibleForTesting; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.io.FileUtils; -import hivemall.utils.io.NIOUtils; -import hivemall.utils.io.NioStatefullSegment; -import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Primitives; -import hivemall.utils.lang.SizeOf; - -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.SortedMap; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; @@ -46,96 +26,52 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.io.FloatWritable; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapred.Counters; -import org.apache.hadoop.mapred.Reporter; @Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])" + " - Returns a relation consists of <int topic, string word, float score>") -public class LDAUDTF extends UDTFWithOptions { +public class LDAUDTF extends ProbabilisticTopicModelBaseUDTF { private static final Log logger = LogFactory.getLog(LDAUDTF.class); - public static final int DEFAULT_TOPICS = 10; public static final double DEFAULT_DELTA = 1E-3d; // Options - protected int topics; protected float alpha; protected float eta; protected long numDocs; protected double tau0; protected double kappa; - protected int iterations; protected double delta; - protected double eps; - protected int miniBatchSize; - - // if `num_docs` option is not given, this flag will be true - // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model - protected boolean isAutoD; - - // number of proceeded training samples - protected long count; - - protected String[][] miniBatch; - protected int miniBatchCount; - - protected transient OnlineLDAModel model; - - protected ListObjectInspector wordCountsOI; - - // for iterations - protected NioStatefullSegment fileIO; - protected ByteBuffer inputBuf; public LDAUDTF() { - this.topics = DEFAULT_TOPICS; + super(); + this.alpha = 1.f / topics; this.eta = 1.f / topics; this.numDocs = -1L; this.tau0 = 64.d; this.kappa = 0.7; - this.iterations = 10; this.delta = DEFAULT_DELTA; - this.eps = 1E-1d; - this.miniBatchSize = 128; // if 1, truly online setting } @Override protected Options getOptions() { - Options opts = new Options(); - opts.addOption("k", "topics", true, "The number of topics [default: 10]"); + Options opts = super.getOptions(); opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]"); opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]"); opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]"); opts.addOption("tau", "tau0", true, "The parameter which downweights early iterations [default: 64.0]"); opts.addOption("kappa", true, "Exponential decay rate (i.e., learning rate) [default: 0.7]"); - opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]"); - opts.addOption("eps", "epsilon", true, - "Check convergence based on the difference of perplexity [default: 1E-1]"); - opts.addOption("s", "mini_batch_size", true, - "Repeat model updating per mini-batch [default: 128]"); return opts; } @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { - CommandLine cl = null; + CommandLine cl = super.processOptions(argOIs); - if (argOIs.length >= 2) { - String rawArgs = HiveUtils.getConstString(argOIs[1]); - cl = parseOptions(rawArgs); - this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS); + if (cl != null) { this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics); this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topics); this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L); @@ -147,436 +83,13 @@ public class LDAUDTF extends UDTFWithOptions { if (kappa <= 0.5 || kappa > 1.d) { throw new UDFArgumentException("'-kappa' must be in (0.5, 1.0]: " + kappa); } - this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10); - if (iterations < 1) { - throw new UDFArgumentException( - "'-iterations' must be greater than or equals to 1: " + iterations); - } this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA); - this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d); - this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128); } return cl; } - @Override - public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - if (argOIs.length < 1) { - throw new UDFArgumentException( - "_FUNC_ takes 1 arguments: array<string> words [, const string options]"); - } - - this.wordCountsOI = HiveUtils.asListOI(argOIs[0]); - HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector()); - - processOptions(argOIs); - - this.model = null; - this.count = 0L; - this.isAutoD = (numDocs < 0L); - this.miniBatch = new String[miniBatchSize][]; - this.miniBatchCount = 0; - - ArrayList<String> fieldNames = new ArrayList<String>(); - ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); - fieldNames.add("topic"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); - fieldNames.add("word"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); - fieldNames.add("score"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); - - return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); - } - - protected void initModel() { - this.model = new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta); - } - - @Override - public void process(Object[] args) throws HiveException { - if (model == null) { - initModel(); - } - - final int length = wordCountsOI.getListLength(args[0]); - final String[] wordCounts = new String[length]; - int j = 0; - for (int i = 0; i < length; i++) { - Object o = wordCountsOI.getListElement(args[0], i); - if (o == null) { - throw new HiveException("Given feature vector contains invalid elements"); - } - String s = o.toString(); - wordCounts[j] = s; - j++; - } - if (j == 0) {// avoid empty documents - return; - } - - count++; - if (isAutoD) { - model.setNumTotalDocs(count); - } - - recordTrainSampleToTempFile(wordCounts); - - miniBatch[miniBatchCount] = wordCounts; - miniBatchCount++; - - if (miniBatchCount == miniBatchSize) { - model.train(miniBatch); - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - } - } - - protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts) - throws HiveException { - if (iterations == 1) { - return; - } - - ByteBuffer buf = inputBuf; - NioStatefullSegment dst = fileIO; - - if (buf == null) { - final File file; - try { - file = File.createTempFile("hivemall_lda", ".sgmt"); - file.deleteOnExit(); - if (!file.canWrite()) { - throw new UDFArgumentException("Cannot write a temporary file: " - + file.getAbsolutePath()); - } - logger.info("Record training samples to a file: " + file.getAbsolutePath()); - } catch (IOException ioe) { - throw new UDFArgumentException(ioe); - } catch (Throwable e) { - throw new UDFArgumentException(e); - } - this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB - this.fileIO = dst = new NioStatefullSegment(file, false); - } - - // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... - int wcLengthTotal = 0; - for (String wc : wordCounts) { - if (wc == null) { - continue; - } - wcLengthTotal += wc.length(); - } - int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * wordCounts.length + wcLengthTotal - * SizeOf.CHAR; - - int remain = buf.remaining(); - if (remain < requiredRecordBytes) { - writeBuffer(buf, dst); - } - - buf.putInt(requiredRecordBytes); - buf.putInt(wordCounts.length); - for (String wc : wordCounts) { - NIOUtils.putString(wc, buf); - } - } - - private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst) - throws HiveException { - srcBuf.flip(); - try { - dst.write(srcBuf); - } catch (IOException e) { - throw new HiveException("Exception causes while writing a buffer to file", e); - } - srcBuf.clear(); - } - - @Override - public void close() throws HiveException { - if (count == 0) { - this.model = null; - return; - } - if (miniBatchCount > 0) { // update for remaining samples - model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); - } - if (iterations > 1) { - runIterativeTraining(iterations); - } - forwardModel(); - this.model = null; - } - - protected final void runIterativeTraining(@Nonnegative final int iterations) - throws HiveException { - final ByteBuffer buf = this.inputBuf; - final NioStatefullSegment dst = this.fileIO; - assert (buf != null); - assert (dst != null); - final long numTrainingExamples = count; - - final Reporter reporter = getReporter(); - final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( - "hivemall.lda.OnlineLDA$Counter", "iteration"); - - try { - if (dst.getPosition() == 0L) {// run iterations w/o temporary file - if (buf.position() == 0) { - return; // no training example - } - buf.flip(); - - int iter = 2; - float perplexityPrev = Float.MAX_VALUE; - float perplexity; - int numTrain; - for (; iter <= iterations; iter++) { - perplexity = 0.f; - numTrain = 0; - - reportProgress(reporter); - setCounterValue(iterCounter, iter); - - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - - while (buf.remaining() > 0) { - int recordBytes = buf.getInt(); - assert (recordBytes > 0) : recordBytes; - int wcLength = buf.getInt(); - final String[] wordCounts = new String[wcLength]; - for (int j = 0; j < wcLength; j++) { - wordCounts[j] = NIOUtils.getString(buf); - } - - miniBatch[miniBatchCount] = wordCounts; - miniBatchCount++; - - if (miniBatchCount == miniBatchSize) { - model.train(miniBatch); - perplexity += model.computePerplexity(); - numTrain++; - - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - } - } - buf.rewind(); - - // update for remaining samples - if (miniBatchCount > 0) { // update for remaining samples - model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); - perplexity += model.computePerplexity(); - numTrain++; - } - - logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); - perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches - if (Math.abs(perplexityPrev - perplexity) < eps) { - break; - } - perplexityPrev = perplexity; - } - logger.info("Performed " - + Math.min(iter, iterations) - + " iterations of " - + NumberUtils.formatNumber(numTrainingExamples) - + " training examples on memory (thus " - + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) - + " training updates in total) "); - } else {// read training examples in the temporary file and invoke train for each example - // write training examples in buffer to a temporary file - if (buf.remaining() > 0) { - writeBuffer(buf, dst); - } - try { - dst.flush(); - } catch (IOException e) { - throw new HiveException("Failed to flush a file: " - + dst.getFile().getAbsolutePath(), e); - } - if (logger.isInfoEnabled()) { - File tmpFile = dst.getFile(); - logger.info("Wrote " + numTrainingExamples - + " records to a temporary file for iterative training: " - + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) - + ")"); - } - - // run iterations - int iter = 2; - float perplexityPrev = Float.MAX_VALUE; - float perplexity; - int numTrain; - for (; iter <= iterations; iter++) { - perplexity = 0.f; - numTrain = 0; - - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - - setCounterValue(iterCounter, iter); - - buf.clear(); - dst.resetPosition(); - while (true) { - reportProgress(reporter); - // TODO prefetch - // writes training examples to a buffer in the temporary file - final int bytesRead; - try { - bytesRead = dst.read(buf); - } catch (IOException e) { - throw new HiveException("Failed to read a file: " - + dst.getFile().getAbsolutePath(), e); - } - if (bytesRead == 0) { // reached file EOF - break; - } - assert (bytesRead > 0) : bytesRead; - - // reads training examples from a buffer - buf.flip(); - int remain = buf.remaining(); - if (remain < SizeOf.INT) { - throw new HiveException("Illegal file format was detected"); - } - while (remain >= SizeOf.INT) { - int pos = buf.position(); - int recordBytes = buf.getInt() - SizeOf.INT; - remain -= SizeOf.INT; - if (remain < recordBytes) { - buf.position(pos); - break; - } - - int wcLength = buf.getInt(); - final String[] wordCounts = new String[wcLength]; - for (int j = 0; j < wcLength; j++) { - wordCounts[j] = NIOUtils.getString(buf); - } - - miniBatch[miniBatchCount] = wordCounts; - miniBatchCount++; - - if (miniBatchCount == miniBatchSize) { - model.train(miniBatch); - perplexity += model.computePerplexity(); - numTrain++; - - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - } - - remain -= recordBytes; - } - buf.compact(); - } - - // update for remaining samples - if (miniBatchCount > 0) { // update for remaining samples - model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); - perplexity += model.computePerplexity(); - numTrain++; - } - - logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); - perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches - if (Math.abs(perplexityPrev - perplexity) < eps) { - break; - } - perplexityPrev = perplexity; - } - logger.info("Performed " - + Math.min(iter, iterations) - + " iterations of " - + NumberUtils.formatNumber(numTrainingExamples) - + " training examples on a secondary storage (thus " - + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) - + " training updates in total)"); - } - } catch (Throwable e) { - throw new HiveException("Exception caused in the iterative training", e); - } finally { - // delete the temporary file and release resources - try { - dst.close(true); - } catch (IOException e) { - throw new HiveException("Failed to close a file: " - + dst.getFile().getAbsolutePath(), e); - } - this.inputBuf = null; - this.fileIO = null; - } - } - - protected void forwardModel() throws HiveException { - final IntWritable topicIdx = new IntWritable(); - final Text word = new Text(); - final FloatWritable score = new FloatWritable(); - - final Object[] forwardObjs = new Object[3]; - forwardObjs[0] = topicIdx; - forwardObjs[1] = word; - forwardObjs[2] = score; - - for (int k = 0; k < topics; k++) { - topicIdx.set(k); - - final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k); - for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { - score.set(e.getKey()); - List<String> words = e.getValue(); - for (int i = 0; i < words.size(); i++) { - word.set(words.get(i)); - forward(forwardObjs); - } - } - } - - logger.info("Forwarded topic words each of " + topics + " topics"); - } - - /* - * For testing: - */ - - @VisibleForTesting - void closeWithoutModelReset() throws HiveException { - // launch close(), but not forward & clear model - if (count == 0) { - this.model = null; - return; - } - if (miniBatchCount > 0) { // update for remaining samples - model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); - } - if (iterations > 1) { - runIterativeTraining(iterations); - } - } - - @VisibleForTesting - double getLambda(String label, int k) { - return model.getLambda(label, k); - } - - @VisibleForTesting - SortedMap<Float, List<String>> getTopicWords(int k) { - return model.getTopicWords(k); - } - - @VisibleForTesting - SortedMap<Float, List<String>> getTopicWords(int k, int topN) { - return model.getTopicWords(k, topN); - } - - @VisibleForTesting - float[] getTopicDistribution(@Nonnull String[] doc) { - return model.getTopicDistribution(doc); + protected AbstractProbabilisticTopicModel createModel() { + return new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java index 8fef10c..4a7531c 100644 --- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java +++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java @@ -19,7 +19,6 @@ package hivemall.topicmodel; import hivemall.annotations.VisibleForTesting; -import hivemall.model.FeatureValue; import hivemall.utils.lang.ArrayUtils; import hivemall.utils.math.MathUtils; @@ -37,24 +36,17 @@ import javax.annotation.Nonnull; import org.apache.commons.math3.distribution.GammaDistribution; import org.apache.commons.math3.special.Gamma; -public final class OnlineLDAModel { +public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { // --------------------------------- // HyperParameters - // number of topics - private final int _K; - // prior on weight vectors "theta ~ Dir(alpha_)" private final float _alpha; // prior on topics "beta" private final float _eta; - // total number of documents - // in the truly online setting, this can be an estimate of the maximum number of documents that could ever seen - private long _D = -1L; - // positive value which downweights early iterations @Nonnegative private final double _tau0; @@ -75,6 +67,10 @@ public final class OnlineLDAModel { // controls how much old lambda is forgotten private double _rhot; + // if `num_docs` option is not given, this flag will be true + // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model + private final boolean _isAutoD; + // parameters @Nonnull private List<Map<String, float[]>> _phi; @@ -88,11 +84,6 @@ public final class OnlineLDAModel { private static final double SHAPE = 100.d; private static final double SCALE = 1.d / SHAPE; - // for mini-batch - @Nonnull - private final List<Map<String, Float>> _miniBatchDocs; - private int _miniBatchSize; - // for computing perplexity private float _docRatio = 1.f; private double _valueSum = 0.d; @@ -103,6 +94,8 @@ public final class OnlineLDAModel { public OnlineLDAModel(int K, float alpha, float eta, long D, double tau0, double kappa, double delta) { + super(K); + if (tau0 < 0.d) { throw new IllegalArgumentException("tau0 MUST be positive: " + tau0); } @@ -110,7 +103,6 @@ public final class OnlineLDAModel { throw new IllegalArgumentException("kappa MUST be in (0.5, 1.0]: " + kappa); } - this._K = K; this._alpha = alpha; this._eta = eta; this._D = D; @@ -118,31 +110,30 @@ public final class OnlineLDAModel { this._kappa = kappa; this._delta = delta; + this._isAutoD = (_D < 0L); + // initialize a random number generator this._gd = new GammaDistribution(SHAPE, SCALE); _gd.reseedRandomGenerator(1001); // initialize the parameters this._lambda = new HashMap<String, float[]>(100); - - this._miniBatchDocs = new ArrayList<Map<String, Float>>(); } - /** - * In a truly online setting, total number of documents corresponds to the number of documents that have ever seen. In that case, users need to - * manually set the current max number of documents via this method. Note that, since the same set of documents could be repeatedly passed to - * `train()`, simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient. - */ - public void setNumTotalDocs(@Nonnegative long D) { - this._D = D; + @Override + public void accumulateDocCount() { + /* + * In a truly online setting, total number of documents equals to the number of documents that have ever seen. + * In that case, users need to manually set the current max number of documents via this method. + * Note that, since the same set of documents could be repeatedly passed to `train()`, + * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient. + */ + if (_isAutoD) { + this._D += 1; + } } public void train(@Nonnull final String[][] miniBatch) { - if (_D <= 0L) { - throw new IllegalStateException( - "Total number of documents MUST be set via `setNumTotalDocs()`"); - } - preprocessMiniBatch(miniBatch); initParams(true); @@ -175,35 +166,6 @@ public final class OnlineLDAModel { this._docRatio = (float) ((double) _D / _miniBatchSize); } - private static void initMiniBatch(@Nonnull final String[][] miniBatch, - @Nonnull final List<Map<String, Float>> docs) { - docs.clear(); - - final FeatureValue probe = new FeatureValue(); - - // parse document - for (final String[] e : miniBatch) { - if (e == null || e.length == 0) { - continue; - } - - final Map<String, Float> doc = new HashMap<String, Float>(); - - // parse features - for (String fv : e) { - if (fv == null) { - continue; - } - FeatureValue.parseFeatureAsString(fv, probe); - String label = probe.getFeatureAsString(); - float value = probe.getValueAsFloat(); - doc.put(label, Float.valueOf(value)); - } - - docs.add(doc); - } - } - private void initParams(final boolean gammaWithRandom) { final List<Map<String, float[]>> phi = new ArrayList<Map<String, float[]>>(); final float[][] gamma = new float[_miniBatchSize][]; @@ -475,7 +437,7 @@ public final class OnlineLDAModel { } @VisibleForTesting - double getLambda(@Nonnull final String label, @Nonnegative final int k) { + float getWordScore(@Nonnull final String label, @Nonnegative final int k) { final float[] lambda_label = _lambda.get(label); if (lambda_label == null) { throw new IllegalArgumentException("Word `" + label + "` is not in the corpus."); @@ -487,7 +449,7 @@ public final class OnlineLDAModel { return lambda_label[k]; } - public void setLambda(@Nonnull final String label, @Nonnegative final int k, + public void setWordScore(@Nonnull final String label, @Nonnegative final int k, final float lambda_k) { float[] lambda_label = _lambda.get(label); if (lambda_label == null) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java index ff29236..7702945 100644 --- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java +++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java @@ -474,7 +474,7 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver { for (int k = 0; k < topics; k++) { final float prob_k = prob_word.get(k).floatValue(); if (prob_k != -1.f) { - model.setProbability(word, k, prob_k); + model.setWordScore(word, k, prob_k); } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java index 46f731f..e1d8797 100644 --- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java @@ -18,27 +18,7 @@ */ package hivemall.topicmodel; -import hivemall.UDTFWithOptions; -import hivemall.annotations.VisibleForTesting; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.io.FileUtils; -import hivemall.utils.io.NIOUtils; -import hivemall.utils.io.NioStatefullSegment; -import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Primitives; -import hivemall.utils.lang.SizeOf; - -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.SortedMap; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; @@ -46,503 +26,49 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.io.FloatWritable; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapred.Counters; -import org.apache.hadoop.mapred.Reporter; @Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])" + " - Returns a relation consists of <int topic, string word, float score>") -public class PLSAUDTF extends UDTFWithOptions { +public class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF { private static final Log logger = LogFactory.getLog(PLSAUDTF.class); - public static final int DEFAULT_TOPICS = 10; public static final float DEFAULT_ALPHA = 0.5f; public static final double DEFAULT_DELTA = 1E-3d; // Options - protected int topics; protected float alpha; - protected int iterations; protected double delta; - protected double eps; - protected int miniBatchSize; - - // number of proceeded training samples - protected long count; - - protected String[][] miniBatch; - protected int miniBatchCount; - - protected transient IncrementalPLSAModel model; - - protected ListObjectInspector wordCountsOI; - - // for iterations - protected NioStatefullSegment fileIO; - protected ByteBuffer inputBuf; public PLSAUDTF() { - this.topics = DEFAULT_TOPICS; + super(); + this.alpha = DEFAULT_ALPHA; - this.iterations = 10; this.delta = DEFAULT_DELTA; - this.eps = 1E-1d; - this.miniBatchSize = 128; } @Override protected Options getOptions() { - Options opts = new Options(); - opts.addOption("k", "topics", true, "The number of topics [default: 10]"); + Options opts = super.getOptions(); opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]"); - opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]"); - opts.addOption("eps", "epsilon", true, - "Check convergence based on the difference of perplexity [default: 1E-1]"); - opts.addOption("s", "mini_batch_size", true, - "Repeat model updating per mini-batch [default: 128]"); return opts; } @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { - CommandLine cl = null; + CommandLine cl = super.processOptions(argOIs); - if (argOIs.length >= 2) { - String rawArgs = HiveUtils.getConstString(argOIs[1]); - cl = parseOptions(rawArgs); - this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS); + if (cl != null) { this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), DEFAULT_ALPHA); - this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10); - if (iterations < 1) { - throw new UDFArgumentException( - "'-iterations' must be greater than or equals to 1: " + iterations); - } this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA); - this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d); - this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128); } return cl; } - @Override - public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - if (argOIs.length < 1) { - throw new UDFArgumentException( - "_FUNC_ takes 1 arguments: array<string> words [, const string options]"); - } - - this.wordCountsOI = HiveUtils.asListOI(argOIs[0]); - HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector()); - - processOptions(argOIs); - - this.model = null; - this.count = 0L; - this.miniBatch = new String[miniBatchSize][]; - this.miniBatchCount = 0; - - ArrayList<String> fieldNames = new ArrayList<String>(); - ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); - fieldNames.add("topic"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); - fieldNames.add("word"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); - fieldNames.add("score"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); - - return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); - } - - protected void initModel() { - this.model = new IncrementalPLSAModel(topics, alpha, delta); - } - - @Override - public void process(Object[] args) throws HiveException { - if (model == null) { - initModel(); - } - - int length = wordCountsOI.getListLength(args[0]); - String[] wordCounts = new String[length]; - int j = 0; - for (int i = 0; i < length; i++) { - Object o = wordCountsOI.getListElement(args[0], i); - if (o == null) { - throw new HiveException("Given feature vector contains invalid elements"); - } - String s = o.toString(); - wordCounts[j] = s; - j++; - } - if (j == 0) {// avoid empty documents - return; - } - - count++; - - recordTrainSampleToTempFile(wordCounts); - - miniBatch[miniBatchCount] = wordCounts; - miniBatchCount++; - - if (miniBatchCount == miniBatchSize) { - model.train(miniBatch); - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - } - } - - protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts) - throws HiveException { - if (iterations == 1) { - return; - } - - ByteBuffer buf = inputBuf; - NioStatefullSegment dst = fileIO; - - if (buf == null) { - final File file; - try { - file = File.createTempFile("hivemall_plsa", ".sgmt"); - file.deleteOnExit(); - if (!file.canWrite()) { - throw new UDFArgumentException("Cannot write a temporary file: " - + file.getAbsolutePath()); - } - logger.info("Record training samples to a file: " + file.getAbsolutePath()); - } catch (IOException ioe) { - throw new UDFArgumentException(ioe); - } catch (Throwable e) { - throw new UDFArgumentException(e); - } - this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB - this.fileIO = dst = new NioStatefullSegment(file, false); - } - - // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... - int wcLengthTotal = 0; - for (String wc : wordCounts) { - if (wc == null) { - continue; - } - wcLengthTotal += wc.length(); - } - int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * wordCounts.length + wcLengthTotal - * SizeOf.CHAR; - - int remain = buf.remaining(); - if (remain < requiredRecordBytes) { - writeBuffer(buf, dst); - } - - buf.putInt(requiredRecordBytes); - buf.putInt(wordCounts.length); - for (String wc : wordCounts) { - NIOUtils.putString(wc, buf); - } - } - - private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst) - throws HiveException { - srcBuf.flip(); - try { - dst.write(srcBuf); - } catch (IOException e) { - throw new HiveException("Exception causes while writing a buffer to file", e); - } - srcBuf.clear(); + protected AbstractProbabilisticTopicModel createModel() { + return new IncrementalPLSAModel(topics, alpha, delta); } - @Override - public void close() throws HiveException { - if (count == 0) { - this.model = null; - return; - } - if (miniBatchCount > 0) { // update for remaining samples - model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); - } - if (iterations > 1) { - runIterativeTraining(iterations); - } - forwardModel(); - this.model = null; - } - - protected final void runIterativeTraining(@Nonnegative final int iterations) - throws HiveException { - final ByteBuffer buf = this.inputBuf; - final NioStatefullSegment dst = this.fileIO; - assert (buf != null); - assert (dst != null); - final long numTrainingExamples = count; - - final Reporter reporter = getReporter(); - final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( - "hivemall.plsa.IncrementalPLSA$Counter", "iteration"); - - try { - if (dst.getPosition() == 0L) {// run iterations w/o temporary file - if (buf.position() == 0) { - return; // no training example - } - buf.flip(); - - int iter = 2; - float perplexityPrev = Float.MAX_VALUE; - float perplexity; - int numTrain; - for (; iter <= iterations; iter++) { - perplexity = 0.f; - numTrain = 0; - - reportProgress(reporter); - setCounterValue(iterCounter, iter); - - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - - while (buf.remaining() > 0) { - int recordBytes = buf.getInt(); - assert (recordBytes > 0) : recordBytes; - int wcLength = buf.getInt(); - final String[] wordCounts = new String[wcLength]; - for (int j = 0; j < wcLength; j++) { - wordCounts[j] = NIOUtils.getString(buf); - } - - miniBatch[miniBatchCount] = wordCounts; - miniBatchCount++; - - if (miniBatchCount == miniBatchSize) { - model.train(miniBatch); - perplexity += model.computePerplexity(); - numTrain++; - - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - } - } - buf.rewind(); - - // update for remaining samples - if (miniBatchCount > 0) { // update for remaining samples - model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); - perplexity += model.computePerplexity(); - numTrain++; - } - - logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); - perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches - if (Math.abs(perplexityPrev - perplexity) < eps) { - break; - } - perplexityPrev = perplexity; - } - logger.info("Performed " - + Math.min(iter, iterations) - + " iterations of " - + NumberUtils.formatNumber(numTrainingExamples) - + " training examples on memory (thus " - + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) - + " training updates in total) "); - } else {// read training examples in the temporary file and invoke train for each example - - // write training examples in buffer to a temporary file - if (buf.remaining() > 0) { - writeBuffer(buf, dst); - } - try { - dst.flush(); - } catch (IOException e) { - throw new HiveException("Failed to flush a file: " - + dst.getFile().getAbsolutePath(), e); - } - if (logger.isInfoEnabled()) { - File tmpFile = dst.getFile(); - logger.info("Wrote " + numTrainingExamples - + " records to a temporary file for iterative training: " - + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) - + ")"); - } - - // run iterations - int iter = 2; - float perplexityPrev = Float.MAX_VALUE; - float perplexity; - int numTrain; - for (; iter <= iterations; iter++) { - perplexity = 0.f; - numTrain = 0; - - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - - setCounterValue(iterCounter, iter); - - buf.clear(); - dst.resetPosition(); - while (true) { - reportProgress(reporter); - // TODO prefetch - // writes training examples to a buffer in the temporary file - final int bytesRead; - try { - bytesRead = dst.read(buf); - } catch (IOException e) { - throw new HiveException("Failed to read a file: " - + dst.getFile().getAbsolutePath(), e); - } - if (bytesRead == 0) { // reached file EOF - break; - } - assert (bytesRead > 0) : bytesRead; - - // reads training examples from a buffer - buf.flip(); - int remain = buf.remaining(); - if (remain < SizeOf.INT) { - throw new HiveException("Illegal file format was detected"); - } - while (remain >= SizeOf.INT) { - int pos = buf.position(); - int recordBytes = buf.getInt() - SizeOf.INT; - remain -= SizeOf.INT; - if (remain < recordBytes) { - buf.position(pos); - break; - } - - int wcLength = buf.getInt(); - final String[] wordCounts = new String[wcLength]; - for (int j = 0; j < wcLength; j++) { - wordCounts[j] = NIOUtils.getString(buf); - } - - miniBatch[miniBatchCount] = wordCounts; - miniBatchCount++; - - if (miniBatchCount == miniBatchSize) { - model.train(miniBatch); - perplexity += model.computePerplexity(); - numTrain++; - - Arrays.fill(miniBatch, null); // clear - miniBatchCount = 0; - } - - remain -= recordBytes; - } - buf.compact(); - } - - // update for remaining samples - if (miniBatchCount > 0) { // update for remaining samples - model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); - perplexity += model.computePerplexity(); - numTrain++; - } - - logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); - perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches - if (Math.abs(perplexityPrev - perplexity) < eps) { - break; - } - perplexityPrev = perplexity; - } - logger.info("Performed " - + Math.min(iter, iterations) - + " iterations of " - + NumberUtils.formatNumber(numTrainingExamples) - + " training examples on a secondary storage (thus " - + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) - + " training updates in total)"); - } - } catch (Throwable e) { - throw new HiveException("Exception caused in the iterative training", e); - } finally { - // delete the temporary file and release resources - try { - dst.close(true); - } catch (IOException e) { - throw new HiveException("Failed to close a file: " - + dst.getFile().getAbsolutePath(), e); - } - this.inputBuf = null; - this.fileIO = null; - } - } - - protected void forwardModel() throws HiveException { - final IntWritable topicIdx = new IntWritable(); - final Text word = new Text(); - final FloatWritable score = new FloatWritable(); - - final Object[] forwardObjs = new Object[3]; - forwardObjs[0] = topicIdx; - forwardObjs[1] = word; - forwardObjs[2] = score; - - for (int k = 0; k < topics; k++) { - topicIdx.set(k); - - final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k); - for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { - score.set(e.getKey()); - List<String> words = e.getValue(); - for (int i = 0; i < words.size(); i++) { - word.set(words.get(i)); - forward(forwardObjs); - } - } - } - - logger.info("Forwarded topic words each of " + topics + " topics"); - } - - /* - * For testing: - */ - - @VisibleForTesting - public void closeWithoutModelReset() throws HiveException { - // launch close(), but not forward & clear model - if (count == 0) { - this.model = null; - return; - } - if (miniBatchCount > 0) { // update for remaining samples - model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); - } - if (iterations > 1) { - runIterativeTraining(iterations); - } - } - - @VisibleForTesting - double getProbability(String label, int k) { - return model.getProbability(label, k); - } - - @VisibleForTesting - SortedMap<Float, List<String>> getTopicWords(int k) { - return model.getTopicWords(k); - } - - @VisibleForTesting - float[] getTopicDistribution(@Nonnull String[] doc) { - return model.getTopicDistribution(doc); - } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java new file mode 100644 index 0000000..cff076e --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NIOUtils; +import hivemall.utils.io.NioStatefullSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.SizeOf; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; + +public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { + private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class); + + public static final int DEFAULT_TOPICS = 10; + + // Options + protected int topics; + protected int iterations; + protected double eps; + protected int miniBatchSize; + + protected String[][] miniBatch; + protected int miniBatchCount; + + protected transient AbstractProbabilisticTopicModel model; + + protected ListObjectInspector wordCountsOI; + + // for iterations + protected NioStatefullSegment fileIO; + protected ByteBuffer inputBuf; + + private float cumPerplexity; + + public ProbabilisticTopicModelBaseUDTF() { + this.topics = DEFAULT_TOPICS; + this.iterations = 10; + this.eps = 1E-1d; + this.miniBatchSize = 128; // if 1, truly online setting + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "topics", true, "The number of topics [default: 10]"); + opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); + opts.addOption("eps", "epsilon", true, + "Check convergence based on the difference of perplexity [default: 1E-1]"); + opts.addOption("s", "mini_batch_size", true, + "Repeat model updating per mini-batch [default: 128]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + if (argOIs.length >= 2) { + String rawArgs = HiveUtils.getConstString(argOIs[1]); + cl = parseOptions(rawArgs); + this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS); + this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10); + if (iterations < 1) { + throw new UDFArgumentException( + "'-iterations' must be greater than or equals to 1: " + iterations); + } + this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d); + this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128); + } + + return cl; + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 1) { + throw new UDFArgumentException( + "_FUNC_ takes 1 arguments: array<string> words [, const string options]"); + } + + this.wordCountsOI = HiveUtils.asListOI(argOIs[0]); + HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector()); + + processOptions(argOIs); + + this.model = null; + this.miniBatch = new String[miniBatchSize][]; + this.miniBatchCount = 0; + this.cumPerplexity = 0.f; + + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("topic"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("word"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("score"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + protected abstract AbstractProbabilisticTopicModel createModel(); + + @Override + public void process(Object[] args) throws HiveException { + if (model == null) { + this.model = createModel(); + } + + final int length = wordCountsOI.getListLength(args[0]); + final String[] wordCounts = new String[length]; + int j = 0; + for (int i = 0; i < length; i++) { + Object o = wordCountsOI.getListElement(args[0], i); + if (o == null) { + throw new HiveException("Given feature vector contains invalid elements"); + } + String s = o.toString(); + wordCounts[j] = s; + j++; + } + if (j == 0) {// avoid empty documents + return; + } + + model.accumulateDocCount();; + + update(wordCounts); + + recordTrainSampleToTempFile(wordCounts); + } + + protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts) + throws HiveException { + if (iterations == 1) { + return; + } + + ByteBuffer buf = inputBuf; + NioStatefullSegment dst = fileIO; + + if (buf == null) { + final File file; + try { + file = File.createTempFile("hivemall_topicmodel", ".sgmt"); + file.deleteOnExit(); + if (!file.canWrite()) { + throw new UDFArgumentException("Cannot write a temporary file: " + + file.getAbsolutePath()); + } + logger.info("Record training samples to a file: " + file.getAbsolutePath()); + } catch (IOException ioe) { + throw new UDFArgumentException(ioe); + } catch (Throwable e) { + throw new UDFArgumentException(e); + } + this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB + this.fileIO = dst = new NioStatefullSegment(file, false); + } + + // wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... + int wcLengthTotal = 0; + for (String wc : wordCounts) { + if (wc == null) { + continue; + } + wcLengthTotal += wc.length(); + } + int recordBytes = SizeOf.INT + SizeOf.INT * wordCounts.length + wcLengthTotal * SizeOf.CHAR; + int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself + + int remain = buf.remaining(); + if (remain < requiredBytes) { + writeBuffer(buf, dst); + } + + buf.putInt(recordBytes); + buf.putInt(wordCounts.length); + for (String wc : wordCounts) { + NIOUtils.putString(wc, buf); + } + } + + private void update(@Nonnull final String[] wordCounts) { + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + train(); + } + } + + protected void train() { + if (miniBatchCount == 0) { + return; + } + + model.train(miniBatch); + + this.cumPerplexity += model.computePerplexity(); + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + + private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst) + throws HiveException { + srcBuf.flip(); + try { + dst.write(srcBuf); + } catch (IOException e) { + throw new HiveException("Exception causes while writing a buffer to file", e); + } + srcBuf.clear(); + } + + @Override + public void close() throws HiveException { + finalizeTraining(); + forwardModel(); + this.model = null; + } + + @VisibleForTesting + void finalizeTraining() throws HiveException { + if (model.getDocCount() == 0L) { + this.model = null; + return; + } + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + } + if (iterations > 1) { + runIterativeTraining(iterations); + } + } + + protected final void runIterativeTraining(@Nonnegative final int iterations) + throws HiveException { + final ByteBuffer buf = this.inputBuf; + final NioStatefullSegment dst = this.fileIO; + assert (buf != null); + assert (dst != null); + final long numTrainingExamples = model.getDocCount(); + + long numTrain = numTrainingExamples / miniBatchSize; + if (numTrainingExamples % miniBatchSize != 0L) { + numTrain++; + } + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( + "hivemall.topicmodel.ProbabilisticTopicModel$Counter", "iteration"); + + try { + if (dst.getPosition() == 0L) {// run iterations w/o temporary file + if (buf.position() == 0) { + return; // no training example + } + buf.flip(); + + int iter = 2; + float perplexity = cumPerplexity / numTrain; + float perplexityPrev; + for (; iter <= iterations; iter++) { + perplexityPrev = perplexity; + cumPerplexity = 0.f; + + reportProgress(reporter); + setCounterValue(iterCounter, iter); + + while (buf.remaining() > 0) { + int recordBytes = buf.getInt(); + assert (recordBytes > 0) : recordBytes; + int wcLength = buf.getInt(); + final String[] wordCounts = new String[wcLength]; + for (int j = 0; j < wcLength; j++) { + wordCounts[j] = NIOUtils.getString(buf); + } + update(wordCounts); + } + buf.rewind(); + + // mean perplexity over `numTrain` mini-batches + perplexity = cumPerplexity / numTrain; + logger.info("Mean perplexity over mini-batches: " + perplexity); + if (Math.abs(perplexityPrev - perplexity) < eps) { + break; + } + } + logger.info("Performed " + + Math.min(iter, iterations) + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on memory (thus " + + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + + " training updates in total) "); + } else {// read training examples in the temporary file and invoke train for each example + // write training examples in buffer to a temporary file + if (buf.remaining() > 0) { + writeBuffer(buf, dst); + } + try { + dst.flush(); + } catch (IOException e) { + throw new HiveException("Failed to flush a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (logger.isInfoEnabled()) { + File tmpFile = dst.getFile(); + logger.info("Wrote " + numTrainingExamples + + " records to a temporary file for iterative training: " + + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + + ")"); + } + + // run iterations + int iter = 2; + float perplexity = cumPerplexity / numTrain; + float perplexityPrev; + for (; iter <= iterations; iter++) { + perplexityPrev = perplexity; + cumPerplexity = 0.f; + + setCounterValue(iterCounter, iter); + + buf.clear(); + dst.resetPosition(); + while (true) { + reportProgress(reporter); + // TODO prefetch + // writes training examples to a buffer in the temporary file + final int bytesRead; + try { + bytesRead = dst.read(buf); + } catch (IOException e) { + throw new HiveException("Failed to read a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (bytesRead == 0) { // reached file EOF + break; + } + assert (bytesRead > 0) : bytesRead; + + // reads training examples from a buffer + buf.flip(); + int remain = buf.remaining(); + if (remain < SizeOf.INT) { + throw new HiveException("Illegal file format was detected"); + } + while (remain >= SizeOf.INT) { + int pos = buf.position(); + int recordBytes = buf.getInt(); + remain -= SizeOf.INT; + if (remain < recordBytes) { + buf.position(pos); + break; + } + + int wcLength = buf.getInt(); + final String[] wordCounts = new String[wcLength]; + for (int j = 0; j < wcLength; j++) { + wordCounts[j] = NIOUtils.getString(buf); + } + update(wordCounts); + + remain -= recordBytes; + } + buf.compact(); + } + + // mean perplexity over `numTrain` mini-batches + perplexity = cumPerplexity / numTrain; + logger.info("Mean perplexity over mini-batches: " + perplexity); + if (Math.abs(perplexityPrev - perplexity) < eps) { + break; + } + } + logger.info("Performed " + + Math.min(iter, iterations) + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on a secondary storage (thus " + + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + + " training updates in total)"); + } + } catch (Throwable e) { + throw new HiveException("Exception caused in the iterative training", e); + } finally { + // delete the temporary file and release resources + try { + dst.close(true); + } catch (IOException e) { + throw new HiveException("Failed to close a file: " + + dst.getFile().getAbsolutePath(), e); + } + this.inputBuf = null; + this.fileIO = null; + } + } + + protected void forwardModel() throws HiveException { + final IntWritable topicIdx = new IntWritable(); + final Text word = new Text(); + final FloatWritable score = new FloatWritable(); + + final Object[] forwardObjs = new Object[3]; + forwardObjs[0] = topicIdx; + forwardObjs[1] = word; + forwardObjs[2] = score; + + for (int k = 0; k < topics; k++) { + topicIdx.set(k); + + final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + score.set(e.getKey()); + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + word.set(words.get(i)); + forward(forwardObjs); + } + } + } + + logger.info("Forwarded topic words each of " + topics + " topics"); + } + + @VisibleForTesting + float getWordScore(String label, int k) { + return model.getWordScore(label, k); + } + + @VisibleForTesting + SortedMap<Float, List<String>> getTopicWords(int k) { + return model.getTopicWords(k); + } + + @VisibleForTesting + float[] getTopicDistribution(@Nonnull String[] doc) { + return model.getTopicDistribution(doc); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java index 79be3a7..96bbe64 100644 --- a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java +++ b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java @@ -110,10 +110,10 @@ public class IncrementalPLSAModelTest { } Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", - model.getProbability("vegetables", k1) > model.getProbability("flu", k1)); + model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1)); Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", - model.getProbability("avocados", k2) > model.getProbability("healthy", k2)); + model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2)); } @Test @@ -177,10 +177,10 @@ public class IncrementalPLSAModelTest { } Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", - model.getProbability("vegetables", k1) > model.getProbability("flu", k1)); + model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1)); Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", - model.getProbability("avocados", k2) > model.getProbability("healthy", k2)); + model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2)); } @Test http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java index a934ba3..4cbb668 100644 --- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java +++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java @@ -53,7 +53,7 @@ public class LDAUDTFTest { udtf.process(new Object[] {Arrays.asList(doc1)}); udtf.process(new Object[] {Arrays.asList(doc2)}); - udtf.closeWithoutModelReset(); + udtf.finalizeTraining(); SortedMap<Float, List<String>> topicWords; @@ -92,10 +92,10 @@ public class LDAUDTFTest { Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", - udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1)); + udtf.getWordScore("vegetables", k1) > udtf.getWordScore("flu", k1)); Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", - udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2)); + udtf.getWordScore("avocados", k2) > udtf.getWordScore("healthy", k2)); } @Test @@ -106,7 +106,7 @@ public class LDAUDTFTest { ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, - "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3")}; + "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3 -mini_batch_size 1")}; udtf.initialize(argOIs); @@ -116,7 +116,7 @@ public class LDAUDTFTest { udtf.process(new Object[] {Arrays.asList(doc1)}); udtf.process(new Object[] {Arrays.asList(doc2)}); - udtf.closeWithoutModelReset(); + udtf.finalizeTraining(); SortedMap<Float, List<String>> topicWords; @@ -155,10 +155,10 @@ public class LDAUDTFTest { Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + "and `éè` SHOULD be more suitable topic word than `ã¤ã³ãã«ã¨ã³ã¶` in the topic", - udtf.getLambda("éè", k1) > udtf.getLambda("ã¤ã³ãã«ã¨ã³ã¶", k1)); + udtf.getWordScore("éè", k1) > udtf.getWordScore("ã¤ã³ãã«ã¨ã³ã¶", k1)); Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + "and `ã¢ãã«ã` SHOULD be more suitable topic word than `å¥åº·` in the topic", - udtf.getLambda("ã¢ãã«ã", k2) > udtf.getLambda("å¥åº·", k2)); + udtf.getWordScore("ã¢ãã«ã", k2) > udtf.getWordScore("å¥åº·", k2)); } private static void println(String msg) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java index 5b0a8c2..68f251a 100644 --- a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java +++ b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java @@ -108,10 +108,10 @@ public class OnlineLDAModelTest { } Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", - model.getLambda("vegetables", k1) > model.getLambda("flu", k1)); + model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1)); Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", - model.getLambda("avocados", k2) > model.getLambda("healthy", k2)); + model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2)); } @Test
