sorry for the inconvenience, refactored APIs changed constructor but forgot to update the related javadocs, it should be fixed now.
Tommaso 2015-04-30 18:28 GMT+02:00 Ryan Ernst <[email protected]>: > Did you mean to remove javadocs on BooleanPerceptionClassifier? This > breaks precommit... > > On Thu, Apr 30, 2015 at 7:12 AM, <[email protected]> wrote: > >> Author: tommaso >> Date: Thu Apr 30 14:12:03 2015 >> New Revision: 1676997 >> >> URL: http://svn.apache.org/r1676997 >> Log: >> LUCENE-6045 - refactor Classifier API to work better with multithreading >> >> Modified: >> >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java >> >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java >> >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java >> >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java >> >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java >> >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java >> >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java >> >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java >> >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java >> >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java >> >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java >> >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java >> >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java >> Thu Apr 30 14:12:03 2015 >> @@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util; >> */ >> public class BooleanPerceptronClassifier implements Classifier<Boolean> { >> >> - private Double threshold; >> - private final Integer batchSize; >> - private Terms textTerms; >> - private Analyzer analyzer; >> - private String textFieldName; >> + private final Double threshold; >> + private final Terms textTerms; >> + private final Analyzer analyzer; >> + private final String textFieldName; >> private FST<Long> fst; >> >> - /** >> - * Create a {@link BooleanPerceptronClassifier} >> - * >> - * @param threshold the binary threshold for perceptron output >> evaluation >> - */ >> - public BooleanPerceptronClassifier(Double threshold, Integer >> batchSize) { >> - this.threshold = threshold; >> - this.batchSize = batchSize; >> - } >> - >> - /** >> - * Default constructor, no batch updates of FST, perceptron threshold >> is >> - * calculated via underlying index metrics during >> - * {@link #train(org.apache.lucene.index.LeafReader, String, String, >> org.apache.lucene.analysis.Analyzer) >> - * training} >> - */ >> - public BooleanPerceptronClassifier() { >> - batchSize = 1; >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public ClassificationResult<Boolean> assignClass(String text) >> - throws IOException { >> - if (textTerms == null) { >> - throw new IOException("You must first call Classifier#train"); >> - } >> - Long output = 0l; >> - try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, >> text)) { >> - CharTermAttribute charTermAttribute = tokenStream >> - .addAttribute(CharTermAttribute.class); >> - tokenStream.reset(); >> - while (tokenStream.incrementToken()) { >> - String s = charTermAttribute.toString(); >> - Long d = Util.get(fst, new BytesRef(s)); >> - if (d != null) { >> - output += d; >> - } >> - } >> - tokenStream.end(); >> - } >> - >> - double score = 1 - Math.exp(-1 * Math.abs(threshold - >> output.doubleValue()) / threshold); >> - return new ClassificationResult<>(output >= threshold, score); >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String textFieldName, >> - String classFieldName, Analyzer analyzer) throws >> IOException { >> - train(leafReader, textFieldName, classFieldName, analyzer, null); >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String textFieldName, >> - String classFieldName, Analyzer analyzer, Query >> query) throws IOException { >> + public BooleanPerceptronClassifier(LeafReader leafReader, String >> textFieldName, String classFieldName, Analyzer analyzer, >> + Query query, Integer batchSize, >> Double threshold) throws IOException { >> this.textTerms = MultiFields.getTerms(leafReader, textFieldName); >> >> if (textTerms == null) { >> @@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier >> this.threshold = (double) sumDocFreq / 2d; >> } else { >> throw new IOException( >> - "threshold cannot be assigned since term vectors for field " >> - + textFieldName + " do not exist"); >> + "threshold cannot be assigned since term vectors for >> field " >> + + textFieldName + " do not exist"); >> } >> + } else { >> + this.threshold = threshold; >> } >> >> // TODO : remove this map as soon as we have a writable FST >> @@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier >> } >> // run the search and use stored field values >> for (ScoreDoc scoreDoc : indexSearcher.search(q, >> - Integer.MAX_VALUE).scoreDocs) { >> + Integer.MAX_VALUE).scoreDocs) { >> StoredDocument doc = indexSearcher.doc(scoreDoc.doc); >> >> StorableField textField = doc.getField(textFieldName); >> @@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier >> long modifier = correctClass.compareTo(assignedClass); >> if (modifier != 0) { >> updateWeights(leafReader, scoreDoc.doc, assignedClass, >> - weights, modifier, batchCount % batchSize == 0); >> + weights, modifier, batchCount % batchSize == 0); >> } >> batchCount++; >> } >> @@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier >> weights.clear(); // free memory while waiting for GC >> } >> >> - @Override >> - public void train(LeafReader leafReader, String[] textFieldNames, >> String classFieldName, Analyzer analyzer, Query query) throws IOException { >> - throw new IOException("training with multiple fields not supported >> by boolean perceptron classifier"); >> - } >> - >> private void updateWeights(LeafReader leafReader, >> int docId, Boolean assignedClass, >> SortedMap<String, Double> weights, >> double modifier, boolean updateFST) throws >> IOException { >> @@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier >> >> if (terms == null) { >> throw new IOException("term vectors must be stored for field " >> - + textFieldName); >> + + textFieldName); >> } >> >> TermsEnum termsEnum = terms.iterator(); >> @@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier >> for (Map.Entry<String, Double> entry : weights.entrySet()) { >> scratchBytes.copyChars(entry.getKey()); >> fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), >> entry >> - .getValue().longValue()); >> + .getValue().longValue()); >> } >> fst = fstBuilder.finish(); >> } >> >> + >> + /** >> + * {@inheritDoc} >> + */ >> + @Override >> + public ClassificationResult<Boolean> assignClass(String text) >> + throws IOException { >> + if (textTerms == null) { >> + throw new IOException("You must first call Classifier#train"); >> + } >> + Long output = 0l; >> + try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, >> text)) { >> + CharTermAttribute charTermAttribute = tokenStream >> + .addAttribute(CharTermAttribute.class); >> + tokenStream.reset(); >> + while (tokenStream.incrementToken()) { >> + String s = charTermAttribute.toString(); >> + Long d = Util.get(fst, new BytesRef(s)); >> + if (d != null) { >> + output += d; >> + } >> + } >> + tokenStream.end(); >> + } >> + >> + double score = 1 - Math.exp(-1 * Math.abs(threshold - >> output.doubleValue()) / threshold); >> + return new ClassificationResult<>(output >= threshold, score); >> + } >> + >> /** >> * {@inheritDoc} >> */ >> @Override >> public List<ClassificationResult<Boolean>> getClasses(String text) >> - throws IOException { >> + throws IOException { >> throw new RuntimeException("not implemented"); >> } >> >> @@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier >> */ >> @Override >> public List<ClassificationResult<Boolean>> getClasses(String text, int >> max) >> - throws IOException { >> + throws IOException { >> throw new RuntimeException("not implemented"); >> } >> >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java >> Thu Apr 30 14:12:03 2015 >> @@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef; >> */ >> public class CachingNaiveBayesClassifier extends >> SimpleNaiveBayesClassifier { >> //for caching classes this will be the classification class list >> - private ArrayList<BytesRef> cclasses = new ArrayList<>(); >> + private final ArrayList<BytesRef> cclasses = new ArrayList<>(); >> // it's a term-inmap style map, where the inmap contains class-hit >> pairs to the >> // upper term >> - private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new >> HashMap<>(); >> + private final Map<String, Map<BytesRef, Integer>> termCClassHitCache = >> new HashMap<>(); >> // the term frequency in classes >> - private Map<BytesRef, Double> classTermFreq = new HashMap<>(); >> + private final Map<BytesRef, Double> classTermFreq = new HashMap<>(); >> private boolean justCachedTerms; >> private int docsWithClassSize; >> >> /** >> - * Creates a new NaiveBayes classifier with inside caching. Note that >> you must >> - * call {@link #train(org.apache.lucene.index.LeafReader, String, >> String, Analyzer) train()} before >> - * you can classify any documents. If you want less memory usage you >> could >> + * Creates a new NaiveBayes classifier with inside caching. If you >> want less memory usage you could >> * call {@link #reInitCache(int, boolean) reInitCache()}. >> */ >> - public CachingNaiveBayesClassifier() { >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String textFieldName, String >> classFieldName, Analyzer analyzer) throws IOException { >> - train(leafReader, textFieldName, classFieldName, analyzer, null); >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String textFieldName, String >> classFieldName, Analyzer analyzer, Query query) throws IOException { >> - train(leafReader, new String[]{textFieldName}, classFieldName, >> analyzer, query); >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String[] textFieldNames, >> String classFieldName, Analyzer analyzer, Query query) throws IOException { >> - super.train(leafReader, textFieldNames, classFieldName, analyzer, >> query); >> + public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer >> analyzer, Query query, String classFieldName, String... textFieldNames) { >> + super(leafReader, analyzer, query, classFieldName, textFieldNames); >> // building the cache >> - reInitCache(0, true); >> + try { >> + reInitCache(0, true); >> + } catch (IOException e) { >> + throw new RuntimeException(e); >> + } >> } >> >> + >> private List<ClassificationResult<BytesRef>> >> assignClassNormalizedList(String inputDocument) throws IOException { >> if (leafReader == null) { >> throw new IOException("You must first call Classifier#train"); >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java >> Thu Apr 30 14:12:03 2015 >> @@ -18,17 +18,19 @@ package org.apache.lucene.classification >> >> /** >> * The result of a call to {@link Classifier#assignClass(String)} >> holding an assigned class of type <code>T</code> and a score. >> + * >> * @lucene.experimental >> */ >> -public class ClassificationResult<T> implements >> Comparable<ClassificationResult<T>>{ >> +public class ClassificationResult<T> implements >> Comparable<ClassificationResult<T>> { >> >> private final T assignedClass; >> private double score; >> >> /** >> * Constructor >> + * >> * @param assignedClass the class <code>T</code> assigned by a {@link >> Classifier} >> - * @param score the score for the assignedClass as a >> <code>double</code> >> + * @param score the score for the assignedClass as a >> <code>double</code> >> */ >> public ClassificationResult(T assignedClass, double score) { >> this.assignedClass = assignedClass; >> @@ -37,6 +39,7 @@ public class ClassificationResult<T> imp >> >> /** >> * retrieve the result class >> + * >> * @return a <code>T</code> representing an assigned class >> */ >> public T getAssignedClass() { >> @@ -45,14 +48,16 @@ public class ClassificationResult<T> imp >> >> /** >> * retrieve the result score >> + * >> * @return a <code>double</code> representing a result score >> */ >> public double getScore() { >> return score; >> } >> - >> + >> /** >> * set the score value >> + * >> * @param score the score for the assignedClass as a >> <code>double</code> >> */ >> public void setScore(double score) { >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java >> Thu Apr 30 14:12:03 2015 >> @@ -22,7 +22,6 @@ import java.util.List; >> import org.apache.lucene.analysis.Analyzer; >> import org.apache.lucene.index.LeafReader; >> import org.apache.lucene.search.Query; >> -import org.apache.lucene.util.BytesRef; >> >> /** >> * A classifier, see <code> >> http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which >> assign classes of type >> @@ -39,7 +38,7 @@ public interface Classifier<T> { >> * @return a {@link ClassificationResult} holding assigned class of >> type <code>T</code> and score >> * @throws IOException If there is a low-level I/O error. >> */ >> - public ClassificationResult<T> assignClass(String text) throws >> IOException; >> + ClassificationResult<T> assignClass(String text) throws IOException; >> >> /** >> * Get all the classes (sorted by score, descending) assigned to the >> given text String. >> @@ -48,7 +47,7 @@ public interface Classifier<T> { >> * @return the whole list of {@link ClassificationResult}, the classes >> and scores. Returns <code>null</code> if the classifier can't make lists. >> * @throws IOException If there is a low-level I/O error. >> */ >> - public List<ClassificationResult<T>> getClasses(String text) throws >> IOException; >> + List<ClassificationResult<T>> getClasses(String text) throws >> IOException; >> >> /** >> * Get the first <code>max</code> classes (sorted by score, >> descending) assigned to the given text String. >> @@ -58,44 +57,6 @@ public interface Classifier<T> { >> * @return the whole list of {@link ClassificationResult}, the classes >> and scores. Cut for "max" number of elements. Returns <code>null</code> if >> the classifier can't make lists. >> * @throws IOException If there is a low-level I/O error. >> */ >> - public List<ClassificationResult<T>> getClasses(String text, int max) >> throws IOException; >> - >> - /** >> - * Train the classifier using the underlying Lucene index >> - * >> - * @param leafReader the reader to use to access the Lucene index >> - * @param textFieldName the name of the field used to compare >> documents >> - * @param classFieldName the name of the field containing the class >> assigned to documents >> - * @param analyzer the analyzer used to tokenize / filter the >> unseen text >> - * @throws IOException If there is a low-level I/O error. >> - */ >> - public void train(LeafReader leafReader, String textFieldName, String >> classFieldName, Analyzer analyzer) >> - throws IOException; >> - >> - /** >> - * Train the classifier using the underlying Lucene index >> - * >> - * @param leafReader the reader to use to access the Lucene index >> - * @param textFieldName the name of the field used to compare >> documents >> - * @param classFieldName the name of the field containing the class >> assigned to documents >> - * @param analyzer the analyzer used to tokenize / filter the >> unseen text >> - * @param query the query to filter which documents use for >> training >> - * @throws IOException If there is a low-level I/O error. >> - */ >> - public void train(LeafReader leafReader, String textFieldName, String >> classFieldName, Analyzer analyzer, Query query) >> - throws IOException; >> - >> - /** >> - * Train the classifier using the underlying Lucene index >> - * >> - * @param leafReader the reader to use to access the Lucene index >> - * @param textFieldNames the names of the fields to be used to compare >> documents >> - * @param classFieldName the name of the field containing the class >> assigned to documents >> - * @param analyzer the analyzer used to tokenize / filter the >> unseen text >> - * @param query the query to filter which documents use for >> training >> - * @throws IOException If there is a low-level I/O error. >> - */ >> - public void train(LeafReader leafReader, String[] textFieldNames, >> String classFieldName, Analyzer analyzer, Query query) >> - throws IOException; >> + List<ClassificationResult<T>> getClasses(String text, int max) throws >> IOException; >> >> } >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java >> Thu Apr 30 14:12:03 2015 >> @@ -26,6 +26,7 @@ import java.util.Map; >> >> import org.apache.lucene.analysis.Analyzer; >> import org.apache.lucene.index.LeafReader; >> +import org.apache.lucene.index.StorableField; >> import org.apache.lucene.index.Term; >> import org.apache.lucene.queries.mlt.MoreLikeThis; >> import org.apache.lucene.search.BooleanClause; >> @@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef; >> */ >> public class KNearestNeighborClassifier implements Classifier<BytesRef> { >> >> - private MoreLikeThis mlt; >> - private String[] textFieldNames; >> - private String classFieldName; >> - private IndexSearcher indexSearcher; >> + private final MoreLikeThis mlt; >> + private final String[] textFieldNames; >> + private final String classFieldName; >> + private final IndexSearcher indexSearcher; >> private final int k; >> - private Query query; >> + private final Query query; >> >> - private int minDocsFreq; >> - private int minTermFreq; >> - >> - /** >> - * Create a {@link Classifier} using kNN algorithm >> - * >> - * @param k the number of neighbors to analyze as an <code>int</code> >> - */ >> - public KNearestNeighborClassifier(int k) { >> + public KNearestNeighborClassifier(LeafReader leafReader, Analyzer >> analyzer, Query query, int k, int minDocsFreq, >> + int minTermFreq, String >> classFieldName, String... textFieldNames) { >> + this.textFieldNames = textFieldNames; >> + this.classFieldName = classFieldName; >> + this.mlt = new MoreLikeThis(leafReader); >> + this.mlt.setAnalyzer(analyzer); >> + this.mlt.setFieldNames(textFieldNames); >> + this.indexSearcher = new IndexSearcher(leafReader); >> + if (minDocsFreq > 0) { >> + mlt.setMinDocFreq(minDocsFreq); >> + } >> + if (minTermFreq > 0) { >> + mlt.setMinTermFreq(minTermFreq); >> + } >> + this.query = query; >> this.k = k; >> } >> >> - /** >> - * Create a {@link Classifier} using kNN algorithm >> - * >> - * @param k the number of neighbors to analyze as an >> <code>int</code> >> - * @param minDocsFreq the minimum number of docs frequency for MLT to >> be set with {@link MoreLikeThis#setMinDocFreq(int)} >> - * @param minTermFreq the minimum number of term frequency for MLT to >> be set with {@link MoreLikeThis#setMinTermFreq(int)} >> - */ >> - public KNearestNeighborClassifier(int k, int minDocsFreq, int >> minTermFreq) { >> - this.k = k; >> - this.minDocsFreq = minDocsFreq; >> - this.minTermFreq = minTermFreq; >> - } >> >> /** >> * {@inheritDoc} >> @@ -136,12 +131,15 @@ public class KNearestNeighborClassifier >> private List<ClassificationResult<BytesRef>> >> buildListFromTopDocs(TopDocs topDocs) throws IOException { >> Map<BytesRef, Integer> classCounts = new HashMap<>(); >> for (ScoreDoc scoreDoc : topDocs.scoreDocs) { >> - BytesRef cl = new >> BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue()); >> - Integer count = classCounts.get(cl); >> - if (count != null) { >> - classCounts.put(cl, count + 1); >> - } else { >> - classCounts.put(cl, 1); >> + StorableField storableField = >> indexSearcher.doc(scoreDoc.doc).getField(classFieldName); >> + if (storableField != null) { >> + BytesRef cl = new BytesRef(storableField.stringValue()); >> + Integer count = classCounts.get(cl); >> + if (count != null) { >> + classCounts.put(cl, count + 1); >> + } else { >> + classCounts.put(cl, 1); >> + } >> } >> } >> List<ClassificationResult<BytesRef>> returnList = new ArrayList<>(); >> @@ -161,39 +159,4 @@ public class KNearestNeighborClassifier >> return returnList; >> } >> >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String textFieldName, String >> classFieldName, Analyzer analyzer) throws IOException { >> - train(leafReader, textFieldName, classFieldName, analyzer, null); >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String textFieldName, String >> classFieldName, Analyzer analyzer, Query query) throws IOException { >> - train(leafReader, new String[]{textFieldName}, classFieldName, >> analyzer, query); >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String[] textFieldNames, >> String classFieldName, Analyzer analyzer, Query query) throws IOException { >> - this.textFieldNames = textFieldNames; >> - this.classFieldName = classFieldName; >> - mlt = new MoreLikeThis(leafReader); >> - mlt.setAnalyzer(analyzer); >> - mlt.setFieldNames(textFieldNames); >> - indexSearcher = new IndexSearcher(leafReader); >> - if (minDocsFreq > 0) { >> - mlt.setMinDocFreq(minDocsFreq); >> - } >> - if (minTermFreq > 0) { >> - mlt.setMinTermFreq(minTermFreq); >> - } >> - this.query = query; >> - } >> } >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java >> Thu Apr 30 14:12:03 2015 >> @@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier >> * {@link org.apache.lucene.index.LeafReader} used to access the >> {@link org.apache.lucene.classification.Classifier}'s >> * index >> */ >> - protected LeafReader leafReader; >> + protected final LeafReader leafReader; >> >> /** >> * names of the fields to be used as input text >> */ >> - protected String[] textFieldNames; >> + protected final String[] textFieldNames; >> >> /** >> * name of the field to be used as a class / category output >> */ >> - protected String classFieldName; >> + protected final String classFieldName; >> >> /** >> * {@link org.apache.lucene.analysis.Analyzer} to be used for >> tokenizing unseen input text >> */ >> - protected Analyzer analyzer; >> + protected final Analyzer analyzer; >> >> /** >> * {@link org.apache.lucene.search.IndexSearcher} to run searches on >> the index for retrieving frequencies >> */ >> - protected IndexSearcher indexSearcher; >> + protected final IndexSearcher indexSearcher; >> >> /** >> * {@link org.apache.lucene.search.Query} used to eventually filter >> the document set to be used to classify >> */ >> - protected Query query; >> + protected final Query query; >> >> /** >> * Creates a new NaiveBayes classifier. >> - * Note that you must call {@link >> #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) >> train()} before you can >> * classify any documents. >> */ >> - public SimpleNaiveBayesClassifier() { >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String textFieldName, String >> classFieldName, Analyzer analyzer) throws IOException { >> - train(leafReader, textFieldName, classFieldName, analyzer, null); >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String textFieldName, String >> classFieldName, Analyzer analyzer, Query query) >> - throws IOException { >> - train(leafReader, new String[]{textFieldName}, classFieldName, >> analyzer, query); >> - } >> - >> - /** >> - * {@inheritDoc} >> - */ >> - @Override >> - public void train(LeafReader leafReader, String[] textFieldNames, >> String classFieldName, Analyzer analyzer, Query query) >> - throws IOException { >> + public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer >> analyzer, Query query, String classFieldName, String... textFieldNames) { >> this.leafReader = leafReader; >> this.indexSearcher = new IndexSearcher(this.leafReader); >> this.textFieldNames = textFieldNames; >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java >> Thu Apr 30 14:12:03 2015 >> @@ -18,7 +18,7 @@ >> /** >> * Uses already seen data (the indexed documents) to classify new >> documents. >> * <p> >> - * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest >> + * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest >> * Neighbor classifier and a Perceptron based classifier. >> */ >> package org.apache.lucene.classification; >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java >> Thu Apr 30 14:12:03 2015 >> @@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils { >> >> /** >> * create a sparse <code>Double</code> vector given doc and field term >> vectors using local frequency of the terms in the doc >> - * @param docTerms term vectors for a given document >> + * >> + * @param docTerms term vectors for a given document >> * @param fieldTerms field term vectors >> * @return a sparse vector of <code>Double</code>s as an array >> * @throws IOException in case accessing the underlying index fails >> @@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils { >> if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) { >> long termFreqLocal = docTermsEnum.totalTermFreq(); // the >> total number of occurrences of this term in the given document >> freqVector[i] = Long.valueOf(termFreqLocal).doubleValue(); >> - } >> - else { >> + } else { >> freqVector[i] = 0d; >> } >> i++; >> @@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils { >> >> /** >> * create a dense <code>Double</code> vector given doc and field term >> vectors using local frequency of the terms in the doc >> + * >> * @param docTerms term vectors for a given document >> * @return a dense vector of <code>Double</code>s as an array >> * @throws IOException in case accessing the underlying index fails >> @@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils { >> public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) >> throws IOException { >> Double[] freqVector = null; >> if (docTerms != null) { >> - freqVector = new Double[(int) docTerms.size()]; >> - int i = 0; >> - TermsEnum docTermsEnum = docTerms.iterator(); >> + freqVector = new Double[(int) docTerms.size()]; >> + int i = 0; >> + TermsEnum docTermsEnum = docTerms.iterator(); >> >> - while (docTermsEnum.next() != null) { >> - long termFreqLocal = docTermsEnum.totalTermFreq(); // the >> total number of occurrences of this term in the given document >> - freqVector[i] = Long.valueOf(termFreqLocal).doubleValue(); >> - i++; >> - } >> + while (docTermsEnum.next() != null) { >> + long termFreqLocal = docTermsEnum.totalTermFreq(); // the total >> number of occurrences of this term in the given document >> + freqVector[i] = Long.valueOf(termFreqLocal).doubleValue(); >> + i++; >> + } >> } >> return freqVector; >> -} >> + } >> } >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java >> Thu Apr 30 14:12:03 2015 >> @@ -17,6 +17,8 @@ >> package org.apache.lucene.classification; >> >> import org.apache.lucene.analysis.MockAnalyzer; >> +import org.apache.lucene.index.LeafReader; >> +import org.apache.lucene.index.SlowCompositeReaderWrapper; >> import org.apache.lucene.index.Term; >> import org.apache.lucene.search.TermQuery; >> import org.junit.Test; >> @@ -28,22 +30,45 @@ public class BooleanPerceptronClassifier >> >> @Test >> public void testBasicUsage() throws Exception { >> - checkCorrectClassification(new BooleanPerceptronClassifier(), >> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, >> booleanFieldName); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + checkCorrectClassification(new >> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, >> analyzer, null, 1, null), TECHNOLOGY_INPUT, false); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> @Test >> public void testExplicitThreshold() throws Exception { >> - checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), >> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, >> booleanFieldName); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + checkCorrectClassification(new >> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, >> analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> @Test >> public void testBasicUsageWithQuery() throws Exception { >> - checkCorrectClassification(new BooleanPerceptronClassifier(), >> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, >> booleanFieldName, new TermQuery(new Term(textFieldName, "it"))); >> - } >> - >> - @Test >> - public void testPerformance() throws Exception { >> - checkPerformance(new BooleanPerceptronClassifier(), new >> MockAnalyzer(random()), booleanFieldName); >> + TermQuery query = new TermQuery(new Term(textFieldName, "it")); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + checkCorrectClassification(new >> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, >> analyzer, query, 1, null), TECHNOLOGY_INPUT, false); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> } >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java >> Thu Apr 30 14:12:03 2015 >> @@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokeni >> import org.apache.lucene.analysis.core.KeywordTokenizer; >> import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; >> import org.apache.lucene.analysis.reverse.ReverseStringFilter; >> +import org.apache.lucene.index.LeafReader; >> +import org.apache.lucene.index.SlowCompositeReaderWrapper; >> import org.apache.lucene.index.Term; >> import org.apache.lucene.search.TermQuery; >> import org.apache.lucene.util.BytesRef; >> @@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifier >> >> @Test >> public void testBasicUsage() throws Exception { >> - checkCorrectClassification(new CachingNaiveBayesClassifier(), >> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), >> textFieldName, categoryFieldName); >> - checkCorrectClassification(new CachingNaiveBayesClassifier(), >> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, >> categoryFieldName); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + checkCorrectClassification(new >> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, >> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + checkCorrectClassification(new >> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, >> textFieldName), POLITICS_INPUT, POLITICS_RESULT); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> @Test >> public void testBasicUsageWithQuery() throws Exception { >> - checkCorrectClassification(new CachingNaiveBayesClassifier(), >> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), >> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, >> "it"))); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + TermQuery query = new TermQuery(new Term(textFieldName, "it")); >> + checkCorrectClassification(new >> CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, >> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> @Test >> public void testNGramUsage() throws Exception { >> - checkCorrectClassification(new CachingNaiveBayesClassifier(), >> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, >> categoryFieldName); >> + LeafReader leafReader = null; >> + try { >> + NGramAnalyzer analyzer = new NGramAnalyzer(); >> + leafReader = populateSampleIndex(analyzer); >> + checkCorrectClassification(new >> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, >> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> private class NGramAnalyzer extends Analyzer { >> @@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifier >> } >> } >> >> - @Test >> - public void testPerformance() throws Exception { >> - checkPerformance(new CachingNaiveBayesClassifier(), new >> MockAnalyzer(random()), categoryFieldName); >> - } >> - >> } >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java >> Thu Apr 30 14:12:03 2015 >> @@ -41,14 +41,14 @@ import org.junit.Before; >> */ >> public abstract class ClassificationTestBase<T> extends LuceneTestCase { >> public final static String POLITICS_INPUT = "Here are some interesting >> questions and answers about Mitt Romney.. " + >> - "If you don't know the answer to the question about Mitt Romney, >> then simply click on the answer below the question section."; >> + "If you don't know the answer to the question about Mitt >> Romney, then simply click on the answer below the question section."; >> public static final BytesRef POLITICS_RESULT = new >> BytesRef("politics"); >> >> public static final String TECHNOLOGY_INPUT = "Much is made of what >> the likes of Facebook, Google and Apple know about users." + >> - " Truth is, Amazon may know more."; >> + " Truth is, Amazon may know more."; >> public static final BytesRef TECHNOLOGY_RESULT = new >> BytesRef("technology"); >> >> - private RandomIndexWriter indexWriter; >> + protected RandomIndexWriter indexWriter; >> private Directory dir; >> private FieldType ft; >> >> @@ -79,53 +79,34 @@ public abstract class ClassificationTest >> dir.close(); >> } >> >> - protected void checkCorrectClassification(Classifier<T> classifier, >> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, >> String classFieldName) throws Exception { >> - checkCorrectClassification(classifier, inputDoc, expectedResult, >> analyzer, textFieldName, classFieldName, null); >> + protected void checkCorrectClassification(Classifier<T> classifier, >> String inputDoc, T expectedResult) throws Exception { >> + ClassificationResult<T> classificationResult = >> classifier.assignClass(inputDoc); >> + assertNotNull(classificationResult.getAssignedClass()); >> + assertEquals("got an assigned class of " + >> classificationResult.getAssignedClass(), expectedResult, >> classificationResult.getAssignedClass()); >> + double score = classificationResult.getScore(); >> + assertTrue("score should be between 0 and 1, got:" + score, score <= >> 1 && score >= 0); >> } >> >> - protected void checkCorrectClassification(Classifier<T> classifier, >> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, >> String classFieldName, Query query) throws Exception { >> - LeafReader leafReader = null; >> - try { >> - populateSampleIndex(analyzer); >> - leafReader = >> SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); >> - classifier.train(leafReader, textFieldName, classFieldName, >> analyzer, query); >> - ClassificationResult<T> classificationResult = >> classifier.assignClass(inputDoc); >> - assertNotNull(classificationResult.getAssignedClass()); >> - assertEquals("got an assigned class of " + >> classificationResult.getAssignedClass(), expectedResult, >> classificationResult.getAssignedClass()); >> - double score = classificationResult.getScore(); >> - assertTrue("score should be between 0 and 1, got:" + score, score >> <= 1 && score >= 0); >> - } finally { >> - if (leafReader != null) >> - leafReader.close(); >> - } >> - } >> protected void checkOnlineClassification(Classifier<T> classifier, >> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, >> String classFieldName) throws Exception { >> checkOnlineClassification(classifier, inputDoc, expectedResult, >> analyzer, textFieldName, classFieldName, null); >> } >> >> protected void checkOnlineClassification(Classifier<T> classifier, >> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, >> String classFieldName, Query query) throws Exception { >> - LeafReader leafReader = null; >> - try { >> - populateSampleIndex(analyzer); >> - leafReader = >> SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); >> - classifier.train(leafReader, textFieldName, classFieldName, >> analyzer, query); >> - ClassificationResult<T> classificationResult = >> classifier.assignClass(inputDoc); >> - assertNotNull(classificationResult.getAssignedClass()); >> - assertEquals("got an assigned class of " + >> classificationResult.getAssignedClass(), expectedResult, >> classificationResult.getAssignedClass()); >> - double score = classificationResult.getScore(); >> - assertTrue("score should be between 0 and 1, got: " + score, score >> <= 1 && score >= 0); >> - updateSampleIndex(); >> - ClassificationResult<T> secondClassificationResult = >> classifier.assignClass(inputDoc); >> - assertEquals(classificationResult.getAssignedClass(), >> secondClassificationResult.getAssignedClass()); >> - assertEquals(Double.valueOf(score), >> Double.valueOf(secondClassificationResult.getScore())); >> - >> - } finally { >> - if (leafReader != null) >> - leafReader.close(); >> - } >> + populateSampleIndex(analyzer); >> + >> + ClassificationResult<T> classificationResult = >> classifier.assignClass(inputDoc); >> + assertNotNull(classificationResult.getAssignedClass()); >> + assertEquals("got an assigned class of " + >> classificationResult.getAssignedClass(), expectedResult, >> classificationResult.getAssignedClass()); >> + double score = classificationResult.getScore(); >> + assertTrue("score should be between 0 and 1, got: " + score, score >> <= 1 && score >= 0); >> + updateSampleIndex(); >> + ClassificationResult<T> secondClassificationResult = >> classifier.assignClass(inputDoc); >> + assertEquals(classificationResult.getAssignedClass(), >> secondClassificationResult.getAssignedClass()); >> + assertEquals(Double.valueOf(score), >> Double.valueOf(secondClassificationResult.getScore())); >> + >> } >> >> - private void populateSampleIndex(Analyzer analyzer) throws IOException >> { >> + protected LeafReader populateSampleIndex(Analyzer analyzer) throws >> IOException { >> indexWriter.close(); >> indexWriter = new RandomIndexWriter(random(), dir, >> newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE)); >> indexWriter.commit(); >> @@ -134,8 +115,8 @@ public abstract class ClassificationTest >> >> Document doc = new Document(); >> text = "The traveling press secretary for Mitt Romney lost his cool >> and cursed at reporters " + >> - "who attempted to ask questions of the Republican presidential >> candidate in a public plaza near the Tomb of " + >> - "the Unknown Soldier in Warsaw Tuesday."; >> + "who attempted to ask questions of the Republican >> presidential candidate in a public plaza near the Tomb of " + >> + "the Unknown Soldier in Warsaw Tuesday."; >> doc.add(new Field(textFieldName, text, ft)); >> doc.add(new Field(categoryFieldName, "politics", ft)); >> doc.add(new Field(booleanFieldName, "true", ft)); >> @@ -144,7 +125,7 @@ public abstract class ClassificationTest >> >> doc = new Document(); >> text = "Mitt Romney seeks to assure Israel and Iran, as well as >> Jewish voters in the United" + >> - " States, that he will be tougher against Iran's nuclear >> ambitions than President Barack Obama."; >> + " States, that he will be tougher against Iran's nuclear >> ambitions than President Barack Obama."; >> doc.add(new Field(textFieldName, text, ft)); >> doc.add(new Field(categoryFieldName, "politics", ft)); >> doc.add(new Field(booleanFieldName, "true", ft)); >> @@ -152,8 +133,8 @@ public abstract class ClassificationTest >> >> doc = new Document(); >> text = "And there's a threshold question that he has to answer for >> the American people and " + >> - "that's whether he is prepared to be commander-in-chief,\" she >> continued. \"As we look to the past events, we " + >> - "know that this raises some questions about his preparedness and >> we'll see how the rest of his trip goes.\""; >> + "that's whether he is prepared to be commander-in-chief,\" >> she continued. \"As we look to the past events, we " + >> + "know that this raises some questions about his preparedness >> and we'll see how the rest of his trip goes.\""; >> doc.add(new Field(textFieldName, text, ft)); >> doc.add(new Field(categoryFieldName, "politics", ft)); >> doc.add(new Field(booleanFieldName, "true", ft)); >> @@ -161,8 +142,8 @@ public abstract class ClassificationTest >> >> doc = new Document(); >> text = "Still, when it comes to gun policy, many congressional >> Democrats have \"decided to " + >> - "keep quiet and not go there,\" said Alan Lizotte, dean and >> professor at the State University of New York at " + >> - "Albany's School of Criminal Justice."; >> + "keep quiet and not go there,\" said Alan Lizotte, dean and >> professor at the State University of New York at " + >> + "Albany's School of Criminal Justice."; >> doc.add(new Field(textFieldName, text, ft)); >> doc.add(new Field(categoryFieldName, "politics", ft)); >> doc.add(new Field(booleanFieldName, "true", ft)); >> @@ -170,8 +151,8 @@ public abstract class ClassificationTest >> >> doc = new Document(); >> text = "Standing amongst the thousands of people at the state >> Capitol, Jorstad, director of " + >> - "technology at the University of Wisconsin-La Crosse, documented >> the historic moment and shared it with the " + >> - "world through the Internet."; >> + "technology at the University of Wisconsin-La Crosse, >> documented the historic moment and shared it with the " + >> + "world through the Internet."; >> doc.add(new Field(textFieldName, text, ft)); >> doc.add(new Field(categoryFieldName, "technology", ft)); >> doc.add(new Field(booleanFieldName, "false", ft)); >> @@ -179,7 +160,7 @@ public abstract class ClassificationTest >> >> doc = new Document(); >> text = "So, about all those experts and analysts who've spent the >> past year or so saying " + >> - "Facebook was going to make a phone. A new expert has stepped >> forward to say it's not going to happen."; >> + "Facebook was going to make a phone. A new expert has >> stepped forward to say it's not going to happen."; >> doc.add(new Field(textFieldName, text, ft)); >> doc.add(new Field(categoryFieldName, "technology", ft)); >> doc.add(new Field(booleanFieldName, "false", ft)); >> @@ -187,8 +168,8 @@ public abstract class ClassificationTest >> >> doc = new Document(); >> text = "More than 400 million people trust Google with their e-mail, >> and 50 million store files" + >> - " in the cloud using the Dropbox service. People manage their >> bank accounts, pay bills, trade stocks and " + >> - "generally transfer or store huge volumes of personal data >> online."; >> + " in the cloud using the Dropbox service. People manage >> their bank accounts, pay bills, trade stocks and " + >> + "generally transfer or store huge volumes of personal data >> online."; >> doc.add(new Field(textFieldName, text, ft)); >> doc.add(new Field(categoryFieldName, "technology", ft)); >> doc.add(new Field(booleanFieldName, "false", ft)); >> @@ -200,22 +181,15 @@ public abstract class ClassificationTest >> indexWriter.addDocument(doc); >> >> indexWriter.commit(); >> + return SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); >> } >> >> protected void checkPerformance(Classifier<T> classifier, Analyzer >> analyzer, String classFieldName) throws Exception { >> - LeafReader leafReader = null; >> long trainStart = System.currentTimeMillis(); >> - try { >> - populatePerformanceIndex(analyzer); >> - leafReader = >> SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); >> - classifier.train(leafReader, textFieldName, classFieldName, >> analyzer); >> - long trainEnd = System.currentTimeMillis(); >> - long trainTime = trainEnd - trainStart; >> - assertTrue("training took more than 2 mins : " + trainTime / 1000 >> + "s", trainTime < 120000); >> - } finally { >> - if (leafReader != null) >> - leafReader.close(); >> - } >> + populatePerformanceIndex(analyzer); >> + long trainEnd = System.currentTimeMillis(); >> + long trainTime = trainEnd - trainStart; >> + assertTrue("training took more than 2 mins : " + trainTime / 1000 + >> "s", trainTime < 120000); >> } >> >> private void populatePerformanceIndex(Analyzer analyzer) throws >> IOException { >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java >> Thu Apr 30 14:12:03 2015 >> @@ -17,6 +17,8 @@ >> package org.apache.lucene.classification; >> >> import org.apache.lucene.analysis.MockAnalyzer; >> +import org.apache.lucene.index.LeafReader; >> +import org.apache.lucene.index.SlowCompositeReaderWrapper; >> import org.apache.lucene.index.Term; >> import org.apache.lucene.search.TermQuery; >> import org.apache.lucene.util.BytesRef; >> @@ -29,20 +31,32 @@ public class KNearestNeighborClassifierT >> >> @Test >> public void testBasicUsage() throws Exception { >> - // usage with default MLT min docs / term freq >> - checkCorrectClassification(new KNearestNeighborClassifier(3), >> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, >> categoryFieldName); >> - // usage without custom min docs / term freq for MLT >> - checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), >> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), >> textFieldName, categoryFieldName); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + checkCorrectClassification(new >> KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, >> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + checkCorrectClassification(new >> KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, >> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> @Test >> public void testBasicUsageWithQuery() throws Exception { >> - checkCorrectClassification(new KNearestNeighborClassifier(1), >> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), >> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, >> "it"))); >> - } >> - >> - @Test >> - public void testPerformance() throws Exception { >> - checkPerformance(new KNearestNeighborClassifier(100), new >> MockAnalyzer(random()), categoryFieldName); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + TermQuery query = new TermQuery(new Term(textFieldName, "it")); >> + checkCorrectClassification(new >> KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, >> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> } >> >> Modified: >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java >> URL: >> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff >> >> ============================================================================== >> --- >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java >> (original) >> +++ >> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java >> Thu Apr 30 14:12:03 2015 >> @@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokeni >> import org.apache.lucene.analysis.core.KeywordTokenizer; >> import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; >> import org.apache.lucene.analysis.reverse.ReverseStringFilter; >> +import org.apache.lucene.index.LeafReader; >> +import org.apache.lucene.index.SlowCompositeReaderWrapper; >> import org.apache.lucene.index.Term; >> import org.apache.lucene.search.TermQuery; >> import org.apache.lucene.util.BytesRef; >> -import org.apache.lucene.util.LuceneTestCase; >> import org.junit.Test; >> >> -import java.io.Reader; >> - >> /** >> * Testcase for {@link SimpleNaiveBayesClassifier} >> */ >> @@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierT >> >> @Test >> public void testBasicUsage() throws Exception { >> - checkCorrectClassification(new SimpleNaiveBayesClassifier(), >> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), >> textFieldName, categoryFieldName); >> - checkCorrectClassification(new SimpleNaiveBayesClassifier(), >> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, >> categoryFieldName); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + checkCorrectClassification(new >> SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, >> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + checkCorrectClassification(new >> SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, >> textFieldName), POLITICS_INPUT, POLITICS_RESULT); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> @Test >> public void testBasicUsageWithQuery() throws Exception { >> - checkCorrectClassification(new SimpleNaiveBayesClassifier(), >> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), >> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, >> "it"))); >> + LeafReader leafReader = null; >> + try { >> + MockAnalyzer analyzer = new MockAnalyzer(random()); >> + leafReader = populateSampleIndex(analyzer); >> + TermQuery query = new TermQuery(new Term(textFieldName, "it")); >> + checkCorrectClassification(new >> SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, >> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> @Test >> public void testNGramUsage() throws Exception { >> - checkCorrectClassification(new SimpleNaiveBayesClassifier(), >> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, >> categoryFieldName); >> + LeafReader leafReader = null; >> + try { >> + Analyzer analyzer = new NGramAnalyzer(); >> + leafReader = populateSampleIndex(analyzer); >> + checkCorrectClassification(new >> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, >> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); >> + } finally { >> + if (leafReader != null) { >> + leafReader.close(); >> + } >> + } >> } >> >> private class NGramAnalyzer extends Analyzer { >> @@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierT >> } >> } >> >> - @Test >> - public void testPerformance() throws Exception { >> - checkPerformance(new SimpleNaiveBayesClassifier(), new >> MockAnalyzer(random()), categoryFieldName); >> - } >> - >> } >> >> >> >
