This is an automated email from the ASF dual-hosted git repository.

tallison pushed a commit to branch chardet-work
in repository https://gitbox.apache.org/repos/asf/tika.git

commit 40ad000bfcf13a6ab5f64b0575dab853dcc52013
Author: tballison <[email protected]>
AuthorDate: Fri Feb 27 16:39:46 2026 -0500

    TIKA-4662: add probe-length sweep and confusion matrix to 
EvalCharsetDetectors
    
    --lengths 20,50,100,200,full  runs evaluation at each truncated probe size,
    showing how accuracy degrades on short inputs (ZIP filenames, headers, etc.)
    
    --confusion  prints top-5 error targets per charset for the ML-All detector,
    making it easy to see which charsets are confused with which.
    
    Key findings from test-set sweep:
    - Accuracy stabilises at ~100-200 bytes (same pattern as language detection)
    - At 20B: ML-All 51.1% vs ICU4J 26.2% -- we hold up better on short probes
    - GB18030 strict is low by design (GB2312/GBK/GB18030 are confusables, 
soft=99.9%)
    - Big5/Big5-HKSCS confusion dominates Big5 errors (need more Trad. Chinese 
data)
    - IBM424-ltr vs IBM424-rtl errors are within the confusable group
---
 .../ml/chardetect/tools/EvalCharsetDetectors.java  | 379 +++++++++++++++++++++
 1 file changed, 379 insertions(+)

diff --git 
a/tika-ml/tika-ml-chardetect/src/test/java/org/apache/tika/ml/chardetect/tools/EvalCharsetDetectors.java
 
b/tika-ml/tika-ml-chardetect/src/test/java/org/apache/tika/ml/chardetect/tools/EvalCharsetDetectors.java
new file mode 100644
index 0000000000..f529c8fc53
--- /dev/null
+++ 
b/tika-ml/tika-ml-chardetect/src/test/java/org/apache/tika/ml/chardetect/tools/EvalCharsetDetectors.java
@@ -0,0 +1,379 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tika.ml.chardetect.tools;
+
+import java.io.DataInputStream;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.zip.GZIPInputStream;
+
+import org.apache.tika.detect.EncodingDetector;
+import org.apache.tika.detect.EncodingResult;
+import org.apache.tika.io.TikaInputStream;
+import org.apache.tika.metadata.Metadata;
+import org.apache.tika.ml.chardetect.CharsetConfusables;
+import org.apache.tika.ml.chardetect.MlEncodingDetector;
+import org.apache.tika.ml.chardetect.MlEncodingDetector.Rule;
+import org.apache.tika.parser.ParseContext;
+import org.apache.tika.parser.txt.Icu4jEncodingDetector;
+import org.apache.tika.parser.txt.UniversalEncodingDetector;
+
+/**
+ * Compares {@link MlEncodingDetector} against ICU4J and juniversalchardet.
+ *
+ * <p>Supports:
+ * <ul>
+ *   <li>{@code --lengths 20,50,100,200,full} — per-probe-length accuracy 
sweep</li>
+ *   <li>{@code --confusion} — top-confusion report for the ML-All 
detector</li>
+ * </ul>
+ *
+ * <p>Usage:
+ * <pre>
+ *   java EvalCharsetDetectors \
+ *     [--model /path/to/chardetect.bin] \
+ *     --data  /path/to/test-dir \
+ *     [--lengths 20,50,100,200,full] \
+ *     [--confusion]
+ * </pre>
+ */
+public class EvalCharsetDetectors {
+
+    private static final String NULL_LABEL = "(null)";
+    private static final int FULL_LENGTH = Integer.MAX_VALUE;
+
+    private static final double OOV_THRESHOLD_CJK  = 0.80;
+    private static final double OOV_THRESHOLD_SBCS = 0.98;
+    private static final Set<String> CJK_CHARSETS = Set.of(
+            "Big5", "Big5-HKSCS", "EUC-JP", "EUC-KR", "EUC-TW",
+            "GB18030", "GB2312", "GBK", "Shift_JIS"
+    );
+    private static final Set<String> OOV_EXEMPT = Set.of(
+            "US-ASCII", "UTF-16-LE", "UTF-16-BE", "UTF-32-LE", "UTF-32-BE",
+            "ISO-2022-JP", "ISO-2022-KR", "ISO-2022-CN", "HZ"
+    );
+
+    private static final String[] COL_NAMES = {"Stat", "+ISO", "+CJK", "All", 
"ICU4J", "juniv"};
+    private static final int NUM_DETECTORS = COL_NAMES.length;
+    // Index of the "All" detector — used for confusion matrix
+    private static final int IDX_ALL = 3;
+
+    public static void main(String[] args) throws Exception {
+        Path modelPath = null;
+        Path dataDir   = null;
+        int[] probeLengths = {FULL_LENGTH};
+        boolean showConfusion = false;
+
+        for (int i = 0; i < args.length; i++) {
+            switch (args[i]) {
+                case "--model":
+                    modelPath = Paths.get(args[++i]);
+                    break;
+                case "--data":
+                    dataDir = Paths.get(args[++i]);
+                    break;
+                case "--lengths":
+                    probeLengths = parseLengths(args[++i]);
+                    break;
+                case "--confusion":
+                    showConfusion = true;
+                    break;
+                default:
+                    System.err.println("Unknown argument: " + args[i]);
+                    System.exit(1);
+            }
+        }
+        if (dataDir == null) {
+            System.err.println(
+                    "Usage: EvalCharsetDetectors [--model <path>] --data <dir>"
+                    + " [--lengths 20,50,100,full] [--confusion]");
+            System.exit(1);
+        }
+
+        MlEncodingDetector base = modelPath != null
+                ? new MlEncodingDetector(modelPath)
+                : new MlEncodingDetector();
+
+        EncodingDetector[] detectors = {
+            base.withRules(EnumSet.noneOf(Rule.class)),
+            base.withRules(EnumSet.of(Rule.STRUCTURAL_GATES, 
Rule.ISO_TO_WINDOWS)),
+            base.withRules(EnumSet.of(Rule.STRUCTURAL_GATES, 
Rule.CJK_GRAMMAR)),
+            base,
+            new Icu4jEncodingDetector(),
+            new UniversalEncodingDetector()
+        };
+
+        List<Path> testFiles = Files.list(dataDir)
+                .filter(p -> p.getFileName().toString().endsWith(".bin.gz"))
+                .sorted()
+                .collect(Collectors.toList());
+
+        if (testFiles.isEmpty()) {
+            System.err.println("No .bin.gz files found in: " + dataDir);
+            System.exit(1);
+        }
+
+        // Load all samples once; truncation happens per probe-length sweep
+        List<String> charsets = new ArrayList<>();
+        List<List<byte[]>> allSamplesPerCharset = new ArrayList<>();
+        for (Path f : testFiles) {
+            String cs = f.getFileName().toString().replaceAll("\\.bin\\.gz$", 
"");
+            List<byte[]> samples = loadSamples(f);
+            if (!samples.isEmpty()) {
+                charsets.add(cs);
+                allSamplesPerCharset.add(samples);
+            }
+        }
+
+        // One pass per probe length
+        for (int probeLen : probeLengths) {
+            String lenLabel = probeLen == FULL_LENGTH ? "full" : probeLen + 
"B";
+            System.out.println("\n=== Probe length: " + lenLabel + " ===");
+            printHeader();
+
+            long totalN = 0;
+            long[][] totals = new long[NUM_DETECTORS][2];
+            // confusion[trueIdx][predLabel] = count  (only for IDX_ALL)
+            List<Map<String, Integer>> confusion = new ArrayList<>();
+            for (int ci = 0; ci < charsets.size(); ci++) {
+                confusion.add(new HashMap<>());
+            }
+
+            for (int ci = 0; ci < charsets.size(); ci++) {
+                String charset = charsets.get(ci);
+                List<byte[]> samples = truncate(allSamplesPerCharset.get(ci), 
probeLen);
+                int n = samples.size();
+                if (n == 0) {
+                    continue;
+                }
+
+                int[][] counts = new int[NUM_DETECTORS][2];
+                for (byte[] sample : samples) {
+                    for (int d = 0; d < NUM_DETECTORS; d++) {
+                        String pred = predict(detectors[d], sample);
+                        if (isStrict(charset, pred)) {
+                            counts[d][0]++;
+                        }
+                        if (isSoft(charset, pred)) {
+                            counts[d][1]++;
+                        }
+                        if (d == IDX_ALL && !isStrict(charset, pred)) {
+                            confusion.get(ci).merge(pred, 1, Integer::sum);
+                        }
+                    }
+                }
+
+                printRow(charset, n, counts);
+                totalN += n;
+                for (int d = 0; d < NUM_DETECTORS; d++) {
+                    totals[d][0] += counts[d][0];
+                    totals[d][1] += counts[d][1];
+                }
+            }
+
+            printFooter(totalN, totals);
+
+            if (showConfusion) {
+                printConfusion(charsets, allSamplesPerCharset, confusion, 
probeLen, lenLabel);
+            }
+        }
+    }
+
+    // -----------------------------------------------------------------------
+    //  Confusion matrix
+    // -----------------------------------------------------------------------
+
+    private static void printConfusion(List<String> charsets,
+                                       List<List<byte[]>> allSamplesPerCharset,
+                                       List<Map<String, Integer>> confusion,
+                                       int probeLen, String lenLabel) {
+        System.out.println("\n--- Confusion (ML-All, " + lenLabel
+                + ", top errors per charset) ---");
+
+        boolean anyError = false;
+        for (int ci = 0; ci < charsets.size(); ci++) {
+            Map<String, Integer> errors = confusion.get(ci);
+            if (errors.isEmpty()) {
+                continue;
+            }
+            anyError = true;
+            int total = truncate(allSamplesPerCharset.get(ci), 
probeLen).size();
+            int totalErrors = 
errors.values().stream().mapToInt(Integer::intValue).sum();
+            double errPct = 100.0 * totalErrors / total;
+
+            // Sort errors by count descending
+            List<Map.Entry<String, Integer>> sorted = new 
ArrayList<>(errors.entrySet());
+            sorted.sort((a, b) -> b.getValue() - a.getValue());
+
+            StringBuilder sb = new StringBuilder();
+            sb.append(String.format(Locale.ROOT, "  %-22s %5.1f%% wrong → ", 
charsets.get(ci), errPct));
+            int shown = 0;
+            for (Map.Entry<String, Integer> e : sorted) {
+                if (shown > 0) {
+                    sb.append(", ");
+                }
+                sb.append(String.format(Locale.ROOT, "%s:%.1f%%",
+                        e.getKey(), 100.0 * e.getValue() / total));
+                if (++shown >= 5) {
+                    break;
+                }
+            }
+            System.out.println(sb);
+        }
+        if (!anyError) {
+            System.out.println("  (no errors)");
+        }
+    }
+
+    // -----------------------------------------------------------------------
+    //  Table formatting
+    // -----------------------------------------------------------------------
+
+    private static void printHeader() {
+        StringBuilder sb1 = new StringBuilder();
+        sb1.append(String.format(Locale.ROOT, "%-22s  %5s  ", "", "N"));
+        sb1.append("| --- ML ablation --------------------------------- ");
+        sb1.append("| --- Baselines ----------------------- |");
+        System.out.println(sb1);
+
+        StringBuilder sb2 = new StringBuilder();
+        sb2.append(String.format(Locale.ROOT, "%-22s  %5s  ", "Charset", ""));
+        for (String name : COL_NAMES) {
+            sb2.append(String.format(Locale.ROOT, "| %-4s R%%   S%%  ", name));
+        }
+        sb2.append("|");
+        System.out.println(sb2);
+        System.out.println("-".repeat(sb2.length()));
+    }
+
+    private static void printRow(String charset, int n, int[][] counts) {
+        StringBuilder sb = new StringBuilder();
+        sb.append(String.format(Locale.ROOT, "%-22s  %5d  ", charset, n));
+        for (int d = 0; d < NUM_DETECTORS; d++) {
+            sb.append(String.format(Locale.ROOT, "| %5.1f %5.1f  ",
+                    pct(counts[d][0], n), pct(counts[d][1], n)));
+        }
+        sb.append("|");
+        System.out.println(sb);
+    }
+
+    private static void printFooter(long totalN, long[][] totals) {
+        System.out.println("-".repeat(120));
+        StringBuilder sb = new StringBuilder();
+        sb.append(String.format(Locale.ROOT, "%-22s  %5d  ", "OVERALL", 
totalN));
+        for (int d = 0; d < NUM_DETECTORS; d++) {
+            sb.append(String.format(Locale.ROOT, "| %5.1f %5.1f  ",
+                    pct(totals[d][0], totalN), pct(totals[d][1], totalN)));
+        }
+        sb.append("|");
+        System.out.println(sb);
+        System.out.println("  Stat=model only | +ISO=+C1-correction | 
+CJK=+grammar | "
+                + "All=production | R%=strict | S%=soft");
+    }
+
+    // -----------------------------------------------------------------------
+    //  Helpers
+    // -----------------------------------------------------------------------
+
+    private static int[] parseLengths(String spec) {
+        String[] parts = spec.split(",");
+        int[] result = new int[parts.length];
+        for (int i = 0; i < parts.length; i++) {
+            result[i] = "full".equalsIgnoreCase(parts[i].trim())
+                    ? FULL_LENGTH : Integer.parseInt(parts[i].trim());
+        }
+        return result;
+    }
+
+    /** Returns samples truncated to at most {@code maxLen} bytes. */
+    private static List<byte[]> truncate(List<byte[]> samples, int maxLen) {
+        if (maxLen == FULL_LENGTH) {
+            return samples;
+        }
+        List<byte[]> out = new ArrayList<>(samples.size());
+        for (byte[] s : samples) {
+            out.add(s.length <= maxLen ? s : Arrays.copyOf(s, maxLen));
+        }
+        return out;
+    }
+
+    private static String predict(EncodingDetector detector, byte[] sample) {
+        try (TikaInputStream tis = TikaInputStream.get(sample)) {
+            List<EncodingResult> results =
+                    detector.detect(tis, new Metadata(), new ParseContext());
+            return results.isEmpty() ? NULL_LABEL : 
results.get(0).getCharset().name();
+        } catch (Exception e) {
+            return NULL_LABEL;
+        }
+    }
+
+    private static boolean isStrict(String actual, String predicted) {
+        return !NULL_LABEL.equals(predicted)
+                && normalize(actual).equals(normalize(predicted));
+    }
+
+    private static boolean isSoft(String actual, String predicted) {
+        if (NULL_LABEL.equals(predicted)) {
+            return false;
+        }
+        if (normalize(actual).equals(normalize(predicted))) {
+            return true;
+        }
+        return CharsetConfusables.isLenientMatch(actual, predicted)
+                || CharsetConfusables.isLenientMatch(predicted, actual);
+    }
+
+    private static String normalize(String name) {
+        return name.toLowerCase(Locale.ROOT).replace("-", "").replace("_", "");
+    }
+
+    private static double pct(long correct, long total) {
+        return total == 0 ? 0.0 : 100.0 * correct / total;
+    }
+
+    private static List<byte[]> loadSamples(Path file) throws IOException {
+        List<byte[]> out = new ArrayList<>();
+        try (InputStream fis = new FileInputStream(file.toFile());
+             GZIPInputStream gis = new GZIPInputStream(fis);
+             DataInputStream dis = new DataInputStream(gis)) {
+            while (true) {
+                int len;
+                try {
+                    len = dis.readUnsignedShort();
+                } catch (java.io.EOFException e) {
+                    break;
+                }
+                byte[] chunk = new byte[len];
+                dis.readFully(chunk);
+                out.add(chunk);
+            }
+        }
+        return out;
+    }
+}

Reply via email to