Author: robinanil
Date: Thu Feb 18 15:12:14 2010
New Revision: 911432

URL: http://svn.apache.org/viewvc?rev=911432&view=rev
Log:
MAHOUT-296 testclassifier and trainclassifier added to shell script. Test 
classifier now uses correct label from the key

Modified:
    lucene/mahout/trunk/bin/mahout
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java

Modified: lucene/mahout/trunk/bin/mahout
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/bin/mahout?rev=911432&r1=911431&r2=911432&view=diff
==============================================================================
--- lucene/mahout/trunk/bin/mahout (original)
+++ lucene/mahout/trunk/bin/mahout Thu Feb 18 15:12:14 2010
@@ -61,6 +61,8 @@
   echo "  kmeans                run kmeans clustering"
   echo "  lda                   run LDA clustering"
   echo "  lucenevector          generate vectors from a lucene index"
+  echo "  trainclassifier       run Bayes/CBayes classifier training job"
+  echo "  testclassifier        test Bayes/CBayes model using a pre-classified 
data"
   echo "  meanshift             run Mean Shift clustering"
   echo "  seqdirectory          generate sequence files containing the 
documents beneathe a directory"
   echo "  seqdump               dump a sequence files using the writable 
toString() method"
@@ -200,6 +202,10 @@
   CLASS=org.apache.mahout.text.WikipediaToSequenceFile
 elif [ "$COMMAND" = "seq2sparse" ]; then
   CLASS=org.apache.mahout.text.SparseVectorsFromSequenceFiles  
+elif [ "$COMMAND" = "trainclassifier" ]; then
+  CLASS=org.apache.mahout.classifier.bayes.TrainClassifier  
+elif [ "$COMMAND" = "testclassifier" ]; then
+  CLASS=org.apache.mahout.classifier.bayes.TestClassifier
 else
   CLASS=$COMMAND
 fi

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=911432&r1=911431&r2=911432&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
 Thu Feb 18 15:12:14 2010
@@ -108,7 +108,7 @@
       
abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create()).withDescription(
       "Location of model: hdfs|hbase Default Value: 
hdfs").withShortName("source").create();
     
-    Option methodOpt = 
obuilder.withLongName("method").withRequired(true).withArgument(
+    Option methodOpt = 
obuilder.withLongName("method").withRequired(false).withArgument(
       
abuilder.withName("method").withMinimum(1).withMaximum(1).create()).withDescription(
       "Method of Classification: sequential|mapreduce. Default Value: 
sequential").withShortName("method")
         .create();
@@ -158,7 +158,10 @@
       
       String testDirPath = (String) cmdLine.getValue(dirOpt);
       
-      String classificationMethod = (String) cmdLine.getValue(methodOpt);
+      String classificationMethod = "sequential";
+      if (cmdLine.hasOption(methodOpt)) {
+        classificationMethod = (String) cmdLine.getValue(methodOpt);
+      }
       
       params.set("verbose", Boolean.toString(verbose));
       params.set("basePath", modelBasePath);
@@ -229,9 +232,10 @@
     if (subdirs != null) {
       
       for (File file : subdirs) {
-        log.info("--------------");
-        log.info("Testing: {}", file);
-        String correctLabel = file.getName().split(".txt")[0];
+        if (verbose) {
+          log.info("--------------");
+          log.info("Testing: {}", file);
+        }
         TimingStatistics operationStats = new TimingStatistics();
         
         long lineNum = 0;
@@ -241,6 +245,7 @@
           Map<String,List<String>> document = new NGrams(line, 
Integer.parseInt(params.get("gramSize")))
               .generateNGrams();
           for (Map.Entry<String,List<String>> stringListEntry : 
document.entrySet()) {
+            String correctLabel = stringListEntry.getKey();
             List<String> strings = stringListEntry.getValue();
             TimingStatistics.Call call = operationStats.newCall();
             TimingStatistics.Call outercall = totalStatistics.newCall();
@@ -261,11 +266,14 @@
           }
           lineNum++;
         }
-       log.info("{}\t{}\t{}/{}",
+       /*log.info("{}\t{}\t{}/{}",
           new Object[] {correctLabel, 
resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel),
                         
resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel),
-                        
resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)});
-       log.info("{}", operationStats.toString());
+                        
resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)});*/
+        log.info("Classified instances from {}", file.getName());
+        if (verbose) {
+          log.info("Performance stats {}", operationStats.toString());
+        }
       }
       
     }


Reply via email to