Repository: incubator-hivemall Updated Branches: refs/heads/master 912010305 -> 8ac3165db
[HIVEMALL-219] Fixed LDA bug for single update ## What changes were proposed in this pull request? Fixed LDA bug for single update and added unit tests ## What type of PR is it? Bug Fix ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-219 ## How was this patch tested? unit tests and manual tests on EMR ## Checklist (Please remove this section if not needed; check `x` for YES, blank for NO) - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <[email protected]> Closes #166 from myui/HIVEMALL-219-2. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8ac3165d Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8ac3165d Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8ac3165d Branch: refs/heads/master Commit: 8ac3165db33a63cf3da9c2598b9051c26736cdd1 Parents: 9120103 Author: Makoto Yui <[email protected]> Authored: Tue Sep 18 19:46:18 2018 +0900 Committer: Makoto Yui <[email protected]> Committed: Tue Sep 18 19:46:18 2018 +0900 ---------------------------------------------------------------------- .../main/java/hivemall/topicmodel/LDAUDTF.java | 4 +-- .../hivemall/topicmodel/OnlineLDAModel.java | 2 +- .../java/hivemall/topicmodel/LDAUDTFTest.java | 35 ++++++++++++++++++-- .../java/hivemall/topicmodel/PLSAUDTFTest.java | 35 ++++++++++++++++++-- 4 files changed, 67 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8ac3165d/core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java index 4cbd964..59dc2e3 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java @@ -45,7 +45,7 @@ public final class LDAUDTF extends ProbabilisticTopicModelBaseUDTF { this.alpha = 1.f / topics; this.eta = 1.f / topics; - this.numDocs = -1L; + this.numDocs = 0L; this.tau0 = 64.d; this.kappa = 0.7; this.delta = DEFAULT_DELTA; @@ -72,7 +72,7 @@ public final class LDAUDTF extends ProbabilisticTopicModelBaseUDTF { if (cl != null) { this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics); this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topics); - this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L); + this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), 0L); this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 64.d); if (tau0 <= 0.d) { throw new UDFArgumentException("'-tau0' must be positive: " + tau0); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8ac3165d/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java index aed17ba..1ddf52b 100644 --- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java +++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java @@ -110,7 +110,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel { this._kappa = kappa; this._delta = delta; - this._isAutoD = (_D < 0L); + this._isAutoD = (_D <= 0L); // initialize a random number generator this._gd = new GammaDistribution(SHAPE, SCALE); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8ac3165d/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java index b4636f7..ed29487 100644 --- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java +++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java @@ -18,18 +18,20 @@ */ package hivemall.topicmodel; +import hivemall.TestUtils; +import hivemall.utils.lang.mutable.MutableInt; + +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.SortedMap; -import java.util.Arrays; -import hivemall.TestUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.Collector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; - import org.junit.Assert; import org.junit.Test; @@ -181,6 +183,33 @@ public class LDAUDTFTest { "oranges:1")}}); } + @Test + public void testSingleRow() throws HiveException { + LDAUDTF udtf = new LDAUDTF(); + final int numTopics = 2; + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-topics " + numTopics)}; + udtf.initialize(argOIs); + + String[] doc1 = new String[] {"1", "2", "3"}; + udtf.process(new Object[] {Arrays.asList(doc1)}); + + final MutableInt cnt = new MutableInt(0); + udtf.setCollector(new Collector() { + @Override + public void collect(Object arg0) throws HiveException { + cnt.addValue(1); + } + }); + udtf.close(); + + Assert.assertEquals(doc1.length * numTopics, cnt.getValue()); + } + private static void println(String msg) { if (DEBUG) { System.out.println(msg); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8ac3165d/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java index 7f344d1..a069bcb 100644 --- a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java +++ b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java @@ -18,18 +18,20 @@ */ package hivemall.topicmodel; +import hivemall.TestUtils; +import hivemall.utils.lang.mutable.MutableInt; + +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.SortedMap; -import java.util.Arrays; -import hivemall.TestUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.Collector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; - import org.junit.Assert; import org.junit.Test; @@ -182,6 +184,33 @@ public class PLSAUDTFTest { "oranges:1")}}); } + @Test + public void testSingleRow() throws HiveException { + PLSAUDTF udtf = new PLSAUDTF(); + final int numTopics = 2; + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-topics " + numTopics)}; + udtf.initialize(argOIs); + + String[] doc1 = new String[] {"1", "2", "3"}; + udtf.process(new Object[] {Arrays.asList(doc1)}); + + final MutableInt cnt = new MutableInt(0); + udtf.setCollector(new Collector() { + @Override + public void collect(Object arg0) throws HiveException { + cnt.addValue(1); + } + }); + udtf.close(); + + Assert.assertEquals(doc1.length * numTopics, cnt.getValue()); + } + private static void println(String msg) { if (DEBUG) { System.out.println(msg);
