This is an automated email from the ASF dual-hosted git repository.

rzo1 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git


The following commit(s) were added to refs/heads/master by this push:
     new 59f10a29 OPENNLP-1411 Provide equals and hashCode for POSModel
59f10a29 is described below

commit 59f10a2996d2dac842906c1ce1d150b760a4b12d
Author: Martin Wiesner <[email protected]>
AuthorDate: Sun Dec 11 18:15:12 2022 +0100

    OPENNLP-1411 Provide equals and hashCode for POSModel
    
    - adds specific `equals` and `hashCode` implementations for `POSModel`, 
`GISModel`, and `PerceptronModel`.
    - improves `POSModelTest` by adding further assertions and fixes to open 
TODOs.
    - adds 'Override' where applicable
---
 .../java/opennlp/tools/ml/maxent/GISModel.java     | 27 +++++++++++++
 .../tools/ml/perceptron/PerceptronModel.java       | 45 ++++++++++++++++++++++
 .../main/java/opennlp/tools/postag/POSModel.java   | 28 ++++++++++++++
 .../java/opennlp/tools/postag/POSModelTest.java    | 33 ++++++++--------
 4 files changed, 115 insertions(+), 18 deletions(-)

diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISModel.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISModel.java
index 8a0cd6d2..4034efde 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISModel.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISModel.java
@@ -17,6 +17,9 @@
 
 package opennlp.tools.ml.maxent;
 
+import java.util.Arrays;
+import java.util.Objects;
+
 import opennlp.tools.ml.ArrayMath;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.Context;
@@ -77,14 +80,17 @@ public final class GISModel extends AbstractModel {
    *         string representation of the outcomes can be obtained from the
    *         method getOutcome(int i).
    */
+  @Override
   public final double[] eval(String[] context) {
     return (eval(context, new double[evalParams.getNumOutcomes()]));
   }
 
+  @Override
   public final double[] eval(String[] context, float[] values) {
     return (eval(context, values, new double[evalParams.getNumOutcomes()]));
   }
 
+  @Override
   public final double[] eval(String[] context, double[] outsums) {
     return eval(context, null, outsums);
   }
@@ -197,4 +203,25 @@ public final class GISModel extends AbstractModel {
     }
     return prior;
   }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(pmap, Arrays.hashCode(outcomeNames), evalParams, 
prior);
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (obj == this) {
+      return true;
+    }
+
+    if (obj instanceof GISModel) {
+      GISModel model = (GISModel) obj;
+
+      return pmap.equals(model.pmap) && Objects.deepEquals(outcomeNames, 
model.outcomeNames)
+              && Objects.equals(prior, model.prior);
+    }
+
+    return false;
+  }
 }
diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModel.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModel.java
index 74851875..118006c7 100644
--- 
a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModel.java
+++ 
b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModel.java
@@ -17,6 +17,9 @@
 
 package opennlp.tools.ml.perceptron;
 
+import java.util.Arrays;
+import java.util.Objects;
+
 import opennlp.tools.ml.ArrayMath;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.Context;
@@ -29,14 +32,17 @@ public class PerceptronModel extends AbstractModel {
     modelType = ModelType.Perceptron;
   }
 
+  @Override
   public double[] eval(String[] context) {
     return eval(context,new double[evalParams.getNumOutcomes()]);
   }
 
+  @Override
   public double[] eval(String[] context, float[] values) {
     return eval(context,values,new double[evalParams.getNumOutcomes()]);
   }
 
+  @Override
   public double[] eval(String[] context, double[] probs) {
     return eval(context,null,probs);
   }
@@ -91,4 +97,43 @@ public class PerceptronModel extends AbstractModel {
     }
     return prior;
   }
+
+  @Override
+  public int hashCode() {
+    /*
+     * Note:
+     * The hashcode for 'pmap' can not be used here, as PerceptronModelWriter
+     * uses compressions during sortValues() operation, quote:
+     * "remove parameters with 0 weight and predicates with no parameters"
+     *
+     * This leads to fewer entries in 'pmap' for serialized PerceptronModel 
instances
+     * that were trained from scratch.
+     */
+    return Objects.hash(Arrays.hashCode(outcomeNames), evalParams, prior);
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (obj == this) {
+      return true;
+    }
+
+    if (obj instanceof PerceptronModel) {
+      PerceptronModel model = (PerceptronModel) obj;
+
+      /*
+       * Note:
+       * The comparison 'pmap.equals(model.pmap)' can not be made here, as 
PerceptronModelWriter
+       * uses compressions during sortValues() operation, quote:
+       * "remove parameters with 0 weight and predicates with no parameters"
+       *
+       * This leads to fewer entries in 'pmap' for serialized PerceptronModel 
instances
+       * that were trained from scratch.
+       */
+      return Objects.deepEquals(outcomeNames, model.outcomeNames)
+              && Objects.equals(prior, model.prior);
+    }
+
+    return false;
+  }
 }
diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java 
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
index 422cac56..adcf1db3 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java
@@ -22,11 +22,13 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.net.URL;
 import java.nio.file.Path;
+import java.util.Arrays;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Properties;
 
 import opennlp.tools.ml.BeamSearch;
+import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.MaxentModel;
 import opennlp.tools.ml.model.SequenceClassificationModel;
 import opennlp.tools.util.BaseToolFactory;
@@ -229,4 +231,30 @@ public final class POSModel extends BaseModel implements 
SerializableArtifact {
   public Class<POSModelSerializer> getArtifactSerializerClass() {
     return POSModelSerializer.class;
   }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(artifactMap.get("manifest.properties"), 
artifactMap.get("pos.model"),
+            Arrays.hashCode((byte[]) artifactMap.get("generator.featuregen"))
+    );
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (obj == this) {
+      return true;
+    }
+
+    if (obj instanceof POSModel) {
+      POSModel model = (POSModel) obj;
+      Map<String, Object> artifactMapToCheck = model.artifactMap;
+      AbstractModel abstractModel = (AbstractModel) 
artifactMapToCheck.get("pos.model");
+
+      return 
artifactMap.get("manifest.properties").equals(artifactMapToCheck.get("manifest.properties"))
 &&
+              artifactMap.get("pos.model").equals(abstractModel) &&
+              Arrays.equals((byte[]) artifactMap.get("generator.featuregen"),
+                            (byte[]) 
artifactMapToCheck.get("generator.featuregen"));
+    }
+    return false;
+  }
 }
diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java 
b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
index e9d86159..14086c16 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
@@ -21,6 +21,7 @@ import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import opennlp.tools.util.model.ModelType;
@@ -30,34 +31,30 @@ public class POSModelTest {
   @Test
   void testPOSModelSerializationMaxent() throws IOException {
     POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.MAXENT);
+    Assertions.assertFalse(posModel.isLoadedFromSerialized());
 
-    ByteArrayOutputStream out = new ByteArrayOutputStream();
-
-    try {
+    try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
       posModel.serialize(out);
-    } finally {
-      out.close();
-    }
 
-    POSModel recreatedPosModel = new POSModel(new 
ByteArrayInputStream(out.toByteArray()));
-
-    // TODO: add equals to pos model
+      POSModel recreatedPosModel = new POSModel(new 
ByteArrayInputStream(out.toByteArray()));
+      Assertions.assertNotNull(recreatedPosModel);
+      Assertions.assertTrue(recreatedPosModel.isLoadedFromSerialized());
+      Assertions.assertEquals(posModel, recreatedPosModel);
+    }
   }
 
   @Test
   void testPOSModelSerializationPerceptron() throws IOException {
     POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.PERCEPTRON);
-
-    ByteArrayOutputStream out = new ByteArrayOutputStream();
-
-    try {
+    Assertions.assertFalse(posModel.isLoadedFromSerialized());
+    
+    try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
       posModel.serialize(out);
-    } finally {
-      out.close();
-    }
 
-    POSModel recreatedPosModel = new POSModel(new 
ByteArrayInputStream(out.toByteArray()));
+      POSModel recreatedPosModel = new POSModel(new 
ByteArrayInputStream(out.toByteArray()));
+      Assertions.assertTrue(recreatedPosModel.isLoadedFromSerialized());
+      Assertions.assertEquals(posModel, recreatedPosModel);
+    }
 
-    // TODO: add equals to pos model
   }
 }

Reply via email to