This is an automated email from the ASF dual-hosted git repository.
rzo1 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/master by this push:
new 59f10a29 OPENNLP-1411 Provide equals and hashCode for POSModel
59f10a29 is described below
commit 59f10a2996d2dac842906c1ce1d150b760a4b12d
Author: Martin Wiesner <[email protected]>
AuthorDate: Sun Dec 11 18:15:12 2022 +0100
OPENNLP-1411 Provide equals and hashCode for POSModel
- adds specific `equals` and `hashCode` implementations for `POSModel`,
`GISModel`, and `PerceptronModel`.
- improves `POSModelTest` by adding further assertions and fixes to open
TODOs.
- adds 'Override' where applicable
---
.../java/opennlp/tools/ml/maxent/GISModel.java | 27 +++++++++++++
.../tools/ml/perceptron/PerceptronModel.java | 45 ++++++++++++++++++++++
.../main/java/opennlp/tools/postag/POSModel.java | 28 ++++++++++++++
.../java/opennlp/tools/postag/POSModelTest.java | 33 ++++++++--------
4 files changed, 115 insertions(+), 18 deletions(-)
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISModel.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISModel.java
index 8a0cd6d2..4034efde 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISModel.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISModel.java
@@ -17,6 +17,9 @@
package opennlp.tools.ml.maxent;
+import java.util.Arrays;
+import java.util.Objects;
+
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
@@ -77,14 +80,17 @@ public final class GISModel extends AbstractModel {
* string representation of the outcomes can be obtained from the
* method getOutcome(int i).
*/
+ @Override
public final double[] eval(String[] context) {
return (eval(context, new double[evalParams.getNumOutcomes()]));
}
+ @Override
public final double[] eval(String[] context, float[] values) {
return (eval(context, values, new double[evalParams.getNumOutcomes()]));
}
+ @Override
public final double[] eval(String[] context, double[] outsums) {
return eval(context, null, outsums);
}
@@ -197,4 +203,25 @@ public final class GISModel extends AbstractModel {
}
return prior;
}
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(pmap, Arrays.hashCode(outcomeNames), evalParams,
prior);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == this) {
+ return true;
+ }
+
+ if (obj instanceof GISModel) {
+ GISModel model = (GISModel) obj;
+
+ return pmap.equals(model.pmap) && Objects.deepEquals(outcomeNames,
model.outcomeNames)
+ && Objects.equals(prior, model.prior);
+ }
+
+ return false;
+ }
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModel.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModel.java
index 74851875..118006c7 100644
---
a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModel.java
+++
b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModel.java
@@ -17,6 +17,9 @@
package opennlp.tools.ml.perceptron;
+import java.util.Arrays;
+import java.util.Objects;
+
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
@@ -29,14 +32,17 @@ public class PerceptronModel extends AbstractModel {
modelType = ModelType.Perceptron;
}
+ @Override
public double[] eval(String[] context) {
return eval(context,new double[evalParams.getNumOutcomes()]);
}
+ @Override
public double[] eval(String[] context, float[] values) {
return eval(context,values,new double[evalParams.getNumOutcomes()]);
}
+ @Override
public double[] eval(String[] context, double[] probs) {
return eval(context,null,probs);
}
@@ -91,4 +97,43 @@ public class PerceptronModel extends AbstractModel {
}
return prior;
}
+
+ @Override
+ public int hashCode() {
+ /*
+ * Note:
+ * The hashcode for 'pmap' can not be used here, as PerceptronModelWriter
+ * uses compressions during sortValues() operation, quote:
+ * "remove parameters with 0 weight and predicates with no parameters"
+ *
+ * This leads to fewer entries in 'pmap' for serialized PerceptronModel
instances
+ * that were trained from scratch.
+ */
+ return Objects.hash(Arrays.hashCode(outcomeNames), evalParams, prior);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == this) {
+ return true;
+ }
+
+ if (obj instanceof PerceptronModel) {
+ PerceptronModel model = (PerceptronModel) obj;
+
+ /*
+ * Note:
+ * The comparison 'pmap.equals(model.pmap)' can not be made here, as
PerceptronModelWriter
+ * uses compressions during sortValues() operation, quote:
+ * "remove parameters with 0 weight and predicates with no parameters"
+ *
+ * This leads to fewer entries in 'pmap' for serialized PerceptronModel
instances
+ * that were trained from scratch.
+ */
+ return Objects.deepEquals(outcomeNames, model.outcomeNames)
+ && Objects.equals(prior, model.prior);
+ }
+
+ return false;
+ }
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
index 422cac56..adcf1db3 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
@@ -22,11 +22,13 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
+import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import opennlp.tools.ml.BeamSearch;
+import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.util.BaseToolFactory;
@@ -229,4 +231,30 @@ public final class POSModel extends BaseModel implements
SerializableArtifact {
public Class<POSModelSerializer> getArtifactSerializerClass() {
return POSModelSerializer.class;
}
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(artifactMap.get("manifest.properties"),
artifactMap.get("pos.model"),
+ Arrays.hashCode((byte[]) artifactMap.get("generator.featuregen"))
+ );
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == this) {
+ return true;
+ }
+
+ if (obj instanceof POSModel) {
+ POSModel model = (POSModel) obj;
+ Map<String, Object> artifactMapToCheck = model.artifactMap;
+ AbstractModel abstractModel = (AbstractModel)
artifactMapToCheck.get("pos.model");
+
+ return
artifactMap.get("manifest.properties").equals(artifactMapToCheck.get("manifest.properties"))
&&
+ artifactMap.get("pos.model").equals(abstractModel) &&
+ Arrays.equals((byte[]) artifactMap.get("generator.featuregen"),
+ (byte[])
artifactMapToCheck.get("generator.featuregen"));
+ }
+ return false;
+ }
}
diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
index e9d86159..14086c16 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
@@ -21,6 +21,7 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
+import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import opennlp.tools.util.model.ModelType;
@@ -30,34 +31,30 @@ public class POSModelTest {
@Test
void testPOSModelSerializationMaxent() throws IOException {
POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.MAXENT);
+ Assertions.assertFalse(posModel.isLoadedFromSerialized());
- ByteArrayOutputStream out = new ByteArrayOutputStream();
-
- try {
+ try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
posModel.serialize(out);
- } finally {
- out.close();
- }
- POSModel recreatedPosModel = new POSModel(new
ByteArrayInputStream(out.toByteArray()));
-
- // TODO: add equals to pos model
+ POSModel recreatedPosModel = new POSModel(new
ByteArrayInputStream(out.toByteArray()));
+ Assertions.assertNotNull(recreatedPosModel);
+ Assertions.assertTrue(recreatedPosModel.isLoadedFromSerialized());
+ Assertions.assertEquals(posModel, recreatedPosModel);
+ }
}
@Test
void testPOSModelSerializationPerceptron() throws IOException {
POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.PERCEPTRON);
-
- ByteArrayOutputStream out = new ByteArrayOutputStream();
-
- try {
+ Assertions.assertFalse(posModel.isLoadedFromSerialized());
+
+ try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
posModel.serialize(out);
- } finally {
- out.close();
- }
- POSModel recreatedPosModel = new POSModel(new
ByteArrayInputStream(out.toByteArray()));
+ POSModel recreatedPosModel = new POSModel(new
ByteArrayInputStream(out.toByteArray()));
+ Assertions.assertTrue(recreatedPosModel.isLoadedFromSerialized());
+ Assertions.assertEquals(posModel, recreatedPosModel);
+ }
- // TODO: add equals to pos model
}
}