Author: robinanil Date: Thu Feb 18 15:12:14 2010 New Revision: 911432 URL: http://svn.apache.org/viewvc?rev=911432&view=rev Log: MAHOUT-296 testclassifier and trainclassifier added to shell script. Test classifier now uses correct label from the key
Modified: lucene/mahout/trunk/bin/mahout lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java Modified: lucene/mahout/trunk/bin/mahout URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/bin/mahout?rev=911432&r1=911431&r2=911432&view=diff ============================================================================== --- lucene/mahout/trunk/bin/mahout (original) +++ lucene/mahout/trunk/bin/mahout Thu Feb 18 15:12:14 2010 @@ -61,6 +61,8 @@ echo " kmeans run kmeans clustering" echo " lda run LDA clustering" echo " lucenevector generate vectors from a lucene index" + echo " trainclassifier run Bayes/CBayes classifier training job" + echo " testclassifier test Bayes/CBayes model using a pre-classified data" echo " meanshift run Mean Shift clustering" echo " seqdirectory generate sequence files containing the documents beneathe a directory" echo " seqdump dump a sequence files using the writable toString() method" @@ -200,6 +202,10 @@ CLASS=org.apache.mahout.text.WikipediaToSequenceFile elif [ "$COMMAND" = "seq2sparse" ]; then CLASS=org.apache.mahout.text.SparseVectorsFromSequenceFiles +elif [ "$COMMAND" = "trainclassifier" ]; then + CLASS=org.apache.mahout.classifier.bayes.TrainClassifier +elif [ "$COMMAND" = "testclassifier" ]; then + CLASS=org.apache.mahout.classifier.bayes.TestClassifier else CLASS=$COMMAND fi Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=911432&r1=911431&r2=911432&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java Thu Feb 18 15:12:14 2010 @@ -108,7 +108,7 @@ abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create()).withDescription( "Location of model: hdfs|hbase Default Value: hdfs").withShortName("source").create(); - Option methodOpt = obuilder.withLongName("method").withRequired(true).withArgument( + Option methodOpt = obuilder.withLongName("method").withRequired(false).withArgument( abuilder.withName("method").withMinimum(1).withMaximum(1).create()).withDescription( "Method of Classification: sequential|mapreduce. Default Value: sequential").withShortName("method") .create(); @@ -158,7 +158,10 @@ String testDirPath = (String) cmdLine.getValue(dirOpt); - String classificationMethod = (String) cmdLine.getValue(methodOpt); + String classificationMethod = "sequential"; + if (cmdLine.hasOption(methodOpt)) { + classificationMethod = (String) cmdLine.getValue(methodOpt); + } params.set("verbose", Boolean.toString(verbose)); params.set("basePath", modelBasePath); @@ -229,9 +232,10 @@ if (subdirs != null) { for (File file : subdirs) { - log.info("--------------"); - log.info("Testing: {}", file); - String correctLabel = file.getName().split(".txt")[0]; + if (verbose) { + log.info("--------------"); + log.info("Testing: {}", file); + } TimingStatistics operationStats = new TimingStatistics(); long lineNum = 0; @@ -241,6 +245,7 @@ Map<String,List<String>> document = new NGrams(line, Integer.parseInt(params.get("gramSize"))) .generateNGrams(); for (Map.Entry<String,List<String>> stringListEntry : document.entrySet()) { + String correctLabel = stringListEntry.getKey(); List<String> strings = stringListEntry.getValue(); TimingStatistics.Call call = operationStats.newCall(); TimingStatistics.Call outercall = totalStatistics.newCall(); @@ -261,11 +266,14 @@ } lineNum++; } - log.info("{}\t{}\t{}/{}", + /*log.info("{}\t{}\t{}/{}", new Object[] {correctLabel, resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel), resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel), - resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)}); - log.info("{}", operationStats.toString()); + resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)});*/ + log.info("Classified instances from {}", file.getName()); + if (verbose) { + log.info("Performance stats {}", operationStats.toString()); + } } }