This is an automated email from the ASF dual-hosted git repository.

mawiesne pushed a commit to branch 
OPENNLP-1556-Improve-speed-of-checksum-computation-in-TwoPassDataIndexer
in repository https://gitbox.apache.org/repos/asf/opennlp.git

commit 42d1b3418f20f6143e3f6201a1b4ce6335923957
Author: Martin Wiesner <[email protected]>
AuthorDate: Sun May 5 10:18:54 2024 +0200

    OPENNLP-1556 Improve speed of checksum computation in TwoPassDataIndexer
    - adjusts TwoPassDataIndexer to make use of JDK's built-in 
CheckedOutputStream / CheckedInputStream for checksum (CRC32c) computations
    - removes untested class HashSumEventStream which is just a wrapper for 
calling a slow toString() in Event to get some bytes to use for the computation 
of a checksum
    - provides a HashSumEventStream replacement: ChecksumEventStream which 
makes use of the faster CRC32c checksum computation, avoiding cryptographic 
hash functions such as MD5
    - adds JUnit tests for ChecksumEventStream
---
 .../opennlp/tools/ml/AbstractEventTrainer.java     |  6 +-
 ...umEventStream.java => ChecksumEventStream.java} | 43 +++-------
 .../opennlp/tools/ml/model/TwoPassDataIndexer.java | 65 +++++++--------
 .../tools/ml/model/ChecksumEventStreamTest.java    | 93 ++++++++++++++++++++++
 4 files changed, 142 insertions(+), 65 deletions(-)

diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
index d546739a..9ea5ddce 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
@@ -20,10 +20,10 @@ package opennlp.tools.ml;
 import java.io.IOException;
 
 import opennlp.tools.ml.model.AbstractDataIndexer;
+import opennlp.tools.ml.model.ChecksumEventStream;
 import opennlp.tools.ml.model.DataIndexer;
 import opennlp.tools.ml.model.DataIndexerFactory;
 import opennlp.tools.ml.model.Event;
-import opennlp.tools.ml.model.HashSumEventStream;
 import opennlp.tools.ml.model.MaxentModel;
 import opennlp.tools.util.InsufficientTrainingDataException;
 import opennlp.tools.util.ObjectStream;
@@ -85,10 +85,10 @@ public abstract class AbstractEventTrainer extends 
AbstractTrainer implements Ev
   public final MaxentModel train(ObjectStream<Event> events) throws 
IOException {
     validate();
 
-    HashSumEventStream hses = new HashSumEventStream(events);
+    ChecksumEventStream hses = new ChecksumEventStream(events);
     DataIndexer indexer = getDataIndexer(hses);
 
-    addToReport("Training-Eventhash", hses.calculateHashSum().toString(16));
+    addToReport("Training-Eventhash", 
String.valueOf(hses.calculateChecksum()));
     return train(indexer);
   }
 }
diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java
similarity index 50%
rename from 
opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java
rename to 
opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java
index 6fafb243..41c8caa4 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java
+++ 
b/opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java
@@ -18,57 +18,38 @@
 package opennlp.tools.ml.model;
 
 import java.io.IOException;
-import java.math.BigInteger;
 import java.nio.charset.StandardCharsets;
-import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
+import java.util.zip.CRC32C;
+import java.util.zip.Checksum;
 
 import opennlp.tools.util.AbstractObjectStream;
 import opennlp.tools.util.ObjectStream;
 
-/**
- * A hash sum based {@link AbstractObjectStream} implementation.
- *
- * @see Event
- * @see MessageDigest
- * @see AbstractObjectStream
- */
-public class HashSumEventStream extends AbstractObjectStream<Event> {
+public class ChecksumEventStream extends AbstractObjectStream<Event> {
 
-  private final MessageDigest digest;
+  private final Checksum checksum;
 
-  public HashSumEventStream(ObjectStream<Event> eventStream) {
+  public ChecksumEventStream(ObjectStream<Event> eventStream) {
     super(eventStream);
-
-    try {
-      digest = MessageDigest.getInstance("MD5");
-    } catch (NoSuchAlgorithmException e) {
-      // should never happen: do all java runtimes have md5 ?!
-      throw new IllegalStateException(e);
-    }
+    // CRC32C supports CPU-specific acceleration instructions
+    checksum = new CRC32C();
   }
 
   @Override
   public Event read() throws IOException {
     Event event = super.read();
-
     if (event != null) {
-      digest.update(event.toString().getBytes(StandardCharsets.UTF_8));
+      checksum.update(event.toString().getBytes(StandardCharsets.UTF_8));
     }
-
     return event;
   }
 
   /**
-   * Calculates the hash sum of the stream and wraps it into a {@link 
BigInteger}.
-   * Note: The method must be called after the stream is completely consumed.
+   * Calculates the check sum of the stream.
    *
-   * @return The calculated hash sum as {@link BigInteger}.
-   * @throws IllegalStateException Thrown if the stream is not consumed 
completely,
-   *     completely means that hasNext() returns {@code false}.
+   * @return The calculated check sum as {@code long}.
    */
-  public BigInteger calculateHashSum() {
-    return new BigInteger(1, digest.digest());
+  public long calculateChecksum()  {
+    return checksum.getValue();
   }
-
 }
diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
index dd67dc21..0e49a4bd 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
@@ -17,7 +17,6 @@
 
 package opennlp.tools.ml.model;
 
-
 import java.io.BufferedInputStream;
 import java.io.BufferedOutputStream;
 import java.io.DataInputStream;
@@ -26,11 +25,13 @@ import java.io.File;
 import java.io.FileInputStream;
 import java.io.FileOutputStream;
 import java.io.IOException;
-import java.math.BigInteger;
 import java.nio.file.Files;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.zip.CRC32C;
+import java.util.zip.CheckedInputStream;
+import java.util.zip.CheckedOutputStream;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -63,48 +64,50 @@ public class TwoPassDataIndexer extends AbstractDataIndexer 
{
     int cutoff = trainingParameters.getIntParameter(CUTOFF_PARAM, 
CUTOFF_DEFAULT);
     boolean sort = trainingParameters.getBooleanParameter(SORT_PARAM, 
SORT_DEFAULT);
 
-    long start = System.currentTimeMillis();
-
     logger.info("Indexing events with TwoPass using cutoff of {}", cutoff);
-
     logger.info("Computing event counts...");
 
+    long start = System.currentTimeMillis();
     Map<String,Integer> predicateIndex = new HashMap<>();
-
     File tmp = Files.createTempFile("events", null).toFile();
     tmp.deleteOnExit();
     int numEvents;
-    BigInteger writeHash;
-    HashSumEventStream writeEventStream = new HashSumEventStream(eventStream); 
 // do not close.
-    try (DataOutputStream dos = new DataOutputStream(new 
BufferedOutputStream(new FileOutputStream(tmp)))) {
-      numEvents = computeEventCounts(writeEventStream, dos, predicateIndex, 
cutoff);
-    }
-    writeHash = writeEventStream.calculateHashSum();
+    long writeChecksum;
 
-    logger.info("done. {} events", numEvents);
-    logger.info("Indexing...");
+    try (BufferedOutputStream out = new BufferedOutputStream(new 
FileOutputStream(tmp));
+        CheckedOutputStream writeStream = new CheckedOutputStream(out, new 
CRC32C());
+        DataOutputStream dos = new DataOutputStream(writeStream)) {
+
+      numEvents = computeEventCounts(eventStream, dos, predicateIndex, cutoff);
+      writeChecksum = writeStream.getChecksum().getValue();
+      logger.info("done. {} events", numEvents);
+    }
 
     List<ComparableEvent> eventsToCompare;
-    BigInteger readHash = null;
-    try (HashSumEventStream readStream = new HashSumEventStream(new 
EventStream(tmp))) {
-      eventsToCompare = index(readStream, predicateIndex);
-      readHash = readStream.calculateHashSum();
+    long readChecksum;
+    try (BufferedInputStream in = new BufferedInputStream(new 
FileInputStream(tmp));
+         CheckedInputStream readStream = new CheckedInputStream(in, new 
CRC32C());
+         EventStream readEventsStream = new EventStream(new 
DataInputStream(readStream))) {
+      logger.info("Indexing...");
+      eventsToCompare = index(readEventsStream, predicateIndex);
+      readChecksum = readStream.getChecksum().getValue();
     }
     tmp.delete();
 
-    if (readHash.compareTo(writeHash) != 0)
-      throw new IOException("Event hash for writing and reading events did not 
match.");
+    if (readChecksum != writeChecksum) {
+      throw new IOException("Checksum for writing and reading events did not 
match.");
+    } else {
+      logger.info("done.");
 
-    logger.info("done.");
-
-    if (sort) {
-      logger.info("Sorting and merging events... ");
-    }
-    else {
-      logger.info("Collecting events... ");
+      if (sort) {
+        logger.info("Sorting and merging events... ");
+      }
+      else {
+        logger.info("Collecting events... ");
+      }
+      sortAndMerge(eventsToCompare,sort);
+      logger.info(String.format("Done indexing in %.2f s.", 
(System.currentTimeMillis() - start) / 1000d));
     }
-    sortAndMerge(eventsToCompare,sort);
-    logger.info(String.format("Done indexing in %.2f s.", 
(System.currentTimeMillis() - start) / 1000d));
   }
 
   /**
@@ -170,8 +173,8 @@ public class TwoPassDataIndexer extends AbstractDataIndexer 
{
 
     private final DataInputStream inputStream;
 
-    public EventStream(File file) throws IOException {
-      inputStream = new DataInputStream(new BufferedInputStream(new 
FileInputStream(file)));
+    public EventStream(DataInputStream dataInputStream) {
+      this.inputStream = dataInputStream;
     }
 
     @Override
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java
 
b/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java
new file mode 100644
index 00000000..95d38b29
--- /dev/null
+++ 
b/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.model;
+
+import java.io.IOException;
+
+import org.junit.jupiter.api.Test;
+
+import opennlp.tools.util.ObjectStream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ChecksumEventStreamTest {
+
+  @Test
+  void testCalculateChecksumEquality() throws IOException {
+    ChecksumEventStream ces1 = new 
ChecksumEventStream(createEventStreamFull());
+    ChecksumEventStream ces2 = new 
ChecksumEventStream(createEventStreamFull());
+    consumeEventStream(ces1, 7);
+    consumeEventStream(ces2, 7);
+    
+    long checksum1 = ces1.calculateChecksum();
+    long checksum2 = ces2.calculateChecksum();
+    assertTrue(checksum1 != 0);
+    assertTrue(checksum2 != 0);
+    assertEquals(checksum1, checksum2);
+  }
+
+  @Test
+  void testCalculateChecksumMismatch() throws IOException {
+    ChecksumEventStream ces1 = new 
ChecksumEventStream(createEventStreamFull());
+    ChecksumEventStream ces2 = new 
ChecksumEventStream(createEventStreamPartial());
+    consumeEventStream(ces1, 7);
+    consumeEventStream(ces2, 2);
+
+    long checksum1 = ces1.calculateChecksum();
+    long checksum2 = ces2.calculateChecksum();
+    assertTrue(checksum1 != 0);
+    assertTrue(checksum2 != 0);
+    assertNotEquals(checksum1, checksum2);
+  }
+
+  private ObjectStream<Event> createEventStreamFull() {
+    // He belongs to <START:org> Apache Software Foundation <END> .
+    return new SimpleEventStreamBuilder()
+        .add("other/w=he n1w=belongs n2w=to po=other pow=other,He 
powf=other,ic ppo=other")
+        .add("other/w=belongs p1w=he n1w=to n2w=apache po=other 
pow=other,belongs powf=other,lc ppo=other")
+        .add("other/w=to p1w=belongs p2w=he n1w=apache n2w=software po=other 
pow=other,to" +
+                " powf=other,lc ppo=other")
+        .add("org-start/w=apache p1w=to p2w=belongs n1w=software 
n2w=foundation po=other pow=other,Apache" +
+                " powf=other,ic ppo=other")
+        .add("org-cont/w=software p1w=apache p2w=to n1w=foundation n2w=. 
po=org-start" +
+                " pow=org-start,Software powf=org-start,ic ppo=other")
+        .add("org-cont/w=foundation p1w=software p2w=apache n1w=. po=org-cont 
pow=org-cont,Foundation" +
+                " powf=org-cont,ic ppo=org-start")
+        .add("other/w=. p1w=foundation p2w=software po=org-cont pow=org-cont,. 
powf=org-cont,other" +
+                " ppo=org-cont")
+        .build();
+  }
+
+  private ObjectStream<Event> createEventStreamPartial() {
+    // He .
+    return new SimpleEventStreamBuilder()
+        .add("other/w=he n1w=belongs n2w=to po=other pow=other,He 
powf=other,ic ppo=other")
+        .add("other/w=. p1w=foundation p2w=software po=org-cont pow=org-cont,. 
powf=org-cont,other" +
+                " ppo=org-cont")
+        .build();
+  }
+
+  private void consumeEventStream(ObjectStream<Event> eventStream, int 
eventCount) throws IOException {
+    for (int i = 0; i < eventCount; i++) {
+      assertNotNull(eventStream.read());
+    }
+  }
+}

Reply via email to