Author: srowen
Date: Sun Jun 5 16:04:16 2011
New Revision: 1132441
URL: http://svn.apache.org/viewvc?rev=1132441&view=rev
Log:
While here, class could be streamlined and use charsets explicitly
Modified:
mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
Modified:
mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java?rev=1132441&r1=1132440&r2=1132441&view=diff
==============================================================================
---
mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
(original)
+++
mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
Sun Jun 5 16:04:16 2011
@@ -17,6 +17,7 @@
package org.apache.mahout.clustering.lda;
+import com.google.common.base.Charsets;
import com.google.common.io.Closeables;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
@@ -37,12 +38,15 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.utils.vectors.VectorHelper;
import java.io.File;
+import java.io.FileOutputStream;
import java.io.IOException;
+import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -54,42 +58,12 @@ import java.util.Queue;
*/
public final class LDAPrintTopics {
- private LDAPrintTopics() { }
-
- private static class StringDoublePair implements
Comparable<StringDoublePair> {
- private double score;
- private final String word;
-
- StringDoublePair(double score, String word) {
- this.score = score;
- this.word = word;
- }
-
- @Override
- public int compareTo(StringDoublePair other) {
- return Double.compare(score, other.score);
- }
-
- @Override
- public boolean equals(Object o) {
- if (!(o instanceof StringDoublePair)) {
- return false;
- }
- StringDoublePair other = (StringDoublePair) o;
- return score == other.score && word.equals(other.word);
- }
-
- @Override
- public int hashCode() {
- return (int) Double.doubleToLongBits(score) ^ word.hashCode();
- }
-
- }
+ private LDAPrintTopics() { }
// Expands the queue list to have a Queue for topic K
- private static void
ensureQueueSize(Collection<PriorityQueue<StringDoublePair>> queues, int k) {
+ private static void ensureQueueSize(Collection<Queue<Pair<String,Double>>>
queues, int k) {
for (int i = queues.size(); i <= k; ++i) {
- queues.add(new PriorityQueue<StringDoublePair>());
+ queues.add(new PriorityQueue<Pair<String,Double>>());
}
}
@@ -154,7 +128,7 @@ public final class LDAPrintTopics {
throw new IllegalArgumentException("Invalid dictionary format");
}
- List<PriorityQueue<StringDoublePair>> topWords =
topWordsForTopics(input, config, wordList, numWords);
+ List<Queue<Pair<String,Double>>> topWords = topWordsForTopics(input,
config, wordList, numWords);
File output = null;
if (cmdLine.hasOption(outOpt)) {
@@ -171,38 +145,44 @@ public final class LDAPrintTopics {
}
// Adds the word if the queue is below capacity, or the score is high enough
- private static void maybeEnqueue(Queue<StringDoublePair> q, String word,
double score, int numWordsToPrint) {
- if (q.size() >= numWordsToPrint && score > q.peek().score) {
+ private static void maybeEnqueue(Queue<Pair<String,Double>> q, String word,
double score, int numWordsToPrint) {
+ if (q.size() >= numWordsToPrint && score > q.peek().getSecond()) {
q.poll();
}
if (q.size() < numWordsToPrint) {
- q.add(new StringDoublePair(score, word));
+ q.add(new Pair<String,Double>(word, score));
}
}
- private static void printTopWords(List<PriorityQueue<StringDoublePair>>
topWords, File outputDir)
+ private static void printTopWords(List<Queue<Pair<String,Double>>> topWords,
File outputDir)
throws IOException {
for (int i = 0; i < topWords.size(); ++i) {
- PriorityQueue<StringDoublePair> topK = topWords.get(i);
+ Collection<Pair<String,Double>> topK = topWords.get(i);
PrintWriter out = null;
boolean printingToSystemOut = false;
try {
if (outputDir != null) {
- out = new PrintWriter(new File(outputDir, "topic_" + i));
+ out = new PrintWriter(new OutputStreamWriter(
+ new FileOutputStream(new File(outputDir, "topic_" + i)),
Charsets.UTF_8));
} else {
- out = new PrintWriter(System.out);
+ out = new PrintWriter(new OutputStreamWriter(System.out,
Charsets.UTF_8));
printingToSystemOut = true;
out.println("Topic " + i);
out.println("===========");
}
- List<StringDoublePair> topKasList = new
ArrayList<StringDoublePair>(topK.size());
- for(StringDoublePair wordWithScore : topK) {
+ List<Pair<String,Double>> topKasList = new
ArrayList<Pair<String,Double>>(topK.size());
+ for(Pair<String,Double> wordWithScore : topK) {
topKasList.add(wordWithScore);
}
- Collections.sort(topKasList, Collections.reverseOrder());
- for(StringDoublePair wordWithScore : topKasList) {
- out.println(wordWithScore.word + " [p(" + wordWithScore.word +
"|topic_" + i +") = "
- + wordWithScore.score);
+ Collections.sort(topKasList, new Comparator<Pair<String,Double>>() {
+ @Override
+ public int compare(Pair<String,Double> pair1, Pair<String,Double>
pair2) {
+ return pair2.getSecond().compareTo(pair1.getSecond());
+ }
+ });
+ for(Pair<String,Double> wordWithScore : topKasList) {
+ out.println(wordWithScore.getFirst() + " [p(" +
wordWithScore.getFirst() + "|topic_" + i +") = "
+ + wordWithScore.getSecond());
}
} finally {
if (!printingToSystemOut) {
@@ -212,11 +192,11 @@ public final class LDAPrintTopics {
}
}
- private static List<PriorityQueue<StringDoublePair>>
topWordsForTopics(String dir,
- Configuration job,
- List<String> wordList,
- int numWordsToPrint) {
- List<PriorityQueue<StringDoublePair>> queues = new
ArrayList<PriorityQueue<StringDoublePair>>();
+ private static List<Queue<Pair<String,Double>>> topWordsForTopics(String dir,
+
Configuration job,
+
List<String> wordList,
+ int
numWordsToPrint) {
+ List<Queue<Pair<String,Double>>> queues = new
ArrayList<Queue<Pair<String,Double>>>();
Map<Integer,Double> expSums = new HashMap<Integer, Double>();
for (Pair<IntPairWritable,DoubleWritable> record :
new SequenceFileDirIterable<IntPairWritable, DoubleWritable>(
@@ -227,19 +207,22 @@ public final class LDAPrintTopics {
ensureQueueSize(queues, topic);
if (word >= 0 && topic >= 0) {
double score = record.getSecond().get();
- if(expSums.get(topic) == null) {
- expSums.put(topic, 0d);
+ if (expSums.get(topic) == null) {
+ expSums.put(topic, 0.0);
}
expSums.put(topic, expSums.get(topic) + Math.exp(score));
String realWord = wordList.get(word);
maybeEnqueue(queues.get(topic), realWord, score, numWordsToPrint);
}
}
- for(int i=0; i<queues.size(); i++) {
- PriorityQueue<StringDoublePair> queue = queues.get(i);
- for(StringDoublePair pair : queue) {
- pair.score = Math.exp(pair.score) / expSums.get(i);
+ for (int i = 0; i < queues.size(); i++) {
+ Queue<Pair<String,Double>> queue = queues.get(i);
+ Queue<Pair<String,Double>> newQueue = new PriorityQueue<Pair<String,
Double>>(queue.size());
+ double norm = expSums.get(i);
+ for (Pair<String,Double> pair : queue) {
+ newQueue.add(new Pair<String,Double>(pair.getFirst(),
Math.exp(pair.getSecond()) / norm));
}
+ queues.set(i, newQueue);
}
return queues;
}