Author: tommaso
Date: Tue Aug 11 15:58:53 2015
New Revision: 1695334
URL: http://svn.apache.org/r1695334
Log:
OPENNLP-777 - naive bayes classifier (patch from Cohan Sujay Carlos)
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerNB.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelReader.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelWriter.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbabilities.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbability.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesEvalParameters.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModel.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReader.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelReader.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelWriter.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probabilities.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probability.java
opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/doccat/DocumentCategorizerNBTest.java
opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/
opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java
opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java
Modified:
opennlp/trunk/opennlp-tools/pom.xml
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModel.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelReader.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelWriter.java
Modified: opennlp/trunk/opennlp-tools/pom.xml
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/pom.xml?rev=1695334&r1=1695333&r2=1695334&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/pom.xml (original)
+++ opennlp/trunk/opennlp-tools/pom.xml Tue Aug 11 15:58:53 2015
@@ -155,6 +155,7 @@
<exclude>src/test/resources/data/ppa/devset</exclude>
<exclude>src/test/resources/data/ppa/test</exclude>
<exclude>src/test/resources/data/ppa/training</exclude>
+ <exclude>src/test/resources/data/ppa/NOTICE</exclude>
<exclude>src/test/resources/opennlp/tools/doccat/DoccatSample.txt</exclude>
<exclude>src/test/resources/opennlp/tools/formats/brat/voa-with-entities.ann</exclude>
<exclude>src/test/resources/opennlp/tools/formats/brat/voa-with-entities.txt</exclude>
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerNB.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerNB.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerNB.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerNB.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.doccat;
+
+import java.io.IOException;
+import java.io.ObjectStreamException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import opennlp.tools.ml.AbstractTrainer;
+import opennlp.tools.ml.EventTrainer;
+import opennlp.tools.ml.TrainerFactory;
+import opennlp.tools.ml.model.MaxentModel;
+import opennlp.tools.ml.naivebayes.NaiveBayesModel;
+import opennlp.tools.ml.naivebayes.NaiveBayesTrainer;
+import opennlp.tools.tokenize.SimpleTokenizer;
+import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.util.ObjectStream;
+import opennlp.tools.util.TrainingParameters;
+import opennlp.tools.util.model.ModelUtil;
+
+/**
+ * Naive Bayes implementation of {@link DocumentCategorizer}.
+ */
+public class DocumentCategorizerNB implements DocumentCategorizer {
+
+ /**
+ * Shared default thread safe feature generator.
+ */
+ private static FeatureGenerator defaultFeatureGenerator = new
BagOfWordsFeatureGenerator();
+
+ private DoccatModel model;
+ private DocumentCategorizerContextGenerator mContextGenerator;
+
+ /**
+ * Initializes a the current instance with a doccat model and custom feature
+ * generation. The feature generation must be identical to the configuration
+ * at training time.
+ *
+ * @param model
+ * @param featureGenerators
+ * @deprecated train a {@link DoccatModel} with a specific
+ * {@link DoccatFactory} to customize the {@link FeatureGenerator}s
+ */
+ public DocumentCategorizerNB(DoccatModel model, FeatureGenerator...
featureGenerators) {
+ this.model = model;
+ this.mContextGenerator = new
DocumentCategorizerContextGenerator(featureGenerators);
+ }
+
+ /**
+ * Initializes the current instance with a doccat model. Default feature
+ * generation is used.
+ *
+ * @param model
+ */
+ public DocumentCategorizerNB(DoccatModel model) {
+ this.model = model;
+ this.mContextGenerator = new DocumentCategorizerContextGenerator(this.model
+ .getFactory().getFeatureGenerators());
+ }
+
+ @Override
+ public double[] categorize(String[] text, Map<String, Object>
extraInformation) {
+ return model.getMaxentModel().eval(
+ mContextGenerator.getContext(text, extraInformation));
+ }
+
+ /**
+ * Categorizes the given text.
+ *
+ * @param text
+ */
+ public double[] categorize(String text[]) {
+ return this.categorize(text, Collections.<String, Object>emptyMap());
+ }
+
+ /**
+ * Categorizes the given text. The Tokenizer is obtained from
+ * {@link DoccatFactory#getTokenizer()} and defaults to
+ * {@link SimpleTokenizer}.
+ */
+ @Override
+ public double[] categorize(String documentText,
+ Map<String, Object> extraInformation) {
+ Tokenizer tokenizer = model.getFactory().getTokenizer();
+ return categorize(tokenizer.tokenize(documentText), extraInformation);
+ }
+
+ /**
+ * Categorizes the given text. The text is tokenized with the SimpleTokenizer
+ * before it is passed to the feature generation.
+ */
+ public double[] categorize(String documentText) {
+ Tokenizer tokenizer = model.getFactory().getTokenizer();
+ return categorize(tokenizer.tokenize(documentText),
+ Collections.<String, Object>emptyMap());
+ }
+
+ /**
+ * Returns a map in which the key is the category name and the value is the
score
+ *
+ * @param text the input text to classify
+ * @return
+ */
+ public Map<String, Double> scoreMap(String text) {
+ Map<String, Double> probDist = new HashMap<String, Double>();
+
+ double[] categorize = categorize(text);
+ int catSize = getNumberOfCategories();
+ for (int i = 0; i < catSize; i++) {
+ String category = getCategory(i);
+ probDist.put(category, categorize[getIndex(category)]);
+ }
+ return probDist;
+
+ }
+
+ /**
+ * Returns a map with the score as a key in ascendng order. The value is a
Set of categories with the score.
+ * Many categories can have the same score, hence the Set as value
+ *
+ * @param text the input text to classify
+ * @return
+ */
+ public SortedMap<Double, Set<String>> sortedScoreMap(String text) {
+ SortedMap<Double, Set<String>> descendingMap = new TreeMap<Double,
Set<String>>();
+ double[] categorize = categorize(text);
+ int catSize = getNumberOfCategories();
+ for (int i = 0; i < catSize; i++) {
+ String category = getCategory(i);
+ double score = categorize[getIndex(category)];
+ if (descendingMap.containsKey(score)) {
+ descendingMap.get(score).add(category);
+ } else {
+ Set<String> newset = new HashSet<String>();
+ newset.add(category);
+ descendingMap.put(score, newset);
+ }
+ }
+ return descendingMap;
+ }
+
+ public String getBestCategory(double[] outcome) {
+ return model.getMaxentModel().getBestOutcome(outcome);
+ }
+
+ public int getIndex(String category) {
+ return model.getMaxentModel().getIndex(category);
+ }
+
+ public String getCategory(int index) {
+ return model.getMaxentModel().getOutcome(index);
+ }
+
+ public int getNumberOfCategories() {
+ return model.getMaxentModel().getNumOutcomes();
+ }
+
+ public String getAllResults(double results[]) {
+ return model.getMaxentModel().getAllOutcomes(results);
+ }
+
+ /**
+ * @deprecated Use
+ * {@link #train(String, ObjectStream, TrainingParameters, DoccatFactory)}
+ * instead.
+ */
+ public static DoccatModel train(String languageCode,
ObjectStream<DocumentSample> samples,
+ TrainingParameters mlParams,
FeatureGenerator... featureGenerators)
+ throws IOException {
+
+ if (featureGenerators.length == 0) {
+ featureGenerators = new FeatureGenerator[]{defaultFeatureGenerator};
+ }
+
+ Map<String, String> manifestInfoEntries = new HashMap<String, String>();
+
+ mlParams.put(AbstractTrainer.ALGORITHM_PARAM,
NaiveBayesTrainer.NAIVE_BAYES_VALUE);
+
+ NaiveBayesModel nbModel = getTrainedInnerModel(samples, mlParams,
manifestInfoEntries, featureGenerators);
+
+ return new DoccatModel(languageCode, nbModel, manifestInfoEntries);
+ }
+
+ public static DoccatModel train(String languageCode,
ObjectStream<DocumentSample> samples,
+ TrainingParameters mlParams, DoccatFactory
factory)
+ throws IOException {
+
+ Map<String, String> manifestInfoEntries = new HashMap<String, String>();
+
+ mlParams.put(AbstractTrainer.ALGORITHM_PARAM,
NaiveBayesTrainer.NAIVE_BAYES_VALUE);
+
+ NaiveBayesModel nbModel = getTrainedInnerModel(samples, mlParams,
manifestInfoEntries, factory.getFeatureGenerators());
+
+ return new DoccatModel(languageCode, nbModel, manifestInfoEntries,
factory);
+ }
+
+ protected static NaiveBayesModel getTrainedInnerModel(
+ ObjectStream<DocumentSample> samples, TrainingParameters mlParams,
+ Map<String, String> manifestInfoEntries,
+ FeatureGenerator... featureGenerators) throws IOException {
+ if (!TrainerFactory.isSupportEvent(mlParams.getSettings())) {
+ throw new IllegalArgumentException("EventTrain is not supported");
+ }
+ EventTrainer trainer =
TrainerFactory.getEventTrainer(mlParams.getSettings(), manifestInfoEntries);
+ MaxentModel model = trainer.train(new
DocumentCategorizerEventStream(samples, featureGenerators));
+
+ NaiveBayesModel nbModel = null;
+ if (model instanceof NaiveBayesModel) {
+ nbModel = (NaiveBayesModel) model;
+ }
+ return nbModel;
+ }
+
+ /**
+ * Trains a doccat model with default feature generation.
+ *
+ * @param languageCode
+ * @param samples
+ * @return the trained doccat model
+ * @throws IOException
+ * @throws ObjectStreamException
+ * @deprecated Use
+ * {@link #train(String, ObjectStream, TrainingParameters, DoccatFactory)}
+ * instead.
+ */
+ public static DoccatModel train(String languageCode,
ObjectStream<DocumentSample> samples) throws IOException {
+ return train(languageCode, samples,
ModelUtil.createDefaultTrainingParameters(), defaultFeatureGenerator);
+ }
+}
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1695334&r1=1695333&r2=1695334&view=diff
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
(original)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
Tue Aug 11 15:58:53 2015
@@ -24,6 +24,7 @@ import java.util.Map;
import opennlp.tools.ml.maxent.GIS;
import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
+import opennlp.tools.ml.naivebayes.NaiveBayesTrainer;
import opennlp.tools.ml.perceptron.PerceptronTrainer;
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
import opennlp.tools.util.ext.ExtensionLoader;
@@ -47,6 +48,7 @@ public class TrainerFactory {
_trainers.put(PerceptronTrainer.PERCEPTRON_VALUE, PerceptronTrainer.class);
_trainers.put(SimplePerceptronSequenceTrainer.PERCEPTRON_SEQUENCE_VALUE,
SimplePerceptronSequenceTrainer.class);
+ _trainers.put(NaiveBayesTrainer.NAIVE_BAYES_VALUE,
NaiveBayesTrainer.class);
BUILTIN_TRAINERS = Collections.unmodifiableMap(_trainers);
}
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModel.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModel.java?rev=1695334&r1=1695333&r2=1695334&view=diff
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModel.java
(original)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModel.java
Tue Aug 11 15:58:53 2015
@@ -32,7 +32,7 @@ public abstract class AbstractModel impl
/** Prior distribution for this model. */
protected Prior prior;
- public enum ModelType {Maxent,Perceptron,MaxentQn};
+ public enum ModelType {Maxent,Perceptron,MaxentQn,NaiveBayes};
/** The type of the model. */
protected ModelType modelType;
@@ -165,4 +165,4 @@ public abstract class AbstractModel impl
data[4] = evalParams.getCorrectionParam();
return data;
}
-}
+}
\ No newline at end of file
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelReader.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelReader.java?rev=1695334&r1=1695333&r2=1695334&view=diff
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelReader.java
(original)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelReader.java
Tue Aug 11 15:58:53 2015
@@ -24,6 +24,7 @@ import java.io.IOException;
import opennlp.tools.ml.maxent.io.GISModelReader;
import opennlp.tools.ml.maxent.io.QNModelReader;
+import opennlp.tools.ml.naivebayes.NaiveBayesModelReader;
import opennlp.tools.ml.perceptron.PerceptronModelReader;
public class GenericModelReader extends AbstractModelReader {
@@ -49,6 +50,9 @@ public class GenericModelReader extends
else if (modelType.equals("QN")) {
delegateModelReader = new QNModelReader(this.dataReader);
}
+ else if (modelType.equals("NaiveBayes")) {
+ delegateModelReader = new NaiveBayesModelReader(this.dataReader);
+ }
else {
throw new IOException("Unknown model format: "+modelType);
}
@@ -63,4 +67,4 @@ public class GenericModelReader extends
AbstractModel m = new GenericModelReader(new File(args[0])).getModel();
new GenericModelWriter( m, new File(args[1])).persist();
}
-}
+}
\ No newline at end of file
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelWriter.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelWriter.java?rev=1695334&r1=1695333&r2=1695334&view=diff
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelWriter.java
(original)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/GenericModelWriter.java
Tue Aug 11 15:58:53 2015
@@ -32,6 +32,8 @@ import opennlp.tools.ml.maxent.io.Binary
import opennlp.tools.ml.maxent.io.BinaryQNModelWriter;
import opennlp.tools.ml.maxent.io.PlainTextGISModelWriter;
import opennlp.tools.ml.model.AbstractModel.ModelType;
+import opennlp.tools.ml.naivebayes.BinaryNaiveBayesModelWriter;
+import opennlp.tools.ml.naivebayes.PlainTextNaiveBayesModelWriter;
import opennlp.tools.ml.perceptron.BinaryPerceptronModelWriter;
import opennlp.tools.ml.perceptron.PlainTextPerceptronModelWriter;
@@ -45,43 +47,44 @@ public class GenericModelWriter extends
// handle the zipped/not zipped distinction
if (filename.endsWith(".gz")) {
os = new GZIPOutputStream(new FileOutputStream(file));
- filename = filename.substring(0,filename.length()-3);
- }
- else {
+ filename = filename.substring(0, filename.length() - 3);
+ } else {
os = new FileOutputStream(file);
}
// handle the different formats
if (filename.endsWith(".bin")) {
- init(model,new DataOutputStream(os));
- }
- else { // filename ends with ".txt"
- init(model,new BufferedWriter(new OutputStreamWriter(os)));
+ init(model, new DataOutputStream(os));
+ } else { // filename ends with ".txt"
+ init(model, new BufferedWriter(new OutputStreamWriter(os)));
}
}
public GenericModelWriter(AbstractModel model, DataOutputStream dos) {
- init(model,dos);
+ init(model, dos);
}
private void init(AbstractModel model, DataOutputStream dos) {
if (model.getModelType() == ModelType.Perceptron) {
- delegateWriter = new BinaryPerceptronModelWriter(model,dos);
- }
- else if (model.getModelType() == ModelType.Maxent) {
- delegateWriter = new BinaryGISModelWriter(model,dos);
+ delegateWriter = new BinaryPerceptronModelWriter(model, dos);
+ } else if (model.getModelType() == ModelType.Maxent) {
+ delegateWriter = new BinaryGISModelWriter(model, dos);
+ } else if (model.getModelType() == ModelType.MaxentQn) {
+ delegateWriter = new BinaryQNModelWriter(model, dos);
}
- else if (model.getModelType() == ModelType.MaxentQn) {
- delegateWriter = new BinaryQNModelWriter(model,dos);
+ if (model.getModelType() == ModelType.NaiveBayes) {
+ delegateWriter = new BinaryNaiveBayesModelWriter(model, dos);
}
}
private void init(AbstractModel model, BufferedWriter bw) {
if (model.getModelType() == ModelType.Perceptron) {
- delegateWriter = new PlainTextPerceptronModelWriter(model,bw);
+ delegateWriter = new PlainTextPerceptronModelWriter(model, bw);
+ } else if (model.getModelType() == ModelType.Maxent) {
+ delegateWriter = new PlainTextGISModelWriter(model, bw);
}
- else if (model.getModelType() == ModelType.Maxent) {
- delegateWriter = new PlainTextGISModelWriter(model,bw);
+ if (model.getModelType() == ModelType.NaiveBayes) {
+ delegateWriter = new PlainTextNaiveBayesModelWriter(model, bw);
}
}
@@ -109,4 +112,4 @@ public class GenericModelWriter extends
public void writeUTF(String s) throws IOException {
delegateWriter.writeUTF(s);
}
-}
+}
\ No newline at end of file
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelReader.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelReader.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelReader.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelReader.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.io.DataInputStream;
+import java.io.File;
+import java.io.IOException;
+
+import opennlp.tools.ml.model.BinaryFileDataReader;
+
+public class BinaryNaiveBayesModelReader extends NaiveBayesModelReader {
+
+
+ /**
+ * Constructor which directly instantiates the DataInputStream containing
+ * the model contents.
+ *
+ * @param dis The DataInputStream containing the model information.
+ */
+ public BinaryNaiveBayesModelReader(DataInputStream dis) {
+ super(new BinaryFileDataReader(dis));
+ }
+
+ /**
+ * Constructor which takes a File and creates a reader for it. Detects
+ * whether the file is gzipped or not based on whether the suffix contains
+ * ".gz"
+ *
+ * @param f The File in which the model is stored.
+ */
+ public BinaryNaiveBayesModelReader(File f) throws IOException {
+ super(f);
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelWriter.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelWriter.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelWriter.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/BinaryNaiveBayesModelWriter.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.zip.GZIPOutputStream;
+
+import opennlp.tools.ml.model.AbstractModel;
+
+/**
+ * Model writer that saves models in binary format.
+ */
+public class BinaryNaiveBayesModelWriter extends NaiveBayesModelWriter {
+ DataOutputStream output;
+
+ /**
+ * Constructor which takes a NaiveBayesModel and a File and prepares itself
to
+ * write the model to that file. Detects whether the file is gzipped or not
+ * based on whether the suffix contains ".gz".
+ *
+ * @param model The NaiveBayesModel which is to be persisted.
+ * @param f The File in which the model is to be persisted.
+ */
+ public BinaryNaiveBayesModelWriter(AbstractModel model, File f) throws
IOException {
+
+ super(model);
+
+ if (f.getName().endsWith(".gz")) {
+ output = new DataOutputStream(
+ new GZIPOutputStream(new FileOutputStream(f)));
+ } else {
+ output = new DataOutputStream(new FileOutputStream(f));
+ }
+ }
+
+ /**
+ * Constructor which takes a NaiveBayesModel and a DataOutputStream and
prepares
+ * itself to write the model to that stream.
+ *
+ * @param model The NaiveBayesModel which is to be persisted.
+ * @param dos The stream which will be used to persist the model.
+ */
+ public BinaryNaiveBayesModelWriter(AbstractModel model, DataOutputStream
dos) {
+ super(model);
+ output = dos;
+ }
+
+ public void writeUTF(String s) throws java.io.IOException {
+ output.writeUTF(s);
+ }
+
+ public void writeInt(int i) throws java.io.IOException {
+ output.writeInt(i);
+ }
+
+ public void writeDouble(double d) throws java.io.IOException {
+ output.writeDouble(d);
+ }
+
+ public void close() throws java.io.IOException {
+ output.flush();
+ output.close();
+ }
+
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbabilities.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbabilities.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbabilities.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbabilities.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.util.ArrayList;
+import java.util.Map;
+
+/**
+ * Class implementing the probability distribution over labels returned by a
classifier as a log of probabilities.
+ * This is necessary because floating point precision in Java does not allow
for high-accuracy representation of very low probabilities
+ * such as would occur in a text categorizer.
+ *
+ * @param <T> the label (category) class
+ *
+ */
+public class LogProbabilities<T> extends Probabilities<T> {
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability.
+ *
+ * @param t the label to which the probability is being assigned
+ * @param probability the probability to assign
+ */
+ public void set(T t, double probability) {
+ isNormalised = false;
+ map.put(t, log(probability));
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability.
+ *
+ * @param t the label to which the probability is being assigned
+ * @param probability the probability to assign
+ */
+ public void set(T t, Probability<T> probability) {
+ isNormalised = false;
+ map.put(t, probability.getLog());
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability, if the new probability is greater than the old one.
+ *
+ * @param t the label to which the probability is being assigned
+ * @param probability the probability to assign
+ */
+ public void setIfLarger(T t, double probability) {
+ double logProbability = log(probability);
+ Double p = map.get(t);
+ if (p == null || logProbability > p) {
+ isNormalised = false;
+ map.put(t, logProbability);
+ }
+ }
+
+ /**
+ * Assigns a log probability to a label, discarding any previously assigned
probability.
+ *
+ * @param t the label to which the log probability is being
assigned
+ * @param probability the log probability to assign
+ */
+ public void setLog(T t, double probability) {
+ isNormalised = false;
+ map.put(t, probability);
+ }
+
+ /**
+ * Compounds the existing probability mass on the label with the new
probability passed in to the method.
+ *
+ * @param t the label whose probability mass is being updated
+ * @param probability the probability weight to add
+ * @param count the amplifying factor for the probability compounding
+ */
+ public void addIn(T t, double probability, int count) {
+ isNormalised = false;
+ Double p = map.get(t);
+ if (p == null)
+ p = 0.0;
+ probability = log(probability) * count;
+ map.put(t, p + probability);
+ }
+
+ private Map<T, Double> normalize() {
+ if (isNormalised)
+ return normalised;
+ Map<T, Double> temp = createMapDataStructure();
+ double highestLogProbability = Double.NEGATIVE_INFINITY;
+ for (T t : map.keySet()) {
+ Double p = map.get(t);
+ if (p != null && p > highestLogProbability) {
+ highestLogProbability = p;
+ }
+ }
+ double sum = 0;
+ for (T t : map.keySet()) {
+ Double p = map.get(t);
+ if (p != null) {
+ double temp_p = Math.exp(p - highestLogProbability);
+ if (!Double.isNaN(temp_p)) {
+ sum += temp_p;
+ temp.put(t, temp_p);
+ }
+ }
+ }
+ for (T t : temp.keySet()) {
+ Double p = temp.get(t);
+ if (p != null && sum > Double.MIN_VALUE) {
+ temp.put(t, p / sum);
+ }
+ }
+ normalised = temp;
+ isNormalised = true;
+ return temp;
+ }
+
+ private double log(double prob) {
+ return Math.log(prob);
+ }
+
+ /**
+ * Returns the probability associated with a label
+ *
+ * @param t the label whose probability needs to be returned
+ * @return the probability associated with the label
+ */
+ public Double get(T t) {
+ Double d = normalize().get(t);
+ if (d == null)
+ return 0.0;
+ return d;
+ }
+
+ /**
+ * Returns the log probability associated with a label
+ *
+ * @param t the label whose log probability needs to be returned
+ * @return the log probability associated with the label
+ */
+ public Double getLog(T t) {
+ Double d = map.get(t);
+ if (d == null)
+ return Double.NEGATIVE_INFINITY;
+ return d;
+ }
+
+ public void discardCountsBelow(double i) {
+ i = Math.log(i);
+ ArrayList<T> labelsToRemove = new ArrayList<T>();
+ for (T label : map.keySet()) {
+ Double sum = map.get(label);
+ if (sum == null) sum = Double.NEGATIVE_INFINITY;
+ if (sum < i)
+ labelsToRemove.add(label);
+ }
+ for (T label : labelsToRemove) {
+ map.remove(label);
+ }
+ }
+
+ /**
+ * Returns the probabilities associated with all labels
+ *
+ * @return the HashMap of labels and their probabilities
+ */
+ public Map<T, Double> getAll() {
+ return normalize();
+ }
+
+ /**
+ * Returns the most likely label
+ *
+ * @return the label that has the highest associated probability
+ */
+ public T getMax() {
+ double max = Double.NEGATIVE_INFINITY;
+ T maxT = null;
+ for (T t : map.keySet()) {
+ Double temp = map.get(t);
+ if (temp >= max) {
+ max = temp;
+ maxT = t;
+ }
+ }
+ return maxT;
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbability.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbability.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbability.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/LogProbability.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+/**
+ * Class implementing the probability for a label.
+ *
+ * @param <T> the label (category) class
+ *
+ */
+public class LogProbability<T> extends Probability<T> {
+
+ public LogProbability(T label) {
+ super(label);
+ set(1.0);
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability.
+ *
+ * @param probability the probability to assign
+ */
+ public void set(double probability) {
+ this.probability = Math.log(probability);
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability.
+ *
+ * @param probability the probability to assign
+ */
+ public void set(Probability probability) {
+ this.probability = probability.getLog();
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability, if the new probability is greater than the old one.
+ *
+ * @param probability the probability to assign
+ */
+ public void setIfLarger(double probability) {
+ double logP = Math.log(probability);
+ if (this.probability < logP) {
+ this.probability = logP;
+ }
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability, if the new probability is greater than the old one.
+ *
+ * @param probability the probability to assign
+ */
+ public void setIfLarger(Probability probability) {
+ if (this.probability < probability.getLog()) {
+ this.probability = probability.getLog();
+ }
+ }
+
+ /**
+ * Checks if a probability is greater than the old one.
+ *
+ * @param probability the probability to assign
+ */
+ public boolean isLarger(Probability probability) {
+ return this.probability < probability.getLog();
+ }
+
+ /**
+ * Assigns a log probability to a label, discarding any previously assigned
probability.
+ *
+ * @param probability the log probability to assign
+ */
+ public void setLog(double probability) {
+ this.probability = probability;
+ }
+
+ /**
+ * Compounds the existing probability mass on the label with the new
probability passed in to the method.
+ *
+ * @param probability the probability weight to add
+ */
+ public void addIn(double probability) {
+ setLog(this.probability + Math.log(probability));
+ }
+
+ /**
+ * Returns the probability associated with a label
+ *
+ * @return the probability associated with the label
+ */
+ public Double get() {
+ return Math.exp(probability);
+ }
+
+ /**
+ * Returns the log probability associated with a label
+ *
+ * @return the log probability associated with the label
+ */
+ public Double getLog() {
+ return probability;
+ }
+
+ /**
+ * Returns the probabilities associated with all labels
+ *
+ * @return the HashMap of labels and their probabilities
+ */
+ public T getLabel() {
+ return label;
+ }
+
+ public String toString() {
+ return label.toString() + ":" + probability;
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesEvalParameters.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesEvalParameters.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesEvalParameters.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesEvalParameters.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package opennlp.tools.ml.naivebayes;
+
+import opennlp.tools.ml.model.Context;
+import opennlp.tools.ml.model.EvalParameters;
+
+public class NaiveBayesEvalParameters extends EvalParameters {
+
+ protected double[] outcomeTotals;
+ protected long vocabulary;
+
+ public NaiveBayesEvalParameters(Context[] params, int numOutcomes, double[]
outcomeTotals, long vocabulary) {
+ super(params, 0, 0, numOutcomes);
+ this.outcomeTotals = outcomeTotals;
+ this.vocabulary = vocabulary;
+ }
+
+ public double[] getOutcomeTotals() {
+ return outcomeTotals;
+ }
+
+ public long getVocabulary() {
+ return vocabulary;
+ }
+
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModel.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModel.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModel.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModel.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.InputStreamReader;
+import java.text.DecimalFormat;
+import java.util.Map;
+
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.Context;
+import opennlp.tools.ml.model.EvalParameters;
+import opennlp.tools.ml.model.IndexHashTable;
+
+/**
+ * Class implementing the multinomial Naive Bayes classifier model.
+ *
+ *
+ */
+public class NaiveBayesModel extends AbstractModel {
+
+ protected double[] outcomeTotals;
+ protected long vocabulary;
+ private static boolean isSmoothed = true; // Turn this off only for
testing/validation
+
+ public NaiveBayesModel(Context[] params, String[] predLabels,
IndexHashTable<String> pmap, String[] outcomeNames) {
+ super(params, predLabels, pmap, outcomeNames);
+ outcomeTotals = initOutcomeTotals(outcomeNames, params);
+ this.evalParams = new NaiveBayesEvalParameters(params,
outcomeNames.length, outcomeTotals, predLabels.length);
+ modelType = ModelType.NaiveBayes;
+ }
+
+ /**
+ * @deprecated use the constructor with the {@link IndexHashTable} instead!
+ */
+ @Deprecated
+ public NaiveBayesModel(Context[] params, String[] predLabels, Map<String,
Integer> pmap, String[] outcomeNames) {
+ super(params, predLabels, outcomeNames);
+ outcomeTotals = initOutcomeTotals(outcomeNames, params);
+ this.evalParams = new NaiveBayesEvalParameters(params,
outcomeNames.length, outcomeTotals, predLabels.length);
+ modelType = ModelType.NaiveBayes;
+ }
+
+ public NaiveBayesModel(Context[] params, String[] predLabels, String[]
outcomeNames) {
+ super(params, predLabels, outcomeNames);
+ outcomeTotals = initOutcomeTotals(outcomeNames, params);
+ this.evalParams = new NaiveBayesEvalParameters(params,
outcomeNames.length, outcomeTotals, predLabels.length);
+ modelType = ModelType.NaiveBayes;
+ }
+
+ protected double[] initOutcomeTotals(String[] outcomeNames, Context[]
params) {
+ double[] outcomeTotals = new double[outcomeNames.length];
+ for (int i = 0; i < params.length; ++i) {
+ Context context = params[i];
+ for (int j = 0; j < context.getOutcomes().length; ++j) {
+ int outcome = context.getOutcomes()[j];
+ double count = context.getParameters()[j];
+ outcomeTotals[outcome] += count;
+ }
+ }
+ return outcomeTotals;
+ }
+
+ public double[] eval(String[] context) {
+ return eval(context, new double[evalParams.getNumOutcomes()]);
+ }
+
+ public double[] eval(String[] context, float[] values) {
+ return eval(context, values, new double[evalParams.getNumOutcomes()]);
+ }
+
+ public double[] eval(String[] context, double[] probs) {
+ return eval(context, null, probs);
+ }
+
+ public double[] eval(String[] context, float[] values, double[] outsums) {
+ int[] scontexts = new int[context.length];
+ java.util.Arrays.fill(outsums, 0);
+ for (int i = 0; i < context.length; i++) {
+ Integer ci = pmap.get(context[i]);
+ scontexts[i] = ci == null ? -1 : ci;
+ }
+ return eval(scontexts, values, outsums, evalParams, true);
+ }
+
+ public static double[] eval(int[] context, double[] prior, EvalParameters
model) {
+ return eval(context, null, prior, model, true);
+ }
+
+ public static double[] eval(int[] context, float[] values, double[] prior,
EvalParameters model, boolean normalize) {
+ Probabilities<Integer> probabilities = new LogProbabilities<Integer>();
+ Context[] params = model.getParams();
+ double[] outcomeTotals = model instanceof NaiveBayesEvalParameters ?
((NaiveBayesEvalParameters) model).getOutcomeTotals() : new
double[prior.length];
+ long vocabulary = model instanceof NaiveBayesEvalParameters ?
((NaiveBayesEvalParameters) model).getVocabulary() : 0;
+ double[] activeParameters;
+ int[] activeOutcomes;
+ double value = 1;
+ for (int ci = 0; ci < context.length; ci++) {
+ if (context[ci] >= 0) {
+ Context predParams = params[context[ci]];
+ activeOutcomes = predParams.getOutcomes();
+ activeParameters = predParams.getParameters();
+ if (values != null) {
+ value = values[ci];
+ }
+ int ai = 0;
+ for (int i = 0; i < outcomeTotals.length && ai <
activeOutcomes.length; ++i) {
+ int oid = activeOutcomes[ai];
+ double numerator = oid == i ? activeParameters[ai++] * value : 0;
+ double denominator = outcomeTotals[i];
+ probabilities.addIn(i, getProbability(numerator, denominator,
vocabulary), 1);
+ }
+ }
+ }
+ double total = 0;
+ for (int i = 0; i < outcomeTotals.length; ++i) {
+ total += outcomeTotals[i];
+ }
+ for (int i = 0; i < outcomeTotals.length; ++i) {
+ double numerator = outcomeTotals[i];
+ double denominator = total;
+ probabilities.addIn(i, numerator / denominator, 1);
+ }
+ for (int i = 0; i < outcomeTotals.length; ++i) {
+ prior[i] = probabilities.get(i);
+ }
+ return prior;
+ }
+
+ private static double getProbability(double numerator, double denominator,
double vocabulary) {
+ if (isSmoothed)
+ return getSmoothedProbability(numerator, denominator, vocabulary);
+ else if (denominator == 0 || denominator < Double.MIN_VALUE)
+ return 0;
+ else
+ return 1.0 * (numerator) / (denominator);
+ }
+
+ static void setSmoothed(boolean flag) {
+ isSmoothed = flag;
+ }
+
+ static boolean isSmoothed() {
+ return isSmoothed;
+ }
+
+ private static double getSmoothedProbability(double numerator, double
denominator, double vocabulary) {
+ final double delta = 0.05; // Lidstone smoothing
+ final double featureVocabularySize = vocabulary;
+
+ return 1.0 * (numerator + delta) / (denominator + delta *
featureVocabularySize);
+ }
+
+ public static void main(String[] args) throws java.io.IOException {
+ if (args.length == 0) {
+ System.err.println("Usage: NaiveBayesModel modelname < contexts");
+ System.exit(1);
+ }
+ AbstractModel m = new NaiveBayesModelReader(new File(args[0])).getModel();
+ BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
+ DecimalFormat df = new java.text.DecimalFormat(".###");
+ for (String line = in.readLine(); line != null; line = in.readLine()) {
+ String[] context = line.split(" ");
+ double[] dist = m.eval(context);
+ for (int oi = 0; oi < dist.length; oi++) {
+ System.out.print("[" + m.getOutcome(oi) + " " + df.format(dist[oi]) +
"] ");
+ }
+ System.out.println();
+ }
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReader.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReader.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReader.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReader.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.io.File;
+import java.io.IOException;
+
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.AbstractModelReader;
+import opennlp.tools.ml.model.Context;
+import opennlp.tools.ml.model.DataReader;
+
+/**
+ * Abstract parent class for readers of NaiveBayes.
+ */
+public class NaiveBayesModelReader extends AbstractModelReader {
+
+ public NaiveBayesModelReader(File file) throws IOException {
+ super(file);
+ }
+
+ public NaiveBayesModelReader(DataReader dataReader) {
+ super(dataReader);
+ }
+
+ /**
+ * Retrieve a model from disk. It assumes that models are saved in the
+ * following sequence:
+ * <p/>
+ * <br>NaiveBayes (model type identifier)
+ * <br>1. # of parameters (int)
+ * <br>2. # of outcomes (int)
+ * <br> * list of outcome names (String)
+ * <br>3. # of different types of outcome patterns (int)
+ * <br> * list of (int int[])
+ * <br> [# of predicates for which outcome pattern is true] [outcome
pattern]
+ * <br>4. # of predicates (int)
+ * <br> * list of predicate names (String)
+ * <p/>
+ * <p>If you are creating a reader for a format which won't work with this
+ * (perhaps a database or xml file), override this method and ignore the
+ * other methods provided in this abstract class.
+ *
+ * @return The NaiveBayesModel stored in the format and location specified to
+ * this NaiveBayesModelReader (usually via its the constructor).
+ */
+ public AbstractModel constructModel() throws IOException {
+ String[] outcomeLabels = getOutcomes();
+ int[][] outcomePatterns = getOutcomePatterns();
+ String[] predLabels = getPredicates();
+ Context[] params = getParameters(outcomePatterns);
+
+ return new NaiveBayesModel(params,
+ predLabels,
+ outcomeLabels);
+ }
+
+ public void checkModelType() throws java.io.IOException {
+ String modelType = readUTF();
+ if (!modelType.equals("NaiveBayes"))
+ System.out.println("Error: attempting to load a " + modelType +
+ " model as a NaiveBayes model." +
+ " You should expect problems.");
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.AbstractModelWriter;
+import opennlp.tools.ml.model.ComparablePredicate;
+import opennlp.tools.ml.model.Context;
+import opennlp.tools.ml.model.IndexHashTable;
+
+/**
+ * Abstract parent class for NaiveBayes writers. It provides the persist
method
+ * which takes care of the structure of a stored document, and requires an
+ * extending class to define precisely how the data should be stored.
+ */
+public abstract class NaiveBayesModelWriter extends AbstractModelWriter {
+ protected Context[] PARAMS;
+ protected String[] OUTCOME_LABELS;
+ protected String[] PRED_LABELS;
+ int numOutcomes;
+
+ public NaiveBayesModelWriter(AbstractModel model) {
+
+ Object[] data = model.getDataStructures();
+ this.numOutcomes = model.getNumOutcomes();
+ PARAMS = (Context[]) data[0];
+ IndexHashTable<String> pmap = (IndexHashTable<String>) data[1];
+ OUTCOME_LABELS = (String[]) data[2];
+
+ PRED_LABELS = new String[pmap.size()];
+ pmap.toArray(PRED_LABELS);
+ }
+
+ protected ComparablePredicate[] sortValues() {
+ ComparablePredicate[] sortPreds;
+ ComparablePredicate[] tmpPreds = new ComparablePredicate[PARAMS.length];
+ int[] tmpOutcomes = new int[numOutcomes];
+ double[] tmpParams = new double[numOutcomes];
+ int numPreds = 0;
+ //remove parameters with 0 weight and predicates with no parameters
+ for (int pid = 0; pid < PARAMS.length; pid++) {
+ int numParams = 0;
+ double[] predParams = PARAMS[pid].getParameters();
+ int[] outcomePattern = PARAMS[pid].getOutcomes();
+ for (int pi = 0; pi < predParams.length; pi++) {
+ if (predParams[pi] != 0d) {
+ tmpOutcomes[numParams] = outcomePattern[pi];
+ tmpParams[numParams] = predParams[pi];
+ numParams++;
+ }
+ }
+
+ int[] activeOutcomes = new int[numParams];
+ double[] activeParams = new double[numParams];
+
+ for (int pi = 0; pi < numParams; pi++) {
+ activeOutcomes[pi] = tmpOutcomes[pi];
+ activeParams[pi] = tmpParams[pi];
+ }
+ if (numParams != 0) {
+ tmpPreds[numPreds] = new ComparablePredicate(PRED_LABELS[pid],
activeOutcomes, activeParams);
+ numPreds++;
+ }
+ }
+ System.err.println("Compressed " + PARAMS.length + " parameters to " +
numPreds);
+ sortPreds = new ComparablePredicate[numPreds];
+ System.arraycopy(tmpPreds, 0, sortPreds, 0, numPreds);
+ Arrays.sort(sortPreds);
+ return sortPreds;
+ }
+
+
+ protected List<List<ComparablePredicate>>
computeOutcomePatterns(ComparablePredicate[] sorted) {
+ ComparablePredicate cp = sorted[0];
+ List<List<ComparablePredicate>> outcomePatterns = new
ArrayList<List<ComparablePredicate>>();
+ List<ComparablePredicate> newGroup = new ArrayList<ComparablePredicate>();
+ for (ComparablePredicate predicate : sorted) {
+ if (cp.compareTo(predicate) == 0) {
+ newGroup.add(predicate);
+ } else {
+ cp = predicate;
+ outcomePatterns.add(newGroup);
+ newGroup = new ArrayList<ComparablePredicate>();
+ newGroup.add(predicate);
+ }
+ }
+ outcomePatterns.add(newGroup);
+ System.err.println(outcomePatterns.size() + " outcome patterns");
+ return outcomePatterns;
+ }
+
+ /**
+ * Writes the model to disk, using the <code>writeX()</code> methods
+ * provided by extending classes.
+ * <p/>
+ * <p>If you wish to create a NaiveBayesModelWriter which uses a different
+ * structure, it will be necessary to override the persist method in
+ * addition to implementing the <code>writeX()</code> methods.
+ */
+ public void persist() throws IOException {
+
+ // the type of model (NaiveBayes)
+ writeUTF("NaiveBayes");
+
+ // the mapping from outcomes to their integer indexes
+ writeInt(OUTCOME_LABELS.length);
+
+ for (String label : OUTCOME_LABELS) {
+ writeUTF(label);
+ }
+
+ // the mapping from predicates to the outcomes they contributed to.
+ // The sorting is done so that we actually can write this out more
+ // compactly than as the entire list.
+ ComparablePredicate[] sorted = sortValues();
+ List<List<ComparablePredicate>> compressed =
computeOutcomePatterns(sorted);
+
+ writeInt(compressed.size());
+
+ for (List<ComparablePredicate> a : compressed) {
+ writeUTF(a.size() + a.get(0).toString());
+ }
+
+ // the mapping from predicate names to their integer indexes
+ writeInt(sorted.length);
+
+ for (ComparablePredicate s : sorted) {
+ writeUTF(s.name);
+ }
+
+ // write out the parameters
+ for (int i = 0; i < sorted.length; i++)
+ for (int j = 0; j < sorted[i].params.length; j++)
+ writeDouble(sorted[i].params[j]);
+
+ close();
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.io.IOException;
+
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.EvalParameters;
+import opennlp.tools.ml.model.MutableContext;
+
+/**
+ * Trains models using the perceptron algorithm. Each outcome is represented
as
+ * a binary perceptron classifier. This supports standard (integer) weighting
as well
+ * average weighting as described in:
+ * Discriminative Training Methods for Hidden Markov Models: Theory and
Experiments
+ * with the Perceptron Algorithm. Michael Collins, EMNLP 2002.
+ */
+public class NaiveBayesTrainer extends AbstractEventTrainer {
+
+ public static final String NAIVE_BAYES_VALUE = "NAIVEBAYES";
+
+ /**
+ * Number of unique events which occurred in the event set.
+ */
+ private int numUniqueEvents;
+ /**
+ * Number of events in the event set.
+ */
+ private int numEvents;
+
+ /**
+ * Number of predicates.
+ */
+ private int numPreds;
+ /**
+ * Number of outcomes.
+ */
+ private int numOutcomes;
+ /**
+ * Records the array of predicates seen in each event.
+ */
+ private int[][] contexts;
+
+ /**
+ * The value associates with each context. If null then context values are
assumes to be 1.
+ */
+ private float[][] values;
+
+ /**
+ * List of outcomes for each event i, in context[i].
+ */
+ private int[] outcomeList;
+
+ /**
+ * Records the num of times an event has been seen for each event i, in
context[i].
+ */
+ private int[] numTimesEventsSeen;
+
+ /**
+ * Stores the String names of the outcomes. The NaiveBayes only tracks
outcomes
+ * as ints, and so this array is needed to save the model to disk and
+ * thereby allow users to know what the outcome was in human
+ * understandable terms.
+ */
+ private String[] outcomeLabels;
+
+ /**
+ * Stores the String names of the predicates. The NaiveBayes only tracks
+ * predicates as ints, and so this array is needed to save the model to
+ * disk and thereby allow users to know what the outcome was in human
+ * understandable terms.
+ */
+ private String[] predLabels;
+
+ private boolean printMessages = true;
+
+ public NaiveBayesTrainer() {
+ }
+
+ public boolean isSortAndMerge() {
+ return false;
+ }
+
+ public AbstractModel doTrain(DataIndexer indexer) throws IOException {
+ if (!isValid()) {
+ throw new IllegalArgumentException("trainParams are not valid!");
+ }
+
+ return this.trainModel(indexer);
+ }
+
+ // << members related to AbstractSequenceTrainer
+
+ public AbstractModel trainModel(DataIndexer di) {
+ display("Incorporating indexed data for training... \n");
+ contexts = di.getContexts();
+ values = di.getValues();
+ numTimesEventsSeen = di.getNumTimesEventsSeen();
+ numEvents = di.getNumEvents();
+ numUniqueEvents = contexts.length;
+
+ outcomeLabels = di.getOutcomeLabels();
+ outcomeList = di.getOutcomeList();
+
+ predLabels = di.getPredLabels();
+ numPreds = predLabels.length;
+ numOutcomes = outcomeLabels.length;
+
+ display("done.\n");
+
+ display("\tNumber of Event Tokens: " + numUniqueEvents + "\n");
+ display("\t Number of Outcomes: " + numOutcomes + "\n");
+ display("\t Number of Predicates: " + numPreds + "\n");
+
+ display("Computing model parameters...\n");
+
+ MutableContext[] finalParameters = findParameters();
+
+ display("...done.\n");
+
+ /*************** Create and return the model ******************/
+ return new NaiveBayesModel(finalParameters, predLabels, outcomeLabels);
+ }
+
+ private MutableContext[] findParameters() {
+
+ int[] allOutcomesPattern = new int[numOutcomes];
+ for (int oi = 0; oi < numOutcomes; oi++)
+ allOutcomesPattern[oi] = oi;
+
+ /** Stores the estimated parameter value of each predicate during
iteration. */
+ MutableContext[] params = new MutableContext[numPreds];
+ for (int pi = 0; pi < numPreds; pi++) {
+ params[pi] = new MutableContext(allOutcomesPattern, new
double[numOutcomes]);
+ for (int aoi = 0; aoi < numOutcomes; aoi++)
+ params[pi].setParameter(aoi, 0.0);
+ }
+
+ EvalParameters evalParams = new EvalParameters(params, numOutcomes);
+
+ double stepsize = 1;
+
+ for (int ei = 0; ei < numUniqueEvents; ei++) {
+ int targetOutcome = outcomeList[ei];
+ for (int ni = 0; ni < this.numTimesEventsSeen[ei]; ni++) {
+ for (int ci = 0; ci < contexts[ei].length; ci++) {
+ int pi = contexts[ei][ci];
+ if (values == null) {
+ params[pi].updateParameter(targetOutcome, stepsize);
+ } else {
+ params[pi].updateParameter(targetOutcome, stepsize *
values[ei][ci]);
+ }
+ }
+ }
+ }
+
+ // Output the final training stats.
+ trainingStats(evalParams);
+
+ return params;
+
+ }
+
+ private double trainingStats(EvalParameters evalParams) {
+ int numCorrect = 0;
+
+ for (int ei = 0; ei < numUniqueEvents; ei++) {
+ for (int ni = 0; ni < this.numTimesEventsSeen[ei]; ni++) {
+
+ double[] modelDistribution = new double[numOutcomes];
+
+ if (values != null)
+ NaiveBayesModel.eval(contexts[ei], values[ei], modelDistribution,
evalParams, false);
+ else
+ NaiveBayesModel.eval(contexts[ei], null, modelDistribution,
evalParams, false);
+
+ int max = maxIndex(modelDistribution);
+ if (max == outcomeList[ei])
+ numCorrect++;
+ }
+ }
+ double trainingAccuracy = (double) numCorrect / numEvents;
+ display("Stats: (" + numCorrect + "/" + numEvents + ") " +
trainingAccuracy + "\n");
+ return trainingAccuracy;
+ }
+
+
+ private int maxIndex(double[] values) {
+ int max = 0;
+ for (int i = 1; i < values.length; i++)
+ if (values[i] > values[max])
+ max = i;
+ return max;
+ }
+
+ private void display(String s) {
+ if (printMessages)
+ System.out.print(s);
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelReader.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelReader.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelReader.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelReader.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+
+import opennlp.tools.ml.model.PlainTextFileDataReader;
+
+public class PlainTextNaiveBayesModelReader extends NaiveBayesModelReader {
+
+ /**
+ * Constructor which directly instantiates the BufferedReader containing
+ * the model contents.
+ *
+ * @param br The BufferedReader containing the model information.
+ */
+ public PlainTextNaiveBayesModelReader(BufferedReader br) {
+ super(new PlainTextFileDataReader(br));
+ }
+
+ /**
+ * Constructor which takes a File and creates a reader for it. Detects
+ * whether the file is gzipped or not based on whether the suffix contains
+ * ".gz".
+ *
+ * @param f The File in which the model is stored.
+ */
+ public PlainTextNaiveBayesModelReader(File f) throws IOException {
+ super(f);
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelWriter.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelWriter.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelWriter.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/PlainTextNaiveBayesModelWriter.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.util.zip.GZIPOutputStream;
+
+import opennlp.tools.ml.model.AbstractModel;
+
+/**
+ * Model writer that saves models in plain text format.
+ */
+public class PlainTextNaiveBayesModelWriter extends NaiveBayesModelWriter {
+ BufferedWriter output;
+
+ /**
+ * Constructor which takes a NaiveBayesModel and a File and prepares itself
to
+ * write the model to that file. Detects whether the file is gzipped or not
+ * based on whether the suffix contains ".gz".
+ *
+ * @param model The NaiveBayesModel which is to be persisted.
+ * @param f The File in which the model is to be persisted.
+ */
+ public PlainTextNaiveBayesModelWriter(AbstractModel model, File f)
+ throws IOException, FileNotFoundException {
+
+ super(model);
+ if (f.getName().endsWith(".gz")) {
+ output = new BufferedWriter(new OutputStreamWriter(
+ new GZIPOutputStream(new FileOutputStream(f))));
+ } else {
+ output = new BufferedWriter(new FileWriter(f));
+ }
+ }
+
+ /**
+ * Constructor which takes a NaiveBayesModel and a BufferedWriter and
prepares
+ * itself to write the model to that writer.
+ *
+ * @param model The NaiveBayesModel which is to be persisted.
+ * @param bw The BufferedWriter which will be used to persist the model.
+ */
+ public PlainTextNaiveBayesModelWriter(AbstractModel model, BufferedWriter
bw) {
+ super(model);
+ output = bw;
+ }
+
+ public void writeUTF(String s) throws java.io.IOException {
+ output.write(s);
+ output.newLine();
+ }
+
+ public void writeInt(int i) throws java.io.IOException {
+ output.write(Integer.toString(i));
+ output.newLine();
+ }
+
+ public void writeDouble(double d) throws java.io.IOException {
+ output.write(Double.toString(d));
+ output.newLine();
+ }
+
+ public void close() throws java.io.IOException {
+ output.flush();
+ output.close();
+ }
+
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probabilities.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probabilities.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probabilities.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probabilities.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.util.*;
+
+/**
+ * Class implementing the probability distribution over labels returned by a
classifier.
+ *
+ * @param <T> the label (category) class
+ *
+ */
+public abstract class Probabilities<T> {
+ protected HashMap<T, Double> map = new HashMap<T, Double>();
+
+ protected transient boolean isNormalised = false;
+ protected Map<T, Double> normalised;
+
+ protected double confidence = 0.0;
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability.
+ *
+ * @param t the label to which the probability is being assigned
+ * @param probability the probability to assign
+ */
+ public void set(T t, double probability) {
+ isNormalised = false;
+ map.put(t, probability);
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability.
+ *
+ * @param t the label to which the probability is being assigned
+ * @param probability the probability to assign
+ */
+ public void set(T t, Probability<T> probability) {
+ isNormalised = false;
+ map.put(t, probability.get());
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability, if the new probability is greater than the old one.
+ *
+ * @param t the label to which the probability is being assigned
+ * @param probability the probability to assign
+ */
+ public void setIfLarger(T t, double probability) {
+ Double p = map.get(t);
+ if (p == null || probability > p) {
+ isNormalised = false;
+ map.put(t, probability);
+ }
+ }
+
+ /**
+ * Assigns a log probability to a label, discarding any previously assigned
probability.
+ *
+ * @param t the label to which the log probability is being
assigned
+ * @param probability the log probability to assign
+ */
+ public void setLog(T t, double probability) {
+ set(t, Math.exp(probability));
+ }
+
+ /**
+ * Compounds the existing probability mass on the label with the new
probability passed in to the method.
+ *
+ * @param t the label whose probability mass is being updated
+ * @param probability the probability weight to add
+ * @param count the amplifying factor for the probability compounding
+ */
+ public void addIn(T t, double probability, int count) {
+ isNormalised = false;
+ Double p = map.get(t);
+ if (p == null)
+ p = 1.0;
+ probability = Math.pow(probability, count);
+ map.put(t, p * probability);
+ }
+
+ /**
+ * Returns the probability associated with a label
+ *
+ * @param t the label whose probability needs to be returned
+ * @return the probability associated with the label
+ */
+ public Double get(T t) {
+ Double d = normalize().get(t);
+ if (d == null)
+ return 0.0;
+ return d;
+ }
+
+ /**
+ * Returns the log probability associated with a label
+ *
+ * @param t the label whose log probability needs to be returned
+ * @return the log probability associated with the label
+ */
+ public Double getLog(T t) {
+ return Math.log(get(t));
+ }
+
+ /**
+ * Returns the probabilities associated with all labels
+ *
+ * @return the HashMap of labels and their probabilities
+ */
+ public Set<T> getKeys() {
+ return map.keySet();
+ }
+
+ /**
+ * Returns the probabilities associated with all labels
+ *
+ * @return the HashMap of labels and their probabilities
+ */
+ public Map<T, Double> getAll() {
+ return normalize();
+ }
+
+ private Map<T, Double> normalize() {
+ if (isNormalised)
+ return normalised;
+ Map<T, Double> temp = createMapDataStructure();
+ double sum = 0;
+ for (T t : map.keySet()) {
+ Double p = map.get(t);
+ if (p != null) {
+ sum += p;
+ }
+ }
+ for (T t : temp.keySet()) {
+ Double p = temp.get(t);
+ if (p != null) {
+ temp.put(t, p / sum);
+ }
+ }
+ normalised = temp;
+ isNormalised = true;
+ return temp;
+ }
+
+ protected Map<T, Double> createMapDataStructure() {
+ return new HashMap<T, Double>();
+ }
+
+ /**
+ * Returns the most likely label
+ *
+ * @return the label that has the highest associated probability
+ */
+ public T getMax() {
+ double max = 0;
+ T maxT = null;
+ for (T t : map.keySet()) {
+ Double temp = map.get(t);
+ if (temp >= max) {
+ max = temp;
+ maxT = t;
+ }
+ }
+ return maxT;
+ }
+
+ /**
+ * Returns the probability of the most likely label
+ *
+ * @return the highest probability
+ */
+ public double getMaxValue() {
+ return get(getMax());
+ }
+
+ public void discardCountsBelow(double i) {
+ ArrayList<T> labelsToRemove = new ArrayList<T>();
+ for (T label : map.keySet()) {
+ Double sum = map.get(label);
+ if (sum == null) sum = 0.0;
+ if (sum < i)
+ labelsToRemove.add(label);
+ }
+ for (T label : labelsToRemove) {
+ map.remove(label);
+ }
+ }
+
+ /**
+ * Returns the best confidence with which this set of probabilities has been
calculated.
+ * This is a function of the amount of data that supports the assertion.
+ * It is also a measure of the accuracy of the estimator of the probability.
+ *
+ * @return the best confidence of the probabilities
+ */
+ public double getConfidence() {
+ return confidence;
+ }
+
+ /**
+ * Sets the best confidence with which this set of probabilities has been
calculated.
+ * This is a function of the amount of data that supports the assertion.
+ * It is also a measure of the accuracy of the estimator of the probability.
+ *
+ * @param confidence the confidence in the probabilities
+ */
+ public void setConfidence(double confidence) {
+ this.confidence = confidence;
+ }
+
+ public String toString() {
+ return getAll().toString();
+ }
+
+}
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probability.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probability.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probability.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/Probability.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+/**
+ * Class implementing the probability for a label.
+ *
+ * @param <T> the label (category) class
+ *
+ */
+public class Probability<T> {
+ protected T label;
+ protected double probability = 1.0;
+
+ public Probability(T label) {
+ this.label = label;
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability.
+ *
+ * @param probability the probability to assign
+ */
+ public void set(double probability) {
+ this.probability = probability;
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability.
+ *
+ * @param probability the probability to assign
+ */
+ public void set(Probability probability) {
+ this.probability = probability.get();
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability, if the new probability is greater than the old one.
+ *
+ * @param probability the probability to assign
+ */
+ public void setIfLarger(double probability) {
+ if (this.probability < probability) {
+ this.probability = probability;
+ }
+ }
+
+ /**
+ * Assigns a probability to a label, discarding any previously assigned
probability, if the new probability is greater than the old one.
+ *
+ * @param probability the probability to assign
+ */
+ public void setIfLarger(Probability probability) {
+ if (this.probability < probability.get()) {
+ this.probability = probability.get();
+ }
+ }
+
+ /**
+ * Checks if a probability is greater than the old one.
+ *
+ * @param probability the probability to assign
+ */
+ public boolean isLarger(Probability probability) {
+ return this.probability < probability.get();
+ }
+
+ /**
+ * Assigns a log probability to a label, discarding any previously assigned
probability.
+ *
+ * @param probability the log probability to assign
+ */
+ public void setLog(double probability) {
+ set(Math.exp(probability));
+ }
+
+ /**
+ * Compounds the existing probability mass on the label with the new
probability passed in to the method.
+ *
+ * @param probability the probability weight to add
+ */
+ public void addIn(double probability) {
+ set(this.probability * probability);
+ }
+
+ /**
+ * Returns the probability associated with a label
+ *
+ * @return the probability associated with the label
+ */
+ public Double get() {
+ return probability;
+ }
+
+ /**
+ * Returns the log probability associated with a label
+ *
+ * @return the log probability associated with the label
+ */
+ public Double getLog() {
+ return Math.log(get());
+ }
+
+ /**
+ * Returns the probabilities associated with all labels
+ *
+ * @return the HashMap of labels and their probabilities
+ */
+ public T getLabel() {
+ return label;
+ }
+
+ public String toString() {
+ return label == null ? "" + probability : label.toString() + ":" +
probability;
+ }
+}
Added:
opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/doccat/DocumentCategorizerNBTest.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/doccat/DocumentCategorizerNBTest.java?rev=1695334&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/doccat/DocumentCategorizerNBTest.java
(added)
+++
opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/doccat/DocumentCategorizerNBTest.java
Tue Aug 11 15:58:53 2015
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package opennlp.tools.doccat;
+
+import static org.junit.Assert.assertEquals;
+
+import java.io.IOException;
+import java.util.Set;
+import java.util.SortedMap;
+
+import opennlp.tools.util.ObjectStream;
+import opennlp.tools.util.ObjectStreamUtils;
+import opennlp.tools.util.TrainingParameters;
+
+import org.junit.Test;
+
+public class DocumentCategorizerNBTest {
+
+ @Test
+ public void testSimpleTraining() throws IOException {
+
+ ObjectStream<DocumentSample> samples =
ObjectStreamUtils.createObjectStream(new DocumentSample("1", new String[]{"a",
"b", "c"}),
+ new DocumentSample("1", new String[]{"a", "b", "c", "1", "2"}),
+ new DocumentSample("1", new String[]{"a", "b", "c", "3", "4"}),
+ new DocumentSample("0", new String[]{"x", "y", "z"}),
+ new DocumentSample("0", new String[]{"x", "y", "z", "5", "6"}),
+ new DocumentSample("0", new String[]{"x", "y", "z", "7", "8"}));
+
+ TrainingParameters params = new TrainingParameters();
+ params.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(100));
+ params.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(0));
+
+ DoccatModel model = DocumentCategorizerNB.train("x-unspecified", samples,
+ params, new BagOfWordsFeatureGenerator());
+
+ DocumentCategorizer doccat = new DocumentCategorizerNB(model);
+
+ double aProbs[] = doccat.categorize("a");
+ assertEquals("1", doccat.getBestCategory(aProbs));
+
+ double bProbs[] = doccat.categorize("x");
+ assertEquals("0", doccat.getBestCategory(bProbs));
+
+ //test to make sure sorted map's last key is cat 1 because it has the
highest score.
+ SortedMap<Double, Set<String>> sortedScoreMap = doccat.sortedScoreMap("a");
+ for (String cat : sortedScoreMap.get(sortedScoreMap.lastKey())) {
+ assertEquals("1", cat);
+ break;
+ }
+ System.out.println("");
+
+ }
+}