Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 912010305 -> 8ac3165db


[HIVEMALL-219] Fixed LDA bug for single update

## What changes were proposed in this pull request?

Fixed LDA bug for single update and added unit tests

## What type of PR is it?

Bug Fix

## What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-219

## How was this patch tested?

unit tests and manual tests on EMR

## Checklist

(Please remove this section if not needed; check `x` for YES, blank for NO)

- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for 
your commit?
- [x] Did you run system tests on Hive (or Spark)?

Author: Makoto Yui <[email protected]>

Closes #166 from myui/HIVEMALL-219-2.


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8ac3165d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8ac3165d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8ac3165d

Branch: refs/heads/master
Commit: 8ac3165db33a63cf3da9c2598b9051c26736cdd1
Parents: 9120103
Author: Makoto Yui <[email protected]>
Authored: Tue Sep 18 19:46:18 2018 +0900
Committer: Makoto Yui <[email protected]>
Committed: Tue Sep 18 19:46:18 2018 +0900

----------------------------------------------------------------------
 .../main/java/hivemall/topicmodel/LDAUDTF.java  |  4 +--
 .../hivemall/topicmodel/OnlineLDAModel.java     |  2 +-
 .../java/hivemall/topicmodel/LDAUDTFTest.java   | 35 ++++++++++++++++++--
 .../java/hivemall/topicmodel/PLSAUDTFTest.java  | 35 ++++++++++++++++++--
 4 files changed, 67 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8ac3165d/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java 
b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 4cbd964..59dc2e3 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -45,7 +45,7 @@ public final class LDAUDTF extends 
ProbabilisticTopicModelBaseUDTF {
 
         this.alpha = 1.f / topics;
         this.eta = 1.f / topics;
-        this.numDocs = -1L;
+        this.numDocs = 0L;
         this.tau0 = 64.d;
         this.kappa = 0.7;
         this.delta = DEFAULT_DELTA;
@@ -72,7 +72,7 @@ public final class LDAUDTF extends 
ProbabilisticTopicModelBaseUDTF {
         if (cl != null) {
             this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f 
/ topics);
             this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / 
topics);
-            this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), 
-1L);
+            this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), 
0L);
             this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 
64.d);
             if (tau0 <= 0.d) {
                 throw new UDFArgumentException("'-tau0' must be positive: " + 
tau0);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8ac3165d/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java 
b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index aed17ba..1ddf52b 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -110,7 +110,7 @@ public final class OnlineLDAModel extends 
AbstractProbabilisticTopicModel {
         this._kappa = kappa;
         this._delta = delta;
 
-        this._isAutoD = (_D < 0L);
+        this._isAutoD = (_D <= 0L);
 
         // initialize a random number generator
         this._gd = new GammaDistribution(SHAPE, SCALE);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8ac3165d/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java 
b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
index b4636f7..ed29487 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
@@ -18,18 +18,20 @@
  */
 package hivemall.topicmodel;
 
+import hivemall.TestUtils;
+import hivemall.utils.lang.mutable.MutableInt;
+
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.SortedMap;
-import java.util.Arrays;
 
-import hivemall.TestUtils;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.Collector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -181,6 +183,33 @@ public class LDAUDTFTest {
                         "oranges:1")}});
     }
 
+    @Test
+    public void testSingleRow() throws HiveException {
+        LDAUDTF udtf = new LDAUDTF();
+        final int numTopics = 2;
+        ObjectInspector[] argOIs = new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                    "-topics " + numTopics)};
+        udtf.initialize(argOIs);
+
+        String[] doc1 = new String[] {"1", "2", "3"};
+        udtf.process(new Object[] {Arrays.asList(doc1)});
+
+        final MutableInt cnt = new MutableInt(0);
+        udtf.setCollector(new Collector() {
+            @Override
+            public void collect(Object arg0) throws HiveException {
+                cnt.addValue(1);
+            }
+        });
+        udtf.close();
+
+        Assert.assertEquals(doc1.length * numTopics, cnt.getValue());
+    }
+
     private static void println(String msg) {
         if (DEBUG) {
             System.out.println(msg);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8ac3165d/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java 
b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
index 7f344d1..a069bcb 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
@@ -18,18 +18,20 @@
  */
 package hivemall.topicmodel;
 
+import hivemall.TestUtils;
+import hivemall.utils.lang.mutable.MutableInt;
+
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.SortedMap;
-import java.util.Arrays;
 
-import hivemall.TestUtils;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.Collector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -182,6 +184,33 @@ public class PLSAUDTFTest {
                         "oranges:1")}});
     }
 
+    @Test
+    public void testSingleRow() throws HiveException {
+        PLSAUDTF udtf = new PLSAUDTF();
+        final int numTopics = 2;
+        ObjectInspector[] argOIs = new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                    "-topics " + numTopics)};
+        udtf.initialize(argOIs);
+
+        String[] doc1 = new String[] {"1", "2", "3"};
+        udtf.process(new Object[] {Arrays.asList(doc1)});
+
+        final MutableInt cnt = new MutableInt(0);
+        udtf.setCollector(new Collector() {
+            @Override
+            public void collect(Object arg0) throws HiveException {
+                cnt.addValue(1);
+            }
+        });
+        udtf.close();
+
+        Assert.assertEquals(doc1.length * numTopics, cnt.getValue());
+    }
+
     private static void println(String msg) {
         if (DEBUG) {
             System.out.println(msg);

Reply via email to