This is an automated email from the ASF dual-hosted git repository. tballison pushed a commit to branch junk-detector-v6 in repository https://gitbox.apache.org/repos/asf/tika.git
commit 0e08c2d80a7db538c3fcb52b47b519f7e47b7293 Author: tballison <[email protected]> AuthorDate: Thu May 14 14:45:39 2026 -0400 checkpoint --- .../org/apache/tika/ml/junkdetect/F1Tables.java | 107 ----- .../apache/tika/ml/junkdetect/JunkDetector.java | 290 ++++++------ .../org/apache/tika/ml/junkdetect/V7Tables.java | 204 ++++++++ .../tools/JunkDetectorTrainingConfig.java | 46 +- .../tika/ml/junkdetect/tools/TrainJunkModel.java | 517 ++++++++++----------- .../org/apache/tika/ml/junkdetect/junkdetect.bin | Bin 578012 -> 2810396 bytes .../tika/ml/junkdetect/JunkDetectorV6Test.java | 361 -------------- .../tika/ml/junkdetect/JunkDetectorV7Test.java | 351 ++++++++++++++ .../tools/JunkDetectorTrainingConfigTest.java | 15 +- 9 files changed, 998 insertions(+), 893 deletions(-) diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/F1Tables.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/F1Tables.java deleted file mode 100644 index 1d322dc64e..0000000000 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/F1Tables.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.tika.ml.junkdetect; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -/** - * Carrier for the global Feature 1 (codepoint-bigram hash + Bloom filter + - * unigram-backoff) tables used by the v6 junk-detector model format. - * - * <p>Instances are created by - * {@link org.apache.tika.ml.junkdetect.tools.TrainJunkModel#trainCodepointHashTables} - * and consumed by {@link JunkDetector#computeF1MeanLogP}. Separating this - * type from {@link JunkDetector} lets the trainer import it directly without - * reaching into the inference class, and keeps the fields package-private so - * they are not part of the public API. - */ -public final class F1Tables { - - final byte[] bigramHash; - final int bigramBuckets; - final float bigramQuantMin; - final float bigramQuantMax; - final byte[] unigramHash; - final int unigramBuckets; - final float unigramQuantMin; - final float unigramQuantMax; - final long[] bloomBits; - final int bloomBitCount; - final int bloomK; - final int fnvSeed; - final float backoffAlpha; - - public F1Tables(byte[] bigramHash, int bigramBuckets, - float bigramQuantMin, float bigramQuantMax, - byte[] unigramHash, int unigramBuckets, - float unigramQuantMin, float unigramQuantMax, - long[] bloomBits, int bloomBitCount, int bloomK, - int fnvSeed, float backoffAlpha) { - this.bigramHash = bigramHash; - this.bigramBuckets = bigramBuckets; - this.bigramQuantMin = bigramQuantMin; - this.bigramQuantMax = bigramQuantMax; - this.unigramHash = unigramHash; - this.unigramBuckets = unigramBuckets; - this.unigramQuantMin = unigramQuantMin; - this.unigramQuantMax = unigramQuantMax; - this.bloomBits = bloomBits; - this.bloomBitCount = bloomBitCount; - this.bloomK = bloomK; - this.fnvSeed = fnvSeed; - this.backoffAlpha = backoffAlpha; - } - - /** - * Serializes the global F1 section to {@code dos} in the v6 model binary - * format. Called by - * {@link org.apache.tika.ml.junkdetect.tools.TrainJunkModel#saveModelV6}. - */ - public void writeTo(DataOutputStream dos) throws IOException { - dos.writeInt(fnvSeed); - dos.writeFloat(backoffAlpha); - dos.writeInt(bigramBuckets); - dos.writeFloat(bigramQuantMin); - dos.writeFloat(bigramQuantMax); - dos.write(bigramHash); - dos.writeInt(unigramBuckets); - dos.writeFloat(unigramQuantMin); - dos.writeFloat(unigramQuantMax); - dos.write(unigramHash); - dos.writeInt(bloomBitCount); - dos.writeByte(bloomK); - ByteBuffer bloomBuf = ByteBuffer.allocate(bloomBitCount / 8) - .order(ByteOrder.BIG_ENDIAN); - for (long w : bloomBits) { - bloomBuf.putLong(w); - } - dos.write(bloomBuf.array()); - } - - /** - * Returns a human-readable summary of the quantization ranges, for - * trainer progress output. - */ - public String statsString() { - return String.format( - " bigram quant range: [%.3f, %.3f]%n unigram quant range: [%.3f, %.3f]%n", - bigramQuantMin, bigramQuantMax, unigramQuantMin, unigramQuantMax); - } -} diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java index 6513b0f29d..d932da97cc 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java @@ -92,10 +92,11 @@ public final class JunkDetector implements TextQualityDetector { static final String MAGIC = "JUNKDET1"; /** Sole supported file-format version. Mismatch is a hard error. */ - static final int VERSION = 6; + static final int VERSION = 7; - // Feature 1 — global hashed codepoint-bigram + Bloom-gated unigram backoff - private final F1Tables f1; + // Feature 1 — per-script open-addressed codepoint-bigram tables. + // No global Bloom: empty-slot is the membership oracle. + private final Map<String, V7Tables> f1TablesByScript; /** Per-script F1 calibration on the codepoint-hash mean log-prob. */ private final Map<String, float[]> calibrations; // script → float[2] {mu, sigma} @@ -127,7 +128,7 @@ public final class JunkDetector implements TextQualityDetector { float[] scriptTransitionCalibration, Map<String, Integer> scriptBucketIndex, int numScriptBuckets, - F1Tables f1) { + Map<String, V7Tables> f1TablesByScript) { this.calibrations = Collections.unmodifiableMap(calibrations); this.blockTables = Collections.unmodifiableMap(blockTables); this.blockCalibrations = Collections.unmodifiableMap(blockCalibrations); @@ -137,7 +138,7 @@ public final class JunkDetector implements TextQualityDetector { this.scriptTransitionCalibration = scriptTransitionCalibration; this.scriptBucketIndex = Collections.unmodifiableMap(scriptBucketIndex); this.numScriptBuckets = numScriptBuckets; - this.f1 = f1; + this.f1TablesByScript = Collections.unmodifiableMap(f1TablesByScript); } // ----------------------------------------------------------------------- @@ -197,7 +198,7 @@ public final class JunkDetector implements TextQualityDetector { * <p>File-format layout (gzipped): * <pre> * [8 bytes] magic "JUNKDET1" (ASCII) - * [1 byte] version (= 6) + * [1 byte] version (= 7) * [4 bytes] num_scripts (int BE) * [1 byte] block_scheme_version (must equal * {@link UnicodeBlockRanges#SCHEME_VERSION}) @@ -208,32 +209,34 @@ public final class JunkDetector implements TextQualityDetector { * [num_script_buckets² × 4 bytes] script-transition log-prob table (F4) * [4 bytes] mu4 (float32 BE) * [4 bytes] sigma4 (float32 BE) - * [4 bytes] fnv_seed (int BE) - * [4 bytes] backoff_alpha (float32 BE) - * [4 bytes] bigram_buckets (int BE, power of 2) - * [4 bytes] bigram_quant_min (float32 BE) - * [4 bytes] bigram_quant_max (float32 BE) - * [bigram_buckets bytes] bigram log-prob table (8-bit quantized) - * [4 bytes] unigram_buckets (int BE, power of 2) - * [4 bytes] unigram_quant_min (float32 BE) - * [4 bytes] unigram_quant_max (float32 BE) - * [unigram_buckets bytes] unigram log-prob table (8-bit quantized) - * [4 bytes] bloom_bits (int BE, multiple of 64) - * [1 byte] bloom_k - * [bloom_bits/8 bytes] Bloom bit array * for each script (sorted by name): * [2 bytes] name length * [name bytes] script name (UTF-8) - * [4 bytes] mu1 (F1 calibration, codepoint-hash mean log-prob) + * [4 bytes] mu1 (F1 calibration, codepoint-bigram mean log-prob) * [4 bytes] sigma1 + * // V7 F1 tables for this script — see {@link V7Tables#writeTo} + * [4 bytes] backoff_alpha (float32 BE) + * [4 bytes] codepoint_count + * [codepoint_count × 4 bytes] codepoint index (sorted, ascending) + * [4 bytes] bigram_slots (power of 2) + * [4 bytes] bigram_quant_min (float32 BE) + * [4 bytes] bigram_quant_max (float32 BE) + * [bigram_slots × 4 bytes] bigram open-addressing keys + * ((idxA<<16)|idxB, or {@link V7Tables#EMPTY_KEY}) + * [bigram_slots bytes] bigram values (8-bit quantized log-probs) + * [4 bytes] unigram_quant_min (float32 BE) + * [4 bytes] unigram_quant_max (float32 BE) + * [4 bytes] unigram_fallback_log_prob (float32 BE; used for + * codepoints not in index) + * [codepoint_count bytes] unigram values (8-bit quantized log-probs) + * // F2/F3/classifier (unchanged from v6 layout) * [4 bytes] mu2 (F2 calibration) * [4 bytes] sigma2 * [block_N² × 4 bytes] block-transition log-prob table (F2) * [4 bytes] mu3 (F3 calibration) * [4 bytes] sigma3 * [1 byte] num_features - * [num_features × 4 bytes] classifier weights w1..wN - * [4 bytes] bias + * [(num_features+1) × 4 bytes] classifier weights w1..wN and bias * </pre> */ public static JunkDetector load(InputStream rawIs) throws IOException { @@ -281,39 +284,22 @@ public final class JunkDetector implements TextQualityDetector { float[] scriptTransitionTable = readFloatTable(dis, numScriptBuckets * numScriptBuckets); float[] scriptTransitionCalibration = new float[]{dis.readFloat(), dis.readFloat()}; - // Global F1 hash + Bloom section - int fnvSeed = dis.readInt(); - float backoffAlpha = dis.readFloat(); - int bigramBuckets = dis.readInt(); - float bigramMin = dis.readFloat(); - float bigramMax = dis.readFloat(); - byte[] bigramHash = dis.readNBytes(bigramBuckets); - int unigramBuckets = dis.readInt(); - float unigramMin = dis.readFloat(); - float unigramMax = dis.readFloat(); - byte[] unigramHash = dis.readNBytes(unigramBuckets); - int bloomBits = dis.readInt(); - int bloomK = dis.readUnsignedByte(); - int bloomLongs = (bloomBits + 63) >> 6; - long[] bloom = new long[bloomLongs]; - byte[] bloomBytes = dis.readNBytes(bloomBits / 8); - ByteBuffer bloomBuf = ByteBuffer.wrap(bloomBytes).order(ByteOrder.BIG_ENDIAN); - bloomBuf.asLongBuffer().get(bloom); - F1Tables f1 = new F1Tables(bigramHash, bigramBuckets, bigramMin, bigramMax, - unigramHash, unigramBuckets, unigramMin, unigramMax, - bloom, bloomBits, bloomK, fnvSeed, backoffAlpha); - - Map<String, float[]> calibrations = new HashMap<>(numScripts * 2); - Map<String, float[]> blockTables = new HashMap<>(numScripts * 2); - Map<String, float[]> blockCalibrations = new HashMap<>(numScripts * 2); - Map<String, float[]> controlCalibrations = new HashMap<>(numScripts * 2); - Map<String, float[]> classifierWeights = new HashMap<>(numScripts * 2); + Map<String, V7Tables> f1TablesByScript = new HashMap<>(numScripts * 2); + Map<String, float[]> calibrations = new HashMap<>(numScripts * 2); + Map<String, float[]> blockTables = new HashMap<>(numScripts * 2); + Map<String, float[]> blockCalibrations = new HashMap<>(numScripts * 2); + Map<String, float[]> controlCalibrations = new HashMap<>(numScripts * 2); + Map<String, float[]> classifierWeights = new HashMap<>(numScripts * 2); for (int s = 0; s < numScripts; s++) { int nameLen = dis.readUnsignedShort(); String script = new String(dis.readNBytes(nameLen), StandardCharsets.UTF_8); calibrations.put(script, new float[]{dis.readFloat(), dis.readFloat()}); + + // Per-script V7 F1 tables. + f1TablesByScript.put(script, V7Tables.readFrom(dis)); + blockCalibrations.put(script, new float[]{dis.readFloat(), dis.readFloat()}); blockTables.put(script, readFloatTable(dis, blockN * blockN)); controlCalibrations.put(script, new float[]{dis.readFloat(), dis.readFloat()}); @@ -330,7 +316,7 @@ public final class JunkDetector implements TextQualityDetector { blockTables, blockCalibrations, controlCalibrations, classifierWeights, scriptTransitionTable, scriptTransitionCalibration, - scriptBucketIndex, numScriptBuckets, f1); + scriptBucketIndex, numScriptBuckets, f1TablesByScript); } } @@ -582,8 +568,9 @@ public final class JunkDetector implements TextQualityDetector { * training and inference share the same math. */ private float[] computeChunkZs(byte[] utf8, String text, String script) { - // Feature 1: global hashed codepoint-bigram, calibrated per-script - float meanF1LogProb = computeCodepointHashMeanLogP(text); + // Feature 1: per-script codepoint-bigram, calibrated per-script + V7Tables tables = f1TablesByScript.get(script); + float meanF1LogProb = computeCodepointF1MeanLogP(text, tables); float[] cal1 = calibrations.get(script); float z1 = (meanF1LogProb - cal1[0]) / cal1[1]; @@ -593,6 +580,12 @@ public final class JunkDetector implements TextQualityDetector { return new float[]{z1, z2, z3}; } + private static float computeCodepointF1MeanLogP(String text, V7Tables tables) { + if (tables == null) return Float.NaN; + double v = computeF1MeanLogP(text, tables); + return Double.isNaN(v) ? Float.NaN : (float) v; + } + /** * Feature 2 — calibrated z-score for block-transition mean log-prob on * one text window. Returns 0 if the window has fewer than two @@ -719,80 +712,145 @@ public final class JunkDetector implements TextQualityDetector { } // ----------------------------------------------------------------------- - // Feature 1: hashed codepoint-bigram + Bloom-gated unigram-backoff + // Feature 1: per-script open-addressing codepoint-bigram lookup // ----------------------------------------------------------------------- /** * Mean log-prob over the codepoint pairs in {@code text} using the given - * F1 tables. + * script's V7 F1 tables. * - * <p>For each adjacent codepoint pair {@code (a, b)}: if the Bloom filter - * reports the pair was seen at training time, return the dequantized - * log-prob from the bigram hash bucket. Else fall back to - * {@code alpha * (log P(a) + log P(b))} via the unigram hash table. + * <p>For each adjacent codepoint pair {@code (a, b)}: + * <ol> + * <li>Binary-search both codepoints in the script's codepoint index. + * If either is absent, the pair was never seen in training; emit + * {@code α * (logP(a) + logP(b))} using each codepoint's unigram + * value (or {@link V7Tables#unigramFallbackLogProb} if the + * codepoint isn't even in the unigram index).</li> + * <li>Otherwise, look up the packed {@code (idxA<<16)|idxB} key in + * the open-addressing bigram table. Empty slot → unseen pair → + * unigram backoff (same formula). Match → dequantize the stored + * value.</li> + * </ol> * - * <p>This is the single authoritative implementation of the F1 scoring - * math, shared between inference ({@link #score}) and training - * ({@link org.apache.tika.ml.junkdetect.tools.TrainJunkModel#calibrateF1PerScript}). - * Keeping one implementation eliminates the risk of train/infer drift in - * the F1 feature — the same risk that motivated making z2/z3/z4 public - * static methods. + * <p>This is the single authoritative implementation of the V7 F1 + * scoring math, shared by inference and training. Keeping one + * implementation eliminates the risk of train/infer drift in the F1 + * feature. * * @return mean log-prob, or {@link Double#NaN} if {@code text} has fewer - * than two codepoints + * than two codepoints or {@code tables} is null */ - public static double computeF1MeanLogP(String text, F1Tables f1) { - if (text == null || text.length() < 2) { + public static double computeF1MeanLogP(String text, V7Tables tables) { + if (text == null || text.length() < 2 || tables == null) { return Double.NaN; } double sum = 0; int n = 0; int prevCp = -1; + int prevIdx = -1; for (int i = 0; i < text.length(); ) { int cp = text.codePointAt(i); i += Character.charCount(cp); + int curIdx = codepointToIndex(tables, cp); if (prevCp >= 0) { - sum += scorePairF1(prevCp, cp, f1); + sum += scorePairF1V7(prevCp, prevIdx, cp, curIdx, tables); n++; } prevCp = cp; + prevIdx = curIdx; } return n == 0 ? Double.NaN : sum / n; } - /** Thin instance wrapper around the shared static implementation. */ - private float computeCodepointHashMeanLogP(String text) { - double v = computeF1MeanLogP(text, f1); - return Double.isNaN(v) ? Float.NaN : (float) v; + /** + * Binary-search a codepoint in the script's index. + * + * @return the dense index (≥ 0) if found, or -1 if the codepoint + * doesn't appear in any kept bigram for this script + */ + public static int codepointToIndex(V7Tables tables, int cp) { + return java.util.Arrays.binarySearch(tables.codepointIndex, cp); + } + + /** + * Mixing function used to scatter packed (idxA, idxB) keys across + * the open-addressing table. A simple integer finalizer (splitmix32 + * style) gives good distribution for sequential index values. + * + * <p>Public so the trainer's open-addressing insertion routine uses + * the same probe order as inference — drift here would silently + * corrupt every lookup. + */ + public static int mixIndexKey(int packedKey) { + int x = packedKey; + x = (x ^ (x >>> 16)) * 0x7feb352d; + x = (x ^ (x >>> 15)) * 0x846ca68b; + x = x ^ (x >>> 16); + return x; + } + + /** + * Packed bigram key for indices {@code (a, b)} where each index fits in + * {@link JunkDetectorTrainingConfig#KEY_INDEX_BITS} bits. Asserts that + * indices are non-negative; that's the caller's contract. + */ + public static int packBigramKey(int idxA, int idxB) { + return (idxA << 16) | (idxB & 0xFFFF); } - private static double scorePairF1(int cpA, int cpB, F1Tables f1) { - if (bloomContainsF1(cpA, cpB, f1)) { - int bucket = (int) (fnv1aBigram(cpA, cpB, f1.fnvSeed) & (f1.bigramBuckets - 1)); - return dequantize(f1.bigramHash[bucket], f1.bigramQuantMin, f1.bigramQuantMax); - } - // Unigram backoff. α=1.0 gives plain independence assumption; - // α<1 discounts (stupid-backoff style) — prototype showed α=1.0 is - // correct for junk discrimination. - double ua = dequantize( - f1.unigramHash[(int) (fnv1aUnigram(cpA, f1.fnvSeed) & (f1.unigramBuckets - 1))], - f1.unigramQuantMin, f1.unigramQuantMax); - double ub = dequantize( - f1.unigramHash[(int) (fnv1aUnigram(cpB, f1.fnvSeed) & (f1.unigramBuckets - 1))], - f1.unigramQuantMin, f1.unigramQuantMax); - return f1.backoffAlpha * (ua + ub); - } - - private static boolean bloomContainsF1(int cpA, int cpB, F1Tables f1) { - long h1 = fnv1aBigram(cpA, cpB, f1.fnvSeed); - long h2 = secondaryHash(cpA, cpB); - for (int i = 0; i < f1.bloomK; i++) { - long pos = ((h1 + (long) i * h2) & 0x7FFFFFFFFFFFFFFFL) % f1.bloomBitCount; - if ((f1.bloomBits[(int) (pos >>> 6)] & (1L << (pos & 63))) == 0) { - return false; + /** + * Looks up a (cpA, cpB) bigram in the script's V7 tables and returns + * its dequantized log-prob. Falls back to unigram backoff on miss. + * + * <p>{@code idxA}/{@code idxB} are the pre-computed codepoint indices + * (from {@link #codepointToIndex}); {@code -1} means the codepoint is + * not in this script's index. The caller is expected to compute them + * once when scanning the text (avoiding a redundant binary search per + * codepoint). + */ + private static double scorePairF1V7(int cpA, int idxA, int cpB, int idxB, + V7Tables tables) { + if (idxA >= 0 && idxB >= 0) { + int slot = lookupBigramSlot(tables, idxA, idxB); + if (slot >= 0) { + return dequantize(tables.bigramValues[slot], + tables.bigramQuantMin, tables.bigramQuantMax); } } - return true; + // Unigram backoff for unseen pair or for codepoints absent from the + // per-script index. α=1.0 = plain independence; prototype-validated. + double ua = unigramLogProb(tables, idxA); + double ub = unigramLogProb(tables, idxB); + return tables.backoffAlpha * (ua + ub); + } + + /** + * Open-addressing lookup: returns the slot index that contains the key + * for {@code (idxA, idxB)}, or {@code -1} if not present (probe hit an + * empty slot first). + * + * <p>Linear probing with the same mix-hash used at training time — + * required for the table to be readable, not just writable. + */ + static int lookupBigramSlot(V7Tables tables, int idxA, int idxB) { + int packedKey = packBigramKey(idxA, idxB); + int[] keys = tables.bigramKeys; + int mask = keys.length - 1; + int h = mixIndexKey(packedKey) & mask; + while (true) { + int k = keys[h]; + if (k == V7Tables.EMPTY_KEY) return -1; + if (k == packedKey) return h; + h = (h + 1) & mask; + } + } + + private static double unigramLogProb(V7Tables tables, int idx) { + if (idx < 0) { + return tables.unigramFallbackLogProb; + } + return dequantize(tables.unigramTable[idx], + tables.unigramQuantMin, tables.unigramQuantMax); } private static float dequantize(byte b, float min, float max) { @@ -800,44 +858,6 @@ public final class JunkDetector implements TextQualityDetector { return min + (u / 255.0f) * (max - min); } - // FNV-1a constants for codepoint hashing. Must match the values used at - // training time (TrainJunkModel). Seed is stored in the model file. - private static final long FNV_OFFSET = 0xcbf29ce484222325L; - private static final long FNV_PRIME = 0x100000001b3L; - - static long fnv1aBigram(int cpA, int cpB, int seed) { - long h = FNV_OFFSET; - h = (h ^ (seed & 0xFF)) * FNV_PRIME; - h = (h ^ ((cpA >>> 24) & 0xFF)) * FNV_PRIME; - h = (h ^ ((cpA >>> 16) & 0xFF)) * FNV_PRIME; - h = (h ^ ((cpA >>> 8) & 0xFF)) * FNV_PRIME; - h = (h ^ (cpA & 0xFF)) * FNV_PRIME; - h = (h ^ 0xFF) * FNV_PRIME; // separator - h = (h ^ ((cpB >>> 24) & 0xFF)) * FNV_PRIME; - h = (h ^ ((cpB >>> 16) & 0xFF)) * FNV_PRIME; - h = (h ^ ((cpB >>> 8) & 0xFF)) * FNV_PRIME; - h = (h ^ (cpB & 0xFF)) * FNV_PRIME; - return h; - } - - static long fnv1aUnigram(int cp, int seed) { - long h = FNV_OFFSET; - h = (h ^ (seed & 0xFF)) * FNV_PRIME; - h = (h ^ ((cp >>> 24) & 0xFF)) * FNV_PRIME; - h = (h ^ ((cp >>> 16) & 0xFF)) * FNV_PRIME; - h = (h ^ ((cp >>> 8) & 0xFF)) * FNV_PRIME; - h = (h ^ (cp & 0xFF)) * FNV_PRIME; - return h; - } - - static long secondaryHash(int cpA, int cpB) { - long h = 0xff51afd7ed558ccdL; - h = (h ^ Integer.reverse(cpA)) * 0xc4ceb9fe1a85ec53L; - h = (h ^ Integer.reverse(cpB)) * 0xc4ceb9fe1a85ec53L; - h ^= h >>> 33; - return h; - } - /** * Computes the global script-transition z-score for the whole input * string against this model's loaded tables. Thin wrapper around the diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/V7Tables.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/V7Tables.java new file mode 100644 index 0000000000..93a82640ca --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/V7Tables.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * Carrier for one script's v7 F1 tables. + * + * <p>The v6 design used a single global codepoint-bigram hash + Bloom + * filter shared across all scripts. We measured that this ceiling + * limits accuracy: enlarging one script's training data (e.g. HAN) hurts + * the other scripts' z-scores because they share the global hash. v7 + * gives each script its own pair of tables. + * + * <p>Per-script layout: + * + * <ul> + * <li>{@code codepointIndex} — sorted, ascending {@code int[]} of every + * codepoint that appears as either side of a kept bigram for this + * script. Codepoint → dense index is a binary search; index → + * codepoint is direct array access. Typical sizes: ~7K-15K for HAN, + * ~200-500 for most other scripts. + * <li>{@code bigramKeys} / {@code bigramValues} — parallel arrays + * implementing an open-addressed hash table with linear probing. + * Each key is a 32-bit value {@code (idxA << 16) | idxB}; key {@code + * -1} means "empty slot." Indices are bounded at 16 bits (65535), + * which is comfortably above the largest per-script codepoint count + * we observe. + * <li>{@code unigramTable} — {@code byte[numCodepoints]}, quantized + * unigram log-probabilities indexed by the same codepoint→index map. + * <li>{@code bigramQuantMin/Max}, {@code unigramQuantMin/Max} — + * per-quantization ranges; dequantize by + * {@code min + (b/255) * (max - min)}. + * <li>{@code unigramFallbackLogProb} — log-prob assigned when a + * codepoint is not in {@code codepointIndex} at all. Set to the + * script's most-pessimistic unigram value (its quantization min) so + * absent codepoints don't accidentally score above legitimately-rare + * ones. + * <li>{@code backoffAlpha} — multiplier on the unigram-backoff + * independence sum, copied from v6. + * </ul> + * + * <p>Membership semantics: no Bloom filter. The empty-slot sentinel is + * the membership oracle — a pair is "seen" iff binary-search finds both + * codepoints in the index AND a probe sequence hits a matching key before + * an empty slot. Lookups are therefore exact; there is no false-positive + * backoff path as there is in v6. + * + * <p>Fields are package-private so the + * {@link org.apache.tika.ml.junkdetect.tools.TrainJunkModel} trainer can + * construct instances directly without going through accessors. + */ +public final class V7Tables { + + /** Reserved value in {@link #bigramKeys} marking an unoccupied slot. */ + public static final int EMPTY_KEY = -1; + + final int[] codepointIndex; + final int[] bigramKeys; + final byte[] bigramValues; + final byte[] unigramTable; + final float bigramQuantMin; + final float bigramQuantMax; + final float unigramQuantMin; + final float unigramQuantMax; + final float unigramFallbackLogProb; + final float backoffAlpha; + + public V7Tables(int[] codepointIndex, + int[] bigramKeys, byte[] bigramValues, + byte[] unigramTable, + float bigramQuantMin, float bigramQuantMax, + float unigramQuantMin, float unigramQuantMax, + float unigramFallbackLogProb, + float backoffAlpha) { + if (bigramKeys.length != bigramValues.length) { + throw new IllegalArgumentException( + "bigramKeys and bigramValues must have equal length: " + + bigramKeys.length + " vs " + bigramValues.length); + } + if (unigramTable.length != codepointIndex.length) { + throw new IllegalArgumentException( + "unigramTable.length must equal codepointIndex.length: " + + unigramTable.length + " vs " + codepointIndex.length); + } + this.codepointIndex = codepointIndex; + this.bigramKeys = bigramKeys; + this.bigramValues = bigramValues; + this.unigramTable = unigramTable; + this.bigramQuantMin = bigramQuantMin; + this.bigramQuantMax = bigramQuantMax; + this.unigramQuantMin = unigramQuantMin; + this.unigramQuantMax = unigramQuantMax; + this.unigramFallbackLogProb = unigramFallbackLogProb; + this.backoffAlpha = backoffAlpha; + } + + /** + * Serialises this script's F1 tables. Read back via + * {@link #readFrom(DataInputStream)}. + */ + public void writeTo(DataOutputStream dos) throws IOException { + dos.writeFloat(backoffAlpha); + + // Codepoint index. + dos.writeInt(codepointIndex.length); + ByteBuffer cpBuf = ByteBuffer.allocate(codepointIndex.length * 4) + .order(ByteOrder.BIG_ENDIAN); + cpBuf.asIntBuffer().put(codepointIndex); + dos.write(cpBuf.array()); + + // Bigram open-addressing table (keys + values). + dos.writeInt(bigramKeys.length); + dos.writeFloat(bigramQuantMin); + dos.writeFloat(bigramQuantMax); + ByteBuffer keyBuf = ByteBuffer.allocate(bigramKeys.length * 4) + .order(ByteOrder.BIG_ENDIAN); + keyBuf.asIntBuffer().put(bigramKeys); + dos.write(keyBuf.array()); + dos.write(bigramValues); + + // Unigram table. + dos.writeFloat(unigramQuantMin); + dos.writeFloat(unigramQuantMax); + dos.writeFloat(unigramFallbackLogProb); + dos.write(unigramTable); + } + + /** Inverse of {@link #writeTo(DataOutputStream)}. */ + public static V7Tables readFrom(DataInputStream dis) throws IOException { + float backoffAlpha = dis.readFloat(); + + int cpCount = dis.readInt(); + byte[] cpBytes = dis.readNBytes(cpCount * 4); + int[] codepoints = new int[cpCount]; + ByteBuffer.wrap(cpBytes).order(ByteOrder.BIG_ENDIAN).asIntBuffer().get(codepoints); + + int slots = dis.readInt(); + float bMin = dis.readFloat(); + float bMax = dis.readFloat(); + byte[] keyBytes = dis.readNBytes(slots * 4); + int[] keys = new int[slots]; + ByteBuffer.wrap(keyBytes).order(ByteOrder.BIG_ENDIAN).asIntBuffer().get(keys); + byte[] values = dis.readNBytes(slots); + + float uMin = dis.readFloat(); + float uMax = dis.readFloat(); + float uFallback = dis.readFloat(); + byte[] unigramTable = dis.readNBytes(cpCount); + + return new V7Tables(codepoints, keys, values, unigramTable, + bMin, bMax, uMin, uMax, uFallback, backoffAlpha); + } + + /** + * Returns a one-line summary for trainer progress output. + */ + public String statsString() { + return String.format( + " cp_index=%d, bigram_slots=%d (load≈%.2f), " + + "bigram_range=[%.3f, %.3f], unigram_range=[%.3f, %.3f]", + codepointIndex.length, bigramKeys.length, + occupiedSlots() / (double) Math.max(1, bigramKeys.length), + bigramQuantMin, bigramQuantMax, + unigramQuantMin, unigramQuantMax); + } + + private int occupiedSlots() { + int n = 0; + for (int k : bigramKeys) { + if (k != EMPTY_KEY) n++; + } + return n; + } + + /** Number of codepoints in this script's index. Diagnostic. */ + public int codepointCount() { + return codepointIndex.length; + } + + /** Number of bigram-table slots (capacity). Diagnostic. */ + public int bigramSlots() { + return bigramKeys.length; + } +} diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/JunkDetectorTrainingConfig.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/JunkDetectorTrainingConfig.java index 15600a96c6..aa3761ef79 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/JunkDetectorTrainingConfig.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/JunkDetectorTrainingConfig.java @@ -142,11 +142,15 @@ public final class JunkDetectorTrainingConfig { /** * Per-script byte-budget overrides applied on top of the entropy- - * proportional allocation. Empty in the current configuration: an - * experiment that gave HAN 60 MB instead of the entropy-derived 26 MB - * <i>worsened</i> Cohen's d for every non-HAN script (the global F1 - * hash table is the bottleneck, not the corpus), so the override - * mechanism is preserved as infrastructure but is not currently used. + * proportional allocation. Empty in the current configuration. + * + * <p>Under v6 the {@code HAN=60MB} experiment <em>worsened</em> every + * non-HAN script (the global F1 hash table was the bottleneck). Under + * v7's per-script tables, the same experiment correctly leaves other + * scripts untouched, but the HAN gain itself was negligible (Cohen's d + * moved 7.26 → 7.35) — the per-script HAN model is already near its + * data-saturation point with ~18 MB of training data. Override left + * empty until a more decisive HAN-coverage experiment is designed. */ public static final Map<String, Long> SCRIPT_BUDGET_OVERRIDES = Collections.emptyMap(); @@ -156,22 +160,34 @@ public final class JunkDetectorTrainingConfig { // ======================================================================= /** - * Drop F1 bigrams whose global per-pair occurrence count is below this - * threshold from the codepoint-bigram hash table and Bloom filter. - * Set to 3 on evidence that singleton and doubleton pairs are - * overwhelmingly OCR artifacts and proper-noun noise that inflate the - * clean-side score distribution tail without contributing signal. + * Drop per-script F1 bigrams whose per-pair occurrence count (within + * that script's training data) is below this threshold. Set to 3 on + * evidence that singleton and doubleton pairs are overwhelmingly OCR + * artifacts and proper-noun noise that inflate the clean-side score + * distribution tail without contributing signal. * - * <p>Set to 1 to disable the filter (legacy behavior). + * <p>Set to 1 to disable the filter (every observed pair retained). */ public static final int MIN_BIGRAM_COUNT = 3; /** - * Bloom filter capacity in bits for the F1 codepoint-bigram membership - * oracle. Must be a multiple of 64. 16 Mbit gives a comfortable false- - * positive rate at the current corpus's distinct-pair count. + * Target load factor for the per-script open-addressing F1 hash + * table. Table capacity is sized as the smallest power of two + * larger than {@code keptPairs / loadFactor}, giving an average of + * 1 / (1 - loadFactor) probes per lookup. 0.5 → ~2 probes; modestly + * wasteful in space but very cheap to probe. + */ + public static final double OA_LOAD_FACTOR = 0.5; + + /** + * Bit width of each codepoint's dense index within a script's F1 + * table. Each bigram is packed as {@code (idxA << KEY_INDEX_BITS) | + * idxB}, so each side must fit in this many bits. 16 bits supports + * up to 65535 distinct codepoints per script, which is comfortably + * above the largest per-script count we have measured (HAN is the + * worst case at ~15K kept codepoints). */ - public static final int BLOOM_BITS = 16 * 1024 * 1024; + public static final int KEY_INDEX_BITS = 16; private JunkDetectorTrainingConfig() { // No instances. diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java index 229c60a5c5..cf52a9eedf 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java @@ -41,8 +41,8 @@ import java.util.TreeMap; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; -import org.apache.tika.ml.junkdetect.F1Tables; import org.apache.tika.ml.junkdetect.JunkDetector; +import org.apache.tika.ml.junkdetect.V7Tables; /** * Trains the junk detector model from per-script corpus files produced by @@ -128,26 +128,16 @@ public class TrainJunkModel { static final String MAGIC = "JUNKDET1"; /** Sole supported file-format version. Matches JunkDetector.VERSION. */ - static final byte VERSION = 6; + static final byte VERSION = 7; // ----------------------------------------------------------------------- - // v6 model constants (codepoint-bigram-hash + Bloom + unigram-backoff) + // v7 model constants (per-script open-addressing codepoint-bigram tables) // ----------------------------------------------------------------------- - /** Bigram hash bucket count (power of 2). */ - static final int V6_BIGRAM_BUCKETS = 4096; - /** Unigram hash bucket count (power of 2). */ - static final int V6_UNIGRAM_BUCKETS = 8192; - /** Default Bloom filter capacity in bits (must be a multiple of 64). */ - static final int V6_BLOOM_BITS_DEFAULT = 16 * 1024 * 1024; - /** Number of hash functions for the Bloom filter (double-hashing). */ - static final int V6_BLOOM_K = 7; - /** FNV-1a seed; stored in the model so training and inference stay synced. */ - public static final int V6_FNV_SEED = 0xB8E7A1F3; /** Unigram backoff multiplier. α=1.0 = plain independence; prototype validated. */ - static final float V6_BACKOFF_ALPHA = 1.0f; + static final float V7_BACKOFF_ALPHA = 1.0f; /** Additive smoothing constant for log-prob computation. */ - static final double V6_ADD_ALPHA = 0.01; + static final double V7_ADD_ALPHA = 0.01; /** Number of clean (and corrupted) windows used to train the per-script classifier. */ static final int NUM_CLASSIFIER_SAMPLES = 500; @@ -205,16 +195,21 @@ public class TrainJunkModel { // Durable training parameters live in JunkDetectorTrainingConfig; this // tool deliberately refuses CLI overrides so a built model file's // identity always matches a committed config. - int bloomBits = JunkDetectorTrainingConfig.BLOOM_BITS; int minBigramCount = JunkDetectorTrainingConfig.MIN_BIGRAM_COUNT; - if (bloomBits % 64 != 0) { - System.err.println("ERROR: BLOOM_BITS must be a multiple of 64"); - System.exit(1); - } + double loadFactor = JunkDetectorTrainingConfig.OA_LOAD_FACTOR; + int keyIndexBits = JunkDetectorTrainingConfig.KEY_INDEX_BITS; if (minBigramCount < 1) { System.err.println("ERROR: MIN_BIGRAM_COUNT must be >= 1"); System.exit(1); } + if (loadFactor <= 0 || loadFactor >= 1) { + System.err.println("ERROR: OA_LOAD_FACTOR must be in (0, 1), got " + loadFactor); + System.exit(1); + } + if (keyIndexBits < 1 || keyIndexBits > 16) { + System.err.println("ERROR: KEY_INDEX_BITS must be in [1, 16], got " + keyIndexBits); + System.exit(1); + } for (int i = 0; i < args.length; i++) { switch (args[i]) { @@ -240,15 +235,12 @@ public class TrainJunkModel { System.out.println("=== TrainJunkModel ==="); System.out.println(" data-dir: " + dataDir); System.out.println(" output: " + output); - System.out.println(" --- v6 format constants (TrainJunkModel) ---"); - System.out.printf( " bigram_buckets: %d%n", V6_BIGRAM_BUCKETS); - System.out.printf( " unigram_buckets: %d%n", V6_UNIGRAM_BUCKETS); - System.out.printf( " fnv_seed: 0x%08X%n", V6_FNV_SEED); - System.out.printf( " backoff_alpha: %.2f%n", V6_BACKOFF_ALPHA); + System.out.println(" --- v7 format constants (TrainJunkModel) ---"); + System.out.printf( " backoff_alpha: %.2f%n", V7_BACKOFF_ALPHA); System.out.println(" --- config (JunkDetectorTrainingConfig) ---"); - System.out.printf( " bloom_bits: %d (%d KB), k=%d%n", - bloomBits, bloomBits / 8 / 1024, V6_BLOOM_K); System.out.printf( " min_bigram_count: %d%n", minBigramCount); + System.out.printf( " oa_load_factor: %.2f%n", loadFactor); + System.out.printf( " key_index_bits: %d%n", keyIndexBits); if (!Files.isDirectory(dataDir)) { System.err.println("ERROR: data-dir not found: " + dataDir); @@ -283,19 +275,11 @@ public class TrainJunkModel { } // ----------------------------------------------------------------------- - // Phase 1 — global codepoint-bigram + unigram hash + Bloom filter - // ----------------------------------------------------------------------- - System.out.println("\n--- Phase 1: global codepoint-hash tables + Bloom ---"); - t0 = System.currentTimeMillis(); - System.out.print(" Training global codepoint-bigram + unigram + Bloom... "); - F1Tables f1Tables = trainCodepointHashTables(trainFiles, bloomBits, minBigramCount); - System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); - System.out.print(f1Tables.statsString()); - - // ----------------------------------------------------------------------- - // Phase 1.5 — per-script F1 calibration + F2 block tables + F3 control cal + // Phase 1 — per-script F1 tables (V7), F1 calibration, F2 block tables, + // F3 control-byte calibration // ----------------------------------------------------------------------- - System.out.println("\n--- Phase 1.5: per-script tables and calibrations ---"); + TreeMap<String, V7Tables> f1TablesByScript = new TreeMap<>(); + System.out.println("\n--- Phase 1: per-script F1 tables + calibrations ---"); for (Path trainFile : trainFiles) { String filename = trainFile.getFileName().toString(); String script = filename.substring(0, filename.length() - ".train.gz".length()) @@ -304,6 +288,14 @@ public class TrainJunkModel { System.out.printf("%n [%s]%n", script); allTrainFiles.add(trainFile); + t0 = System.currentTimeMillis(); + System.out.print(" Training V7 F1 tables (cp index + OA).."); + V7Tables v7 = trainV7TablesForScript(trainFile, minBigramCount, + loadFactor, keyIndexBits); + System.out.printf(" done (%dms)%n", System.currentTimeMillis() - t0); + System.out.println(v7.statsString()); + f1TablesByScript.put(script, v7); + t0 = System.currentTimeMillis(); System.out.print(" Training named-block table... "); float[] blockTable = trainBlockTable(trainFile); @@ -311,7 +303,7 @@ public class TrainJunkModel { t0 = System.currentTimeMillis(); System.out.print(" Calibrating F1 (cp-hash) on train.. "); - float[] f1Cal = calibrateF1PerScript(trainFile, f1Tables); + float[] f1Cal = calibrateF1PerScript(trainFile, v7); System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", f1Cal[0], f1Cal[1], System.currentTimeMillis() - t0); @@ -387,8 +379,8 @@ public class TrainJunkModel { } t0 = System.currentTimeMillis(); System.out.printf(" [%s] training classifier... ", script); - float[] weights = trainClassifierV6(trainFile, - f1Tables, f1Calibrations.get(script), + float[] weights = trainClassifierV7(trainFile, + f1TablesByScript.get(script), f1Calibrations.get(script), blockTables.get(script), blockCalibrations.get(script), controlCalibrations.get(script), scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets, @@ -401,11 +393,11 @@ public class TrainJunkModel { System.out.printf("%nWriting model (%d scripts, blockN=%d, scriptBuckets=%d) → %s%n", f1Calibrations.size(), blockN, numScriptBuckets, output); - saveModelV6(f1Calibrations, + saveModelV7(f1TablesByScript, f1Calibrations, blockTables, blockCalibrations, controlCalibrations, classifierWeights, scriptBuckets, scriptTransTable, scriptTransCal, - f1Tables, output); + output); System.out.printf("Model size: %,d bytes (%.1f KB)%n", Files.size(output), Files.size(output) / 1024.0); System.out.println("Done."); @@ -726,159 +718,218 @@ public class TrainJunkModel { } // ----------------------------------------------------------------------- - // v6 model save + carrier - // ----------------------------------------------------------------------- - - // ----------------------------------------------------------------------- - // v6 Phase 1: global codepoint-hash training + // v7 Phase 1: per-script open-addressing F1 table training // ----------------------------------------------------------------------- /** - * Single global pass over all training .gz files: counts codepoint - * pairs into a {@link #V6_BIGRAM_BUCKETS}-sized hash table, codepoint - * unigrams into a {@link #V6_UNIGRAM_BUCKETS}-sized table, and inserts - * every observed pair into a Bloom filter sized at {@code bloomBits}. - * Quantizes both log-prob tables to 8-bit unsigned. Returns a - * {@link org.apache.tika.ml.junkdetect.F1Tables} ready to - * hand to {@link #saveModelV6}. + * Builds the {@link V7Tables} F1 carrier for one script's training data. * - * <p>Hash function and Bloom-filter scheme are identical to those used at - * inference time in {@link org.apache.tika.ml.junkdetect.JunkDetector#computeF1MeanLogP} - * — same seed, same FNV-1a polynomial, same double-hashing layout — - * so the trained Bloom filter accurately reflects which pairs the - * scorer will treat as "seen". - */ - public static F1Tables trainCodepointHashTables(List<Path> trainFiles, int bloomBits) - throws IOException { - return trainCodepointHashTables(trainFiles, bloomBits, 1); - } - - /** - * Same as the 2-arg overload, but only bigrams with global per-pair count - * >= {@code minBigramCount} contribute to the bigram hash table and the - * Bloom filter. Unigrams are always counted (used for backoff). When - * {@code minBigramCount == 1} this is a no-op and the single-pass code - * path runs. Otherwise a first pass tallies per-pair counts into an - * in-memory map and a second pass emits only frequent pairs. + * <p>Two-pass: + * <ol> + * <li><b>Pass 1.</b> Count every (cpA, cpB) pair occurrence and every + * cp unigram occurrence in the script's {@code *.train.gz} file. + * Pairs with count {@code < minBigramCount} are dropped at this + * step — they're typically OCR artifacts and proper-noun noise.</li> + * <li><b>Pass 2.</b> Collect every codepoint that appears in any + * kept pair (as either side), sort, assign each a dense small + * index. Build a power-of-two open-addressing hash table sized + * for {@code keptPairs / loadFactor}; pack each retained + * {@code (idxA, idxB)} into a 32-bit key and insert via linear + * probing. Quantize both bigram log-probs and unigram log-probs + * to 8-bit.</li> + * </ol> + * + * <p>Returned {@link V7Tables} are ready to hand to + * {@link #saveModelV7}. + * + * @param trainFile the per-script {@code *.train.gz} + * @param minBigramCount drop pairs whose count is below this + * @param loadFactor target OA table load factor (e.g. 0.5) + * @param keyIndexBits bit-width per index in the packed key + * (each side of the pair must fit) */ - public static F1Tables trainCodepointHashTables(List<Path> trainFiles, int bloomBits, - int minBigramCount) throws IOException { - long[] bigramCounts = new long[V6_BIGRAM_BUCKETS]; - long[] unigramCounts = new long[V6_UNIGRAM_BUCKETS]; + public static V7Tables trainV7TablesForScript(Path trainFile, + int minBigramCount, + double loadFactor, + int keyIndexBits) throws IOException { + // --- Pass 1: tally pair and unigram counts. --- + HashMap<Long, long[]> pairCounts = new HashMap<>(1 << 14); + HashMap<Integer, long[]> unigramCounts = new HashMap<>(1 << 12); long bigramTotal = 0; long unigramTotal = 0; - long[] bloomBitArr = new long[(bloomBits + 63) >> 6]; - - HashMap<Long, long[]> pairTallies = null; - if (minBigramCount > 1) { - System.out.printf(" pre-pass: tallying per-pair counts " - + "(min_bigram_count=%d)%n", minBigramCount); - pairTallies = new HashMap<>(1 << 18); - for (Path trainFile : trainFiles) { - try (BufferedReader r = openGzipped(trainFile)) { - String line; - while ((line = r.readLine()) != null) { - int prevCp = -1; - for (int i = 0; i < line.length(); ) { - int cp = line.codePointAt(i); - i += Character.charCount(cp); - if (prevCp >= 0) { - long packed = ((long) prevCp << 24) | (cp & 0xFFFFFFL); - long[] c = pairTallies.get(packed); - if (c == null) { - pairTallies.put(packed, new long[]{1L}); - } else { - c[0]++; - } - } - prevCp = cp; + + try (BufferedReader r = openGzipped(trainFile)) { + String line; + while ((line = r.readLine()) != null) { + int prevCp = -1; + for (int i = 0; i < line.length(); ) { + int cp = line.codePointAt(i); + i += Character.charCount(cp); + long[] uc = unigramCounts.get(cp); + if (uc == null) { + unigramCounts.put(cp, new long[]{1L}); + } else { + uc[0]++; + } + unigramTotal++; + if (prevCp >= 0) { + long packed = ((long) prevCp << 32) | (cp & 0xFFFFFFFFL); + long[] bc = pairCounts.get(packed); + if (bc == null) { + pairCounts.put(packed, new long[]{1L}); + } else { + bc[0]++; } + bigramTotal++; } + prevCp = cp; } } - int kept = 0; - int dropped = 0; - for (long[] c : pairTallies.values()) { - if (c[0] >= minBigramCount) kept++; - else dropped++; - } - System.out.printf(" pre-pass: distinct pairs=%,d kept=%,d dropped=%,d%n", - pairTallies.size(), kept, dropped); } - for (Path trainFile : trainFiles) { - try (BufferedReader r = openGzipped(trainFile)) { - String line; - while ((line = r.readLine()) != null) { - int prevCp = -1; - for (int i = 0; i < line.length(); ) { - int cp = line.codePointAt(i); - i += Character.charCount(cp); - int uBucket = (int) (fnv1aUnigramV6(cp, V6_FNV_SEED) - & (V6_UNIGRAM_BUCKETS - 1)); - unigramCounts[uBucket]++; - unigramTotal++; - if (prevCp >= 0) { - boolean accept = true; - if (pairTallies != null) { - long packed = ((long) prevCp << 24) | (cp & 0xFFFFFFL); - long[] c = pairTallies.get(packed); - accept = c != null && c[0] >= minBigramCount; - } - if (accept) { - int bBucket = (int) (fnv1aBigramV6(prevCp, cp, V6_FNV_SEED) - & (V6_BIGRAM_BUCKETS - 1)); - bigramCounts[bBucket]++; - bigramTotal++; - bloomAddV6(bloomBitArr, bloomBits, V6_BLOOM_K, - prevCp, cp, V6_FNV_SEED); - } - } - prevCp = cp; - } - } - } + // --- Filter pairs by count, collect kept-codepoint set. --- + int totalDistinct = pairCounts.size(); + int keptPairs = 0; + long keptBigramTotal = 0; + java.util.TreeSet<Integer> keptCodepoints = new java.util.TreeSet<>(); + for (Map.Entry<Long, long[]> e : pairCounts.entrySet()) { + if (e.getValue()[0] < minBigramCount) continue; + keptPairs++; + keptBigramTotal += e.getValue()[0]; + long packed = e.getKey(); + int cpA = (int) (packed >>> 32); + int cpB = (int) (packed & 0xFFFFFFFFL); + keptCodepoints.add(cpA); + keptCodepoints.add(cpB); } + int dropped = totalDistinct - keptPairs; - // Add-α smoothing → log-prob → 8-bit quantize. - float[] bigramLogP = new float[V6_BIGRAM_BUCKETS]; - double bigramDenom = bigramTotal + V6_ADD_ALPHA * V6_BIGRAM_BUCKETS; - for (int i = 0; i < V6_BIGRAM_BUCKETS; i++) { - double p = (bigramCounts[i] + V6_ADD_ALPHA) / bigramDenom; - bigramLogP[i] = (float) Math.log(p); + // --- Build sorted codepoint index. --- + int[] cpIndex = new int[keptCodepoints.size()]; + int idx = 0; + for (int cp : keptCodepoints) { + cpIndex[idx++] = cp; } - float[] unigramLogP = new float[V6_UNIGRAM_BUCKETS]; - double unigramDenom = unigramTotal + V6_ADD_ALPHA * V6_UNIGRAM_BUCKETS; - for (int i = 0; i < V6_UNIGRAM_BUCKETS; i++) { - double p = (unigramCounts[i] + V6_ADD_ALPHA) / unigramDenom; + // Enforce the indexable-bits contract. + int maxIndex = (1 << keyIndexBits) - 1; + if (cpIndex.length > maxIndex + 1) { + throw new IllegalStateException("Per-script codepoint count " + + cpIndex.length + " exceeds 2^KEY_INDEX_BITS (= " + + (maxIndex + 1) + "). Increase KEY_INDEX_BITS or apply" + + " a tighter pair-count filter for " + + trainFile.getFileName()); + } + + // --- Compute per-pair log-prob (add-α smoothed over kept pairs). --- + // Denominator: kept-bigram total + α × keptPairs (only pairs we store). + double bigramDenom = keptBigramTotal + V7_ADD_ALPHA * keptPairs; + // Unigram log-probs. We keep one entry per indexed codepoint; the + // denominator uses ALL unigram observations (kept pairs only would + // bias the backoff toward common pairs). + double unigramDenom = unigramTotal + V7_ADD_ALPHA * unigramCounts.size(); + float[] unigramLogP = new float[cpIndex.length]; + for (int i = 0; i < cpIndex.length; i++) { + long[] uc = unigramCounts.get(cpIndex[i]); + long count = uc != null ? uc[0] : 0L; + double p = (count + V7_ADD_ALPHA) / unigramDenom; unigramLogP[i] = (float) Math.log(p); } + // Per-script "absent codepoint" fallback: the lowest unigram log-prob + // we'd assign to a codepoint observed exactly once. A codepoint + // *not* in our index has count 0, so: + double fallbackP = V7_ADD_ALPHA / unigramDenom; + float unigramFallbackLogP = (float) Math.log(fallbackP); - QuantizedFloats qBigram = quantizeFloats(bigramLogP); + // Quantize unigram log-probs. QuantizedFloats qUnigram = quantizeFloats(unigramLogP); - return new F1Tables( - qBigram.bytes, V6_BIGRAM_BUCKETS, + // --- Build the open-addressing bigram table. --- + int slots = nextPowerOfTwo((int) Math.max(2, Math.ceil(keptPairs / loadFactor))); + int[] keys = new int[slots]; + java.util.Arrays.fill(keys, V7Tables.EMPTY_KEY); + // Compute log-probs first, quantize once, then write into the table + // alongside its key. + float[] keptLogP = new float[keptPairs]; + int[] keptKeys = new int[keptPairs]; + int writeIdx = 0; + // codepoint -> index lookup helper (small map keyed by Integer) + HashMap<Integer, Integer> cpToIdx = new HashMap<>(cpIndex.length * 2); + for (int i = 0; i < cpIndex.length; i++) { + cpToIdx.put(cpIndex[i], i); + } + for (Map.Entry<Long, long[]> e : pairCounts.entrySet()) { + long count = e.getValue()[0]; + if (count < minBigramCount) continue; + long packed = e.getKey(); + int cpA = (int) (packed >>> 32); + int cpB = (int) (packed & 0xFFFFFFFFL); + int idxA = cpToIdx.get(cpA); + int idxB = cpToIdx.get(cpB); + int packedKey = JunkDetector.packBigramKey(idxA, idxB); + double p = (count + V7_ADD_ALPHA) / bigramDenom; + keptKeys[writeIdx] = packedKey; + keptLogP[writeIdx] = (float) Math.log(p); + writeIdx++; + } + // Quantize all kept log-probs together so they share min/max. + QuantizedFloats qBigram = quantizeFloats(keptLogP); + byte[] values = new byte[slots]; + for (int i = 0; i < keptPairs; i++) { + insertOA(keys, values, keptKeys[i], qBigram.bytes[i]); + } + + System.out.printf( + " pair_counts: distinct=%,d, kept=%,d (>=%d), dropped=%,d " + + "cp_index=%,d slots=%,d (load=%.2f)%n", + totalDistinct, keptPairs, minBigramCount, dropped, + cpIndex.length, slots, keptPairs / (double) slots); + + return new V7Tables(cpIndex, keys, values, qUnigram.bytes, qBigram.min, qBigram.max, - qUnigram.bytes, V6_UNIGRAM_BUCKETS, qUnigram.min, qUnigram.max, - bloomBitArr, bloomBits, V6_BLOOM_K, - V6_FNV_SEED, V6_BACKOFF_ALPHA); + unigramFallbackLogP, V7_BACKOFF_ALPHA); + } + + /** + * Inserts a {@code (packedKey, value)} pair into the open-addressing + * table. The caller is responsible for sizing the table large enough + * to avoid an infinite probe (any load < 1.0 is safe). + */ + private static void insertOA(int[] keys, byte[] values, int packedKey, byte value) { + int mask = keys.length - 1; + int h = JunkDetector.mixIndexKey(packedKey) & mask; + while (keys[h] != V7Tables.EMPTY_KEY) { + if (keys[h] == packedKey) { + // Same key twice — shouldn't happen with our dedup, but be + // defensive and overwrite rather than corrupt. + values[h] = value; + return; + } + h = (h + 1) & mask; + } + keys[h] = packedKey; + values[h] = value; + } + + private static int nextPowerOfTwo(int n) { + if (n < 1) return 1; + int p = Integer.highestOneBit(n - 1) << 1; + return Math.max(1, p); } /** * Computes per-script F1 calibration ({mu, sigma}) by scoring each - * window in the dev file against the trained codepoint-hash tables - * and collecting the mean log-prob distribution. Delegates to + * window in the dev file against the trained per-script codepoint + * tables. Delegates to * {@link org.apache.tika.ml.junkdetect.JunkDetector#computeF1MeanLogP} * — the single authoritative F1 implementation shared between training * and inference. */ - public static float[] calibrateF1PerScript(Path devGz, F1Tables f1) throws IOException { + public static float[] calibrateF1PerScript(Path devGz, V7Tables tables) throws IOException { List<String> windows = sampleSubstrings(devGz, CALIB_SAMPLES, CALIB_LENGTHS, 42); List<Double> scores = new ArrayList<>(windows.size()); for (String window : windows) { - double score = JunkDetector.computeF1MeanLogP(window, f1); + double score = JunkDetector.computeF1MeanLogP(window, tables); if (!Double.isNaN(score)) { scores.add(score); } @@ -887,38 +938,20 @@ public class TrainJunkModel { return muSigma(scores); } - /** Bloom membership check (matches JunkDetector.bloomContains semantics). */ - public static boolean bloomContainsV6(long[] bloomBits, int bitCount, int k, - int cpA, int cpB, int seed) { - long h1 = fnv1aBigramV6(cpA, cpB, seed); - long h2 = secondaryHashV6(cpA, cpB); - for (int i = 0; i < k; i++) { - long pos = ((h1 + (long) i * h2) & 0x7FFFFFFFFFFFFFFFL) % bitCount; - if ((bloomBits[(int) (pos >>> 6)] & (1L << (pos & 63))) == 0) { - return false; - } - } - return true; - } - - private static float dequantize(byte b, float min, float max) { - return min + ((b & 0xFF) / 255.0f) * (max - min); - } - // ----------------------------------------------------------------------- - // v6 Phase 3: classifier feature extractor + orchestrator + // v7 Phase 3: classifier feature extractor + orchestrator // ----------------------------------------------------------------------- /** * Extracts a 4-dim calibrated z-score vector for one training window - * using the codepoint-hash architecture. z2/z3/z4 delegate to the - * public helpers on {@link JunkDetector} — same math used at inference, - * no trainer/inference drift possible. + * using the v7 per-script tables. z2/z3/z4 delegate to the public + * helpers on {@link JunkDetector} — same math used at inference, no + * trainer/inference drift possible. * * @return float[4] = {z1_cpHash, z2_block, z3_control, z4_scriptTrans} */ - static float[] extractFeaturesV6(String window, - F1Tables f1, float[] f1Cal, + static float[] extractFeaturesV7(String window, + V7Tables tables, float[] f1Cal, float[] blockTable, float[] blockCal, float[] controlCal, float[] scriptTransTable, float[] scriptTransCal, @@ -926,9 +959,9 @@ public class TrainJunkModel { int numScriptBuckets) { byte[] utf8 = window.getBytes(StandardCharsets.UTF_8); - // z1: codepoint-hash mean log-prob, per-script-calibrated + // z1: per-script codepoint-bigram mean log-prob float z1 = 0f; - double rawF1 = JunkDetector.computeF1MeanLogP(window, f1); + double rawF1 = JunkDetector.computeF1MeanLogP(window, tables); if (!Double.isNaN(rawF1) && f1Cal != null && f1Cal[1] > 0) { z1 = ((float) rawF1 - f1Cal[0]) / f1Cal[1]; } @@ -946,12 +979,12 @@ public class TrainJunkModel { /** * Trains a per-script binary logistic regression classifier on - * (z1_cpHash, z2, z3, z4). Mirrors the original {@code trainClassifier} - * scaffolding (sample windows, corrupt half, fit LR, bias-calibrate - * on short windows) but uses the codepoint-hash feature extractor. + * (z1_cpHash, z2, z3, z4). Same scaffolding as the v6 trainer + * (sample windows, corrupt half, fit LR, bias-calibrate on short + * windows) but uses v7 per-script F1 tables. */ - static float[] trainClassifierV6(Path devGz, - F1Tables f1, float[] f1Cal, + static float[] trainClassifierV7(Path devGz, + V7Tables tables, float[] f1Cal, float[] blockTable, float[] blockCal, float[] controlCal, float[] scriptTransTable, float[] scriptTransCal, @@ -995,13 +1028,13 @@ public class TrainJunkModel { List<Integer> labels = new ArrayList<>(cleanWindows.size() + corruptedWindows.size()); for (String w : cleanWindows) { - features.add(extractFeaturesV6(w, f1, f1Cal, + features.add(extractFeaturesV7(w, tables, f1Cal, blockTable, blockCal, controlCal, scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets)); labels.add(1); } for (String w : corruptedWindows) { - features.add(extractFeaturesV6(w, f1, f1Cal, + features.add(extractFeaturesV7(w, tables, f1Cal, blockTable, blockCal, controlCal, scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets)); labels.add(0); @@ -1014,7 +1047,7 @@ public class TrainJunkModel { List<Float> shortLogits = new ArrayList<>(shortWindows.size()); int nFeat = weights.length - 1; for (String w : shortWindows) { - float[] x = extractFeaturesV6(w, f1, f1Cal, + float[] x = extractFeaturesV7(w, tables, f1Cal, blockTable, blockCal, controlCal, scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets); float logit = weights[nFeat]; @@ -1032,22 +1065,19 @@ public class TrainJunkModel { } /** - * Writes a v6 model file (JUNKDET1 version=6 gzipped binary). + * Writes a v7 model file (JUNKDET1 version=7 gzipped binary). * - * <p>Layout differs from v5 in two ways: - * <ol> - * <li>A new global F1 section after the script-transition section, - * holding the codepoint-bigram hash + Bloom + unigram-backoff - * tables.</li> - * <li>Per-script section drops the 65,536-float byte-bigram table. - * The {@code mu1}/{@code sigma1} calibration fields remain (now - * calibrated on the codepoint-hash mean log-prob, not byte-bigram).</li> - * </ol> + * <p>Layout vs. v6: no global F1+Bloom section. Each per-script + * section embeds that script's {@link V7Tables} (codepoint index, + * open-addressing bigram keys+values, unigram table) directly after + * its F1 calibration, before F2. See {@link JunkDetector#load} for + * the full layout spec. * - * F2 (block transition), F3 (control byte), F4 (script transition) - * sections are unchanged from v5 — pass v5-trained tables through. + * <p>F2 (block transition), F3 (control byte), F4 (script transition) + * sections are unchanged from v6. */ - public static void saveModelV6(TreeMap<String, float[]> f1Calibrations, + public static void saveModelV7(TreeMap<String, V7Tables> f1Tables, + TreeMap<String, float[]> f1Calibrations, TreeMap<String, float[]> blockTables, TreeMap<String, float[]> blockCalibrations, TreeMap<String, float[]> controlCalibrations, @@ -1055,7 +1085,6 @@ public class TrainJunkModel { List<String> scriptBuckets, float[] scriptTransTable, float[] scriptTransCal, - F1Tables v6, Path output) throws IOException { try (DataOutputStream dos = new DataOutputStream( new GZIPOutputStream(Files.newOutputStream(output)))) { @@ -1081,15 +1110,15 @@ public class TrainJunkModel { dos.writeFloat(scriptTransCal[0]); dos.writeFloat(scriptTransCal[1]); - // Global F1 section (v6+) - v6.writeTo(dos); - - // Per-script section — v6 drops the per-script byte-bigram table. - // mu1/sigma1 remain (calibrated on the codepoint-hash score). + // Per-script sections. V7 embeds the F1 tables inline. int blockN = org.apache.tika.ml.junkdetect.UnicodeBlockRanges.bucketCount(); for (var entry : f1Calibrations.entrySet()) { String script = entry.getKey(); float[] f1Cal = entry.getValue(); + V7Tables tables = f1Tables.get(script); + if (tables == null) { + throw new IllegalStateException("No V7Tables for script " + script); + } float[] blockTable = blockTables.getOrDefault(script, new float[blockN * blockN]); float[] blockCal = blockCalibrations.getOrDefault(script, new float[]{0f, 1f}); float[] controlCal = controlCalibrations.getOrDefault(script, new float[]{0f, 1f}); @@ -1100,10 +1129,13 @@ public class TrainJunkModel { dos.writeShort(nameBytes.length); dos.write(nameBytes); - // F1 calibration only (no byte-bigram table in v6) + // F1 calibration dos.writeFloat(f1Cal[0]); dos.writeFloat(f1Cal[1]); + // F1 per-script tables + tables.writeTo(dos); + // F2 — block transitions dos.writeFloat(blockCal[0]); dos.writeFloat(blockCal[1]); @@ -1170,57 +1202,6 @@ public class TrainJunkModel { } } - // ----------------------------------------------------------------------- - // FNV-1a hash + Bloom utilities (v6 — must match JunkDetector inference) - // ----------------------------------------------------------------------- - - private static final long V6_FNV_OFFSET = 0xcbf29ce484222325L; - private static final long V6_FNV_PRIME = 0x100000001b3L; - - public static long fnv1aBigramV6(int cpA, int cpB, int seed) { - long h = V6_FNV_OFFSET; - h = (h ^ (seed & 0xFF)) * V6_FNV_PRIME; - h = (h ^ ((cpA >>> 24) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ ((cpA >>> 16) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ ((cpA >>> 8) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ (cpA & 0xFF)) * V6_FNV_PRIME; - h = (h ^ 0xFF) * V6_FNV_PRIME; - h = (h ^ ((cpB >>> 24) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ ((cpB >>> 16) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ ((cpB >>> 8) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ (cpB & 0xFF)) * V6_FNV_PRIME; - return h; - } - - public static long fnv1aUnigramV6(int cp, int seed) { - long h = V6_FNV_OFFSET; - h = (h ^ (seed & 0xFF)) * V6_FNV_PRIME; - h = (h ^ ((cp >>> 24) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ ((cp >>> 16) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ ((cp >>> 8) & 0xFF)) * V6_FNV_PRIME; - h = (h ^ (cp & 0xFF)) * V6_FNV_PRIME; - return h; - } - - public static long secondaryHashV6(int cpA, int cpB) { - long h = 0xff51afd7ed558ccdL; - h = (h ^ Integer.reverse(cpA)) * 0xc4ceb9fe1a85ec53L; - h = (h ^ Integer.reverse(cpB)) * 0xc4ceb9fe1a85ec53L; - h ^= h >>> 33; - return h; - } - - /** Adds the codepoint pair to the Bloom filter using double-hashing. */ - public static void bloomAddV6(long[] bloomBits, int bitCount, int k, - int cpA, int cpB, int seed) { - long h1 = fnv1aBigramV6(cpA, cpB, seed); - long h2 = secondaryHashV6(cpA, cpB); - for (int i = 0; i < k; i++) { - long pos = ((h1 + (long) i * h2) & 0x7FFFFFFFFFFFFFFFL) % bitCount; - bloomBits[(int) (pos >>> 6)] |= 1L << (pos & 63); - } - } - // ----------------------------------------------------------------------- // Helpers // ----------------------------------------------------------------------- diff --git a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin index 3d76d288fb..644d46bad0 100644 Binary files a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin and b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin differ diff --git a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV6Test.java b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV6Test.java deleted file mode 100644 index 1b35554e40..0000000000 --- a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV6Test.java +++ /dev/null @@ -1,361 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.tika.ml.junkdetect; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.io.BufferedWriter; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Arrays; -import java.util.List; -import java.util.TreeMap; -import java.util.zip.GZIPOutputStream; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; - -import org.apache.tika.ml.junkdetect.tools.TrainJunkModel; -import org.apache.tika.quality.TextQualityScore; - -/** - * Validates the v6 model file format end-to-end: a synthetic small model is - * constructed in-memory with known hash-table values, saved via - * {@link TrainJunkModel#saveModelV6}, loaded via {@link JunkDetector#load}, - * scored against known input, and the output verified against hand-computed - * expected values. - * - * <p>This is the architectural-decision validation: it confirms that the v6 - * file format spec, the trainer's save path, the loader, and the scoring - * path (hashed codepoint-bigram + Bloom + unigram backoff) all agree on the - * semantics. Does not require the production training corpus. - */ -public class JunkDetectorV6Test { - - @Test - void v6RoundTripSeenPairAndUnigramBackoff(@TempDir Path tmp) throws IOException { - final int seed = TrainJunkModel.V6_FNV_SEED; - - // ----------------------------------------------------------------- - // Build a tiny synthetic v6 model. - // - // Bigram table: floor at -10.0 nat, bucket for (A,B) at -1.0 nat. - // Unigram table: floor at -5.0 nat, buckets for A and B at -2.0 nat. - // Bloom: contains only (A,B). (B,A) takes the unigram-backoff path. - // ----------------------------------------------------------------- - - int bigramBuckets = 4096; - float[] bigramLog = new float[bigramBuckets]; - Arrays.fill(bigramLog, -10.0f); - int bucketAB = (int) (TrainJunkModel.fnv1aBigramV6('A', 'B', seed) - & (bigramBuckets - 1)); - bigramLog[bucketAB] = -1.0f; - TrainJunkModel.QuantizedFloats qBigram = TrainJunkModel.quantizeFloats(bigramLog); - - int unigramBuckets = 8192; - float[] unigramLog = new float[unigramBuckets]; - Arrays.fill(unigramLog, -5.0f); - int bucketA = (int) (TrainJunkModel.fnv1aUnigramV6('A', seed) - & (unigramBuckets - 1)); - int bucketB = (int) (TrainJunkModel.fnv1aUnigramV6('B', seed) - & (unigramBuckets - 1)); - unigramLog[bucketA] = -2.0f; - unigramLog[bucketB] = -2.0f; - TrainJunkModel.QuantizedFloats qUnigram = TrainJunkModel.quantizeFloats(unigramLog); - - int bloomBits = 1024; - int bloomK = 3; - long[] bloom = new long[(bloomBits + 63) >> 6]; - TrainJunkModel.bloomAddV6(bloom, bloomBits, bloomK, 'A', 'B', seed); - - F1Tables v6Tables = new F1Tables( - qBigram.bytes, bigramBuckets, qBigram.min, qBigram.max, - qUnigram.bytes, unigramBuckets, qUnigram.min, qUnigram.max, - bloom, bloomBits, bloomK, seed, 1.0f); - - // ----------------------------------------------------------------- - // Per-script F2/F3/F4 placeholders — all zeros, but with valid - // calibrations (mu=0, sigma=1). Classifier weights for LATIN - // make ONLY z1 contribute (w1=1, w2=w3=w4=0, bias=0), so the - // expected z-score isolates the v6 F1 codepoint-hash path. - // ----------------------------------------------------------------- - - TreeMap<String, float[]> f1Cal = new TreeMap<>(); - f1Cal.put("LATIN", new float[]{-5.0f, 1.0f}); - - int blockN = UnicodeBlockRanges.bucketCount(); - - TreeMap<String, float[]> blockTables = new TreeMap<>(); - blockTables.put("LATIN", new float[blockN * blockN]); - TreeMap<String, float[]> blockCal = new TreeMap<>(); - blockCal.put("LATIN", new float[]{0f, 1f}); - - TreeMap<String, float[]> controlCal = new TreeMap<>(); - controlCal.put("LATIN", new float[]{0f, 1f}); - - List<String> scriptBuckets = List.of("LATIN", "OTHER"); - float[] scriptTransTable = new float[scriptBuckets.size() * scriptBuckets.size()]; - float[] scriptTransCal = new float[]{0f, 1f}; - - TreeMap<String, float[]> classifierWeights = new TreeMap<>(); - classifierWeights.put("LATIN", new float[]{1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); - - Path modelFile = tmp.resolve("v6-test.bin"); - TrainJunkModel.saveModelV6( - f1Cal, blockTables, blockCal, controlCal, classifierWeights, - scriptBuckets, scriptTransTable, scriptTransCal, - v6Tables, modelFile); - - assertTrue(Files.size(modelFile) > 0, "Saved model should be non-empty"); - - // ----------------------------------------------------------------- - // Load and verify version. - // ----------------------------------------------------------------- - - JunkDetector detector = JunkDetector.loadFromPath(modelFile); - assertEquals(6, detector.getModelVersion(), "Loaded model should be v6"); - - // ----------------------------------------------------------------- - // Score "ABAB". Expected: - // Pair (A, B): in Bloom → bigram table → -1.0 - // Pair (B, A): not in Bloom → unigram backoff = 1.0 * (-2 + -2) = -4.0 - // Pair (A, B): in Bloom → -1.0 - // mean log-prob = (-1 + -4 + -1) / 3 = -2.0 - // z1 = (-2 - (-5)) / 1 = +3.0 - // logit = 1.0 * 3.0 + 0 + 0 + 0 + 0 = +3.0 - // - // Tolerance: 8-bit quantization of bigram table [-10, -1] gives - // ~0.035 nat per level; of unigram table [-5, -2] gives ~0.012 nat - // per level. Net z-score error is bounded by ~0.1 over 3 pairs. - // Allow 0.3 tolerance to be safe. - // ----------------------------------------------------------------- - - TextQualityScore score = detector.score("ABAB"); - assertEquals("LATIN", score.getDominantScript(), "Dominant script should be LATIN"); - assertEquals(3.0f, score.getZScore(), 0.3f, - "Expected z ≈ +3.0 for 'ABAB' (seen-pair + backoff mix)"); - } - - @Test - void v6RoundTripAllSeenPairsScoreHigher(@TempDir Path tmp) throws IOException { - // Same shape as the first test but with ALL pairs in the Bloom. - // mean log-prob = -1.0, z1 = +4.0. Verifies seen-only path. - final int seed = TrainJunkModel.V6_FNV_SEED; - - int bigramBuckets = 4096; - float[] bigramLog = new float[bigramBuckets]; - Arrays.fill(bigramLog, -10.0f); - // Put both (A,B) and (B,A) at -1.0 - int bucketAB = (int) (TrainJunkModel.fnv1aBigramV6('A', 'B', seed) - & (bigramBuckets - 1)); - int bucketBA = (int) (TrainJunkModel.fnv1aBigramV6('B', 'A', seed) - & (bigramBuckets - 1)); - bigramLog[bucketAB] = -1.0f; - bigramLog[bucketBA] = -1.0f; - TrainJunkModel.QuantizedFloats qBigram = TrainJunkModel.quantizeFloats(bigramLog); - - int unigramBuckets = 8192; - float[] unigramLog = new float[unigramBuckets]; - Arrays.fill(unigramLog, -5.0f); - TrainJunkModel.QuantizedFloats qUnigram = TrainJunkModel.quantizeFloats(unigramLog); - - int bloomBits = 1024; - int bloomK = 3; - long[] bloom = new long[(bloomBits + 63) >> 6]; - TrainJunkModel.bloomAddV6(bloom, bloomBits, bloomK, 'A', 'B', seed); - TrainJunkModel.bloomAddV6(bloom, bloomBits, bloomK, 'B', 'A', seed); - - F1Tables v6Tables = new F1Tables( - qBigram.bytes, bigramBuckets, qBigram.min, qBigram.max, - qUnigram.bytes, unigramBuckets, qUnigram.min, qUnigram.max, - bloom, bloomBits, bloomK, seed, 1.0f); - - Path modelFile = tmp.resolve("v6-test-allseen.bin"); - saveMinimalV6Model(v6Tables, modelFile); - JunkDetector detector = JunkDetector.loadFromPath(modelFile); - - TextQualityScore score = detector.score("ABAB"); - // mean = -1.0, z1 = +4.0, logit = +4.0 - assertEquals(4.0f, score.getZScore(), 0.3f, - "All-seen 'ABAB' should score z ≈ +4"); - } - - // ----------------------------------------------------------------------- - // Helper — minimal LATIN-only v6 model for tests that only need to - // exercise scoring of LATIN text. - // ----------------------------------------------------------------------- - - private static void saveMinimalV6Model(F1Tables v6, - Path modelFile) throws IOException { - TreeMap<String, float[]> f1Cal = new TreeMap<>(); - f1Cal.put("LATIN", new float[]{-5.0f, 1.0f}); - - int blockN = UnicodeBlockRanges.bucketCount(); - - TreeMap<String, float[]> blockTables = new TreeMap<>(); - blockTables.put("LATIN", new float[blockN * blockN]); - TreeMap<String, float[]> blockCal = new TreeMap<>(); - blockCal.put("LATIN", new float[]{0f, 1f}); - - TreeMap<String, float[]> controlCal = new TreeMap<>(); - controlCal.put("LATIN", new float[]{0f, 1f}); - - List<String> scriptBuckets = List.of("LATIN", "OTHER"); - float[] scriptTransTable = new float[scriptBuckets.size() * scriptBuckets.size()]; - float[] scriptTransCal = new float[]{0f, 1f}; - - TreeMap<String, float[]> classifierWeights = new TreeMap<>(); - classifierWeights.put("LATIN", new float[]{1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); - - TrainJunkModel.saveModelV6( - f1Cal, blockTables, blockCal, controlCal, classifierWeights, - scriptBuckets, scriptTransTable, scriptTransCal, - v6, modelFile); - } - - /** - * End-to-end trainer integration: drives - * {@link TrainJunkModel#trainCodepointHashTables} on a tiny synthetic - * corpus, calibrates F1, saves a model, loads it via - * {@link JunkDetector#load}, and scores text. Catches drift between - * trainer F1 math and inference F1 math — the Bloom-filter hash - * scheme, FNV seed, quantization range, and codepoint-pair iteration - * order all have to agree exactly, or scoring produces nonsense. - * - * <p>F2/F3/F4 are zeroed out (placeholder data) — the test isolates - * F1's trainer↔inference round-trip. The actual retrain (with real - * F2/F3/F4 training data) is the training-phase work. - */ - @Test - void trainerRoundTripIntegration(@TempDir Path tmp) throws IOException { - // --- 1. Build a tiny LATIN corpus on disk --- - Path trainFile = tmp.resolve("LATIN.train.gz"); - writeGzippedLines(trainFile, - "the quick brown fox jumps over the lazy dog", - "pack my box with five dozen liquor jugs", - "how vexingly quick daft zebras jump", - "the five boxing wizards jump quickly", - "sphinx of black quartz judge my vow"); - Path devFile = tmp.resolve("LATIN.dev.gz"); - writeGzippedLines(devFile, - "the rain in spain falls mainly on the plain", - "a stitch in time saves nine", - "all that glitters is not gold"); - - // --- 2. Phase 1: train codepoint-hash tables --- - // Use a small Bloom (64 KB) — the synthetic corpus has only a - // few hundred unique pairs. - F1Tables f1 = TrainJunkModel.trainCodepointHashTables( - List.of(trainFile), 524288); - - // Sanity: Bloom should contain pairs we observed in training. - // "the" → pairs (t,h) and (h,e); "fox" → (f,o), (o,x). - assertTrue(TrainJunkModel.bloomContainsV6( - f1.bloomBits, f1.bloomBitCount, f1.bloomK, - 't', 'h', f1.fnvSeed), - "Bloom should contain (t, h) — appears in training"); - assertTrue(TrainJunkModel.bloomContainsV6( - f1.bloomBits, f1.bloomBitCount, f1.bloomK, - 'o', 'x', f1.fnvSeed), - "Bloom should contain (o, x) — appears in training"); - - // --- 3. F1 raw scoring sanity --- - double meanLogP = JunkDetector.computeF1MeanLogP( - "the quick brown fox", f1); - assertTrue(Double.isFinite(meanLogP), - "Mean log-prob on training text should be finite, got " + meanLogP); - // A score in [-10, 0] is the expected range for in-distribution text. - assertTrue(meanLogP > -10 && meanLogP < 0, - "Score on training text should be sensible, got " + meanLogP); - - // --- 4. Phase 1.5: F1 calibration on dev --- - float[] f1CalLatin = TrainJunkModel.calibrateF1PerScript(devFile, f1); - assertTrue(Float.isFinite(f1CalLatin[0]), "mu1 should be finite"); - assertTrue(Float.isFinite(f1CalLatin[1]) && f1CalLatin[1] > 0, - "sigma1 should be positive finite"); - - // --- 5. Assemble + save a minimal v6 model --- - // F2/F3/F4 tables zeroed, classifier weights pure-F1 (w1=1, rest 0). - int blockN = UnicodeBlockRanges.bucketCount(); - TreeMap<String, float[]> blockTables = new TreeMap<>(); - blockTables.put("LATIN", new float[blockN * blockN]); - TreeMap<String, float[]> blockCal = new TreeMap<>(); - blockCal.put("LATIN", new float[]{0f, 1f}); - TreeMap<String, float[]> controlCal = new TreeMap<>(); - controlCal.put("LATIN", new float[]{0f, 1f}); - TreeMap<String, float[]> f1CalMap = new TreeMap<>(); - f1CalMap.put("LATIN", f1CalLatin); - TreeMap<String, float[]> classifierWeights = new TreeMap<>(); - classifierWeights.put("LATIN", new float[]{1f, 0f, 0f, 0f, 0f}); - - List<String> scriptBuckets = List.of("LATIN", "OTHER"); - float[] scriptTransTable = new float[scriptBuckets.size() * scriptBuckets.size()]; - float[] scriptTransCal = new float[]{0f, 1f}; - - Path modelPath = tmp.resolve("junkdetect.bin"); - TrainJunkModel.saveModelV6( - f1CalMap, blockTables, blockCal, controlCal, classifierWeights, - scriptBuckets, scriptTransTable, scriptTransCal, f1, modelPath); - - // --- 6. Load via JunkDetector and score --- - JunkDetector detector = JunkDetector.loadFromPath(modelPath); - assertEquals(6, detector.getModelVersion(), - "Loaded model should be v6"); - assertTrue(detector.knownScripts().contains("LATIN"), - "Loaded model should know LATIN"); - - // Score in-distribution text. With w1=1 and z2/z3/z4 forced to 0, - // the logit is purely z1 = (raw - mu)/sigma. A short window of - // in-distribution text should produce z1 in roughly [-2, +2] — - // not at the calibration extremes. - TextQualityScore score = detector.score("the quick brown fox jumps"); - assertEquals("LATIN", score.getDominantScript()); - assertTrue(Float.isFinite(score.getZScore()), - "Score on in-distribution text should be finite, got " + score); - - // --- 7. Train/infer consistency check --- - // The inference path should compute the same raw F1 score as - // JunkDetector.computeF1MeanLogP on the same text — if these - // two ever disagree, the model's calibration is silently wrong. - // We can verify indirectly: score same text using - // computeF1MeanLogP and re-derive z1 manually. - String probe = "pack my box with five dozen liquor jugs"; - double trainerRawMean = JunkDetector.computeF1MeanLogP(probe, f1); - float expectedZ1 = (float) (trainerRawMean - f1CalLatin[0]) / f1CalLatin[1]; - TextQualityScore probeScore = detector.score(probe); - // logit = w1 * z1 + 0 + 0 + 0 + 0 = z1 in this test configuration. - assertEquals(expectedZ1, probeScore.getZScore(), 0.001f, - "Inference z1 must match trainer-computed z1 " - + "(train/infer F1 math drift)"); - } - - // Writes one sentence per line, UTF-8, gzipped. - private static void writeGzippedLines(Path path, String... lines) throws IOException { - try (BufferedWriter w = new BufferedWriter(new OutputStreamWriter( - new GZIPOutputStream(Files.newOutputStream(path)), - StandardCharsets.UTF_8))) { - for (String line : lines) { - w.write(line); - w.write('\n'); - } - } - } -} diff --git a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV7Test.java b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV7Test.java new file mode 100644 index 0000000000..b846064c52 --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV7Test.java @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.TreeMap; +import java.util.zip.GZIPOutputStream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import org.apache.tika.ml.junkdetect.tools.JunkDetectorTrainingConfig; +import org.apache.tika.ml.junkdetect.tools.TrainJunkModel; +import org.apache.tika.quality.TextQualityScore; + +/** + * Validates the v7 model file format end-to-end: a synthetic small model is + * constructed in-memory with known table values, saved via + * {@link TrainJunkModel#saveModelV7}, loaded via {@link JunkDetector#load}, + * scored against known input, and the output verified against hand-computed + * expected values. + * + * <p>This is the architectural-decision validation: it confirms that the v7 + * file format spec, the trainer's save path, the loader, and the scoring + * path (per-script open-addressing codepoint-bigram + unigram backoff) all + * agree on the semantics. Does not require the production training corpus. + */ +public class JunkDetectorV7Test { + + @Test + void v7RoundTripSeenPairAndUnigramBackoff(@TempDir Path tmp) throws IOException { + // ----------------------------------------------------------------- + // Build a tiny synthetic v7 model for LATIN. + // + // codepointIndex = ['A', 'B'] (indices 0, 1) + // Pair (A, B) stored with log-prob -1.0 + // (B, A) is *not* in the bigram table — falls back to unigram. + // Unigram log-prob = -2.0 for both 'A' and 'B'. + // backoffAlpha = 1.0 → backoff sum = -4.0 + // + // Expected mean log-prob over "ABAB": + // (A,B) seen: -1.0 + // (B,A) backoff: 1.0 * (-2 + -2) = -4.0 + // (A,B) seen: -1.0 + // mean = -2.0 + // f1Cal mu=-5, sigma=1 → z1 = (-2 - -5) / 1 = +3.0 + // Classifier w1=1, rest 0, bias=0 → logit = +3.0 + // ----------------------------------------------------------------- + V7Tables tables = buildLatinTablesAB(); + + Path modelFile = tmp.resolve("v7-test.bin"); + saveMinimalV7Model(tables, modelFile); + + // Verify the file roundtrips through the loader. + JunkDetector detector = JunkDetector.loadFromPath(modelFile); + assertEquals(7, detector.getModelVersion(), "Loaded model should be v7"); + + TextQualityScore score = detector.score("ABAB"); + assertEquals("LATIN", score.getDominantScript(), "Dominant script should be LATIN"); + // Quantization of [-4, -1] to 8 bits introduces ~0.012 nat / level. + // Net z-error over 3 pairs bounded ~0.05; allow 0.3 to be safe. + assertEquals(3.0f, score.getZScore(), 0.3f, + "Expected z ≈ +3.0 for 'ABAB' (seen-pair + backoff mix)"); + } + + @Test + void v7RoundTripAllSeenPairsScoreHigher(@TempDir Path tmp) throws IOException { + // Same shape as the first test but with BOTH (A,B) and (B,A) in the + // bigram table. mean log-prob = -1.0, z1 = +4.0, logit = +4.0. + int[] cpIndex = new int[]{'A', 'B'}; + int[] keys = new int[4]; + Arrays.fill(keys, V7Tables.EMPTY_KEY); + byte[] values = new byte[4]; + float bMin = -10.0f; + float bMax = -1.0f; + byte b = quantizeOne(-1.0f, bMin, bMax); + insertOA(keys, values, JunkDetector.packBigramKey(0, 1), b); + insertOA(keys, values, JunkDetector.packBigramKey(1, 0), b); + + float uMin = -5.0f; + float uMax = -2.0f; + byte[] unigramBytes = new byte[]{ + quantizeOne(-2.0f, uMin, uMax), + quantizeOne(-2.0f, uMin, uMax), + }; + + V7Tables tables = new V7Tables(cpIndex, keys, values, unigramBytes, + bMin, bMax, uMin, uMax, + -10.0f, 1.0f); + + Path modelFile = tmp.resolve("v7-test-allseen.bin"); + saveMinimalV7Model(tables, modelFile); + JunkDetector detector = JunkDetector.loadFromPath(modelFile); + + TextQualityScore score = detector.score("ABAB"); + // mean = -1.0, z1 = (-1 - -5) / 1 = +4.0 + assertEquals(4.0f, score.getZScore(), 0.3f, + "All-seen 'ABAB' should score z ≈ +4"); + } + + /** + * End-to-end trainer integration: drives {@link + * TrainJunkModel#trainV7TablesForScript} on a tiny synthetic corpus, + * calibrates F1, saves a model, loads it, and scores text. Catches + * drift between trainer F1 math and inference F1 math — the FNV + * mix-hash, packed-key layout, and codepoint-pair iteration order all + * have to agree exactly, or scoring produces nonsense. + * + * <p>F2/F3/F4 are zeroed out (placeholder data) — the test isolates + * F1's trainer↔inference round-trip. + */ + @Test + void trainerRoundTripIntegration(@TempDir Path tmp) throws IOException { + // --- 1. Build a tiny LATIN corpus on disk --- + Path trainFile = tmp.resolve("LATIN.train.gz"); + writeGzippedLines(trainFile, + "the quick brown fox jumps over the lazy dog", + "pack my box with five dozen liquor jugs", + "how vexingly quick daft zebras jump", + "the five boxing wizards jump quickly", + "sphinx of black quartz judge my vow"); + Path devFile = tmp.resolve("LATIN.dev.gz"); + writeGzippedLines(devFile, + "the rain in spain falls mainly on the plain", + "a stitch in time saves nine", + "all that glitters is not gold"); + + // --- 2. Phase 1: train V7 F1 tables for this script --- + // Tiny corpus → min_count=1 so all pairs survive. + V7Tables tables = TrainJunkModel.trainV7TablesForScript(trainFile, + 1, JunkDetectorTrainingConfig.OA_LOAD_FACTOR, + JunkDetectorTrainingConfig.KEY_INDEX_BITS); + + // Sanity: 'h' should be in the codepoint index (appears in "the"). + assertTrue(Arrays.binarySearch(tables.codepointIndex, (int) 'h') >= 0, + "'h' should be in codepoint index — it appears in training"); + assertTrue(Arrays.binarySearch(tables.codepointIndex, (int) 'x') >= 0, + "'x' should be in codepoint index — appears in 'box', 'fox'"); + + // The pair (t, h) is in training; the OA lookup should find it. + int idxT = Arrays.binarySearch(tables.codepointIndex, (int) 't'); + int idxH = Arrays.binarySearch(tables.codepointIndex, (int) 'h'); + assertTrue(idxT >= 0 && idxH >= 0); + int slot = JunkDetector.lookupBigramSlot(tables, idxT, idxH); + assertTrue(slot >= 0, "OA lookup should find seen pair (t, h)"); + + // --- 3. F1 raw scoring sanity --- + double meanLogP = JunkDetector.computeF1MeanLogP("the quick brown fox", tables); + assertTrue(Double.isFinite(meanLogP), + "Mean log-prob on training text should be finite, got " + meanLogP); + assertTrue(meanLogP > -15 && meanLogP < 0, + "Score on training text should be sensible, got " + meanLogP); + + // --- 4. Phase 1.5: F1 calibration on dev --- + float[] f1CalLatin = TrainJunkModel.calibrateF1PerScript(devFile, tables); + assertTrue(Float.isFinite(f1CalLatin[0]), "mu1 should be finite"); + assertTrue(Float.isFinite(f1CalLatin[1]) && f1CalLatin[1] > 0, + "sigma1 should be positive finite"); + + // --- 5. Assemble + save a minimal v7 model --- + int blockN = UnicodeBlockRanges.bucketCount(); + TreeMap<String, V7Tables> f1Tables = new TreeMap<>(); + f1Tables.put("LATIN", tables); + TreeMap<String, float[]> blockTables = new TreeMap<>(); + blockTables.put("LATIN", new float[blockN * blockN]); + TreeMap<String, float[]> blockCal = new TreeMap<>(); + blockCal.put("LATIN", new float[]{0f, 1f}); + TreeMap<String, float[]> controlCal = new TreeMap<>(); + controlCal.put("LATIN", new float[]{0f, 1f}); + TreeMap<String, float[]> f1CalMap = new TreeMap<>(); + f1CalMap.put("LATIN", f1CalLatin); + TreeMap<String, float[]> classifierWeights = new TreeMap<>(); + classifierWeights.put("LATIN", new float[]{1f, 0f, 0f, 0f, 0f}); + + List<String> scriptBuckets = List.of("LATIN", "OTHER"); + float[] scriptTransTable = new float[scriptBuckets.size() * scriptBuckets.size()]; + float[] scriptTransCal = new float[]{0f, 1f}; + + Path modelPath = tmp.resolve("junkdetect.bin"); + TrainJunkModel.saveModelV7( + f1Tables, f1CalMap, blockTables, blockCal, controlCal, + classifierWeights, scriptBuckets, scriptTransTable, + scriptTransCal, modelPath); + + // --- 6. Load via JunkDetector and score --- + JunkDetector detector = JunkDetector.loadFromPath(modelPath); + assertEquals(7, detector.getModelVersion(), + "Loaded model should be v7"); + assertTrue(detector.knownScripts().contains("LATIN"), + "Loaded model should know LATIN"); + + TextQualityScore score = detector.score("the quick brown fox jumps"); + assertEquals("LATIN", score.getDominantScript()); + assertTrue(Float.isFinite(score.getZScore()), + "Score on in-distribution text should be finite, got " + score); + + // --- 7. Train/infer consistency check --- + // The inference path should compute the same raw F1 score as + // JunkDetector.computeF1MeanLogP on the same text — if these + // two ever disagree, the model's calibration is silently wrong. + String probe = "pack my box with five dozen liquor jugs"; + double trainerRawMean = JunkDetector.computeF1MeanLogP(probe, tables); + float expectedZ1 = (float) ((trainerRawMean - f1CalLatin[0]) / f1CalLatin[1]); + TextQualityScore probeScore = detector.score(probe); + // logit = w1 * z1 + 0 + 0 + 0 + 0 = z1 in this test configuration. + assertEquals(expectedZ1, probeScore.getZScore(), 0.001f, + "Inference z1 must match trainer-computed z1 " + + "(train/infer F1 math drift)"); + } + + // ----------------------------------------------------------------------- + // Helpers + // ----------------------------------------------------------------------- + + /** + * Builds a V7Tables with codepoint index ['A', 'B'], where (A,B) has a + * stored log-prob of -1.0 but (B,A) is absent (forces unigram backoff). + * Unigram log-prob = -2.0 for both A and B. + * + * <p>Bigram quant range is set explicitly to {@code [-10, -1]} so that + * the single stored value at -1.0 maps to byte 255 (avoids the + * degenerate {@code min == max} branch in + * {@link TrainJunkModel#quantizeFloats}). Same idea for the unigram + * range {@code [-5, -2]} so the (-2.0, -2.0) values map to byte 255. + */ + private static V7Tables buildLatinTablesAB() { + int[] cpIndex = new int[]{'A', 'B'}; + + // 4 slots ≈ 25% load for 1 pair. Open-addressing with linear probe. + int[] keys = new int[4]; + Arrays.fill(keys, V7Tables.EMPTY_KEY); + byte[] values = new byte[4]; + + // Manual quantization with a chosen range so we don't hit the + // degenerate single-element case. range=[-10, -1] → -1.0 → byte 255. + float bMin = -10.0f; + float bMax = -1.0f; + byte b = quantizeOne(-1.0f, bMin, bMax); + insertOA(keys, values, JunkDetector.packBigramKey(0, 1), b); + + float uMin = -5.0f; + float uMax = -2.0f; + byte[] unigramBytes = new byte[]{ + quantizeOne(-2.0f, uMin, uMax), + quantizeOne(-2.0f, uMin, uMax), + }; + + return new V7Tables(cpIndex, keys, values, unigramBytes, + bMin, bMax, + uMin, uMax, + -10.0f, 1.0f); + } + + /** Quantize a single float to 8-bit unsigned using the explicit range. */ + private static byte quantizeOne(float v, float min, float max) { + float range = max - min; + int q = Math.round(((v - min) / range) * 255.0f); + if (q < 0) q = 0; + else if (q > 255) q = 255; + return (byte) q; + } + + /** + * Replica of {@code TrainJunkModel.insertOA} (package-private) for the + * test's hand-constructed tables. Uses the same mix-hash as the + * production code path. + */ + private static void insertOA(int[] keys, byte[] values, int packedKey, byte value) { + int mask = keys.length - 1; + int h = JunkDetector.mixIndexKey(packedKey) & mask; + while (keys[h] != V7Tables.EMPTY_KEY) { + if (keys[h] == packedKey) { + values[h] = value; + return; + } + h = (h + 1) & mask; + } + keys[h] = packedKey; + values[h] = value; + } + + /** + * Saves a minimal v7 model containing only LATIN, with F2/F3/F4 zeroed + * out and pure-F1 classifier weights (w1=1, rest 0, bias 0). Scoring + * a window thus reduces to z1 directly. F1 calibration: mu=-5, sigma=1. + */ + private static void saveMinimalV7Model(V7Tables tables, Path modelFile) throws IOException { + TreeMap<String, V7Tables> f1Tables = new TreeMap<>(); + f1Tables.put("LATIN", tables); + + TreeMap<String, float[]> f1Cal = new TreeMap<>(); + f1Cal.put("LATIN", new float[]{-5.0f, 1.0f}); + + int blockN = UnicodeBlockRanges.bucketCount(); + + TreeMap<String, float[]> blockTables = new TreeMap<>(); + blockTables.put("LATIN", new float[blockN * blockN]); + TreeMap<String, float[]> blockCal = new TreeMap<>(); + blockCal.put("LATIN", new float[]{0f, 1f}); + + TreeMap<String, float[]> controlCal = new TreeMap<>(); + controlCal.put("LATIN", new float[]{0f, 1f}); + + List<String> scriptBuckets = List.of("LATIN", "OTHER"); + float[] scriptTransTable = new float[scriptBuckets.size() * scriptBuckets.size()]; + float[] scriptTransCal = new float[]{0f, 1f}; + + TreeMap<String, float[]> classifierWeights = new TreeMap<>(); + classifierWeights.put("LATIN", new float[]{1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + + TrainJunkModel.saveModelV7( + f1Tables, f1Cal, blockTables, blockCal, controlCal, + classifierWeights, scriptBuckets, scriptTransTable, + scriptTransCal, modelFile); + } + + private static void writeGzippedLines(Path path, String... lines) throws IOException { + try (BufferedWriter w = new BufferedWriter(new OutputStreamWriter( + new GZIPOutputStream(Files.newOutputStream(path)), + StandardCharsets.UTF_8))) { + for (String line : lines) { + w.write(line); + w.write('\n'); + } + } + } +} diff --git a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/tools/JunkDetectorTrainingConfigTest.java b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/tools/JunkDetectorTrainingConfigTest.java index a0f975eb46..5539830719 100644 --- a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/tools/JunkDetectorTrainingConfigTest.java +++ b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/tools/JunkDetectorTrainingConfigTest.java @@ -73,19 +73,20 @@ class JunkDetectorTrainingConfigTest { } @Test - void scriptBudgetOverridesEmptyByDefault() { - // We tried HAN=60MB; it lowered Cohen's d for every non-HAN script - // because the global F1 hash table is the bottleneck. Keep this - // map empty until v7 (per-script F1 tables) lands. + void scriptBudgetOverridesEmpty() { + // v7 hypothesis test (HAN=60MB) ran but gave only marginal gains. + // Override map is intentionally empty pending a more decisive + // experiment. assertTrue(JunkDetectorTrainingConfig.SCRIPT_BUDGET_OVERRIDES.isEmpty()); } @Test void modelTrainValues() { assertEquals(3, JunkDetectorTrainingConfig.MIN_BIGRAM_COUNT); - assertEquals(16 * 1024 * 1024, JunkDetectorTrainingConfig.BLOOM_BITS); - assertEquals(0, JunkDetectorTrainingConfig.BLOOM_BITS % 64, - "BLOOM_BITS must be a multiple of 64"); + assertEquals(0.5, JunkDetectorTrainingConfig.OA_LOAD_FACTOR, 1e-9); + assertEquals(16, JunkDetectorTrainingConfig.KEY_INDEX_BITS); + assertTrue(JunkDetectorTrainingConfig.KEY_INDEX_BITS <= 16, + "KEY_INDEX_BITS must be <= 16 to fit packed key in an int"); } @Test
