Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 211c28036 -> e27307898


Close #76: [HIVEMALL-74-2][HIVEMALL-91-2] Revise topic model UDFs


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e2730789
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e2730789
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e2730789

Branch: refs/heads/master
Commit: e273078982261d7e9eb5cd93cebe1ca8a3c1c5e9
Parents: 211c280
Author: Takuya Kitazawa <[email protected]>
Authored: Tue May 9 17:05:14 2017 +0900
Committer: myui <[email protected]>
Committed: Tue May 9 17:05:14 2017 +0900

----------------------------------------------------------------------
 .../topicmodel/IncrementalPLSAModel.java        |  21 ++-
 .../hivemall/topicmodel/LDAPredictUDAF.java     |  82 +++++-----
 .../main/java/hivemall/topicmodel/LDAUDTF.java  |  13 +-
 .../hivemall/topicmodel/PLSAPredictUDAF.java    |  36 +++--
 .../main/java/hivemall/topicmodel/PLSAUDTF.java |  11 +-
 .../java/hivemall/utils/lang/ArrayUtils.java    |   6 +-
 .../java/hivemall/utils/math/MathUtils.java     |   8 +-
 .../topicmodel/IncrementalPLSAModelTest.java    |   6 +-
 .../hivemall/topicmodel/LDAPredictUDAFTest.java | 151 +++++++++++--------
 .../topicmodel/PLSAPredictUDAFTest.java         | 129 ++++++++++------
 docs/gitbook/clustering/lda.md                  |  10 +-
 docs/gitbook/clustering/plsa.md                 |  16 +-
 12 files changed, 304 insertions(+), 185 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java 
b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index 745e510..a75febb 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -20,6 +20,8 @@ package hivemall.topicmodel;
 
 import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
 import static hivemall.utils.math.MathUtils.l1normalize;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
 import hivemall.model.FeatureValue;
 import hivemall.utils.math.MathUtils;
 
@@ -29,7 +31,6 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Random;
 import java.util.SortedMap;
 import java.util.TreeMap;
 
@@ -54,7 +55,7 @@ public final class IncrementalPLSAModel {
 
     // random number generator
     @Nonnull
-    private final Random _rnd;
+    private final PRNG _rnd;
 
     // optimized in the E step
     private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of 
topics for each document-word (i.e., instance-feature) pair
@@ -73,7 +74,7 @@ public final class IncrementalPLSAModel {
         this._alpha = alpha;
         this._delta = delta;
 
-        this._rnd = new Random(1001);
+        this._rnd = RandomNumberGeneratorFactory.createPRNG(1001);
 
         this._p_zw = new HashMap<String, float[]>();
 
@@ -92,7 +93,9 @@ public final class IncrementalPLSAModel {
         for (int d = 0; d < _miniBatchSize; d++) {
             do {
                 pPrev_dz.clear();
-                pPrev_dz.addAll(_p_dz);
+                for (float[] p_dz_d : _p_dz) { // deep copy
+                    pPrev_dz.add(p_dz_d.clone());
+                }
 
                 // Expectation
                 eStep(d);
@@ -216,6 +219,10 @@ public final class IncrementalPLSAModel {
                 for (int z = 0; z < _K; z++) {
                     p_zw_w[z] = n * p_dwz_dw[z] + _alpha * p_zw_w[z];
                 }
+            } else { // others
+                for (int z = 0; z < _K; z++) {
+                    p_zw_w[z] = _alpha * p_zw_w[z];
+                }
             }
 
             MathUtils.add(p_zw_w, sums, _K);
@@ -223,7 +230,7 @@ public final class IncrementalPLSAModel {
         // normalize to ensure \sum_w P(w|z) = 1
         for (float[] p_zw_w : _p_zw.values()) {
             for (int z = 0; z < _K; z++) {
-                p_zw_w[z] /= sums[z];
+                p_zw_w[z] = (float) (p_zw_w[z] / sums[z]);
             }
         }
     }
@@ -256,6 +263,10 @@ public final class IncrementalPLSAModel {
                     p_dw += (double) p_zw_w[z] * p_dz_d[z];
                 }
 
+                if (p_dw == 0.d) {
+                    throw new IllegalStateException("Perplexity would be 
Infinity. "
+                            + "Try different mini-batch size `-s`, larger 
`-delta` and/or larger `-alpha`.");
+                }
                 numer += w_value * Math.log(p_dw);
                 denom += w_value;
             }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java 
b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index a4076b6..8d1edd8 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -120,7 +120,7 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
         private StructObjectInspector internalMergeOI;
         private StructField wcListField;
         private StructField lambdaMapField;
-        private StructField topicOptionField;
+        private StructField topicsOptionField;
         private StructField alphaOptionField;
         private StructField deltaOptionField;
         private PrimitiveObjectInspector wcListElemOI;
@@ -134,7 +134,7 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
 
         protected Options getOptions() {
             Options opts = new Options();
-            opts.addOption("k", "topics", true, "The number of topics 
[required]");
+            opts.addOption("k", "topics", true, "The number of topics 
[default: 10]");
             opts.addOption("alpha", true, "The hyperparameter for theta 
[default: 1/k]");
             opts.addOption("delta", true,
                 "Check convergence in the expectation step [default: 1E-5]");
@@ -175,22 +175,24 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
         protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
             CommandLine cl = null;
 
-            if (argOIs.length != 5) {
-                throw new UDFArgumentException("At least 1 option `-topics` 
MUST be specified");
-            }
+            if (argOIs.length >= 5) {
+                String rawArgs = HiveUtils.getConstString(argOIs[4]);
+                cl = parseOptions(rawArgs);
 
-            String rawArgs = HiveUtils.getConstString(argOIs[4]);
-            cl = parseOptions(rawArgs);
+                this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 
LDAUDTF.DEFAULT_TOPICS);
+                if (topics < 1) {
+                    throw new UDFArgumentException(
+                            "A positive integer MUST be set to an option 
`-topics`: " + topics);
+                }
 
-            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 0);
-            if (topics < 1) {
-                throw new UDFArgumentException(
-                    "A positive integer MUST be set to an option `-topics`: " 
+ topics);
+                this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 
1.f / topics);
+                this.delta = 
Primitives.parseDouble(cl.getOptionValue("delta"), LDAUDTF.DEFAULT_DELTA);
+            } else {
+                this.topics = LDAUDTF.DEFAULT_TOPICS;
+                this.alpha = 1.f / topics;
+                this.delta = LDAUDTF.DEFAULT_DELTA;
             }
 
-            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f 
/ topics);
-            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 
1E-5d);
-
             return cl;
         }
 
@@ -211,7 +213,7 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
                 this.internalMergeOI = soi;
                 this.wcListField = soi.getStructFieldRef("wcList");
                 this.lambdaMapField = soi.getStructFieldRef("lambdaMap");
-                this.topicOptionField = soi.getStructFieldRef("topics");
+                this.topicsOptionField = soi.getStructFieldRef("topics");
                 this.alphaOptionField = soi.getStructFieldRef("alpha");
                 this.deltaOptionField = soi.getStructFieldRef("delta");
                 this.wcListElemOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
@@ -310,7 +312,7 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
             Object[] partialResult = new Object[5];
             partialResult[0] = myAggr.wcList;
             partialResult[1] = myAggr.lambdaMap;
-            partialResult[2] = new IntWritable(myAggr.topic);
+            partialResult[2] = new IntWritable(myAggr.topics);
             partialResult[3] = new FloatWritable(myAggr.alpha);
             partialResult[4] = new DoubleWritable(myAggr.delta);
 
@@ -358,8 +360,8 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
             }
 
             // restore options from partial result
-            Object topicObj = internalMergeOI.getStructFieldData(partial, 
topicOptionField);
-            this.topics = 
PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+            Object topicsObj = internalMergeOI.getStructFieldData(partial, 
topicsOptionField);
+            this.topics = 
PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicsObj);
 
             Object alphaObj = internalMergeOI.getStructFieldData(partial, 
alphaOptionField);
             this.alpha = 
PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
@@ -402,7 +404,7 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
         private List<String> wcList;
         private Map<String, List<Float>> lambdaMap;
 
-        private int topic;
+        private int topics;
         private float alpha;
         private double delta;
 
@@ -410,8 +412,8 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
             super();
         }
 
-        void setOptions(int topic, float alpha, double delta) {
-            this.topic = topic;
+        void setOptions(int topics, float alpha, double delta) {
+            this.topics = topics;
             this.alpha = alpha;
             this.delta = delta;
         }
@@ -424,17 +426,16 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
         void iterate(String word, float value, int label, float lambda) {
             wcList.add(word + ":" + value);
 
+            List<Float> lambda_word = lambdaMap.get(word);
+
             // for an unforeseen word, initialize its lambdas w/ -1s
-            if (!lambdaMap.containsKey(word)) {
-                List<Float> lambdaEmpty_word = new ArrayList<Float>(
-                    Collections.nCopies(topic, -1.f));
-                lambdaMap.put(word, lambdaEmpty_word);
+            if (lambda_word == null) {
+                lambda_word = new ArrayList<Float>(Collections.nCopies(topics, 
-1.f));
+                lambdaMap.put(word, lambda_word);
             }
 
             // set the given lambda value
-            List<Float> lambda_word = lambdaMap.get(word);
             lambda_word.set(label, lambda);
-            lambdaMap.put(word, lambda_word);
         }
 
         void merge(List<String> o_wcList, Map<String, List<Float>> 
o_lambdaMap) {
@@ -444,13 +445,14 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
                 String o_word = e.getKey();
                 List<Float> o_lambda_word = e.getValue();
 
-                if (!lambdaMap.containsKey(o_word)) { // for an unforeseen word
+                final List<Float> lambda_word = lambdaMap.get(o_word);
+                if (lambda_word == null) { // for an unforeseen word
                     lambdaMap.put(o_word, o_lambda_word);
                 } else { // for a partially observed word
-                    List<Float> lambda_word = lambdaMap.get(o_word);
-                    for (int k = 0; k < topic; k++) {
-                        if (o_lambda_word.get(k) != -1.f) { // not default 
value
-                            lambda_word.set(k, o_lambda_word.get(k)); // set 
the partial lambda value
+                    for (int k = 0; k < topics; k++) {
+                        final float lambda_k = 
o_lambda_word.get(k).floatValue();
+                        if (lambda_k != -1.f) { // not default value
+                            lambda_word.set(k, lambda_k); // set the partial 
lambda value
                         }
                     }
                     lambdaMap.put(o_word, lambda_word);
@@ -459,12 +461,16 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
         }
 
         float[] get() {
-            OnlineLDAModel model = new OnlineLDAModel(topic, alpha, delta);
-
-            for (String word : lambdaMap.keySet()) {
-                List<Float> lambda_word = lambdaMap.get(word);
-                for (int k = 0; k < topic; k++) {
-                    model.setLambda(word, k, lambda_word.get(k));
+            OnlineLDAModel model = new OnlineLDAModel(topics, alpha, delta);
+
+            for (Map.Entry<String, List<Float>> e : lambdaMap.entrySet()) {
+                final String word = e.getKey();
+                final List<Float> lambda_word = e.getValue();
+                for (int k = 0; k < topics; k++) {
+                    final float lambda_k = lambda_word.get(k).floatValue();
+                    if (lambda_k != -1.f) {
+                        model.setLambda(word, k, lambda_k);
+                    }
                 }
             }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java 
b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 1e28a30..1cec875 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -62,6 +62,9 @@ import org.apache.hadoop.mapred.Reporter;
 public class LDAUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(LDAUDTF.class);
 
+    public static final int DEFAULT_TOPICS = 10;
+    public static final double DEFAULT_DELTA = 1E-3d;
+
     // Options
     protected int topics;
     protected float alpha;
@@ -93,14 +96,14 @@ public class LDAUDTF extends UDTFWithOptions {
     protected ByteBuffer inputBuf;
 
     public LDAUDTF() {
-        this.topics = 10;
+        this.topics = DEFAULT_TOPICS;
         this.alpha = 1.f / topics;
         this.eta = 1.f / topics;
         this.numDocs = -1L;
         this.tau0 = 64.d;
         this.kappa = 0.7;
         this.iterations = 10;
-        this.delta = 1E-3d;
+        this.delta = DEFAULT_DELTA;
         this.eps = 1E-1d;
         this.miniBatchSize = 128; // if 1, truly online setting
     }
@@ -131,7 +134,7 @@ public class LDAUDTF extends UDTFWithOptions {
         if (argOIs.length >= 2) {
             String rawArgs = HiveUtils.getConstString(argOIs[1]);
             cl = parseOptions(rawArgs);
-            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 10);
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 
DEFAULT_TOPICS);
             this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f 
/ topics);
             this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / 
topics);
             this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), 
-1L);
@@ -148,7 +151,7 @@ public class LDAUDTF extends UDTFWithOptions {
                 throw new UDFArgumentException(
                     "'-iterations' must be greater than or equals to 1: " + 
iterations);
             }
-            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 
1E-3d);
+            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 
DEFAULT_DELTA);
             this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 
1E-1d);
             this.miniBatchSize = 
Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
         }
@@ -504,6 +507,8 @@ public class LDAUDTF extends UDTFWithOptions {
                         + NumberUtils.formatNumber(numTrainingExamples * 
Math.min(iter, iterations))
                         + " training updates in total)");
             }
+        } catch (Throwable e) {
+            throw new HiveException("Exception caused in the iterative 
training", e);
         } finally {
             // delete the temporary file and release resources
             try {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java 
b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
index c0b60fc..08febb4 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -120,7 +120,7 @@ public final class PLSAPredictUDAF extends 
AbstractGenericUDAFResolver {
         private StructObjectInspector internalMergeOI;
         private StructField wcListField;
         private StructField probMapField;
-        private StructField topicOptionField;
+        private StructField topicsOptionField;
         private StructField alphaOptionField;
         private StructField deltaOptionField;
         private PrimitiveObjectInspector wcListElemOI;
@@ -174,21 +174,25 @@ public final class PLSAPredictUDAF extends 
AbstractGenericUDAFResolver {
 
         @Nullable
         protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
-            if (argOIs.length != 5) {
-                return null;
-            }
+            CommandLine cl = null;
 
-            String rawArgs = HiveUtils.getConstString(argOIs[4]);
-            CommandLine cl = parseOptions(rawArgs);
+            if (argOIs.length >= 5) {
+                String rawArgs = HiveUtils.getConstString(argOIs[4]);
+                cl = parseOptions(rawArgs);
 
-            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 
PLSAUDTF.DEFAULT_TOPICS);
-            if (topics < 1) {
-                throw new UDFArgumentException(
-                    "A positive integer MUST be set to an option `-topics`: " 
+ topics);
-            }
+                this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 
PLSAUDTF.DEFAULT_TOPICS);
+                if (topics < 1) {
+                    throw new UDFArgumentException(
+                            "A positive integer MUST be set to an option 
`-topics`: " + topics);
+                }
 
-            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 
PLSAUDTF.DEFAULT_ALPHA);
-            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 
PLSAUDTF.DEFAULT_DELTA);
+                this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 
PLSAUDTF.DEFAULT_ALPHA);
+                this.delta = 
Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA);
+            } else {
+                this.topics = PLSAUDTF.DEFAULT_TOPICS;
+                this.alpha = PLSAUDTF.DEFAULT_ALPHA;
+                this.delta = PLSAUDTF.DEFAULT_DELTA;
+            }
 
             return cl;
         }
@@ -210,7 +214,7 @@ public final class PLSAPredictUDAF extends 
AbstractGenericUDAFResolver {
                 this.internalMergeOI = soi;
                 this.wcListField = soi.getStructFieldRef("wcList");
                 this.probMapField = soi.getStructFieldRef("probMap");
-                this.topicOptionField = soi.getStructFieldRef("topics");
+                this.topicsOptionField = soi.getStructFieldRef("topics");
                 this.alphaOptionField = soi.getStructFieldRef("alpha");
                 this.deltaOptionField = soi.getStructFieldRef("delta");
                 this.wcListElemOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
@@ -356,8 +360,8 @@ public final class PLSAPredictUDAF extends 
AbstractGenericUDAFResolver {
             }
 
             // restore options from partial result
-            Object topicObj = internalMergeOI.getStructFieldData(partial, 
topicOptionField);
-            this.topics = 
PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+            Object topicsObj = internalMergeOI.getStructFieldData(partial, 
topicsOptionField);
+            this.topics = 
PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicsObj);
 
             Object alphaObj = internalMergeOI.getStructFieldData(partial, 
alphaOptionField);
             this.alpha = 
PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java 
b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index 2616133..014356e 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -46,7 +46,10 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import org.apache.hadoop.io.FloatWritable;
 import org.apache.hadoop.io.IntWritable;
@@ -58,11 +61,11 @@ import org.apache.hadoop.mapred.Reporter;
         + " - Returns a relation consists of <int topic, string word, float 
score>")
 public class PLSAUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
-    
+
     public static final int DEFAULT_TOPICS = 10;
     public static final float DEFAULT_ALPHA = 0.5f;
     public static final double DEFAULT_DELTA = 1E-3d;
-    
+
     // Options
     protected int topics;
     protected float alpha;
@@ -474,6 +477,8 @@ public class PLSAUDTF extends UDTFWithOptions {
                         + NumberUtils.formatNumber(numTrainingExamples * 
Math.min(iter, iterations))
                         + " training updates in total)");
             }
+        } catch (Throwable e) {
+            throw new HiveException("Exception caused in the iterative 
training", e);
         } finally {
             // delete the temporary file and release resources
             try {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java 
b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 4177d70..540f1c6 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -18,6 +18,8 @@
  */
 package hivemall.utils.lang;
 
+import hivemall.math.random.PRNG;
+
 import java.lang.reflect.Array;
 import java.util.Arrays;
 import java.util.List;
@@ -737,10 +739,10 @@ public final class ArrayUtils {
 
     @Nonnull
     public static float[] newRandomFloatArray(@Nonnegative final int size,
-            @Nonnull final Random rnd) {
+            @Nonnull final PRNG rnd) {
         final float[] ret = new float[size];
         for (int i = 0; i < size; i++) {
-            ret[i] = rnd.nextFloat();
+            ret[i] = (float) rnd.nextDouble();
         }
         return ret;
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java 
b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 8ffb89c..56c4f89 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -411,10 +411,14 @@ public final class MathUtils {
     @Nonnull
     public static float[] l1normalize(@Nonnull final float[] arr) {
         double sum = 0.d;
-        for (int i = 0; i < arr.length; i++) {
+        int size = arr.length;
+        for (int i = 0; i < size; i++) {
             sum += Math.abs(arr[i]);
         }
-        for (int i = 0; i < arr.length; i++) {
+        if (sum == 0.d) {
+            return new float[size];
+        }
+        for (int i = 0; i < size; i++) {
             arr[i] /= sum;
         }
         return arr;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java 
b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
index db34a38..79be3a7 100644
--- a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
+++ b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
@@ -50,7 +50,7 @@ public class IncrementalPLSAModelTest {
         float perplexityPrev;
         float perplexity = Float.MAX_VALUE;
 
-        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d);
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.5f, 1E-5d);
 
         String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
         String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2",
@@ -124,7 +124,7 @@ public class IncrementalPLSAModelTest {
         float perplexityPrev;
         float perplexity = Float.MAX_VALUE;
 
-        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d);
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.5f, 1E-5d);
 
         String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
         String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2",
@@ -191,7 +191,7 @@ public class IncrementalPLSAModelTest {
         int cnt, it;
         int maxIter = 64;
 
-        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.8f, 1E-5d);
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 100.f, 1E-3d);
 
         BufferedReader news20 = readFile("news20-multiclass.gz");
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java 
b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
index 2c08560..e09e57e 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -18,7 +18,8 @@
  */
 package hivemall.topicmodel;
 
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import hivemall.utils.math.MathUtils;
+
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -46,64 +47,8 @@ public class LDAPredictUDAFTest {
     int[] labels;
     float[] lambdas;
 
-    @Test(expected=UDFArgumentException.class)
-    public void testWithoutOption() throws Exception {
-        udaf = new LDAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
-
-        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
-        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
-    }
-
-    @Test(expected=UDFArgumentException.class)
-    public void testWithoutTopicOption() throws Exception {
-        udaf = new LDAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                ObjectInspectorUtils.getConstantObjectInspector(
-                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-alpha 0.1")};
-
-        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
-        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
-    }
-
     @Before
     public void setUp() throws Exception {
-        udaf = new LDAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                ObjectInspectorUtils.getConstantObjectInspector(
-                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
-
-        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
         ArrayList<String> fieldNames = new ArrayList<String>();
         ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
 
@@ -129,8 +74,6 @@ public class LDAPredictUDAFTest {
         partialOI = new ObjectInspector[4];
         partialOI[0] = 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
 
-        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
-
         words = new String[] {"fruits", "vegetables", "healthy", "flu", 
"apples", "oranges", "like", "avocados", "colds",
             "colds", "avocados", "oranges", "like", "apples", "flu", 
"healthy", "vegetables", "fruits"};
         labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 
1};
@@ -140,6 +83,24 @@ public class LDAPredictUDAFTest {
 
     @Test
     public void test() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
         final Map<String, Float> doc1 = new HashMap<String, Float>();
         doc1.put("fruits", 1.f);
         doc1.put("healthy", 1.f);
@@ -176,6 +137,24 @@ public class LDAPredictUDAFTest {
 
     @Test
     public void testMerge() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
         final Map<String, Float> doc = new HashMap<String, Float>();
         doc.put("apples", 1.f);
         doc.put("avocados", 1.f);
@@ -225,4 +204,58 @@ public class LDAPredictUDAFTest {
             Assert.assertTrue(distr[0] < distr[1]);
         }
     }
+
+    @Test
+    public void testUnmatchedTopics() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
+        final Map<String, Float> doc1 = new HashMap<String, Float>();
+        doc1.put("fruits", 1.f);
+        doc1.put("healthy", 1.f);
+        doc1.put("vegetables", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc1.get(word), 
labels[i], lambdas[i]});
+        }
+        float[] doc1Distr = agg.get();
+
+        final Map<String, Float> doc2 = new HashMap<String, Float>();
+        doc2.put("apples", 1.f);
+        doc2.put("avocados", 1.f);
+        doc2.put("colds", 1.f);
+        doc2.put("flu", 1.f);
+        doc2.put("like", 2.f);
+        doc2.put("oranges", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc2.get(word), 
labels[i], lambdas[i]});
+        }
+        float[] doc2Distr = agg.get();
+
+        Assert.assertEquals(LDAUDTF.DEFAULT_TOPICS, doc1Distr.length);
+        Assert.assertEquals(1.d, MathUtils.sum(doc1Distr), 1E-5d);
+
+        Assert.assertEquals(LDAUDTF.DEFAULT_TOPICS, doc2Distr.length);
+        Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d);
+    }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java 
b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
index 456dd1d..2be48e1 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
@@ -18,7 +18,8 @@
  */
 package hivemall.topicmodel;
 
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import hivemall.utils.math.MathUtils;
+
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -46,52 +47,8 @@ public class PLSAPredictUDAFTest {
     int[] labels;
     float[] probs;
 
-    @Test(expected = UDFArgumentException.class)
-    public void testWithoutOption() throws Exception {
-        udaf = new PLSAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
-
-        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
-        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
-    }
-
-    @Test(expected = UDFArgumentException.class)
-    public void testWithoutTopicOption() throws Exception {
-        udaf = new PLSAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                ObjectInspectorUtils.getConstantObjectInspector(
-                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-alpha 0.1")};
-
-        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
-        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
-    }
-
     @Before
     public void setUp() throws Exception {
-        udaf = new PLSAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                ObjectInspectorUtils.getConstantObjectInspector(
-                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topics 2")};
-
-        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
         ArrayList<String> fieldNames = new ArrayList<String>();
         ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
 
@@ -115,8 +72,6 @@ public class PLSAPredictUDAFTest {
         partialOI = new ObjectInspector[4];
         partialOI[0] = 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
 
-        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
-
         words = new String[] {"fruits", "vegetables", "healthy", "flu", 
"apples", "oranges",
                 "like", "avocados", "colds", "colds", "avocados", "oranges", 
"like", "apples",
                 "flu", "healthy", "vegetables", "fruits"};
@@ -129,6 +84,20 @@ public class PLSAPredictUDAFTest {
 
     @Test
     public void test() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
         final Map<String, Float> doc1 = new HashMap<String, Float>();
         doc1.put("fruits", 1.f);
         doc1.put("healthy", 1.f);
@@ -165,6 +134,20 @@ public class PLSAPredictUDAFTest {
 
     @Test
     public void testMerge() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
         final Map<String, Float> doc = new HashMap<String, Float>();
         doc.put("apples", 1.f);
         doc.put("avocados", 1.f);
@@ -214,4 +197,56 @@ public class PLSAPredictUDAFTest {
             Assert.assertTrue(distr[0] < distr[1]);
         }
     }
+
+    @Test
+    public void testUnmatchedTopics() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        // pre-defined topic model only has two topics, but prediction is 
launched with -topics=10 (default value)
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
+        final Map<String, Float> doc1 = new HashMap<String, Float>();
+        doc1.put("fruits", 1.f);
+        doc1.put("healthy", 1.f);
+        doc1.put("vegetables", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc1.get(word), 
labels[i], probs[i]});
+        }
+        float[] doc1Distr = agg.get();
+
+        final Map<String, Float> doc2 = new HashMap<String, Float>();
+        doc2.put("apples", 1.f);
+        doc2.put("avocados", 1.f);
+        doc2.put("colds", 1.f);
+        doc2.put("flu", 1.f);
+        doc2.put("like", 2.f);
+        doc2.put("oranges", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc2.get(word), 
labels[i], probs[i]});
+        }
+        float[] doc2Distr = agg.get();
+
+        Assert.assertEquals(PLSAUDTF.DEFAULT_TOPICS, doc1Distr.length);
+        Assert.assertEquals(1.d, MathUtils.sum(doc1Distr), 1E-5d);
+
+        Assert.assertEquals(PLSAUDTF.DEFAULT_TOPICS, doc2Distr.length);
+        Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d);
+    }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/docs/gitbook/clustering/lda.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/clustering/lda.md b/docs/gitbook/clustering/lda.md
index cc477da..8b8e5f5 100644
--- a/docs/gitbook/clustering/lda.md
+++ b/docs/gitbook/clustering/lda.md
@@ -82,7 +82,7 @@ with word_counts as (
     docid, word
 )
 select
-  train_lda(feature, "-topic 2 -iter 20") as (label, word, lambda)
+  train_lda(feature, "-topics 2 -iter 20") as (label, word, lambda)
 from (
   select docid, collect_set(word_count) as feature
   from word_counts
@@ -92,7 +92,7 @@ from (
 ;
 ```
 
-Here, an option `-topic 2` specifies the number of topics we assume in the set 
of documents.
+Here, an option `-topics 2` specifies the number of topics we assume in the 
set of documents.
 
 Notice that `order by docid` ensures building a LDA model precisely in a 
single node. In case that you like to launch `train_lda` in parallel, following 
query hopefully returns similar (but might be slightly approximated) result:
 
@@ -104,7 +104,7 @@ select
   label, word, avg(lambda) as lambda
 from (
   select
-    train_lda(feature, "-topic 2 -iter 20") as (label, word, lambda)
+    train_lda(feature, "-topics 2 -iter 20") as (label, word, lambda)
   from (
     select docid, collect_set(f) as feature
     from word_counts
@@ -163,7 +163,7 @@ with test as (
 )
 select
   t.docid,
-  lda_predict(t.word, t.value, m.label, m.lambda, "-topic 2") as probabilities
+  lda_predict(t.word, t.value, m.label, m.lambda, "-topics 2") as probabilities
 from
   test t
   JOIN lda_model m ON (t.word = m.word)
@@ -177,7 +177,7 @@ group by
 |1  | [{"label":0,"probability":0.875},{"label":1,"probability":0.125}]|
 |2  | [{"label":1,"probability":0.9375},{"label":0,"probability":0.0625}]|
 
-Importantly, an option `-topic` should be set to the same value as you set for 
training.
+Importantly, an option `-topics` is expected to be the same value as you set 
for training.
 
 Since the probabilities are sorted in descending order, a label of the most 
promising topic is easily obtained as:
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/docs/gitbook/clustering/plsa.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/clustering/plsa.md b/docs/gitbook/clustering/plsa.md
index 456dfe7..7cd3a9d 100644
--- a/docs/gitbook/clustering/plsa.md
+++ b/docs/gitbook/clustering/plsa.md
@@ -151,4 +151,18 @@ This value controls **how much iterative model update is 
affected by the old res
 
 From an algorithmic point of view, training pLSA (and LDA) iteratively repeats 
certain operations and updates the target value (i.e., probability obtained as 
a result of `train_plsa()`). This iterative procedure gradually makes the 
probabilities more accurate. What `alpha` does is to control the degree of the 
change of probabilities in each step.
 
-Normally, `alpha` is set to a small value from 0.0 to 0.5 (default is 0.5).
\ No newline at end of file
+Importantly, pLSA is likely to overfit single mini-batch. As a result, 
$$P(w|z)$$ could be particularly bad values (i.e., $$(w|z) = 0$$), and 
`train_plsa()` sometimes fails with an exception like:
+
+```
+Perplexity would be Infinity. Try different mini-batch size `-s`, larger 
`-delta` and/or larger `-alpha`.
+```
+
+In that case, you need to try different hyper-parameters to avoid overfitting 
as the exception suggests.
+
+For instance, [20 newsgroups dataset](http://qwone.com/~jason/20Newsgroups/) 
which consists of 10906 realistic documents empirically requires the following 
options:
+
+```sql
+SELECT train_plsa(features, "-topics 20 -iter 10 -s 128 -delta 0.01 -alpha 512 
-eps 0.1")
+```
+
+Clearly, `alpha` is much larger than `0.01` which was used for the dummy data 
above. Let you keep in mind that an appropriate value of `alpha` highly depends 
on the number of documents and mini-batch size.
\ No newline at end of file

Reply via email to