Applied refactoring for topicmodel module Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9f01ebf2 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9f01ebf2 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9f01ebf2
Branch: refs/heads/master Commit: 9f01ebf20c74559be8a50d459103118a51c229bf Parents: 0495ffa Author: Makoto Yui <[email protected]> Authored: Tue Jun 27 15:44:31 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Tue Jun 27 15:44:31 2017 +0900 ---------------------------------------------------------------------- .../AbstractProbabilisticTopicModel.java | 26 +++++++++++++------- .../topicmodel/IncrementalPLSAModel.java | 16 ++++++------ .../main/java/hivemall/topicmodel/LDAUDTF.java | 5 +--- .../hivemall/topicmodel/OnlineLDAModel.java | 18 +++++++------- .../main/java/hivemall/topicmodel/PLSAUDTF.java | 5 +--- .../ProbabilisticTopicModelBaseUDTF.java | 25 ++++++++++++------- 6 files changed, 52 insertions(+), 43 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java index 3c097e2..1b7f3e8 100644 --- a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java +++ b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java @@ -21,9 +21,14 @@ package hivemall.topicmodel; import hivemall.annotations.VisibleForTesting; import hivemall.model.FeatureValue; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; + import javax.annotation.Nonnegative; import javax.annotation.Nonnull; -import java.util.*; public abstract class AbstractProbabilisticTopicModel { @@ -31,6 +36,7 @@ public abstract class AbstractProbabilisticTopicModel { protected final int _K; // total number of documents + @Nonnegative protected long _D; // for mini-batch @@ -38,7 +44,7 @@ public abstract class AbstractProbabilisticTopicModel { protected final List<Map<String, Float>> _miniBatchDocs; protected int _miniBatchSize; - public AbstractProbabilisticTopicModel(int K) { + public AbstractProbabilisticTopicModel(@Nonnegative int K) { this._K = K; this._D = 0L; this._miniBatchDocs = new ArrayList<Map<String, Float>>(); @@ -73,26 +79,28 @@ public abstract class AbstractProbabilisticTopicModel { } } - public void accumulateDocCount() { + protected void accumulateDocCount() { this._D += 1; } - public long getDocCount() { + @Nonnegative + protected long getDocCount() { return _D; } - public abstract void train(@Nonnull final String[][] miniBatch); + protected abstract void train(@Nonnull final String[][] miniBatch); - public abstract float computePerplexity(); + protected abstract float computePerplexity(); @Nonnull - public abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k); + protected abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k); @Nonnull - public abstract float[] getTopicDistribution(@Nonnull final String[] doc); + protected abstract float[] getTopicDistribution(@Nonnull final String[] doc); @VisibleForTesting abstract float getWordScore(@Nonnull final String word, @Nonnegative final int topic); - public abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, final float score); + protected abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, + final float score); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java index b99e670..6419664 100644 --- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java +++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java @@ -20,7 +20,6 @@ package hivemall.topicmodel; import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray; import static hivemall.utils.math.MathUtils.l1normalize; - import hivemall.annotations.VisibleForTesting; import hivemall.math.random.PRNG; import hivemall.math.random.RandomNumberGeneratorFactory; @@ -59,9 +58,10 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair // optimized in the M step - @Nonnull private List<float[]> _p_dz; // P(z|d) probability of topics for documents - private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic + + @Nonnull + private final Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic public IncrementalPLSAModel(int K, float alpha, double delta) { super(K); @@ -74,7 +74,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel this._p_zw = new HashMap<String, float[]>(); } - public void train(@Nonnull final String[][] miniBatch) { + protected void train(@Nonnull final String[][] miniBatch) { initMiniBatch(miniBatch, _miniBatchDocs); this._miniBatchSize = _miniBatchDocs.size(); @@ -211,7 +211,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel return (diff / _K) < _delta; } - public float computePerplexity() { + protected float computePerplexity() { double numer = 0.d; double denom = 0.d; @@ -241,7 +241,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel } @Nonnull - public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) { + protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) { final SortedMap<Float, List<String>> res = new TreeMap<Float, List<String>>( Collections.reverseOrder()); @@ -261,7 +261,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel } @Nonnull - public float[] getTopicDistribution(@Nonnull final String[] doc) { + protected float[] getTopicDistribution(@Nonnull final String[] doc) { train(new String[][] {doc}); return _p_dz.get(0); } @@ -271,7 +271,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel return _p_zw.get(w)[z]; } - public void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) { + protected void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) { float[] prob_label = _p_zw.get(w); if (prob_label == null) { prob_label = newRandomFloatArray(_K, _rnd); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java index 41386a4..9bac908 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java @@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])" + " - Returns a relation consists of <int topic, string word, float score>") -public class LDAUDTF extends ProbabilisticTopicModelBaseUDTF { - private static final Log logger = LogFactory.getLog(LDAUDTF.class); +public final class LDAUDTF extends ProbabilisticTopicModelBaseUDTF { public static final double DEFAULT_DELTA = 1E-3d; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java index 4a7531c..6a8d6db 100644 --- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java +++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java @@ -38,6 +38,9 @@ import org.apache.commons.math3.special.Gamma; public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { + private static final double SHAPE = 100.d; + private static final double SCALE = 1.d / SHAPE; + // --------------------------------- // HyperParameters @@ -72,7 +75,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { private final boolean _isAutoD; // parameters - @Nonnull private List<Map<String, float[]>> _phi; private float[][] _gamma; @Nonnull @@ -81,8 +83,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { // random number generator @Nonnull private final GammaDistribution _gd; - private static final double SHAPE = 100.d; - private static final double SCALE = 1.d / SHAPE; // for computing perplexity private float _docRatio = 1.f; @@ -121,7 +121,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { } @Override - public void accumulateDocCount() { + protected void accumulateDocCount() { /* * In a truly online setting, total number of documents equals to the number of documents that have ever seen. * In that case, users need to manually set the current max number of documents via this method. @@ -133,7 +133,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { } } - public void train(@Nonnull final String[][] miniBatch) { + protected void train(@Nonnull final String[][] miniBatch) { preprocessMiniBatch(miniBatch); initParams(true); @@ -341,7 +341,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { /** * Calculate approximate perplexity for the current mini-batch. */ - public float computePerplexity() { + protected float computePerplexity() { double bound = computeApproxBound(); double perWordBound = bound / (_docRatio * _valueSum); return (float) Math.exp(-1.d * perWordBound); @@ -449,7 +449,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { return lambda_label[k]; } - public void setWordScore(@Nonnull final String label, @Nonnegative final int k, + protected void setWordScore(@Nonnull final String label, @Nonnegative final int k, final float lambda_k) { float[] lambda_label = _lambda.get(label); if (lambda_label == null) { @@ -460,7 +460,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { } @Nonnull - public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) { + protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) { return getTopicWords(k, _lambda.keySet().size()); } @@ -501,7 +501,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { } @Nonnull - public float[] getTopicDistribution(@Nonnull final String[] doc) { + protected float[] getTopicDistribution(@Nonnull final String[] doc) { preprocessMiniBatch(new String[][] {doc}); initParams(false); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java index e1d8797..9c5a0ea 100644 --- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java @@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])" + " - Returns a relation consists of <int topic, string word, float score>") -public class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF { - private static final Log logger = LogFactory.getLog(PLSAUDTF.class); +public final class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF { public static final float DEFAULT_ALPHA = 0.5f; public static final double DEFAULT_DELTA = 1E-3d; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java index cff076e..c3dab89 100644 --- a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java @@ -27,6 +27,19 @@ import hivemall.utils.io.NioStatefullSegment; import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Primitives; import hivemall.utils.lang.SizeOf; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.commons.logging.Log; @@ -44,13 +57,6 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters; import org.apache.hadoop.mapred.Reporter; -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.*; - public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class); @@ -143,6 +149,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } + @Nonnull protected abstract AbstractProbabilisticTopicModel createModel(); @Override @@ -157,7 +164,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { for (int i = 0; i < length; i++) { Object o = wordCountsOI.getListElement(args[0], i); if (o == null) { - throw new HiveException("Given feature vector contains invalid elements"); + throw new HiveException("Given feature vector contains invalid null elements"); } String s = o.toString(); wordCounts[j] = s; @@ -167,7 +174,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { return; } - model.accumulateDocCount();; + model.accumulateDocCount(); update(wordCounts);
