Repository: incubator-hivemall Updated Branches: refs/heads/master bba252ac1 -> e4e1531e1
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java new file mode 100644 index 0000000..b4810a6 --- /dev/null +++ b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.SortedMap; +import java.util.Set; +import java.util.HashSet; +import java.util.Arrays; +import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; + +import hivemall.classifier.KernelExpansionPassiveAggressiveUDTFTest; + +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nonnull; + +public class OnlineLDAModelTest { + private static final boolean DEBUG = false; + + @Test + public void test() { + int K = 2; + int it = 0; + float perplexityPrev; + float perplexity = Float.MAX_VALUE; + + OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, 2, 80, 0.8, 1E-5d); + + String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"}; + String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"}; + + do { + perplexityPrev = perplexity; + perplexity = 0.f; + + // online (i.e., one-by-one) updating + model.train(new String[][] {doc1}); + perplexity += model.computePerplexity(); + + model.train(new String[][] {doc2}); + perplexity += model.computePerplexity(); + + perplexity /= 2.f; // mean perplexity for the 2 docs + + it++; + println("Iteration " + it + ": mean perplexity = " + perplexity); + } while(Math.abs(perplexityPrev - perplexity) >= 1E-6f); + + SortedMap<Float, List<String>> topicWords; + + println("Topic 0:"); + println("========"); + topicWords = model.getTopicWords(0); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + println("Topic 1:"); + println("========"); + topicWords = model.getTopicWords(1); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + int k1, k2; + float[] topicDistr = model.getTopicDistribution(doc1); + if (topicDistr[0] > topicDistr[1]) { + // topic 0 MUST represent doc#1 + k1 = 0; + k2 = 1; + } else { + k1 = 1; + k2 = 0; + } + Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", + model.getLambda("vegetables", k1) > model.getLambda("flu", k1)); + Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", + model.getLambda("avocados", k2) > model.getLambda("healthy", k2)); + } + + @Test + public void testPerplexity() { + int K = 2; + int it = 0; + float perplexityPrev; + float perplexity = Float.MAX_VALUE; + + OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, 2, 80, 0.8, 1E-5d); + + String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"}; + String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"}; + + do { + perplexityPrev = perplexity; + + model.train(new String[][] {doc1, doc2}); + perplexity = model.computePerplexity(); + + it++; + } while(Math.abs(perplexityPrev - perplexity) >= 1E-6f); + + println("Iterated " + it + " times, perplexity = " + perplexity); + + // For the same data and hyperparameters, + // scikit-learn Python library (implemented based on Matthew D. Hoffman's onlineldavb code) + // returns perplexity=15 in a batch setting and perplexity=22 in an online setting. + // Hivemall needs to converge to the similar perplexity. + Assert.assertTrue("Perplexity SHOULD be in [12, 25]; " + + "converged perplexity is too small or large for some reasons",12.f <= perplexity && perplexity <= 25.f); + } + + @Test + public void testNews20() throws IOException { + int K = 20; + int numTotalDocs = 2000; + int miniBatchSize = 2; + + int cnt, it; + + OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, numTotalDocs, 80, 0.8, 1E-3d); + + BufferedReader news20 = readFile("news20-multiclass.gz"); + + String[][] docs = new String[K][]; + + String line = news20.readLine(); + List<String> doc = new ArrayList<String>(); + + cnt = 0; + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + + int k = Integer.parseInt(tokens.nextToken()) - 1; + + while (tokens.hasMoreTokens()) { + doc.add(tokens.nextToken()); + } + + // store first document in each of K classes + if (docs[k] == null) { + docs[k] = doc.toArray(new String[doc.size()]); + cnt++; + } + + if (cnt == K) { + break; + } + + doc.clear(); + line = news20.readLine(); + } + println("Stored " + cnt + " docs. Start training w/ mini-batch size: " + miniBatchSize); + + float perplexityPrev; + float perplexity = Float.MAX_VALUE; + + it = 0; + do { + perplexityPrev = perplexity; + perplexity = 0.f; + + int head = 0; + cnt = 0; + while (head < K) { + int tail = head + miniBatchSize; + model.train(Arrays.copyOfRange(docs, head, tail)); + perplexity += model.computePerplexity(); + head = tail; + cnt++; + println("Processed mini-batch#" + cnt); + } + + perplexity /= cnt; + + it++; + + println("Iteration " + it + ": mean perplexity = " + perplexity); + } while(Math.abs(perplexityPrev - perplexity) >= 1E-1f); + + Set<Integer> topics = new HashSet<Integer>(); + for (int k = 0; k < K; k++) { + topics.add(findMaxTopic(model.getTopicDistribution(docs[k]))); + } + + int n = topics.size(); + Assert.assertTrue("At least 15 documents SHOULD be classified to different topics, " + + "but there are only " + n + " unique topics.", n >= 15); + } + + private static void println(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + // use data stored for KPA UDTF test + InputStream is = KernelExpansionPassiveAggressiveUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + + @Nonnull + private static int findMaxTopic(@Nonnull float[] topicDistr) { + int maxIdx = 0; + for (int i = 1; i < topicDistr.length; i++) { + if (topicDistr[maxIdx] < topicDistr[i]) { + maxIdx = i; + } + } + return maxIdx; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index 4c6ed1b..78b1faa 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -150,7 +150,11 @@ * [Change-Point Detection using Singular Spectrum Transformation (SST)](anomaly/sst.md) * [ChangeFinder: Detecting Outlier and Change-Point Simultaneously](anomaly/changefinder.md) -## Part X - Hivemall on Spark +## Part X - Clustering + +* [Latent Dirichlet Allocation](clustering/lda.md) + +## Part XI - Hivemall on Spark * [Getting Started](spark/getting_started/README.md) * [Installation](spark/getting_started/installation.md) @@ -165,7 +169,7 @@ * [Top-k Join processing](spark/misc/topk_join.md) * [Other utility functions](spark/misc/functions.md) -## Part X - External References +## Part XII - External References * [Hivemall on Apache Spark](https://github.com/maropu/hivemall-spark) * [Hivemall on Apache Pig](https://github.com/daijyc/hivemall/wiki/PigHome) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/docs/gitbook/clustering/lda.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/clustering/lda.md b/docs/gitbook/clustering/lda.md new file mode 100644 index 0000000..cc477da --- /dev/null +++ b/docs/gitbook/clustering/lda.md @@ -0,0 +1,195 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +Topic modeling is a way to analyze massive documents by clustering them into some ***topics***. In particular, **Latent Dirichlet Allocation** (LDA) is one of the most popular topic modeling techniques; papers introduce the method are as follows: + +- D. M. Blei, et al. [Latent Dirichlet Allocation](http://www.jmlr.org/papers/v3/blei03a.html). Journal of Machine Learning Research 3, pp. 993-1022, 2003. +- M. D. Hoffman, et al. [Online Learning for Latent Dirichlet Allocation](https://papers.nips.cc/paper/3902-online-learning-for-latent-dirichlet-allocation). NIPS 2010. + +Hivemall enables you to analyze your data such as, but not limited to, documents based on LDA. This page gives usage instructions of the feature. + +<!-- toc --> + +> #### Note +> This feature is supported from Hivemall v0.5-rc.1 or later. + +# Prepare document data + +Assume that we already have a table `docs` which contains many documents as string format: + +| docid | doc | +|:---:|:---| +| 1 | "Fruits and vegetables are healthy." | +|2 | "I like apples, oranges, and avocados. I do not like the flu or colds." | +| ... | ... | + +Hivemall has several functions which are particularly useful for text processing. More specifically, by using `tokenize()` and `is_stopword()`, you can immediately convert the documents to [bag-of-words](https://en.wikipedia.org/wiki/Bag-of-words_model)-like format: + +```sql +with word_counts as ( + select + docid, + feature(word, count(word)) as word_count + from docs t1 LATERAL VIEW explode(tokenize(doc, true)) t2 as word + where + not is_stopword(word) + group by + docid, word +) +select docid, collect_set(word_count) as feature +from word_counts +group by docid +; +``` + +| docid | feature | +|:---:|:---| +|1 | ["fruits:1","healthy:1","vegetables:1"] | +|2 | ["apples:1","avocados:1","colds:1","flu:1","like:2","oranges:1"] | + +> #### Note +> It should be noted that, as long as your data can be represented as the feature format, LDA can be applied for arbitrary data as a generic clustering technique. + +# Building Topic Models and Finding Topic Words + +Each feature vector is input to the `train_lda()` function: + +```sql +with word_counts as ( + select + docid, + feature(word, count(word)) as word_count + from docs t1 LATERAL VIEW explode(tokenize(doc, true)) t2 as word + where + not is_stopword(word) + group by + docid, word +) +select + train_lda(feature, "-topic 2 -iter 20") as (label, word, lambda) +from ( + select docid, collect_set(word_count) as feature + from word_counts + group by docid + order by docid +) t +; +``` + +Here, an option `-topic 2` specifies the number of topics we assume in the set of documents. + +Notice that `order by docid` ensures building a LDA model precisely in a single node. In case that you like to launch `train_lda` in parallel, following query hopefully returns similar (but might be slightly approximated) result: + +```sql +with word_counts as ( + -- same as above +) +select + label, word, avg(lambda) as lambda +from ( + select + train_lda(feature, "-topic 2 -iter 20") as (label, word, lambda) + from ( + select docid, collect_set(f) as feature + from word_counts + group by docid + ) t1 +) t2 +group by label, word +order by lambda desc +; +``` + +Eventually, a new table `lda_model` is generated as shown below: + +|label | word | lambda | +|:---:|:---:|:---:| +|0 | fruits | 0.33372128| +|0 | vegetables | 0.33272517| +|0 | healthy | 0.33246377| +|0 | flu | 2.3617347E-4| +|0 | apples | 2.1898883E-4| +|0 | oranges | 1.8161473E-4| +|0 | like | 1.7666373E-4| +|0 | avocados | 1.726186E-4| +|0 | colds | 1.037139E-4| +|1 | colds | 0.16622013| +|1 | avocados | 0.16618845| +|1 | oranges | 0.1661859| +|1 | like | 0.16618414| +|1 | apples | 0.16616651| +|1 | flu | 0.16615893| +|1 | healthy | 0.0012059759| +|1 | vegetables | 0.0010818697| +|1 | fruits | 6.080827E-4| + +In the table, `label` indicates a topic index, and `lambda` is a value which represents how each word is likely to characterize a topic. That is, we can say that, in terms of `lambda`, top-N words are the ***topic words*** of a topic. + +Obviously, we can observe that topic `0` corresponds to document `1`, and topic `1` represents words in document `2`. + +# Predicting Topic Assignments of Documents + +Once you have constructed topic models as described before, a function `lda_predict()` allows you to predict topic assignments of documents. + +For example, if we consider the `docs` table, the exactly same set of documents as used for training, probability that a document is assigned to a topic can be computed by: + +```sql +with test as ( + select + docid, + word, + count(word) as value + from docs t1 LATERAL VIEW explode(tokenize(doc, true)) t2 as word + where + not is_stopword(word) + group by + docid, word +) +select + t.docid, + lda_predict(t.word, t.value, m.label, m.lambda, "-topic 2") as probabilities +from + test t + JOIN lda_model m ON (t.word = m.word) +group by + t.docid +; +``` + +| docid | probabilities (sorted by probabilities) | +|:---:|:---| +|1 | [{"label":0,"probability":0.875},{"label":1,"probability":0.125}]| +|2 | [{"label":1,"probability":0.9375},{"label":0,"probability":0.0625}]| + +Importantly, an option `-topic` should be set to the same value as you set for training. + +Since the probabilities are sorted in descending order, a label of the most promising topic is easily obtained as: + +```sql +select docid, probabilities[0].label +from topic +; +``` + +| docid | label | +|:---:|:---:| +| 1 | 0 | +| 2 | 1 | + +Of course, using the different set of documents for prediction is possible. Predicting topic assignments of newly observed documents should be more realistic scenario. http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/resources/ddl/define-all-as-permanent.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index c6dda03..1eb9c82 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -616,6 +616,16 @@ CREATE FUNCTION changefinder as 'hivemall.anomaly.ChangeFinderUDF' USING JAR '${ DROP FUNCTION IF EXISTS sst; CREATE FUNCTION sst as 'hivemall.anomaly.SingularSpectrumTransformUDF' USING JAR '${hivemall_jar}'; +-------------------- +-- Topic Modeling -- +-------------------- + +DROP FUNCTION IF EXISTS train_lda; +CREATE FUNCTION train_lda as 'hivemall.topicmodel.LDAUDTF' USING JAR '${hivemall_jar}'; + +DROP FUNCTION IF EXISTS lda_predict; +CREATE FUNCTION lda_predict as 'hivemall.topicmodel.LDAPredictUDAF' USING JAR '${hivemall_jar}'; + ---------------------------- -- Smile related features -- ---------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/resources/ddl/define-all.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 8ea16c1..b503546 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -612,6 +612,16 @@ create temporary function changefinder as 'hivemall.anomaly.ChangeFinderUDF'; drop temporary function if exists sst; create temporary function sst as 'hivemall.anomaly.SingularSpectrumTransformUDF'; +-------------------- +-- Topic Modeling -- +-------------------- + +drop temporary function if exists train_lda; +create temporary function train_lda as 'hivemall.topicmodel.LDAUDTF'; + +drop temporary function if exists lda_predict; +create temporary function lda_predict as 'hivemall.topicmodel.LDAPredictUDAF'; + ---------------------------- -- Smile related features -- ---------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/resources/ddl/define-all.spark ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index 0172cc8..b5239cf 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -597,6 +597,16 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS sst") sqlContext.sql("CREATE TEMPORARY FUNCTION sst AS 'hivemall.anomaly.SingularSpectrumTransformUDF'") /** + * Topic Modeling + */ + +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_lda") +sqlContext.sql("CREATE TEMPORARY FUNCTION train_lda AS 'hivemall.topicmodel.LDAUDTF'") + +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS lda_predict") +sqlContext.sql("CREATE TEMPORARY FUNCTION lda_predict AS 'hivemall.topicmodel.LDAPredictUDAF'") + +/** * Smile related features */ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/resources/ddl/define-udfs.td.hql ---------------------------------------------------------------------- diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql index cff0913..28d17ff 100644 --- a/resources/ddl/define-udfs.td.hql +++ b/resources/ddl/define-udfs.td.hql @@ -158,6 +158,8 @@ create temporary function guess_attribute_types as 'hivemall.smile.tools.GuessAt -- since Hivemall v0.5-rc.1 create temporary function changefinder as 'hivemall.anomaly.ChangeFinderUDF'; create temporary function sst as 'hivemall.anomaly.SingularSpectrumTransformUDF'; +create temporary function train_lda as 'hivemall.topicmodel.LDAUDTF'; +create temporary function lda_predict as 'hivemall.topicmodel.LDAPredictUDAF'; -- NLP features create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';
