Author: srowen
Date: Sun Jun  5 16:04:16 2011
New Revision: 1132441

URL: http://svn.apache.org/viewvc?rev=1132441&view=rev
Log:
While here, class could be streamlined and use charsets explicitly

Modified:
    
mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java

Modified: 
mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java?rev=1132441&r1=1132440&r2=1132441&view=diff
==============================================================================
--- 
mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
 (original)
+++ 
mahout/trunk/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
 Sun Jun  5 16:04:16 2011
@@ -17,6 +17,7 @@
 
 package org.apache.mahout.clustering.lda;
 
+import com.google.common.base.Charsets;
 import com.google.common.io.Closeables;
 import org.apache.commons.cli2.CommandLine;
 import org.apache.commons.cli2.Group;
@@ -37,12 +38,15 @@ import org.apache.mahout.common.iterator
 import org.apache.mahout.utils.vectors.VectorHelper;
 
 import java.io.File;
+import java.io.FileOutputStream;
 import java.io.IOException;
+import java.io.OutputStreamWriter;
 import java.io.PrintWriter;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -54,42 +58,12 @@ import java.util.Queue;
  */
 public final class LDAPrintTopics {
 
-  private LDAPrintTopics() { }  
-  
-  private static class StringDoublePair implements 
Comparable<StringDoublePair> {
-    private double score;
-    private final String word;
-    
-    StringDoublePair(double score, String word) {
-      this.score = score;
-      this.word = word;
-    }
-    
-    @Override
-    public int compareTo(StringDoublePair other) {
-      return Double.compare(score, other.score);
-    }
-    
-    @Override
-    public boolean equals(Object o) {
-      if (!(o instanceof StringDoublePair)) {
-        return false;
-      }
-      StringDoublePair other = (StringDoublePair) o;
-      return score == other.score && word.equals(other.word);
-    }
-    
-    @Override
-    public int hashCode() {
-      return (int) Double.doubleToLongBits(score) ^ word.hashCode();
-    }
-    
-  }
+  private LDAPrintTopics() { }
   
   // Expands the queue list to have a Queue for topic K
-  private static void 
ensureQueueSize(Collection<PriorityQueue<StringDoublePair>> queues, int k) {
+  private static void ensureQueueSize(Collection<Queue<Pair<String,Double>>> 
queues, int k) {
     for (int i = queues.size(); i <= k; ++i) {
-      queues.add(new PriorityQueue<StringDoublePair>());
+      queues.add(new PriorityQueue<Pair<String,Double>>());
     }
   }
   
@@ -154,7 +128,7 @@ public final class LDAPrintTopics {
         throw new IllegalArgumentException("Invalid dictionary format");
       }
       
-      List<PriorityQueue<StringDoublePair>> topWords = 
topWordsForTopics(input, config, wordList, numWords);
+      List<Queue<Pair<String,Double>>> topWords = topWordsForTopics(input, 
config, wordList, numWords);
 
       File output = null;
       if (cmdLine.hasOption(outOpt)) {
@@ -171,38 +145,44 @@ public final class LDAPrintTopics {
   }
   
   // Adds the word if the queue is below capacity, or the score is high enough
-  private static void maybeEnqueue(Queue<StringDoublePair> q, String word, 
double score, int numWordsToPrint) {
-    if (q.size() >= numWordsToPrint && score > q.peek().score) {
+  private static void maybeEnqueue(Queue<Pair<String,Double>> q, String word, 
double score, int numWordsToPrint) {
+    if (q.size() >= numWordsToPrint && score > q.peek().getSecond()) {
       q.poll();
     }
     if (q.size() < numWordsToPrint) {
-      q.add(new StringDoublePair(score, word));
+      q.add(new Pair<String,Double>(word, score));
     }
   }
   
-  private static void printTopWords(List<PriorityQueue<StringDoublePair>> 
topWords, File outputDir)
+  private static void printTopWords(List<Queue<Pair<String,Double>>> topWords, 
File outputDir)
     throws IOException {
     for (int i = 0; i < topWords.size(); ++i) {
-      PriorityQueue<StringDoublePair> topK = topWords.get(i);
+      Collection<Pair<String,Double>> topK = topWords.get(i);
       PrintWriter out = null;
       boolean printingToSystemOut = false;
       try {
         if (outputDir != null) {
-          out = new PrintWriter(new File(outputDir, "topic_" + i));
+          out = new PrintWriter(new OutputStreamWriter(
+              new FileOutputStream(new File(outputDir, "topic_" + i)), 
Charsets.UTF_8));
         } else {
-          out = new PrintWriter(System.out);
+          out = new PrintWriter(new OutputStreamWriter(System.out, 
Charsets.UTF_8));
           printingToSystemOut = true;
           out.println("Topic " + i);
           out.println("===========");
         }
-        List<StringDoublePair> topKasList = new 
ArrayList<StringDoublePair>(topK.size());
-        for(StringDoublePair wordWithScore : topK) {
+        List<Pair<String,Double>> topKasList = new 
ArrayList<Pair<String,Double>>(topK.size());
+        for(Pair<String,Double> wordWithScore : topK) {
           topKasList.add(wordWithScore);
         }
-        Collections.sort(topKasList, Collections.reverseOrder());
-        for(StringDoublePair wordWithScore : topKasList) {
-          out.println(wordWithScore.word + " [p(" + wordWithScore.word + 
"|topic_" + i +") = "
-           + wordWithScore.score);
+        Collections.sort(topKasList, new Comparator<Pair<String,Double>>() {
+          @Override
+          public int compare(Pair<String,Double> pair1, Pair<String,Double> 
pair2) {
+            return pair2.getSecond().compareTo(pair1.getSecond());
+          }
+        });
+        for(Pair<String,Double> wordWithScore : topKasList) {
+          out.println(wordWithScore.getFirst() + " [p(" + 
wordWithScore.getFirst() + "|topic_" + i +") = "
+           + wordWithScore.getSecond());
         }
       } finally {
         if (!printingToSystemOut) {
@@ -212,11 +192,11 @@ public final class LDAPrintTopics {
     }
   }
   
-  private static List<PriorityQueue<StringDoublePair>> 
topWordsForTopics(String dir,
-                                                      Configuration job,
-                                                      List<String> wordList,
-                                                      int numWordsToPrint) {
-    List<PriorityQueue<StringDoublePair>> queues = new 
ArrayList<PriorityQueue<StringDoublePair>>();
+  private static List<Queue<Pair<String,Double>>> topWordsForTopics(String dir,
+                                                                    
Configuration job,
+                                                                    
List<String> wordList,
+                                                                    int 
numWordsToPrint) {
+    List<Queue<Pair<String,Double>>> queues = new 
ArrayList<Queue<Pair<String,Double>>>();
     Map<Integer,Double> expSums = new HashMap<Integer, Double>();
     for (Pair<IntPairWritable,DoubleWritable> record :
          new SequenceFileDirIterable<IntPairWritable, DoubleWritable>(
@@ -227,19 +207,22 @@ public final class LDAPrintTopics {
       ensureQueueSize(queues, topic);
       if (word >= 0 && topic >= 0) {
         double score = record.getSecond().get();
-        if(expSums.get(topic) == null) {
-          expSums.put(topic, 0d);
+        if (expSums.get(topic) == null) {
+          expSums.put(topic, 0.0);
         }
         expSums.put(topic, expSums.get(topic) + Math.exp(score));
         String realWord = wordList.get(word);
         maybeEnqueue(queues.get(topic), realWord, score, numWordsToPrint);
       }
     }
-    for(int i=0; i<queues.size(); i++) {
-      PriorityQueue<StringDoublePair> queue = queues.get(i);
-      for(StringDoublePair pair : queue) {
-        pair.score = Math.exp(pair.score) / expSums.get(i);
+    for (int i = 0; i < queues.size(); i++) {
+      Queue<Pair<String,Double>> queue = queues.get(i);
+      Queue<Pair<String,Double>> newQueue = new PriorityQueue<Pair<String, 
Double>>(queue.size());
+      double norm = expSums.get(i);
+      for (Pair<String,Double> pair : queue) {
+        newQueue.add(new Pair<String,Double>(pair.getFirst(), 
Math.exp(pair.getSecond()) / norm));
       }
+      queues.set(i, newQueue);
     }
     return queues;
   }


Reply via email to