Repository: incubator-hivemall Updated Branches: refs/heads/master 10e7d450f -> 609e48016
Close #83 #82: [HIVEMALL-109][HIVEMALL-112] Fix topic model and tokenize UDFs Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/92e85ad1 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/92e85ad1 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/92e85ad1 Branch: refs/heads/master Commit: 92e85ad1411d36ff1fdc09188fb40820adc6106b Parents: 10e7d45 Author: Takuya Kitazawa <[email protected]> Authored: Tue Jun 6 14:19:07 2017 +0900 Committer: myui <[email protected]> Committed: Tue Jun 6 14:19:07 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/tools/text/TokenizeUDF.java | 3 + .../topicmodel/IncrementalPLSAModel.java | 5 +- .../hivemall/topicmodel/LDAPredictUDAF.java | 8 +- .../main/java/hivemall/topicmodel/LDAUDTF.java | 49 ++++++----- .../hivemall/topicmodel/PLSAPredictUDAF.java | 11 ++- .../main/java/hivemall/topicmodel/PLSAUDTF.java | 48 ++++++----- .../hivemall/topicmodel/LDAPredictUDAFTest.java | 70 +++++++--------- .../java/hivemall/topicmodel/LDAUDTFTest.java | 87 +++++++++++++++++--- .../hivemall/topicmodel/OnlineLDAModelTest.java | 19 +++-- .../topicmodel/PLSAPredictUDAFTest.java | 4 +- .../java/hivemall/topicmodel/PLSAUDTFTest.java | 74 +++++++++++++++-- 11 files changed, 263 insertions(+), 115 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/main/java/hivemall/tools/text/TokenizeUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/text/TokenizeUDF.java b/core/src/main/java/hivemall/tools/text/TokenizeUDF.java index a5c8777..d306fa2 100644 --- a/core/src/main/java/hivemall/tools/text/TokenizeUDF.java +++ b/core/src/main/java/hivemall/tools/text/TokenizeUDF.java @@ -38,6 +38,9 @@ public final class TokenizeUDF extends UDF { } public List<Text> evaluate(Text input, boolean toLowerCase) { + if (input == null) { + return null; + } final List<Text> tokens = new ArrayList<Text>(); final StringTokenizer tokenizer = new StringTokenizer(input.toString(), DELIM); while (tokenizer.hasMoreElements()) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java index a75febb..6eef23e 100644 --- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java +++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java @@ -264,8 +264,9 @@ public final class IncrementalPLSAModel { } if (p_dw == 0.d) { - throw new IllegalStateException("Perplexity would be Infinity. " - + "Try different mini-batch size `-s`, larger `-delta` and/or larger `-alpha`."); + throw new IllegalStateException( + "Perplexity would be Infinity. " + + "Try different mini-batch size `-s`, larger `-delta` and/or larger `-alpha`."); } numer += w_value * Math.log(p_dw); denom += w_value; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java index 8d1edd8..03779b0 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java @@ -179,14 +179,16 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { String rawArgs = HiveUtils.getConstString(argOIs[4]); cl = parseOptions(rawArgs); - this.topics = Primitives.parseInt(cl.getOptionValue("topics"), LDAUDTF.DEFAULT_TOPICS); + this.topics = Primitives.parseInt(cl.getOptionValue("topics"), + LDAUDTF.DEFAULT_TOPICS); if (topics < 1) { throw new UDFArgumentException( - "A positive integer MUST be set to an option `-topics`: " + topics); + "A positive integer MUST be set to an option `-topics`: " + topics); } this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics); - this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), LDAUDTF.DEFAULT_DELTA); + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), + LDAUDTF.DEFAULT_DELTA); } else { this.topics = LDAUDTF.DEFAULT_TOPICS; this.alpha = 1.f / topics; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java index 1cec875..daec7ea 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java @@ -22,6 +22,7 @@ import hivemall.UDTFWithOptions; import hivemall.annotations.VisibleForTesting; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NIOUtils; import hivemall.utils.io.NioStatefullSegment; import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Primitives; @@ -260,28 +261,26 @@ public class LDAUDTF extends UDTFWithOptions { this.fileIO = dst = new NioStatefullSegment(file, false); } - int wcLength = 0; + // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... + int wcLengthTotal = 0; for (String wc : wordCounts) { if (wc == null) { continue; } - wcLength += wc.getBytes().length; + wcLengthTotal += wc.length(); } - // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... - int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + wcLength; + int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * wordCounts.length + wcLengthTotal + * SizeOf.CHAR; + int remain = buf.remaining(); - if (remain < recordBytes) { + if (remain < requiredRecordBytes) { writeBuffer(buf, dst); } - buf.putInt(recordBytes); + buf.putInt(requiredRecordBytes); buf.putInt(wordCounts.length); for (String wc : wordCounts) { - if (wc == null) { - continue; - } - buf.putInt(wc.length()); - buf.put(wc.getBytes()); + NIOUtils.putString(wc, buf); } } @@ -351,10 +350,7 @@ public class LDAUDTF extends UDTFWithOptions { int wcLength = buf.getInt(); final String[] wordCounts = new String[wcLength]; for (int j = 0; j < wcLength; j++) { - int len = buf.getInt(); - byte[] bytes = new byte[len]; - buf.get(bytes); - wordCounts[j] = new String(bytes); + wordCounts[j] = NIOUtils.getString(buf); } miniBatch[miniBatchCount] = wordCounts; @@ -393,7 +389,6 @@ public class LDAUDTF extends UDTFWithOptions { + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + " training updates in total) "); } else {// read training examples in the temporary file and invoke train for each example - // write training examples in buffer to a temporary file if (buf.remaining() > 0) { writeBuffer(buf, dst); @@ -452,7 +447,7 @@ public class LDAUDTF extends UDTFWithOptions { } while (remain >= SizeOf.INT) { int pos = buf.position(); - int recordBytes = buf.getInt(); + int recordBytes = buf.getInt() - SizeOf.INT; remain -= SizeOf.INT; if (remain < recordBytes) { buf.position(pos); @@ -462,10 +457,7 @@ public class LDAUDTF extends UDTFWithOptions { int wcLength = buf.getInt(); final String[] wordCounts = new String[wcLength]; for (int j = 0; j < wcLength; j++) { - int len = buf.getInt(); - byte[] bytes = new byte[len]; - buf.get(bytes); - wordCounts[j] = new String(bytes); + wordCounts[j] = NIOUtils.getString(buf); } miniBatch[miniBatchCount] = wordCounts; @@ -554,6 +546,21 @@ public class LDAUDTF extends UDTFWithOptions { */ @VisibleForTesting + public void closeWithoutModelReset() throws HiveException { + // launch close(), but not forward & clear model + if (count == 0) { + this.model = null; + return; + } + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + } + if (iterations > 1) { + runIterativeTraining(iterations); + } + } + + @VisibleForTesting double getLambda(String label, int k) { return model.getLambda(label, k); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java index 08febb4..ff29236 100644 --- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java +++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java @@ -180,14 +180,17 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver { String rawArgs = HiveUtils.getConstString(argOIs[4]); cl = parseOptions(rawArgs); - this.topics = Primitives.parseInt(cl.getOptionValue("topics"), PLSAUDTF.DEFAULT_TOPICS); + this.topics = Primitives.parseInt(cl.getOptionValue("topics"), + PLSAUDTF.DEFAULT_TOPICS); if (topics < 1) { throw new UDFArgumentException( - "A positive integer MUST be set to an option `-topics`: " + topics); + "A positive integer MUST be set to an option `-topics`: " + topics); } - this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), PLSAUDTF.DEFAULT_ALPHA); - this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA); + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), + PLSAUDTF.DEFAULT_ALPHA); + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), + PLSAUDTF.DEFAULT_DELTA); } else { this.topics = PLSAUDTF.DEFAULT_TOPICS; this.alpha = PLSAUDTF.DEFAULT_ALPHA; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java index 014356e..46f731f 100644 --- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java @@ -22,6 +22,7 @@ import hivemall.UDTFWithOptions; import hivemall.annotations.VisibleForTesting; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NIOUtils; import hivemall.utils.io.NioStatefullSegment; import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Primitives; @@ -230,28 +231,26 @@ public class PLSAUDTF extends UDTFWithOptions { this.fileIO = dst = new NioStatefullSegment(file, false); } - int wcLength = 0; + // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... + int wcLengthTotal = 0; for (String wc : wordCounts) { if (wc == null) { continue; } - wcLength += wc.getBytes().length; + wcLengthTotal += wc.length(); } - // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... - int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + wcLength; + int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * wordCounts.length + wcLengthTotal + * SizeOf.CHAR; + int remain = buf.remaining(); - if (remain < recordBytes) { + if (remain < requiredRecordBytes) { writeBuffer(buf, dst); } - buf.putInt(recordBytes); + buf.putInt(requiredRecordBytes); buf.putInt(wordCounts.length); for (String wc : wordCounts) { - if (wc == null) { - continue; - } - buf.putInt(wc.length()); - buf.put(wc.getBytes()); + NIOUtils.putString(wc, buf); } } @@ -321,10 +320,7 @@ public class PLSAUDTF extends UDTFWithOptions { int wcLength = buf.getInt(); final String[] wordCounts = new String[wcLength]; for (int j = 0; j < wcLength; j++) { - int len = buf.getInt(); - byte[] bytes = new byte[len]; - buf.get(bytes); - wordCounts[j] = new String(bytes); + wordCounts[j] = NIOUtils.getString(buf); } miniBatch[miniBatchCount] = wordCounts; @@ -422,7 +418,7 @@ public class PLSAUDTF extends UDTFWithOptions { } while (remain >= SizeOf.INT) { int pos = buf.position(); - int recordBytes = buf.getInt(); + int recordBytes = buf.getInt() - SizeOf.INT; remain -= SizeOf.INT; if (remain < recordBytes) { buf.position(pos); @@ -432,10 +428,7 @@ public class PLSAUDTF extends UDTFWithOptions { int wcLength = buf.getInt(); final String[] wordCounts = new String[wcLength]; for (int j = 0; j < wcLength; j++) { - int len = buf.getInt(); - byte[] bytes = new byte[len]; - buf.get(bytes); - wordCounts[j] = new String(bytes); + wordCounts[j] = NIOUtils.getString(buf); } miniBatch[miniBatchCount] = wordCounts; @@ -524,6 +517,21 @@ public class PLSAUDTF extends UDTFWithOptions { */ @VisibleForTesting + public void closeWithoutModelReset() throws HiveException { + // launch close(), but not forward & clear model + if (count == 0) { + this.model = null; + return; + } + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + } + if (iterations > 1) { + runIterativeTraining(iterations); + } + } + + @VisibleForTesting double getProbability(String label, int k) { return model.getProbability(label, k); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java index e09e57e..5099c66 100644 --- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java +++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java @@ -53,14 +53,12 @@ public class LDAPredictUDAFTest { ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("wcList"); - fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector)); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector)); fieldNames.add("lambdaMap"); fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - ObjectInspectorFactory.getStandardListObjectInspector( - PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); fieldNames.add("topics"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); @@ -74,11 +72,14 @@ public class LDAPredictUDAFTest { partialOI = new ObjectInspector[4]; partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); - words = new String[] {"fruits", "vegetables", "healthy", "flu", "apples", "oranges", "like", "avocados", "colds", - "colds", "avocados", "oranges", "like", "apples", "flu", "healthy", "vegetables", "fruits"}; + words = new String[] {"fruits", "vegetables", "healthy", "flu", "apples", "oranges", + "like", "avocados", "colds", "colds", "avocados", "oranges", "like", "apples", + "flu", "healthy", "vegetables", "fruits"}; labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - lambdas = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 3.2804057E-4f, 3.0303953E-4f, 2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 1.352576E-4f, - 0.1660153f, 0.16596903f, 0.1659654f, 0.1659627f, 0.16593699f, 0.1659259f, 0.0017611005f, 0.0015791848f, 8.84464E-4f}; + lambdas = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 3.2804057E-4f, 3.0303953E-4f, + 2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 1.352576E-4f, 0.1660153f, 0.16596903f, + 0.1659654f, 0.1659627f, 0.16593699f, 0.1659259f, 0.0017611005f, 0.0015791848f, + 8.84464E-4f}; } @Test @@ -86,16 +87,12 @@ public class LDAPredictUDAFTest { udaf = new LDAPredictUDAF(); inputOIs = new ObjectInspector[] { - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.STRING), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.FLOAT), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.INT), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT), ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); @@ -140,16 +137,12 @@ public class LDAPredictUDAFTest { udaf = new LDAPredictUDAF(); inputOIs = new ObjectInspector[] { - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.STRING), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.FLOAT), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.INT), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT), ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); @@ -169,7 +162,8 @@ public class LDAPredictUDAFTest { evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 0; i < 6; i++) { - evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + evaluator.iterate(agg, + new Object[] {words[i], doc.get(words[i]), labels[i], lambdas[i]}); } partials[0] = evaluator.terminatePartial(agg); @@ -177,7 +171,8 @@ public class LDAPredictUDAFTest { evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 6; i < 12; i++) { - evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + evaluator.iterate(agg, + new Object[] {words[i], doc.get(words[i]), labels[i], lambdas[i]}); } partials[1] = evaluator.terminatePartial(agg); @@ -185,13 +180,14 @@ public class LDAPredictUDAFTest { evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 12; i < 18; i++) { - evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]}); + evaluator.iterate(agg, + new Object[] {words[i], doc.get(words[i]), labels[i], lambdas[i]}); } partials[2] = evaluator.terminatePartial(agg); // merge in a different order - final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}}; + final int[][] orders = new int[][] { {0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}}; for (int i = 0; i < orders.length; i++) { evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI); evaluator.reset(agg); @@ -210,14 +206,10 @@ public class LDAPredictUDAFTest { udaf = new LDAPredictUDAF(); inputOIs = new ObjectInspector[] { - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.STRING), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.FLOAT), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.INT), - PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveObjectInspector.PrimitiveCategory.FLOAT)}; + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT)}; evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java index a5881d4..a934ba3 100644 --- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java +++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java @@ -40,18 +40,20 @@ public class LDAUDTFTest { LDAUDTF udtf = new LDAUDTF(); ObjectInspector[] argOIs = new ObjectInspector[] { - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), - ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2 -num_docs 2 -s 1")}; + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3")}; udtf.initialize(argOIs); - String[] doc1 = new String[]{"fruits:1", "healthy:1", "vegetables:1"}; - String[] doc2 = new String[]{"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"}; - for (int it = 0; it < 5; it++) { - udtf.process(new Object[]{ Arrays.asList(doc1) }); - udtf.process(new Object[]{ Arrays.asList(doc2) }); - } + String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"}; + String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", + "oranges:1"}; + udtf.process(new Object[] {Arrays.asList(doc1)}); + udtf.process(new Object[] {Arrays.asList(doc2)}); + + udtf.closeWithoutModelReset(); SortedMap<Float, List<String>> topicWords; @@ -89,13 +91,76 @@ public class LDAUDTFTest { } Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " - + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", + + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1)); Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " - + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", + + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2)); } + @Test + public void testMultiBytes() throws HiveException { + LDAUDTF udtf = new LDAUDTF(); + + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3")}; + + udtf.initialize(argOIs); + + String[] doc1 = new String[] {"æç©:1", "å¥åº·:1", "éè:1"}; + String[] doc2 = new String[] {"ããã:1", "ã¢ãã«ã:1", "風éª:1", "ã¤ã³ãã«ã¨ã³ã¶:1", "好ã:2", "ã¿ãã:1"}; + + udtf.process(new Object[] {Arrays.asList(doc1)}); + udtf.process(new Object[] {Arrays.asList(doc2)}); + + udtf.closeWithoutModelReset(); + + SortedMap<Float, List<String>> topicWords; + + println("Topic 0:"); + println("========"); + topicWords = udtf.getTopicWords(0); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + println("Topic 1:"); + println("========"); + topicWords = udtf.getTopicWords(1); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + int k1, k2; + float[] topicDistr = udtf.getTopicDistribution(doc1); + if (topicDistr[0] > topicDistr[1]) { + // topic 0 MUST represent doc#1 + k1 = 0; + k2 = 1; + } else { + k1 = 1; + k2 = 0; + } + + Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + + "and `éè` SHOULD be more suitable topic word than `ã¤ã³ãã«ã¨ã³ã¶` in the topic", + udtf.getLambda("éè", k1) > udtf.getLambda("ã¤ã³ãã«ã¨ã³ã¶", k1)); + Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + + "and `ã¢ãã«ã` SHOULD be more suitable topic word than `å¥åº·` in the topic", + udtf.getLambda("ã¢ãã«ã", k2) > udtf.getLambda("å¥åº·", k2)); + } + private static void println(String msg) { if (DEBUG) { System.out.println(msg); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java index b4810a6..5b0a8c2 100644 --- a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java +++ b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java @@ -52,7 +52,8 @@ public class OnlineLDAModelTest { OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, 2, 80, 0.8, 1E-5d); String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"}; - String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"}; + String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", + "oranges:1"}; do { perplexityPrev = perplexity; @@ -69,7 +70,7 @@ public class OnlineLDAModelTest { it++; println("Iteration " + it + ": mean perplexity = " + perplexity); - } while(Math.abs(perplexityPrev - perplexity) >= 1E-6f); + } while (Math.abs(perplexityPrev - perplexity) >= 1E-6f); SortedMap<Float, List<String>> topicWords; @@ -106,10 +107,10 @@ public class OnlineLDAModelTest { k2 = 0; } Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " - + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", + + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", model.getLambda("vegetables", k1) > model.getLambda("flu", k1)); Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " - + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", + + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", model.getLambda("avocados", k2) > model.getLambda("healthy", k2)); } @@ -123,7 +124,8 @@ public class OnlineLDAModelTest { OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, 2, 80, 0.8, 1E-5d); String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"}; - String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"}; + String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", + "oranges:1"}; do { perplexityPrev = perplexity; @@ -132,7 +134,7 @@ public class OnlineLDAModelTest { perplexity = model.computePerplexity(); it++; - } while(Math.abs(perplexityPrev - perplexity) >= 1E-6f); + } while (Math.abs(perplexityPrev - perplexity) >= 1E-6f); println("Iterated " + it + " times, perplexity = " + perplexity); @@ -141,7 +143,8 @@ public class OnlineLDAModelTest { // returns perplexity=15 in a batch setting and perplexity=22 in an online setting. // Hivemall needs to converge to the similar perplexity. Assert.assertTrue("Perplexity SHOULD be in [12, 25]; " - + "converged perplexity is too small or large for some reasons",12.f <= perplexity && perplexity <= 25.f); + + "converged perplexity is too small or large for some reasons", 12.f <= perplexity + && perplexity <= 25.f); } @Test @@ -210,7 +213,7 @@ public class OnlineLDAModelTest { it++; println("Iteration " + it + ": mean perplexity = " + perplexity); - } while(Math.abs(perplexityPrev - perplexity) >= 1E-1f); + } while (Math.abs(perplexityPrev - perplexity) >= 1E-1f); Set<Integer> topics = new HashSet<Integer>(); for (int k = 0; k < K; k++) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java index 2be48e1..78f4a62 100644 --- a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java +++ b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java @@ -92,7 +92,7 @@ public class PLSAPredictUDAFTest { PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT), PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT), ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); @@ -142,7 +142,7 @@ public class PLSAPredictUDAFTest { PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT), PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT), ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/92e85ad1/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java index 76795bc..addacbc 100644 --- a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java +++ b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java @@ -43,17 +43,18 @@ public class PLSAUDTFTest { ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, - "-topics 2 -alpha 0.1 -delta 0.00001")}; + "-topics 2 -alpha 0.1 -delta 0.00001 -iter 10000")}; udtf.initialize(argOIs); String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"}; String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"}; - for (int it = 0; it < 10000; it++) { - udtf.process(new Object[] {Arrays.asList(doc1)}); - udtf.process(new Object[] {Arrays.asList(doc2)}); - } + + udtf.process(new Object[] {Arrays.asList(doc1)}); + udtf.process(new Object[] {Arrays.asList(doc2)}); + + udtf.closeWithoutModelReset(); SortedMap<Float, List<String>> topicWords; @@ -98,6 +99,69 @@ public class PLSAUDTFTest { udtf.getProbability("avocados", k2) > udtf.getProbability("healthy", k2)); } + @Test + public void testMultiBytes() throws HiveException { + PLSAUDTF udtf = new PLSAUDTF(); + + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-topics 2 -alpha 0.1 -delta 0.00001 -iter 10000")}; + + udtf.initialize(argOIs); + + String[] doc1 = new String[] {"æç©:1", "å¥åº·:1", "éè:1"}; + String[] doc2 = new String[] {"ããã:1", "ã¢ãã«ã:1", "風éª:1", "ã¤ã³ãã«ã¨ã³ã¶:1", "好ã:2", "ã¿ãã:1"}; + + udtf.process(new Object[] {Arrays.asList(doc1)}); + udtf.process(new Object[] {Arrays.asList(doc2)}); + + udtf.closeWithoutModelReset(); + + SortedMap<Float, List<String>> topicWords; + + println("Topic 0:"); + println("========"); + topicWords = udtf.getTopicWords(0); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + println("Topic 1:"); + println("========"); + topicWords = udtf.getTopicWords(1); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + int k1, k2; + float[] topicDistr = udtf.getTopicDistribution(doc1); + if (topicDistr[0] > topicDistr[1]) { + // topic 0 MUST represent doc#1 + k1 = 0; + k2 = 1; + } else { + k1 = 1; + k2 = 0; + } + + Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + + "and `éè` SHOULD be more suitable topic word than `ã¤ã³ãã«ã¨ã³ã¶` in the topic", + udtf.getProbability("éè", k1) > udtf.getProbability("ã¤ã³ãã«ã¨ã³ã¶", k1)); + Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + + "and `ã¢ãã«ã` SHOULD be more suitable topic word than `å¥åº·` in the topic", + udtf.getProbability("ã¢ãã«ã", k2) > udtf.getProbability("å¥åº·", k2)); + } + private static void println(String msg) { if (DEBUG) { System.out.println(msg);
