Did you mean to remove javadocs on BooleanPerceptionClassifier? This breaks
precommit...

On Thu, Apr 30, 2015 at 7:12 AM, <[email protected]> wrote:

> Author: tommaso
> Date: Thu Apr 30 14:12:03 2015
> New Revision: 1676997
>
> URL: http://svn.apache.org/r1676997
> Log:
> LUCENE-6045 - refactor Classifier API to work better with multithreading
>
> Modified:
>
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
> Thu Apr 30 14:12:03 2015
> @@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util;
>   */
>  public class BooleanPerceptronClassifier implements Classifier<Boolean> {
>
> -  private Double threshold;
> -  private final Integer batchSize;
> -  private Terms textTerms;
> -  private Analyzer analyzer;
> -  private String textFieldName;
> +  private final Double threshold;
> +  private final Terms textTerms;
> +  private final Analyzer analyzer;
> +  private final String textFieldName;
>    private FST<Long> fst;
>
> -  /**
> -   * Create a {@link BooleanPerceptronClassifier}
> -   *
> -   * @param threshold the binary threshold for perceptron output
> evaluation
> -   */
> -  public BooleanPerceptronClassifier(Double threshold, Integer batchSize)
> {
> -    this.threshold = threshold;
> -    this.batchSize = batchSize;
> -  }
> -
> -  /**
> -   * Default constructor, no batch updates of FST, perceptron threshold is
> -   * calculated via underlying index metrics during
> -   * {@link #train(org.apache.lucene.index.LeafReader, String, String,
> org.apache.lucene.analysis.Analyzer)
> -   * training}
> -   */
> -  public BooleanPerceptronClassifier() {
> -    batchSize = 1;
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public ClassificationResult<Boolean> assignClass(String text)
> -      throws IOException {
> -    if (textTerms == null) {
> -      throw new IOException("You must first call Classifier#train");
> -    }
> -    Long output = 0l;
> -    try (TokenStream tokenStream = analyzer.tokenStream(textFieldName,
> text)) {
> -      CharTermAttribute charTermAttribute = tokenStream
> -          .addAttribute(CharTermAttribute.class);
> -      tokenStream.reset();
> -      while (tokenStream.incrementToken()) {
> -        String s = charTermAttribute.toString();
> -        Long d = Util.get(fst, new BytesRef(s));
> -        if (d != null) {
> -          output += d;
> -        }
> -      }
> -      tokenStream.end();
> -    }
> -
> -    double score = 1 - Math.exp(-1 * Math.abs(threshold -
> output.doubleValue()) / threshold);
> -    return new ClassificationResult<>(output >= threshold, score);
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String textFieldName,
> -                    String classFieldName, Analyzer analyzer) throws
> IOException {
> -    train(leafReader, textFieldName, classFieldName, analyzer, null);
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String textFieldName,
> -                    String classFieldName, Analyzer analyzer, Query
> query) throws IOException {
> +  public BooleanPerceptronClassifier(LeafReader leafReader, String
> textFieldName, String classFieldName, Analyzer analyzer,
> +                                     Query query, Integer batchSize,
> Double threshold) throws IOException {
>      this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
>
>      if (textTerms == null) {
> @@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier
>          this.threshold = (double) sumDocFreq / 2d;
>        } else {
>          throw new IOException(
> -            "threshold cannot be assigned since term vectors for field "
> -                + textFieldName + " do not exist");
> +                "threshold cannot be assigned since term vectors for
> field "
> +                        + textFieldName + " do not exist");
>        }
> +    } else {
> +      this.threshold = threshold;
>      }
>
>      // TODO : remove this map as soon as we have a writable FST
> @@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier
>      }
>      // run the search and use stored field values
>      for (ScoreDoc scoreDoc : indexSearcher.search(q,
> -        Integer.MAX_VALUE).scoreDocs) {
> +            Integer.MAX_VALUE).scoreDocs) {
>        StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
>
>        StorableField textField = doc.getField(textFieldName);
> @@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier
>          long modifier = correctClass.compareTo(assignedClass);
>          if (modifier != 0) {
>            updateWeights(leafReader, scoreDoc.doc, assignedClass,
> -                weights, modifier, batchCount % batchSize == 0);
> +                  weights, modifier, batchCount % batchSize == 0);
>          }
>          batchCount++;
>        }
> @@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier
>      weights.clear(); // free memory while waiting for GC
>    }
>
> -  @Override
> -  public void train(LeafReader leafReader, String[] textFieldNames,
> String classFieldName, Analyzer analyzer, Query query) throws IOException {
> -    throw new IOException("training with multiple fields not supported by
> boolean perceptron classifier");
> -  }
> -
>    private void updateWeights(LeafReader leafReader,
>                               int docId, Boolean assignedClass,
> SortedMap<String, Double> weights,
>                               double modifier, boolean updateFST) throws
> IOException {
> @@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier
>
>      if (terms == null) {
>        throw new IOException("term vectors must be stored for field "
> -          + textFieldName);
> +              + textFieldName);
>      }
>
>      TermsEnum termsEnum = terms.iterator();
> @@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier
>      for (Map.Entry<String, Double> entry : weights.entrySet()) {
>        scratchBytes.copyChars(entry.getKey());
>        fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts),
> entry
> -          .getValue().longValue());
> +              .getValue().longValue());
>      }
>      fst = fstBuilder.finish();
>    }
>
> +
> +  /**
> +   * {@inheritDoc}
> +   */
> +  @Override
> +  public ClassificationResult<Boolean> assignClass(String text)
> +          throws IOException {
> +    if (textTerms == null) {
> +      throw new IOException("You must first call Classifier#train");
> +    }
> +    Long output = 0l;
> +    try (TokenStream tokenStream = analyzer.tokenStream(textFieldName,
> text)) {
> +      CharTermAttribute charTermAttribute = tokenStream
> +              .addAttribute(CharTermAttribute.class);
> +      tokenStream.reset();
> +      while (tokenStream.incrementToken()) {
> +        String s = charTermAttribute.toString();
> +        Long d = Util.get(fst, new BytesRef(s));
> +        if (d != null) {
> +          output += d;
> +        }
> +      }
> +      tokenStream.end();
> +    }
> +
> +    double score = 1 - Math.exp(-1 * Math.abs(threshold -
> output.doubleValue()) / threshold);
> +    return new ClassificationResult<>(output >= threshold, score);
> +  }
> +
>    /**
>     * {@inheritDoc}
>     */
>    @Override
>    public List<ClassificationResult<Boolean>> getClasses(String text)
> -      throws IOException {
> +          throws IOException {
>      throw new RuntimeException("not implemented");
>    }
>
> @@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier
>     */
>    @Override
>    public List<ClassificationResult<Boolean>> getClasses(String text, int
> max)
> -      throws IOException {
> +          throws IOException {
>      throw new RuntimeException("not implemented");
>    }
>
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
> Thu Apr 30 14:12:03 2015
> @@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef;
>   */
>  public class CachingNaiveBayesClassifier extends
> SimpleNaiveBayesClassifier {
>    //for caching classes this will be the classification class list
> -  private ArrayList<BytesRef> cclasses = new ArrayList<>();
> +  private final ArrayList<BytesRef> cclasses = new ArrayList<>();
>    // it's a term-inmap style map, where the inmap contains class-hit
> pairs to the
>    // upper term
> -  private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new
> HashMap<>();
> +  private final Map<String, Map<BytesRef, Integer>> termCClassHitCache =
> new HashMap<>();
>    // the term frequency in classes
> -  private Map<BytesRef, Double> classTermFreq = new HashMap<>();
> +  private final Map<BytesRef, Double> classTermFreq = new HashMap<>();
>    private boolean justCachedTerms;
>    private int docsWithClassSize;
>
>    /**
> -   * Creates a new NaiveBayes classifier with inside caching. Note that
> you must
> -   * call {@link #train(org.apache.lucene.index.LeafReader, String,
> String, Analyzer) train()} before
> -   * you can classify any documents. If you want less memory usage you
> could
> +   * Creates a new NaiveBayes classifier with inside caching. If you want
> less memory usage you could
>     * call {@link #reInitCache(int, boolean) reInitCache()}.
>     */
> -  public CachingNaiveBayesClassifier() {
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String textFieldName, String
> classFieldName, Analyzer analyzer) throws IOException {
> -    train(leafReader, textFieldName, classFieldName, analyzer, null);
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String textFieldName, String
> classFieldName, Analyzer analyzer, Query query) throws IOException {
> -    train(leafReader, new String[]{textFieldName}, classFieldName,
> analyzer, query);
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String[] textFieldNames,
> String classFieldName, Analyzer analyzer, Query query) throws IOException {
> -    super.train(leafReader, textFieldNames, classFieldName, analyzer,
> query);
> +  public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer
> analyzer, Query query, String classFieldName, String... textFieldNames) {
> +    super(leafReader, analyzer, query, classFieldName, textFieldNames);
>      // building the cache
> -    reInitCache(0, true);
> +    try {
> +      reInitCache(0, true);
> +    } catch (IOException e) {
> +      throw new RuntimeException(e);
> +    }
>    }
>
> +
>    private List<ClassificationResult<BytesRef>>
> assignClassNormalizedList(String inputDocument) throws IOException {
>      if (leafReader == null) {
>        throw new IOException("You must first call Classifier#train");
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
> Thu Apr 30 14:12:03 2015
> @@ -18,17 +18,19 @@ package org.apache.lucene.classification
>
>  /**
>   * The result of a call to {@link Classifier#assignClass(String)} holding
> an assigned class of type <code>T</code> and a score.
> + *
>   * @lucene.experimental
>   */
> -public class ClassificationResult<T> implements
> Comparable<ClassificationResult<T>>{
> +public class ClassificationResult<T> implements
> Comparable<ClassificationResult<T>> {
>
>    private final T assignedClass;
>    private double score;
>
>    /**
>     * Constructor
> +   *
>     * @param assignedClass the class <code>T</code> assigned by a {@link
> Classifier}
> -   * @param score the score for the assignedClass as a <code>double</code>
> +   * @param score         the score for the assignedClass as a
> <code>double</code>
>     */
>    public ClassificationResult(T assignedClass, double score) {
>      this.assignedClass = assignedClass;
> @@ -37,6 +39,7 @@ public class ClassificationResult<T> imp
>
>    /**
>     * retrieve the result class
> +   *
>     * @return a <code>T</code> representing an assigned class
>     */
>    public T getAssignedClass() {
> @@ -45,14 +48,16 @@ public class ClassificationResult<T> imp
>
>    /**
>     * retrieve the result score
> +   *
>     * @return a <code>double</code> representing a result score
>     */
>    public double getScore() {
>      return score;
>    }
> -
> +
>    /**
>     * set the score value
> +   *
>     * @param score the score for the assignedClass as a <code>double</code>
>     */
>    public void setScore(double score) {
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
> Thu Apr 30 14:12:03 2015
> @@ -22,7 +22,6 @@ import java.util.List;
>  import org.apache.lucene.analysis.Analyzer;
>  import org.apache.lucene.index.LeafReader;
>  import org.apache.lucene.search.Query;
> -import org.apache.lucene.util.BytesRef;
>
>  /**
>   * A classifier, see <code>
> http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which
> assign classes of type
> @@ -39,7 +38,7 @@ public interface Classifier<T> {
>     * @return a {@link ClassificationResult} holding assigned class of
> type <code>T</code> and score
>     * @throws IOException If there is a low-level I/O error.
>     */
> -  public ClassificationResult<T> assignClass(String text) throws
> IOException;
> +  ClassificationResult<T> assignClass(String text) throws IOException;
>
>    /**
>     * Get all the classes (sorted by score, descending) assigned to the
> given text String.
> @@ -48,7 +47,7 @@ public interface Classifier<T> {
>     * @return the whole list of {@link ClassificationResult}, the classes
> and scores. Returns <code>null</code> if the classifier can't make lists.
>     * @throws IOException If there is a low-level I/O error.
>     */
> -  public List<ClassificationResult<T>> getClasses(String text) throws
> IOException;
> +  List<ClassificationResult<T>> getClasses(String text) throws
> IOException;
>
>    /**
>     * Get the first <code>max</code> classes (sorted by score, descending)
> assigned to the given text String.
> @@ -58,44 +57,6 @@ public interface Classifier<T> {
>     * @return the whole list of {@link ClassificationResult}, the classes
> and scores. Cut for "max" number of elements. Returns <code>null</code> if
> the classifier can't make lists.
>     * @throws IOException If there is a low-level I/O error.
>     */
> -  public List<ClassificationResult<T>> getClasses(String text, int max)
> throws IOException;
> -
> -  /**
> -   * Train the classifier using the underlying Lucene index
> -   *
> -   * @param leafReader   the reader to use to access the Lucene index
> -   * @param textFieldName  the name of the field used to compare documents
> -   * @param classFieldName the name of the field containing the class
> assigned to documents
> -   * @param analyzer       the analyzer used to tokenize / filter the
> unseen text
> -   * @throws IOException If there is a low-level I/O error.
> -   */
> -  public void train(LeafReader leafReader, String textFieldName, String
> classFieldName, Analyzer analyzer)
> -      throws IOException;
> -
> -  /**
> -   * Train the classifier using the underlying Lucene index
> -   *
> -   * @param leafReader   the reader to use to access the Lucene index
> -   * @param textFieldName  the name of the field used to compare documents
> -   * @param classFieldName the name of the field containing the class
> assigned to documents
> -   * @param analyzer       the analyzer used to tokenize / filter the
> unseen text
> -   * @param query          the query to filter which documents use for
> training
> -   * @throws IOException If there is a low-level I/O error.
> -   */
> -  public void train(LeafReader leafReader, String textFieldName, String
> classFieldName, Analyzer analyzer, Query query)
> -      throws IOException;
> -
> -  /**
> -   * Train the classifier using the underlying Lucene index
> -   *
> -   * @param leafReader   the reader to use to access the Lucene index
> -   * @param textFieldNames the names of the fields to be used to compare
> documents
> -   * @param classFieldName the name of the field containing the class
> assigned to documents
> -   * @param analyzer       the analyzer used to tokenize / filter the
> unseen text
> -   * @param query          the query to filter which documents use for
> training
> -   * @throws IOException If there is a low-level I/O error.
> -   */
> -  public void train(LeafReader leafReader, String[] textFieldNames,
> String classFieldName, Analyzer analyzer, Query query)
> -      throws IOException;
> +  List<ClassificationResult<T>> getClasses(String text, int max) throws
> IOException;
>
>  }
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
> Thu Apr 30 14:12:03 2015
> @@ -26,6 +26,7 @@ import java.util.Map;
>
>  import org.apache.lucene.analysis.Analyzer;
>  import org.apache.lucene.index.LeafReader;
> +import org.apache.lucene.index.StorableField;
>  import org.apache.lucene.index.Term;
>  import org.apache.lucene.queries.mlt.MoreLikeThis;
>  import org.apache.lucene.search.BooleanClause;
> @@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef;
>   */
>  public class KNearestNeighborClassifier implements Classifier<BytesRef> {
>
> -  private MoreLikeThis mlt;
> -  private String[] textFieldNames;
> -  private String classFieldName;
> -  private IndexSearcher indexSearcher;
> +  private final MoreLikeThis mlt;
> +  private final String[] textFieldNames;
> +  private final String classFieldName;
> +  private final IndexSearcher indexSearcher;
>    private final int k;
> -  private Query query;
> +  private final Query query;
>
> -  private int minDocsFreq;
> -  private int minTermFreq;
> -
> -  /**
> -   * Create a {@link Classifier} using kNN algorithm
> -   *
> -   * @param k the number of neighbors to analyze as an <code>int</code>
> -   */
> -  public KNearestNeighborClassifier(int k) {
> +  public KNearestNeighborClassifier(LeafReader leafReader, Analyzer
> analyzer, Query query, int k, int minDocsFreq,
> +                                    int minTermFreq, String
> classFieldName, String... textFieldNames) {
> +    this.textFieldNames = textFieldNames;
> +    this.classFieldName = classFieldName;
> +    this.mlt = new MoreLikeThis(leafReader);
> +    this.mlt.setAnalyzer(analyzer);
> +    this.mlt.setFieldNames(textFieldNames);
> +    this.indexSearcher = new IndexSearcher(leafReader);
> +    if (minDocsFreq > 0) {
> +      mlt.setMinDocFreq(minDocsFreq);
> +    }
> +    if (minTermFreq > 0) {
> +      mlt.setMinTermFreq(minTermFreq);
> +    }
> +    this.query = query;
>      this.k = k;
>    }
>
> -  /**
> -   * Create a {@link Classifier} using kNN algorithm
> -   *
> -   * @param k           the number of neighbors to analyze as an
> <code>int</code>
> -   * @param minDocsFreq the minimum number of docs frequency for MLT to
> be set with {@link MoreLikeThis#setMinDocFreq(int)}
> -   * @param minTermFreq the minimum number of term frequency for MLT to
> be set with {@link MoreLikeThis#setMinTermFreq(int)}
> -   */
> -  public KNearestNeighborClassifier(int k, int minDocsFreq, int
> minTermFreq) {
> -    this.k = k;
> -    this.minDocsFreq = minDocsFreq;
> -    this.minTermFreq = minTermFreq;
> -  }
>
>    /**
>     * {@inheritDoc}
> @@ -136,12 +131,15 @@ public class KNearestNeighborClassifier
>    private List<ClassificationResult<BytesRef>>
> buildListFromTopDocs(TopDocs topDocs) throws IOException {
>      Map<BytesRef, Integer> classCounts = new HashMap<>();
>      for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
> -      BytesRef cl = new
> BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
> -      Integer count = classCounts.get(cl);
> -      if (count != null) {
> -        classCounts.put(cl, count + 1);
> -      } else {
> -        classCounts.put(cl, 1);
> +      StorableField storableField =
> indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
> +      if (storableField != null) {
> +        BytesRef cl = new BytesRef(storableField.stringValue());
> +        Integer count = classCounts.get(cl);
> +        if (count != null) {
> +          classCounts.put(cl, count + 1);
> +        } else {
> +          classCounts.put(cl, 1);
> +        }
>        }
>      }
>      List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
> @@ -161,39 +159,4 @@ public class KNearestNeighborClassifier
>      return returnList;
>    }
>
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String textFieldName, String
> classFieldName, Analyzer analyzer) throws IOException {
> -    train(leafReader, textFieldName, classFieldName, analyzer, null);
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String textFieldName, String
> classFieldName, Analyzer analyzer, Query query) throws IOException {
> -    train(leafReader, new String[]{textFieldName}, classFieldName,
> analyzer, query);
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String[] textFieldNames,
> String classFieldName, Analyzer analyzer, Query query) throws IOException {
> -    this.textFieldNames = textFieldNames;
> -    this.classFieldName = classFieldName;
> -    mlt = new MoreLikeThis(leafReader);
> -    mlt.setAnalyzer(analyzer);
> -    mlt.setFieldNames(textFieldNames);
> -    indexSearcher = new IndexSearcher(leafReader);
> -    if (minDocsFreq > 0) {
> -      mlt.setMinDocFreq(minDocsFreq);
> -    }
> -    if (minTermFreq > 0) {
> -      mlt.setMinTermFreq(minTermFreq);
> -    }
> -    this.query = query;
> -  }
>  }
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
> Thu Apr 30 14:12:03 2015
> @@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier
>     * {@link org.apache.lucene.index.LeafReader} used to access the {@link
> org.apache.lucene.classification.Classifier}'s
>     * index
>     */
> -  protected LeafReader leafReader;
> +  protected final LeafReader leafReader;
>
>    /**
>     * names of the fields to be used as input text
>     */
> -  protected String[] textFieldNames;
> +  protected final String[] textFieldNames;
>
>    /**
>     * name of the field to be used as a class / category output
>     */
> -  protected String classFieldName;
> +  protected final String classFieldName;
>
>    /**
>     * {@link org.apache.lucene.analysis.Analyzer} to be used for
> tokenizing unseen input text
>     */
> -  protected Analyzer analyzer;
> +  protected final Analyzer analyzer;
>
>    /**
>     * {@link org.apache.lucene.search.IndexSearcher} to run searches on
> the index for retrieving frequencies
>     */
> -  protected IndexSearcher indexSearcher;
> +  protected final IndexSearcher indexSearcher;
>
>    /**
>     * {@link org.apache.lucene.search.Query} used to eventually filter the
> document set to be used to classify
>     */
> -  protected Query query;
> +  protected final Query query;
>
>    /**
>     * Creates a new NaiveBayes classifier.
> -   * Note that you must call {@link
> #train(org.apache.lucene.index.LeafReader, String, String, Analyzer)
> train()} before you can
>     * classify any documents.
>     */
> -  public SimpleNaiveBayesClassifier() {
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String textFieldName, String
> classFieldName, Analyzer analyzer) throws IOException {
> -    train(leafReader, textFieldName, classFieldName, analyzer, null);
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String textFieldName, String
> classFieldName, Analyzer analyzer, Query query)
> -      throws IOException {
> -    train(leafReader, new String[]{textFieldName}, classFieldName,
> analyzer, query);
> -  }
> -
> -  /**
> -   * {@inheritDoc}
> -   */
> -  @Override
> -  public void train(LeafReader leafReader, String[] textFieldNames,
> String classFieldName, Analyzer analyzer, Query query)
> -      throws IOException {
> +  public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer
> analyzer, Query query, String classFieldName, String... textFieldNames) {
>      this.leafReader = leafReader;
>      this.indexSearcher = new IndexSearcher(this.leafReader);
>      this.textFieldNames = textFieldNames;
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
> Thu Apr 30 14:12:03 2015
> @@ -18,7 +18,7 @@
>  /**
>   * Uses already seen data (the indexed documents) to classify new
> documents.
>   * <p>
> - * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
> + * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
>   * Neighbor classifier and a Perceptron based classifier.
>   */
>  package org.apache.lucene.classification;
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
> Thu Apr 30 14:12:03 2015
> @@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils {
>
>    /**
>     * create a sparse <code>Double</code> vector given doc and field term
> vectors using local frequency of the terms in the doc
> -   * @param docTerms term vectors for a given document
> +   *
> +   * @param docTerms   term vectors for a given document
>     * @param fieldTerms field term vectors
>     * @return a sparse vector of <code>Double</code>s as an array
>     * @throws IOException in case accessing the underlying index fails
> @@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils {
>          if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
>            long termFreqLocal = docTermsEnum.totalTermFreq(); // the total
> number of occurrences of this term in the given document
>            freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
> -        }
> -        else {
> +        } else {
>            freqVector[i] = 0d;
>          }
>          i++;
> @@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils {
>
>    /**
>     * create a dense <code>Double</code> vector given doc and field term
> vectors using local frequency of the terms in the doc
> +   *
>     * @param docTerms term vectors for a given document
>     * @return a dense vector of <code>Double</code>s as an array
>     * @throws IOException in case accessing the underlying index fails
> @@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils {
>    public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms)
> throws IOException {
>      Double[] freqVector = null;
>      if (docTerms != null) {
> -        freqVector = new Double[(int) docTerms.size()];
> -        int i = 0;
> -        TermsEnum docTermsEnum = docTerms.iterator();
> +      freqVector = new Double[(int) docTerms.size()];
> +      int i = 0;
> +      TermsEnum docTermsEnum = docTerms.iterator();
>
> -        while (docTermsEnum.next() != null) {
> -            long termFreqLocal = docTermsEnum.totalTermFreq(); // the
> total number of occurrences of this term in the given document
> -            freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
> -            i++;
> -        }
> +      while (docTermsEnum.next() != null) {
> +        long termFreqLocal = docTermsEnum.totalTermFreq(); // the total
> number of occurrences of this term in the given document
> +        freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
> +        i++;
> +      }
>      }
>      return freqVector;
> -}
> +  }
>  }
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
> Thu Apr 30 14:12:03 2015
> @@ -17,6 +17,8 @@
>  package org.apache.lucene.classification;
>
>  import org.apache.lucene.analysis.MockAnalyzer;
> +import org.apache.lucene.index.LeafReader;
> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>  import org.apache.lucene.index.Term;
>  import org.apache.lucene.search.TermQuery;
>  import org.junit.Test;
> @@ -28,22 +30,45 @@ public class BooleanPerceptronClassifier
>
>    @Test
>    public void testBasicUsage() throws Exception {
> -    checkCorrectClassification(new BooleanPerceptronClassifier(),
> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
> booleanFieldName);
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      checkCorrectClassification(new
> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
> analyzer, null, 1, null), TECHNOLOGY_INPUT, false);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    @Test
>    public void testExplicitThreshold() throws Exception {
> -    checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1),
> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
> booleanFieldName);
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      checkCorrectClassification(new
> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
> analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    @Test
>    public void testBasicUsageWithQuery() throws Exception {
> -    checkCorrectClassification(new BooleanPerceptronClassifier(),
> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
> booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
> -  }
> -
> -  @Test
> -  public void testPerformance() throws Exception {
> -    checkPerformance(new BooleanPerceptronClassifier(), new
> MockAnalyzer(random()), booleanFieldName);
> +    TermQuery query = new TermQuery(new Term(textFieldName, "it"));
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      checkCorrectClassification(new
> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
> analyzer, query, 1, null), TECHNOLOGY_INPUT, false);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>  }
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
> Thu Apr 30 14:12:03 2015
> @@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokeni
>  import org.apache.lucene.analysis.core.KeywordTokenizer;
>  import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
>  import org.apache.lucene.analysis.reverse.ReverseStringFilter;
> +import org.apache.lucene.index.LeafReader;
> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>  import org.apache.lucene.index.Term;
>  import org.apache.lucene.search.TermQuery;
>  import org.apache.lucene.util.BytesRef;
> @@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifier
>
>    @Test
>    public void testBasicUsage() throws Exception {
> -    checkCorrectClassification(new CachingNaiveBayesClassifier(),
> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
> textFieldName, categoryFieldName);
> -    checkCorrectClassification(new CachingNaiveBayesClassifier(),
> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
> categoryFieldName);
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      checkCorrectClassification(new
> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +      checkCorrectClassification(new
> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
> textFieldName), POLITICS_INPUT, POLITICS_RESULT);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    @Test
>    public void testBasicUsageWithQuery() throws Exception {
> -    checkCorrectClassification(new CachingNaiveBayesClassifier(),
> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
> "it")));
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
> +      checkCorrectClassification(new
> CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName,
> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    @Test
>    public void testNGramUsage() throws Exception {
> -    checkCorrectClassification(new CachingNaiveBayesClassifier(),
> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName,
> categoryFieldName);
> +    LeafReader leafReader = null;
> +    try {
> +      NGramAnalyzer analyzer = new NGramAnalyzer();
> +      leafReader = populateSampleIndex(analyzer);
> +      checkCorrectClassification(new
> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    private class NGramAnalyzer extends Analyzer {
> @@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifier
>      }
>    }
>
> -  @Test
> -  public void testPerformance() throws Exception {
> -    checkPerformance(new CachingNaiveBayesClassifier(), new
> MockAnalyzer(random()), categoryFieldName);
> -  }
> -
>  }
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
> Thu Apr 30 14:12:03 2015
> @@ -41,14 +41,14 @@ import org.junit.Before;
>   */
>  public abstract class ClassificationTestBase<T> extends LuceneTestCase {
>    public final static String POLITICS_INPUT = "Here are some interesting
> questions and answers about Mitt Romney.. " +
> -      "If you don't know the answer to the question about Mitt Romney,
> then simply click on the answer below the question section.";
> +          "If you don't know the answer to the question about Mitt
> Romney, then simply click on the answer below the question section.";
>    public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
>
>    public static final String TECHNOLOGY_INPUT = "Much is made of what the
> likes of Facebook, Google and Apple know about users." +
> -      " Truth is, Amazon may know more.";
> +          " Truth is, Amazon may know more.";
>    public static final BytesRef TECHNOLOGY_RESULT = new
> BytesRef("technology");
>
> -  private RandomIndexWriter indexWriter;
> +  protected RandomIndexWriter indexWriter;
>    private Directory dir;
>    private FieldType ft;
>
> @@ -79,53 +79,34 @@ public abstract class ClassificationTest
>      dir.close();
>    }
>
> -  protected void checkCorrectClassification(Classifier<T> classifier,
> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
> String classFieldName) throws Exception {
> -    checkCorrectClassification(classifier, inputDoc, expectedResult,
> analyzer, textFieldName, classFieldName, null);
> +  protected void checkCorrectClassification(Classifier<T> classifier,
> String inputDoc, T expectedResult) throws Exception {
> +    ClassificationResult<T> classificationResult =
> classifier.assignClass(inputDoc);
> +    assertNotNull(classificationResult.getAssignedClass());
> +    assertEquals("got an assigned class of " +
> classificationResult.getAssignedClass(), expectedResult,
> classificationResult.getAssignedClass());
> +    double score = classificationResult.getScore();
> +    assertTrue("score should be between 0 and 1, got:" + score, score <=
> 1 && score >= 0);
>    }
>
> -  protected void checkCorrectClassification(Classifier<T> classifier,
> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
> String classFieldName, Query query) throws Exception {
> -    LeafReader leafReader = null;
> -    try {
> -      populateSampleIndex(analyzer);
> -      leafReader =
> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
> -      classifier.train(leafReader, textFieldName, classFieldName,
> analyzer, query);
> -      ClassificationResult<T> classificationResult =
> classifier.assignClass(inputDoc);
> -      assertNotNull(classificationResult.getAssignedClass());
> -      assertEquals("got an assigned class of " +
> classificationResult.getAssignedClass(), expectedResult,
> classificationResult.getAssignedClass());
> -      double score = classificationResult.getScore();
> -      assertTrue("score should be between 0 and 1, got:" + score, score
> <= 1 && score >= 0);
> -    } finally {
> -      if (leafReader != null)
> -        leafReader.close();
> -    }
> -  }
>    protected void checkOnlineClassification(Classifier<T> classifier,
> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
> String classFieldName) throws Exception {
>      checkOnlineClassification(classifier, inputDoc, expectedResult,
> analyzer, textFieldName, classFieldName, null);
>    }
>
>    protected void checkOnlineClassification(Classifier<T> classifier,
> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
> String classFieldName, Query query) throws Exception {
> -    LeafReader leafReader = null;
> -    try {
> -      populateSampleIndex(analyzer);
> -      leafReader =
> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
> -      classifier.train(leafReader, textFieldName, classFieldName,
> analyzer, query);
> -      ClassificationResult<T> classificationResult =
> classifier.assignClass(inputDoc);
> -      assertNotNull(classificationResult.getAssignedClass());
> -      assertEquals("got an assigned class of " +
> classificationResult.getAssignedClass(), expectedResult,
> classificationResult.getAssignedClass());
> -      double score = classificationResult.getScore();
> -      assertTrue("score should be between 0 and 1, got: " + score, score
> <= 1 && score >= 0);
> -      updateSampleIndex();
> -      ClassificationResult<T> secondClassificationResult =
> classifier.assignClass(inputDoc);
> -      assertEquals(classificationResult.getAssignedClass(),
> secondClassificationResult.getAssignedClass());
> -      assertEquals(Double.valueOf(score),
> Double.valueOf(secondClassificationResult.getScore()));
> -
> -    } finally {
> -      if (leafReader != null)
> -        leafReader.close();
> -    }
> +    populateSampleIndex(analyzer);
> +
> +    ClassificationResult<T> classificationResult =
> classifier.assignClass(inputDoc);
> +    assertNotNull(classificationResult.getAssignedClass());
> +    assertEquals("got an assigned class of " +
> classificationResult.getAssignedClass(), expectedResult,
> classificationResult.getAssignedClass());
> +    double score = classificationResult.getScore();
> +    assertTrue("score should be between 0 and 1, got: " + score, score <=
> 1 && score >= 0);
> +    updateSampleIndex();
> +    ClassificationResult<T> secondClassificationResult =
> classifier.assignClass(inputDoc);
> +    assertEquals(classificationResult.getAssignedClass(),
> secondClassificationResult.getAssignedClass());
> +    assertEquals(Double.valueOf(score),
> Double.valueOf(secondClassificationResult.getScore()));
> +
>    }
>
> -  private void populateSampleIndex(Analyzer analyzer) throws IOException {
> +  protected LeafReader populateSampleIndex(Analyzer analyzer) throws
> IOException {
>      indexWriter.close();
>      indexWriter = new RandomIndexWriter(random(), dir,
> newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
>      indexWriter.commit();
> @@ -134,8 +115,8 @@ public abstract class ClassificationTest
>
>      Document doc = new Document();
>      text = "The traveling press secretary for Mitt Romney lost his cool
> and cursed at reporters " +
> -        "who attempted to ask questions of the Republican presidential
> candidate in a public plaza near the Tomb of " +
> -        "the Unknown Soldier in Warsaw Tuesday.";
> +            "who attempted to ask questions of the Republican
> presidential candidate in a public plaza near the Tomb of " +
> +            "the Unknown Soldier in Warsaw Tuesday.";
>      doc.add(new Field(textFieldName, text, ft));
>      doc.add(new Field(categoryFieldName, "politics", ft));
>      doc.add(new Field(booleanFieldName, "true", ft));
> @@ -144,7 +125,7 @@ public abstract class ClassificationTest
>
>      doc = new Document();
>      text = "Mitt Romney seeks to assure Israel and Iran, as well as
> Jewish voters in the United" +
> -        " States, that he will be tougher against Iran's nuclear
> ambitions than President Barack Obama.";
> +            " States, that he will be tougher against Iran's nuclear
> ambitions than President Barack Obama.";
>      doc.add(new Field(textFieldName, text, ft));
>      doc.add(new Field(categoryFieldName, "politics", ft));
>      doc.add(new Field(booleanFieldName, "true", ft));
> @@ -152,8 +133,8 @@ public abstract class ClassificationTest
>
>      doc = new Document();
>      text = "And there's a threshold question that he has to answer for
> the American people and " +
> -        "that's whether he is prepared to be commander-in-chief,\" she
> continued. \"As we look to the past events, we " +
> -        "know that this raises some questions about his preparedness and
> we'll see how the rest of his trip goes.\"";
> +            "that's whether he is prepared to be commander-in-chief,\"
> she continued. \"As we look to the past events, we " +
> +            "know that this raises some questions about his preparedness
> and we'll see how the rest of his trip goes.\"";
>      doc.add(new Field(textFieldName, text, ft));
>      doc.add(new Field(categoryFieldName, "politics", ft));
>      doc.add(new Field(booleanFieldName, "true", ft));
> @@ -161,8 +142,8 @@ public abstract class ClassificationTest
>
>      doc = new Document();
>      text = "Still, when it comes to gun policy, many congressional
> Democrats have \"decided to " +
> -        "keep quiet and not go there,\" said Alan Lizotte, dean and
> professor at the State University of New York at " +
> -        "Albany's School of Criminal Justice.";
> +            "keep quiet and not go there,\" said Alan Lizotte, dean and
> professor at the State University of New York at " +
> +            "Albany's School of Criminal Justice.";
>      doc.add(new Field(textFieldName, text, ft));
>      doc.add(new Field(categoryFieldName, "politics", ft));
>      doc.add(new Field(booleanFieldName, "true", ft));
> @@ -170,8 +151,8 @@ public abstract class ClassificationTest
>
>      doc = new Document();
>      text = "Standing amongst the thousands of people at the state
> Capitol, Jorstad, director of " +
> -        "technology at the University of Wisconsin-La Crosse, documented
> the historic moment and shared it with the " +
> -        "world through the Internet.";
> +            "technology at the University of Wisconsin-La Crosse,
> documented the historic moment and shared it with the " +
> +            "world through the Internet.";
>      doc.add(new Field(textFieldName, text, ft));
>      doc.add(new Field(categoryFieldName, "technology", ft));
>      doc.add(new Field(booleanFieldName, "false", ft));
> @@ -179,7 +160,7 @@ public abstract class ClassificationTest
>
>      doc = new Document();
>      text = "So, about all those experts and analysts who've spent the
> past year or so saying " +
> -        "Facebook was going to make a phone. A new expert has stepped
> forward to say it's not going to happen.";
> +            "Facebook was going to make a phone. A new expert has stepped
> forward to say it's not going to happen.";
>      doc.add(new Field(textFieldName, text, ft));
>      doc.add(new Field(categoryFieldName, "technology", ft));
>      doc.add(new Field(booleanFieldName, "false", ft));
> @@ -187,8 +168,8 @@ public abstract class ClassificationTest
>
>      doc = new Document();
>      text = "More than 400 million people trust Google with their e-mail,
> and 50 million store files" +
> -        " in the cloud using the Dropbox service. People manage their
> bank accounts, pay bills, trade stocks and " +
> -        "generally transfer or store huge volumes of personal data
> online.";
> +            " in the cloud using the Dropbox service. People manage their
> bank accounts, pay bills, trade stocks and " +
> +            "generally transfer or store huge volumes of personal data
> online.";
>      doc.add(new Field(textFieldName, text, ft));
>      doc.add(new Field(categoryFieldName, "technology", ft));
>      doc.add(new Field(booleanFieldName, "false", ft));
> @@ -200,22 +181,15 @@ public abstract class ClassificationTest
>      indexWriter.addDocument(doc);
>
>      indexWriter.commit();
> +    return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>    }
>
>    protected void checkPerformance(Classifier<T> classifier, Analyzer
> analyzer, String classFieldName) throws Exception {
> -    LeafReader leafReader = null;
>      long trainStart = System.currentTimeMillis();
> -    try {
> -      populatePerformanceIndex(analyzer);
> -      leafReader =
> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
> -      classifier.train(leafReader, textFieldName, classFieldName,
> analyzer);
> -      long trainEnd = System.currentTimeMillis();
> -      long trainTime = trainEnd - trainStart;
> -      assertTrue("training took more than 2 mins : " + trainTime / 1000 +
> "s", trainTime < 120000);
> -    } finally {
> -      if (leafReader != null)
> -        leafReader.close();
> -    }
> +    populatePerformanceIndex(analyzer);
> +    long trainEnd = System.currentTimeMillis();
> +    long trainTime = trainEnd - trainStart;
> +    assertTrue("training took more than 2 mins : " + trainTime / 1000 +
> "s", trainTime < 120000);
>    }
>
>    private void populatePerformanceIndex(Analyzer analyzer) throws
> IOException {
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
> Thu Apr 30 14:12:03 2015
> @@ -17,6 +17,8 @@
>  package org.apache.lucene.classification;
>
>  import org.apache.lucene.analysis.MockAnalyzer;
> +import org.apache.lucene.index.LeafReader;
> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>  import org.apache.lucene.index.Term;
>  import org.apache.lucene.search.TermQuery;
>  import org.apache.lucene.util.BytesRef;
> @@ -29,20 +31,32 @@ public class KNearestNeighborClassifierT
>
>    @Test
>    public void testBasicUsage() throws Exception {
> -    // usage with default MLT min docs / term freq
> -    checkCorrectClassification(new KNearestNeighborClassifier(3),
> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
> categoryFieldName);
> -    // usage without custom min docs / term freq for MLT
> -    checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1),
> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
> textFieldName, categoryFieldName);
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      checkCorrectClassification(new
> KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0,
> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +      checkCorrectClassification(new
> KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1,
> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    @Test
>    public void testBasicUsageWithQuery() throws Exception {
> -    checkCorrectClassification(new KNearestNeighborClassifier(1),
> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
> "it")));
> -  }
> -
> -  @Test
> -  public void testPerformance() throws Exception {
> -    checkPerformance(new KNearestNeighborClassifier(100), new
> MockAnalyzer(random()), categoryFieldName);
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
> +      checkCorrectClassification(new
> KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0,
> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>  }
>
> Modified:
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
> URL:
> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>
> ==============================================================================
> ---
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
> (original)
> +++
> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
> Thu Apr 30 14:12:03 2015
> @@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokeni
>  import org.apache.lucene.analysis.core.KeywordTokenizer;
>  import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
>  import org.apache.lucene.analysis.reverse.ReverseStringFilter;
> +import org.apache.lucene.index.LeafReader;
> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>  import org.apache.lucene.index.Term;
>  import org.apache.lucene.search.TermQuery;
>  import org.apache.lucene.util.BytesRef;
> -import org.apache.lucene.util.LuceneTestCase;
>  import org.junit.Test;
>
> -import java.io.Reader;
> -
>  /**
>   * Testcase for {@link SimpleNaiveBayesClassifier}
>   */
> @@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierT
>
>    @Test
>    public void testBasicUsage() throws Exception {
> -    checkCorrectClassification(new SimpleNaiveBayesClassifier(),
> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
> textFieldName, categoryFieldName);
> -    checkCorrectClassification(new SimpleNaiveBayesClassifier(),
> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
> categoryFieldName);
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      checkCorrectClassification(new
> SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +      checkCorrectClassification(new
> SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
> textFieldName), POLITICS_INPUT, POLITICS_RESULT);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    @Test
>    public void testBasicUsageWithQuery() throws Exception {
> -    checkCorrectClassification(new SimpleNaiveBayesClassifier(),
> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
> "it")));
> +    LeafReader leafReader = null;
> +    try {
> +      MockAnalyzer analyzer = new MockAnalyzer(random());
> +      leafReader = populateSampleIndex(analyzer);
> +      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
> +      checkCorrectClassification(new
> SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName,
> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    @Test
>    public void testNGramUsage() throws Exception {
> -    checkCorrectClassification(new SimpleNaiveBayesClassifier(),
> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName,
> categoryFieldName);
> +    LeafReader leafReader = null;
> +    try {
> +      Analyzer analyzer = new NGramAnalyzer();
> +      leafReader = populateSampleIndex(analyzer);
> +      checkCorrectClassification(new
> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
> +    } finally {
> +      if (leafReader != null) {
> +        leafReader.close();
> +      }
> +    }
>    }
>
>    private class NGramAnalyzer extends Analyzer {
> @@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierT
>      }
>    }
>
> -  @Test
> -  public void testPerformance() throws Exception {
> -    checkPerformance(new SimpleNaiveBayesClassifier(), new
> MockAnalyzer(random()), categoryFieldName);
> -  }
> -
>  }
>
>
>

Reply via email to