This is an automated email from the ASF dual-hosted git repository. tballison pushed a commit to branch junk-detector-v6 in repository https://gitbox.apache.org/repos/asf/tika.git
commit 3efaa019ffc5e5efda6365913a7862cc3f2e817c Author: tballison <[email protected]> AuthorDate: Thu May 14 11:23:02 2026 -0400 v6 mods --- .../org/apache/tika/ml/junkdetect/F1Tables.java | 107 ++++++++++ .../apache/tika/ml/junkdetect/JunkDetector.java | 137 ++++--------- .../tika/ml/junkdetect/tools/TrainJunkModel.java | 215 +++++---------------- .../org/apache/tika/ml/junkdetect/junkdetect.bin | Bin 465105 -> 578012 bytes .../tika/ml/junkdetect/JunkDetectorSmokeTest.java | 7 - .../tika/ml/junkdetect/JunkDetectorV6Test.java | 31 +-- 6 files changed, 206 insertions(+), 291 deletions(-) diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/F1Tables.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/F1Tables.java new file mode 100644 index 0000000000..1d322dc64e --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/F1Tables.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * Carrier for the global Feature 1 (codepoint-bigram hash + Bloom filter + + * unigram-backoff) tables used by the v6 junk-detector model format. + * + * <p>Instances are created by + * {@link org.apache.tika.ml.junkdetect.tools.TrainJunkModel#trainCodepointHashTables} + * and consumed by {@link JunkDetector#computeF1MeanLogP}. Separating this + * type from {@link JunkDetector} lets the trainer import it directly without + * reaching into the inference class, and keeps the fields package-private so + * they are not part of the public API. + */ +public final class F1Tables { + + final byte[] bigramHash; + final int bigramBuckets; + final float bigramQuantMin; + final float bigramQuantMax; + final byte[] unigramHash; + final int unigramBuckets; + final float unigramQuantMin; + final float unigramQuantMax; + final long[] bloomBits; + final int bloomBitCount; + final int bloomK; + final int fnvSeed; + final float backoffAlpha; + + public F1Tables(byte[] bigramHash, int bigramBuckets, + float bigramQuantMin, float bigramQuantMax, + byte[] unigramHash, int unigramBuckets, + float unigramQuantMin, float unigramQuantMax, + long[] bloomBits, int bloomBitCount, int bloomK, + int fnvSeed, float backoffAlpha) { + this.bigramHash = bigramHash; + this.bigramBuckets = bigramBuckets; + this.bigramQuantMin = bigramQuantMin; + this.bigramQuantMax = bigramQuantMax; + this.unigramHash = unigramHash; + this.unigramBuckets = unigramBuckets; + this.unigramQuantMin = unigramQuantMin; + this.unigramQuantMax = unigramQuantMax; + this.bloomBits = bloomBits; + this.bloomBitCount = bloomBitCount; + this.bloomK = bloomK; + this.fnvSeed = fnvSeed; + this.backoffAlpha = backoffAlpha; + } + + /** + * Serializes the global F1 section to {@code dos} in the v6 model binary + * format. Called by + * {@link org.apache.tika.ml.junkdetect.tools.TrainJunkModel#saveModelV6}. + */ + public void writeTo(DataOutputStream dos) throws IOException { + dos.writeInt(fnvSeed); + dos.writeFloat(backoffAlpha); + dos.writeInt(bigramBuckets); + dos.writeFloat(bigramQuantMin); + dos.writeFloat(bigramQuantMax); + dos.write(bigramHash); + dos.writeInt(unigramBuckets); + dos.writeFloat(unigramQuantMin); + dos.writeFloat(unigramQuantMax); + dos.write(unigramHash); + dos.writeInt(bloomBitCount); + dos.writeByte(bloomK); + ByteBuffer bloomBuf = ByteBuffer.allocate(bloomBitCount / 8) + .order(ByteOrder.BIG_ENDIAN); + for (long w : bloomBits) { + bloomBuf.putLong(w); + } + dos.write(bloomBuf.array()); + } + + /** + * Returns a human-readable summary of the quantization ranges, for + * trainer progress output. + */ + public String statsString() { + return String.format( + " bigram quant range: [%.3f, %.3f]%n unigram quant range: [%.3f, %.3f]%n", + bigramQuantMin, bigramQuantMax, unigramQuantMin, unigramQuantMax); + } +} diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java index 9170415828..6513b0f29d 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java @@ -95,27 +95,7 @@ public final class JunkDetector implements TextQualityDetector { static final int VERSION = 6; // Feature 1 — global hashed codepoint-bigram + Bloom-gated unigram backoff - /** Quantized 8-bit log-prob per bigram bucket, size = {@link #cpBigramBuckets}. */ - private final byte[] cpBigramHash; - /** Power of 2. Index = FNV-1a(cp_a, cp_b, seed) & (buckets - 1). */ - private final int cpBigramBuckets; - private final float cpBigramQuantMin; - private final float cpBigramQuantMax; - /** Quantized 8-bit log-prob per unigram bucket, size = {@link #cpUnigramBuckets}. */ - private final byte[] cpUnigramHash; - /** Power of 2. Index = FNV-1a(cp, seed) & (buckets - 1). */ - private final int cpUnigramBuckets; - private final float cpUnigramQuantMin; - private final float cpUnigramQuantMax; - /** Bloom filter bit array. Length = bloomBitCount / 64. */ - private final long[] cpBloomBits; - private final int cpBloomBitCount; - /** Number of hash functions for Bloom filter (double-hashing). */ - private final int cpBloomK; - /** Seed for FNV-1a hash; same value at training and inference. */ - private final int cpFnvSeed; - /** Multiplier on unigram-backoff log-prob when bigram pair is unseen. */ - private final float cpBackoffAlpha; + private final F1Tables f1; /** Per-script F1 calibration on the codepoint-hash mean log-prob. */ private final Map<String, float[]> calibrations; // script → float[2] {mu, sigma} @@ -157,61 +137,7 @@ public final class JunkDetector implements TextQualityDetector { this.scriptTransitionCalibration = scriptTransitionCalibration; this.scriptBucketIndex = Collections.unmodifiableMap(scriptBucketIndex); this.numScriptBuckets = numScriptBuckets; - this.cpBigramHash = f1.bigramHash; - this.cpBigramBuckets = f1.bigramBuckets; - this.cpBigramQuantMin = f1.bigramQuantMin; - this.cpBigramQuantMax = f1.bigramQuantMax; - this.cpUnigramHash = f1.unigramHash; - this.cpUnigramBuckets = f1.unigramBuckets; - this.cpUnigramQuantMin = f1.unigramQuantMin; - this.cpUnigramQuantMax = f1.unigramQuantMax; - this.cpBloomBits = f1.bloomBits; - this.cpBloomBitCount = f1.bloomBitCount; - this.cpBloomK = f1.bloomK; - this.cpFnvSeed = f1.fnvSeed; - this.cpBackoffAlpha = f1.backoffAlpha; - } - - /** - * Carrier for Feature 1 data — global codepoint-bigram hash table, Bloom - * filter, and unigram-backoff hash table. Loaders and tests build - * instances of this and hand them to the {@link JunkDetector} constructor. - */ - static final class F1Tables { - final byte[] bigramHash; - final int bigramBuckets; - final float bigramQuantMin; - final float bigramQuantMax; - final byte[] unigramHash; - final int unigramBuckets; - final float unigramQuantMin; - final float unigramQuantMax; - final long[] bloomBits; - final int bloomBitCount; - final int bloomK; - final int fnvSeed; - final float backoffAlpha; - - F1Tables(byte[] bigramHash, int bigramBuckets, - float bigramQuantMin, float bigramQuantMax, - byte[] unigramHash, int unigramBuckets, - float unigramQuantMin, float unigramQuantMax, - long[] bloomBits, int bloomBitCount, int bloomK, - int fnvSeed, float backoffAlpha) { - this.bigramHash = bigramHash; - this.bigramBuckets = bigramBuckets; - this.bigramQuantMin = bigramQuantMin; - this.bigramQuantMax = bigramQuantMax; - this.unigramHash = unigramHash; - this.unigramBuckets = unigramBuckets; - this.unigramQuantMin = unigramQuantMin; - this.unigramQuantMax = unigramQuantMax; - this.bloomBits = bloomBits; - this.bloomBitCount = bloomBitCount; - this.bloomK = bloomK; - this.fnvSeed = fnvSeed; - this.backoffAlpha = backoffAlpha; - } + this.f1 = f1; } // ----------------------------------------------------------------------- @@ -797,18 +723,27 @@ public final class JunkDetector implements TextQualityDetector { // ----------------------------------------------------------------------- /** - * Mean log-prob over the codepoint pairs in {@code text}. + * Mean log-prob over the codepoint pairs in {@code text} using the given + * F1 tables. * * <p>For each adjacent codepoint pair {@code (a, b)}: if the Bloom filter * reports the pair was seen at training time, return the dequantized * log-prob from the bigram hash bucket. Else fall back to - * {@code alpha * (log P(a) + log P(b))} via the unigram hash table — - * captures "rare-codepoint-pair from mojibake" failure modes that bigram - * collisions can't. + * {@code alpha * (log P(a) + log P(b))} via the unigram hash table. + * + * <p>This is the single authoritative implementation of the F1 scoring + * math, shared between inference ({@link #score}) and training + * ({@link org.apache.tika.ml.junkdetect.tools.TrainJunkModel#calibrateF1PerScript}). + * Keeping one implementation eliminates the risk of train/infer drift in + * the F1 feature — the same risk that motivated making z2/z3/z4 public + * static methods. + * + * @return mean log-prob, or {@link Double#NaN} if {@code text} has fewer + * than two codepoints */ - private float computeCodepointHashMeanLogP(String text) { + public static double computeF1MeanLogP(String text, F1Tables f1) { if (text == null || text.length() < 2) { - return Float.NaN; + return Double.NaN; } double sum = 0; int n = 0; @@ -817,37 +752,43 @@ public final class JunkDetector implements TextQualityDetector { int cp = text.codePointAt(i); i += Character.charCount(cp); if (prevCp >= 0) { - sum += scorePair(prevCp, cp); + sum += scorePairF1(prevCp, cp, f1); n++; } prevCp = cp; } - return n == 0 ? Float.NaN : (float) (sum / n); + return n == 0 ? Double.NaN : sum / n; + } + + /** Thin instance wrapper around the shared static implementation. */ + private float computeCodepointHashMeanLogP(String text) { + double v = computeF1MeanLogP(text, f1); + return Double.isNaN(v) ? Float.NaN : (float) v; } - private double scorePair(int cpA, int cpB) { - if (bloomContains(cpA, cpB)) { - int bucket = (int) (fnv1aBigram(cpA, cpB, cpFnvSeed) & (cpBigramBuckets - 1)); - return dequantize(cpBigramHash[bucket], cpBigramQuantMin, cpBigramQuantMax); + private static double scorePairF1(int cpA, int cpB, F1Tables f1) { + if (bloomContainsF1(cpA, cpB, f1)) { + int bucket = (int) (fnv1aBigram(cpA, cpB, f1.fnvSeed) & (f1.bigramBuckets - 1)); + return dequantize(f1.bigramHash[bucket], f1.bigramQuantMin, f1.bigramQuantMax); } // Unigram backoff. α=1.0 gives plain independence assumption; // α<1 discounts (stupid-backoff style) — prototype showed α=1.0 is // correct for junk discrimination. double ua = dequantize( - cpUnigramHash[(int) (fnv1aUnigram(cpA, cpFnvSeed) & (cpUnigramBuckets - 1))], - cpUnigramQuantMin, cpUnigramQuantMax); + f1.unigramHash[(int) (fnv1aUnigram(cpA, f1.fnvSeed) & (f1.unigramBuckets - 1))], + f1.unigramQuantMin, f1.unigramQuantMax); double ub = dequantize( - cpUnigramHash[(int) (fnv1aUnigram(cpB, cpFnvSeed) & (cpUnigramBuckets - 1))], - cpUnigramQuantMin, cpUnigramQuantMax); - return cpBackoffAlpha * (ua + ub); + f1.unigramHash[(int) (fnv1aUnigram(cpB, f1.fnvSeed) & (f1.unigramBuckets - 1))], + f1.unigramQuantMin, f1.unigramQuantMax); + return f1.backoffAlpha * (ua + ub); } - private boolean bloomContains(int cpA, int cpB) { - long h1 = fnv1aBigram(cpA, cpB, cpFnvSeed); + private static boolean bloomContainsF1(int cpA, int cpB, F1Tables f1) { + long h1 = fnv1aBigram(cpA, cpB, f1.fnvSeed); long h2 = secondaryHash(cpA, cpB); - for (int i = 0; i < cpBloomK; i++) { - long pos = ((h1 + (long) i * h2) & 0x7FFFFFFFFFFFFFFFL) % cpBloomBitCount; - if ((cpBloomBits[(int) (pos >>> 6)] & (1L << (pos & 63))) == 0) { + for (int i = 0; i < f1.bloomK; i++) { + long pos = ((h1 + (long) i * h2) & 0x7FFFFFFFFFFFFFFFL) % f1.bloomBitCount; + if ((f1.bloomBits[(int) (pos >>> 6)] & (1L << (pos & 63))) == 0) { return false; } } diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java index 4c85bac7c5..8855ac9338 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java @@ -41,6 +41,9 @@ import java.util.TreeMap; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; +import org.apache.tika.ml.junkdetect.F1Tables; +import org.apache.tika.ml.junkdetect.JunkDetector; + /** * Trains the junk detector model from per-script corpus files produced by * {@link BuildJunkTrainingData}. @@ -131,12 +134,12 @@ public class TrainJunkModel { // v6 model constants (codepoint-bigram-hash + Bloom + unigram-backoff) // ----------------------------------------------------------------------- - /** Bigram hash bucket count (power of 2). Locked at 4096 from prototype sweep. */ + /** Bigram hash bucket count (power of 2). */ static final int V6_BIGRAM_BUCKETS = 4096; /** Unigram hash bucket count (power of 2). */ static final int V6_UNIGRAM_BUCKETS = 8192; /** Default Bloom filter capacity in bits (must be a multiple of 64). */ - static final int V6_BLOOM_BITS_DEFAULT = 4 * 1024 * 1024; + static final int V6_BLOOM_BITS_DEFAULT = 16 * 1024 * 1024; /** Number of hash functions for the Bloom filter (double-hashing). */ static final int V6_BLOOM_K = 7; /** FNV-1a seed; stored in the model so training and inference stay synced. */ @@ -248,9 +251,8 @@ public class TrainJunkModel { TreeMap<String, float[]> blockCalibrations = new TreeMap<>(); TreeMap<String, float[]> controlCalibrations = new TreeMap<>(); TreeMap<String, float[]> classifierWeights = new TreeMap<>(); - TreeMap<String, Path> devFilePaths = new TreeMap<>(); + TreeMap<String, Path> trainFilePaths = new TreeMap<>(); List<Path> allTrainFiles = new ArrayList<>(); - List<Path> allDevFiles = new ArrayList<>(); List<Path> trainFiles; try (var stream = Files.list(dataDir)) { @@ -271,12 +273,9 @@ public class TrainJunkModel { System.out.println("\n--- Phase 1: global codepoint-hash tables + Bloom ---"); t0 = System.currentTimeMillis(); System.out.print(" Training global codepoint-bigram + unigram + Bloom... "); - V6F1Tables f1Tables = trainCodepointHashTables(trainFiles, bloomBits); + F1Tables f1Tables = trainCodepointHashTables(trainFiles, bloomBits); System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); - System.out.printf(" bigram quant range: [%.3f, %.3f]%n", - f1Tables.bigramQuantMin, f1Tables.bigramQuantMax); - System.out.printf(" unigram quant range: [%.3f, %.3f]%n", - f1Tables.unigramQuantMin, f1Tables.unigramQuantMax); + System.out.print(f1Tables.statsString()); // ----------------------------------------------------------------------- // Phase 1.5 — per-script F1 calibration + F2 block tables + F3 control cal @@ -286,45 +285,34 @@ public class TrainJunkModel { String filename = trainFile.getFileName().toString(); String script = filename.substring(0, filename.length() - ".train.gz".length()) .toUpperCase(); - Path devFile = trainFile.getParent().resolve( - filename.replace(".train.gz", ".dev.gz")); System.out.printf("%n [%s]%n", script); allTrainFiles.add(trainFile); t0 = System.currentTimeMillis(); - System.out.print(" Training named-block table... "); + System.out.print(" Training named-block table... "); float[] blockTable = trainBlockTable(trainFile); System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); - float[] f1Cal = new float[]{0f, 1f}; - float[] blockCal = new float[]{0f, 1f}; - float[] controlCal = new float[]{0f, 1f}; - - if (Files.exists(devFile)) { - t0 = System.currentTimeMillis(); - System.out.print(" Calibrating F1 (cp-hash) on dev... "); - f1Cal = calibrateF1PerScript(devFile, f1Tables); - System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", - f1Cal[0], f1Cal[1], System.currentTimeMillis() - t0); - - t0 = System.currentTimeMillis(); - System.out.print(" Calibrating named blocks on dev..."); - blockCal = computeBlockCalibration(devFile, blockTable); - System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", - blockCal[0], blockCal[1], System.currentTimeMillis() - t0); - - t0 = System.currentTimeMillis(); - System.out.print(" Calibrating control bytes on dev.."); - controlCal = computeControlByteCalibration(devFile); - System.out.printf("done — mu=%.6f sigma=%.6f (%dms)%n", - controlCal[0], controlCal[1], System.currentTimeMillis() - t0); - - devFilePaths.put(script, devFile); - allDevFiles.add(devFile); - } else { - System.out.println(" WARNING: no dev file found, using uncalibrated defaults"); - } + t0 = System.currentTimeMillis(); + System.out.print(" Calibrating F1 (cp-hash) on train.. "); + float[] f1Cal = calibrateF1PerScript(trainFile, f1Tables); + System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", + f1Cal[0], f1Cal[1], System.currentTimeMillis() - t0); + + t0 = System.currentTimeMillis(); + System.out.print(" Calibrating named blocks on train..."); + float[] blockCal = computeBlockCalibration(trainFile, blockTable); + System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", + blockCal[0], blockCal[1], System.currentTimeMillis() - t0); + + t0 = System.currentTimeMillis(); + System.out.print(" Calibrating control bytes on train.."); + float[] controlCal = computeControlByteCalibration(trainFile); + System.out.printf("done — mu=%.6f sigma=%.6f (%dms)%n", + controlCal[0], controlCal[1], System.currentTimeMillis() - t0); + + trainFilePaths.put(script, trainFile); f1Calibrations.put(script, f1Cal); blockTables.put(script, blockTable); @@ -353,7 +341,7 @@ public class TrainJunkModel { t0 = System.currentTimeMillis(); System.out.print(" Calibrating script transitions... "); - float[] scriptTransCal = calibrateScriptTransitions(allDevFiles, scriptTransTable, + float[] scriptTransCal = calibrateScriptTransitions(allTrainFiles, scriptTransTable, scriptBucketMap, numScriptBuckets); System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", scriptTransCal[0], scriptTransCal[1], System.currentTimeMillis() - t0); @@ -377,14 +365,14 @@ public class TrainJunkModel { // ----------------------------------------------------------------------- System.out.println("\n--- Phase 3: per-script linear classifiers (z1,z2,z3,z4) ---"); for (String script : f1Calibrations.keySet()) { - Path devFile = devFilePaths.get(script); - if (devFile == null) { - System.out.printf(" [%s] WARNING: no dev file, keeping equal-weight defaults%n", script); + Path trainFile = trainFilePaths.get(script); + if (trainFile == null) { + System.out.printf(" [%s] WARNING: no train file, keeping equal-weight defaults%n", script); continue; } t0 = System.currentTimeMillis(); System.out.printf(" [%s] training classifier... ", script); - float[] weights = trainClassifierV6(devFile, + float[] weights = trainClassifierV6(trainFile, f1Tables, f1Calibrations.get(script), blockTables.get(script), blockCalibrations.get(script), controlCalibrations.get(script), @@ -726,50 +714,6 @@ public class TrainJunkModel { // v6 model save + carrier // ----------------------------------------------------------------------- - /** - * Carrier for v6 Feature 1 data — global hashed codepoint-bigram + Bloom- - * gated unigram-backoff replacing v5's per-script byte-bigram tables. - */ - public static final class V6F1Tables { - /** 8-bit quantized log-prob per bigram bucket. */ - public final byte[] bigramHash; - public final int bigramBuckets; - public final float bigramQuantMin; - public final float bigramQuantMax; - /** 8-bit quantized log-prob per unigram bucket. */ - public final byte[] unigramHash; - public final int unigramBuckets; - public final float unigramQuantMin; - public final float unigramQuantMax; - /** Bloom bit array; length = bloomBitCount / 64. */ - public final long[] bloomBits; - public final int bloomBitCount; - public final int bloomK; - public final int fnvSeed; - public final float backoffAlpha; - - public V6F1Tables(byte[] bigramHash, int bigramBuckets, - float bigramQuantMin, float bigramQuantMax, - byte[] unigramHash, int unigramBuckets, - float unigramQuantMin, float unigramQuantMax, - long[] bloomBits, int bloomBitCount, int bloomK, - int fnvSeed, float backoffAlpha) { - this.bigramHash = bigramHash; - this.bigramBuckets = bigramBuckets; - this.bigramQuantMin = bigramQuantMin; - this.bigramQuantMax = bigramQuantMax; - this.unigramHash = unigramHash; - this.unigramBuckets = unigramBuckets; - this.unigramQuantMin = unigramQuantMin; - this.unigramQuantMax = unigramQuantMax; - this.bloomBits = bloomBits; - this.bloomBitCount = bloomBitCount; - this.bloomK = bloomK; - this.fnvSeed = fnvSeed; - this.backoffAlpha = backoffAlpha; - } - } - // ----------------------------------------------------------------------- // v6 Phase 1: global codepoint-hash training // ----------------------------------------------------------------------- @@ -780,15 +724,16 @@ public class TrainJunkModel { * unigrams into a {@link #V6_UNIGRAM_BUCKETS}-sized table, and inserts * every observed pair into a Bloom filter sized at {@code bloomBits}. * Quantizes both log-prob tables to 8-bit unsigned. Returns a - * {@link V6F1Tables} ready to hand to {@link #saveModelV6}. + * {@link org.apache.tika.ml.junkdetect.F1Tables} ready to + * hand to {@link #saveModelV6}. * - * <p>Hash function and Bloom-filter scheme are identical to those - * used at inference time in {@link JunkDetector#computeCodepointHashMeanLogP} + * <p>Hash function and Bloom-filter scheme are identical to those used at + * inference time in {@link org.apache.tika.ml.junkdetect.JunkDetector#computeF1MeanLogP} * — same seed, same FNV-1a polynomial, same double-hashing layout — * so the trained Bloom filter accurately reflects which pairs the * scorer will treat as "seen". */ - public static V6F1Tables trainCodepointHashTables(List<Path> trainFiles, int bloomBits) + public static F1Tables trainCodepointHashTables(List<Path> trainFiles, int bloomBits) throws IOException { long[] bigramCounts = new long[V6_BIGRAM_BUCKETS]; long[] unigramCounts = new long[V6_UNIGRAM_BUCKETS]; @@ -839,7 +784,8 @@ public class TrainJunkModel { QuantizedFloats qBigram = quantizeFloats(bigramLogP); QuantizedFloats qUnigram = quantizeFloats(unigramLogP); - return new V6F1Tables(qBigram.bytes, V6_BIGRAM_BUCKETS, + return new F1Tables( + qBigram.bytes, V6_BIGRAM_BUCKETS, qBigram.min, qBigram.max, qUnigram.bytes, V6_UNIGRAM_BUCKETS, qUnigram.min, qUnigram.max, @@ -850,15 +796,16 @@ public class TrainJunkModel { /** * Computes per-script F1 calibration ({mu, sigma}) by scoring each * window in the dev file against the trained codepoint-hash tables - * and collecting the mean log-prob distribution. Same math as - * inference-time z1 computation, sans the per-script z-normalization - * (that's what this is producing). + * and collecting the mean log-prob distribution. Delegates to + * {@link org.apache.tika.ml.junkdetect.JunkDetector#computeF1MeanLogP} + * — the single authoritative F1 implementation shared between training + * and inference. */ - public static float[] calibrateF1PerScript(Path devGz, V6F1Tables f1) throws IOException { + public static float[] calibrateF1PerScript(Path devGz, F1Tables f1) throws IOException { List<String> windows = sampleSubstrings(devGz, CALIB_SAMPLES, CALIB_LENGTHS, 42); List<Double> scores = new ArrayList<>(windows.size()); for (String window : windows) { - double score = codepointHashMeanLogP(window, f1); + double score = JunkDetector.computeF1MeanLogP(window, f1); if (!Double.isNaN(score)) { scores.add(score); } @@ -867,47 +814,6 @@ public class TrainJunkModel { return muSigma(scores); } - /** - * Mirrors {@link JunkDetector#computeCodepointHashMeanLogP} but for - * use inside the trainer (where we have a {@link V6F1Tables} rather - * than a loaded {@link JunkDetector} instance). Same math by - * construction — if these two ever diverge, the v6 retrain produces - * a model whose inference scores don't match its training-time - * calibration. - */ - public static double codepointHashMeanLogP(String text, V6F1Tables f1) { - if (text == null || text.length() < 2) { - return Double.NaN; - } - double sum = 0; - int n = 0; - int prevCp = -1; - for (int i = 0; i < text.length(); ) { - int cp = text.codePointAt(i); - i += Character.charCount(cp); - if (prevCp >= 0) { - sum += scorePair(prevCp, cp, f1); - n++; - } - prevCp = cp; - } - return n == 0 ? Double.NaN : sum / n; - } - - private static double scorePair(int cpA, int cpB, V6F1Tables f1) { - if (bloomContainsV6(f1.bloomBits, f1.bloomBitCount, f1.bloomK, cpA, cpB, f1.fnvSeed)) { - int bucket = (int) (fnv1aBigramV6(cpA, cpB, f1.fnvSeed) & (f1.bigramBuckets - 1)); - return dequantize(f1.bigramHash[bucket], f1.bigramQuantMin, f1.bigramQuantMax); - } - double ua = dequantize( - f1.unigramHash[(int) (fnv1aUnigramV6(cpA, f1.fnvSeed) & (f1.unigramBuckets - 1))], - f1.unigramQuantMin, f1.unigramQuantMax); - double ub = dequantize( - f1.unigramHash[(int) (fnv1aUnigramV6(cpB, f1.fnvSeed) & (f1.unigramBuckets - 1))], - f1.unigramQuantMin, f1.unigramQuantMax); - return f1.backoffAlpha * (ua + ub); - } - /** Bloom membership check (matches JunkDetector.bloomContains semantics). */ public static boolean bloomContainsV6(long[] bloomBits, int bitCount, int k, int cpA, int cpB, int seed) { @@ -939,7 +845,7 @@ public class TrainJunkModel { * @return float[4] = {z1_cpHash, z2_block, z3_control, z4_scriptTrans} */ static float[] extractFeaturesV6(String window, - V6F1Tables f1, float[] f1Cal, + F1Tables f1, float[] f1Cal, float[] blockTable, float[] blockCal, float[] controlCal, float[] scriptTransTable, float[] scriptTransCal, @@ -949,7 +855,7 @@ public class TrainJunkModel { // z1: codepoint-hash mean log-prob, per-script-calibrated float z1 = 0f; - double rawF1 = codepointHashMeanLogP(window, f1); + double rawF1 = JunkDetector.computeF1MeanLogP(window, f1); if (!Double.isNaN(rawF1) && f1Cal != null && f1Cal[1] > 0) { z1 = ((float) rawF1 - f1Cal[0]) / f1Cal[1]; } @@ -972,7 +878,7 @@ public class TrainJunkModel { * on short windows) but uses the codepoint-hash feature extractor. */ static float[] trainClassifierV6(Path devGz, - V6F1Tables f1, float[] f1Cal, + F1Tables f1, float[] f1Cal, float[] blockTable, float[] blockCal, float[] controlCal, float[] scriptTransTable, float[] scriptTransCal, @@ -1076,7 +982,7 @@ public class TrainJunkModel { List<String> scriptBuckets, float[] scriptTransTable, float[] scriptTransCal, - V6F1Tables v6, + F1Tables v6, Path output) throws IOException { try (DataOutputStream dos = new DataOutputStream( new GZIPOutputStream(Files.newOutputStream(output)))) { @@ -1102,25 +1008,8 @@ public class TrainJunkModel { dos.writeFloat(scriptTransCal[0]); dos.writeFloat(scriptTransCal[1]); - // Global F1 section (v6+) — new - dos.writeInt(v6.fnvSeed); - dos.writeFloat(v6.backoffAlpha); - dos.writeInt(v6.bigramBuckets); - dos.writeFloat(v6.bigramQuantMin); - dos.writeFloat(v6.bigramQuantMax); - dos.write(v6.bigramHash); - dos.writeInt(v6.unigramBuckets); - dos.writeFloat(v6.unigramQuantMin); - dos.writeFloat(v6.unigramQuantMax); - dos.write(v6.unigramHash); - dos.writeInt(v6.bloomBitCount); - dos.writeByte(v6.bloomK); - ByteBuffer bloomBuf = ByteBuffer.allocate(v6.bloomBitCount / 8) - .order(ByteOrder.BIG_ENDIAN); - for (long w : v6.bloomBits) { - bloomBuf.putLong(w); - } - dos.write(bloomBuf.array()); + // Global F1 section (v6+) + v6.writeTo(dos); // Per-script section — v6 drops the per-script byte-bigram table. // mu1/sigma1 remain (calibrated on the codepoint-hash score). diff --git a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin index feb9da112e..3d76d288fb 100644 Binary files a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin and b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin differ diff --git a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java index 28c753f6d0..a277d2d79f 100644 --- a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java +++ b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java @@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets; import java.util.Random; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.apache.tika.quality.TextQualityComparison; @@ -33,13 +32,7 @@ import org.apache.tika.quality.TextQualityScore; * Smoke tests verifying the bundled model meets minimum quality thresholds. * Failures indicate the model needs more data or feature extraction is wrong. * - * <p><b>Disabled on this branch.</b> The bundled {@code junkdetect.bin} is - * still the previous format and is rejected by the strict - * {@link JunkDetector#load} loader. Re-enable these tests once the retrain - * lands a new bundled model in the current file format. See the planning - * doc at {@code 20260512-junkdetector-codepoint-hash-plan.md}. */ -@Disabled("Bundled junkdetect.bin is the previous format; re-enable after retrain") public class JunkDetectorSmokeTest { private static JunkDetector detector; diff --git a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV6Test.java b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV6Test.java index cff7589a1d..1b35554e40 100644 --- a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV6Test.java +++ b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV6Test.java @@ -86,7 +86,7 @@ public class JunkDetectorV6Test { long[] bloom = new long[(bloomBits + 63) >> 6]; TrainJunkModel.bloomAddV6(bloom, bloomBits, bloomK, 'A', 'B', seed); - TrainJunkModel.V6F1Tables v6Tables = new TrainJunkModel.V6F1Tables( + F1Tables v6Tables = new F1Tables( qBigram.bytes, bigramBuckets, qBigram.min, qBigram.max, qUnigram.bytes, unigramBuckets, qUnigram.min, qUnigram.max, bloom, bloomBits, bloomK, seed, 1.0f); @@ -183,7 +183,7 @@ public class JunkDetectorV6Test { TrainJunkModel.bloomAddV6(bloom, bloomBits, bloomK, 'A', 'B', seed); TrainJunkModel.bloomAddV6(bloom, bloomBits, bloomK, 'B', 'A', seed); - TrainJunkModel.V6F1Tables v6Tables = new TrainJunkModel.V6F1Tables( + F1Tables v6Tables = new F1Tables( qBigram.bytes, bigramBuckets, qBigram.min, qBigram.max, qUnigram.bytes, unigramBuckets, qUnigram.min, qUnigram.max, bloom, bloomBits, bloomK, seed, 1.0f); @@ -198,27 +198,12 @@ public class JunkDetectorV6Test { "All-seen 'ABAB' should score z ≈ +4"); } - @Test - void oldFormatModelIsRejected() { - // Strict invariant on this branch: only the current file-format - // version is accepted. The bundled junkdetect.bin from the previous - // architecture must fail to load with a clear error rather than - // silently scoring through a fallback path. - IOException ex = org.junit.jupiter.api.Assertions.assertThrows( - IOException.class, - JunkDetector::loadFromClasspath, - "Bundled previous-format model should be rejected"); - org.junit.jupiter.api.Assertions.assertTrue( - ex.getMessage().contains("Unsupported model format version"), - "Error message should mention unsupported version, was: " + ex.getMessage()); - } - // ----------------------------------------------------------------------- // Helper — minimal LATIN-only v6 model for tests that only need to // exercise scoring of LATIN text. // ----------------------------------------------------------------------- - private static void saveMinimalV6Model(TrainJunkModel.V6F1Tables v6, + private static void saveMinimalV6Model(F1Tables v6, Path modelFile) throws IOException { TreeMap<String, float[]> f1Cal = new TreeMap<>(); f1Cal.put("LATIN", new float[]{-5.0f, 1.0f}); @@ -278,7 +263,7 @@ public class JunkDetectorV6Test { // --- 2. Phase 1: train codepoint-hash tables --- // Use a small Bloom (64 KB) — the synthetic corpus has only a // few hundred unique pairs. - TrainJunkModel.V6F1Tables f1 = TrainJunkModel.trainCodepointHashTables( + F1Tables f1 = TrainJunkModel.trainCodepointHashTables( List.of(trainFile), 524288); // Sanity: Bloom should contain pairs we observed in training. @@ -293,7 +278,7 @@ public class JunkDetectorV6Test { "Bloom should contain (o, x) — appears in training"); // --- 3. F1 raw scoring sanity --- - double meanLogP = TrainJunkModel.codepointHashMeanLogP( + double meanLogP = JunkDetector.computeF1MeanLogP( "the quick brown fox", f1); assertTrue(Double.isFinite(meanLogP), "Mean log-prob on training text should be finite, got " + meanLogP); @@ -348,12 +333,12 @@ public class JunkDetectorV6Test { // --- 7. Train/infer consistency check --- // The inference path should compute the same raw F1 score as - // the trainer's codepointHashMeanLogP on the same text — if these + // JunkDetector.computeF1MeanLogP on the same text — if these // two ever disagree, the model's calibration is silently wrong. // We can verify indirectly: score same text using - // codepointHashMeanLogP and re-derive z1 manually. + // computeF1MeanLogP and re-derive z1 manually. String probe = "pack my box with five dozen liquor jugs"; - double trainerRawMean = TrainJunkModel.codepointHashMeanLogP(probe, f1); + double trainerRawMean = JunkDetector.computeF1MeanLogP(probe, f1); float expectedZ1 = (float) (trainerRawMean - f1CalLatin[0]) / f1CalLatin[1]; TextQualityScore probeScore = detector.score(probe); // logit = w1 * z1 + 0 + 0 + 0 + 0 = z1 in this test configuration.
