This is an automated email from the ASF dual-hosted git repository.
mawiesne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 64844b0d OPENNLP-1556 Improve speed of checksum computation in
TwoPassDataIndexer (#600)
64844b0d is described below
commit 64844b0de162122d69b3a7e5987b9b4b68618a62
Author: Martin Wiesner <[email protected]>
AuthorDate: Tue May 7 08:50:25 2024 +0200
OPENNLP-1556 Improve speed of checksum computation in TwoPassDataIndexer
(#600)
- adjusts TwoPassDataIndexer to make use of JDK's built-in
CheckedOutputStream / CheckedInputStream for checksum (CRC32c) computations
- removes untested class HashSumEventStream which is just a wrapper for
calling a slow toString() in Event to get some bytes to use for the computation
of a checksum
- provides a HashSumEventStream replacement: ChecksumEventStream which
makes use of the faster CRC32c checksum computation, avoiding cryptographic
hash functions such as MD5
- adds JUnit tests for ChecksumEventStream
---
.../opennlp/tools/ml/AbstractEventTrainer.java | 6 +-
...umEventStream.java => ChecksumEventStream.java} | 51 ++++++------
.../opennlp/tools/ml/model/TwoPassDataIndexer.java | 65 +++++++--------
.../tools/ml/model/ChecksumEventStreamTest.java | 93 ++++++++++++++++++++++
4 files changed, 155 insertions(+), 60 deletions(-)
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
index d546739a..9ea5ddce 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
@@ -20,10 +20,10 @@ package opennlp.tools.ml;
import java.io.IOException;
import opennlp.tools.ml.model.AbstractDataIndexer;
+import opennlp.tools.ml.model.ChecksumEventStream;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.DataIndexerFactory;
import opennlp.tools.ml.model.Event;
-import opennlp.tools.ml.model.HashSumEventStream;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.util.InsufficientTrainingDataException;
import opennlp.tools.util.ObjectStream;
@@ -85,10 +85,10 @@ public abstract class AbstractEventTrainer extends
AbstractTrainer implements Ev
public final MaxentModel train(ObjectStream<Event> events) throws
IOException {
validate();
- HashSumEventStream hses = new HashSumEventStream(events);
+ ChecksumEventStream hses = new ChecksumEventStream(events);
DataIndexer indexer = getDataIndexer(hses);
- addToReport("Training-Eventhash", hses.calculateHashSum().toString(16));
+ addToReport("Training-Eventhash",
String.valueOf(hses.calculateChecksum()));
return train(indexer);
}
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java
similarity index 52%
rename from
opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java
rename to
opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java
index 6fafb243..52af8edc 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java
+++
b/opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java
@@ -18,57 +18,56 @@
package opennlp.tools.ml.model;
import java.io.IOException;
-import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
-import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
+import java.util.zip.CRC32C;
+import java.util.zip.Checksum;
import opennlp.tools.util.AbstractObjectStream;
import opennlp.tools.util.ObjectStream;
/**
- * A hash sum based {@link AbstractObjectStream} implementation.
+ * A {@link Checksum}-based {@link AbstractObjectStream event stream}
implementation.
+ * Computes the checksum while consuming the event stream.
+ * By default, this implementation will use {@link CRC32C} for checksum
calculations
+ * as it can use of CPU-specific acceleration instructions at runtime.
*
* @see Event
- * @see MessageDigest
+ * @see Checksum
* @see AbstractObjectStream
*/
-public class HashSumEventStream extends AbstractObjectStream<Event> {
+public class ChecksumEventStream extends AbstractObjectStream<Event> {
- private final MessageDigest digest;
+ private final Checksum checksum;
- public HashSumEventStream(ObjectStream<Event> eventStream) {
- super(eventStream);
- try {
- digest = MessageDigest.getInstance("MD5");
- } catch (NoSuchAlgorithmException e) {
- // should never happen: do all java runtimes have md5 ?!
- throw new IllegalStateException(e);
- }
+ /**
+ * Initializes an {@link ChecksumEventStream}.
+ *
+ * @param eventStream The {@link ObjectStream} that provides the {@link
Event} samples.
+ */
+ public ChecksumEventStream(ObjectStream<Event> eventStream) {
+ super(eventStream);
+ // CRC32C supports CPU-specific acceleration instructions
+ checksum = new CRC32C();
}
@Override
public Event read() throws IOException {
Event event = super.read();
-
if (event != null) {
- digest.update(event.toString().getBytes(StandardCharsets.UTF_8));
+ checksum.update(event.toString().getBytes(StandardCharsets.UTF_8));
}
-
return event;
}
/**
- * Calculates the hash sum of the stream and wraps it into a {@link
BigInteger}.
- * Note: The method must be called after the stream is completely consumed.
+ * Calculates and returns the (current) checksum.
+ * <p>
+ * Note: This should be called once the underlying stream has been (fully)
consumed.
*
- * @return The calculated hash sum as {@link BigInteger}.
- * @throws IllegalStateException Thrown if the stream is not consumed
completely,
- * completely means that hasNext() returns {@code false}.
+ * @return The calculated checksum as {@code long}.
*/
- public BigInteger calculateHashSum() {
- return new BigInteger(1, digest.digest());
+ public long calculateChecksum() {
+ return checksum.getValue();
}
-
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
index dd67dc21..0e49a4bd 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
@@ -17,7 +17,6 @@
package opennlp.tools.ml.model;
-
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
@@ -26,11 +25,13 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
-import java.math.BigInteger;
import java.nio.file.Files;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.zip.CRC32C;
+import java.util.zip.CheckedInputStream;
+import java.util.zip.CheckedOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -63,48 +64,50 @@ public class TwoPassDataIndexer extends AbstractDataIndexer
{
int cutoff = trainingParameters.getIntParameter(CUTOFF_PARAM,
CUTOFF_DEFAULT);
boolean sort = trainingParameters.getBooleanParameter(SORT_PARAM,
SORT_DEFAULT);
- long start = System.currentTimeMillis();
-
logger.info("Indexing events with TwoPass using cutoff of {}", cutoff);
-
logger.info("Computing event counts...");
+ long start = System.currentTimeMillis();
Map<String,Integer> predicateIndex = new HashMap<>();
-
File tmp = Files.createTempFile("events", null).toFile();
tmp.deleteOnExit();
int numEvents;
- BigInteger writeHash;
- HashSumEventStream writeEventStream = new HashSumEventStream(eventStream);
// do not close.
- try (DataOutputStream dos = new DataOutputStream(new
BufferedOutputStream(new FileOutputStream(tmp)))) {
- numEvents = computeEventCounts(writeEventStream, dos, predicateIndex,
cutoff);
- }
- writeHash = writeEventStream.calculateHashSum();
+ long writeChecksum;
- logger.info("done. {} events", numEvents);
- logger.info("Indexing...");
+ try (BufferedOutputStream out = new BufferedOutputStream(new
FileOutputStream(tmp));
+ CheckedOutputStream writeStream = new CheckedOutputStream(out, new
CRC32C());
+ DataOutputStream dos = new DataOutputStream(writeStream)) {
+
+ numEvents = computeEventCounts(eventStream, dos, predicateIndex, cutoff);
+ writeChecksum = writeStream.getChecksum().getValue();
+ logger.info("done. {} events", numEvents);
+ }
List<ComparableEvent> eventsToCompare;
- BigInteger readHash = null;
- try (HashSumEventStream readStream = new HashSumEventStream(new
EventStream(tmp))) {
- eventsToCompare = index(readStream, predicateIndex);
- readHash = readStream.calculateHashSum();
+ long readChecksum;
+ try (BufferedInputStream in = new BufferedInputStream(new
FileInputStream(tmp));
+ CheckedInputStream readStream = new CheckedInputStream(in, new
CRC32C());
+ EventStream readEventsStream = new EventStream(new
DataInputStream(readStream))) {
+ logger.info("Indexing...");
+ eventsToCompare = index(readEventsStream, predicateIndex);
+ readChecksum = readStream.getChecksum().getValue();
}
tmp.delete();
- if (readHash.compareTo(writeHash) != 0)
- throw new IOException("Event hash for writing and reading events did not
match.");
+ if (readChecksum != writeChecksum) {
+ throw new IOException("Checksum for writing and reading events did not
match.");
+ } else {
+ logger.info("done.");
- logger.info("done.");
-
- if (sort) {
- logger.info("Sorting and merging events... ");
- }
- else {
- logger.info("Collecting events... ");
+ if (sort) {
+ logger.info("Sorting and merging events... ");
+ }
+ else {
+ logger.info("Collecting events... ");
+ }
+ sortAndMerge(eventsToCompare,sort);
+ logger.info(String.format("Done indexing in %.2f s.",
(System.currentTimeMillis() - start) / 1000d));
}
- sortAndMerge(eventsToCompare,sort);
- logger.info(String.format("Done indexing in %.2f s.",
(System.currentTimeMillis() - start) / 1000d));
}
/**
@@ -170,8 +173,8 @@ public class TwoPassDataIndexer extends AbstractDataIndexer
{
private final DataInputStream inputStream;
- public EventStream(File file) throws IOException {
- inputStream = new DataInputStream(new BufferedInputStream(new
FileInputStream(file)));
+ public EventStream(DataInputStream dataInputStream) {
+ this.inputStream = dataInputStream;
}
@Override
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java
new file mode 100644
index 00000000..95d38b29
--- /dev/null
+++
b/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.model;
+
+import java.io.IOException;
+
+import org.junit.jupiter.api.Test;
+
+import opennlp.tools.util.ObjectStream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ChecksumEventStreamTest {
+
+ @Test
+ void testCalculateChecksumEquality() throws IOException {
+ ChecksumEventStream ces1 = new
ChecksumEventStream(createEventStreamFull());
+ ChecksumEventStream ces2 = new
ChecksumEventStream(createEventStreamFull());
+ consumeEventStream(ces1, 7);
+ consumeEventStream(ces2, 7);
+
+ long checksum1 = ces1.calculateChecksum();
+ long checksum2 = ces2.calculateChecksum();
+ assertTrue(checksum1 != 0);
+ assertTrue(checksum2 != 0);
+ assertEquals(checksum1, checksum2);
+ }
+
+ @Test
+ void testCalculateChecksumMismatch() throws IOException {
+ ChecksumEventStream ces1 = new
ChecksumEventStream(createEventStreamFull());
+ ChecksumEventStream ces2 = new
ChecksumEventStream(createEventStreamPartial());
+ consumeEventStream(ces1, 7);
+ consumeEventStream(ces2, 2);
+
+ long checksum1 = ces1.calculateChecksum();
+ long checksum2 = ces2.calculateChecksum();
+ assertTrue(checksum1 != 0);
+ assertTrue(checksum2 != 0);
+ assertNotEquals(checksum1, checksum2);
+ }
+
+ private ObjectStream<Event> createEventStreamFull() {
+ // He belongs to <START:org> Apache Software Foundation <END> .
+ return new SimpleEventStreamBuilder()
+ .add("other/w=he n1w=belongs n2w=to po=other pow=other,He
powf=other,ic ppo=other")
+ .add("other/w=belongs p1w=he n1w=to n2w=apache po=other
pow=other,belongs powf=other,lc ppo=other")
+ .add("other/w=to p1w=belongs p2w=he n1w=apache n2w=software po=other
pow=other,to" +
+ " powf=other,lc ppo=other")
+ .add("org-start/w=apache p1w=to p2w=belongs n1w=software
n2w=foundation po=other pow=other,Apache" +
+ " powf=other,ic ppo=other")
+ .add("org-cont/w=software p1w=apache p2w=to n1w=foundation n2w=.
po=org-start" +
+ " pow=org-start,Software powf=org-start,ic ppo=other")
+ .add("org-cont/w=foundation p1w=software p2w=apache n1w=. po=org-cont
pow=org-cont,Foundation" +
+ " powf=org-cont,ic ppo=org-start")
+ .add("other/w=. p1w=foundation p2w=software po=org-cont pow=org-cont,.
powf=org-cont,other" +
+ " ppo=org-cont")
+ .build();
+ }
+
+ private ObjectStream<Event> createEventStreamPartial() {
+ // He .
+ return new SimpleEventStreamBuilder()
+ .add("other/w=he n1w=belongs n2w=to po=other pow=other,He
powf=other,ic ppo=other")
+ .add("other/w=. p1w=foundation p2w=software po=org-cont pow=org-cont,.
powf=org-cont,other" +
+ " ppo=org-cont")
+ .build();
+ }
+
+ private void consumeEventStream(ObjectStream<Event> eventStream, int
eventCount) throws IOException {
+ for (int i = 0; i < eventCount; i++) {
+ assertNotNull(eventStream.read());
+ }
+ }
+}