Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 609e48016 -> 1dac1a62f (forced update)


Close #83, Close #82: [HIVEMALL-109][HIVEMALL-112] Fix topic model and tokenize 
UDFs


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/78acc428
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/78acc428
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/78acc428

Branch: refs/heads/master
Commit: 78acc428e57aa23e1d89ad757cb0693f9155a277
Parents: 10e7d45
Author: Takuya Kitazawa <[email protected]>
Authored: Tue Jun 6 14:19:07 2017 +0900
Committer: myui <[email protected]>
Committed: Tue Jun 6 15:31:20 2017 +0900

----------------------------------------------------------------------
 .../java/hivemall/tools/text/TokenizeUDF.java   |  3 +
 .../topicmodel/IncrementalPLSAModel.java        |  5 +-
 .../hivemall/topicmodel/LDAPredictUDAF.java     |  8 +-
 .../main/java/hivemall/topicmodel/LDAUDTF.java  | 49 ++++++-----
 .../hivemall/topicmodel/PLSAPredictUDAF.java    | 11 ++-
 .../main/java/hivemall/topicmodel/PLSAUDTF.java | 48 ++++++-----
 .../hivemall/topicmodel/LDAPredictUDAFTest.java | 70 +++++++---------
 .../java/hivemall/topicmodel/LDAUDTFTest.java   | 87 +++++++++++++++++---
 .../hivemall/topicmodel/OnlineLDAModelTest.java | 19 +++--
 .../topicmodel/PLSAPredictUDAFTest.java         |  4 +-
 .../java/hivemall/topicmodel/PLSAUDTFTest.java  | 74 +++++++++++++++--
 11 files changed, 263 insertions(+), 115 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/main/java/hivemall/tools/text/TokenizeUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/text/TokenizeUDF.java 
b/core/src/main/java/hivemall/tools/text/TokenizeUDF.java
index a5c8777..d306fa2 100644
--- a/core/src/main/java/hivemall/tools/text/TokenizeUDF.java
+++ b/core/src/main/java/hivemall/tools/text/TokenizeUDF.java
@@ -38,6 +38,9 @@ public final class TokenizeUDF extends UDF {
     }
 
     public List<Text> evaluate(Text input, boolean toLowerCase) {
+        if (input == null) {
+            return null;
+        }
         final List<Text> tokens = new ArrayList<Text>();
         final StringTokenizer tokenizer = new 
StringTokenizer(input.toString(), DELIM);
         while (tokenizer.hasMoreElements()) {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java 
b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index a75febb..6eef23e 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -264,8 +264,9 @@ public final class IncrementalPLSAModel {
                 }
 
                 if (p_dw == 0.d) {
-                    throw new IllegalStateException("Perplexity would be 
Infinity. "
-                            + "Try different mini-batch size `-s`, larger 
`-delta` and/or larger `-alpha`.");
+                    throw new IllegalStateException(
+                        "Perplexity would be Infinity. "
+                                + "Try different mini-batch size `-s`, larger 
`-delta` and/or larger `-alpha`.");
                 }
                 numer += w_value * Math.log(p_dw);
                 denom += w_value;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java 
b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index 8d1edd8..03779b0 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -179,14 +179,16 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
                 String rawArgs = HiveUtils.getConstString(argOIs[4]);
                 cl = parseOptions(rawArgs);
 
-                this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 
LDAUDTF.DEFAULT_TOPICS);
+                this.topics = Primitives.parseInt(cl.getOptionValue("topics"),
+                    LDAUDTF.DEFAULT_TOPICS);
                 if (topics < 1) {
                     throw new UDFArgumentException(
-                            "A positive integer MUST be set to an option 
`-topics`: " + topics);
+                        "A positive integer MUST be set to an option 
`-topics`: " + topics);
                 }
 
                 this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 
1.f / topics);
-                this.delta = 
Primitives.parseDouble(cl.getOptionValue("delta"), LDAUDTF.DEFAULT_DELTA);
+                this.delta = Primitives.parseDouble(cl.getOptionValue("delta"),
+                    LDAUDTF.DEFAULT_DELTA);
             } else {
                 this.topics = LDAUDTF.DEFAULT_TOPICS;
                 this.alpha = 1.f / topics;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java 
b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 1cec875..daec7ea 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -22,6 +22,7 @@ import hivemall.UDTFWithOptions;
 import hivemall.annotations.VisibleForTesting;
 import hivemall.utils.hadoop.HiveUtils;
 import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NIOUtils;
 import hivemall.utils.io.NioStatefullSegment;
 import hivemall.utils.lang.NumberUtils;
 import hivemall.utils.lang.Primitives;
@@ -260,28 +261,26 @@ public class LDAUDTF extends UDTFWithOptions {
             this.fileIO = dst = new NioStatefullSegment(file, false);
         }
 
-        int wcLength = 0;
+        // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 
length, wc2 string, ...
+        int wcLengthTotal = 0;
         for (String wc : wordCounts) {
             if (wc == null) {
                 continue;
             }
-            wcLength += wc.getBytes().length;
+            wcLengthTotal += wc.length();
         }
-        // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, 
wc2 string, ...
-        int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + 
wcLength;
+        int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * 
wordCounts.length + wcLengthTotal
+                * SizeOf.CHAR;
+
         int remain = buf.remaining();
-        if (remain < recordBytes) {
+        if (remain < requiredRecordBytes) {
             writeBuffer(buf, dst);
         }
 
-        buf.putInt(recordBytes);
+        buf.putInt(requiredRecordBytes);
         buf.putInt(wordCounts.length);
         for (String wc : wordCounts) {
-            if (wc == null) {
-                continue;
-            }
-            buf.putInt(wc.length());
-            buf.put(wc.getBytes());
+            NIOUtils.putString(wc, buf);
         }
     }
 
@@ -351,10 +350,7 @@ public class LDAUDTF extends UDTFWithOptions {
                         int wcLength = buf.getInt();
                         final String[] wordCounts = new String[wcLength];
                         for (int j = 0; j < wcLength; j++) {
-                            int len = buf.getInt();
-                            byte[] bytes = new byte[len];
-                            buf.get(bytes);
-                            wordCounts[j] = new String(bytes);
+                            wordCounts[j] = NIOUtils.getString(buf);
                         }
 
                         miniBatch[miniBatchCount] = wordCounts;
@@ -393,7 +389,6 @@ public class LDAUDTF extends UDTFWithOptions {
                         + NumberUtils.formatNumber(numTrainingExamples * 
Math.min(iter, iterations))
                         + " training updates in total) ");
             } else {// read training examples in the temporary file and invoke 
train for each example
-
                 // write training examples in buffer to a temporary file
                 if (buf.remaining() > 0) {
                     writeBuffer(buf, dst);
@@ -452,7 +447,7 @@ public class LDAUDTF extends UDTFWithOptions {
                         }
                         while (remain >= SizeOf.INT) {
                             int pos = buf.position();
-                            int recordBytes = buf.getInt();
+                            int recordBytes = buf.getInt() - SizeOf.INT;
                             remain -= SizeOf.INT;
                             if (remain < recordBytes) {
                                 buf.position(pos);
@@ -462,10 +457,7 @@ public class LDAUDTF extends UDTFWithOptions {
                             int wcLength = buf.getInt();
                             final String[] wordCounts = new String[wcLength];
                             for (int j = 0; j < wcLength; j++) {
-                                int len = buf.getInt();
-                                byte[] bytes = new byte[len];
-                                buf.get(bytes);
-                                wordCounts[j] = new String(bytes);
+                                wordCounts[j] = NIOUtils.getString(buf);
                             }
 
                             miniBatch[miniBatchCount] = wordCounts;
@@ -554,6 +546,21 @@ public class LDAUDTF extends UDTFWithOptions {
      */
 
     @VisibleForTesting
+    public void closeWithoutModelReset() throws HiveException {
+        // launch close(), but not forward & clear model
+        if (count == 0) {
+            this.model = null;
+            return;
+        }
+        if (miniBatchCount > 0) { // update for remaining samples
+            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+        }
+        if (iterations > 1) {
+            runIterativeTraining(iterations);
+        }
+    }
+
+    @VisibleForTesting
     double getLambda(String label, int k) {
         return model.getLambda(label, k);
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java 
b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
index 08febb4..ff29236 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -180,14 +180,17 @@ public final class PLSAPredictUDAF extends 
AbstractGenericUDAFResolver {
                 String rawArgs = HiveUtils.getConstString(argOIs[4]);
                 cl = parseOptions(rawArgs);
 
-                this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 
PLSAUDTF.DEFAULT_TOPICS);
+                this.topics = Primitives.parseInt(cl.getOptionValue("topics"),
+                    PLSAUDTF.DEFAULT_TOPICS);
                 if (topics < 1) {
                     throw new UDFArgumentException(
-                            "A positive integer MUST be set to an option 
`-topics`: " + topics);
+                        "A positive integer MUST be set to an option 
`-topics`: " + topics);
                 }
 
-                this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 
PLSAUDTF.DEFAULT_ALPHA);
-                this.delta = 
Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA);
+                this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"),
+                    PLSAUDTF.DEFAULT_ALPHA);
+                this.delta = Primitives.parseDouble(cl.getOptionValue("delta"),
+                    PLSAUDTF.DEFAULT_DELTA);
             } else {
                 this.topics = PLSAUDTF.DEFAULT_TOPICS;
                 this.alpha = PLSAUDTF.DEFAULT_ALPHA;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java 
b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index 014356e..46f731f 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -22,6 +22,7 @@ import hivemall.UDTFWithOptions;
 import hivemall.annotations.VisibleForTesting;
 import hivemall.utils.hadoop.HiveUtils;
 import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NIOUtils;
 import hivemall.utils.io.NioStatefullSegment;
 import hivemall.utils.lang.NumberUtils;
 import hivemall.utils.lang.Primitives;
@@ -230,28 +231,26 @@ public class PLSAUDTF extends UDTFWithOptions {
             this.fileIO = dst = new NioStatefullSegment(file, false);
         }
 
-        int wcLength = 0;
+        // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 
length, wc2 string, ...
+        int wcLengthTotal = 0;
         for (String wc : wordCounts) {
             if (wc == null) {
                 continue;
             }
-            wcLength += wc.getBytes().length;
+            wcLengthTotal += wc.length();
         }
-        // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, 
wc2 string, ...
-        int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + 
wcLength;
+        int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * 
wordCounts.length + wcLengthTotal
+                * SizeOf.CHAR;
+
         int remain = buf.remaining();
-        if (remain < recordBytes) {
+        if (remain < requiredRecordBytes) {
             writeBuffer(buf, dst);
         }
 
-        buf.putInt(recordBytes);
+        buf.putInt(requiredRecordBytes);
         buf.putInt(wordCounts.length);
         for (String wc : wordCounts) {
-            if (wc == null) {
-                continue;
-            }
-            buf.putInt(wc.length());
-            buf.put(wc.getBytes());
+            NIOUtils.putString(wc, buf);
         }
     }
 
@@ -321,10 +320,7 @@ public class PLSAUDTF extends UDTFWithOptions {
                         int wcLength = buf.getInt();
                         final String[] wordCounts = new String[wcLength];
                         for (int j = 0; j < wcLength; j++) {
-                            int len = buf.getInt();
-                            byte[] bytes = new byte[len];
-                            buf.get(bytes);
-                            wordCounts[j] = new String(bytes);
+                            wordCounts[j] = NIOUtils.getString(buf);
                         }
 
                         miniBatch[miniBatchCount] = wordCounts;
@@ -422,7 +418,7 @@ public class PLSAUDTF extends UDTFWithOptions {
                         }
                         while (remain >= SizeOf.INT) {
                             int pos = buf.position();
-                            int recordBytes = buf.getInt();
+                            int recordBytes = buf.getInt() - SizeOf.INT;
                             remain -= SizeOf.INT;
                             if (remain < recordBytes) {
                                 buf.position(pos);
@@ -432,10 +428,7 @@ public class PLSAUDTF extends UDTFWithOptions {
                             int wcLength = buf.getInt();
                             final String[] wordCounts = new String[wcLength];
                             for (int j = 0; j < wcLength; j++) {
-                                int len = buf.getInt();
-                                byte[] bytes = new byte[len];
-                                buf.get(bytes);
-                                wordCounts[j] = new String(bytes);
+                                wordCounts[j] = NIOUtils.getString(buf);
                             }
 
                             miniBatch[miniBatchCount] = wordCounts;
@@ -524,6 +517,21 @@ public class PLSAUDTF extends UDTFWithOptions {
      */
 
     @VisibleForTesting
+    public void closeWithoutModelReset() throws HiveException {
+        // launch close(), but not forward & clear model
+        if (count == 0) {
+            this.model = null;
+            return;
+        }
+        if (miniBatchCount > 0) { // update for remaining samples
+            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+        }
+        if (iterations > 1) {
+            runIterativeTraining(iterations);
+        }
+    }
+
+    @VisibleForTesting
     double getProbability(String label, int k) {
         return model.getProbability(label, k);
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java 
b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
index e09e57e..5099c66 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -53,14 +53,12 @@ public class LDAPredictUDAFTest {
         ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
 
         fieldNames.add("wcList");
-        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
-                PrimitiveObjectInspectorFactory.javaStringObjectInspector));
+        
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
 
         fieldNames.add("lambdaMap");
         fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
-                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
-                ObjectInspectorFactory.getStandardListObjectInspector(
-                        
PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
 
         fieldNames.add("topics");
         
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
@@ -74,11 +72,14 @@ public class LDAPredictUDAFTest {
         partialOI = new ObjectInspector[4];
         partialOI[0] = 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
 
-        words = new String[] {"fruits", "vegetables", "healthy", "flu", 
"apples", "oranges", "like", "avocados", "colds",
-            "colds", "avocados", "oranges", "like", "apples", "flu", 
"healthy", "vegetables", "fruits"};
+        words = new String[] {"fruits", "vegetables", "healthy", "flu", 
"apples", "oranges",
+                "like", "avocados", "colds", "colds", "avocados", "oranges", 
"like", "apples",
+                "flu", "healthy", "vegetables", "fruits"};
         labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 
1};
-        lambdas = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 
3.2804057E-4f, 3.0303953E-4f, 2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 
1.352576E-4f,
-            0.1660153f, 0.16596903f, 0.1659654f, 0.1659627f, 0.16593699f, 
0.1659259f, 0.0017611005f, 0.0015791848f, 8.84464E-4f};
+        lambdas = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 
3.2804057E-4f, 3.0303953E-4f,
+                2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 1.352576E-4f, 
0.1660153f, 0.16596903f,
+                0.1659654f, 0.1659627f, 0.16593699f, 0.1659259f, 
0.0017611005f, 0.0015791848f,
+                8.84464E-4f};
     }
 
     @Test
@@ -86,16 +87,12 @@ public class LDAPredictUDAFTest {
         udaf = new LDAPredictUDAF();
 
         inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
                 ObjectInspectorUtils.getConstantObjectInspector(
-                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topics 2")};
 
         evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
 
@@ -140,16 +137,12 @@ public class LDAPredictUDAFTest {
         udaf = new LDAPredictUDAF();
 
         inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
                 ObjectInspectorUtils.getConstantObjectInspector(
-                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topics 2")};
 
         evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
 
@@ -169,7 +162,8 @@ public class LDAPredictUDAFTest {
         evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
         evaluator.reset(agg);
         for (int i = 0; i < 6; i++) {
-            evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), 
labels[i], lambdas[i]});
+            evaluator.iterate(agg,
+                new Object[] {words[i], doc.get(words[i]), labels[i], 
lambdas[i]});
         }
         partials[0] = evaluator.terminatePartial(agg);
 
@@ -177,7 +171,8 @@ public class LDAPredictUDAFTest {
         evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
         evaluator.reset(agg);
         for (int i = 6; i < 12; i++) {
-            evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), 
labels[i], lambdas[i]});
+            evaluator.iterate(agg,
+                new Object[] {words[i], doc.get(words[i]), labels[i], 
lambdas[i]});
         }
         partials[1] = evaluator.terminatePartial(agg);
 
@@ -185,13 +180,14 @@ public class LDAPredictUDAFTest {
         evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
         evaluator.reset(agg);
         for (int i = 12; i < 18; i++) {
-            evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), 
labels[i], lambdas[i]});
+            evaluator.iterate(agg,
+                new Object[] {words[i], doc.get(words[i]), labels[i], 
lambdas[i]});
         }
 
         partials[2] = evaluator.terminatePartial(agg);
 
         // merge in a different order
-        final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, 
{2, 1, 0}};
+        final int[][] orders = new int[][] { {0, 1, 2}, {1, 0, 2}, {1, 2, 0}, 
{2, 1, 0}};
         for (int i = 0; i < orders.length; i++) {
             evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
             evaluator.reset(agg);
@@ -210,14 +206,10 @@ public class LDAPredictUDAFTest {
         udaf = new LDAPredictUDAF();
 
         inputOIs = new ObjectInspector[] {
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
 
         evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java 
b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
index a5881d4..a934ba3 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
@@ -40,18 +40,20 @@ public class LDAUDTFTest {
         LDAUDTF udtf = new LDAUDTF();
 
         ObjectInspector[] argOIs = new ObjectInspector[] {
-            
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
-            ObjectInspectorUtils.getConstantObjectInspector(
-                PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topics 2 -num_docs 2 -s 1")};
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                    "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3")};
 
         udtf.initialize(argOIs);
 
-        String[] doc1 = new String[]{"fruits:1", "healthy:1", "vegetables:1"};
-        String[] doc2 = new String[]{"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2", "oranges:1"};
-        for (int it = 0; it < 5; it++) {
-            udtf.process(new Object[]{ Arrays.asList(doc1) });
-            udtf.process(new Object[]{ Arrays.asList(doc2) });
-        }
+        String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
+        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2",
+                "oranges:1"};
+        udtf.process(new Object[] {Arrays.asList(doc1)});
+        udtf.process(new Object[] {Arrays.asList(doc2)});
+
+        udtf.closeWithoutModelReset();
 
         SortedMap<Float, List<String>> topicWords;
 
@@ -89,13 +91,76 @@ public class LDAUDTFTest {
         }
 
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 
100) + "%), "
-            + "and `vegetables` SHOULD be more suitable topic word than `flu` 
in the topic",
+                + "and `vegetables` SHOULD be more suitable topic word than 
`flu` in the topic",
             udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 
100) + "%), "
-            + "and `avocados` SHOULD be more suitable topic word than 
`healthy` in the topic",
+                + "and `avocados` SHOULD be more suitable topic word than 
`healthy` in the topic",
             udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2));
     }
 
+    @Test
+    public void testMultiBytes() throws HiveException {
+        LDAUDTF udtf = new LDAUDTF();
+
+        ObjectInspector[] argOIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                    "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3")};
+
+        udtf.initialize(argOIs);
+
+        String[] doc1 = new String[] {"果物:1", "健康:1", "野菜:1"};
+        String[] doc2 = new String[] {"りんご:1", "アボカド:1", 
"風邪:1", "インフルエンザ:1", "好き:2", "みかん:1"};
+
+        udtf.process(new Object[] {Arrays.asList(doc1)});
+        udtf.process(new Object[] {Arrays.asList(doc2)});
+
+        udtf.closeWithoutModelReset();
+
+        SortedMap<Float, List<String>> topicWords;
+
+        println("Topic 0:");
+        println("========");
+        topicWords = udtf.getTopicWords(0);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        println("Topic 1:");
+        println("========");
+        topicWords = udtf.getTopicWords(1);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        int k1, k2;
+        float[] topicDistr = udtf.getTopicDistribution(doc1);
+        if (topicDistr[0] > topicDistr[1]) {
+            // topic 0 MUST represent doc#1
+            k1 = 0;
+            k2 = 1;
+        } else {
+            k1 = 1;
+            k2 = 0;
+        }
+
+        Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 
100) + "%), "
+                + "and `野菜` SHOULD be more suitable topic word than 
`インフルエンザ` in the topic",
+            udtf.getLambda("野菜", k1) > 
udtf.getLambda("インフルエンザ", k1));
+        Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 
100) + "%), "
+                + "and `アボカド` SHOULD be more suitable topic word than 
`健康` in the topic",
+            udtf.getLambda("アボカド", k2) > udtf.getLambda("健康", k2));
+    }
+
     private static void println(String msg) {
         if (DEBUG) {
             System.out.println(msg);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java 
b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
index b4810a6..5b0a8c2 100644
--- a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
+++ b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
@@ -52,7 +52,8 @@ public class OnlineLDAModelTest {
         OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, 2, 80, 
0.8, 1E-5d);
 
         String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
-        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2", "oranges:1"};
+        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2",
+                "oranges:1"};
 
         do {
             perplexityPrev = perplexity;
@@ -69,7 +70,7 @@ public class OnlineLDAModelTest {
 
             it++;
             println("Iteration " + it + ": mean perplexity = " + perplexity);
-        } while(Math.abs(perplexityPrev - perplexity) >= 1E-6f);
+        } while (Math.abs(perplexityPrev - perplexity) >= 1E-6f);
 
         SortedMap<Float, List<String>> topicWords;
 
@@ -106,10 +107,10 @@ public class OnlineLDAModelTest {
             k2 = 0;
         }
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 
100) + "%), "
-            + "and `vegetables` SHOULD be more suitable topic word than `flu` 
in the topic",
+                + "and `vegetables` SHOULD be more suitable topic word than 
`flu` in the topic",
             model.getLambda("vegetables", k1) > model.getLambda("flu", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 
100) + "%), "
-            + "and `avocados` SHOULD be more suitable topic word than 
`healthy` in the topic",
+                + "and `avocados` SHOULD be more suitable topic word than 
`healthy` in the topic",
             model.getLambda("avocados", k2) > model.getLambda("healthy", k2));
     }
 
@@ -123,7 +124,8 @@ public class OnlineLDAModelTest {
         OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, 2, 80, 
0.8, 1E-5d);
 
         String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
-        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2", "oranges:1"};
+        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2",
+                "oranges:1"};
 
         do {
             perplexityPrev = perplexity;
@@ -132,7 +134,7 @@ public class OnlineLDAModelTest {
             perplexity = model.computePerplexity();
 
             it++;
-        } while(Math.abs(perplexityPrev - perplexity) >= 1E-6f);
+        } while (Math.abs(perplexityPrev - perplexity) >= 1E-6f);
 
         println("Iterated " + it + " times, perplexity = " + perplexity);
 
@@ -141,7 +143,8 @@ public class OnlineLDAModelTest {
         // returns perplexity=15 in a batch setting and perplexity=22 in an 
online setting.
         // Hivemall needs to converge to the similar perplexity.
         Assert.assertTrue("Perplexity SHOULD be in [12, 25]; "
-            + "converged perplexity is too small or large for some 
reasons",12.f <= perplexity && perplexity <= 25.f);
+                + "converged perplexity is too small or large for some 
reasons", 12.f <= perplexity
+                && perplexity <= 25.f);
     }
 
     @Test
@@ -210,7 +213,7 @@ public class OnlineLDAModelTest {
             it++;
 
             println("Iteration " + it + ": mean perplexity = " + perplexity);
-        } while(Math.abs(perplexityPrev - perplexity) >= 1E-1f);
+        } while (Math.abs(perplexityPrev - perplexity) >= 1E-1f);
 
         Set<Integer> topics = new HashSet<Integer>();
         for (int k = 0; k < K; k++) {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java 
b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
index 2be48e1..78f4a62 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
@@ -92,7 +92,7 @@ public class PLSAPredictUDAFTest {
                 
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
                 
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
                 ObjectInspectorUtils.getConstantObjectInspector(
-                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topics 2")};
 
         evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
 
@@ -142,7 +142,7 @@ public class PLSAPredictUDAFTest {
                 
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
                 
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
                 ObjectInspectorUtils.getConstantObjectInspector(
-                        
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topics 2")};
 
         evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/78acc428/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java 
b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
index 76795bc..addacbc 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
@@ -43,17 +43,18 @@ public class PLSAUDTFTest {
                 
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                 ObjectInspectorUtils.getConstantObjectInspector(
                     PrimitiveObjectInspectorFactory.javaStringObjectInspector,
-                    "-topics 2 -alpha 0.1 -delta 0.00001")};
+                    "-topics 2 -alpha 0.1 -delta 0.00001 -iter 10000")};
 
         udtf.initialize(argOIs);
 
         String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
         String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", 
"flu:1", "like:2",
                 "oranges:1"};
-        for (int it = 0; it < 10000; it++) {
-            udtf.process(new Object[] {Arrays.asList(doc1)});
-            udtf.process(new Object[] {Arrays.asList(doc2)});
-        }
+
+        udtf.process(new Object[] {Arrays.asList(doc1)});
+        udtf.process(new Object[] {Arrays.asList(doc2)});
+
+        udtf.closeWithoutModelReset();
 
         SortedMap<Float, List<String>> topicWords;
 
@@ -98,6 +99,69 @@ public class PLSAUDTFTest {
             udtf.getProbability("avocados", k2) > 
udtf.getProbability("healthy", k2));
     }
 
+    @Test
+    public void testMultiBytes() throws HiveException {
+        PLSAUDTF udtf = new PLSAUDTF();
+
+        ObjectInspector[] argOIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                    "-topics 2 -alpha 0.1 -delta 0.00001 -iter 10000")};
+
+        udtf.initialize(argOIs);
+
+        String[] doc1 = new String[] {"果物:1", "健康:1", "野菜:1"};
+        String[] doc2 = new String[] {"りんご:1", "アボカド:1", 
"風邪:1", "インフルエンザ:1", "好き:2", "みかん:1"};
+
+        udtf.process(new Object[] {Arrays.asList(doc1)});
+        udtf.process(new Object[] {Arrays.asList(doc2)});
+
+        udtf.closeWithoutModelReset();
+
+        SortedMap<Float, List<String>> topicWords;
+
+        println("Topic 0:");
+        println("========");
+        topicWords = udtf.getTopicWords(0);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        println("Topic 1:");
+        println("========");
+        topicWords = udtf.getTopicWords(1);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        int k1, k2;
+        float[] topicDistr = udtf.getTopicDistribution(doc1);
+        if (topicDistr[0] > topicDistr[1]) {
+            // topic 0 MUST represent doc#1
+            k1 = 0;
+            k2 = 1;
+        } else {
+            k1 = 1;
+            k2 = 0;
+        }
+
+        Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 
100) + "%), "
+                + "and `野菜` SHOULD be more suitable topic word than 
`インフルエンザ` in the topic",
+            udtf.getProbability("野菜", k1) > 
udtf.getProbability("インフルエンザ", k1));
+        Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 
100) + "%), "
+                + "and `アボカド` SHOULD be more suitable topic word than 
`健康` in the topic",
+            udtf.getProbability("アボカド", k2) > 
udtf.getProbability("健康", k2));
+    }
+
     private static void println(String msg) {
         if (DEBUG) {
             System.out.println(msg);


Reply via email to