Did you mean to remove javadocs on BooleanPerceptionClassifier? This breaks precommit...
On Thu, Apr 30, 2015 at 7:12 AM, <[email protected]> wrote: > Author: tommaso > Date: Thu Apr 30 14:12:03 2015 > New Revision: 1676997 > > URL: http://svn.apache.org/r1676997 > Log: > LUCENE-6045 - refactor Classifier API to work better with multithreading > > Modified: > > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java > > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java > > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java > > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java > > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java > > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java > > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java > > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java > > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java > > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java > > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java > > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java > > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java > > Modified: > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java > Thu Apr 30 14:12:03 2015 > @@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util; > */ > public class BooleanPerceptronClassifier implements Classifier<Boolean> { > > - private Double threshold; > - private final Integer batchSize; > - private Terms textTerms; > - private Analyzer analyzer; > - private String textFieldName; > + private final Double threshold; > + private final Terms textTerms; > + private final Analyzer analyzer; > + private final String textFieldName; > private FST<Long> fst; > > - /** > - * Create a {@link BooleanPerceptronClassifier} > - * > - * @param threshold the binary threshold for perceptron output > evaluation > - */ > - public BooleanPerceptronClassifier(Double threshold, Integer batchSize) > { > - this.threshold = threshold; > - this.batchSize = batchSize; > - } > - > - /** > - * Default constructor, no batch updates of FST, perceptron threshold is > - * calculated via underlying index metrics during > - * {@link #train(org.apache.lucene.index.LeafReader, String, String, > org.apache.lucene.analysis.Analyzer) > - * training} > - */ > - public BooleanPerceptronClassifier() { > - batchSize = 1; > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public ClassificationResult<Boolean> assignClass(String text) > - throws IOException { > - if (textTerms == null) { > - throw new IOException("You must first call Classifier#train"); > - } > - Long output = 0l; > - try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, > text)) { > - CharTermAttribute charTermAttribute = tokenStream > - .addAttribute(CharTermAttribute.class); > - tokenStream.reset(); > - while (tokenStream.incrementToken()) { > - String s = charTermAttribute.toString(); > - Long d = Util.get(fst, new BytesRef(s)); > - if (d != null) { > - output += d; > - } > - } > - tokenStream.end(); > - } > - > - double score = 1 - Math.exp(-1 * Math.abs(threshold - > output.doubleValue()) / threshold); > - return new ClassificationResult<>(output >= threshold, score); > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String textFieldName, > - String classFieldName, Analyzer analyzer) throws > IOException { > - train(leafReader, textFieldName, classFieldName, analyzer, null); > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String textFieldName, > - String classFieldName, Analyzer analyzer, Query > query) throws IOException { > + public BooleanPerceptronClassifier(LeafReader leafReader, String > textFieldName, String classFieldName, Analyzer analyzer, > + Query query, Integer batchSize, > Double threshold) throws IOException { > this.textTerms = MultiFields.getTerms(leafReader, textFieldName); > > if (textTerms == null) { > @@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier > this.threshold = (double) sumDocFreq / 2d; > } else { > throw new IOException( > - "threshold cannot be assigned since term vectors for field " > - + textFieldName + " do not exist"); > + "threshold cannot be assigned since term vectors for > field " > + + textFieldName + " do not exist"); > } > + } else { > + this.threshold = threshold; > } > > // TODO : remove this map as soon as we have a writable FST > @@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier > } > // run the search and use stored field values > for (ScoreDoc scoreDoc : indexSearcher.search(q, > - Integer.MAX_VALUE).scoreDocs) { > + Integer.MAX_VALUE).scoreDocs) { > StoredDocument doc = indexSearcher.doc(scoreDoc.doc); > > StorableField textField = doc.getField(textFieldName); > @@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier > long modifier = correctClass.compareTo(assignedClass); > if (modifier != 0) { > updateWeights(leafReader, scoreDoc.doc, assignedClass, > - weights, modifier, batchCount % batchSize == 0); > + weights, modifier, batchCount % batchSize == 0); > } > batchCount++; > } > @@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier > weights.clear(); // free memory while waiting for GC > } > > - @Override > - public void train(LeafReader leafReader, String[] textFieldNames, > String classFieldName, Analyzer analyzer, Query query) throws IOException { > - throw new IOException("training with multiple fields not supported by > boolean perceptron classifier"); > - } > - > private void updateWeights(LeafReader leafReader, > int docId, Boolean assignedClass, > SortedMap<String, Double> weights, > double modifier, boolean updateFST) throws > IOException { > @@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier > > if (terms == null) { > throw new IOException("term vectors must be stored for field " > - + textFieldName); > + + textFieldName); > } > > TermsEnum termsEnum = terms.iterator(); > @@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier > for (Map.Entry<String, Double> entry : weights.entrySet()) { > scratchBytes.copyChars(entry.getKey()); > fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), > entry > - .getValue().longValue()); > + .getValue().longValue()); > } > fst = fstBuilder.finish(); > } > > + > + /** > + * {@inheritDoc} > + */ > + @Override > + public ClassificationResult<Boolean> assignClass(String text) > + throws IOException { > + if (textTerms == null) { > + throw new IOException("You must first call Classifier#train"); > + } > + Long output = 0l; > + try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, > text)) { > + CharTermAttribute charTermAttribute = tokenStream > + .addAttribute(CharTermAttribute.class); > + tokenStream.reset(); > + while (tokenStream.incrementToken()) { > + String s = charTermAttribute.toString(); > + Long d = Util.get(fst, new BytesRef(s)); > + if (d != null) { > + output += d; > + } > + } > + tokenStream.end(); > + } > + > + double score = 1 - Math.exp(-1 * Math.abs(threshold - > output.doubleValue()) / threshold); > + return new ClassificationResult<>(output >= threshold, score); > + } > + > /** > * {@inheritDoc} > */ > @Override > public List<ClassificationResult<Boolean>> getClasses(String text) > - throws IOException { > + throws IOException { > throw new RuntimeException("not implemented"); > } > > @@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier > */ > @Override > public List<ClassificationResult<Boolean>> getClasses(String text, int > max) > - throws IOException { > + throws IOException { > throw new RuntimeException("not implemented"); > } > > > Modified: > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java > Thu Apr 30 14:12:03 2015 > @@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef; > */ > public class CachingNaiveBayesClassifier extends > SimpleNaiveBayesClassifier { > //for caching classes this will be the classification class list > - private ArrayList<BytesRef> cclasses = new ArrayList<>(); > + private final ArrayList<BytesRef> cclasses = new ArrayList<>(); > // it's a term-inmap style map, where the inmap contains class-hit > pairs to the > // upper term > - private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new > HashMap<>(); > + private final Map<String, Map<BytesRef, Integer>> termCClassHitCache = > new HashMap<>(); > // the term frequency in classes > - private Map<BytesRef, Double> classTermFreq = new HashMap<>(); > + private final Map<BytesRef, Double> classTermFreq = new HashMap<>(); > private boolean justCachedTerms; > private int docsWithClassSize; > > /** > - * Creates a new NaiveBayes classifier with inside caching. Note that > you must > - * call {@link #train(org.apache.lucene.index.LeafReader, String, > String, Analyzer) train()} before > - * you can classify any documents. If you want less memory usage you > could > + * Creates a new NaiveBayes classifier with inside caching. If you want > less memory usage you could > * call {@link #reInitCache(int, boolean) reInitCache()}. > */ > - public CachingNaiveBayesClassifier() { > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String textFieldName, String > classFieldName, Analyzer analyzer) throws IOException { > - train(leafReader, textFieldName, classFieldName, analyzer, null); > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String textFieldName, String > classFieldName, Analyzer analyzer, Query query) throws IOException { > - train(leafReader, new String[]{textFieldName}, classFieldName, > analyzer, query); > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String[] textFieldNames, > String classFieldName, Analyzer analyzer, Query query) throws IOException { > - super.train(leafReader, textFieldNames, classFieldName, analyzer, > query); > + public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer > analyzer, Query query, String classFieldName, String... textFieldNames) { > + super(leafReader, analyzer, query, classFieldName, textFieldNames); > // building the cache > - reInitCache(0, true); > + try { > + reInitCache(0, true); > + } catch (IOException e) { > + throw new RuntimeException(e); > + } > } > > + > private List<ClassificationResult<BytesRef>> > assignClassNormalizedList(String inputDocument) throws IOException { > if (leafReader == null) { > throw new IOException("You must first call Classifier#train"); > > Modified: > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java > Thu Apr 30 14:12:03 2015 > @@ -18,17 +18,19 @@ package org.apache.lucene.classification > > /** > * The result of a call to {@link Classifier#assignClass(String)} holding > an assigned class of type <code>T</code> and a score. > + * > * @lucene.experimental > */ > -public class ClassificationResult<T> implements > Comparable<ClassificationResult<T>>{ > +public class ClassificationResult<T> implements > Comparable<ClassificationResult<T>> { > > private final T assignedClass; > private double score; > > /** > * Constructor > + * > * @param assignedClass the class <code>T</code> assigned by a {@link > Classifier} > - * @param score the score for the assignedClass as a <code>double</code> > + * @param score the score for the assignedClass as a > <code>double</code> > */ > public ClassificationResult(T assignedClass, double score) { > this.assignedClass = assignedClass; > @@ -37,6 +39,7 @@ public class ClassificationResult<T> imp > > /** > * retrieve the result class > + * > * @return a <code>T</code> representing an assigned class > */ > public T getAssignedClass() { > @@ -45,14 +48,16 @@ public class ClassificationResult<T> imp > > /** > * retrieve the result score > + * > * @return a <code>double</code> representing a result score > */ > public double getScore() { > return score; > } > - > + > /** > * set the score value > + * > * @param score the score for the assignedClass as a <code>double</code> > */ > public void setScore(double score) { > > Modified: > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java > Thu Apr 30 14:12:03 2015 > @@ -22,7 +22,6 @@ import java.util.List; > import org.apache.lucene.analysis.Analyzer; > import org.apache.lucene.index.LeafReader; > import org.apache.lucene.search.Query; > -import org.apache.lucene.util.BytesRef; > > /** > * A classifier, see <code> > http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which > assign classes of type > @@ -39,7 +38,7 @@ public interface Classifier<T> { > * @return a {@link ClassificationResult} holding assigned class of > type <code>T</code> and score > * @throws IOException If there is a low-level I/O error. > */ > - public ClassificationResult<T> assignClass(String text) throws > IOException; > + ClassificationResult<T> assignClass(String text) throws IOException; > > /** > * Get all the classes (sorted by score, descending) assigned to the > given text String. > @@ -48,7 +47,7 @@ public interface Classifier<T> { > * @return the whole list of {@link ClassificationResult}, the classes > and scores. Returns <code>null</code> if the classifier can't make lists. > * @throws IOException If there is a low-level I/O error. > */ > - public List<ClassificationResult<T>> getClasses(String text) throws > IOException; > + List<ClassificationResult<T>> getClasses(String text) throws > IOException; > > /** > * Get the first <code>max</code> classes (sorted by score, descending) > assigned to the given text String. > @@ -58,44 +57,6 @@ public interface Classifier<T> { > * @return the whole list of {@link ClassificationResult}, the classes > and scores. Cut for "max" number of elements. Returns <code>null</code> if > the classifier can't make lists. > * @throws IOException If there is a low-level I/O error. > */ > - public List<ClassificationResult<T>> getClasses(String text, int max) > throws IOException; > - > - /** > - * Train the classifier using the underlying Lucene index > - * > - * @param leafReader the reader to use to access the Lucene index > - * @param textFieldName the name of the field used to compare documents > - * @param classFieldName the name of the field containing the class > assigned to documents > - * @param analyzer the analyzer used to tokenize / filter the > unseen text > - * @throws IOException If there is a low-level I/O error. > - */ > - public void train(LeafReader leafReader, String textFieldName, String > classFieldName, Analyzer analyzer) > - throws IOException; > - > - /** > - * Train the classifier using the underlying Lucene index > - * > - * @param leafReader the reader to use to access the Lucene index > - * @param textFieldName the name of the field used to compare documents > - * @param classFieldName the name of the field containing the class > assigned to documents > - * @param analyzer the analyzer used to tokenize / filter the > unseen text > - * @param query the query to filter which documents use for > training > - * @throws IOException If there is a low-level I/O error. > - */ > - public void train(LeafReader leafReader, String textFieldName, String > classFieldName, Analyzer analyzer, Query query) > - throws IOException; > - > - /** > - * Train the classifier using the underlying Lucene index > - * > - * @param leafReader the reader to use to access the Lucene index > - * @param textFieldNames the names of the fields to be used to compare > documents > - * @param classFieldName the name of the field containing the class > assigned to documents > - * @param analyzer the analyzer used to tokenize / filter the > unseen text > - * @param query the query to filter which documents use for > training > - * @throws IOException If there is a low-level I/O error. > - */ > - public void train(LeafReader leafReader, String[] textFieldNames, > String classFieldName, Analyzer analyzer, Query query) > - throws IOException; > + List<ClassificationResult<T>> getClasses(String text, int max) throws > IOException; > > } > > Modified: > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java > Thu Apr 30 14:12:03 2015 > @@ -26,6 +26,7 @@ import java.util.Map; > > import org.apache.lucene.analysis.Analyzer; > import org.apache.lucene.index.LeafReader; > +import org.apache.lucene.index.StorableField; > import org.apache.lucene.index.Term; > import org.apache.lucene.queries.mlt.MoreLikeThis; > import org.apache.lucene.search.BooleanClause; > @@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef; > */ > public class KNearestNeighborClassifier implements Classifier<BytesRef> { > > - private MoreLikeThis mlt; > - private String[] textFieldNames; > - private String classFieldName; > - private IndexSearcher indexSearcher; > + private final MoreLikeThis mlt; > + private final String[] textFieldNames; > + private final String classFieldName; > + private final IndexSearcher indexSearcher; > private final int k; > - private Query query; > + private final Query query; > > - private int minDocsFreq; > - private int minTermFreq; > - > - /** > - * Create a {@link Classifier} using kNN algorithm > - * > - * @param k the number of neighbors to analyze as an <code>int</code> > - */ > - public KNearestNeighborClassifier(int k) { > + public KNearestNeighborClassifier(LeafReader leafReader, Analyzer > analyzer, Query query, int k, int minDocsFreq, > + int minTermFreq, String > classFieldName, String... textFieldNames) { > + this.textFieldNames = textFieldNames; > + this.classFieldName = classFieldName; > + this.mlt = new MoreLikeThis(leafReader); > + this.mlt.setAnalyzer(analyzer); > + this.mlt.setFieldNames(textFieldNames); > + this.indexSearcher = new IndexSearcher(leafReader); > + if (minDocsFreq > 0) { > + mlt.setMinDocFreq(minDocsFreq); > + } > + if (minTermFreq > 0) { > + mlt.setMinTermFreq(minTermFreq); > + } > + this.query = query; > this.k = k; > } > > - /** > - * Create a {@link Classifier} using kNN algorithm > - * > - * @param k the number of neighbors to analyze as an > <code>int</code> > - * @param minDocsFreq the minimum number of docs frequency for MLT to > be set with {@link MoreLikeThis#setMinDocFreq(int)} > - * @param minTermFreq the minimum number of term frequency for MLT to > be set with {@link MoreLikeThis#setMinTermFreq(int)} > - */ > - public KNearestNeighborClassifier(int k, int minDocsFreq, int > minTermFreq) { > - this.k = k; > - this.minDocsFreq = minDocsFreq; > - this.minTermFreq = minTermFreq; > - } > > /** > * {@inheritDoc} > @@ -136,12 +131,15 @@ public class KNearestNeighborClassifier > private List<ClassificationResult<BytesRef>> > buildListFromTopDocs(TopDocs topDocs) throws IOException { > Map<BytesRef, Integer> classCounts = new HashMap<>(); > for (ScoreDoc scoreDoc : topDocs.scoreDocs) { > - BytesRef cl = new > BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue()); > - Integer count = classCounts.get(cl); > - if (count != null) { > - classCounts.put(cl, count + 1); > - } else { > - classCounts.put(cl, 1); > + StorableField storableField = > indexSearcher.doc(scoreDoc.doc).getField(classFieldName); > + if (storableField != null) { > + BytesRef cl = new BytesRef(storableField.stringValue()); > + Integer count = classCounts.get(cl); > + if (count != null) { > + classCounts.put(cl, count + 1); > + } else { > + classCounts.put(cl, 1); > + } > } > } > List<ClassificationResult<BytesRef>> returnList = new ArrayList<>(); > @@ -161,39 +159,4 @@ public class KNearestNeighborClassifier > return returnList; > } > > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String textFieldName, String > classFieldName, Analyzer analyzer) throws IOException { > - train(leafReader, textFieldName, classFieldName, analyzer, null); > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String textFieldName, String > classFieldName, Analyzer analyzer, Query query) throws IOException { > - train(leafReader, new String[]{textFieldName}, classFieldName, > analyzer, query); > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String[] textFieldNames, > String classFieldName, Analyzer analyzer, Query query) throws IOException { > - this.textFieldNames = textFieldNames; > - this.classFieldName = classFieldName; > - mlt = new MoreLikeThis(leafReader); > - mlt.setAnalyzer(analyzer); > - mlt.setFieldNames(textFieldNames); > - indexSearcher = new IndexSearcher(leafReader); > - if (minDocsFreq > 0) { > - mlt.setMinDocFreq(minDocsFreq); > - } > - if (minTermFreq > 0) { > - mlt.setMinTermFreq(minTermFreq); > - } > - this.query = query; > - } > } > > Modified: > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java > Thu Apr 30 14:12:03 2015 > @@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier > * {@link org.apache.lucene.index.LeafReader} used to access the {@link > org.apache.lucene.classification.Classifier}'s > * index > */ > - protected LeafReader leafReader; > + protected final LeafReader leafReader; > > /** > * names of the fields to be used as input text > */ > - protected String[] textFieldNames; > + protected final String[] textFieldNames; > > /** > * name of the field to be used as a class / category output > */ > - protected String classFieldName; > + protected final String classFieldName; > > /** > * {@link org.apache.lucene.analysis.Analyzer} to be used for > tokenizing unseen input text > */ > - protected Analyzer analyzer; > + protected final Analyzer analyzer; > > /** > * {@link org.apache.lucene.search.IndexSearcher} to run searches on > the index for retrieving frequencies > */ > - protected IndexSearcher indexSearcher; > + protected final IndexSearcher indexSearcher; > > /** > * {@link org.apache.lucene.search.Query} used to eventually filter the > document set to be used to classify > */ > - protected Query query; > + protected final Query query; > > /** > * Creates a new NaiveBayes classifier. > - * Note that you must call {@link > #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) > train()} before you can > * classify any documents. > */ > - public SimpleNaiveBayesClassifier() { > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String textFieldName, String > classFieldName, Analyzer analyzer) throws IOException { > - train(leafReader, textFieldName, classFieldName, analyzer, null); > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String textFieldName, String > classFieldName, Analyzer analyzer, Query query) > - throws IOException { > - train(leafReader, new String[]{textFieldName}, classFieldName, > analyzer, query); > - } > - > - /** > - * {@inheritDoc} > - */ > - @Override > - public void train(LeafReader leafReader, String[] textFieldNames, > String classFieldName, Analyzer analyzer, Query query) > - throws IOException { > + public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer > analyzer, Query query, String classFieldName, String... textFieldNames) { > this.leafReader = leafReader; > this.indexSearcher = new IndexSearcher(this.leafReader); > this.textFieldNames = textFieldNames; > > Modified: > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java > Thu Apr 30 14:12:03 2015 > @@ -18,7 +18,7 @@ > /** > * Uses already seen data (the indexed documents) to classify new > documents. > * <p> > - * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest > + * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest > * Neighbor classifier and a Perceptron based classifier. > */ > package org.apache.lucene.classification; > > Modified: > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java > Thu Apr 30 14:12:03 2015 > @@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils { > > /** > * create a sparse <code>Double</code> vector given doc and field term > vectors using local frequency of the terms in the doc > - * @param docTerms term vectors for a given document > + * > + * @param docTerms term vectors for a given document > * @param fieldTerms field term vectors > * @return a sparse vector of <code>Double</code>s as an array > * @throws IOException in case accessing the underlying index fails > @@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils { > if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) { > long termFreqLocal = docTermsEnum.totalTermFreq(); // the total > number of occurrences of this term in the given document > freqVector[i] = Long.valueOf(termFreqLocal).doubleValue(); > - } > - else { > + } else { > freqVector[i] = 0d; > } > i++; > @@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils { > > /** > * create a dense <code>Double</code> vector given doc and field term > vectors using local frequency of the terms in the doc > + * > * @param docTerms term vectors for a given document > * @return a dense vector of <code>Double</code>s as an array > * @throws IOException in case accessing the underlying index fails > @@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils { > public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) > throws IOException { > Double[] freqVector = null; > if (docTerms != null) { > - freqVector = new Double[(int) docTerms.size()]; > - int i = 0; > - TermsEnum docTermsEnum = docTerms.iterator(); > + freqVector = new Double[(int) docTerms.size()]; > + int i = 0; > + TermsEnum docTermsEnum = docTerms.iterator(); > > - while (docTermsEnum.next() != null) { > - long termFreqLocal = docTermsEnum.totalTermFreq(); // the > total number of occurrences of this term in the given document > - freqVector[i] = Long.valueOf(termFreqLocal).doubleValue(); > - i++; > - } > + while (docTermsEnum.next() != null) { > + long termFreqLocal = docTermsEnum.totalTermFreq(); // the total > number of occurrences of this term in the given document > + freqVector[i] = Long.valueOf(termFreqLocal).doubleValue(); > + i++; > + } > } > return freqVector; > -} > + } > } > > Modified: > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java > Thu Apr 30 14:12:03 2015 > @@ -17,6 +17,8 @@ > package org.apache.lucene.classification; > > import org.apache.lucene.analysis.MockAnalyzer; > +import org.apache.lucene.index.LeafReader; > +import org.apache.lucene.index.SlowCompositeReaderWrapper; > import org.apache.lucene.index.Term; > import org.apache.lucene.search.TermQuery; > import org.junit.Test; > @@ -28,22 +30,45 @@ public class BooleanPerceptronClassifier > > @Test > public void testBasicUsage() throws Exception { > - checkCorrectClassification(new BooleanPerceptronClassifier(), > TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, > booleanFieldName); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + checkCorrectClassification(new > BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, > analyzer, null, 1, null), TECHNOLOGY_INPUT, false); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > @Test > public void testExplicitThreshold() throws Exception { > - checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), > TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, > booleanFieldName); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + checkCorrectClassification(new > BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, > analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > @Test > public void testBasicUsageWithQuery() throws Exception { > - checkCorrectClassification(new BooleanPerceptronClassifier(), > TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, > booleanFieldName, new TermQuery(new Term(textFieldName, "it"))); > - } > - > - @Test > - public void testPerformance() throws Exception { > - checkPerformance(new BooleanPerceptronClassifier(), new > MockAnalyzer(random()), booleanFieldName); > + TermQuery query = new TermQuery(new Term(textFieldName, "it")); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + checkCorrectClassification(new > BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, > analyzer, query, 1, null), TECHNOLOGY_INPUT, false); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > } > > Modified: > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java > Thu Apr 30 14:12:03 2015 > @@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokeni > import org.apache.lucene.analysis.core.KeywordTokenizer; > import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; > import org.apache.lucene.analysis.reverse.ReverseStringFilter; > +import org.apache.lucene.index.LeafReader; > +import org.apache.lucene.index.SlowCompositeReaderWrapper; > import org.apache.lucene.index.Term; > import org.apache.lucene.search.TermQuery; > import org.apache.lucene.util.BytesRef; > @@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifier > > @Test > public void testBasicUsage() throws Exception { > - checkCorrectClassification(new CachingNaiveBayesClassifier(), > TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), > textFieldName, categoryFieldName); > - checkCorrectClassification(new CachingNaiveBayesClassifier(), > POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, > categoryFieldName); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + checkCorrectClassification(new > CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, > textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + checkCorrectClassification(new > CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, > textFieldName), POLITICS_INPUT, POLITICS_RESULT); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > @Test > public void testBasicUsageWithQuery() throws Exception { > - checkCorrectClassification(new CachingNaiveBayesClassifier(), > TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), > textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, > "it"))); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + TermQuery query = new TermQuery(new Term(textFieldName, "it")); > + checkCorrectClassification(new > CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, > textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > @Test > public void testNGramUsage() throws Exception { > - checkCorrectClassification(new CachingNaiveBayesClassifier(), > TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, > categoryFieldName); > + LeafReader leafReader = null; > + try { > + NGramAnalyzer analyzer = new NGramAnalyzer(); > + leafReader = populateSampleIndex(analyzer); > + checkCorrectClassification(new > CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, > textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > private class NGramAnalyzer extends Analyzer { > @@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifier > } > } > > - @Test > - public void testPerformance() throws Exception { > - checkPerformance(new CachingNaiveBayesClassifier(), new > MockAnalyzer(random()), categoryFieldName); > - } > - > } > > Modified: > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java > Thu Apr 30 14:12:03 2015 > @@ -41,14 +41,14 @@ import org.junit.Before; > */ > public abstract class ClassificationTestBase<T> extends LuceneTestCase { > public final static String POLITICS_INPUT = "Here are some interesting > questions and answers about Mitt Romney.. " + > - "If you don't know the answer to the question about Mitt Romney, > then simply click on the answer below the question section."; > + "If you don't know the answer to the question about Mitt > Romney, then simply click on the answer below the question section."; > public static final BytesRef POLITICS_RESULT = new BytesRef("politics"); > > public static final String TECHNOLOGY_INPUT = "Much is made of what the > likes of Facebook, Google and Apple know about users." + > - " Truth is, Amazon may know more."; > + " Truth is, Amazon may know more."; > public static final BytesRef TECHNOLOGY_RESULT = new > BytesRef("technology"); > > - private RandomIndexWriter indexWriter; > + protected RandomIndexWriter indexWriter; > private Directory dir; > private FieldType ft; > > @@ -79,53 +79,34 @@ public abstract class ClassificationTest > dir.close(); > } > > - protected void checkCorrectClassification(Classifier<T> classifier, > String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, > String classFieldName) throws Exception { > - checkCorrectClassification(classifier, inputDoc, expectedResult, > analyzer, textFieldName, classFieldName, null); > + protected void checkCorrectClassification(Classifier<T> classifier, > String inputDoc, T expectedResult) throws Exception { > + ClassificationResult<T> classificationResult = > classifier.assignClass(inputDoc); > + assertNotNull(classificationResult.getAssignedClass()); > + assertEquals("got an assigned class of " + > classificationResult.getAssignedClass(), expectedResult, > classificationResult.getAssignedClass()); > + double score = classificationResult.getScore(); > + assertTrue("score should be between 0 and 1, got:" + score, score <= > 1 && score >= 0); > } > > - protected void checkCorrectClassification(Classifier<T> classifier, > String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, > String classFieldName, Query query) throws Exception { > - LeafReader leafReader = null; > - try { > - populateSampleIndex(analyzer); > - leafReader = > SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); > - classifier.train(leafReader, textFieldName, classFieldName, > analyzer, query); > - ClassificationResult<T> classificationResult = > classifier.assignClass(inputDoc); > - assertNotNull(classificationResult.getAssignedClass()); > - assertEquals("got an assigned class of " + > classificationResult.getAssignedClass(), expectedResult, > classificationResult.getAssignedClass()); > - double score = classificationResult.getScore(); > - assertTrue("score should be between 0 and 1, got:" + score, score > <= 1 && score >= 0); > - } finally { > - if (leafReader != null) > - leafReader.close(); > - } > - } > protected void checkOnlineClassification(Classifier<T> classifier, > String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, > String classFieldName) throws Exception { > checkOnlineClassification(classifier, inputDoc, expectedResult, > analyzer, textFieldName, classFieldName, null); > } > > protected void checkOnlineClassification(Classifier<T> classifier, > String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, > String classFieldName, Query query) throws Exception { > - LeafReader leafReader = null; > - try { > - populateSampleIndex(analyzer); > - leafReader = > SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); > - classifier.train(leafReader, textFieldName, classFieldName, > analyzer, query); > - ClassificationResult<T> classificationResult = > classifier.assignClass(inputDoc); > - assertNotNull(classificationResult.getAssignedClass()); > - assertEquals("got an assigned class of " + > classificationResult.getAssignedClass(), expectedResult, > classificationResult.getAssignedClass()); > - double score = classificationResult.getScore(); > - assertTrue("score should be between 0 and 1, got: " + score, score > <= 1 && score >= 0); > - updateSampleIndex(); > - ClassificationResult<T> secondClassificationResult = > classifier.assignClass(inputDoc); > - assertEquals(classificationResult.getAssignedClass(), > secondClassificationResult.getAssignedClass()); > - assertEquals(Double.valueOf(score), > Double.valueOf(secondClassificationResult.getScore())); > - > - } finally { > - if (leafReader != null) > - leafReader.close(); > - } > + populateSampleIndex(analyzer); > + > + ClassificationResult<T> classificationResult = > classifier.assignClass(inputDoc); > + assertNotNull(classificationResult.getAssignedClass()); > + assertEquals("got an assigned class of " + > classificationResult.getAssignedClass(), expectedResult, > classificationResult.getAssignedClass()); > + double score = classificationResult.getScore(); > + assertTrue("score should be between 0 and 1, got: " + score, score <= > 1 && score >= 0); > + updateSampleIndex(); > + ClassificationResult<T> secondClassificationResult = > classifier.assignClass(inputDoc); > + assertEquals(classificationResult.getAssignedClass(), > secondClassificationResult.getAssignedClass()); > + assertEquals(Double.valueOf(score), > Double.valueOf(secondClassificationResult.getScore())); > + > } > > - private void populateSampleIndex(Analyzer analyzer) throws IOException { > + protected LeafReader populateSampleIndex(Analyzer analyzer) throws > IOException { > indexWriter.close(); > indexWriter = new RandomIndexWriter(random(), dir, > newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE)); > indexWriter.commit(); > @@ -134,8 +115,8 @@ public abstract class ClassificationTest > > Document doc = new Document(); > text = "The traveling press secretary for Mitt Romney lost his cool > and cursed at reporters " + > - "who attempted to ask questions of the Republican presidential > candidate in a public plaza near the Tomb of " + > - "the Unknown Soldier in Warsaw Tuesday."; > + "who attempted to ask questions of the Republican > presidential candidate in a public plaza near the Tomb of " + > + "the Unknown Soldier in Warsaw Tuesday."; > doc.add(new Field(textFieldName, text, ft)); > doc.add(new Field(categoryFieldName, "politics", ft)); > doc.add(new Field(booleanFieldName, "true", ft)); > @@ -144,7 +125,7 @@ public abstract class ClassificationTest > > doc = new Document(); > text = "Mitt Romney seeks to assure Israel and Iran, as well as > Jewish voters in the United" + > - " States, that he will be tougher against Iran's nuclear > ambitions than President Barack Obama."; > + " States, that he will be tougher against Iran's nuclear > ambitions than President Barack Obama."; > doc.add(new Field(textFieldName, text, ft)); > doc.add(new Field(categoryFieldName, "politics", ft)); > doc.add(new Field(booleanFieldName, "true", ft)); > @@ -152,8 +133,8 @@ public abstract class ClassificationTest > > doc = new Document(); > text = "And there's a threshold question that he has to answer for > the American people and " + > - "that's whether he is prepared to be commander-in-chief,\" she > continued. \"As we look to the past events, we " + > - "know that this raises some questions about his preparedness and > we'll see how the rest of his trip goes.\""; > + "that's whether he is prepared to be commander-in-chief,\" > she continued. \"As we look to the past events, we " + > + "know that this raises some questions about his preparedness > and we'll see how the rest of his trip goes.\""; > doc.add(new Field(textFieldName, text, ft)); > doc.add(new Field(categoryFieldName, "politics", ft)); > doc.add(new Field(booleanFieldName, "true", ft)); > @@ -161,8 +142,8 @@ public abstract class ClassificationTest > > doc = new Document(); > text = "Still, when it comes to gun policy, many congressional > Democrats have \"decided to " + > - "keep quiet and not go there,\" said Alan Lizotte, dean and > professor at the State University of New York at " + > - "Albany's School of Criminal Justice."; > + "keep quiet and not go there,\" said Alan Lizotte, dean and > professor at the State University of New York at " + > + "Albany's School of Criminal Justice."; > doc.add(new Field(textFieldName, text, ft)); > doc.add(new Field(categoryFieldName, "politics", ft)); > doc.add(new Field(booleanFieldName, "true", ft)); > @@ -170,8 +151,8 @@ public abstract class ClassificationTest > > doc = new Document(); > text = "Standing amongst the thousands of people at the state > Capitol, Jorstad, director of " + > - "technology at the University of Wisconsin-La Crosse, documented > the historic moment and shared it with the " + > - "world through the Internet."; > + "technology at the University of Wisconsin-La Crosse, > documented the historic moment and shared it with the " + > + "world through the Internet."; > doc.add(new Field(textFieldName, text, ft)); > doc.add(new Field(categoryFieldName, "technology", ft)); > doc.add(new Field(booleanFieldName, "false", ft)); > @@ -179,7 +160,7 @@ public abstract class ClassificationTest > > doc = new Document(); > text = "So, about all those experts and analysts who've spent the > past year or so saying " + > - "Facebook was going to make a phone. A new expert has stepped > forward to say it's not going to happen."; > + "Facebook was going to make a phone. A new expert has stepped > forward to say it's not going to happen."; > doc.add(new Field(textFieldName, text, ft)); > doc.add(new Field(categoryFieldName, "technology", ft)); > doc.add(new Field(booleanFieldName, "false", ft)); > @@ -187,8 +168,8 @@ public abstract class ClassificationTest > > doc = new Document(); > text = "More than 400 million people trust Google with their e-mail, > and 50 million store files" + > - " in the cloud using the Dropbox service. People manage their > bank accounts, pay bills, trade stocks and " + > - "generally transfer or store huge volumes of personal data > online."; > + " in the cloud using the Dropbox service. People manage their > bank accounts, pay bills, trade stocks and " + > + "generally transfer or store huge volumes of personal data > online."; > doc.add(new Field(textFieldName, text, ft)); > doc.add(new Field(categoryFieldName, "technology", ft)); > doc.add(new Field(booleanFieldName, "false", ft)); > @@ -200,22 +181,15 @@ public abstract class ClassificationTest > indexWriter.addDocument(doc); > > indexWriter.commit(); > + return SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); > } > > protected void checkPerformance(Classifier<T> classifier, Analyzer > analyzer, String classFieldName) throws Exception { > - LeafReader leafReader = null; > long trainStart = System.currentTimeMillis(); > - try { > - populatePerformanceIndex(analyzer); > - leafReader = > SlowCompositeReaderWrapper.wrap(indexWriter.getReader()); > - classifier.train(leafReader, textFieldName, classFieldName, > analyzer); > - long trainEnd = System.currentTimeMillis(); > - long trainTime = trainEnd - trainStart; > - assertTrue("training took more than 2 mins : " + trainTime / 1000 + > "s", trainTime < 120000); > - } finally { > - if (leafReader != null) > - leafReader.close(); > - } > + populatePerformanceIndex(analyzer); > + long trainEnd = System.currentTimeMillis(); > + long trainTime = trainEnd - trainStart; > + assertTrue("training took more than 2 mins : " + trainTime / 1000 + > "s", trainTime < 120000); > } > > private void populatePerformanceIndex(Analyzer analyzer) throws > IOException { > > Modified: > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java > Thu Apr 30 14:12:03 2015 > @@ -17,6 +17,8 @@ > package org.apache.lucene.classification; > > import org.apache.lucene.analysis.MockAnalyzer; > +import org.apache.lucene.index.LeafReader; > +import org.apache.lucene.index.SlowCompositeReaderWrapper; > import org.apache.lucene.index.Term; > import org.apache.lucene.search.TermQuery; > import org.apache.lucene.util.BytesRef; > @@ -29,20 +31,32 @@ public class KNearestNeighborClassifierT > > @Test > public void testBasicUsage() throws Exception { > - // usage with default MLT min docs / term freq > - checkCorrectClassification(new KNearestNeighborClassifier(3), > POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, > categoryFieldName); > - // usage without custom min docs / term freq for MLT > - checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), > TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), > textFieldName, categoryFieldName); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + checkCorrectClassification(new > KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, > categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + checkCorrectClassification(new > KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, > categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > @Test > public void testBasicUsageWithQuery() throws Exception { > - checkCorrectClassification(new KNearestNeighborClassifier(1), > TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), > textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, > "it"))); > - } > - > - @Test > - public void testPerformance() throws Exception { > - checkPerformance(new KNearestNeighborClassifier(100), new > MockAnalyzer(random()), categoryFieldName); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + TermQuery query = new TermQuery(new Term(textFieldName, "it")); > + checkCorrectClassification(new > KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, > categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > } > > Modified: > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java > URL: > http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff > > ============================================================================== > --- > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java > (original) > +++ > lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java > Thu Apr 30 14:12:03 2015 > @@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokeni > import org.apache.lucene.analysis.core.KeywordTokenizer; > import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; > import org.apache.lucene.analysis.reverse.ReverseStringFilter; > +import org.apache.lucene.index.LeafReader; > +import org.apache.lucene.index.SlowCompositeReaderWrapper; > import org.apache.lucene.index.Term; > import org.apache.lucene.search.TermQuery; > import org.apache.lucene.util.BytesRef; > -import org.apache.lucene.util.LuceneTestCase; > import org.junit.Test; > > -import java.io.Reader; > - > /** > * Testcase for {@link SimpleNaiveBayesClassifier} > */ > @@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierT > > @Test > public void testBasicUsage() throws Exception { > - checkCorrectClassification(new SimpleNaiveBayesClassifier(), > TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), > textFieldName, categoryFieldName); > - checkCorrectClassification(new SimpleNaiveBayesClassifier(), > POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, > categoryFieldName); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + checkCorrectClassification(new > SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, > textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + checkCorrectClassification(new > SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, > textFieldName), POLITICS_INPUT, POLITICS_RESULT); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > @Test > public void testBasicUsageWithQuery() throws Exception { > - checkCorrectClassification(new SimpleNaiveBayesClassifier(), > TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), > textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, > "it"))); > + LeafReader leafReader = null; > + try { > + MockAnalyzer analyzer = new MockAnalyzer(random()); > + leafReader = populateSampleIndex(analyzer); > + TermQuery query = new TermQuery(new Term(textFieldName, "it")); > + checkCorrectClassification(new > SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, > textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > @Test > public void testNGramUsage() throws Exception { > - checkCorrectClassification(new SimpleNaiveBayesClassifier(), > TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, > categoryFieldName); > + LeafReader leafReader = null; > + try { > + Analyzer analyzer = new NGramAnalyzer(); > + leafReader = populateSampleIndex(analyzer); > + checkCorrectClassification(new > CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, > textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT); > + } finally { > + if (leafReader != null) { > + leafReader.close(); > + } > + } > } > > private class NGramAnalyzer extends Analyzer { > @@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierT > } > } > > - @Test > - public void testPerformance() throws Exception { > - checkPerformance(new SimpleNaiveBayesClassifier(), new > MockAnalyzer(random()), categoryFieldName); > - } > - > } > > >
